pytorch实现简易分类模型
生活随笔
收集整理的這篇文章主要介紹了
pytorch实现简易分类模型
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
1 導入庫
import torch import matplotlib.pyplot as plt import torch.nn.functional as F2 數據處理
n_data=torch.ones(100,2)x_0=torch.normal(2*n_data,1) #x_0的shape是(100,2),值的均值為2*1,標準差為1 y_0=torch.zeros(100) #y_0的shape是(100,)x_1=torch.normal(-2*n_data,1) #x_1的shape是(100,2),值的均值為-2*1,標準差為1 y_1=torch.ones(100) #y_1的shape是(100,)x=torch.cat((x_0,x_1)) #默認沿著列的方向合并 y=torch.cat((y_0,y_1)).type(torch.LongTensor) #!!!必須得修改類型,否則之后報錯 #RuntimeError: expected scalar type Long but found Float ''' print(x.size(),y.size()) torch.Size([200, 2]) torch.Size([200]) '''#數據可視化 plt.scatter(x[:,0],x[:,1],c=y) plt.show()3 模型定義
class Net(torch.nn.Module):def __init__(self,n_input,n_hidden,n_output):super(Net,self).__init__()self.hidden=torch.nn.Linear(n_input,n_hidden)self.out=torch.nn.Linear(n_hidden,n_output)#前向傳播 def forward(self,x):x=F.relu(self.hidden(x))x=self.out(x)return xnet=Net(2,10,2) #輸入是每個點的橫縱坐標,輸出是屬于1和屬于0類的概率 net ''' Net((hidden): Linear(in_features=2, out_features=10, bias=True)(out): Linear(in_features=10, out_features=2, bias=True) ) '''4 模型訓練及可視化
#訓練及可視化 #優化函數 optimizer=torch.optim.SGD(net.parameters(),lr=0.2)#損失函數 loss_func=torch.nn.CrossEntropyLoss()for epoch in range(100):prediction=net(x)loss=loss_func(prediction,y)optimizer.zero_grad()#清空上一步的參與更新的參數值loss.backward()#誤差反向傳播,計算參數更新值optimizer.step()#將參數更新值施加到net的parameters上if(epoch % 10==0):predict=torch.max(F.softmax(prediction,dim=0),axis=1)[1]#默認逐行進行softmax (對num[x][1,2,....n][y]進行操作)#每一行然后找到最大的一個,max返回的第一個元素是這一行最大的那個值,第二個元素是這一行最大的那個值對應的下標#計算準確率predict=predict.data.numpy()target=y.data.numpy() accuracy=sum(predict==target)/predict.size#分類結果可視化plt.scatter(x[:,0],x[:,1],c=predict)plt.text(1.5,-4,'accuracy:{acc:.2f}'.format(acc=accuracy))plt.show()第一步:
最后一步:
總結
以上是生活随笔為你收集整理的pytorch实现简易分类模型的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: pytorch笔记:实例解析NLLLos
- 下一篇: pytorch 学习笔记:nn.Sequ