Keras实现将两个模型连接到一起
作者:木盏 时间:2021-07-10 07:24:08
神经网络玩得越久就越会尝试一些网络结构上的大改动。
先说意图
有两个模型:模型A和模型B。模型A的输出可以连接B的输入。将两个小模型连接成一个大模型,A-B,既可以同时训练又可以分离训练。
流行的算法里经常有这么关系的两个模型,对GAN来说,生成器和判别器就是这样子;对VAE来说,编码器和解码器就是这样子;对目标检测网络来说,backbone和整体也是可以拆分的。所以,应用范围还是挺广的。
实现方法
首先说明,我的实现方法不一定是最佳方法。也是实在没有借鉴到比较好的方法,所以才自己手动写了一个。
第一步,我们有现成的两个模型A和B;我们想把A的输出连到B的输入,组成一个整体C。
第二步, 重构新模型C;我的方法是:读出A和B各有哪些layer,然后一层一层重新搭成C。
可以看一个自编码器的代码(本人所编写):
class AE:
def __init__(self, dim, img_dim, batch_size):
self.dim = dim
self.img_dim = img_dim
self.batch_size = batch_size
self.encoder = self.encoder_construct()
self.decoder = self.decoder_construct()
def encoder_construct(self):
x_in = Input(shape=(self.img_dim, self.img_dim, 3))
x = x_in
x = Conv2D(self.dim // 16, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(x)
x = BatchNormalization()(x)
x = LeakyReLU(0.2)(x)
x = Conv2D(self.dim // 8, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(x)
x = BatchNormalization()(x)
x = LeakyReLU(0.2)(x)
x = Conv2D(self.dim // 4, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(x)
x = BatchNormalization()(x)
x = LeakyReLU(0.2)(x)
x = Conv2D(self.dim // 2, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(x)
x = BatchNormalization()(x)
x = LeakyReLU(0.2)(x)
x = Conv2D(self.dim, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(x)
x = BatchNormalization()(x)
x = LeakyReLU(0.2)(x)
x = GlobalAveragePooling2D()(x)
encoder = Model(x_in, x)
return encoder
def decoder_construct(self):
map_size = K.int_shape(self.encoder.layers[-2].output)[1:-1]
# print(type(map_size))
z_in = Input(shape=K.int_shape(self.encoder.output)[1:])
z = z_in
z_dim = self.dim
z = Dense(np.prod(map_size) * z_dim)(z)
z = Reshape(map_size + (z_dim,))(z)
z = Conv2DTranspose(z_dim // 2, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(z)
z = BatchNormalization()(z)
z = Activation('relu')(z)
z = Conv2DTranspose(z_dim // 4, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(z)
z = BatchNormalization()(z)
z = Activation('relu')(z)
z = Conv2DTranspose(z_dim // 8, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(z)
z = BatchNormalization()(z)
z = Activation('relu')(z)
z = Conv2DTranspose(z_dim // 16, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(z)
z = BatchNormalization()(z)
z = Activation('relu')(z)
z = Conv2DTranspose(3, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(z)
z = Activation('tanh')(z)
decoder = Model(z_in, z)
return decoder
def build_ae(self):
input_x = Input(shape=(self.img_dim, self.img_dim, 3))
x = input_x
for i in range(1, len(self.encoder.layers)):
x = self.encoder.layers[i](x)
for j in range(1, len(self.decoder.layers)):
x = self.decoder.layers[j](x)
y = x
auto_encoder = Model(input_x, y)
return auto_encoder
模型A就是这里的encoder,模型B就是这里的decoder。所以,连接的精髓在build_ae()函数,直接用for循环读出各层,然后一层一层重新构造新的模型,从而实现连接效果。因为keras也是基于图的框架,这个操作并不会很费时,因为没有实际地计算。
补充知识:keras得到每层的系数
使用keras搭建好一个模型,训练好,怎么得到每层的系数呢:
weights = np.array(model.get_weights())
print(weights)
print(weights[0].shape)
print(weights[1].shape)
这样系数就被存放到一个np中了。
来源:https://blog.csdn.net/leviopku/article/details/83510927
标签:Keras,模型,连接
0
投稿
猜你喜欢
Mysql精粹系列(精粹)
2024-01-21 02:27:05
Python异常处理如何才能写得优雅(retrying模块)
2023-07-13 05:50:44
一个asp正则替换的方法
2008-11-25 14:05:00
Golang并发编程之Channel详解
2024-05-09 14:58:42
Js sort排序使用方法
2023-10-19 10:20:55
MySQL基础教程第一篇 mysql5.7.18安装和连接教程
2024-01-15 18:55:19
如何利用Python打开txt格式的文件
2022-06-01 02:08:36
IE6中隐形的PNG8图片
2009-11-27 18:38:00
python学习字符串驻留与常量折叠隐藏特性详解
2021-01-19 14:40:10
MySQL内连接和外连接及七种SQL JOINS的实现
2024-01-21 09:23:16
使用PyCharm创建Django项目及基本配置详解
2021-03-31 10:51:36
java 中JDBC连接数据库代码和步骤详解及实例代码
2024-01-27 16:35:14
MySQL异常处理浅析
2024-01-17 21:47:44
Vue中axios的封装(报错、鉴权、跳转、拦截、提示)
2024-05-02 17:06:03
一波神奇的Python语句、函数与方法的使用技巧总结
2023-05-12 19:48:41
关于mysql基础知识的介绍
2024-01-18 10:57:28
CTF中的PHP特性函数解析之上篇
2023-06-14 02:19:58
thinkphp(php)插件钩子(hooks)分析的简单实现机制
2023-05-25 09:27:58
Python 编码Basic Auth使用方法简单实例
2023-06-13 22:29:50
Python实现字符串格式化输出的方法详解
2023-05-02 20:52:59