본문 바로가기

CS/알고리즘

[Introduction to Algorithms] 4. 분할정복

 

이번 장은 재귀를 통한 분할정복(Divide and conquer)을 다룬다.

재귀와 분할정복은 다른 포스팅에서 간략하게 다룬 적이 있다.

 

재귀 : https://codecpr.tistory.com/4

 

10/29 : 자료구조 Recursion

교재 : Data Structures and Abstraction with Java 챕터 : Chap 7 - Recursion Recursion : 문제를 풀기 위해서 문제를 작게 분리했는데, 반복적으로 이뤄지는 작업이라면 Recursion을 이용할 수 있다. Recursion..

codecpr.tistory.com

분할정복 [2.3 알고리즘의 설계 참고] : https://codecpr.tistory.com/24?category=516456 

 

[Introduction to Algorithms] 2. 시작하기

2. 시작하기 많은 알고리즘에서는 루프를 사용하여 문제를 해결하곤 한다. 우리는 코드가 문제없이 작성 되었다면, 그 루프가 정상적으로 작동하여 문제를 해결할 것을 알고 있다. 그러나 이를

codecpr.tistory.com

 

Maximum subarray problem

 

 

가장 먼저 다룰 알고리즘은 최대 부분 배열 문제(Maximum-subarray problem)이다.

최대 부분 배열 문제는 배열을 입력으로 받아 가장 큰 합을 가지는 연속 부분 배열을 찾는다.

 

이를 활용한 문제는 다음과 같다.

아래와 같은 향후 17일간의 주가 그래프가 주어졌다고 하자.

 

우리가 풀 문제는 "언제 사서, 언제 팔아야 이득일까?" 이다.

즉, 특정한 위치에서 특정한 위치까지 정해졌을 때, 그 안의 모든 값들의 합이 가장 큰 부분을 구하는 것이다.

주어지는 배열로는 주가의 증폭값인 change가 주어질 것이다.

 

이를 구현하는 방법으로, 브루프 포스와 분할정복 두 가지를 알아볼 것이다.

 

1) Brute force

Brute force는 가능한 모든 계산을 단순히 다 시도해보는 것이다.

이를 의사코드로 작성한다면, 다음과 같다.

// Pseudocode
for buy_date = 0 to 15
	for sell_date = buy_date + 1 to 16
		find maximum price[sell_date] – price[buy_date]

 

위와 같이 구현한다면, 이 알고리즘은 $\mathsf{\Theta}(n^2)$ 의 시간복잡도를 가진다.

 

// Maximum subarray problem (4.1)
// Ver1) Brute-fore

import java.io.*;

class MaximumSubarrayBruteForce {
    public static void main(String[] args){
        int[] stockPriceChanges = {0, 13, -3, -25, 20, -3, -16, -23, 18, 20, -7, 12, -5, -22, 15, -4, 7};
        int max = Integer.MIN_VALUE, maxi = 1, maxj = 1;

        for (int i = 1; i < stockPriceChanges.length; i++) {
            int profit = 0;
            for (int j = i; j < stockPriceChanges.length; j++) {
                profit += stockPriceChanges[j];
                if (profit > max) {
                    maxi = i;
                    maxj = j;
                    max = profit;
                }
            }
        }

        System.out.printf("[%d] : %d\\n", maxi, stockPriceChanges[maxi]);
        System.out.printf("[%d] : %d\\n", maxj, stockPriceChanges[maxj]);
    }
}

 

2) Divide and conquer

분할정복을 사용한 방법이다.

다음과 같은 전략을 사용하여 구현할 것이다.

 
  1. Divide : MaxSubarray(A[low...high]) 를 MaxSubarray(A[low...mid]) 과 MaxSubarray(A[mid+1...high]) 으로 나눈다.
  2. Conquer : MaxSubarray(A[low...mid]) 과 MaxSubarray(A[mid+1...high]) 를 재귀적으로 푼다.
  3. Combine : MaxSubarray(A[low...mid]) ,MaxSubarray(A[mid+1...high]), MaxCrossingSubarray(A[low...high]) 중에서 최대값을 고른다.

