循环取矩阵的某行_1.2 震惊! 某大二本科生写的矩阵乘法吊打Mathematica-线性代数库BLAS-矩阵 (上)...
本文是
1. 線性代數庫BLAS?zhuanlan.zhihu.com系列的第二篇, 將講述矩陣類的結構和矩陣基礎運算的AVX2加速算法.
1. 矩陣類的結構
在講述矩陣各種算法之前很有必要詳解一下普通矩陣和各種常用的特殊矩陣 (包括方陣, 三對角矩陣, 對稱矩陣, 上下三角矩陣, 帶狀矩陣, 上下三角帶狀矩陣和稀疏矩陣)的基本數據結構.
為了表明某一矩陣是特殊矩陣, 我創建了一個枚舉類MatType以指示矩陣的種類. 而為了更快速的區分各種帶狀矩陣 (儲存結構和前面的矩陣不一樣) 和稀疏矩陣, 我采用了以下安排方式
enum class MatType {NormalMat,SquareMat,DiagonalMat,SymmetricMat,LMat,UMat,//for band mat, width == heightBandMat,LBandMat,UBandMat,//sparse matSparseMat, };這樣只需要用比較matType和BandMat, SparseMat的大小就能判斷是否是帶狀矩陣或者稀疏矩陣. 而對于向量, 還存在這樣一種分類:
enum class Type {Native = 0,Parasitic = 1,Non32Aligened = 2,//must be Parasitic as well... };Native即原生的, 構造時是在堆上分配了內存的; Parasitic指的是32字節對齊的寄生向量, Non32Aligened則是沒有32字節對齊的寄生向量, 一般是向量的子向量或者矩陣的一行等等. 這些寄生向量在構造和析構時不會進行內存分配和釋放, 避免構造中間向量, 提高性能.
而類mat的成員如下:
double* data; union {unsigned long long width;unsigned long long halfBandWidth;//if is BandMatunsigned long long* rowIndice;//if is SparseMat }; union {unsigned long long height;unsigned long long* colIndice;//if is SparseMat }; union {unsigned long long width4d;unsigned long long elementNum; }; Type type; MatType matType;由于一個矩陣不能同時為普通矩陣, 帶狀矩陣或者稀疏矩陣, 所以可以利用union來節約空間, 并在進行各種矩陣運算的時候根據MatType來使用對應的方式來訪問矩陣.
接下來則是詳解各種矩陣元素的存儲結構.
1. 非帶狀矩陣
即原始的各種矩陣, 最naive的想法是直接按照行優先 (即同一行的元素是連續的)線性鋪開, 但是很快我們就可以發現一個問題: 由于我們分配的內存是4 double對齊的, 所以第一行是一個正常的對齊的向量, 但是若矩陣的列數是一個奇數, 那么4n+1, 4n+2, 4n+3行都不是4 double對齊的向量. 顯然這不是我們想要的, 因為在很多情況下, 計算是圍繞這一個行向量展開的, 倘若3/4的行向量都不是對齊的, 那對性能的影響是很大的. 所以聰明的你肯定想到了這樣的作法: 將行向量的末尾用4 double補齊, 再轉到下一行. 這樣雖然有部分空間沒有利用上, 但是在這個動輒64GB內存的大環境 (自己在做計算物理大作業的時候確實用到過這么多, 即
維電阻網絡的兩點間電阻計算), 相信大家一定不會在意這一點點的"浪費".2. 帶狀矩陣
a. 對于半帶寬為halfBandWidth的n階帶狀矩陣, 計算物理課上提示的是稠密存儲, 即按照
這種模式一個一個的順序存儲 (空位不被儲存). 我們發現這樣存儲對與行向量的4 double對齊非常不友好, 所以同理, 需要在前后加上padding以使每個行向量能夠4 double對齊 (注意這里的對齊指的是按照列號來對齊而不是按照該元素是該行第n個非0元的n來對齊, 例如帶狀矩陣乘向量時, 向量是4 double對齊的, 所以按照矩陣乘法, 矩陣某行的列序號是要對齊的). 按照這種對齊方式, 我們可以計算出每行至少為ceiling4(2*halfBandWidth+4)個元素, 其中ceiling4函數是按4的模向上取整, 例如ceiling4(5)=8等等.
b. 同理, 對于半帶寬為halfBandWidth的n階上/下三角帶狀矩陣, 即(半帶寬為2的6階下三角矩陣)
每行至少為ceiling4(halfBandWidth+4)個元素.
3. 稀疏矩陣
這里采用行優先的方式順序存儲, 即下面的稀疏矩陣
按照{(1, 1,
), (1, 2, ), (1, 6, ), (2, 2, ), ... , (6, 6, )}的方式儲存, 但是這里的(rowIndex, columnIndex, element)不是直接連續排放在一起的, 而是rowIndex放成一個數組, columnIndex放成一個數組, element放成一個數組.2. 普通矩陣的基礎運算
計算物理中, 很多算法是針對某種特殊矩陣的, 需要與普通矩陣的運算分開. 今天這篇只會講普通矩陣的運算, 我們有以下常用的矩陣基礎運算 (剩下的線性代數相關的運算和算法等會放在下篇):
由于上一篇講向量的過程中已經講過了基礎的對應位置的各種算法, 這里就不在重復講述了, 我將把重點放在很多情況下需要深度優化的矩陣乘向量和矩陣乘法, 例如矩陣乘向量在各種迭代法求解線性方程組中是非常重要的操作, 好的優化至關重要.
3. 矩陣乘向量
由于代碼過于冗長 (因為需要考慮到各種邊界上沒對齊問題), 這里只用部分代碼來解釋具體的并行化方法, 并且只講述普通矩陣的優化方法 (其他矩陣如帶狀矩陣和稀疏矩陣并沒有進行特殊優化).
矩陣乘向量
, trivial的想法是將矩陣視為多個行向量, 然后和向量b點乘得到結果c的一個元素, 求和的順序即矩陣行優先 (指的是求和的最內部的循環是同一行).但是顯然這總方式的并行化程度很差, 我們可以從更量化的角度來看這個問題.
在GPU的并行編程中有一個概念叫計算-讀寫比, 即計算操作與讀寫操作的比例, 當這個比例為算力/顯存帶寬時, 計算效率達能到最大化, 而且在很多情況下讀寫是瓶頸. 以剛才的naive方法為例, 假設我們已經使用了向量點乘的AVX2優化, 那么進行一次元運算 (對于一個AVX2寄存器來說, 一次元操作為對4 double的計算)fmadd需要讀寫8 double (矩陣a的4 double和向量b的4 double, 由于結果是寫在另一個寄存器內, 所以不考慮這部分的寫入), 但是只進行了一次對4 double的fmadd操作, 所以計算-讀寫比為0.5.
這個比例顯然沒有達到我們的預期. 首先想到讀取一次向量b的4 double后可以讀取矩陣a的多行的4 double (我選取了4行, 這時候的計算-讀寫比為0.8, 可以自行驗算), 這樣就能重復利用向量b的數據, 達到加速的目的. 同時, 由于單核的AVX2寄存器數量很多 (一般至少有16個), 所以為了增大輸出的口徑 (原理是編譯器會將這些相鄰的計算分配到不同的寄存器以使得多個寄存器能并行地在流水線上工作), 可以將向量b分成多組, 每一組含有多個__m256d數據 (經過測試8個較為合適, 更大的話一般不會更快, 反而會使得剩余的最后一組變得更大, 耗時更久).
constexpr unsigned long long warp = 8;//將輸入向量分組為每組8個__m256d unsigned long long minWidth4((minDim - 1) / 4 + 1); __m256d* aData((__m256d*)data);//矩陣 __m256d* bData((__m256d*)source->data);//輸入向量 __m256d* rData((__m256d*)b.data);//輸出向量 unsigned long long heightFloor4((height >> 2) << 2); unsigned long long widthWarp((minWidth4 / warp) * warp); unsigned long long warpLeftFloor((minDim >> 2) - widthWarp); unsigned long long warpLeftCeiling(minWidth4 - widthWarp); unsigned long long c0(0); for (; c0 < heightFloor4; c0 += 4)//每次計算4行 {__m256d ans[4] = { 0 };__m256d tp[warp];unsigned long long c1(0);for (; c1 < widthWarp; c1 += warp){__m256d* s(aData + minWidth4 * c0 + c1); #pragma unroll(4)for (unsigned long long c2(0); c2 < warp; ++c2)tp[c2] = bData[c1 + c2];for (unsigned long long c2(0); c2 < 4; ++c2, s += minWidth4){ #pragma unroll(4)for (unsigned long long c3(0); c3 < warp; ++c3){__m256d t = s[c3];ans[c2] = _mm256_fmadd_pd(t, tp[c3], ans[c2]);}}}if (c1 < minWidth4)//處理沒有分組對齊的部分{__m256d* s(aData + minWidth4 * c0 + c1); #pragma unroll(4)for (unsigned long long c2(0); c2 < warpLeftCeiling; ++c2)tp[c2] = bData[c1 + c2];unsigned long long finalWidth(minDim - ((minDim >> 2) << 2));for (unsigned long long c2(finalWidth); c2 < 4; ++c2)tp[warpLeftFloor].m256d_f64[c2] = 0;for (unsigned long long c2(0); c2 < 4; ++c2, s += minWidth4){ #pragma unroll(4)for (unsigned long long c3(0); c3 < warpLeftCeiling; ++c3){__m256d t = s[c3];ans[c2] = _mm256_fmadd_pd(t, tp[c3], ans[c2]);}}}__m256d s;for (unsigned long long c1(0); c1 < 4; ++c1){s.m256d_f64[c1] = ans[c1].m256d_f64[0];s.m256d_f64[c1] += ans[c1].m256d_f64[1];s.m256d_f64[c1] += ans[c1].m256d_f64[2];s.m256d_f64[c1] += ans[c1].m256d_f64[3];}rData[c0 >> 2] = s; } //后面是處理剩余的非4行對齊的部分4. 矩陣乘矩陣
此處是本文的重頭戲, 我才不是標題黨呢, 1024階方陣乘法用Ryzen 3950X實測數據如下 (依次為手寫AVX2優化, Mathematica和matlab的計算時間, 由于CPU頻率和其他負載等因素導致的時間計算不準確, 計算耗時都已經取了多次的最快的那次, 并且都禁用了打印結果并提前分配好儲存答案的內存. 手寫版本的正確性已經通過Mathematica計算與正確答案差的范數驗證過了, 所有計算方式均已開啟多線程優化以利用所有核心和超線程):
可以看出, 手寫的矩陣乘法確實比Mathematica快了4倍有余, 和使用mkl的matlab速度幾乎一致 (當然仍然會被CUDA的cublas爆錘). 那么, 這么快的矩陣乘法是咋寫的呢?
有了上面矩陣乘向量的經驗, 我們明白了一件事情: 想要提高速度, 就要讓一次讀取的數據能盡可能多的被用于計算. 有了這樣的想法, 矩陣乘法的AVX2優化也就呼之欲出:
我們需要從另一個角度來看矩陣乘法, 而非和之前的矩陣乘向量一樣將右矩陣視為多個列向量. 矩陣乘法
, 這時候從這個乘法求和式可以看出, 矩陣a的第i行的第k列 要遍歷地乘以矩陣b的第k行的所有元素, 得到一行向量 (這里i, k是固定的, 所以得到一個以j為列指標的行向量), 對k求和就可以得到最終矩陣c的第i行 (這里i是固定的, 以j為列指標). 到這里并行化的方式已經很明顯了, 就是將 利用之前學過的向量乘常數的算法加速. 但是需要注意的是, 由于k的求和范圍1~n可以很大 (如1024階矩陣就是1024), 這樣就無法將計算結果 的整個行向量儲存在寄存器內, 所以也需要將其分組, 經過測試我選擇了16個__m256d為一組, 這樣每次得到的也就是矩陣c的第i行 的一組. 而由于矩陣a的兩行 等需要乘以b的同一行 , 所以可以利用這一點再同時計算這兩行以提高數據利用程度. 經過了以上優化, 大家可以算一下計算-讀寫比, 以下代碼是讀了(2+64)個double, 由于是在內循環外, 所以在b的行數較大的時候基本可以忽略寫入, 計算了128次fmadd, 所以計算-讀寫比為1.94, 相比傳統的三個循環的算法的讀2算1的0.5高了近3倍, 而且這里除了那2個double不是連續讀取, 剩下所有的讀寫都是連續的, 對于內存來說十分友好.//source是乘法的左矩陣 __m256d* aData((__m256d*)a.data);//乘法的右矩陣 __m256d* bData((__m256d*)b.data);//乘法結果 constexpr unsigned long long warp = 16; unsigned long long aWidth256d(a.width4d / 4); unsigned long long aWidthWarpFloor(aWidth256d / warp * warp); unsigned long long warpLeft(aWidth256d - aWidthWarpFloor); unsigned long long height2Floor(height & (-2)); unsigned long long c0(0); for (; c0 < height2Floor; c0 += 2) {unsigned long long c1(0);for (; c1 < aWidthWarpFloor; c1 += warp){__m256d ans0[warp] = { 0 };__m256d ans1[warp] = { 0 };for (unsigned long long c2(0); c2 < minDim; ++c2){__m256d tp0 = _mm256_set1_pd(source->data[c0 * width4d + c2]);__m256d tp1 = _mm256_set1_pd(source->data[(c0 + 1) * width4d + c2]); #pragma unroll(4)for (unsigned long long c3(0); c3 < warp; ++c3){__m256d b = aData[aWidth256d * c2 + c1 + c3];ans0[c3] = _mm256_fmadd_pd(tp0, b, ans0[c3]);ans1[c3] = _mm256_fmadd_pd(tp1, b, ans1[c3]);}} #pragma unroll(4)for (unsigned long long c3(0); c3 < warp; ++c3){bData[c0 * aWidth256d + c1 + c3] = ans0[c3];bData[(c0 + 1) * aWidth256d + c1 + c3] = ans1[c3];}}if (c1 < aWidth256d)//對行向量分組后剩余的部分{__m256d ans0[warp] = { 0 };__m256d ans1[warp] = { 0 };for (unsigned long long c2(0); c2 < minDim; ++c2){__m256d tp0 = _mm256_set1_pd(source->data[c0 * width4d + c2]);__m256d tp1 = _mm256_set1_pd(source->data[(c0 + 1) * width4d + c2]);for (unsigned long long c3(0); c3 < warpLeft; ++c3){__m256d b = aData[aWidth256d * c2 + c1 + c3];ans0[c3] = _mm256_fmadd_pd(tp0, b, ans0[c3]);ans1[c3] = _mm256_fmadd_pd(tp1, b, ans1[c3]);}}for (unsigned long long c3(0); c3 < warpLeft; ++c3){bData[c0 * aWidth256d + c1 + c3] = ans0[c3];bData[(c0 + 1) * aWidth256d + c1 + c3] = ans1[c3];}} } //后面還需計算因每次計算2行導致的可能的剩余的一行至于多線程版本的矩陣乘法, 假設矩陣a有n行, 總共有m個線程, 由于是每2行需要同時計算, 而不同的2行之間是獨立的, 所以可以將連續的floor2(n/m)行分配給同一線程 (floor2指的是對2取模的向下取整, 例floor2(3) = 2), 并讓原線程 (調用矩陣乘法函數的那個線程)來計算最后的那一組和剩下的那行 (如果n是奇數). 由于不想使用全局函數來寫每個線程調用的子函數, 而創建線程的函數_beginthread需要的線程函數不能是帶this指針的類函數, 所以在矩陣乘法函數內嵌一個lambda表達式是再好不過的了 (不用beginthreadex是因為lambda函數是__cdecl的, 而beginthreadex需要的是__stdcall函數作為線程函數). 最終代碼如下:
BLAS::mat& matMultMT(BLAS::mat const& ts, BLAS::mat const& a, BLAS::mat& b) {using namespace BLAS;SYSTEM_INFO systemInfo;GetSystemInfo(&systemInfo);unsigned long long threadNum(systemInfo.dwNumberOfProcessors);::printf("Number of processors: %llun", threadNum);unsigned long long minDim(ts.width > a.height ? a.height : ts.width);if (minDim){bool overflow(ceiling4(a.width, ts.height) > b.width4d * b.height);if (overflow && b.type != Type::Native)return b;mat const* source(&ts);mat r;if (&b == &ts){source = &r;r = ts;}if (overflow)b.reconstruct(a.width, ts.height, false);__m256d* aData((__m256d*)a.data);__m256d* bData((__m256d*)b.data);constexpr unsigned long long warp = 16;unsigned long long aWidth256d(a.width4d / 4);unsigned long long aWidthWarpFloor(aWidth256d / warp * warp);unsigned long long warpLeft(aWidth256d - aWidthWarpFloor);unsigned long long height2Floor(ts.height & (-2));//ptr: A beginning// A ending// B// C beginning// width4d// aWidthWarpFloor// minDim// aWidth256d// warpLeftvoid(*lambda)(void*) = [](void* ptr){void** ptrs((void**)ptr);double* source = ((double*)ptrs[0]);double* sourceEnding((double*)ptrs[1]);__m256d* aData((__m256d*)ptrs[2]);__m256d* bData((__m256d*)ptrs[3]);unsigned long long width4d((unsigned long long)ptrs[4]);unsigned long long aWidthWarpFloor((unsigned long long)ptrs[5]);unsigned long long minDim((unsigned long long)ptrs[6]);unsigned long long aWidth256d((unsigned long long)ptrs[7]);unsigned long long warpLeft((unsigned long long)ptrs[8]);constexpr unsigned long long warp = 16;for (; source < sourceEnding; source += width4d * 2, bData += aWidth256d * 2){unsigned long long c1(0);for (; c1 < aWidthWarpFloor; c1 += warp){__m256d ans0[warp] = { 0 };__m256d ans1[warp] = { 0 };for (unsigned long long c2(0); c2 < minDim; ++c2){__m256d tp0 = _mm256_set1_pd(source[c2]);__m256d tp1 = _mm256_set1_pd(source[width4d + c2]); #pragma unroll(4)for (unsigned long long c3(0); c3 < warp; ++c3){__m256d b = aData[aWidth256d * c2 + c1 + c3];ans0[c3] = _mm256_fmadd_pd(tp0, b, ans0[c3]);ans1[c3] = _mm256_fmadd_pd(tp1, b, ans1[c3]);}} #pragma unroll(4)for (unsigned long long c3(0); c3 < warp; ++c3){bData[c1 + c3] = ans0[c3];bData[aWidth256d + c1 + c3] = ans1[c3];}}if (c1 < aWidth256d){__m256d ans0[warp] = { 0 };__m256d ans1[warp] = { 0 };for (unsigned long long c2(0); c2 < minDim; ++c2){__m256d tp0 = _mm256_set1_pd(source[c2]);__m256d tp1 = _mm256_set1_pd(source[width4d + c2]);for (unsigned long long c3(0); c3 < warpLeft; ++c3){__m256d b = aData[aWidth256d * c2 + c1 + c3];ans0[c3] = _mm256_fmadd_pd(tp0, b, ans0[c3]);ans1[c3] = _mm256_fmadd_pd(tp1, b, ans1[c3]);}}for (unsigned long long c3(0); c3 < warpLeft; ++c3){bData[c1 + c3] = ans0[c3];bData[aWidth256d + c1 + c3] = ans1[c3];}}}};unsigned long long threadHeight((ts.height / threadNum) & (-2));//floor2unsigned long long rowBeginning;HANDLE* threads(nullptr);void** paras(nullptr);if (threadHeight)//如果每個線程都沒法分配到至少2行, 那就直接用單線程算了{rowBeginning = threadHeight * (threadNum - 1);threads = (HANDLE*)::malloc((threadNum - 1) * sizeof(HANDLE));paras = (void**)::malloc((threadNum - 1) * 9 * sizeof(void*));for (unsigned long long c0(0); c0 < threadNum - 1; ++c0){paras[c0 * 9] = (void*)(source->data + threadHeight * ts.width4d * c0);paras[c0 * 9 + 1] = (void*)(source->data + threadHeight * ts.width4d * (c0 + 1));paras[c0 * 9 + 2] = (void*)(a.data);paras[c0 * 9 + 3] = (void*)(b.data + threadHeight * a.width4d * c0);paras[c0 * 9 + 4] = (void*)(ts.width4d);paras[c0 * 9 + 5] = (void*)(aWidthWarpFloor);paras[c0 * 9 + 6] = (void*)(minDim);paras[c0 * 9 + 7] = (void*)(aWidth256d);paras[c0 * 9 + 8] = (void*)(warpLeft);threads[c0] = (HANDLE)_beginthread(lambda, 0, paras + c0 * 9);}}else rowBeginning = 0;//接下來的部分就是原線程求最后剩下的多行unsigned long long c0(rowBeginning);for (; c0 < height2Floor; c0 += 2){unsigned long long c1(0);for (; c1 < aWidthWarpFloor; c1 += warp){__m256d ans0[warp] = { 0 };__m256d ans1[warp] = { 0 };for (unsigned long long c2(0); c2 < minDim; ++c2){__m256d tp0 = _mm256_set1_pd(source->data[c0 * ts.width4d + c2]);__m256d tp1 = _mm256_set1_pd(source->data[(c0 + 1) * ts.width4d + c2]); #pragma unroll(4)for (unsigned long long c3(0); c3 < warp; ++c3){__m256d b = aData[aWidth256d * c2 + c1 + c3];ans0[c3] = _mm256_fmadd_pd(tp0, b, ans0[c3]);ans1[c3] = _mm256_fmadd_pd(tp1, b, ans1[c3]);}} #pragma unroll(4)for (unsigned long long c3(0); c3 < warp; ++c3){bData[c0 * aWidth256d + c1 + c3] = ans0[c3];bData[(c0 + 1) * aWidth256d + c1 + c3] = ans1[c3];}}if (c1 < aWidth256d){__m256d ans0[warp] = { 0 };__m256d ans1[warp] = { 0 };for (unsigned long long c2(0); c2 < minDim; ++c2){__m256d tp0 = _mm256_set1_pd(source->data[c0 * ts.width4d + c2]);__m256d tp1 = _mm256_set1_pd(source->data[(c0 + 1) * ts.width4d + c2]);for (unsigned long long c3(0); c3 < warpLeft; ++c3){__m256d b = aData[aWidth256d * c2 + c1 + c3];ans0[c3] = _mm256_fmadd_pd(tp0, b, ans0[c3]);ans1[c3] = _mm256_fmadd_pd(tp1, b, ans1[c3]);}}for (unsigned long long c3(0); c3 < warpLeft; ++c3){bData[c0 * aWidth256d + c1 + c3] = ans0[c3];bData[(c0 + 1) * aWidth256d + c1 + c3] = ans1[c3];}}}if (c0 < ts.height){unsigned long long c1(0);for (; c1 < aWidthWarpFloor; c1 += warp){__m256d ans0[warp] = { 0 };for (unsigned long long c2(0); c2 < minDim; ++c2){__m256d tp0 = _mm256_set1_pd(source->data[c0 * ts.width4d + c2]); #pragma unroll(4)for (unsigned long long c3(0); c3 < warp; ++c3){__m256d b = aData[aWidth256d * c2 + c1 + c3];ans0[c3] = _mm256_fmadd_pd(tp0, b, ans0[c3]);}} #pragma unroll(4)for (unsigned long long c3(0); c3 < warp; ++c3)bData[c0 * aWidth256d + c1 + c3] = ans0[c3];}if (c1 < aWidth256d){__m256d ans0[warp] = { 0 };for (unsigned long long c2(0); c2 < minDim; ++c2){__m256d tp0 = _mm256_set1_pd(source->data[c0 * ts.width4d + c2]);for (unsigned long long c3(0); c3 < warpLeft; ++c3){__m256d b = aData[aWidth256d * c2 + c1 + c3];ans0[c3] = _mm256_fmadd_pd(tp0, b, ans0[c3]);}}for (unsigned long long c3(0); c3 < warpLeft; ++c3)bData[c0 * aWidth256d + c1 + c3] = ans0[c3];}}if (threadHeight){WaitForMultipleObjects(threadNum - 1, threads, true, INFINITE);for (unsigned long long c0(0); c0 < threadNum - 1; ++c0)CloseHandle(threads[c0]);::free(threads);::free(paras);}}return b; }總結
以上是生活随笔為你收集整理的循环取矩阵的某行_1.2 震惊! 某大二本科生写的矩阵乘法吊打Mathematica-线性代数库BLAS-矩阵 (上)...的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: hibernate mysql casc
- 下一篇: oracle数据库恢复aul_RMAN备