keras 获取某层的输入/输出 tensor 尺寸操作

作者:TinaO-O 时间:2021-11-14 09:55:06 

获取单输入尺寸,该层只被使用了一次。


import keras
from keras.layers import Input, LSTM, Dense, Conv2D
from keras.models import Model
a = Input(shape=(32, 32, 3))
b = Input(shape=(64, 64, 3))

conv = Conv2D(16, (3, 3), padding='same')
conved_a = conv(a)

# 到目前为止只有一个输入,以下可行:
assert conv.input_shape == (None, 32, 32, 3)

如果该层被使用了两次


import keras
from keras.layers import Input, LSTM, Dense, Conv2D
from keras.models import Model
a = Input(shape=(32, 32, 3))
b = Input(shape=(64, 64, 3))

conv = Conv2D(16, (3, 3), padding='same')
conved_a = conv(a)

# 到目前为止只有一个输入,以下可行:
assert conv.input_shape == (None, 32, 32, 3)

conved_b = conv(b)
# 现在 `.input_shape` 属性不可行,但是这样可以:
assert conv.get_input_shape_at(0) == (None, 32, 32, 3)
assert conv.get_input_shape_at(1) == (None, 64, 64, 3)

如果是输出,只需要改成output就好:


import keras
from keras.layers import Input, LSTM, Dense, Conv2D
from keras.models import Model
a = Input(shape=(32, 32, 3))
b = Input(shape=(64, 64, 3))

conv = Conv2D(16, (3, 3), padding='same')
conved_a = conv(a)

# 到目前为止只有一个输入,以下可行:
assert conv.input_shape == (None, 32, 32, 3)

conved_b = conv(b)
# 就改了output,当然尺寸我也改了
assert conv.get_output_shape_at(0) == (None, 32, 32, 16)
assert conv.get_output_shape_at(1) == (None, 64, 64, 16)

补充知识:keras中获取shape的正确方法

在keras的网络中,如果用layer_name.shape的方式获取shape信息将会返还tensorflow.python.framework.tensor_shape.TensorShape其中包含的是tensorflow.python.framework.tensor_shape.Dimension

正确的方式是使用

import keras.backend as K
K.int_shape(laye_name)

来源:https://blog.csdn.net/u013249853/article/details/89191441

标签:keras,输入,输出,tensor,尺寸
0
投稿

猜你喜欢

  • vue切换页面(路由)时如何保持滚动条回到顶部

    2024-05-28 15:54:49
  • Python3+pycuda实现执行简单GPU计算任务

    2022-06-04 09:55:29
  • 使用mysql_udf与curl库完成http_post通信模块示例

    2024-01-21 15:56:04
  • ASP动态包含文件的改进方法

    2009-01-05 12:22:00
  • python 两个一样的字符串用==结果为false问题的解决

    2023-01-24 08:30:59
  • python 利用turtle模块画出没有角的方格

    2022-03-09 04:25:04
  • 网页制作前台之javascript

    2013-07-23 08:32:59
  • MySQL错误TIMESTAMP column with CURRENT_TIMESTAMP的解决方法

    2024-01-25 20:47:47
  • Python3爬虫学习之爬虫利器Beautiful Soup用法分析

    2021-04-13 07:01:50
  • python利用paramiko连接远程服务器执行命令的方法

    2021-07-19 01:07:34
  • Python Pivot table透视表使用方法解析

    2021-06-21 10:22:59
  • 如何在不同版本的SQL Server中存储数据

    2009-01-15 13:06:00
  • Python根据欧拉角求旋转矩阵的实例

    2022-09-03 15:11:00
  • 使用Python的Django框架结合jQuery实现AJAX购物车页面

    2023-05-21 01:59:28
  • MAC下MYSQL5.7.17连接不上的问题及解决办法

    2024-01-15 00:35:32
  • Python文件读写保存操作的示例代码

    2022-03-20 01:21:23
  • 利用JavaScript实现拖拽改变元素大小

    2024-06-10 12:00:05
  • PHP扩展开发入门教程

    2024-05-05 09:17:51
  • 关于MySQL的体系结构及存储引擎图解

    2024-01-20 14:52:46
  • 浅谈Pandas Series 和 Numpy array中的相同点

    2022-06-11 15:20:17
  • asp之家 网络编程 m.aspxhome.com