이번 장은 재귀를 통한 분할정복(Divide and conquer)을 다룬다.
재귀와 분할정복은 다른 포스팅에서 간략하게 다룬 적이 있다.
재귀 : https://codecpr.tistory.com/4
분할정복 [2.3 알고리즘의 설계 참고] : https://codecpr.tistory.com/24?category=516456
Maximum subarray problem
가장 먼저 다룰 알고리즘은 최대 부분 배열 문제(Maximum-subarray problem)이다.
최대 부분 배열 문제는 배열을 입력으로 받아 가장 큰 합을 가지는 연속 부분 배열을 찾는다.
이를 활용한 문제는 다음과 같다.
아래와 같은 향후 17일간의 주가 그래프가 주어졌다고 하자.
우리가 풀 문제는 "언제 사서, 언제 팔아야 이득일까?" 이다.
즉, 특정한 위치에서 특정한 위치까지 정해졌을 때, 그 안의 모든 값들의 합이 가장 큰 부분을 구하는 것이다.
주어지는 배열로는 주가의 증폭값인 change가 주어질 것이다.
이를 구현하는 방법으로, 브루프 포스와 분할정복 두 가지를 알아볼 것이다.
1) Brute force
Brute force는 가능한 모든 계산을 단순히 다 시도해보는 것이다.
이를 의사코드로 작성한다면, 다음과 같다.
// Pseudocode
for buy_date = 0 to 15
for sell_date = buy_date + 1 to 16
find maximum price[sell_date] – price[buy_date]
위와 같이 구현한다면, 이 알고리즘은 $\mathsf{\Theta}(n^2)$ 의 시간복잡도를 가진다.
// Maximum subarray problem (4.1)
// Ver1) Brute-fore
import java.io.*;
class MaximumSubarrayBruteForce {
public static void main(String[] args){
int[] stockPriceChanges = {0, 13, -3, -25, 20, -3, -16, -23, 18, 20, -7, 12, -5, -22, 15, -4, 7};
int max = Integer.MIN_VALUE, maxi = 1, maxj = 1;
for (int i = 1; i < stockPriceChanges.length; i++) {
int profit = 0;
for (int j = i; j < stockPriceChanges.length; j++) {
profit += stockPriceChanges[j];
if (profit > max) {
maxi = i;
maxj = j;
max = profit;
}
}
}
System.out.printf("[%d] : %d\\n", maxi, stockPriceChanges[maxi]);
System.out.printf("[%d] : %d\\n", maxj, stockPriceChanges[maxj]);
}
}
2) Divide and conquer
분할정복을 사용한 방법이다.
다음과 같은 전략을 사용하여 구현할 것이다.
- Divide : MaxSubarray(A[low...high]) 를 MaxSubarray(A[low...mid]) 과 MaxSubarray(A[mid+1...high]) 으로 나눈다.
- Conquer : MaxSubarray(A[low...mid]) 과 MaxSubarray(A[mid+1...high]) 를 재귀적으로 푼다.
- Combine : MaxSubarray(A[low...mid]) ,MaxSubarray(A[mid+1...high]), MaxCrossingSubarray(A[low...high]) 중에서 최대값을 고른다.
분할정복을 사용할 경우, 주의해야 할 사항이 있다.
위의 그림에서 (a)를 살펴보자.
mid로 절반을 나눠 계산할 경우, A[low .. mid]와 A[mid + 1 .. high]에 속하는 최대 부분 배열을 찾을 수 있다.
그러나 (b)와 같이 최대 부분 배열이 mid를 포함하는 경우를 찾지 못한다.
그래서 [Combine] 단계에 mid를 포함하는 MaxCrossingSubarray(A[low .. high]) 가 필요하다.
이에 대한 의사코드는 아래와 같다.
// Pseudocode
FIND-MAXIMUM-SUBARRAY(A, low, high)
if high == low
return (low, high, A[low])
else
mid = (low + high) / 2
(leftLow, leftHigh, leftSum) = FIND-MAXIMUM-SUBARRAY(A, low, mid)
(rightLow, rightHigh, rightSum) = FIND-MAXIMUM-SUBARRAY(A, mid + 1, high)
(crossLow, crossHigh, crossSum) = FIND-MAXIMUM-SUBARRAY(A, low, mid, high)
if leftSum >= rightSum and leftSum >= crossSum
return (leftLow, leftHigh, leftSum)
eleif rightSum >= leftSum and rightSum >= crossSum
return (rightLow, rightHigh, rightSum)
else return (crossLow, crossHigh, crossSum)
FIND-MAX-CROSSING-SUBARRAY(A, low, mid, high)
leftSum = -∞ // infinity
sum = 0
for i = mid downto low
sum += A[i]
if sum > leftSum
leftSum = sum
maxLeft = i
rightSum = -∞ // infinity
sum = 0
for j = mid downto low
sum += A[j]
if sum > rightSum
rightSum = sum
maxRight = j
return (maxLeft, maxRight, leftSum + rightSum)
이 방법의 시간 복잡도를 구해보자.
$$ f(n) = \begin{cases} \mathsf{\Theta}(1) \qquad \qquad \qquad \quad if \ n = 1 \\ 2T(n/2) + \mathsf{\Theta}(n) \qquad if \ n > 1 \end{cases} $$
이기 때문에 마스터 방법2 에 의해서 $T(n) = \mathsf{\Theta}(n \ lg \ n)$ 이다.
// Maximum subarray problem (4.1)
// Ver2) Divde and conquer
import java.io.*;
class Subarray {
public int low, high, sum;
Subarray(int low, int high, int sum) {
this.low = low;
this.high = high;
this.sum = sum;
}
}
class MaximumSubarrayDivideAndConquer {
public static Subarray findMaximumSubarray(int[] A, int low, int high) {
if (high == low) {
return new Subarray(low, high, A[low]);
} else {
int mid = (low + high) / 2;
Subarray left = findMaximumSubarray(A, low, mid);
Subarray right = findMaximumSubarray(A, mid + 1, high);
Subarray cross = findMAxCrossingSubarray(A, low, mid, high);
if ((left.sum >= right.sum) && (left.sum >= cross.sum)) return left;
else if ((right.sum >= left.sum) && (right.sum >= cross.sum)) return right;
else return cross;
}
}
public static Subarray findMAxCrossingSubarray(int[] A, int low, int mid, int high) {
int leftSum = Integer.MIN_VALUE, rightSum = Integer.MIN_VALUE, sum = 0;
int maxLeft = 0, maxRight = 0;
for (int i = mid; i > low; i--) {
sum += A[i];
if (sum > leftSum) {
leftSum = sum;
maxLeft = i;
}
}
sum = 0;
for (int i = mid + 1; i < high; i++) {
sum += A[i];
if (sum > rightSum) {
rightSum = sum;
maxRight = i;
}
}
return new Subarray(maxLeft, maxRight, leftSum + rightSum);
}
public static void main(String[] args){
int[] stockPriceChanges = {0, 13, -3, -25, 20, -3, -16, -23, 18, 20, -7, 12, -5, -22, 15, -4, 7};
Subarray result = findMaximumSubarray(stockPriceChanges, 1, stockPriceChanges.length - 1);
System.out.printf("%d %d %d", result.low, result.high, result.sum);
}
}
Matrix multiplication
다음으로는 N * N 행렬을 곱하는 알고리즘이다.
총 세 가지로 단순 반복문으로 정사각 행렬을 곱하는 방법, 분할정복을 통한 방법, 스트라센(Strassen) 알고리즘을 통한 방법이다.
1) Iterative
먼저, 단순 반복문을 통한 방법이다.
아래의 의사코드를 확인하면, 1부터 N까지 순회하는 반복문 3개가 중첩되어 있다.
그래서 이 방법은 $O(n^3)$ 의 시간복잡도를 지닌다.
// Pseudocode
squareMatrixMultiply(A, B)
n = A.rows
let C be a new n * n matrix
for i = 1 to n
for j = 1 to n
Cij = 0
for k = 1 to n
Cij = Cij + Aik * Bkj
위의 의사코드를 아래와 같이 작성할 수 있다.
// Matrix multiplication
// Ver 1) Brute force
import java.util.Arrays;
class squareMatrixMultiply {
public static int[][] squareMatrixMultiply(int[][] A, int[][] B) {
int n = A.length;
int[][] C = new int[n][n];
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) {
// Cij = 0; // Java에서는 자동으로 0으로 초기화되기 때문에 생략
for (int k = 0; k < n; k++) {
C[i][j] += A[i][k] * B[k][j];
}
}
}
return C;
}
public static void main(String[] args){
int[][] A = new int[4][4], B = new int[4][4];
for (int i = 0; i < A.length; i++) {
for (int j = 0; j < A.length; j++) {
A[i][j] = (i * 4) + j + 1;
B[i][j] = (i * 4) + j + 11;
}
}
System.out.println(Arrays.deepToString(A));
System.out.println(Arrays.deepToString(B));
System.out.println(Arrays.deepToString(squareMatrixMultiply(A, B)));
}
}
2) Divide and conquer
다음은 분할정복을 통한 방법이다.
- Divide : A,B,C 행렬을 1⁄4 씩 나눈다.
- Conquer : A11 * B11 , A12 * B21 ,A11 * B12 , A12 * B22 ,A21 * B11 , A22 * B21 ,A21 * B12 , A22 * B22를 계산한다.
- Combine :
C11 = A11 * B11 + A12 * B21
C12 = A11 * B12 + A12 * B21
C21 = A21 * B11 + A22 * B21
C22 = A21 * B12 + A22 * B22 을 계산한다.
// Pseudocode
squareMatrixMultiplyRecursive(A, B)
n = A.rows
let C be a new n * n matrix
if n == 1
c11 = a11 * b11
else
c11 = squareMatrixMultiplyRecursive(A11, B11) + squareMatrixMultiplyRecursive(A12, B21)
c12 = squareMatrixMultiplyRecursive(A11, B12) + squareMatrixMultiplyRecursive(A12, B22)
c21 = squareMatrixMultiplyRecursive(A21, B11) + squareMatrixMultiplyRecursive(A22, B21)
c22 = squareMatrixMultiplyRecursive(A21, B12) + squareMatrixMultiplyRecursive(A22, B22)
return C
위 의사코드의 시간복잡도를 구해보자.
$$ f(n) = \begin{cases} \mathsf{\Theta}(1) \qquad \qquad \qquad \quad \ if \ n = 1 \\ 8T(n/2) + \mathsf{\Theta}(n^2) \qquad if \ n > 1 \end{cases} $$
이기 때문에 마스터 방법 1에 의해서 $T(n) = \mathsf{\Theta}(n^3)$ 이다.
이렇게 어렵게 구현해야 하는데, 단순히 Brute-force를 사용한 방법과 차이가 없는 것처럼 보인다.
맞다.
이 알고리즘 방식은 메모리 계층 구조가 없는 모델에서는 장점이 없다.
그러나 cache나 virtual memory와 같이 메모리 계층 구조에서는 tiling(blocking) 이라고 불리는 기법에 의해 효과가 있다.
어찌되었든, 코드로 한 번 작성해보자.
// Matrix multiplication
// Ver 2) Divide and Conquer
import java.util.Arrays;
class SquareMatrixMultiplyRecursive {
public static int[][] add(int[][] C, int[][] A, int[][] B, int rowC, int colC) {
int n = A.length;
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) {
C[i + rowC][j + colC] = A[i][j] + B[i][j];
}
}
return C;
}
public static int[][] squareMatrixMultiplyRecursive(int[][] A, int rowA, int colA, int[][] B, int rowB, int colB, int size) {
int[][] C = new int[size][size];
if (size == 1) {
C[0][0] = A[rowA][colA] * B[rowB][colB];
} else {
int newSize = size / 2;
// C11 = squareMatrixMultiplyRecursive(A11, B11) + squareMatrixMultiplyRecursive(A12, B21)
add(C, squareMatrixMultiplyRecursive(A, rowA, colA, B, rowB, colB, newSize), squareMatrixMultiplyRecursive(A, rowA, colA + newSize, B, rowB + newSize, colB, newSize), 0, 0);
// C12 = squareMatrixMultiplyRecursive(A11, B12) + squareMatrixMultiplyRecursive(A12, B22)
add(C, squareMatrixMultiplyRecursive(A, rowA, colA, B, rowB, colB + newSize, newSize), squareMatrixMultiplyRecursive(A, rowA, colA + newSize, B, rowB + newSize, colB + newSize, newSize), 0, newSize);
// C21 = squareMatrixMultiplyRecursive(A21, B11) + squareMatrixMultiplyRecursive(A22, B21)
add(C, squareMatrixMultiplyRecursive(A, rowA + newSize, colA, B, rowB, colB, newSize), squareMatrixMultiplyRecursive(A, rowA + newSize, colA + newSize, B, rowB + newSize, colB , newSize), newSize, 0);
// C22 = squareMatrixMultiplyRecursive(A21, B12) + squareMatrixMultiplyRecursive(A22, B22)
add(C, squareMatrixMultiplyRecursive(A, rowA + newSize, colA, B, rowB, colB + newSize, newSize), squareMatrixMultiplyRecursive(A, rowA + newSize, colA + newSize, B, rowB + newSize, colB + newSize, newSize), newSize, newSize);
}
return C;
}
public static void main(String[] args){
int[][] A = new int[4][4], B = new int[4][4];
for (int i = 0; i < A.length; i++) {
for (int j = 0; j < A.length; j++) {
A[i][j] = (i * 4) + j + 1;
B[i][j] = (i * 4) + j + 11;
}
}
System.out.println(Arrays.deepToString(A));
System.out.println(Arrays.deepToString(B));
System.out.println(Arrays.deepToString(squareMatrixMultiplyRecursive(A, 0, 0, B, 0, 0, A.length)));
}
}
3) Strassen
다음은 스트라센 알고리즘을 통한 방법이다.
- Divide : A,B,C 행렬을 1⁄4 씩 나눈다.
- Conquer : S1~S10 , P1~P7 을 계산한다.
- Combine :
C11 = P5 + P4 - P2 + P6
C21 = P3 + P4
C22 = P5 + P1 - P3 - P7 을 계산한다.
스트라센에 대한 자세한 설명은 다음 링크를 참고하자.
의사코드는 아래와 같다.
// Pseudocode
squareMatrixMultiplyStrassen(A, B)
n = A.rows
let C be a new n * n matrix
if n == 1
C11 = A11 * B11
else
let S1, S2, ..., and S10 be 10 new n/2 * n/2 matrices
let P1, P2, ..., and P7 be 7 new n/2 * n/2 matrices
// S(sum) matrices
S1 = B12 - B22
S2 = A11 + A12
S3 = A21 + A22
S4 = B21 - B11
S5 = A11 + A22
S6 = B11 + B22
S7 = A12 - A22
S8 = B21 + B22
S9 = A11 - A21
S10 = B11 + B12
// P(product) matrices
// P1 = A11 * S1 = A11 * B12 - A11 * B22
P1 = SQUARE-MATRIX-MULTIPLY-STRASSEN(A11, S1)
// P2 = S2 * B22 = A11 * B22 + A12 * B22
P2 = SQUARE-MATRIX-MULTIPLY-STRASSEN(S2, B22)
// P3 = S3 * B11 = A21 * B11 + A22 * B11
P3 = SQUARE-MATRIX-MULTIPLY-STRASSEN(S3, B11)
// P4 = A22 * S4 = A22 * B21 - A22 * B11
P4 = SQUARE-MATRIX-MULTIPLY-STRASSEN(A22, S4)
// P5 = S5 * S6 = A11 * B11 + A11 * B22 + A22 * B11 + A22 * B22
P5 = SQUARE-MATRIX-MULTIPLY-STRASSEN(S5, S6)
// P6 = S7 * S8 = A12 * B21 + A12 * B22 - A22 * B21 - A22 * B22
P6 = SQUARE-MATRIX-MULTIPLY-STRASSEN(S7, S8)
// P7 = S9 * S10 = A11 * B11 + A11 * B12 - A21 * B11 - A21 * B12
P7 = SQUARE-MATRIX-MULTIPLY-STRASSEN(S9, S10)
// final sub matrices
C11 = P4 + P5 + P6 - P2
C12 = P1 + P2
C21 = P3 + P4
C22 = P1 + P5 - P3 - P7
return C
위 의사코드의 시간복잡도를 구해보자.
$$ f(n) = \begin{cases} \mathsf{\Theta}(1) \qquad \qquad \qquad \quad \ if \ n = 1 \\ 7T(n/2) + \mathsf{\Theta}(n^2) \qquad if \ n > 1 \end{cases} $$
이기 때문에 마스터 방법에 의해서 $T(n) = \mathsf{\Theta}(n^{lg \ 7})$ 이고, 이는 $T(n) = O(n^{2.81}) \ < \ \mathsf{\Theta}(n^3)$ 이다.
앞선 두 알고리즘보다 빠르다.
코드로 한 번 작성해보자.
// Matrix multiplication
// Ver 3) Strassen
import java.util.Arrays;
public class SquareMatrixMultiplyStrassen {
public static int[][] squareMatrixMultiplyStrassen(int[][] A, int rowA, int colA, int[][] B, int rowB, int colB, int size) {
int[][] C = new int[size][size];
if (size == 1) {
C[0][0] = A[rowA][colA] * B[rowB][colB];
} else {
int newSize = size / 2;
int[][] S1 = sub(B, rowB, colB + newSize, B, rowB + newSize, colB + newSize, newSize);
int[][] S2 = add(A, rowA, colA, A, rowA, colA + newSize, newSize);
int[][] S3 = add(A, rowA + newSize, colA, A, rowA + newSize, colA + newSize, newSize);
int[][] S4 = sub(B, rowB + newSize, colB, B, rowB, colB, newSize);
int[][] S5 = add(A, rowA, colA, A, rowA + newSize, colA + newSize, newSize);
int[][] S6 = add(B, rowB, colB, B, rowB + newSize, colB + newSize, newSize);
int[][] S7 = sub(A, rowA, colA + newSize, A, rowA + newSize, colA + newSize, newSize);
int[][] S8 = add(B, rowB + newSize, colB, B, rowB + newSize, colB + newSize, newSize);
int[][] S9 = sub(A, rowA, colA, A, rowA + newSize, colA, newSize);
int[][] S10 = add(B, rowB, colB, B, rowB, colB + newSize, newSize);
int[][] P1 = squareMatrixMultiplyStrassen(A, rowA, colA, S1, 0, 0, newSize);
int[][] P2 = squareMatrixMultiplyStrassen(S2, 0, 0, B,rowB + newSize, colB + newSize, newSize);
int[][] P3 = squareMatrixMultiplyStrassen(S3, 0, 0, B, rowB, colB, newSize);
int[][] P4 = squareMatrixMultiplyStrassen(A, rowA + newSize, colA + newSize, S4, 0, 0, newSize);
int[][] P5 = squareMatrixMultiplyStrassen(S5, 0, 0, S6, 0, 0, newSize);
int[][] P6 = squareMatrixMultiplyStrassen(S7, 0, 0, S8, 0, 0, newSize);
int[][] P7 = squareMatrixMultiplyStrassen(S9, 0, 0, S10, 0, 0, newSize);
int[][] C1 = add(sub(add(P5, P4), P2), P6);
int[][] C2 = add(P1, P2);
int[][] C3 = add(P3, P4);
int[][] C4 = sub(sub(add(P5, P1), P3), P7);
join(C1, C, 0, 0);
join(C2, C, 0, newSize);
join(C3, C, newSize, 0);
join(C4, C, newSize, newSize);
}
return C;
}
private static void join(int[][] C1, int[][] C, int row, int col) {
int size = C1.length;
for (int i = 0; i < size; i++) {
for (int j = 0; j < size; j++) {
C[i + row][j + col] = C1[i][j];
}
}
}
private static int[][] add(int[][] A, int[][] B) {
int[][] C = new int[A.length][B.length];
for (int i = 0; i < A.length; i++) {
for (int j = 0; j < B.length; j++) {
C[i][j] = A[i][j] + B[i][j];
}
}
return C;
}
private static int[][] add(int[][] A, int rowA, int colA, int[][] B, int rowB, int colB, int size) {
int[][] C = new int[size][size];
for (int i = 0; i < size; i++) {
for (int j = 0; j < size; j++) {
C[i][j] = A[rowA + i][colA + j] + B[rowB + i][colB + j];
}
}
return C;
}
private static int[][] sub(int[][] A, int[][] B) {
int[][] C = new int[A.length][B.length];
for (int i = 0; i < A.length; i++) {
for (int j = 0; j < B.length; j++) {
C[i][j] = A[i][j] - B[i][j];
}
}
return C;
}
private static int[][] sub(int[][] A, int rowA, int colA, int[][] B, int rowB, int colB, int size) {
int[][] C = new int[size][size];
for (int i = 0; i < size; i++) {
for (int j = 0; j < size; j++) {
C[i][j] = A[rowA + i][colA + j] - B[rowB + i][colB + j];
}
}
return C;
}
public static void main(String[] args) {
int[][] A = new int[4][4], B = new int[4][4];
for (int i = 0; i < A.length; i++) {
for (int j = 0; j < A.length; j++) {
A[i][j] = (i * 4) + j + 1;
B[i][j] = (i * 4) + j + 11;
}
}
System.out.println(Arrays.deepToString(squareMatrixMultiplyStrassen(A, 0, 0, B, 0, 0, A.length)));
}
}
'CS > 알고리즘' 카테고리의 다른 글
[Introduction to Algorithms] 3. 함수의 증가 (0) | 2022.02.12 |
---|---|
[Introduction to Algorithms] 2. 시작하기 (0) | 2022.02.02 |
[Introduction to Algorithms] 1. 알고리즘의 역할 (0) | 2022.02.02 |
유클리드 알고리즘 (Euclid’s algorithm) (0) | 2021.11.03 |