kmeans算法python实现(iris数据集)
首先從sklearn里面載入iris數據集
如下所示
Sepal_Length ?Sepal_Width ?Petal_Length ?Petal_Width ?Species
0 ? ? ? ? ? ? 5.1 ? ? ? ? ?3.5 ? ? ? ? ? 1.4 ? ? ? ? ?0.2 ? ? ? ?0
1 ? ? ? ? ? ? 4.9 ? ? ? ? ?3.0 ? ? ? ? ? 1.4 ? ? ? ? ?0.2 ? ? ? ?0
2 ? ? ? ? ? ? 4.7 ? ? ? ? ?3.2 ? ? ? ? ? 1.3 ? ? ? ? ?0.2 ? ? ? ?0
3 ? ? ? ? ? ? 4.6 ? ? ? ? ?3.1 ? ? ? ? ? 1.5 ? ? ? ? ?0.2 ? ? ? ?0
4 ? ? ? ? ? ? 5.0 ? ? ? ? ?3.6 ? ? ? ? ? 1.4 ? ? ? ? ?0.2 ? ? ? ?0
.. ? ? ? ? ? ?... ? ? ? ? ?... ? ? ? ? ? ... ? ? ? ? ?... ? ? ?...
145 ? ? ? ? ? 6.7 ? ? ? ? ?3.0 ? ? ? ? ? 5.2 ? ? ? ? ?2.3 ? ? ? ?2
146 ? ? ? ? ? 6.3 ? ? ? ? ?2.5 ? ? ? ? ? 5.0 ? ? ? ? ?1.9 ? ? ? ?2
147 ? ? ? ? ? 6.5 ? ? ? ? ?3.0 ? ? ? ? ? 5.2 ? ? ? ? ?2.0 ? ? ? ?2
148 ? ? ? ? ? 6.2 ? ? ? ? ?3.4 ? ? ? ? ? 5.4 ? ? ? ? ?2.3 ? ? ? ?2
149 ? ? ? ? ? 5.9 ? ? ? ? ?3.0 ? ? ? ? ? 5.1 ? ? ? ? ?1.8 ? ? ? ?2
[150 rows x 5 columns]
可以看到有4列為特征,最后一列為類別
這里為了畫圖方便僅使用了Sepal_Length? 和Petal_Width? 兩列
?可以看到特征和結果相關性挺高的
假如沒有標簽,看起來可以用kmeans解決,最后用kmeans看能不能得到類似的一個結果
# -*- coding: utf-8 -*- import glob from collections import defaultdict import matplotlib.pyplot as plt from sklearn.datasets import load_iris import pandas as pd import numpy as npdef plot_scatter(df, name, centers=None, title=None):'''畫圖'''plt.figure()plt.scatter(df['Sepal_Length'], df['Petal_Width'], c=df['Species'])plt.xlabel('Sepal_Length')plt.ylabel('Petal_Width')plt.legend()if centers:plt.scatter([i[0] for i in centers], [i[1] for i in centers], c='r')if title:plt.title(title)plt.savefig(name)def distance(point_a, point_b):'''歐氏距離計算'''return np.sqrt(sum((np.array(point_a) - np.array(point_b)) ** 2))def k_means(points, k):centers = [points[i] for i in range(k)]dict_ = []iter_num = 0while True:point_dict = defaultdict(list)for point in points:distances = [distance(center, point) for center in centers]class_ = np.argmin(distances)dict_.append({'Sepal_Length': point[0], 'Petal_Width': point[1], 'Species': class_}, )point_dict[class_].append(point)print({k: len(v) for k, v in point_dict.items()})new_centers = [np.array(points).mean(axis=0) for class_, points in point_dict.items()]dis = (np.array(new_centers) - centers)if abs(dis.mean()) <= 0.0002:breakelse:centers = new_centersplot_scatter(pd.DataFrame(dict_), f'kmeans_{iter_num}.png', centers, iter_num)iter_num += 1def png2jif():'''迭代生成的png轉為動圖'''file_names = glob.glob('*.png')from PIL import Imageim = Image.open(file_names[0])images = []for file_name in file_names[1:]:images.append(Image.open(file_name))im.save('gif.gif', save_all=True, append_images=images, loop=1, duration=500, comment=b"aaabb")def get_iris_df():iris = load_iris()iris_d = pd.DataFrame(iris['data'], columns=['Sepal_Length', 'Sepal_Width', 'Petal_Length', 'Petal_Width'])iris_d['Species'] = iris.targetiris_d.dropna(inplace=True)return iris_dif __name__ == '__main__':iris_df = get_iris_df()plot_scatter(iris_df, 'raw.png')print(iris_df)points = iris_df[['Sepal_Length', 'Petal_Width']].valuesk = 3k_means(points, k)png2jif()?
?紅色為中心點,可以看到通過kmeans可以得到一個和原始結果相近的一個結果
總結
以上是生活随笔為你收集整理的kmeans算法python实现(iris数据集)的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: stdout字符串过滤输出
- 下一篇: sql插入后返回id