분할정복을 사용할 경우, 주의해야 할 사항이 있다.

 

 

위의 그림에서 (a)를 살펴보자.

mid로 절반을 나눠 계산할 경우, A[low .. mid]와 A[mid + 1 .. high]에 속하는 최대 부분 배열을 찾을 수 있다.

그러나 (b)와 같이 최대 부분 배열이 mid를 포함하는 경우를 찾지 못한다.

그래서 [Combine] 단계에 mid를 포함하는 MaxCrossingSubarray(A[low .. high]) 가 필요하다.

 

이에 대한 의사코드는 아래와 같다.

 

// Pseudocode
FIND-MAXIMUM-SUBARRAY(A, low, high)
	if high == low
		return (low, high, A[low])
	else
		mid = (low + high) / 2
		(leftLow, leftHigh, leftSum) = FIND-MAXIMUM-SUBARRAY(A, low, mid)
		(rightLow, rightHigh, rightSum) = FIND-MAXIMUM-SUBARRAY(A, mid + 1, high)
		(crossLow, crossHigh, crossSum) = FIND-MAXIMUM-SUBARRAY(A, low, mid, high)
	if leftSum >= rightSum and leftSum >= crossSum
		return (leftLow, leftHigh, leftSum)
	eleif rightSum >= leftSum and rightSum >= crossSum 
		return (rightLow, rightHigh, rightSum)
	else return (crossLow, crossHigh, crossSum)

FIND-MAX-CROSSING-SUBARRAY(A, low, mid, high)
	leftSum = -∞ // infinity
	sum = 0
	for i = mid downto low
		sum += A[i]
		if sum > leftSum
			leftSum = sum
			maxLeft = i
	rightSum = -∞ // infinity
	sum = 0
	for j = mid downto low
		sum += A[j]
		if sum > rightSum
			rightSum = sum
			maxRight = j
	return (maxLeft, maxRight, leftSum + rightSum)

 

이 방법의 시간 복잡도를 구해보자.

 

$$ f(n) = \begin{cases} \mathsf{\Theta}(1) \qquad \qquad \qquad \quad if \ n = 1 \\ 2T(n/2) + \mathsf{\Theta}(n) \qquad if \ n > 1 \end{cases} $$

 

이기 때문에 마스터 방법2 에 의해서 $T(n) = \mathsf{\Theta}(n \ lg \ n)$ 이다.

 

 

// Maximum subarray problem (4.1)
// Ver2) Divde and conquer

import java.io.*;

class Subarray {
    public int low, high, sum;
    Subarray(int low, int high, int sum) {
        this.low = low;
        this.high = high;
        this.sum = sum;
    }
}

class MaximumSubarrayDivideAndConquer {
    public static Subarray findMaximumSubarray(int[] A, int low, int high) {
        if (high == low) {
            return new Subarray(low, high, A[low]);
        } else {
          int mid = (low + high) / 2;
            Subarray left = findMaximumSubarray(A, low, mid);
            Subarray right = findMaximumSubarray(A, mid + 1, high);
            Subarray cross = findMAxCrossingSubarray(A, low, mid, high);

            if ((left.sum >= right.sum) && (left.sum >= cross.sum)) return left;
            else if ((right.sum >= left.sum) && (right.sum >= cross.sum)) return right;
            else return cross;
        }
    }

    public static Subarray findMAxCrossingSubarray(int[] A, int low, int mid, int high) {
        int leftSum = Integer.MIN_VALUE, rightSum = Integer.MIN_VALUE, sum = 0;
        int maxLeft = 0, maxRight = 0;

        for (int i = mid; i > low; i--) {
            sum += A[i];
            if (sum > leftSum) {
                leftSum = sum;
                maxLeft = i;
            }
        }

        sum = 0;
        for (int i = mid + 1; i < high; i++) {
            sum += A[i];
            if (sum > rightSum) {
                rightSum = sum;
                maxRight = i;
            }
        }

        return new Subarray(maxLeft, maxRight, leftSum + rightSum);
    }

