Thursday, May 25, 2017

311. Sparse Matrix Multiplication

Given two sparse matrices A and B, return the result of AB.
You may assume that A's column number is equal to B's row number.
Example:
A = [
  [ 1, 0, 0],
  [-1, 0, 3]
]

B = [
  [ 7, 0, 0 ],
  [ 0, 0, 0 ],
  [ 0, 0, 1 ]
]


     |  1 0 0 |   | 7 0 0 |   |  7 0 0 |
AB = | -1 0 3 | x | 0 0 0 | = | -7 0 3 |
                  | 0 0 1 |



Solution:

Method 1: 

The naive way is to use three loops.

We find that we can switch the second and third loop and the result remains the same.

We also find if A[i][k] is 0, there is no need to continue the third loop. Thus save up some time.



Code:


public class Solution {
    public int[][] multiply(int[][] A, int[][] B) {
        int m = A.length;
        int t = A[0].length;
        int n = B[0].length;
        
        int[][] result = new int[m][n];
        
        for (int i = 0; i < m; i++) {
            for (int k = 0; k < t; k++) {
                if (A[i][k] == 0) {
                    continue;
                }
                for (int j = 0; j < n; j++) {
                    result[i][j] += A[i][k] * B[k][j];
                }
            }
        }
        
        return result;
    }
}



Method 2:

Similar to Method 1. We want to skip zeros.

In other words, for each row in A, we only care about non-zero elements.

Therefore, we use a list to store the non-zero elements and their index.

In the for loop, we only multiple these elements.



Code:


public class Solution {
    public int[][] multiply(int[][] A, int[][] B) {
        int m = A.length;
        int t = A[0].length;
        int n = B[0].length;
        
        int[][] result = new int[m][n];
        
        List[] index = new List[m];
        for (int i = 0; i < m; i++) {
            List<Integer> list = new ArrayList<>();
            for (int j = 0; j < t; j++) {
                if (A[i][j] != 0) {
                    list.add(j);
                    list.add(A[i][j]);
                }
            }
            index[i] = list;
        }
        
        for (int i = 0; i < m; i++) {
            List<Integer> list = index[i];
            for (int p = 0; p < list.size() - 1; p += 2) {
                int col = list.get(p);
                int val = list.get(p + 1);
                for (int j = 0; j < n; j++) {
                    result[i][j] += val * B[col][j];
                }
            }
        }
        
        return result;
    }
}