본문 바로가기

Algorithm/Problems

[JAVA]슈퍼컴퓨터 클러스터(Softeer)

문제

문제

 대규모 머신 러닝에서는 여러 컴퓨터를 하나의 클러스터로 묶어 계산을 수행하는 경우가 많다. 병렬 컴퓨팅 파워가 늘어나면 훨씬 더 거대한 데이터도 실용적으로 사용할 수 있게 된다. 클라우드 컴퓨팅을 이용하는 기업도 많지만, 개인정보와 보안, 네트워킹, 비용 등의 문제로 직접 클러스터를 구축하는 경우도 많다.

 현지도 이러한 머신 러닝용 클러스터를 관리하는 역할을 맡고 있다. 클러스터의 성능은 컴퓨터의 수가 많아질 수록, 각각의 성능이 올라갈 수록 향상된다. 그런데 어느 날 협업 중인 몇몇 연구실에서 클러스터의 성능을 업그레이드해 달라는 요청을 보내 왔다. 이들은 특히 클러스터를 이루는 각각의 컴퓨터 중 성능이 가장 낮은 컴퓨터 의 성능이 병목이 된다고 알려 왔다.

 이 클러스터에는 N대의 컴퓨터가 있으며, 각각의 성능은 ai라는 정수로 평가할 수 있다. 현지는 각각의 컴퓨터에 비용을 들여 업그레이드를 진행할 수 있다. 성능을 d만큼 향상시키는 데에 드는 비용은 $d^2$원이다. (단, d는 자연수이다.)

업그레이드를 하지 않는 컴퓨터가 있어도 되지만, 한 컴퓨터에 두 번 이상 업그레이드를 수행할 수는 없다.

 업그레이드를 위한 예산이 B원 책정되어 있다. 현지의 목표는 B원 이하의 총 비용으로 업그레이드를 수행하여, 성능이 가장 낮은 컴퓨터의 성능을 최대화 하는 것이 목표이다. 이 최선의 최저성능을 계산하는 프로그램을 작성하시오.

제약조건

  • 1≤N≤$10^5$인 정수
  • 1≤$a_i$≤$10^9$인 정수
  • 1≤B≤$10^18$인 정수

입력형식

첫째 줄에 컴퓨터의 수와 예산을 나타내는 정수 N과 B가 공백을 사이에 두고 주어진다.
둘째 줄에 각 컴퓨터의 성능을 나타내는 N개의 정수 $a_1, a_2, ..., a_n$이 공백을 사이에 두고 주어진다.

B의 범위가 매우 넓어, 사용하는 프로그래밍 언어에 따라 64비트 정수형을 사용해야 할 수 있음에 유의하시오.

출력형식

첫째 줄에 예산을 효율적으로 사용했을 때, 성능이 가장 낮은 컴퓨터의 성능으로 가능한 최댓값을 출력하시오.

 

입력예제1

4 10
5 5 6 1

출력예제1

4

 네 번째 컴퓨터의 성능을 1에서 4로 업그레이드 하는 데 드는 비용은 $3^2=9$원이다. 이 경우, 가장 낮은 성능의 컴퓨터는 4의 성능을 가지게 되며, 가장 낮은 성능으로 가능한 최대값은 4가 됨을 알 수 있다.

 

입력예제2

10 10
5 3 9 8 4 3 1 8 6 3
111

출력예제2

3

 일곱 번째 컴퓨터의 성능을 1에서 3으로 업그레이드 하는 데 드는 비용은 $2^2=4$원이다. 이 경우, 가장 낮은 성능의 컴퓨터는 3의 성능을 가지게 된다.

 가장 낮은 성능이 4가 되기 위해서는 원래 3의 성능을 가지고 있었던 두 번째, 여섯 번째, 그리고 열 번째 컴퓨터를 향상시키는 데 드는 비용 $1^2+1^2+1^2=3$원과 일곱 번째 컴퓨터의 성능을 3이 아닌 4로 향상시키는 데 드는 비용 $(4-1)^2=9$원, 총 12원이 들기 때문에 불가능하다.

 

입력예제3

10 10000000  
10 2 2 9 6 1 8 3 1 9

출력예제3

1005

