在keras下实现多个模型的融合方式

作者:小风风12580 时间:2023-06-03 17:14:59 

在网上搜过发现关于keras下的模型融合框架其实很简单,奈何网上说了一大堆,这个东西官方文档上就有,自己写了个demo:


# Function:基于keras框架下实现,多个独立任务分类
# Writer: PQF
# Time: 2019/9/29

import numpy as np
from keras.layers import Input, Dense
from keras.models import Model
import tensorflow as tf

# 生成训练集
dataset_size = 128*3
rdm = np.random.RandomState(1)
X = rdm.rand(dataset_size,2)
Y1 = [[int(x1+x2<1)] for (x1,x2) in X]
Y2 = [[int(x1+x2*x2<0.5)] for (x1,x2) in X]

X_train = X[:-2]
Y_train1 = Y1[:-2]
Y_train2 = Y2[:-2]

X_test = X[-2:dataset_size]
Y_test1 = Y1[-2:dataset_size]
Y_test2 = Y2[-2:dataset_size]

#网络一
input = Input(shape=(2,))
x = Dense(units=16,activation='relu')(input)
output = Dense(units=1,activation='sigmoid',name='output1')(x)

#网络二
input2 = Input(shape=(2,))
x2 = Dense(units=16,activation='relu')(input2)
output2 = Dense(units=1,activation='sigmoid',name='output2')(x2)

#模型合并
model = Model(inputs=[input,input2],outputs=[output,output2])
model.summary()

model.compile(optimizer='rmsprop',loss='binary_crossentropy',loss_weights=[1.0,1.0])
model.fit([X_train,X_train],[Y_train1,Y_train2],batch_size=48,epochs=200)

print('x_test is :\n')
print(X_test)
print('y_test1 is :\n')
print(Y_test1)
print('y_test2 is :\n')
print(Y_test2)

predict = model.predict([X_test,X_test])
print('prediction is : \n')
print(predict[0])
print(predict[1])

补充知识:keras的融合层使用理解

最近开始研究U-net网络,其中接触到了融合层的概念,做个笔记。

在keras下实现多个模型的融合方式

上图为U-net网络,其中上采样层(绿色箭头)需要与下采样层池化层(红色箭头)层进行融合,要求每层的图片大小一致,维度依照融合的方式可以不同,融合之后输出的图片相较于没有融合层的网络,边缘处要清晰很多!

这时候就要用到keras的融合层概念(Keras中文文档https://keras.io/zh/)

文档中分别讲述了加减乘除的四中融合方式,这种方式要求两层之间shape必须一致。

重点讲述一下Concatenate(拼接)方式

拼接方式默认依照最后一维也就是通道来进行拼接

在keras下实现多个模型的融合方式

如同上图(128*128*64)与(128*128*128)进行Concatenate之后的shape为128*128*192

ps:

中文文档为老版本,最新版本的keras.layers.merge方法进行了整合

在keras下实现多个模型的融合方式

上图为新版本整合之后的方法,具体使用方法一看就懂,不再赘述。

来源:https://blog.csdn.net/weixin_43392276/article/details/101757173

标签:keras,模型,融合
0
投稿

猜你喜欢

  • Python文件操作的方法

    2022-10-27 19:54:29
  • JS onmousemove鼠标移动坐标接龙DIV效果实例

    2023-08-08 19:59:13
  • Django限制API访问频率常用方法解析

    2022-06-24 18:20:13
  • python 顺时针打印矩阵的超简洁代码

    2023-03-25 14:03:52
  • Python基础常用内建函数图文示例解析

    2022-05-04 04:54:24
  • Python+drawpad实现CPU监控小程序

    2022-05-30 19:54:38
  • python中图像通道分离与合并实例

    2021-04-02 00:09:48
  • 深入了解Vue3中props的原理与使用

    2024-05-09 15:09:17
  • python基础入门学习笔记(Python环境搭建)

    2022-01-12 20:27:48
  • Python制作简易计算器功能

    2023-05-06 19:53:47
  • ASP实现文件直接下载

    2008-11-19 15:39:00
  • numpy中np.dstack()、np.hstack()、np.vstack()用法

    2021-08-27 11:47:42
  • 详解MySQL 数据库优化方法

    2010-08-12 14:50:00
  • MySQL8自增主键变化图文详解

    2024-01-25 19:08:38
  • Python中几种属性访问的区别与用法详解

    2022-12-24 23:36:20
  • python命名关键字参数的作用详解

    2023-09-01 10:35:37
  • python入门学习笔记分享

    2023-01-29 17:46:16
  • python处理json字符串(使用json.loads而不是eval())

    2023-06-13 11:50:39
  • Django框架安装及项目创建过程解析

    2022-09-20 12:55:45
  • python list转置和前后反转的例子

    2022-04-26 10:39:55
  • asp之家 网络编程 m.aspxhome.com