    public static void main(String[] args){
        int[] stockPriceChanges = {0, 13, -3, -25, 20, -3, -16, -23, 18, 20, -7, 12, -5, -22, 15, -4, 7};
        Subarray result = findMaximumSubarray(stockPriceChanges, 1, stockPriceChanges.length - 1);
        System.out.printf("%d %d %d", result.low, result.high, result.sum);
    }
}

 

Matrix multiplication

다음으로는 N * N 행렬을 곱하는 알고리즘이다.

총 세 가지로 단순 반복문으로 정사각 행렬을 곱하는 방법, 분할정복을 통한 방법, 스트라센(Strassen) 알고리즘을 통한 방법이다.

 

1) Iterative

먼저, 단순 반복문을 통한 방법이다.

아래의 의사코드를 확인하면, 1부터 N까지 순회하는 반복문 3개가 중첩되어 있다.

그래서 이 방법은 $O(n^3)$ 의 시간복잡도를 지닌다.

 

// Pseudocode
squareMatrixMultiply(A, B)
    n = A.rows
    let C be a new n * n matrix
    for i = 1 to n
        for j = 1 to n
            Cij = 0
            for k = 1 to n
                Cij = Cij + Aik * Bkj

 

위의 의사코드를 아래와 같이 작성할 수 있다.

 

 

// Matrix multiplication
// Ver 1) Brute force

import java.util.Arrays;

class squareMatrixMultiply {
    public static int[][] squareMatrixMultiply(int[][] A, int[][] B) {
        int n = A.length;
        int[][] C = new int[n][n];

        for (int i = 0; i < n; i++) {
            for (int j = 0; j < n; j++) {
                // Cij = 0; // Java에서는 자동으로 0으로 초기화되기 때문에 생략
                for (int k = 0; k < n; k++) {
                    C[i][j] += A[i][k] * B[k][j];
                }
            }
        }
        return C;
    }

    public static void main(String[] args){
        int[][] A = new int[4][4], B = new int[4][4];

        for (int i = 0; i < A.length; i++) {
            for (int j = 0; j < A.length; j++) {
                A[i][j] = (i * 4) + j + 1;
                B[i][j] = (i * 4) + j + 11;
            }
        }

        System.out.println(Arrays.deepToString(A));
        System.out.println(Arrays.deepToString(B));
        System.out.println(Arrays.deepToString(squareMatrixMultiply(A, B)));
    }
}

 

2) Divide and conquer

다음은 분할정복을 통한 방법이다.

  • Divide : A,B,C 행렬을 1⁄4 씩 나눈다.
  • Conquer : A11 * B11 , A12 * B21 ,A11 * B12 , A12 * B22 ,A21 * B11 , A22 * B21 ,A21 * B12 , A22 * B22를 계산한다.
  • Combine :
    C11 = A11 * B11 + A12 * B21
    C12 = A11 * B12 + A12 * B21
    C21 = A21 * B11 + A22 * B21
    C22 = A21 * B12 + A22 * B22 을 계산한다.

 

 
// Pseudocode
squareMatrixMultiplyRecursive(A, B)
    n = A.rows
    let C be a new n * n matrix
    if n == 1
        c11 = a11 * b11
    else
        c11 = squareMatrixMultiplyRecursive(A11, B11) + squareMatrixMultiplyRecursive(A12, B21)
        c12 = squareMatrixMultiplyRecursive(A11, B12) + squareMatrixMultiplyRecursive(A12, B22)
        c21 = squareMatrixMultiplyRecursive(A21, B11) + squareMatrixMultiplyRecursive(A22, B21)
        c22 = squareMatrixMultiplyRecursive(A21, B12) + squareMatrixMultiplyRecursive(A22, B22)

    return C

 

위 의사코드의 시간복잡도를 구해보자.

 

$$ f(n) = \begin{cases} \mathsf{\Theta}(1) \qquad \qquad \qquad \quad \ if \ n = 1 \\ 8T(n/2) + \mathsf{\Theta}(n^2) \qquad if \ n > 1 \end{cases} $$

 

이기 때문에 마스터 방법 1에 의해서 $T(n) = \mathsf{\Theta}(n^3)$ 이다.

이렇게 어렵게 구현해야 하는데, 단순히 Brute-force를 사용한 방법과 차이가 없는 것처럼 보인다.

 

