【Knowledge distillation: A good teacher is patient and consistent】
在計算機視覺方面,實現最先進性能的大型模型與實際應用中簡單的模型之間的差距越來越大。在本文中,將解決這個問題,并顯著地彌補這2種模型之間的差距。
在實證研究中,作者的目標不是一定要提出一種新的方法,而是努力確定一種穩健和有效的配置方案,使最先進的大模型在實踐中能夠得到應用。本文證明了在正確使用的情況下,知識蒸餾可以在不影響大模型性能的情況下減小它們的規模。作者還發現有某些隱式的設計選擇可能會極大地影響蒸餾的有效性。
作者的主要貢獻是明確地識別了這些設計選擇。作者通過一項全面的實證研究來支持本文的發現,在廣泛的視覺數據集上展示了很不錯的結果,特別是,為ImageNet獲得了最先進的ResNet-50模型,達到了82.8%的Top-1精度。
一、簡介.
大型視覺模型目前主導著計算機視覺的許多領域。最新的圖像分類、目標檢測或語義分割模型都將模型的大小推到現代硬件允許的極限。盡管它們的性能令人印象深刻,但由于計算成本高,這些模型很少在實踐中使用。
相反,實踐者通常使用更小的模型,如ResNet-50或MobileNet等,這些模型運行起來代價更低。根據Tensorflow Hub的5個BiT的下載次數,最小的ResNet-50的下載次數明顯多于較大的模型。因此,許多最近在視覺方面的改進并沒有轉化為現實世界的應用程序。
為了解決這個問題,本文將專注于以下任務:給定一個特定的應用程序和一個在它上性能很好的大模型,目標是在不影響性能的情況下將模型壓縮到一個更小、更高效的模型體系結構。針對這個任務有2種廣泛使用的范例:模型剪枝和知識蒸餾。
模型剪枝通過剝離大模型的各個部分來減少大模型的大小。這個過程在實踐中可能會有限制性:首先,它不允許更改模型族,比如從ResNet到MobileNet。其次,可能存在依賴于架構的挑戰;例如,如果大模型使用GN,修剪通道可能導致需要動態地重新分配通道組。
相反,作者專注于沒有這些缺點的知識蒸餾方法。知識蒸餾背后的理念是“提煉”一個教師模型,在本文例子中,一個龐大而繁瑣的模型或模型集合,制成一個小而高效的學生模型。這是通過強迫學生模型的預測與教師模型的預測相匹配,從而自然地允許模型家族的變化作為壓縮的一部分。
密切遵循Hinton的原始蒸餾配置,發現如果操作正確,它驚人地有效;如圖1所示作者將蒸餾解釋為匹配教師和學生實現的函數的任務。通過這種解釋發現對模型壓縮的知識蒸餾的2個關鍵原則。
- 首先,教師和學生模型應該處理完全相同的輸入圖像,或者更具體地說,相同的裁剪和數據增強;
- 其次,希望函數在大量的支撐點上匹配,以便更好地推廣。
使用Mixup的變體,可以在原始圖像流形外生成支撐點??紤]到這一點,通過實驗證明,一致的圖像視圖、合適的數據增強和非常長的訓練計劃是通過知識蒸餾使模型壓縮在實踐中工作良好的關鍵。
盡管發現明顯很簡單,但有很多種原因可能會阻止研究人員(和從業者)做出建議的設計選擇。
- 首先,很容易預先計算教師對離線圖像的激活量,以節省計算量,特別是對于非常大的教師模型;
- 其次,知識蒸餾也通常用于不同的上下文(除了模型壓縮),其中作者推薦不同甚至相反的設計選擇;
- 最后,知識蒸餾需要比較多的Epoch來達到最佳性能,比通常用于監督訓練的Epoch要多得多。更糟糕的是,在常規時間的訓練中看起來不理想的選擇往往是最好的,反之亦然。
在本文的實證研究中,主要集中于壓縮大型BiT-ResNet-152x2,它在ImageNet-21k數據集上預訓練,并對感興趣的相關數據集進行微調。在不影響精度的情況下,將其蒸餾為標準的ResNet-50架構(用GN代替BN)。還在ImageNet數據集上取得了非常強的結果:總共有9600個蒸餾周期,在ImageNet上得到了新的ResNet-50SOTA結果,達到了驚人的82.8%。這比原始的ResNet-50模型高出4.4%,比文獻中最好的ResNet-50模型高出2.2%。
最后,作者還證明了本文的蒸餾方案在同時壓縮和更改模型時也可以工作,例如BiT-ResNet架構到MobileNet架構。
二、實驗配置
2.1 Datasets, metrics and evaluation protocol
在5個流行的圖像分類數據集上進行了實驗:flowers102,pets,food101,sun397和ILSVRC-2012(“ImageNet”)。這些數據集跨越了不同的圖像分類場景;特別是,它們的類的數量不同,從37到1000個類,訓練圖像的總數從1020到1281167個不等。
2.2 Teacher and student models
在本文中,選擇使用來自BiT的預訓練教師模型,該模型提供了大量在ILSVRC-2012和ImageNet-21k數據集上預訓練的ResNet模型,具有最先進的準確性。BiT-ResNets與標準ResNets唯一顯著的區別是使用了GN層和權重標準化。
特別地專注于BiT-M-R152x2架構:在ImageNet-21k上預訓練的BiT-ResNet-152x2(152層,“x2”表示寬度倍數)。該模型在各種視覺基準上都顯示出了優異的性能,而且它仍然可以使用它進行廣泛的消融研究。盡管如此,它的部署成本還是很昂貴的(它需要比標準ResNet-50多10倍的計算量),因此該模型的有效壓縮具有實際的重要性。對于學生模型的架構,使用了一個BiT-ResNet-50變體,為了簡潔起見,它被稱為ResNet-50。
2.3 Distillation loss
這里使用教師模型的和學生模型的之間的KL散度作為一個蒸餾損失來預測類概率向量。對于原始數據集的硬標簽,不使用任何額外的損失:
C是類別。這里還引入了一個溫度參數T,用于在損失計算之前調整預測的softmax-probability分布的熵:
2.4 Training setup
為了優化,使用帶有默認參數的Adam優化器訓練模型。還使用了不帶有Warm up的余弦學習率機制。
作者同時還為所有的實驗使用了解耦的權重衰減機制。為了穩定訓練,在梯度的全局l2范數上以1.0的閾值進行梯度裁剪。最后,除在ImageNet上訓練的模型使用batch size為4096進行訓練外,對其他所有實驗都使用batch size為512。
本文的方案的另一個重要組成部分是Mixup數據增強策略。特別在“函數匹配”策略中中引入了一個Mixup變量,其中使用從[0,1]均勻抽樣的較強的Mixup系數,這可以看作是最初提出的β分布抽樣的一個極端情況。
作者還使用了““inception-style”的裁剪,然后將圖像的大小調整為固定的正方形大小。此外,為了能夠廣泛的分析在計算上的可行(訓練了數十萬個模型),除了ImageNet實驗,使用標準輸入224×224分辨率,其他數據集均使用相對較低的輸入分辨率,并將輸入圖像的大小調整為128×128大小。
三、模型蒸餾
3.1 “consistent and patient teacher”假說
在本節中,對介紹中提出的假設進行實驗驗證,如圖1所示,當作為函數匹配時,蒸餾效果最好,即當學生和教師模型輸入圖像是一致視圖時,通過mixup合成“filled”,當學生模型接受長時間的訓練時(即“教師”很有耐心)。
為了確保假說的穩健性,作者對4個中小型數據集進行了非常徹底的分析,即Flowers102,Pets,Food101,Sun397進行了訓練。
為了消除任何混雜因素,作者對每個精餾設定使用學習速率{0.0003,0.001,0.003,0.01}與權重衰減{1× 1 0 ? 5 10^{-5} 10?5,3× 1 0 ? 5 10^{-5} 10?5,1× 1 0 ? 4 10^{-4} 10?4,3× 1 0 ? 4 10^{-4} 10?4,1× 1 0 ? 3 10^{-3} 10?3}以及蒸餾溫度{1,2,5,10}的所有組合。
3.1.1.Importance of “consistent” teaching
首先,證明了一致性標準,即學生和教師看到相同的視圖,是執行蒸餾的唯一方法,它可以在所有數據集上一致地達到學生模型的最佳表現。在本研究中,定義了多個蒸餾配置,它們對應于圖1中所示的所有4個選項的實例化:
1. Fixed teacher
作者探索了幾個選項,其中教師模型的預測是恒定的,為一個給定的圖像。
最簡單(也是最差的)的方法是fix/rs,即學生和老師的圖像大小都被調整到224x224pixel。
fix/cc遵循一種更常見的方法,即教師使用固定的central crop,而學生使用random crop。
fix/ic_ens是一種重數據增強方法,教師模型的預測是1024種inception crops的平均值,我們驗證了以提高教師的表現。該學生模型使用random crop。
2. Independent noise
用2種方式實例化了這種常見的策略:
ind/rc分別為教師和學生計算2種獨立的random crop;
ind/ic則使用heavy inception crop。
3. Consistent teaching
在這種方法中,只對圖像進行隨機裁剪一次,要么是mild random cropping(same/rc),要么是heavy inception crop(same/ic),并使用相同的crop向學生和教師模型提供輸入。
4. Function matching
這種方法擴展了consistent teaching,通過mixup擴展圖像的輸入,并再次為學生和教師模型提供一致的輸入。為了簡潔起見,將這種方法稱為“FunMatch”。
3.1.2 Importance of “patient” teaching
人們可以將蒸餾解釋為監督學習的一種變體,其中標簽是由一個強大的教師模型提供的。當教師模型的預測計算為單一圖像視圖時,這一點尤其正確。這種方法繼承了標準監督學習的所有問題,例如,嚴重的數據增強可能會扭曲實際的圖像標簽,而輕微的增強可能又會導致過擬合。
然而,如果將蒸餾解釋為函數匹配,并且最重要的是,確保為學生和老師模型提供一致的輸入,情況就會發生變化。在這種情況下,可以進行比較強的圖像增強:即使圖像視圖過于扭曲,仍然會在匹配該輸入上的相關函數方面取得進展。因此,可以通過增強來增加機會,通過做比較強的圖像增強來避免過擬合,如果正確,可以優化很長一段時間,直到學生模型的函數接近教師模型的函數。
在圖4中證實了作者的假設,對于每個數據集,顯示了在訓練最佳函數匹配學生模型時不同數量的訓練Epoch的測試精度的變化。教師模型為一條紅線,經過比在標準監督訓練中使用的更多的Epoch后,最終總是能夠達到。至關重要的是,即使優化了一百萬個Epoch,也沒有過擬合的跡象。
作者還訓練和調整了另外2個Baseline以供參考:使用數據集原始硬標簽從零開始訓練ResNet-50,以及傳輸在ImageNet-21k上預訓練的ResNet-50。對于這2個Baseline,側重于調整學習率和權重衰減。使用原始標簽從零開始訓練的模型大大優于學生模型。
值得注意的是,相對較短的100個Epoch的訓練結果比遷移Baseline差得多??偟膩碚f,ResNet-50的學生模型持續地匹配ResNet-152x2教師模型。
CIFAR-10 Example
以Cifar-10數據集為例,驗證蒸餾得到的resnet-50模型的準確率
weights_cifar10 = get_weights('BiT-M-R50x1-CIFAR10') model = ResNetV2(ResNetV2.BLOCK_UNITS['r50'], width_factor=1, head_size=10) # NOTE: No new head. model.load_from(weights_cifar10) model.to(device); def eval_cifar10(model, bs=100, progressbar=True):loader_test = torch.utils.data.DataLoader(testset, batch_size=bs, shuffle=False, num_workers=2)model.eval()if progressbar is True:progressbar = display(progress(0, len(loader_test)), display_id=True)preds = []with torch.no_grad():for i, (x, t) in enumerate(loader_test):x, t = x.to(device), t.numpy()logits = model(x)_, y = torch.max(logits.data, 1)preds.extend(y.cpu().numpy() == t)progressbar.update(progress(i+1, len(loader_test)))return np.mean(preds) print("Expected: 97.61%") print(f"Accuracy: {eval_cifar10(model):.2%}")評估預訓練模型,輸出如下:
Expected: 97.61%Accuracy: 97.62%找到索引以創建5個鏡頭的CIFAR10變體
preprocess_tiny = tv.transforms.Compose([tv.transforms.CenterCrop((2, 2)), tv.transforms.ToTensor()]) trainset_tiny = tv.datasets.CIFAR10(root='./data', train=True, download=True, transform=preprocess_tiny) loader = torch.utils.data.DataLoader(trainset_tiny, batch_size=50000, shuffle=False, num_workers=2) images, labels = iter(loader).next() indices = {cls: np.random.choice(np.where(labels.numpy() == cls)[0], 5, replace=False) for cls in range(10)} print(indices) fig = plt.figure(figsize=(10, 4)) ig = ImageGrid(fig, 111, (5, 10)) for c, cls in enumerate(indices):for r, i in enumerate(indices[cls]):img, _ = trainset[i]ax = ig.axes_column[c][r]ax.imshow((img.numpy().transpose([1, 2, 0]) * 127.5 + 127.5).astype(np.uint8))ax.set_axis_off() fig.suptitle('The whole 5-shot CIFAR10 dataset'); train_5shot = torch.utils.data.Subset(trainset, indices=[i for v in indices.values() for i in v]) len(train_5shot)輸出如下
50微調BiT-M(resnet-50)在這個5-shot CIFAR10變體上
model = ResNetV2(ResNetV2.BLOCK_UNITS['r50'], width_factor=1, head_size=10, zero_head=True) model.load_from(weights) model.to(device); sampler = torch.utils.data.RandomSampler(train_5shot, replacement=True, num_samples=256) loader_train = torch.utils.data.DataLoader(train_5shot, batch_size=256, num_workers=2, sampler=sampler) crit = nn.CrossEntropyLoss() opti = torch.optim.SGD(model.parameters(), lr=0.003, momentum=0.9) model.train(); S = 500 def schedule(s):step_lr = stairs(s, 3e-3, 200, 3e-4, 300, 3e-5, 400, 3e-6, S, None)return rampup(s, 100, step_lr)pb_train = display(progress(0, S), display_id=True) pb_test = display(progress(0, 100), display_id=True) losses = [[]] accus_train = [[]] accus_test = []steps_per_iter = 512 // loader_train.batch_sizewhile len(losses) < S:for x, t in loader_train:x, t = x.to(device), t.to(device)logits = model(x)loss = crit(logits, t) / steps_per_iterloss.backward()losses[-1].append(loss.item())with torch.no_grad():accus_train[-1].extend(torch.max(logits, dim=1)[1].cpu().numpy() == t.cpu().numpy())if len(losses[-1]) == steps_per_iter:losses[-1] = sum(losses[-1])losses.append([])accus_train[-1] = np.mean(accus_train[-1])accus_train.append([])# Update learning-rate according to schedule, and stop if necessarylr = schedule(len(losses) - 1)for param_group in opti.param_groups:param_group['lr'] = lropti.step()opti.zero_grad()pb_train.update(progress(len(losses) - 1, S))print(f'\r[Step {len(losses) - 1}] loss={losses[-2]:.2e} 'f'train accu={accus_train[-2]:.2%} 'f'test accu={accus_test[-1] if accus_test else 0:.2%} 'f'(lr={lr:g})', end='', flush=True)if len(losses) % 25 == 0:accus_test.append(eval_cifar10(model, progressbar=pb_test))model.train()得到的損失函數、訓練準確率和測試準確率輸出如下
[Step 499] loss=2.23e-05 train accu=100.00% test accu=86.41% (lr=3e-06) fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(15, 4)) ax1.plot(losses[:-1]) ax1.set_yscale('log') ax1.set_title('loss') ax2.plot(accus_train[:-1]) ax2.set_title('training accuracy') ax3.plot(np.arange(25, 501, 25), accus_test) ax3.set_title('test accuracy');得到的損失函數、訓練準確率和測試準確率圖像輸出如下
參考文章:
讓ResNet-50精度高達82.8%!ViT原作者的知識蒸餾新作 | CVPR 2022 Oral
總結
以上是生活随笔為你收集整理的【Knowledge distillation: A good teacher is patient and consistent】的全部內容,希望文章能夠幫你解決所遇到的問題。