BLAS: behind the scenes fast linear algebra kernels in the presence of memory hierarchies
Recall BLAS is classified into three categories: The Level 1 BLAS (BLAS1) operate mostly on vectors (1D arrays) if the vectors have length n, perform O(n) operations, return either a vector or a scalar, e.g., saxpy operation: y = a*x+y; saxpy is an acronym, S stands for single precision; daxpy for double precision, caxpy for complex, and zaxpy for double complex The Level 2 BLAS (BLAS2) operate mostly on a matrix (2D array) and a vector (or vectors), returning a matrix or a vector if the array is n-by-n, O(n2) operations are performed, e.g., dgemv: matrix-vector multiplication y = y + A*x in double precision, where A is m-by-n, x is n-by-1 and y is m-by-1 rank-one update A = A + y*x', where A is m-by-n, y is m-by-1, x is n-by-1, etc. triangular solve y=T*x for x, where T is a triangular matrix. The Level 3 BLAS (BLAS3) operate on pairs or triples of matrices, returning a matrix dgemm: matrix-matrix multiplication C = C + A*B, where C is m-by-n, A is m-by-k and B is k-by-n a multiple triangular solve Y = T*X for X, where T is a triangular matrix, and X is a rectangular matrix, etc.
History BLAS1 ~ early 1970s BLAS2 ~ mid 1980s BLAS3 ~ 1990 Why distinguish? performance! Graph: the performance in megaflops of the BLAS on the RS 6000/590, versus matrix or vector dimension (peak machine speed of 266 Mflops) Top – BLAS3 Medium – BLAS2 Bottom – BLAS1
A simple memory model 2 levels of memory: “fast”, e.g., L2 cache “slow”, e.g., RAM m = number of “slow” memory references f = number of floating point operations q = f/m = average flops/(slow reference)
A simple memory model m Justification for m f q saxpy 3n Read each x(i), y(i) once, write y(i) once 2n 2/3 sgemv n2 + O(n) Read each A(i,j) once, etc. 2 n2 2 sgemm 4 n2 Read each A(i,j), B(i,j), C(i,j), write C(i,j), etc. 2 n3 n/2 Interpreting q: for each “slow” word read, can do at most q flops while in “fast” memory – high q is better!
A simple memory model There are just two levels in the hierarchy, fast and slow The small, fast memory has size M words, where M << n2, so we can only fit a small part of an entire n-by-n matrix, but M >= 4*n, so we can fit several whole rows or columns Each word is read from slow memory individually (in practice, larger groups of words are read, such as cache lines or memory pages, but this doesn't change the basic analysis) We have complete control over which words are transferred between the two levels The last is a best-case assumption, since often the hardware (cache or virtual memory system) makes this decision for us. In parallel computing, however, when the two levels are local processor memory and remote processor memory, we often have explicit control, whether we want it or not.
Algorithm 1: Unblocked matrix multiply for i=1 to n {Read row i of A into fast memory} for j=1 to n {Read C(i,j) into fast memory} {Read column j of B into fast memory} for k=1 to n C(i,j)=C(i,j) + A(i,k)*B(k,j) end for {Write C(i,j) back to slow memory}
Algorithm 1: Unblocked matrix multiply m = # slow memory refs = n3 read each column of B n times + n2 read each row of A once for each i, and keep it in fast memory during the execution of the two inner loops + 2n2 read/write each entry of C once = n3 + 3n2 Thus q = f/m = (2n3)/(n3 + 3n2) ~ 2 << n/2 !
Algorithm 2: Column blocked matrix multiply Consider the matrix C = [C1, C2, ... , CN] as a set of N blocks, each of n/N complete columns, same for B. for j=1 to N {Read Bj into fast memory} {Read Cj into fast memory} for k=1 to n {Read column k of A into fast memory} Cj = Cj + A( :,k ) * Bj( k,: ) ... rank-1 update of Cj end for {Write Cj back to slow memory}
Algorithm 2: Column blocked matrix multiply Assume fast memory is large enough to keep Bj, Cj and Ak at a time, i.e., M>=2*n2/N + n, or N >= 2*n2/(M-n) ~ 2*n2/M m = # memory refs = n2 read each Bj once + N* n2 read each column of A N times + 2* n2 read/write each Cj once = (N+3)* n2 Thus q = f/m = (2n3)/((N+3)n2) ~ 2n/N closer to n/2 ! (in fact, q ~ M/n)
Algorithm 3: Square blocked matrix multiply (or 2D blocked) Consider C to be an N-by-N matrix of n/N-by-n/N sub-blocks Cij, same for A and B for i=1 to N for j=1 to N {Read Cij into fast memory} for k=1 to N {Read Aik into fast memory} {Read Bkj into fast memory} Cij = Cij + Aik * Bkj end for {Write Cij back to slow memory}
Algorithm 3: Square blocked matrix multiply Assume fast memory is large enough for 3 sub-blocks Cij, Aik and Bkj, i.e., M>= 3*(n/N)2, or N >= sqrt(3/M)n m = # memory refs = Nn2 read each Bkj N3 times + Nn2 read each Aik N3 times + 2n2 read/write each Cij once = (2N+2)n2 Thus q = f/m = (2n3)/((2N+2)n2) ~ n/N is “optimal” ! (q ~ sqrt(M/3))
The Problem: MM ATLAS uses this classic matrix multiply For square matrices of size n x n, the algorithm takes O(n3) It achieves 80-90% of peak performance Can we do better than O(n3) in flop count? yes, e.g., Strassen’s algorithm (for large problems)
Matrix Multiplication (blocked) = * C21 C22 A21 A22 B21 B22 C11= A11B11 + A12B21 C12= A11B12 + A12B22 C22= A21B12 + A22B22 C21= A21B11 + A22B21
Strassen’s formula Note: 7 matrix multiplications instead of 8 !
Strassen’s formula In fact: multiplication is what matters, thus Flop count S(n) = 7*S(n/2) ... the cost of the 7 recursive calls + 18*(n/2)2 ... 18 n/2-by-n/2 matrix additions Giving while traditional multiply gives In fact: multiplication is what matters, thus S(n) ~ n log2 (7) ~ n 2.83 Caveat: Strassen algorithm is only weakly stable!