맞다.

이 알고리즘 방식은 메모리 계층 구조가 없는 모델에서는 장점이 없다.

그러나 cache나 virtual memory와 같이 메모리 계층 구조에서는 tiling(blocking) 이라고 불리는 기법에 의해 효과가 있다.

 

어찌되었든, 코드로 한 번 작성해보자. 

 

// Matrix multiplication
// Ver 2) Divide and Conquer

import java.util.Arrays;

class SquareMatrixMultiplyRecursive {
    public static int[][] add(int[][] C, int[][] A, int[][] B, int rowC, int colC) {
        int n = A.length;

        for (int i = 0; i < n; i++) {
            for (int j = 0; j < n; j++) {
                C[i + rowC][j + colC] = A[i][j] + B[i][j];
            }
        }
        return C;
    }

    public static int[][] squareMatrixMultiplyRecursive(int[][] A, int rowA, int colA, int[][] B, int rowB, int colB, int size) {
        int[][] C = new int[size][size];

        if (size == 1) {
            C[0][0] = A[rowA][colA] * B[rowB][colB];
        } else {
            int newSize = size / 2;
            // C11 = squareMatrixMultiplyRecursive(A11, B11) + squareMatrixMultiplyRecursive(A12, B21)
            add(C, squareMatrixMultiplyRecursive(A, rowA, colA, B, rowB, colB, newSize), squareMatrixMultiplyRecursive(A, rowA, colA + newSize, B, rowB + newSize, colB, newSize), 0, 0);
            // C12 = squareMatrixMultiplyRecursive(A11, B12) + squareMatrixMultiplyRecursive(A12, B22)
            add(C, squareMatrixMultiplyRecursive(A, rowA, colA, B, rowB, colB + newSize, newSize), squareMatrixMultiplyRecursive(A, rowA, colA + newSize, B, rowB + newSize, colB + newSize, newSize), 0, newSize);
            // C21 = squareMatrixMultiplyRecursive(A21, B11) + squareMatrixMultiplyRecursive(A22, B21)
            add(C, squareMatrixMultiplyRecursive(A, rowA + newSize, colA, B, rowB, colB, newSize), squareMatrixMultiplyRecursive(A, rowA  + newSize, colA + newSize, B, rowB + newSize, colB , newSize),  newSize, 0);
            // C22 = squareMatrixMultiplyRecursive(A21, B12) + squareMatrixMultiplyRecursive(A22, B22)
            add(C, squareMatrixMultiplyRecursive(A, rowA + newSize, colA, B, rowB, colB + newSize, newSize), squareMatrixMultiplyRecursive(A, rowA + newSize, colA + newSize, B, rowB + newSize, colB + newSize, newSize), newSize, newSize);
        }

        return C;
    }

    public static void main(String[] args){
        int[][] A = new int[4][4], B = new int[4][4];

        for (int i = 0; i < A.length; i++) {
            for (int j = 0; j < A.length; j++) {
                A[i][j] = (i * 4) + j + 1;
                B[i][j] = (i * 4) + j + 11;
            }
        }

        System.out.println(Arrays.deepToString(A));
        System.out.println(Arrays.deepToString(B));
        System.out.println(Arrays.deepToString(squareMatrixMultiplyRecursive(A, 0, 0, B, 0, 0, A.length)));
    }
}

 

3) Strassen

다음은 스트라센 알고리즘을 통한 방법이다.

  1. Divide : A,B,C 행렬을 1⁄4 씩 나눈다.
  2. Conquer : S1~S10 , P1~P7 을 계산한다.
  3. Combine :
    C11 = P5 + P4 - P2 + P6
     
    C12 = P1 + P2
    C21 = P3 + P4
    C22 = P5 + P1 - P3 - P7 을 계산한다.

 

스트라센에 대한 자세한 설명은 다음 링크를 참고하자.

https://ko.wikipedia.org/wiki/%EC%8A%88%ED%8A%B8%EB%9D%BC%EC%84%BC_%EC%95%8C%EA%B3%A0%EB%A6%AC%EC%A6%98

 

