tensorflow实现softma识别MNIST

作者:freedom098 时间:2021-02-17 22:32:56 

识别MNIST已经成了深度学习的hello world,所以每次例程基本都会用到这个数据集,这个数据集在tensorflow内部用着很好的封装,因此可以方便地使用。

这次我们用tensorflow搭建一个softmax多分类器,和之前搭建线性回归差不多,第一步是通过确定变量建立图模型,然后确定误差函数,最后调用优化器优化。

误差函数与线性回归不同,这里因为是多分类问题,所以使用了交叉熵。

另外,有一点值得注意的是,这里构建模型时我试图想拆分多个函数,但是后来发现这样做难度很大,因为图是在规定变量就已经定义好的,不能随意拆分,也不能当做变量传来传去,因此需要将他们写在一起。

代码如下:


#encoding=utf-8
__author__ = 'freedom'
import tensorflow as tf

def loadMNIST():
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data',one_hot=True)
return mnist

def softmax(mnist,rate=0.01,batchSize=50,epoch=20):
n = 784 # 向量的维度数目
m = None # 样本数,这里可以获取,也可以不获取
c = 10 # 类别数目

x = tf.placeholder(tf.float32,[m,n])
y = tf.placeholder(tf.float32,[m,c])

w = tf.Variable(tf.zeros([n,c]))
b = tf.Variable(tf.zeros([c]))

pred= tf.nn.softmax(tf.matmul(x,w)+b)
loss = tf.reduce_mean(-tf.reduce_sum(y*tf.log(pred),reduction_indices=1))
opt = tf.train.GradientDescentOptimizer(rate).minimize(loss)

init = tf.initialize_all_variables()

sess = tf.Session()
sess.run(init)
for index in range(epoch):
 avgLoss = 0
 batchNum = int(mnist.train.num_examples/batchSize)
 for batch in range(batchNum):
  batch_x,batch_y = mnist.train.next_batch(batchSize)
  _,Loss = sess.run([opt,loss],{x:batch_x,y:batch_y})
  avgLoss += Loss
 avgLoss /= batchNum
 print 'every epoch average loss is ',avgLoss

right = tf.equal(tf.argmax(pred,1),tf.argmax(y,1))
accuracy = tf.reduce_mean(tf.cast(right,tf.float32))
print 'Accracy is ',sess.run(accuracy,({x:mnist.test.images,y:mnist.test.labels}))

if __name__ == "__main__":
mnist = loadMNIST()
softmax(mnist)

来源:http://blog.csdn.net/freedom098/article/details/52116813

标签:tensorflow,softma,MNIST
0
投稿

猜你喜欢

  • JavaScript简单计算人的年龄示例

    2024-05-03 15:04:39
  • mysql中合并两个字段的方法分享

    2024-01-21 19:01:44
  • zookeeper python接口实例详解

    2023-03-11 01:34:48
  • mysql中取字符串中的数字的语句

    2024-01-15 02:16:15
  • python处理大日志文件

    2021-11-09 22:21:14
  • 网页禁用右键实现代码(JavaScript代码)

    2024-02-26 09:46:23
  • python中set()函数简介及实例解析

    2022-05-15 17:12:24
  • MySQL8.0中的窗口函数的示例代码

    2024-01-14 12:30:26
  • python 如何获取文件夹中的全部文件

    2022-09-10 16:48:11
  • Python+OpenCV之图像轮廓详解

    2023-08-10 18:59:42
  • Advanced SQL Injection with MySQL

    2024-01-24 18:09:24
  • 浅谈Keras中fit()和fit_generator()的区别及其参数的坑

    2022-04-18 07:22:26
  • 最新版 Windows10上安装Python 3.8.5的步骤详解

    2021-12-31 00:50:29
  • sqlserver获取当前日期的最大时间值

    2024-01-16 06:54:24
  • 在ASP中使用类,实现模块化

    2008-10-15 14:57:00
  • 能否用显示/隐藏层来控制FLASH播放与停止

    2008-10-27 14:08:00
  • 百度的图片轮换JS代码,支持FF

    2007-11-16 16:24:00
  • 可以举出一个最简单的计数器吗?

    2009-11-01 15:37:00
  • Javascript 中 var 和 let 、const 的区别及使用方法

    2024-05-09 15:07:41
  • CentOS 7上为PHP5安装suPHP的方法(彭哥)

    2024-03-24 23:34:12
  • asp之家 网络编程 m.aspxhome.com