【机器学习】梯度下降原理
求解導(dǎo)數(shù) 導(dǎo)數(shù)為0 取最小值
x = np.linspace(-2,5,100) y = f(x) plt.plot(x,y)梯度下降求最小值
#導(dǎo)數(shù)函數(shù) d = lambda x:2*(x-3)*1+2.5#學(xué)習(xí)率 需調(diào)節(jié) 每次改變數(shù)值的時(shí)候,改變多少 learning_rate = 0.1min_value = np.random.randint(-3,5,size =1)[0] print('-'*30,min_value) #記錄數(shù)據(jù)更新了,原來的值,上一步的值 min_value_last = min_value+0.1#tollerence容忍度,誤差,在萬分之一,任務(wù)結(jié)束 tol = 0.0001 count = 0 while True:if np.abs(min_value - min_value_last)<tol:break #梯度下降min_value_last = min_value #更新值min_value = min_value - learning_rate*d(min_value)print("++++++++++%d"%count,min_value)count+=1 print("*"*30,min_value)------------------------------ -2
++++++++++0 -1.25
++++++++++1 -0.6499999999999999
++++++++++2 -0.16999999999999993
++++++++++3 0.21400000000000008
++++++++++4 0.5212000000000001
++++++++++5 0.7669600000000001
++++++++++6 0.9635680000000001
++++++++++7 1.1208544
++++++++++8 1.24668352
++++++++++9 1.347346816
++++++++++10 1.4278774528
++++++++++11 1.49230196224
++++++++++12 1.543841569792
++++++++++13 1.5850732558336
++++++++++14 1.6180586046668801
++++++++++15 1.644446883733504
++++++++++16 1.6655575069868032
++++++++++17 1.6824460055894426
++++++++++18 1.695956804471554
++++++++++19 1.7067654435772432
++++++++++20 1.7154123548617946
++++++++++21 1.7223298838894356
++++++++++22 1.7278639071115485
++++++++++23 1.7322911256892388
++++++++++24 1.735832900551391
++++++++++25 1.7386663204411128
++++++++++26 1.7409330563528902
++++++++++27 1.7427464450823122
++++++++++28 1.7441971560658498
++++++++++29 1.74535772485268
++++++++++30 1.7462861798821439
++++++++++31 1.7470289439057152
++++++++++32 1.7476231551245722
++++++++++33 1.7480985240996578
++++++++++34 1.7484788192797263
++++++++++35 1.748783055423781
++++++++++36 1.7490264443390249
++++++++++37 1.7492211554712198
++++++++++38 1.749376924376976
++++++++++39 1.7495015395015807
++++++++++40 1.7496012316012646
****************************** 1.7496012316012646
更新值learning_rate*d(max_value) 最大/最小值導(dǎo)數(shù)為0
就可能滿足np.abs(max_value - max_value_last)<precision:
d2 = lambda x:-2*(x-3)*1+2.5 #學(xué)習(xí)率 需調(diào)節(jié) 每次改變數(shù)值的時(shí)候,改變多少 learning_rate = 0.1 max_value = np.random.randint(-3,5,size =1)[0] print('-'*30,min_value) #記錄數(shù)據(jù)更新了,原來的值,上一步的值 max_value_last = max_value+0.1 result =[] #tollerence容忍度,誤差,在萬分之一,任務(wù)結(jié)束 #precision精確度, 誤差,在萬分之一,任務(wù)結(jié)束 precision = 0.0001 count = 0 while True:if count>3000: # 避免梯度消失 rate =1 # 避免梯度爆炸 導(dǎo)數(shù)更新值有問題時(shí) 或 rate =10breakif np.abs(max_value - max_value_last)<precision:break #梯度下降max_value_last = max_value#更新值learning_rate*d(max_value) 最大/最小值導(dǎo)數(shù)為0 # 就可能滿足np.abs(max_value - max_value_last)<precision:max_value = max_value + learning_rate*d2(max_value)result.append(max_value)print("++++++++++%d"%count,max_value)count+=1 print("*"*30,max_value)------------------------------ 1.7496012316012646
++++++++++0 0.050000000000000044
++++++++++1 0.8900000000000001
++++++++++2 1.5620000000000003
++++++++++3 2.0996
++++++++++4 2.52968
++++++++++5 2.873744
++++++++++6 3.1489952
++++++++++7 3.36919616
++++++++++8 3.545356928
++++++++++9 3.6862855424
++++++++++10 3.79902843392
++++++++++11 3.889222747136
++++++++++12 3.9613781977088
++++++++++13 4.01910255816704
++++++++++14 4.065282046533632
++++++++++15 4.102225637226906
++++++++++16 4.131780509781525
++++++++++17 4.15542440782522
++++++++++18 4.174339526260176
++++++++++19 4.18947162100814
++++++++++20 4.201577296806512
++++++++++21 4.2112618374452095
++++++++++22 4.219009469956168
++++++++++23 4.225207575964935
++++++++++24 4.230166060771948
++++++++++25 4.234132848617558
++++++++++26 4.237306278894047
++++++++++27 4.239845023115238
++++++++++28 4.24187601849219
++++++++++29 4.2435008147937525
++++++++++30 4.244800651835002
++++++++++31 4.2458405214680015
++++++++++32 4.246672417174401
++++++++++33 4.247337933739521
++++++++++34 4.247870346991617
++++++++++35 4.248296277593293
++++++++++36 4.248637022074634
++++++++++37 4.248909617659708
++++++++++38 4.2491276941277665
++++++++++39 4.249302155302213
++++++++++40 4.249441724241771
++++++++++41 4.249553379393417
++++++++++42 4.249642703514733
****************************** 4.249642703514733
總結(jié)
以上是生活随笔為你收集整理的【机器学习】梯度下降原理的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 机器学习之EM算法的原理推导及相关知识总
- 下一篇: gmap mysql cachet_百度