keras 获取某层的输入/输出 tensor 尺寸操作
作者:TinaO-O 时间:2021-11-14 09:55:06
获取单输入尺寸,该层只被使用了一次。
import keras
from keras.layers import Input, LSTM, Dense, Conv2D
from keras.models import Model
a = Input(shape=(32, 32, 3))
b = Input(shape=(64, 64, 3))
conv = Conv2D(16, (3, 3), padding='same')
conved_a = conv(a)
# 到目前为止只有一个输入,以下可行:
assert conv.input_shape == (None, 32, 32, 3)
如果该层被使用了两次
import keras
from keras.layers import Input, LSTM, Dense, Conv2D
from keras.models import Model
a = Input(shape=(32, 32, 3))
b = Input(shape=(64, 64, 3))
conv = Conv2D(16, (3, 3), padding='same')
conved_a = conv(a)
# 到目前为止只有一个输入,以下可行:
assert conv.input_shape == (None, 32, 32, 3)
conved_b = conv(b)
# 现在 `.input_shape` 属性不可行,但是这样可以:
assert conv.get_input_shape_at(0) == (None, 32, 32, 3)
assert conv.get_input_shape_at(1) == (None, 64, 64, 3)
如果是输出,只需要改成output就好:
import keras
from keras.layers import Input, LSTM, Dense, Conv2D
from keras.models import Model
a = Input(shape=(32, 32, 3))
b = Input(shape=(64, 64, 3))
conv = Conv2D(16, (3, 3), padding='same')
conved_a = conv(a)
# 到目前为止只有一个输入,以下可行:
assert conv.input_shape == (None, 32, 32, 3)
conved_b = conv(b)
# 就改了output,当然尺寸我也改了
assert conv.get_output_shape_at(0) == (None, 32, 32, 16)
assert conv.get_output_shape_at(1) == (None, 64, 64, 16)
补充知识:keras中获取shape的正确方法
在keras的网络中,如果用layer_name.shape的方式获取shape信息将会返还tensorflow.python.framework.tensor_shape.TensorShape其中包含的是tensorflow.python.framework.tensor_shape.Dimension
正确的方式是使用
import keras.backend as K
K.int_shape(laye_name)
来源:https://blog.csdn.net/u013249853/article/details/89191441
标签:keras,输入,输出,tensor,尺寸
0
投稿
猜你喜欢
vue切换页面(路由)时如何保持滚动条回到顶部
2024-05-28 15:54:49
Python3+pycuda实现执行简单GPU计算任务
2022-06-04 09:55:29
使用mysql_udf与curl库完成http_post通信模块示例
2024-01-21 15:56:04
ASP动态包含文件的改进方法
2009-01-05 12:22:00
python 两个一样的字符串用==结果为false问题的解决
2023-01-24 08:30:59
python 利用turtle模块画出没有角的方格
2022-03-09 04:25:04
网页制作前台之javascript
2013-07-23 08:32:59
MySQL错误TIMESTAMP column with CURRENT_TIMESTAMP的解决方法
2024-01-25 20:47:47
Python3爬虫学习之爬虫利器Beautiful Soup用法分析
2021-04-13 07:01:50
python利用paramiko连接远程服务器执行命令的方法
2021-07-19 01:07:34
Python Pivot table透视表使用方法解析
2021-06-21 10:22:59
如何在不同版本的SQL Server中存储数据
2009-01-15 13:06:00
Python根据欧拉角求旋转矩阵的实例
2022-09-03 15:11:00
使用Python的Django框架结合jQuery实现AJAX购物车页面
2023-05-21 01:59:28
MAC下MYSQL5.7.17连接不上的问题及解决办法
2024-01-15 00:35:32
Python文件读写保存操作的示例代码
2022-03-20 01:21:23
利用JavaScript实现拖拽改变元素大小
2024-06-10 12:00:05
PHP扩展开发入门教程
2024-05-05 09:17:51
关于MySQL的体系结构及存储引擎图解
2024-01-20 14:52:46
浅谈Pandas Series 和 Numpy array中的相同点
2022-06-11 15:20:17