keras 自定义loss层+接受输入实例

作者:lgy_keira 时间:2023-09-23 16:37:55 

loss函数如何接受输入值

keras封装的比较厉害,官网给的例子写的云里雾里,

在stackoverflow找到了答案

You can wrap the loss function as a inner function and pass your input tensor to it (as commonly done when passing additional arguments to the loss function).


def custom_loss_wrapper(input_tensor):
def custom_loss(y_true, y_pred):
 return K.binary_crossentropy(y_true, y_pred) + K.mean(input_tensor)
return custom_loss

input_tensor = Input(shape=(10,))
hidden = Dense(100, activation='relu')(input_tensor)
out = Dense(1, activation='sigmoid')(hidden)
model = Model(input_tensor, out)
model.compile(loss=custom_loss_wrapper(input_tensor), optimizer='adam')

You can verify that input_tensor and the loss value will change as different X is passed to the model.


X = np.random.rand(1000, 10)
y = np.random.randint(2, size=1000)
model.test_on_batch(X, y) # => 1.1974642
X *= 1000
model.test_on_batch(X, y) # => 511.15466

fit_generator

fit_generator ultimately calls train_on_batch which allows for x to be a dictionary.

Also, it could be a list, in which casex is expected to map 1:1 to the inputs defined in Model(input=[in1, …], …)


### generator
yield [inputX_1,inputX_2],y
### model
model = Model(inputs=[inputX_1,inputX_2],outputs=...)

补充知识:keras中自定义 loss损失函数和修改不同样本的loss权重(样本权重、类别权重)

首先辨析一下概念:

1. loss是整体网络进行优化的目标, 是需要参与到优化运算,更新权值W的过程的

2. metric只是作为评价网络表现的一种“指标”, 比如accuracy,是为了直观地了解算法的效果,充当view的作用,并不参与到优化过程

一、keras自定义损失函数

在keras中实现自定义loss, 可以有两种方式,一种自定义 loss function, 例如:


# 方式一
def vae_loss(x, x_decoded_mean):
xent_loss = objectives.binary_crossentropy(x, x_decoded_mean)
kl_loss = - 0.5 * K.mean(1 + z_log_sigma - K.square(z_mean) - K.exp(z_log_sigma), axis=-1)
return xent_loss + kl_loss

vae.compile(optimizer='rmsprop', loss=vae_loss)

或者通过自定义一个keras的层(layer)来达到目的, 作为model的最后一层,最后令model.compile中的loss=None:


# 方式二
# Custom loss layer
class CustomVariationalLayer(Layer):

def __init__(self, **kwargs):
 self.is_placeholder = True
 super(CustomVariationalLayer, self).__init__(**kwargs)
def vae_loss(self, x, x_decoded_mean_squash):

x = K.flatten(x)
 x_decoded_mean_squash = K.flatten(x_decoded_mean_squash)
 xent_loss = img_rows * img_cols * metrics.binary_crossentropy(x, x_decoded_mean_squash)
 kl_loss = - 0.5 * K.mean(1 + z_log_var - K.square(z_mean) - K.exp(z_log_var), axis=-1)
 return K.mean(xent_loss + kl_loss)

def call(self, inputs):

x = inputs[0]
 x_decoded_mean_squash = inputs[1]
 loss = self.vae_loss(x, x_decoded_mean_squash)
 self.add_loss(loss, inputs=inputs)
 # We don't use this output.
 return x

y = CustomVariationalLayer()([x, x_decoded_mean_squash])
vae = Model(x, y)
vae.compile(optimizer='rmsprop', loss=None)

在keras中自定义metric非常简单,需要用y_pred和y_true作为自定义metric函数的输入参数 点击查看metric的设置

注意事项:

1. keras中定义loss,返回的是batch_size长度的tensor, 而不是像tensorflow中那样是一个scalar

2. 为了能够将自定义的loss保存到model, 以及可以之后能够顺利load model, 需要把自定义的loss拷贝到keras.losses.py 源代码文件下,否则运行时找不到相关信息,keras会报错

有时需要不同的sample的loss施加不同的权重,这时需要用到sample_weight,例如

discriminator.train_on_batch(imgs, [valid, labels], class_weight=class_weights)

二、keras中的样本权重


# Import
import numpy as np
from sklearn.utils import class_weight

# Example model
model = Sequential()
model.add(Dense(32, activation='relu', input_dim=100))
model.add(Dense(1, activation='sigmoid'))

# Use binary crossentropy loss
model.compile(optimizer='rmsprop',
   loss='binary_crossentropy',
   metrics=['accuracy'])

# Calculate the weights for each class so that we can balance the data
weights = class_weight.compute_class_weight('balanced',
          np.unique(y_train),
          y_train)

# Add the class weights to the training          
model.fit(x_train, y_train, epochs=10, batch_size=32, class_weight=weights)

Note that the output of the class_weight.compute_class_weight() is an numpy array like this: [2.57569845 0.68250928].

来源:https://blog.csdn.net/u013608336/article/details/82559469

标签:keras,loss,输入
0
投稿

猜你喜欢

  • MySql 5.7.21免安装版本win10下的配置方法

    2024-01-23 00:45:35
  • python给视频添加背景音乐并改变音量的具体方法

    2021-01-26 20:18:47
  • 在ASP.NET 2.0中操作数据之十一:基于数据的自定义格式化

    2023-07-14 19:53:21
  • Vue3新属性之css中使用v-bind的方法(v-bind in css)

    2024-05-28 16:01:07
  • 调用其他python脚本文件里面的类和方法过程解析

    2021-01-11 13:27:14
  • PyQT实现菜单中的复制,全选和清空的功能的方法

    2023-08-13 03:09:23
  • 如何在页面中快捷地添加翻页按钮?

    2010-06-26 12:33:00
  • MySQL表设计优化与索引 (六)

    2010-10-25 19:53:00
  • python中numpy基础学习及进行数组和矢量计算

    2023-01-22 16:32:04
  • Google的用户体验设计原则

    2009-01-12 18:31:00
  • 深入探索数据库MySQL性能优化与复杂查询相关操作

    2024-01-26 20:25:11
  • 为你总结一些php信息函数

    2023-10-28 09:46:59
  • ASP小偷(远程数据获取)程序的入门教程

    2007-09-21 12:48:00
  • Python实现一元一次与一元二次方程求解

    2022-03-30 14:09:32
  • Pytorch深度学习经典卷积神经网络resnet模块训练

    2022-12-02 01:43:23
  • Python descriptor(描述符)的实现

    2021-12-22 09:07:56
  • python 制作网站小说下载器

    2021-06-07 23:04:42
  • 谈谈网页设计中的字体应用 (4) 实战应用篇·下

    2009-11-24 13:13:00
  • 7个鲜为人知却非常实用的PHP函数

    2023-10-15 03:46:47
  • Python字符串常用方法以及其应用场景详解

    2022-02-15 18:39:53
  • asp之家 网络编程 m.aspxhome.com