keras中的loss、optimizer、metrics用法
作者:wyf 发布时间:2022-06-15 15:53:07
用keras搭好模型架构之后的下一步,就是执行编译操作。在编译时,经常需要指定三个参数
loss
optimizer
metrics
这三个参数有两类选择:
使用字符串
使用标识符,如keras.losses,keras.optimizers,metrics包下面的函数
例如:
sgd = SGD(lr=0.01, decay=1e-6, momentum=0.9, nesterov=True)
model.compile(loss='categorical_crossentropy',
optimizer=sgd,
metrics=['accuracy'])
因为有时可以使用字符串,有时可以使用标识符,令人很想知道背后是如何操作的。下面分别针对optimizer,loss,metrics三种对象的获取进行研究。
optimizer
一个模型只能有一个optimizer,在执行编译的时候只能指定一个optimizer。
在keras.optimizers.py中,有一个get函数,用于根据用户传进来的optimizer参数获取优化器的实例:
def get(identifier):
# 如果后端是tensorflow并且使用的是tensorflow自带的优化器实例,可以直接使用tensorflow原生的优化器
if K.backend() == 'tensorflow':
# Wrap TF optimizer instances
if isinstance(identifier, tf.train.Optimizer):
return TFOptimizer(identifier)
# 如果以json串的形式定义optimizer并进行参数配置
if isinstance(identifier, dict):
return deserialize(identifier)
elif isinstance(identifier, six.string_types):
# 如果以字符串形式指定optimizer,那么使用优化器的默认配置参数
config = {'class_name': str(identifier), 'config': {}}
return deserialize(config)
if isinstance(identifier, Optimizer):
# 如果使用keras封装的Optimizer的实例
return identifier
else:
raise ValueError('Could not interpret optimizer identifier: ' +
str(identifier))
其中,deserilize(config)函数的作用就是把optimizer反序列化制造一个实例。
loss
keras.losses函数也有一个get(identifier)方法。其中需要注意以下一点:
如果identifier是可调用的一个函数名,也就是一个自定义的损失函数,这个损失函数返回值是一个张量。这样就轻而易举的实现了自定义损失函数。除了使用str和dict类型的identifier,我们也可以直接使用keras.losses包下面的损失函数。
def get(identifier):
if identifier is None:
return None
if isinstance(identifier, six.string_types):
identifier = str(identifier)
return deserialize(identifier)
if isinstance(identifier, dict):
return deserialize(identifier)
elif callable(identifier):
return identifier
else:
raise ValueError('Could not interpret '
'loss function identifier:', identifier)
metrics
在model.compile()函数中,optimizer和loss都是单数形式,只有metrics是复数形式。因为一个模型只能指明一个optimizer和loss,却可以指明多个metrics。metrics也是三者中处理逻辑最为复杂的一个。
在keras最核心的地方keras.engine.train.py中有如下处理metrics的函数。这个函数其实就做了两件事:
根据输入的metric找到具体的metric对应的函数
计算metric张量
在寻找metric对应函数时,有两种步骤:
使用字符串形式指明准确率和交叉熵
使用keras.metrics.py中的函数
def handle_metrics(metrics, weights=None):
metric_name_prefix = 'weighted_' if weights is not None else ''
for metric in metrics:
# 如果metrics是最常见的那种:accuracy,交叉熵
if metric in ('accuracy', 'acc', 'crossentropy', 'ce'):
# custom handling of accuracy/crossentropy
# (because of class mode duality)
output_shape = K.int_shape(self.outputs[i])
# 如果输出维度是1或者损失函数是二分类损失函数,那么说明是个二分类问题,应该使用二分类的accuracy和二分类的的交叉熵
if (output_shape[-1] == 1 or
self.loss_functions[i] == losses.binary_crossentropy):
# case: binary accuracy/crossentropy
if metric in ('accuracy', 'acc'):
metric_fn = metrics_module.binary_accuracy
elif metric in ('crossentropy', 'ce'):
metric_fn = metrics_module.binary_crossentropy
# 如果损失函数是sparse_categorical_crossentropy,那么目标y_input就不是one-hot的,所以就需要使用sparse的多类准去率和sparse的多类交叉熵
elif self.loss_functions[i] == losses.sparse_categorical_crossentropy:
# case: categorical accuracy/crossentropy
# with sparse targets
if metric in ('accuracy', 'acc'):
metric_fn = metrics_module.sparse_categorical_accuracy
elif metric in ('crossentropy', 'ce'):
metric_fn = metrics_module.sparse_categorical_crossentropy
else:
# case: categorical accuracy/crossentropy
if metric in ('accuracy', 'acc'):
metric_fn = metrics_module.categorical_accuracy
elif metric in ('crossentropy', 'ce'):
metric_fn = metrics_module.categorical_crossentropy
if metric in ('accuracy', 'acc'):
suffix = 'acc'
elif metric in ('crossentropy', 'ce'):
suffix = 'ce'
weighted_metric_fn = weighted_masked_objective(metric_fn)
metric_name = metric_name_prefix + suffix
else:
# 如果输入的metric不是字符串,那么就调用metrics模块获取
metric_fn = metrics_module.get(metric)
weighted_metric_fn = weighted_masked_objective(metric_fn)
# Get metric name as string
if hasattr(metric_fn, 'name'):
metric_name = metric_fn.name
else:
metric_name = metric_fn.__name__
metric_name = metric_name_prefix + metric_name
with K.name_scope(metric_name):
metric_result = weighted_metric_fn(y_true, y_pred,
weights=weights,
mask=masks[i])
# Append to self.metrics_names, self.metric_tensors,
# self.stateful_metric_names
if len(self.output_names) > 1:
metric_name = self.output_names[i] + '_' + metric_name
# Dedupe name
j = 1
base_metric_name = metric_name
while metric_name in self.metrics_names:
metric_name = base_metric_name + '_' + str(j)
j += 1
self.metrics_names.append(metric_name)
self.metrics_tensors.append(metric_result)
# Keep track of state updates created by
# stateful metrics (i.e. metrics layers).
if isinstance(metric_fn, Layer) and metric_fn.stateful:
self.stateful_metric_names.append(metric_name)
self.stateful_metric_functions.append(metric_fn)
self.metrics_updates += metric_fn.updates
无论怎么使用metric,最终都会变成metrics包下面的函数。当使用字符串形式指明accuracy和crossentropy时,keras会非常智能地确定应该使用metrics包下面的哪个函数。因为metrics包下的那些metric函数有不同的使用场景,例如:
有的处理的是one-hot形式的y_input(数据的类别),有的处理的是非one-hot形式的y_input
有的处理的是二分类问题的metric,有的处理的是多分类问题的metric
当使用字符串“accuracy”和“crossentropy”指明metric时,keras会根据损失函数、输出层的shape来确定具体应该使用哪个metric函数。在任何情况下,直接使用metrics下面的函数名是总不会出错的。
keras.metrics.py文件中也有一个get(identifier)函数用于获取metric函数。
def get(identifier):
if isinstance(identifier, dict):
config = {'class_name': str(identifier), 'config': {}}
return deserialize(config)
elif isinstance(identifier, six.string_types):
return deserialize(str(identifier))
elif callable(identifier):
return identifier
else:
raise ValueError('Could not interpret '
'metric function identifier:', identifier)
如果identifier是字符串或者字典,那么会根据identifier反序列化出一个metric函数。
如果identifier本身就是一个函数名,那么就直接返回这个函数名。这种方式就为自定义metric提供了巨大便利。
keras中的设计哲学堪称完美。
来源:https://www.cnblogs.com/weiyinfu/p/9783776.html
猜你喜欢
- 最近决定把MT的后台数据从Berkeley的文件DB转到MySQL。原因之一是使用关系数据库可以获得更多的灵活性,比如运行一条sql来变更
- 删除Git缓存的用户名和密码昨天在上传代码的时候提示输入用户名密码,结果输错了3次就没有提示框了,就一直报错(身份验证失败),没办法提交代。
- 一、视图的基本概念视图是用于查询的另外一种方式。 与实际的表不同,它是一个虚表;因此数据库中只存在视图的定义,而不存在视图中相对应的数据,数
- 由于Access数据库是一种文件型数据库,所以无法跨服务器进行访问。下面我们来介绍一下如何利用SQL Server 的链接服务器,把地理上分
- 一、打包多个1、将需要打包的项目为anjuke_sd目录下的所有python文件,其中excute_main.py为主文件。2、生成主函数对
- 好吧,我承认我是对晚上看到一张合适的票转让但打过电话去说已经被搞走了这件事情感到蛋疼。直接上文件吧。#coding: utf-8'&
- 客户强烈要求使用淘宝的首页商品分类效果,很BT~,没辙就满足一下人家的需求。通过淘宝案例,立即想到了显示/隐藏层的效果,于是在DW中画了几个
- 我就废话不多说了,直接上代码吧!import turtleturtle.pensize(5)turtle.pencolor("ye
- 本程序有两文件test.asp 和tree.asp 还有一些图标文件 1。test.asp 调用类生成树 代码如下<%@
- 假设某宝为鼓励大家双12买买买,奖励双十一那天订单最多的两位用户:分别是用户1:“剁手皇帝陈哈哈” 和 用户2:“触手怪刘大莉” 一人一万元
- 1.C++ 代码Demo.h#pragma oncevoid GeneratorGaussKernel(int ksize, float s
- 使用MySQL Administrator 登录,报错: Either the server service or the configur
- 目录前言项目设计后端前端运行项目Q&A前言在前面的Api开发中,我们使用FastApi已经可以很好的实现。但是实际使用中,我们通常建
- 从一个 Demo 入手为了快速进入状态,我们先搞一个 Demo,当然这个 Demo 是参考 Go 源码 src/net/rpc/s
- 如下所示:#!/usr/bin/env python#coding: utf8import getpassdb = {}def newUse
- import os ## for os.path.isfile()def dealline(line) :
- 无限分类是实际开发中经常用到的一种数据结构,一般我们称之为树形结构。题设:类似淘宝的商品分类,可以在任意分类设置其子类。 一、创建
- 下面的路径介绍针对windows在编写的py文件中打开文件的时候经常见到下面其中路径的表达方式:open('aaa.txt'
- 需求背景最近为公司开发了一套邮件日报程序,邮件一般就是表格,图片,然后就是附件。附件一般都是默认写到txt文件里,但是PM希望邮件里的附件能
- import pyperclipimport pyautogui# PyAutoGUI中文输入需要用粘贴实现# Py