Keras实现将两个模型连接到一起

作者:木盏 时间:2021-07-10 07:24:08 

神经网络玩得越久就越会尝试一些网络结构上的大改动。

先说意图

有两个模型:模型A和模型B。模型A的输出可以连接B的输入。将两个小模型连接成一个大模型,A-B,既可以同时训练又可以分离训练。

流行的算法里经常有这么关系的两个模型,对GAN来说,生成器和判别器就是这样子;对VAE来说,编码器和解码器就是这样子;对目标检测网络来说,backbone和整体也是可以拆分的。所以,应用范围还是挺广的。

实现方法

首先说明,我的实现方法不一定是最佳方法。也是实在没有借鉴到比较好的方法,所以才自己手动写了一个。

第一步,我们有现成的两个模型A和B;我们想把A的输出连到B的输入,组成一个整体C。

第二步, 重构新模型C;我的方法是:读出A和B各有哪些layer,然后一层一层重新搭成C。

可以看一个自编码器的代码(本人所编写):


class AE:
def __init__(self, dim, img_dim, batch_size):
 self.dim = dim
 self.img_dim = img_dim
 self.batch_size = batch_size
 self.encoder = self.encoder_construct()
 self.decoder = self.decoder_construct()

def encoder_construct(self):
 x_in = Input(shape=(self.img_dim, self.img_dim, 3))
 x = x_in
 x = Conv2D(self.dim // 16, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(x)
 x = BatchNormalization()(x)
 x = LeakyReLU(0.2)(x)
 x = Conv2D(self.dim // 8, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(x)
 x = BatchNormalization()(x)
 x = LeakyReLU(0.2)(x)
 x = Conv2D(self.dim // 4, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(x)
 x = BatchNormalization()(x)
 x = LeakyReLU(0.2)(x)
 x = Conv2D(self.dim // 2, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(x)
 x = BatchNormalization()(x)
 x = LeakyReLU(0.2)(x)
 x = Conv2D(self.dim, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(x)
 x = BatchNormalization()(x)
 x = LeakyReLU(0.2)(x)
 x = GlobalAveragePooling2D()(x)
 encoder = Model(x_in, x)
 return encoder

def decoder_construct(self):
 map_size = K.int_shape(self.encoder.layers[-2].output)[1:-1]
 # print(type(map_size))
 z_in = Input(shape=K.int_shape(self.encoder.output)[1:])
 z = z_in
 z_dim = self.dim
 z = Dense(np.prod(map_size) * z_dim)(z)
 z = Reshape(map_size + (z_dim,))(z)
 z = Conv2DTranspose(z_dim // 2, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(z)
 z = BatchNormalization()(z)
 z = Activation('relu')(z)
 z = Conv2DTranspose(z_dim // 4, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(z)
 z = BatchNormalization()(z)
 z = Activation('relu')(z)
 z = Conv2DTranspose(z_dim // 8, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(z)
 z = BatchNormalization()(z)
 z = Activation('relu')(z)
 z = Conv2DTranspose(z_dim // 16, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(z)
 z = BatchNormalization()(z)
 z = Activation('relu')(z)
 z = Conv2DTranspose(3, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(z)
 z = Activation('tanh')(z)
 decoder = Model(z_in, z)
 return decoder

def build_ae(self):
 input_x = Input(shape=(self.img_dim, self.img_dim, 3))
 x = input_x
 for i in range(1, len(self.encoder.layers)):
  x = self.encoder.layers[i](x)
 for j in range(1, len(self.decoder.layers)):
  x = self.decoder.layers[j](x)
 y = x
 auto_encoder = Model(input_x, y)
 return auto_encoder

模型A就是这里的encoder,模型B就是这里的decoder。所以,连接的精髓在build_ae()函数,直接用for循环读出各层,然后一层一层重新构造新的模型,从而实现连接效果。因为keras也是基于图的框架,这个操作并不会很费时,因为没有实际地计算。

补充知识:keras得到每层的系数

使用keras搭建好一个模型,训练好,怎么得到每层的系数呢:


weights = np.array(model.get_weights())
print(weights)
print(weights[0].shape)
print(weights[1].shape)

这样系数就被存放到一个np中了。

来源:https://blog.csdn.net/leviopku/article/details/83510927

标签:Keras,模型,连接
0
投稿

猜你喜欢

  • Mysql精粹系列(精粹)

    2024-01-21 02:27:05
  • Python异常处理如何才能写得优雅(retrying模块)

    2023-07-13 05:50:44
  • 一个asp正则替换的方法

    2008-11-25 14:05:00
  • Golang并发编程之Channel详解

    2024-05-09 14:58:42
  • Js sort排序使用方法

    2023-10-19 10:20:55
  • MySQL基础教程第一篇 mysql5.7.18安装和连接教程

    2024-01-15 18:55:19
  • 如何利用Python打开txt格式的文件

    2022-06-01 02:08:36
  • IE6中隐形的PNG8图片

    2009-11-27 18:38:00
  • python学习字符串驻留与常量折叠隐藏特性详解

    2021-01-19 14:40:10
  • MySQL内连接和外连接及七种SQL JOINS的实现

    2024-01-21 09:23:16
  • 使用PyCharm创建Django项目及基本配置详解

    2021-03-31 10:51:36
  • java 中JDBC连接数据库代码和步骤详解及实例代码

    2024-01-27 16:35:14
  • MySQL异常处理浅析

    2024-01-17 21:47:44
  • Vue中axios的封装(报错、鉴权、跳转、拦截、提示)

    2024-05-02 17:06:03
  • 一波神奇的Python语句、函数与方法的使用技巧总结

    2023-05-12 19:48:41
  • 关于mysql基础知识的介绍

    2024-01-18 10:57:28
  • CTF中的PHP特性函数解析之上篇

    2023-06-14 02:19:58
  • thinkphp(php)插件钩子(hooks)分析的简单实现机制

    2023-05-25 09:27:58
  • Python 编码Basic Auth使用方法简单实例

    2023-06-13 22:29:50
  • Python实现字符串格式化输出的方法详解

    2023-05-02 20:52:59
  • asp之家 网络编程 m.aspxhome.com