神经网络(BP)算法Python实现及应用
作者:一清 时间:2021-11-12 20:00:14
本文实例为大家分享了Python实现神经网络算法及应用的具体代码,供大家参考,具体内容如下
首先用Python实现简单地神经网络算法:
import numpy as np
# 定义tanh函数
def tanh(x):
return np.tanh(x)
# tanh函数的导数
def tan_deriv(x):
return 1.0 - np.tanh(x) * np.tan(x)
# sigmoid函数
def logistic(x):
return 1 / (1 + np.exp(-x))
# sigmoid函数的导数
def logistic_derivative(x):
return logistic(x) * (1 - logistic(x))
class NeuralNetwork:
def __init__(self, layers, activation='tanh'):
"""
神经网络算法构造函数
:param layers: 神经元层数
:param activation: 使用的函数(默认tanh函数)
:return:none
"""
if activation == 'logistic':
self.activation = logistic
self.activation_deriv = logistic_derivative
elif activation == 'tanh':
self.activation = tanh
self.activation_deriv = tan_deriv
# 权重列表
self.weights = []
# 初始化权重(随机)
for i in range(1, len(layers) - 1):
self.weights.append((2 * np.random.random((layers[i - 1] + 1, layers[i] + 1)) - 1) * 0.25)
self.weights.append((2 * np.random.random((layers[i] + 1, layers[i + 1])) - 1) * 0.25)
def fit(self, X, y, learning_rate=0.2, epochs=10000):
"""
训练神经网络
:param X: 数据集(通常是二维)
:param y: 分类标记
:param learning_rate: 学习率(默认0.2)
:param epochs: 训练次数(最大循环次数,默认10000)
:return: none
"""
# 确保数据集是二维的
X = np.atleast_2d(X)
temp = np.ones([X.shape[0], X.shape[1] + 1])
temp[:, 0: -1] = X
X = temp
y = np.array(y)
for k in range(epochs):
# 随机抽取X的一行
i = np.random.randint(X.shape[0])
# 用随机抽取的这一组数据对神经网络更新
a = [X[i]]
# 正向更新
for l in range(len(self.weights)):
a.append(self.activation(np.dot(a[l], self.weights[l])))
error = y[i] - a[-1]
deltas = [error * self.activation_deriv(a[-1])]
# 反向更新
for l in range(len(a) - 2, 0, -1):
deltas.append(deltas[-1].dot(self.weights[l].T) * self.activation_deriv(a[l]))
deltas.reverse()
for i in range(len(self.weights)):
layer = np.atleast_2d(a[i])
delta = np.atleast_2d(deltas[i])
self.weights[i] += learning_rate * layer.T.dot(delta)
def predict(self, x):
x = np.array(x)
temp = np.ones(x.shape[0] + 1)
temp[0:-1] = x
a = temp
for l in range(0, len(self.weights)):
a = self.activation(np.dot(a, self.weights[l]))
return a
使用自己定义的神经网络算法实现一些简单的功能:
小案例:
X: Y
0 0 0
0 1 1
1 0 1
1 1 0
from NN.NeuralNetwork import NeuralNetwork
import numpy as np
nn = NeuralNetwork([2, 2, 1], 'tanh')
temp = [[0, 0], [0, 1], [1, 0], [1, 1]]
X = np.array(temp)
y = np.array([0, 1, 1, 0])
nn.fit(X, y)
for i in temp:
print(i, nn.predict(i))
发现结果基本机制,无限接近0或者无限接近1
第二个例子:识别图片中的数字
导入数据:
from sklearn.datasets import load_digits
import pylab as pl
digits = load_digits()
print(digits.data.shape)
pl.gray()
pl.matshow(digits.images[0])
pl.show()
观察下:大小:(1797, 64)
数字0
接下来的代码是识别它们:
import numpy as np
from sklearn.datasets import load_digits
from sklearn.metrics import confusion_matrix, classification_report
from sklearn.preprocessing import LabelBinarizer
from NN.NeuralNetwork import NeuralNetwork
from sklearn.cross_validation import train_test_split
# 加载数据集
digits = load_digits()
X = digits.data
y = digits.target
# 处理数据,使得数据处于0,1之间,满足神经网络算法的要求
X -= X.min()
X /= X.max()
# 层数:
# 输出层10个数字
# 输入层64因为图片是8*8的,64像素
# 隐藏层假设100
nn = NeuralNetwork([64, 100, 10], 'logistic')
# 分隔训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y)
# 转化成sklearn需要的二维数据类型
labels_train = LabelBinarizer().fit_transform(y_train)
labels_test = LabelBinarizer().fit_transform(y_test)
print("start fitting")
# 训练3000次
nn.fit(X_train, labels_train, epochs=3000)
predictions = []
for i in range(X_test.shape[0]):
o = nn.predict(X_test[i])
# np.argmax:第几个数对应最大概率值
predictions.append(np.argmax(o))
# 打印预测相关信息
print(confusion_matrix(y_test, predictions))
print(classification_report(y_test, predictions))
结果:
矩阵对角线代表预测正确的数量,发现正确率很多
这张表更直观地显示出预测正确率:
共450个案例,成功率94%
来源:http://www.cnblogs.com/xuyiqing/p/8797048.html
标签:Python,神经网络
![](/images/zang.png)
![](/images/jiucuo.png)
猜你喜欢
非常不错的[JS]Cookie精通之路第1/2页
2023-09-04 04:04:34
python中有关时间日期格式转换问题
2023-03-17 07:43:12
![](https://img.aspxhome.com/file/2023/8/83658_0s.png)
python样条插值的实现代码
2022-05-11 16:04:05
![](https://img.aspxhome.com/file/2023/3/70833_0s.png)
实现SQL Server到DB2连接服务器很简单
2010-08-08 15:24:00
python函数形参用法实例分析
2023-09-08 21:07:09
python3的数据类型及数据类型转换实例详解
2022-06-30 11:24:45
![](https://img.aspxhome.com/file/2023/7/65447_0s.png)
响应浏览器地址栏#(hash/fragment)变化
2009-12-28 10:45:00
Pyinstaller+Pipenv打包Python文件的实现示例
2021-06-11 01:49:51
![](https://img.aspxhome.com/file/2023/4/82594_0s.png)
教你学会通过python的matplotlib库绘图
2022-03-04 13:08:52
![](https://img.aspxhome.com/file/2023/5/78605_0s.png)
Go语言框架快速集成限流中间件详解
2023-08-26 11:44:39
Python __slots__的使用方法
2023-11-19 16:15:10
Golang正整数指定规则排序算法问题分析
2023-07-12 09:12:03
如何使用python docx模块操作word文档
2022-04-23 12:16:02
![](https://img.aspxhome.com/file/2023/6/76086_0s.png)
DjangoRestFramework 使用 simpleJWT 登陆认证完整记录
2021-03-29 18:34:12
![](https://img.aspxhome.com/file/2023/1/76001_0s.png)
个人网站与动网整合非官方方法
2009-07-05 18:42:00
JS图形技术的终极体现
2008-04-30 19:43:00
![](https://img.aspxhome.com/file/UploadPic/20084/30/200843019490727s.jpg)
如何用Python搭建gRPC服务
2023-02-08 16:00:54
![](https://img.aspxhome.com/file/2023/3/92333_0s.png)
MYSQL和ORACLE的一些操作区别
2008-12-18 14:33:00
python opencv实现图像配准与比较
2023-03-01 15:30:24
![](https://img.aspxhome.com/file/2023/9/87899_0s.jpg)
python中关于时间和日期函数的常用计算总结(time和datatime)
2022-01-02 05:50:08