tensorflow模型继续训练 fineturn实例

作者:-牧野- 时间:2023-07-10 12:53:09 

解决tensoflow如何在已训练模型上继续训练fineturn的问题。

训练代码

任务描述: x = 3.0, y = 100.0, 运算公式 x×W+b = y,求 W和b的最优解。


# -*- coding: utf-8 -*-)
import tensorflow as tf

# 声明占位变量x、y
x = tf.placeholder("float", shape=[None, 1])
y = tf.placeholder("float", [None, 1])

# 声明变量
W = tf.Variable(tf.zeros([1, 1]),name='w')
b = tf.Variable(tf.zeros([1]),name='b')

# 操作
result = tf.matmul(x, W) + b

# 损失函数
lost = tf.reduce_sum(tf.pow((result - y), 2))

# 优化
train_step = tf.train.GradientDescentOptimizer(0.0007).minimize(lost)

with tf.Session() as sess:
 # 初始化变量
 sess.run(tf.global_variables_initializer())
 saver = tf.train.Saver(max_to_keep=3)

# 这里x、y给固定的值
 x_s = [[3.0]]
 y_s = [[100.0]]

step = 0
 while (True):
   step += 1
   feed = {x: x_s, y: y_s}
   # 通过sess.run执行优化
   sess.run(train_step, feed_dict=feed)

if step % 1000 == 0:
     print 'step: {0}, loss: {1}'.format(step, sess.run(lost, feed_dict=feed))
     if sess.run(lost, feed_dict=feed) < 1e-10 or step > 4e3:
       print ''
       # print 'final loss is: {}'.format(sess.run(lost, feed_dict=feed))
       print 'final result of {0} = {1}(目标值是100.0)'.format('x×W+b', 3.0 * sess.run(W) + sess.run(b))
       print ''
       print("模型保存的W值 : %f" % sess.run(W))
       print("模型保存的b : %f" % sess.run(b))
       break
 saver.save(sess, "./save_model/re-train", global_step=step) # 保存模型

训练完成之后生成模型文件:

tensorflow模型继续训练 fineturn实例

训练输出:


step: 1000, loss: 4.89526428282e-08
step: 2000, loss: 4.89526428282e-08
step: 3000, loss: 4.89526428282e-08
step: 4000, loss: 4.89526428282e-08
step: 5000, loss: 4.89526428282e-08

final result of x×W+b = [[99.99978]](目标值是100.0)

模型保存的W值 : 29.999931
模型保存的b : 9.999982

保存在模型中的W值是 29.999931,b是 9.999982。

以下代码从保存的模型中恢复出训练状态,继续训练

任务描述: x = 3.0, y = 200.0, 运算公式 x×W+b = y,从上次训练的模型中恢复出训练参数,继续训练,求 W和b的最优解。


# -*- coding: utf-8 -*-)
import tensorflow as tf

# 声明占位变量x、y
x = tf.placeholder("float", shape=[None, 1])
y = tf.placeholder("float", [None, 1])

with tf.Session() as sess:

# 初始化变量
 sess.run(tf.global_variables_initializer())

# saver = tf.train.Saver(max_to_keep=3)
 saver = tf.train.import_meta_graph(r'./save_model/re-train-5000.meta') # 加载模型图结构
 saver.restore(sess, tf.train.latest_checkpoint(r'./save_model')) # 恢复数据

# 从保存模型中恢复变量
 graph = tf.get_default_graph()
 W = graph.get_tensor_by_name("w:0")
 b = graph.get_tensor_by_name("b:0")

print("从保存的模型中恢复出来的W值 : %f" % sess.run("w:0"))
 print("从保存的模型中恢复出来的b值 : %f" % sess.run("b:0"))

# 操作
 result = tf.matmul(x, W) + b
 # 损失函数
 lost = tf.reduce_sum(tf.pow((result - y), 2))
 # 优化
 train_step = tf.train.GradientDescentOptimizer(0.0007).minimize(lost)

# 这里x、y给固定的值
 x_s = [[3.0]]
 y_s = [[200.0]]

