tensorflow模型的save与restore,及checkpoint中读取变量方式
作者:J_______ll 发布时间:2022-07-27 17:41:43
创建一个NN
import tensorflow as tf
import numpy as np
#fake data
x = np.linspace(-1, 1, 100)[:, np.newaxis] #shape(100,1)
noise = np.random.normal(0, 0.1, size=x.shape)
y = np.power(x, 2) + noise #shape(100,1) + noise
tf_x = tf.placeholder(tf.float32, x.shape) #input x
tf_y = tf.placeholder(tf.float32, y.shape) #output y
l = tf.layers.dense(tf_x, 10, tf.nn.relu) #hidden layer
o = tf.layers.dense(l, 1) #output layer
loss = tf.losses.mean_squared_error(tf_y, o ) #compute loss
train_op = tf.train.GradientDescentOptimizer(learning_rate=0.5).minimize(loss)
1.使用save对模型进行保存
sess= tf.Session()
sess.run(tf.global_variables_initializer()) #initialize var in graph
saver = tf.train.Saver() # define a saver for saving and restoring
for step in range(100): #train
sess.run(train_op,{tf_x:x, tf_y:y})
saver.save(sess, 'params/params.ckpt', write_meta_graph=False) # mate_graph is not recommend
生成三个文件,分别是checkpoint,.ckpt.data-00000-of-00001,.ckpt.index
2.使用restore对提取模型
在提取模型时,需要将模型结构再定义一遍,再将各参数加载出来
#bulid entire net again and restore
tf_x = tf.placeholder(tf.float32, x.shape)
tf_y = tf.placeholder(tf.float32, y.shape)
l_ = tf.layers.dense(tf_x, 10, tf.nn.relu)
o_ = tf.layers.dense(l_, 1)
loss_ = tf.losses.mean_squared_error(tf_y, o_)
sess = tf.Session()
# don't need to initialize variables, just restoring trained variables
saver = tf.train.Saver() # define a saver for saving and restoring
saver.restore(sess, './params/params.ckpt')
3.有时会报错Not found:b1 not found in checkpoint
这时我们想知道我在文件中到底保存了什么内容,即需要读取出checkpoint中的tensor
import os
from tensorflow.python import pywrap_tensorflow
checkpoint_path = os.path.join('params','params.ckpt')
# Read data from checkpoint file
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
var_to_shape_map = reader.get_variable_to_shape_map()
# Print tensor name and value
f = open('params.txt','w')
for key in var_to_shape_map: # write tensors' names and values in file
print(key,file=f)
print(reader.get_tensor(key),file=f)
f.close()
运行后生成一个params.txt文件,在其中可以看到模型的参数。
补充知识:TensorFlow按时间保存检查点
一 实例
介绍一种更简便地保存检查点功能的方法——tf.train.MonitoredTrainingSession函数,该函数可以直接实现保存及载入检查点模型的文件。
演示使用MonitoredTrainingSession函数来自动管理检查点文件。
二 代码
import tensorflow as tf
tf.reset_default_graph()
global_step = tf.train.get_or_create_global_step()
step = tf.assign_add(global_step, 1)
with tf.train.MonitoredTrainingSession(checkpoint_dir='log/checkpoints',save_checkpoint_secs = 2) as sess:
print(sess.run([global_step]))
while not sess.should_stop():
i = sess.run( step)
print( i)
三 运行结果
1 第一次运行后,会发现log文件夹下产生如下文件
2 第二次运行后,结果如下:
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Restoring parameters from log/checkpoints\model.ckpt-15147
INFO:tensorflow:Saving checkpoints for 15147 into log/checkpoints\model.ckpt.
[15147]
15148
15149
15150
15151
15152
15153
15154
15155
15156
15157
15158
15159
四 说明
本例是按照训练时间来保存的。通过指定save_checkpoint_secs参数的具体秒数,来设置每训练多久保存一次检查点。
可见程序自动载入检查点是从第15147次开始运行的。
五 注意
1 如果不设置save_checkpoint_secs参数,默认的保存时间是10分钟,这种按照时间保存的模式更适合用于使用大型数据集来训练复杂模型的情况。
2 使用该方法,必须要定义global_step变量,否则会报错误。
来源:https://blog.csdn.net/J_______ll/article/details/80186201


猜你喜欢
- Main.jsvar routeList = [];router.beforeEach((to, from, next) => { v
- python十进制转二进制python中十进制转二进制使用 bin() 函数。bin() 返回一个整数 int 或者长整数 long int
- 请定义函数,将列表[10, 1, 2, 20, 10, 3, 2, 1, 15, 20, 44, 56, 3, 2, 1]中的重复元素除去,
- 题目描述原题链接 :463. 岛屿的周长 - 力扣(LeetCode)给定一个 row x col 的二维网格地图 grid ,其中:gri
- 1. JSON简介JSON(JavaScript Object Notation) 是一种轻量级的数据交换格式,它是JavaScript的子
- 1.今天复习一下Vue自定义指令的代码,结果出现一个很无语的结果,先贴代码。2.<div id="example"
- 写在前面:在个别时候可能需要查看当前最新的事务 ID,以便做一些业务逻辑上的判断(例如利用事务 ID 变化以及前后时差,统计每次事务的响应时
- 1.Mysql中的数据类型varchar 动态字符串类型(最长255位),可以根据实际长度来动态分配空间,例如:varchar(100)ch
- 问题描述环境: CentOS6.5想在此环境下使用python3进行开发,但CentOS6.5默认的python环境是2.6.6版本。 之前
- 前言PyTorch作为一款深度学习框架,已经帮助我们实现了很多很多的功能了,包括数据的读取和转换了,那么这一章节就介绍一下PyTorch内置
- 优化数据库的注意事项:1、关键字段建立索引。2、使用存储过程,它使SQL变得更加灵活和高效。3、备份数据库和清除垃圾数据。4、SQL语句语法
- 我们知道map() 会根据提供的函数对指定序列做映射。 第一个参数 function 以参数序列中的每一个元素调用 function函数,返
- 当我们建立一个数据库时,并且想将分散在各处的不同类型的数据库分类汇总在这个新建的数据库中时,尤其是在进行数据检验、净化和转换时,将会面临很大
- 重载:同一个类中,函数名一样,返回值或者参数类型,个数不一样的叫做重载。 覆盖:同名函数,同返回值类型,同参数的叫做覆盖。指的是子类对父类中
- 查询操作和性能优化1.基本操作增models.Tb1.objects.create(c1='xx', c2='oo&
- 如何选择速度最快的站点? <html><head><meta http-equiv=&qu
- 1 unittest框架unittest 是python 的单元测试框架,它主要有以下作用:提供用例组织与执行:当你的测试用例只有几条时,可
- 前言最近在解决一些算法优化的问题,为了实时性要求,必须精益求精的将资源利用率用到极致。同时对算法中一些处理进行多线程或者多进程处理。在对代码
- MySQL存储过程与存储函数的相关概念存储函数和存储过程的主要区别:存储函数一定会有返回值的存储过程不一定有返回值存储过程和函数能后将复杂的
- asp之家注:如果你学习过asp,并且在网络公司上过班,一定会接触到网购系统,网购系统可以说是一个典型的程序类型,而其中最重要,也是最关键的