【TensorFlow】实现简单的鸢尾花分类器
生活随笔
收集整理的這篇文章主要介紹了
【TensorFlow】实现简单的鸢尾花分类器
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
代碼實現及說明
# python 3.6 # TensorFlow實現簡單的鳶尾花分類器 import matplotlib.pyplot as plt import tensorflow as tf import numpy as np from sklearn import datasetssess = tf.Session()#導入數據 iris = datasets.load_iris() # 是否是山鳶尾 0/1 binary_target = np.array([1. if x == 0 else 0. forx in iris.target]) # 選擇兩個特征:花瓣長度和寬度 iris_2d = np.array([[x[2],x[3]] for x in iris.data])# 聲明批訓練大小、占位符和變量 # tf.float32降低float字節數 可以提高算法性能 batch_size = 20 x1_data = tf.placeholder(shape=[None,1],dtype=tf.float32) x2_data = tf.placeholder(shape=[None,1],dtype=tf.float32) y_target = tf.placeholder(shape=[None,1],dtype=tf.float32) # 聲明變量 A 和 b (0 = x1 - A*x2 + b) A = tf.Variable(tf.random_normal(shape=[1,1])) b = tf.Variable(tf.random_normal(shape=[1,1]))# 定義線性模型 # 線性模型的表達式為:x1=x2*A+b。 # 如果找到的數據點在直線以上,則將數據點代入x1-x2*A-b計算出的結果大于0; # 同理找到的數據點在直線以下,則將數據點代入x1-x2*A-b計算出的結果小于0。 # 將公式x1-x2*A-b傳入sigmoid函數,然后預測結果1或者0 # TensorFlow有內建的sigmoid損失函數,所以這里僅僅需要定義模型輸出 my_mult = tf.matmul(x2_data, A) my_add = tf.add(my_mult, b) my_output = tf.subtract(x1_data,my_add)# 增加分類損失函數 這里用兩類交叉熵損失函數 cross entropy xentropy = tf.nn.sigmoid_cross_entropy_with_logits(logits=my_output,labels=y_target)# 聲明優化器 my_opt = tf.train.GradientDescentOptimizer(0.05) train_step = my_opt.minimize(xentropy)# 初始化變量 init = tf.global_variables_initializer() sess.run(init)# 循環 for i in range(1000):rand_index = np.random.choice(len(iris_2d),size=batch_size)rand_x = iris_2d[rand_index]rand_x1 = np.array([[x[0]] for x in rand_x])rand_x2 = np.array([[x[1]] for x in rand_x])rand_y = np.array([[y] for y in binary_target[rand_index]])sess.run(train_step, feed_dict={x1_data:rand_x1,x2_data:rand_x2,y_target:rand_y})if (i+1)%200 == 0:print('Step #' + str(i+1) + ' A = ' + str(sess.run(A)) + ', b = ' + str(sess.run(b)))# 結果可視化 [[slope]] = sess.run(A) # 斜率 # 因為A的shape是(1,1)所以要寫成一行一列的形式 [[intercept]] = sess.run(b) # 截距# 創建擬合線 x = np.linspace(0, 3, num=50) # 0~3 50個均勻間隔的數字 ablineValues = [] for i in x:ablineValues.append(slope*i+intercept)# 繪圖 setosa_x = [a[1] for i,a in enumerate(iris_2d) if binary_target[i]==1] setosa_y = [a[0] for i,a in enumerate(iris_2d) if binary_target[i]==1] non_setosa_x = [a[1] for i,a in enumerate(iris_2d) if binary_target[i]==0] non_setosa_y = [a[0] for i,a in enumerate(iris_2d) if binary_target[i]==0] plt.plot(setosa_x, setosa_y, 'rx', ms=10, mew=2, label='setosa') plt.plot(non_setosa_x, non_setosa_y, 'ro', label='Non-setosa') plt.plot(x, ablineValues, 'b-') plt.xlim([0.0, 2.7]) plt.ylim([0.0, 7.1]) plt.xlabel('Petal Length') plt.ylabel('Petal Width') plt.legend(loc='lower right') plt.show()繪圖結果
總結
這里利用花瓣長度和花瓣寬度的特征在山鳶尾與其他物種間擬合一條直線,然后通過該直線來分割兩類目標(山鳶尾和非山鳶尾),直線是迭代1000次得到的線性分割,通過直線分割兩個目標并不是最好的模型。
總結
以上是生活随笔為你收集整理的【TensorFlow】实现简单的鸢尾花分类器的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 企业战略咨询方法:学习SWOT分析
- 下一篇: Android官方开发文档Trainin