keras回调函数的使用

作者:辛勤的小码农^_^ 时间:2022-08-22 11:42:41 

回调函数

  • 回调函数是一个对象(实现了特定方法的类实例),它在调用fit()时被传入模型,并在训练过程中的不同时间点被模型调用

  • 可以访问关于模型状态与模型性能的所有可用数据

  • 模型检查点(model checkpointing):在训练过程中的不同时间点保存模型的当前状态。

  • 提前终止(early stopping):如果验证损失不再改善,则中断训练(当然,同时保存在训练过程中的最佳模型)。

  • 在训练过程中动态调节某些参数值:比如调节优化器的学习率。

  • 在训练过程中记录训练指标和验证指标,或者将模型学到的表示可视化(这些表示在不断更新):fit()进度条实际上就是一个回调函数。

fit()方法中使用callbacks参数

# 这里有两个callback函数:早停和模型检查点
callbacks_list=[
   keras.callbacks.EarlyStopping(
       monitor="val_accuracy",#监控指标
       patience=2 #两轮内不再改善中断训练
   ),
   keras.callbacks.ModelCheckpoint(
       filepath="checkpoint_path",
       monitor="val_loss",
       save_best_only=True
   )
]
#模型获取
model=get_minist_model()
model.compile(optimizer="rmsprop",
            loss="sparse_categorical_crossentropy",
            metrics=["accuracy"])

model.fit(train_images,train_labels,
        epochs=10,callbacks=callbacks_list, #该参数使用回调函数
        validation_data=(val_images,val_labels))

test_metrics=model.evaluate(test_images,test_labels)#计算模型在新数据上的损失和指标
predictions=model.predict(test_images)#计算模型在新数据上的分类概率

keras回调函数的使用

模型的保存和加载

#也可以在训练完成后手动保存模型,只需调用model.save('my_checkpoint_path')。
#重新加载模型
model_new=keras.models.load_model("checkpoint_path.keras")

通过对Callback类子类化来创建自定义回调函数

on_epoch_begin(epoch, logs) ←----在每轮开始时被调用
on_epoch_end(epoch, logs) ←----在每轮结束时被调用
on_batch_begin(batch, logs) ←----在处理每个批量之前被调用
on_batch_end(batch, logs) ←----在处理每个批量之后被调用
on_train_begin(logs) ←----在训练开始时被调用
on_train_end(logs ←----在训练结束时被调用

from matplotlib import pyplot as plt
# 实现记录每一轮中每个batch训练后的损失,并为每个epoch绘制一个图
class LossHistory(keras.callbacks.Callback):
   def on_train_begin(self, logs):
       self.per_batch_losses = []

def on_batch_end(self, batch, logs):
       self.per_batch_losses.append(logs.get("loss"))

def on_epoch_end(self, epoch, logs):
       plt.clf()
       plt.plot(range(len(self.per_batch_losses)), self.per_batch_losses,
                label="Training loss for each batch")
       plt.xlabel(f"Batch (epoch {epoch})")
       plt.ylabel("Loss")
       plt.legend()
       plt.savefig(f"plot_at_epoch_{epoch}")
       self.per_batch_losses = [] #清空,方便下一轮的技术
model = get_mnist_model()
model.compile(optimizer="rmsprop",
             loss="sparse_categorical_crossentropy",
             metrics=["accuracy"])
model.fit(train_images, train_labels,
         epochs=10,
         callbacks=[LossHistory()],
         validation_data=(val_images, val_labels))

keras回调函数的使用

【其他】模型的定义 和 数据加载

def get_minist_model():
   inputs=keras.Input(shape=(28*28,))
   features=layers.Dense(512,activation="relu")(inputs)
   features=layers.Dropout(0.5)(features)
   outputs=layers.Dense(10,activation="softmax")(features)
   model=keras.Model(inputs,outputs)
   return model

#datset
from tensorflow.keras.datasets import mnist
(train_images,train_labels),(test_images,test_labels)=mnist.load_data()
train_images=train_images.reshape((60000,28*28)).astype("float32")/255
test_images=test_images.reshape((10000,28*28)).astype("float32")/255
train_images,val_images=train_images[10000:],train_images[:10000]
train_labels,val_labels=train_labels[10000:],train_labels[:10000]

来源:https://blog.csdn.net/qq_43787439/article/details/129438485

标签:keras,回调函数
0
投稿

猜你喜欢

  • Python 对象中的数据类型

    2022-01-25 00:58:35
  • mysql4.1以上版本连接时出现Client does not support authentication protocol问题解决办法

    2023-11-18 06:10:15
  • 在ASP中按指定参数格式化显示时间的函数。

    2010-05-27 12:29:00
  • 在SQL Server中编写通用数据访问方法

    2009-01-20 11:35:00
  • [多图] Google Chrome 试用 Tips

    2009-12-09 15:49:00
  • ASP实现文件直接下载的代码

    2011-04-11 10:56:00
  • min-height 的原始实现方式

    2008-06-29 15:04:00
  • CSS网页设计时关于字体大小的设计

    2008-10-23 13:42:00
  • 详解django+django-celery+celery的整合实战

    2022-11-14 12:25:13
  • Python爬取网易云音乐上评论火爆的歌曲

    2021-09-16 11:49:53
  • 浅谈numpy广播机制

    2023-08-25 22:07:51
  • [译]艺术和设计的差异 (1)

    2009-09-25 12:38:00
  • python爬取音频下载的示例代码

    2023-07-25 09:59:57
  • javascript权威指南,学习笔记,之运算符号

    2008-04-20 16:43:00
  • Python高阶函数与装饰器函数的深入讲解

    2023-10-04 12:42:41
  • ASP+SQL Server构建网页防火墙

    2009-01-21 19:56:00
  • python中强制关闭线程与协程与进程方法

    2023-05-11 10:24:48
  • 简单的PHP缓存设计实现代码

    2023-10-25 19:58:08
  • Python imutils 填充图片周边为黑色的实现

    2021-04-13 04:06:32
  • 解读ASP.NET 5 & MVC6系列教程(12):基于Lamda表达式的强类型Routing实现

    2023-06-28 15:17:35
  • asp之家 网络编程 m.aspxhome.com