CrossEntropyLabelSmooth 的代码
                                                            生活随笔
收集整理的這篇文章主要介紹了
                                CrossEntropyLabelSmooth 的代码
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.                        
                                本文的注釋可能有不正確的地方,望知
CrossEntropyLabelSmooth
import torch import torch.nn as nn from IPython import embedclass CrossEntropyLabelSmooth(nn.Module):"""用于指定帶標簽平滑的交叉熵公式 ,用于指定帶標簽平滑的交叉熵公式"""def __init__(self, num_classes, epsilon=0.1, use_gpu=True):"""__init__()方法參數有num_classes與epsilon第一個參數指定分類數量第二參數即標簽平滑公式中的epsilon 這里是對應的標簽平滑的過程的。"""super(CrossEntropyLabelSmooth, self).__init__()self.num_classes = num_classes # num_classes = 8self.epsilon = epsilon # epsilon = 0.1self.use_gpu = use_gpu # 是否是要來使用gpu的過程的。self.logsoftmax = nn.LogSoftmax(dim=1) # 把相應的Softmax在來通過log的形式的。def forward(self, inputs, targets):"""Args:inputs: prediction matrix (before softmax) with shape (batch_size, num_classes)targets: ground truth labels with shape (num_classes) 是來對應的真實的標簽的。"""log_probs = self.logsoftmax(inputs) # torch.Size([4, 8])這里是得到批次為4,每一個屬于這8個類別中那一個概率是最大的# scatter_是來沿著1,列方向上這個維度來進行索引的。zeros[4,8]的列數要與scatter這邊的列數是相同的。# ongTensor中的index最大值應與zeros(4, 8)行數相一致targets = torch.zeros(log_probs.size()).scatter_(1, targets.unsqueeze(1).data, 1) # 這里是來輸出相應的標簽信息的。print(targets)if self.use_gpu: targets = targets.cuda() #如果是存在gpu話,就是放在gpu上面來進行計算的過程的。# y = (1 - epsilon) * y + epsilon / K.targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes # 這里是來對應的標簽平滑化的過程的。print(targets)loss = (- targets * log_probs).mean(0).sum()return lossif __name__ == '__main__':batch_size = 4 # 首先是對應的批尺寸num_class = 8 # 對應的分類數是8loss_class = CrossEntropyLabelSmooth(num_classes=num_class, use_gpu=False)input = torch.randn([batch_size,num_class]) # 這里是來隨機生成一批數據的label = torch.randint(0, num_class-1, [batch_size]) # 用于生成指定范圍的整數,也就是對對應的標簽 tensor([4, 5, 3, 3])print(label)print(loss_class(input, label))總結
以上是生活随笔為你收集整理的CrossEntropyLabelSmooth 的代码的全部內容,希望文章能夠幫你解決所遇到的問題。
 
                            
                        - 上一篇: Open3D Delaunay三角剖分(
- 下一篇: 易语言调试工具 code by:↖星空·
