python实现随机梯度下降(SGD)
作者:芳草碧连天lc 时间:2021-04-15 19:41:20
使用神经网络进行样本训练,要实现随机梯度下降算法。这里我根据麦子学院彭亮老师的讲解,总结如下,(神经网络的结构在另一篇博客中已经定义):
def SGD(self, training_data, epochs, mini_batch_size, eta, test_data=None):
if test_data:
n_test = len(test_data)#有多少个测试集
n = len(training_data)
for j in xrange(epochs):
random.shuffle(training_data)
mini_batches = [
training_data[k:k+mini_batch_size]
for k in xrange(0,n,mini_batch_size)]
for mini_batch in mini_batches:
self.update_mini_batch(mini_batch, eta)
if test_data:
print "Epoch {0}: {1}/{2}".format(j, self.evaluate(test_data),n_test)
else:
print "Epoch {0} complete".format(j)
其中training_data是训练集,是由很多的tuples(元组)组成。每一个元组(x,y)代表一个实例,x是图像的向量表示,y是图像的类别。
epochs表示训练多少轮。
mini_batch_size表示每一次训练的实例个数。
eta表示学习率。
test_data表示测试集。
比较重要的函数是self.update_mini_batch,他是更新权重和偏置的关键函数,接下来就定义这个函数。
def update_mini_batch(self, mini_batch,eta):
nabla_b = [np.zeros(b.shape) for b in self.biases]
nabla_w = [np.zeros(w.shape) for w in self.weights]
for x,y in mini_batch:
delta_nabla_b, delta_nable_w = self.backprop(x,y)#目标函数对b和w的偏导数
nabla_b = [nb+dnb for nb,dnb in zip(nabla_b,delta_nabla_b)]
nabla_w = [nw+dnw for nw,dnw in zip(nabla_w,delta_nabla_w)]#累加b和w
#最终更新权重为
self.weights = [w-(eta/len(mini_batch))*nw for w, nw in zip(self.weights, nabla_w)]
self.baises = [b-(eta/len(mini_batch))*nb for b, nb in zip(self.baises, nabla_b)]
这个update_mini_batch函数根据你传入的一些数据进行更新神经网络的权重和偏置。
来源:http://blog.csdn.net/leichaoaizhaojie/article/details/56840328
标签:python,梯度下降,SGD
0
投稿
猜你喜欢
python中判断文件结束符的具体方法
2021-09-28 13:31:53
JavaScript获取URL汇总
2024-02-24 10:40:07
iframe的防插与强插
2009-03-03 12:33:00
CentOS 6.4安装配置LAMP服务器(Apache+PHP5+MySQL)
2023-11-21 21:42:33
Mootools 1.2教程(12)——用Drag.Move实现拖拽和拖放
2008-12-05 12:29:00
python使用sessions模拟登录淘宝的方式
2023-01-09 12:05:25
教你如何使Python爬取酷我在线音乐
2021-02-18 14:13:01
Python 数据类型--集合set
2021-11-23 21:17:54
Python错误的处理方法
2021-08-01 05:38:15
解决Jupyter无法导入已安装的 module问题
2022-05-13 07:14:18
Vue基础学习之项目整合及优化
2024-05-21 10:28:49
Python 限制线程的最大数量的方法(Semaphore)
2022-03-02 06:24:09
PyCharm第一次安装及使用教程
2022-06-21 23:18:00
JavaScript几种弹窗事件的使用
2023-08-24 15:59:08
javascript 缓冲效果实现代码 推荐
2024-04-29 13:36:08
python中range()与xrange()用法分析
2021-03-23 00:31:30
Django migrate报错的解决方案
2021-05-16 12:48:30
Python K最近邻从原理到实现的方法
2022-10-13 09:41:45
vue如何通过params和query传值(刷新不丢失)
2024-05-09 15:17:23
VUE写一个简单的表格实例
2023-07-02 16:56:30