슈트라센 알고리즘 - 위키백과, 우리 모두의 백과사전

선형대수학에서 슈트라센 알고리즘은 독일의 수학자 폴커 슈트라센(Volker Strassen)이 1969년에 개발한 행렬 곱셈 알고리즘이다. 정의에 따라 n×n 크기의 두 행렬을 곱하면 O(n3)의 시간이 소요되지만

ko.wikipedia.org

 

의사코드는 아래와 같다.

 

// Pseudocode
squareMatrixMultiplyStrassen(A, B)
	n = A.rows
	let C be a new n * n matrix
	if n == 1
		C11 = A11 * B11
	else
		let S1, S2, ..., and S10 be 10 new n/2 * n/2 matrices
		let P1, P2, ..., and P7 be 7 new n/2 * n/2 matrices

		// S(sum) matrices
		S1  = B12 - B22
		S2  = A11 + A12
		S3  = A21 + A22
		S4  = B21 - B11
		S5  = A11 + A22
		S6  = B11 + B22
		S7  = A12 - A22
		S8  = B21 + B22
		S9  = A11 - A21
		S10 = B11 + B12

		// P(product) matrices
		// P1 = A11 * S1 = A11 * B12 - A11 * B22
		P1  = SQUARE-MATRIX-MULTIPLY-STRASSEN(A11, S1)
		// P2 = S2 * B22 = A11 * B22 + A12 * B22
		P2  = SQUARE-MATRIX-MULTIPLY-STRASSEN(S2, B22)
		// P3 = S3 * B11 = A21 * B11 + A22 * B11
		P3  = SQUARE-MATRIX-MULTIPLY-STRASSEN(S3, B11)
		// P4 = A22 * S4 = A22 * B21 - A22 * B11
		P4  = SQUARE-MATRIX-MULTIPLY-STRASSEN(A22, S4)
		// P5 = S5 * S6 = A11 * B11 + A11 * B22 + A22 * B11 + A22 * B22
		P5  = SQUARE-MATRIX-MULTIPLY-STRASSEN(S5, S6)
		// P6 = S7 * S8 = A12 * B21 + A12 * B22 - A22 * B21 - A22 * B22
		P6  = SQUARE-MATRIX-MULTIPLY-STRASSEN(S7, S8)
		// P7 = S9 * S10 = A11 * B11 + A11 * B12 - A21 * B11 - A21 * B12
		P7  = SQUARE-MATRIX-MULTIPLY-STRASSEN(S9, S10)

		// final sub matrices
		C11 = P4 + P5 + P6 - P2
		C12 = P1 + P2
		C21 = P3 + P4
		C22 = P1 + P5 - P3 - P7
	
	return C

 

위 의사코드의 시간복잡도를 구해보자.

 

$$ f(n) = \begin{cases} \mathsf{\Theta}(1) \qquad \qquad \qquad \quad \ if \ n = 1 \\ 7T(n/2) + \mathsf{\Theta}(n^2) \qquad if \ n > 1 \end{cases} $$

 

이기 때문에 마스터 방법에 의해서 $T(n) = \mathsf{\Theta}(n^{lg \ 7})$ 이고, 이는 $T(n) = O(n^{2.81}) \ < \ \mathsf{\Theta}(n^3)$ 이다.

앞선 두 알고리즘보다 빠르다.

 

코드로 한 번 작성해보자.

 

// Matrix multiplication
// Ver 3) Strassen

import java.util.Arrays;

