tensorflow建立一个简单的神经网络的方法

作者:Mr丶Caleb 时间:2022-09-27 17:01:51 

本笔记目的是通过tensorflow实现一个两层的神经网络。目的是实现一个二次函数的拟合。

如何添加一层网络

代码如下:


def add_layer(inputs, in_size, out_size, activation_function=None):
 # add one more layer and return the output of this layer
 Weights = tf.Variable(tf.random_normal([in_size, out_size]))
 biases = tf.Variable(tf.zeros([1, out_size]) + 0.1)
 Wx_plus_b = tf.matmul(inputs, Weights) + biases
 if activation_function is None:
   outputs = Wx_plus_b
 else:
   outputs = activation_function(Wx_plus_b)
 return outputs

注意该函数中是xW+b,而不是Wx+b。所以要注意乘法的顺序。x应该定义为[类别数量, 数据数量], W定义为[数据类别,类别数量]。

创建一些数据


# Make up some real data
x_data = np.linspace(-1,1,300)[:, np.newaxis]
noise = np.random.normal(0, 0.05, x_data.shape)
y_data = np.square(x_data) - 0.5 + noise

numpy的linspace函数能够产生等差数列。start,stop决定等差数列的起止值。endpoint参数指定包不包括终点值。


numpy.linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None)[source]
Return evenly spaced numbers over a specified interval.
Returns num evenly spaced samples, calculated over the interval [start, stop].

tensorflow建立一个简单的神经网络的方法

noise函数为添加噪声所用,这样二次函数的点不会与二次函数曲线完全重合。

numpy的newaxis可以新增一个维度而不需要重新创建相应的shape在赋值,非常方便,如上面的例子中就将x_data从一维变成了二维。

添加占位符,用作输入


# define placeholder for inputs to network
xs = tf.placeholder(tf.float32, [None, 1])
ys = tf.placeholder(tf.float32, [None, 1])

添加隐藏层和输出层


# add hidden layer
l1 = add_layer(xs, 1, 10, activation_function=tf.nn.relu)
# add output layer
prediction = add_layer(l1, 10, 1, activation_function=None)

计算误差,并用梯度下降使得误差最小


# the error between prediciton and real data
loss = tf.reduce_mean(tf.reduce_sum(tf.square(ys - prediction),reduction_indices=[1]))
train_step = tf.train.GradientDescentOptimizer(0.1).minimize(loss)

完整代码如下:


from __future__ import print_function
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

def add_layer(inputs, in_size, out_size, activation_function=None):
 # add one more layer and return the output of this layer
 Weights = tf.Variable(tf.random_normal([in_size, out_size]))
 biases = tf.Variable(tf.zeros([1, out_size]) + 0.1)
 Wx_plus_b = tf.matmul(inputs, Weights) + biases
 if activation_function is None:
   outputs = Wx_plus_b
 else:
   outputs = activation_function(Wx_plus_b)
 return outputs

# Make up some real data
x_data = np.linspace(-1,1,300)[:, np.newaxis]
noise = np.random.normal(0, 0.05, x_data.shape)
y_data = np.square(x_data) - 0.5 + noise

# define placeholder for inputs to network
xs = tf.placeholder(tf.float32, [None, 1])
ys = tf.placeholder(tf.float32, [None, 1])
# add hidden layer
l1 = add_layer(xs, 1, 10, activation_function=tf.nn.relu)
# add output layer
prediction = add_layer(l1, 10, 1, activation_function=None)

# the error between prediciton and real data
loss = tf.reduce_mean(tf.reduce_sum(tf.square(ys - prediction),
          reduction_indices=[1]))
train_step = tf.train.GradientDescentOptimizer(0.1).minimize(loss)

# important step
init = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init)

# plot the real data
fig = plt.figure()
ax = fig.add_subplot(1,1,1)
ax.scatter(x_data, y_data)
plt.ion()
plt.show()

for i in range(1000):
 # training
 sess.run(train_step, feed_dict={xs: x_data, ys: y_data})
 if i % 50 == 0:
   # to visualize the result and improvement
   try:
     ax.lines.remove(lines[0])
   except Exception:
     pass
   prediction_value = sess.run(prediction, feed_dict={xs: x_data})
   # plot the prediction
   lines = ax.plot(x_data, prediction_value, 'r-', lw=5)
   plt.pause(0.1)

运行结果:

tensorflow建立一个简单的神经网络的方法

来源:http://blog.csdn.net/qq_30159351/article/details/52639291

标签:tensorflow,神经网络
0
投稿

猜你喜欢

  • pytest中配置文件pytest.ini使用

    2021-01-22 17:04:02
  • SQL Server 2000 SP4补丁打不上的解决办法

    2010-03-08 13:13:00
  • 解决Mac下首次安装pycharm无project interpreter的问题

    2023-02-11 04:32:15
  • 使用c#构造date数据类型

    2024-01-15 22:19:15
  • python字典的遍历3种方法详解

    2022-05-01 06:00:44
  • 详解python时间模块中的datetime模块

    2023-09-26 02:41:28
  • Python实现照片卡通化

    2021-03-29 18:45:40
  • Asp的上下午时间格式问题

    2009-04-13 16:06:00
  • python语言中pandas字符串分割str.split()函数

    2022-01-30 16:55:56
  • golang中的并发和并行

    2024-04-26 17:15:11
  • OpenCV实现图片亮度增强或减弱

    2022-09-16 00:15:04
  • 浅析Python语言自带的数据结构有哪些

    2022-01-14 04:08:44
  • 网站升级兼容firefox经验小谈

    2007-10-28 20:28:00
  • windows10在visual studio2019下配置使用openCV4.3.0

    2021-10-23 12:23:31
  • python实现求解列表中元素的排列和组合问题

    2022-03-18 00:05:08
  • python实现windows下文件备份脚本

    2021-05-06 06:32:41
  • python图像和办公文档处理总结

    2021-03-08 19:24:02
  • win10环境下使用Hyper-V进行虚拟机创建的教程(图解)

    2022-08-01 02:25:06
  • 通过Cursor工具使用GPT-4的方法详解

    2023-08-28 05:08:34
  • python实现自动化办公邮件合并功能

    2022-02-22 21:06:06
  • asp之家 网络编程 m.aspxhome.com