ABOUT ME

-

Today
-
Yesterday
-
Total
-
  • [백준] 23807번 : 두 단계 최단 경로3 - 자바(JAVA)
    알고리즘/백준 2022. 6. 11. 00:01

    https://www.acmicpc.net/problem/23807

     

    23807번: 두 단계 최단 경로 3

    첫째 줄에 정점의 수 N(10 ≤ N ≤ 100,000), 간선의 수 M(10 ≤ M ≤ 300,000)이 주어진다. 다음 M개 줄에 간선 정보 u v w가 주어지며 도시 u와 도시 v 사이의 가중치가 정수 w인 양방향 도로를 나타낸

    www.acmicpc.net

     

    문제 해석

    시작점과 목적지의 최단거리를 구하는 문제이기 때문에 다익스트라가 바로 떠올랐습니다.

    하지만 중간에 꼭 거쳐가야 하는 노드들이 주어집니다.

    임의의 P개의 정점중에 적어도 3개의 정점은 무조건 거쳐가야 함.

     

    문제 풀이 전 설계

    DP를 사용해야 하나..? 전혀 방법이 떠오르지 않았습니다.

    따라서 완전 탐색으로 접근해 보려고 합니다.

     

    1번 start를 시작점으로 하는 다익스트라를 구한다.

    2번 각각의 P개의 정점들을 시작점으로 하는 다익스트라를 구한다.

    최대 101번의 다익스트라 연산

     

    100P3을 통하여 완전 탐색!

    start -> P1, P1-> P2, P2 ->P3, P3->end의 최단거리를 구해줍니다.

     

    코드 : 81% 오답

    import java.io.BufferedReader;
    import java.io.IOException;
    import java.io.InputStreamReader;
    import java.util.ArrayList;
    import java.util.List;
    import java.util.PriorityQueue;
    import java.util.StringTokenizer;
    
    
    public class Main {
    
        static class Node{
            int index;
            int weight;
    
            public Node(int index, int weight) {
                this.index = index;
                this.weight = weight;
            }
        }
    
        static int[][] DP;
        static final int INF = Integer.MAX_VALUE;
        static List<ArrayList<Node>> graph = new ArrayList<>();
        public static void main(String[] args) throws IOException {
            BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
            StringTokenizer st = new StringTokenizer(br.readLine()," ");
            int nodeCount = Integer.parseInt(st.nextToken());
            int wayCount = Integer.parseInt(st.nextToken());
    
            //양방향 그래프 생성
    
            for(int i=0; i<=nodeCount; i++){
                graph.add(new ArrayList<>());
            }
    
            for(int i=0; i<wayCount; i++){
                st = new StringTokenizer(br.readLine()," ");
                int from = Integer.parseInt(st.nextToken());
                int to = Integer.parseInt(st.nextToken());
                int weight = Integer.parseInt(st.nextToken());
                graph.get(from).add(new Node(to, weight));
                graph.get(to).add(new Node(from, weight));
    
            }
            //양방향 그래프 생성 끝
    
            st = new StringTokenizer(br.readLine()," ");
            int start = Integer.parseInt(st.nextToken());
            int end = Integer.parseInt(st.nextToken());
    
            int mustPassedCount = Integer.parseInt(br.readLine());
            List<Integer> mustPassedNodes = new ArrayList<>(mustPassedCount);
            st = new StringTokenizer(br.readLine()," ");
            for(int i=0; i< mustPassedCount; i++){
                mustPassedNodes.add(Integer.parseInt(st.nextToken()));
            }
    
            DP = new int[mustPassedCount+1][nodeCount+1];
            //거리 INF로 초기화
            for(int i=0; i<mustPassedCount+1; i++){
                for(int j=0; j <nodeCount+1; j++){
                    DP[i][j] = INF;
                }
            }
    
            //다익스트라 최대 101번 수행
            dijstra(0,start);
            for(int i=0; i<mustPassedCount; i++){
                dijstra(i+1,mustPassedNodes.get(i));
            }
    
            //답찾기
            int answer = INF;
            for(int i=0; i<mustPassedCount; i++){
                for(int j=0; j<mustPassedCount; j++){
                    for(int k=0; k<mustPassedCount; k++){
                        if(i==j || j==k || k==i){
                            continue;
                        }
                        if(DP[0][mustPassedNodes.get(i)] == INF){
                            continue;
                        }
                        if(DP[i+1][mustPassedNodes.get(j)] == INF){
                            continue;
                        }
                        if(DP[j+1][mustPassedNodes.get(k)] == INF){
                            continue;
                        }
                        if(DP[k+1][end] == INF){
                            continue;
                        }
                        answer = Math.min(answer, DP[0][mustPassedNodes.get(i)] + DP[i+1][mustPassedNodes.get(j)] + DP[j+1][mustPassedNodes.get(k)]+ DP[k+1][end]);
                    }
                }
            }
            if(answer == INF){
                answer = -1;
            }
            System.out.println(answer);
        }
    
    
    
        private static void dijstra(int index, int start){
            PriorityQueue<Node> pq = new PriorityQueue<>((c1,c2) -> c1.weight - c2.weight);
            DP[index][start] = 0;
            pq.add(new Node(start, 0));
            while(!pq.isEmpty()){
                Node cur = pq.poll();
                int curWeight = cur.weight;
                int curIndex = cur.index;
    
                if(DP[index][curIndex] < curWeight){
                    continue;
                }
                //현재 정점으로 갈 수 있는 위치들 찾기
                for(int i=0; i < graph.get(curIndex).size(); i++){
                    int nextIndex = graph.get(curIndex).get(i).index;
                    int nextWeight = graph.get(curIndex).get(i).weight;
                    //작으면 갱신하고 add
                    if(DP[index][nextIndex] > curWeight + nextWeight){
                        DP[index][nextIndex] = curWeight + nextWeight;
                        pq.add(new Node(nextIndex,curWeight + nextWeight));
                    }
                }
    
            }
        }
    }

    로직에는 전혀 문제가 없어 보이고 overflow 문제같아서 int를 long으로 교체해 보았습니다.

    정답코드

    import java.io.BufferedReader;
    import java.io.IOException;
    import java.io.InputStreamReader;
    import java.util.*;
    
    
    public class Main_23807_두단계최단경로3 {
    
        static class Node{
            int index;
            long weight;
    
            public Node(int index, long weight) {
                this.index = index;
                this.weight = weight;
            }
        }
    
        static long[][] DP;
        static final long INF = 300_000L * 1_000_000L +1;
        static List<ArrayList<Node>> graph = new ArrayList<>();
        public static void main(String[] args) throws IOException {
            BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
            StringTokenizer st = new StringTokenizer(br.readLine()," ");
            int nodeCount = Integer.parseInt(st.nextToken());
            int wayCount = Integer.parseInt(st.nextToken());
    
            //양방향 그래프 생성
    
            for(int i=0; i<=nodeCount; i++){
                graph.add(new ArrayList<>());
            }
    
            for(int i=0; i<wayCount; i++){
                st = new StringTokenizer(br.readLine()," ");
                int from = Integer.parseInt(st.nextToken());
                int to = Integer.parseInt(st.nextToken());
                int weight = Integer.parseInt(st.nextToken());
                graph.get(from).add(new Node(to, weight));
                graph.get(to).add(new Node(from, weight));
    
            }
            //양방향 그래프 생성 끝
    
            st = new StringTokenizer(br.readLine()," ");
            int start = Integer.parseInt(st.nextToken());
            int end = Integer.parseInt(st.nextToken());
    
            int mustPassedCount = Integer.parseInt(br.readLine());
            List<Integer> mustPassedNodes = new ArrayList<>(mustPassedCount);
            st = new StringTokenizer(br.readLine()," ");
            for(int i=0; i< mustPassedCount; i++){
                mustPassedNodes.add(Integer.parseInt(st.nextToken()));
            }
    
            DP = new long[mustPassedCount+1][nodeCount+1];
            //거리 INF로 초기화
            for(int i=0; i<mustPassedCount+1; i++){
                for(int j=0; j <nodeCount+1; j++){
                    DP[i][j] = INF;
                }
            }
    
            //다익스트라 최대 101번 수행
            dijstra(0,start);
            for(int i=0; i<mustPassedCount; i++){
                dijstra(i+1,mustPassedNodes.get(i));
            }
    
            //답찾기
            long answer = INF;
            for(int i=0; i<mustPassedCount; i++){
                for(int j=0; j<mustPassedCount; j++){
                    for(int k=0; k<mustPassedCount; k++){
                        if(i==j || j==k || k==i){
                            continue;
                        }
                        if(DP[0][mustPassedNodes.get(i)] == INF){
                            continue;
                        }
                        if(DP[i+1][mustPassedNodes.get(j)] == INF){
                            continue;
                        }
                        if(DP[j+1][mustPassedNodes.get(k)] == INF){
                            continue;
                        }
                        if(DP[k+1][end] == INF){
                            continue;
                        }
                        answer = Math.min(answer, DP[0][mustPassedNodes.get(i)] + DP[i+1][mustPassedNodes.get(j)] + DP[j+1][mustPassedNodes.get(k)]+ DP[k+1][end]);
                    }
                }
            }
            if(answer == INF){
                answer = -1;
            }
            System.out.println(answer);
        }
    
    
    
        private static void dijstra(int index, int start){
            PriorityQueue<Node> pq = new PriorityQueue<>(Comparator.comparingLong(c -> c.weight));
            DP[index][start] = 0;
            pq.add(new Node(start, 0));
            while(!pq.isEmpty()){
                Node cur = pq.poll();
                long curWeight = cur.weight;
                int curIndex = cur.index;
    
                if(DP[index][curIndex] < curWeight){
                    continue;
                }
                //현재 정점으로 갈 수 있는 위치들 찾기
                for(int i=0; i < graph.get(curIndex).size(); i++){
                    int nextIndex = graph.get(curIndex).get(i).index;
                    long nextWeight = graph.get(curIndex).get(i).weight;
                    //작으면 갱신하고 add
                    if(DP[index][nextIndex] > curWeight + nextWeight){
                        DP[index][nextIndex] = curWeight + nextWeight;
                        pq.add(new Node(nextIndex,curWeight + nextWeight));
                    }
                }
    
            }
        }
    }

    댓글

Designed by Tistory.