step = 0
 while (True):
   step += 1
   feed = {x: x_s, y: y_s}
   # 通过sess.run执行优化
   sess.run(train_step, feed_dict=feed)
   if step % 1000 == 0:
     print 'step: {0}, loss: {1}'.format(step, sess.run(lost, feed_dict=feed))
     if sess.run(lost, feed_dict=feed) < 1e-10 or step > 4e3:
       print ''
       # print 'final loss is: {}'.format(sess.run(lost, feed_dict=feed))
       print 'final result of {0} = {1}(目标值是200.0)'.format('x×W+b', 3.0 * sess.run(W) + sess.run(b))
       print("模型保存的W值 : %f" % sess.run(W))
       print("模型保存的b : %f" % sess.run(b))
       break
 saver.save(sess, "./save_mode/re-train", global_step=step) # 保存模型

训练输出:


从保存的模型中恢复出来的W值 : 29.999931
从保存的模型中恢复出来的b值 : 9.999982
step: 1000, loss: 1.95810571313e-07
step: 2000, loss: 1.95810571313e-07
step: 3000, loss: 1.95810571313e-07
step: 4000, loss: 1.95810571313e-07
step: 5000, loss: 1.95810571313e-07

final result of x×W+b = [[199.99956]](目标值是200.0)
模型保存的W值 : 59.999866
模型保存的b : 19.999958

从保存的模型中恢复出来的W值是 29.999931,b是 9.999982,跟模型保存的值一致,说明加载成功。

总结

从头开始训练一个模型,需要通过 tf.train.Saver创建一个保存器,完成之后使用save方法保存模型到本地:


saver = tf.train.Saver(max_to_keep=3)
……
saver.save(sess, "./save_model/re-train", global_step=step) # 保存模型

在训练好的模型上继续训练,fineturn一个模型,可以使用tf.train.import_meta_graph方法加载图结构,使用restore方法恢复训练数据,最后使用同样的save方法保存到本地:


saver = tf.train.import_meta_graph(r'./save_model/re-train-10050.meta') # 加载模型图结构
saver.restore(sess, tf.train.latest_checkpoint(r'./save_model')) # 恢复数据
saver.save(sess, "./save_mode/re-train", global_step=step) # 保存模型

注:特殊情况下(如本例)需要从恢复的模型中加载出数据:


# 从保存模型中恢复变量
graph = tf.get_default_graph()
W = graph.get_tensor_by_name("w:0")
b = graph.get_tensor_by_name("b:0")

来源:https://blog.csdn.net/dcrmg/article/details/83031488

标签:tensorflow,模型,训练,fineturn
0
投稿

猜你喜欢

  • Python实现按照指定要求逆序输出一个数字的方法

    2023-12-21 23:37:46
  • python3 求约数的实例

    2023-12-29 03:18:59
  • 解决MySQL Varchar 类型尾部空格的问题

    2024-01-25 15:31:28
  • 用python对oracle进行简单性能测试

    2021-07-08 16:51:59
  • python如何变换环境

    2021-06-02 19:19:44
  • OpenCV学习之图像加噪与滤波的实现详解

    2022-09-20 04:40:57
  • Python中os和shutil模块实用方法集锦

    2021-04-19 08:45:12
  • ASP程序中使用断开的数据记录集的代码

    2012-12-04 20:20:28
  • Python 爬虫之Beautiful Soup模块使用指南

    2021-10-16 13:28:03
  • MYSQL中 char 和 varchar的区别

    2024-01-25 22:22:52
  • python实现播放音频和录音功能示例代码

    2023-08-20 23:23:15
  • 人民币的符号的正确表示法?一杠?两杠?¥还是¥呢?

    2010-03-24 12:21:00
  • 从MySQL 5.5发布看开源数据库新模式

    2010-01-03 19:54:00
  • Spring数据库多数据源路由配置过程图解

    2024-01-26 11:23:55
  • video.js添加自定义组件的方法

    2024-04-30 10:09:03
  • 使用access数据库时可能用到的数据转换

    2008-09-10 12:49:00
  • Python装饰器的应用场景及实例用法

    2022-06-24 16:09:03
  • python获取代码运行时间的实例代码

    2023-11-04 02:25:10
  • 18个Python脚本可加速你的编码速度(提示和技巧)

    2022-11-09 16:54:34
  • Python实现轻松切割MP3文件

    2023-09-23 21:40:32
  • asp之家 网络编程 m.aspxhome.com