深入理解Spark 2.1 Core (十二):TimSort 的原理与源码分析
使用Sort等對(duì)數(shù)據(jù)進(jìn)行排序,其中用到了TimSort
這篇博文我們就來(lái)深入理解下TimSort
理解timsort
看完視頻后也許你會(huì)發(fā)現(xiàn)TimSort和MergeSort非常像。沒(méi)錯(cuò),這里推薦先閱讀關(guān)于理解timsort的博文,你就會(huì)發(fā)現(xiàn)它其實(shí)只是對(duì)歸并排序進(jìn)行了一系列的改進(jìn)。其中有一些是很聰明的,而也有一些是相當(dāng)簡(jiǎn)單直接的。這些大大小小的改進(jìn)聚集起來(lái)使得算法的效率變得十分的吸引人。
Spark TimSort 源碼分析
其實(shí)OpenJDK在Java SE 7的Arrays關(guān)于Object元素?cái)?shù)組的sort也使用了TimSort,而Spark的org.apache.spark.util.collection包中的用Java編寫(xiě)的TimSort也和java SE 7中的TimSort沒(méi)有太大區(qū)別。
public void sort(Buffer a, int lo, int hi, Comparator<? super K> c) {assert c != null;// 未排序的數(shù)組長(zhǎng)度int nRemaining = hi - lo;// 若數(shù)組大小為 0 或者 1// 那么就以及排序了if (nRemaining < 2)return; // 若是小數(shù)組// 則不使用歸并排序if (nRemaining < MIN_MERGE) {// 得到遞增序列的長(zhǎng)度int initRunLen = countRunAndMakeAscending(a, lo, hi, c);// 二分插入排序binarySort(a, lo, hi, lo + initRunLen, c);return;}// 棧SortState sortState = new SortState(a, c, hi - lo);// 得到最小run長(zhǎng)度int minRun = minRunLength(nRemaining);do {// 得到遞增序列的長(zhǎng)度int runLen = countRunAndMakeAscending(a, lo, hi, c);// 若run太小,// 使用二分插入排序if (runLen < minRun) {int force = nRemaining <= minRun ? nRemaining : minRun;binarySort(a, lo, lo + force, lo + runLen, c);runLen = force;}// 入棧sortState.pushRun(lo, runLen);// 可能進(jìn)行歸并sortState.mergeCollapse();// 查找下一run的預(yù)操作lo += runLen;nRemaining -= runLen;} while (nRemaining != 0);// 歸并所有剩余的run,完成排序assert lo == hi;sortState.mergeForceCollapse();assert sortState.stackSize == 1;}- 36
我們接下來(lái)逐個(gè)深入的講解:
countRunAndMakeAscending
private int countRunAndMakeAscending(Buffer a, int lo, int hi, Comparator<? super K> c) {assert lo < hi;int runHi = lo + 1;if (runHi == hi)return 1;K key0 = s.newKey();K key1 = s.newKey();// 找到run的尾部if (c.compare(s.getKey(a, runHi++, key0), s.getKey(a, lo, key1)) < 0) { // 若是遞減的,找到尾部反轉(zhuǎn)runwhile (runHi < hi && c.compare(s.getKey(a, runHi, key0), s.getKey(a, runHi - 1, key1)) < 0)runHi++;reverseRange(a, lo, runHi);} else { while (runHi < hi && c.compare(s.getKey(a, runHi, key0), s.getKey(a, runHi - 1, key1)) >= 0)runHi++;}// 返回run的長(zhǎng)度return runHi - lo;}- 1
binarySort
private void binarySort(Buffer a, int lo, int hi, int start, Comparator<? super K> c) {assert lo <= start && start <= hi;if (start == lo)start++;K key0 = s.newKey();K key1 = s.newKey();Buffer pivotStore = s.allocate(1);// 將位置[start,hi)上的元素二分插入排序到已經(jīng)有序的[lo,start)序列中for ( ; start < hi; start++) {s.copyElement(a, start, pivotStore, 0);K pivot = s.getKey(pivotStore, 0, key0);int left = lo;int right = start;assert left <= right;while (left < right) {int mid = (left + right) >>> 1;if (c.compare(pivot, s.getKey(a, mid, key1)) < 0)right = mid;elseleft = mid + 1;}assert left == right;int n = start - left; // 對(duì)插入做簡(jiǎn)單的優(yōu)化switch (n) {case 2: s.copyElement(a, left + 1, a, left + 2);case 1: s.copyElement(a, left, a, left + 1);break;default: s.copyRange(a, left, a, left + 1, n);}s.copyElement(pivotStore, 0, a, left);}}- 1
- 5
minRunLength
private int minRunLength(int n) {assert n >= 0;int r = 0; // 這里 MIN_MERGE 為 2 的某次方// if n < MIN_MERGE ,// then 直接返回 n// else if n >= MIN_MERGE 且 n(>1) 為 2 的某次方,// then n 的二進(jìn)制低位第1位 為 0,r |= (n & 1) 一直為 0 ,即返回的是 MIN_MERGE / 2// else r 為之后一次循環(huán)的n的二進(jìn)制低位第1位值 k ,返回的值 MIN_MERGE/2< k < MIN_MERGE while (n >= MIN_MERGE) {r |= (n & 1);n >>= 1;}return n + r;}SortState.pushRun
入棧
private void pushRun(int runBase, int runLen) {this.runBase[stackSize] = runBase;this.runLen[stackSize] = runLen;stackSize++;}- 1
SortState.mergeCollapse
這部分代碼OpenJDK中存在著bug,我們先來(lái)看一下Java SE 7是如何實(shí)現(xiàn)的:
private void mergeCollapse() {while (stackSize > 1) {int n = stackSize - 2;if (n > 0 && runLen[n-1] <= runLen[n] + runLen[n+1]) {if (runLen[n - 1] < runLen[n + 1])n--;mergeAt(n);} else if (runLen[n] <= runLen[n + 1]) {mergeAt(n);} else {break; }} }- 1
我們來(lái)舉個(gè)例子:
當(dāng)棧中的片段長(zhǎng)度為:
120, 80, 25, 20
我們插入長(zhǎng)度的30的片段,由于25 < 20 + 30 并且 25 < 30,所以得到:
120, 80, 45, 30
現(xiàn)在,由于80 > 45 + 30 并且 45 > 30,于是合并結(jié)束。但這并不完全符合根據(jù)不變式的重存儲(chǔ),因?yàn)?20 < 80 + 45!
更多細(xì)節(jié)可以參閱相關(guān)博文,Spark也對(duì)此bug進(jìn)行了修復(fù),修復(fù)后的代碼如下:
private void mergeCollapse() {while (stackSize > 1) {int n = stackSize - 2;if ( (n >= 1 && runLen[n-1] <= runLen[n] + runLen[n+1])|| (n >= 2 && runLen[n-2] <= runLen[n] + runLen[n-1])) {if (runLen[n - 1] < runLen[n + 1])n--;} else if (runLen[n] > runLen[n + 1]) {break; }mergeAt(n);}}- 1
- 4
SortState. mergeAt
private void mergeAt(int i) {assert stackSize >= 2;assert i >= 0;assert i == stackSize - 2 || i == stackSize - 3;int base1 = runBase[i];int len1 = runLen[i];int base2 = runBase[i + 1];int len2 = runLen[i + 1];assert len1 > 0 && len2 > 0;assert base1 + len1 == base2;// 若 i 是從棧頂數(shù)第3個(gè)位置// 則 將棧頂元素 賦值到 從棧頂數(shù)第2個(gè)位置runLen[i] = len1 + len2;if (i == stackSize - 3) {runBase[i + 1] = runBase[i + 2];runLen[i + 1] = runLen[i + 2];}stackSize--;K key0 = s.newKey();// 從 run1 中找到 run2的第1個(gè)元素的位置// 在這之前的run1的元素都可以被忽略int k = gallopRight(s.getKey(a, base2, key0), a, base1, len1, 0, c);assert k >= 0;base1 += k;len1 -= k;if (len1 == 0)return;// 從 run2 中找到 run1的最后1個(gè)元素的位置// 在這之后的run2的元素都可以被忽略len2 = gallopLeft(s.getKey(a, base1 + len1 - 1, key0), a, base2, len2, len2 - 1, c);assert len2 >= 0;if (len2 == 0)return;// 歸并run// 使用 min(len1, len2) 長(zhǎng)度的臨時(shí)數(shù)組if (len1 <= len2)mergeLo(base1, len1, base2, len2);elsemergeHi(base1, len1, base2, len2);}- 1
SortState. gallopRight
// key: run2的第1個(gè)值// a: 數(shù)組// base: run1的起始為位置// len: run1的長(zhǎng)度// hint: 從run1的hint位置開(kāi)始查找,這里我們傳入的值為 0 private int gallopRight(K key, Buffer a, int base, int len, int hint, Comparator<? super K> c) {assert len > 0 && hint >= 0 && hint < len;// 對(duì)二分查找的優(yōu)化:// 我們要從 run1中 截取出這樣一段數(shù)組// lastOfs = k+1// ofs = 2×k+1// run1[lastOfs] <= key <= run1[ofs]// 即在[lastOfs,ofs],做二分查找int ofs = 1;int lastOfs = 0;K key1 = s.newKey();// 若 run2的第1個(gè)值 < run1的第1個(gè)值// 其實(shí)我知道,可以直接返回 0// 但這里還是走了完整的算法流程 if (c.compare(key, s.getKey(a, base + hint, key1)) < 0) {// maxOfs = 1int maxOfs = hint + 1;// 不進(jìn)入循環(huán)while (ofs < maxOfs && c.compare(key, s.getKey(a, base + hint - ofs, key1)) < 0) {lastOfs = ofs;ofs = (ofs << 1) + 1;if (ofs <= 0) ofs = maxOfs;}// 不進(jìn)入if (ofs > maxOfs)ofs = maxOfs;// tmp = 0int tmp = lastOfs;// lastOfs = -1lastOfs = hint - ofs;// ofs = 0ofs = hint - tmp;} else { // 這種情況下,算法才會(huì)發(fā)揮真正的作用// maxOfs = lenint maxOfs = len - hint;while (ofs < maxOfs && c.compare(key, s.getKey(a, base + hint + ofs, key1)) >= 0) {// 更新 lastOfs 和 ofslastOfs = ofs;ofs = (ofs << 1) + 1;// 防止溢出if (ofs <= 0) ofs = maxOfs;}if (ofs > maxOfs)ofs = maxOfs;// 這里都不會(huì)變lastOfs += hint;ofs += hint;}assert -1 <= lastOfs && lastOfs < ofs && ofs <= len;// 進(jìn)行二分查找lastOfs++;while (lastOfs < ofs) {int m = lastOfs + ((ofs - lastOfs) >>> 1);if (c.compare(key, s.getKey(a, base + m, key1)) < 0)// key < a[b + m]ofs = m; else// a[b + m] <= keylastOfs = m + 1; }assert lastOfs == ofs; return ofs;}- 1
- 3
gallopLeft和上述代碼類似,就不再做講解。
SortState. mergeLo
private void mergeLo(int base1, int len1, int base2, int len2) {assert len1 > 0 && len2 > 0 && base1 + len1 == base2;// 使用 min(len1, len2) 長(zhǎng)度的臨時(shí)數(shù)組// 這里 len1 會(huì)較小Buffer a = this.a; Buffer tmp = ensureCapacity(len1);s.copyRange(a, base1, tmp, 0, len1);// tmp(run1) 上的指針int cursor1 = 0; // run2 上的指針int cursor2 = base2; // 合并結(jié)果 上的指針int dest = base1; // Move first element of second run and deal with degenerate cases// 優(yōu)化:// 注意: run2 的第一個(gè)元素比 run1的第一個(gè)元素小// run1 的最后一個(gè)元素 比 run2的最后一個(gè)元素大 // 把 run2 的第1個(gè) 元素復(fù)制到 最終結(jié)果的第1個(gè)位置s.copyElement(a, cursor2++, a, dest++);if (--len2 == 0) {// 若 len2 為 1// 直接 把 run1 拷貝到 最終結(jié)果中s.copyRange(tmp, cursor1, a, dest, len1);return;}if (len1 == 1) {// 若 len1 為 1// 把 run2 剩余的部分 拷貝到 最終結(jié)果中// 再把 run1 拷貝到 最終結(jié)果中s.copyRange(a, cursor2, a, dest, len2);s.copyElement(tmp, cursor1, a, dest + len2); return;}K key0 = s.newKey();K key1 = s.newKey();Comparator<? super K> c = this.c;// 對(duì)歸并排序的優(yōu)化: int minGallop = this.minGallop; outer:while (true) {// 主要思想為 使用 count1 count2 對(duì)插入進(jìn)行計(jì)數(shù)int count1 = 0; int count2 = 0; do {// 歸并assert len1 > 1 && len2 > 0;if (c.compare(s.getKey(a, cursor2, key0), s.getKey(tmp, cursor1, key1)) < 0) {s.copyElement(a, cursor2++, a, dest++);count2++;count1 = 0;if (--len2 == 0)break outer;} else {s.copyElement(tmp, cursor1++, a, dest++);count1++;count2 = 0;if (--len1 == 1)break outer;}// 若某個(gè)run連續(xù)拷貝的次數(shù)超過(guò)minGallop// 退出循環(huán)} while ((count1 | count2) < minGallop);// 我們認(rèn)為若某個(gè)run連續(xù)拷貝的次數(shù)超過(guò)minGallop,// 則可能還會(huì)出現(xiàn)更若某個(gè)run連續(xù)拷貝的次數(shù)超過(guò)minGallop// 所有需要重新進(jìn)行類似于mergeAt中的操作,// 截取出按“段”進(jìn)行歸并// 直到 count1 或者 count2 < MIN_GALLOPdo {assert len1 > 1 && len2 > 0;count1 = gallopRight(s.getKey(a, cursor2, key0), tmp, cursor1, len1, 0, c);if (count1 != 0) {s.copyRange(tmp, cursor1, a, dest, count1);dest += count1;cursor1 += count1;len1 -= count1;if (len1 <= 1) // len1 == 1 || len1 == 0break outer;}s.copyElement(a, cursor2++, a, dest++);if (--len2 == 0)break outer;count2 = gallopLeft(s.getKey(tmp, cursor1, key0), a, cursor2, len2, 0, c);if (count2 != 0) {s.copyRange(a, cursor2, a, dest, count2);dest += count2;cursor2 += count2;len2 -= count2;if (len2 == 0)break outer;}s.copyElement(tmp, cursor1++, a, dest++);if (--len1 == 1)break outer;minGallop--;} while (count1 >= MIN_GALLOP | count2 >= MIN_GALLOP);// 調(diào)整 minGallopif (minGallop < 0)minGallop = 0;minGallop += 2; } // 退出 outer 循環(huán)this.minGallop = minGallop < 1 ? 1 : minGallop; // 把尾部寫(xiě)入最終結(jié)果if (len1 == 1) {assert len2 > 0;s.copyRange(a, cursor2, a, dest, len2);s.copyElement(tmp, cursor1, a, dest + len2); } else if (len1 == 0) {throw new IllegalArgumentException("Comparison method violates its general contract!");} else {assert len2 == 0;assert len1 > 1;s.copyRange(tmp, cursor1, a, dest, len1);}}- 1
mergeHi與上述類似,就不再講解。
SortState.mergeForceCollapse
private void mergeForceCollapse() {// 將所有的run合并while (stackSize > 1) {int n = stackSize - 2;// 若第3個(gè)run 長(zhǎng)度 小于 棧頂?shù)膔un// 先歸并第2,3個(gè) runif (n > 0 && runLen[n - 1] < runLen[n + 1])n--;mergeAt(n);}}- 6
總結(jié)
Spark TimSort 中 對(duì)MergeSort大致有一下幾點(diǎn):
- 元素:不像MergeSort惰性的有原來(lái)的長(zhǎng)度為1,再由歸并自動(dòng)的生成新的歸并元素。TimSort是預(yù)先按連續(xù)遞增(或者將連續(xù)遞減的片段反轉(zhuǎn))的片段作為一個(gè)歸并元素,即run。
- 插入排序:若是長(zhǎng)度小的run,TimSort會(huì)改用二分的InsertSort以及對(duì)再它進(jìn)行一些小優(yōu)化,而不使用MergeSort
- 歸并的時(shí)機(jī):MergeSort的歸并時(shí)機(jī)是定死的,而TimSort中的時(shí)機(jī)是(n >= 1 && runLen[n-1] <= runLen[n] + runLen[n+1])|| (n >= 2 && runLen[n-2] <= runLen[n] + runLen[n-1])。以及,若從棧頂開(kāi)始第3個(gè)run長(zhǎng)度 小于 棧頂?shù)膔un,先歸并第2,3個(gè)run。
- 截取出需要?dú)w并的片段:run1是頭部和run2的尾部都是會(huì)有可以不用進(jìn)行歸并的部分。 如TimSort從 run1中 截取出這樣一段片段:lastOfs = k+1,ofs = 2×k+1,run1[lastOfs] <= key <= run1[ofs]。再?gòu)脑撈紊线M(jìn)行二分查找,得到run1中需要?dú)w并的起始位置
- 歸并的優(yōu)化:對(duì)run長(zhǎng)度為1時(shí),進(jìn)行了小優(yōu)化。實(shí)現(xiàn)了按單個(gè)值和按片段歸并的協(xié)同。
總結(jié)
以上是生活随笔為你收集整理的深入理解Spark 2.1 Core (十二):TimSort 的原理与源码分析的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問(wèn)題。
- 上一篇: 深入理解Spark 2.1 Core (
- 下一篇: spark sql定义RDD、DataF