풀이

 우선 문제를 생각했을때 예산 안에서 가능한 최대의 성능을 찾는 문제이기에 탐색 문제이고, 이렇게 범위가 넓은 경우에는 $O(log)$를 가지는 이분탐색을 먼저 떠올렸다. 그리고, N의 데이터 사이즈가 $10^5$이고, 입력값의 순서나 위치가 중요하지 않기 때문에, 해시맵을 이용해서 <성능, 갯수>로 매핑을 했다.

 이분탐색을 통해 정답을 찾아나가는데, 내부 반복문을 통해 비용 계산이 이루어진다. 따라서, 내부 반복문의 횟수를 최소화하기 위해 정렬을 해준다.(처음에는 내부 반복문 안에서 리스트를 만들어 정렬을 하는식으로 했었는데, 반복문 밖으로 빼서 한번만 했으면 되는거다)

 정답 제출을 하니 웬만한 케이스는 정답처리가 되었는데, 일부 케이스에서 오답과 런타임 오류가 발생했다. 문제에도 주어져 있듯, B의 사이즈가 int의 범위보다 크기에 long을 사용해야 했는데, 처음에는 total과 B모두 int로 사용했었다. 코드 수정을 통해 이 부분을 변경해서 런타임 오류인 케이스는 사라졌지만 오답인 케이스는 잔존했다.

원인 분석을 하다가 다른 사람의 코드를 참고했다.

 원인은 이분 탐색의 범위에 있었다. 처음에는 hi를 $a_i$의 범위로 생각해서 지정했었으나, 문제의 목적이 성능이 가장 낮은 컴퓨터의 성능을 최대화 하는것이기에, 예산만 남아있다면 얼마든지 $a_i$의 범위를 넘어 컴퓨터를 업그레이드 할 수 있는 것이었다. 따라서, hi를 $10^9$이 아닌 현 컴퓨터의 최대 성능 + B의 제곱근 값으로 설정해주니 정답처리가 되었다. B의 제곱근으로 하는 이유는 아마도 $a_i$의 최대범위인 $10^9$이 더해지는 게 맞는데 사실 B의 값에 따라 상계가 정해지는 것이기에 그렇다고 생각했다.

고찰

 런타임 오류가 나는 점에서 자료형의 문제임을 깨닫고 수정한 것에서 성장했다고는 할 수 있으나, 아직 실수가 존재한다. 조금 더 문제를 많이 풀자!

import java.util.*;
import java.io.*;

public class Main
{
    public static void main(String args[]) throws Exception
    {
        HashMap<Integer, Integer> hash = new HashMap(); // <값, 개수> 해시맵

        // 입력값 할당 O(N)
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        StringTokenizer st = new StringTokenizer(br.readLine());
        int N = Integer.parseInt(st.nextToken());
        long B = Long.parseLong(st.nextToken());
        st = new StringTokenizer(br.readLine());
        while(st.hasMoreTokens()){
            int cost = Integer.parseInt(st.nextToken());
            hash.put(cost, hash.getOrDefault(cost, 0) + 1);
        }

        // 값의 범위가 10^9까지, 이분 탐색 사용
        long answer = 0;
        ArrayList<Integer> keys = new ArrayList(hash.keySet()); // 성능값 리스트
        Collections.sort(keys);
        long lo = keys.get(0);
        long hi = keys.get(keys.size()-1) + (int)Math.sqrt(B);
        // 이분 탐색 O(60) (log_2{10^18})
        while(lo <= hi){
            long mid = (lo + hi) / 2; // 목표 성능
            long total = 0; // 예상 비용
            // 성능값으로 비용 계산 O(N)
            for(int cost : keys){
                // 목표 성능 이상인 경우 break;
                if(cost >= mid)
                    break;
                // 목표 성능 미만인 경우 비용 계산
                total += Math.pow(mid - cost, 2) * hash.get(cost);
                // 예산 초과일 경우 반복문 종료
                if(total > B)
                    break;
            }
            // 비용이 예산이하인 경우 하한 상향
            if(total <= B){
                answer = mid;
                lo = mid + 1;
            }
            // 예산 초과일 경우 상한 하향
            else{
                hi = mid - 1;
            }
        }
        System.out.println(answer);
    }
}