public class SquareMatrixMultiplyStrassen {
    public static int[][] squareMatrixMultiplyStrassen(int[][] A, int rowA, int colA, int[][] B, int rowB, int colB, int size) {
        int[][] C = new int[size][size];
        if (size == 1) {
            C[0][0] = A[rowA][colA] * B[rowB][colB];
        } else {
            int newSize = size / 2;
            int[][] S1 = sub(B, rowB, colB + newSize, B, rowB + newSize, colB + newSize, newSize);
            int[][] S2 = add(A, rowA, colA, A, rowA, colA + newSize, newSize);
            int[][] S3 = add(A, rowA + newSize, colA, A, rowA + newSize, colA + newSize, newSize);
            int[][] S4 = sub(B, rowB + newSize, colB, B, rowB, colB, newSize);
            int[][] S5 = add(A, rowA, colA, A, rowA + newSize, colA + newSize, newSize);
            int[][] S6 = add(B, rowB, colB, B, rowB + newSize, colB + newSize, newSize);
            int[][] S7 = sub(A, rowA, colA + newSize, A, rowA + newSize, colA + newSize, newSize);
            int[][] S8 = add(B, rowB + newSize, colB, B, rowB + newSize, colB + newSize, newSize);
            int[][] S9 = sub(A, rowA, colA, A, rowA + newSize, colA, newSize);
            int[][] S10 = add(B, rowB, colB, B, rowB, colB + newSize, newSize);

            int[][] P1 = squareMatrixMultiplyStrassen(A, rowA, colA, S1, 0, 0, newSize);
            int[][] P2 = squareMatrixMultiplyStrassen(S2, 0, 0, B,rowB + newSize, colB + newSize, newSize);
            int[][] P3 = squareMatrixMultiplyStrassen(S3, 0, 0, B, rowB, colB, newSize);
            int[][] P4 = squareMatrixMultiplyStrassen(A, rowA + newSize, colA + newSize, S4, 0, 0, newSize);
            int[][] P5 = squareMatrixMultiplyStrassen(S5, 0, 0, S6, 0, 0, newSize);
            int[][] P6 = squareMatrixMultiplyStrassen(S7, 0, 0, S8, 0, 0, newSize);
            int[][] P7 = squareMatrixMultiplyStrassen(S9, 0, 0, S10, 0, 0, newSize);

            int[][] C1 = add(sub(add(P5, P4), P2), P6);
            int[][] C2 = add(P1, P2);
            int[][] C3 = add(P3, P4);
            int[][] C4 = sub(sub(add(P5, P1), P3), P7);

            join(C1, C, 0, 0);
            join(C2, C, 0, newSize);
            join(C3, C, newSize, 0);
            join(C4, C, newSize, newSize);
        }
        return C;
    }

    private static void join(int[][] C1, int[][] C, int row, int col) {
        int size = C1.length;
        for (int i = 0; i < size; i++) {
            for (int j = 0; j < size; j++) {
                C[i + row][j + col] = C1[i][j];
            }
        }
    }

    private static int[][] add(int[][] A, int[][] B) {
        int[][] C = new int[A.length][B.length];
        for (int i = 0; i < A.length; i++) {
            for (int j = 0; j < B.length; j++) {
                C[i][j] = A[i][j] + B[i][j];
            }
        }
        return C;
    }

    private static int[][] add(int[][] A, int rowA, int colA, int[][] B, int rowB, int colB, int size) {
        int[][] C = new int[size][size];
        for (int i = 0; i < size; i++) {
            for (int j = 0; j < size; j++) {
                C[i][j] = A[rowA + i][colA + j] + B[rowB + i][colB + j];
            }
        }
        return C;
    }

    private static int[][] sub(int[][] A, int[][] B) {
        int[][] C = new int[A.length][B.length];
        for (int i = 0; i < A.length; i++) {
            for (int j = 0; j < B.length; j++) {
                C[i][j] = A[i][j] - B[i][j];
            }
        }
        return C;
    }

    private static int[][] sub(int[][] A, int rowA, int colA, int[][] B, int rowB, int colB, int size) {
        int[][] C = new int[size][size];
        for (int i = 0; i < size; i++) {
            for (int j = 0; j < size; j++) {
                C[i][j] = A[rowA + i][colA + j] - B[rowB + i][colB + j];
            }
        }
        return C;
    }

    public static void main(String[] args) {
        int[][] A = new int[4][4], B = new int[4][4];

        for (int i = 0; i < A.length; i++) {
            for (int j = 0; j < A.length; j++) {
                A[i][j] = (i * 4) + j + 1;
                B[i][j] = (i * 4) + j + 11;
            }
        }

        System.out.println(Arrays.deepToString(squareMatrixMultiplyStrassen(A, 0, 0, B, 0, 0, A.length)));
    }
}