BP神经网络处理iris数据集(Pytorch实现)
一.數據集介紹
這次數據集使用的是iris數據集,也稱鳶尾花卉數據集,是一類多重變量分析的數據集。數據集包含150個數據樣本,分為3類,每類50個數據,每個數據包含4個屬性。可通過花萼長度,花萼寬度,花瓣長度,花瓣寬度4個屬性預測鳶尾花卉屬于(Setosa,Versicolour,Virginica)三個種類中的哪一類。
該數據集進行神經網絡時,輸入是 Sepal.Length(花萼長度), Sepal.Width(花萼寬度),Petal.Length(花瓣長度), Petal.Width(花瓣寬度),輸出為種類,Iris Setosa(山鳶尾)、Iris Versicolour(雜色鳶尾),以及Iris Virginica(維吉尼亞鳶尾)。
二.代碼實現
代碼部分總共為兩個版本,分別是CPU版本和GPU版本。
數據集是從sklearn中下載得到:
我們可以看一下該數據集的輸出,輸出可以自己看,這就不展示了。
print(data) print(iris_type)之后需要對數據進行處理,因為使用到pytorch,我們需要將數據轉為tensor格式:
input = torch.FloatTensor(dataset['data']) label = torch.LongTensor(dataset['target'])接下來可以定義神經網絡模型:
class BPNerualNetwork(torch.nn.Module):def __init__(self):super().__init__()self.model = nn.Sequential(nn.Linear(input_size, hidden_size1),nn.ReLU(),nn.Linear(hidden_size1, hidden_size2),nn.ReLU(),nn.Linear(hidden_size2, hidden_size3),nn.ReLU(),nn.Linear(hidden_size3, output_size),nn.LogSoftmax(dim=1))def forward(self, x):x = self.model(x)return x這里我設置了三層隱藏層,不過你可以自己增減隱藏層,只需要調用函數nn.Linear(),激活函數可以直接設置,pytorch里面可以直接調用,我這里使用的是nn.ReLU()函數。
后續只需要進行訓練就可以了(下面代碼是GPU版本):
三.效果
我使用了matplotlib將準確率和loss進行展示:
準確率達到了0.99,可以說效果不錯哦。
具體兩個版本代碼可以點這里下載。
總結
以上是生活随笔為你收集整理的BP神经网络处理iris数据集(Pytorch实现)的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: SQL Server数据库可疑处理
- 下一篇: 渲染管线概述