알고리즘 연습/동적 계획법 상급

[🥇3 / 백준 11049 / 파이썬] 행렬 곱셈 순서

김세진 2021. 7. 27. 16:12
반응형

 

 

11049번: 행렬 곱셈 순서

첫째 줄에 입력으로 주어진 행렬을 곱하는데 필요한 곱셈 연산의 최솟값을 출력한다. 정답은 231-1 보다 작거나 같은 자연수이다. 또한, 최악의 순서로 연산해도 연산 횟수가 231-1보다 작거나 같

www.acmicpc.net

문제

크기가 N×M인 행렬 A와 M×K인 B를 곱할 때 필요한 곱셈 연산의 수는 총 N×M×K번이다. 행렬 N개를 곱하는데 필요한 곱셈 연산의 수는 행렬을 곱하는 순서에 따라 달라지게 된다.

예를 들어, A의 크기가 5×3이고, B의 크기가 3×2, C의 크기가 2×6인 경우에 행렬의 곱 ABC를 구하는 경우를 생각해보자.

  • AB를 먼저 곱하고 C를 곱하는 경우 (AB)C에 필요한 곱셈 연산의 수는 5×3×2 + 5×2×6 = 30 + 60 = 90번이다.
  • BC를 먼저 곱하고 A를 곱하는 경우 A(BC)에 필요한 곱셈 연산의 수는 3×2×6 + 5×3×6 = 36 + 90 = 126번이다.

같은 곱셈이지만, 곱셈을 하는 순서에 따라서 곱셈 연산의 수가 달라진다.

행렬 N개의 크기가 주어졌을 때, 모든 행렬을 곱하는데 필요한 곱셈 연산 횟수의 최솟값을 구하는 프로그램을 작성하시오. 입력으로 주어진 행렬의 순서를 바꾸면 안 된다.

입력

첫째 줄에 행렬의 개수 N(1 ≤ N ≤ 500)이 주어진다.

둘째 줄부터 N개 줄에는 행렬의 크기 r과 c가 주어진다. (1 ≤ r, c ≤ 500)

항상 순서대로 곱셈을 할 수 있는 크기만 입력으로 주어진다.

출력

첫째 줄에 입력으로 주어진 행렬을 곱하는데 필요한 곱셈 연산의 최솟값을 출력한다. 정답은 231-1 보다 작거나 같은 자연수이다. 또한, 최악의 순서로 연산해도 연산 횟수가 231-1보다 작거나 같다.

 

예제 입력 

3
5 3
3 2
2 6

예제 출력 

90


 

풀이

 

행렬 곱셈 연산의 최솟값을 구하는 문제이다.

우선 다이나믹 프로그래밍을 어떻게 적용할 수 있을 지 살펴보자.

 

 

1. 부분 문제로 해결할 수 있는가?

 

행렬의 곱셈 순서는 문제에 나왔다시피 ABC 를 곱할 때, (AB)C 혹은 A(BC) 두 가지로 나눠서 생각해볼 수 있다.

즉, 더 작은 문제로 나누어 해결할 수 있다!

 

 

2. 메모이제이션을 적용할 수 있는가?

 

만약 ABCD 를 곱하는 방법을 구한다면 AB, BC, CD 는 여러 번 곱하게 될 것이다.

행렬의 개수가 많다면 ABC 혹은 BCD와 같은 3개 단위도 계속하여 여러 번 구하게 될 것이다.

 

따라서, 행렬의 부분 구간의 최솟값을 저장해 둔다면 굳이 똑같은 값을 여러번 구하지 않아도 된다.

메모이제이션이 가능하다는 말이다.

 

이는 bottom-up 형식으로 문제를 풀 때 매우 유리하다.

 

 

위의 정리된 내용을 바탕으로 다이나믹 프로그래밍을 설계한다.

필자는 부분 문제를 수월하게 풀기 위해 재귀를 활용했다.

 

2차원 배열을 만들고, i~j 구간의 최솟값을 이 배열에 dp[i][j] 형식으로 저장해둘 것이다.

 

해당 구간의 값을 구할 때, 이미 구간의 값이 배열에 있다면 이 값을 리턴하고

아니라면 계속하여 양분할 해 나가면서 부분 문제로 나누어 값을 구한다.

 

 

단, 해당 문제는 PyPy3로 제출해야 한다.

python3로 푼 사람이 현 시점까지 단 3명으로, 극한의 시간복잡도 최적화를 요구하는 것 같다.

채점 기준을 조금 수정해야 하지 않나 싶다.

 

import sys
input = sys.stdin.readline

n = int(input())
a = [list(map(int,input().split())) for i in range(n)]
d = [[0 for i in range(n)] for i in range(n)]

# 인접한 2개의 행렬은 미리 곱하여 배열에 넣는다.
for i in range(n-1):
    d[i][i+1] = a[i][0]*a[i+1][0]*a[i+1][1]

def dp(start,end):
    if d[start][end] != 0:
        return d[start][end]
    if start == end:
        return 0
    
    r = float('inf')
    for i in range(start,end):
        temp = dp(start,i) + dp(i+1,end) + a[start][0]*a[i+1][0]*a[end][1]
        if r > temp:
            r = temp
    d[start][end] = r
    return r

print(dp(0,n-1))
반응형