基于python的BP神经网络及异或实现过程解析

作者:沙克的世界 时间:2021-10-29 00:02:01 

BP神经网络是最简单的神经网络模型了,三层能够模拟非线性函数效果。

基于python的BP神经网络及异或实现过程解析

难点:

  • 如何确定初始化参数?

  • 如何确定隐含层节点数量?

  • 迭代多少次?如何更快收敛?

  • 如何获得全局最优解?


'''
neural networks

created on 2019.9.24
author: vince
'''
import math
import logging
import numpy
import random
import matplotlib.pyplot as plt

'''
neural network
'''
class NeuralNetwork:

def __init__(self, layer_nums, iter_num = 10000, batch_size = 1):
 self.__ILI = 0;
 self.__HLI = 1;
 self.__OLI = 2;
 self.__TLN = 3;

if len(layer_nums) != self.__TLN:
  raise Exception("layer_nums length must be 3");

self.__layer_nums = layer_nums; #array [layer0_num, layer1_num ...layerN_num]
 self.__iter_num = iter_num;
 self.__batch_size = batch_size;

def train(self, X, Y):
 X = numpy.array(X);
 Y = numpy.array(Y);

self.L = [];
 #initialize parameters
 self.__weight = [];
 self.__bias = [];
 self.__step_len = [];
 for layer_index in range(1, self.__TLN):
  self.__weight.append(numpy.random.rand(self.__layer_nums[layer_index - 1], self.__layer_nums[layer_index]) * 2 - 1.0);
  self.__bias.append(numpy.random.rand(self.__layer_nums[layer_index]) * 2 - 1.0);
  self.__step_len.append(0.3);

logging.info("bias:%s" % (self.__bias));
 logging.info("weight:%s" % (self.__weight));

for iter_index in range(self.__iter_num):
  sample_index = random.randint(0, len(X) - 1);
  logging.debug("-----round:%s, select sample %s-----" % (iter_index, sample_index));
  output = self.forward_pass(X[sample_index]);
  g = (-output[2] + Y[sample_index]) * self.activation_drive(output[2]);
  logging.debug("g:%s" % (g));
  for j in range(len(output[1])):
   self.__weight[1][j] += self.__step_len[1] * g * output[1][j];
  self.__bias[1] -= self.__step_len[1] * g;

e = [];
  for i in range(self.__layer_nums[self.__HLI]):
   e.append(numpy.dot(g, self.__weight[1][i]) * self.activation_drive(output[1][i]));
  e = numpy.array(e);
  logging.debug("e:%s" % (e));
  for j in range(len(output[0])):
   self.__weight[0][j] += self.__step_len[0] * e * output[0][j];
  self.__bias[0] -= self.__step_len[0] * e;

l = 0;
  for i in range(len(X)):
   predictions = self.forward_pass(X[i])[2];
   l += 0.5 * numpy.sum((predictions - Y[i]) ** 2);
  l /= len(X);
  self.L.append(l);

logging.debug("bias:%s" % (self.__bias));
  logging.debug("weight:%s" % (self.__weight));
  logging.debug("loss:%s" % (l));
 logging.info("bias:%s" % (self.__bias));
 logging.info("weight:%s" % (self.__weight));
 logging.info("L:%s" % (self.L));

def activation(self, z):
 return (1.0 / (1.0 + numpy.exp(-z)));

def activation_drive(self, y):
 return y * (1.0 - y);

def forward_pass(self, x):
 data = numpy.copy(x);
 result = [];
 result.append(data);
 for layer_index in range(self.__TLN - 1):
  data = self.activation(numpy.dot(data, self.__weight[layer_index]) - self.__bias[layer_index]);
  result.append(data);
 return numpy.array(result);

def predict(self, x):
 return self.forward_pass(x)[self.__OLI];

def main():
logging.basicConfig(level = logging.INFO,
  format = '%(asctime)s %(filename)s[line:%(lineno)d] %(levelname)s %(message)s',
  datefmt = '%a, %d %b %Y %H:%M:%S');

logging.info("trainning begin.");
nn = NeuralNetwork([2, 2, 1]);
X = numpy.array([[0, 0], [1, 0], [1, 1], [0, 1]]);
Y = numpy.array([0, 1, 0, 1]);
nn.train(X, Y);

logging.info("trainning end. predict begin.");
for x in X:
 print(x, nn.predict(x));

plt.plot(nn.L)
plt.show();

if __name__ == "__main__":
main();

具体收敛效果

基于python的BP神经网络及异或实现过程解析

来源:https://www.cnblogs.com/thsss/p/11588433.html

标签:python,bp,神经网络
0
投稿

猜你喜欢

  • 兼容IE,FF的弹出层登陆界面代码

    2008-01-04 12:13:00
  • pytorch 实现冻结部分参数训练另一部分

    2023-06-14 16:43:10
  • python中的txt文件转换为XML

    2021-12-05 10:45:48
  • python @classmethod 的使用场合详解

    2023-08-02 20:50:35
  • python各种语言间时间的转化实现代码

    2022-06-27 14:54:28
  • PHP开发实现快递查询功能详解

    2023-11-24 12:19:39
  • Python栈的实现方法示例【列表、单链表】

    2023-07-20 15:51:42
  • pyshp创建shp点文件的方法

    2023-06-30 03:15:29
  • 通过js获取div的background-image属性

    2023-08-23 06:07:23
  • 比较规范的验证Email地址是否正确的正则表达式

    2009-07-28 17:55:00
  • 9种使用Chrome Firefox 自带调试工具调试javascript技巧

    2023-07-19 01:03:48
  • 从"..."看中国的UI设计界的粗糙

    2007-11-21 19:28:00
  • asp如何用OdbcRegTool组件来创建一个数据源?

    2010-06-12 12:55:00
  • DHTML中重要的属性方法

    2008-06-21 17:13:00
  • PHP实现数组根据某个字段进行水平合并,横向合并案例分析

    2023-10-04 04:55:53
  • Python正则表达式的另类解答

    2023-08-02 06:58:04
  • javascript设置页面背景色及背景图片的方法

    2023-09-06 22:00:51
  • python 接收处理外带的参数方法

    2022-05-01 17:56:26
  • python多线程分块读取文件

    2023-10-29 18:48:51
  • 详解python代码模块化

    2023-08-08 23:08:23
  • asp之家 网络编程 m.aspxhome.com