keras K.function获取某层的输出操作

作者:脱贫&&脱单&&不脱发 时间:2023-03-11 15:10:21 

如下所示:


from keras import backend as K
from keras.models import load_model

models = load_model('models.hdf5')
image=r'image.png'
images=cv2.imread(r'image.png')
image_arr = process_image(image, (224, 224, 3))
image_arr = np.expand_dims(image_arr, axis=0)
layer_1 = K.function([base_model.get_input_at(0)], [base_model.get_layer('layer_name').output])
f1 = layer_1([image_arr])[0]

加载训练好并保存的网络模型

加载数据(图像),并将数据处理成array形式

指定输出层

将处理后的数据输入,然后获取输出

其中,K.function有两种不同的写法:

1. 获取名为layer_name的层的输出

layer_1 = K.function([base_model.get_input_at(0)], [base_model.get_layer('layer_name').output])#指定输出层的名称

2. 获取第n层的输出

layer_1 = K.function([model.get_input_at(0)], [model.layers[5].output])#指定输出层的序号(层号从0开始)

另外,需要注意的是,书写不规范会导致报错:

报错:

TypeError: inputs to a TensorFlow backend function should be a list or tuple

将该句:

f1 = layer_1(image_arr)[0]

修改为:

f1 = layer_1([image_arr])[0]

补充知识:keras.backend.function()

如下所示:


def function(inputs, outputs, updates=None, **kwargs):
"""Instantiates a Keras function.
Arguments:
  inputs: List of placeholder tensors.
  outputs: List of output tensors.
  updates: List of update ops.
  **kwargs: Passed to `tf.Session.run`.
Returns:
  Output values as Numpy arrays.
Raises:
  ValueError: if invalid kwargs are passed in.
"""
if kwargs:
 for key in kwargs:
  if (key not in tf_inspect.getargspec(session_module.Session.run)[0] and
    key not in tf_inspect.getargspec(Function.__init__)[0]):
   msg = ('Invalid argument "%s" passed to K.function with Tensorflow '
       'backend') % key
   raise ValueError(msg)
return Function(inputs, outputs, updates=updates, **kwargs)

这是keras.backend.function()的源码。其中函数定义开头的注释就是官方文档对该函数的解释。

我们可以发现function()函数返回的是一个Function对象。下面是Function类的定义。


class Function(object):
"""Runs a computation graph.
Arguments:
  inputs: Feed placeholders to the computation graph.
  outputs: Output tensors to fetch.
  updates: Additional update ops to be run at function call.
  name: a name to help users identify what this function does.
"""

def __init__(self, inputs, outputs, updates=None, name=None,
       **session_kwargs):
 updates = updates or []
 if not isinstance(inputs, (list, tuple)):
  raise TypeError('`inputs` to a TensorFlow backend function '
          'should be a list or tuple.')
 if not isinstance(outputs, (list, tuple)):
  raise TypeError('`outputs` of a TensorFlow backend function '
          'should be a list or tuple.')
 if not isinstance(updates, (list, tuple)):
  raise TypeError('`updates` in a TensorFlow backend function '
          'should be a list or tuple.')
 self.inputs = list(inputs)
 self.outputs = list(outputs)
 with ops.control_dependencies(self.outputs):
  updates_ops = []
  for update in updates:
   if isinstance(update, tuple):
    p, new_p = update
    updates_ops.append(state_ops.assign(p, new_p))
   else:
    # assumed already an op
    updates_ops.append(update)
  self.updates_op = control_flow_ops.group(*updates_ops)
 self.name = name
 self.session_kwargs = session_kwargs

def __call__(self, inputs):
 if not isinstance(inputs, (list, tuple)):
  raise TypeError('`inputs` should be a list or tuple.')
 feed_dict = {}
 for tensor, value in zip(self.inputs, inputs):
  if is_sparse(tensor):
   sparse_coo = value.tocoo()
   indices = np.concatenate((np.expand_dims(sparse_coo.row, 1),
                np.expand_dims(sparse_coo.col, 1)), 1)
   value = (indices, sparse_coo.data, sparse_coo.shape)
  feed_dict[tensor] = value
 session = get_session()
 updated = session.run(
   self.outputs + [self.updates_op],
   feed_dict=feed_dict,
   **self.session_kwargs)
 return updated[:len(self.outputs)]

所以,function函数利用我们之前已经创建好的comuptation graph。遵循计算图,从输入到定义的输出。这也是为什么该函数经常用于提取中间层结果。

来源:https://blog.csdn.net/qq_37974048/article/details/102727653

标签:keras,K.function,层输出
0
投稿

猜你喜欢

  • javascript设计模式之模块模式学习笔记

    2024-04-29 13:16:11
  • python实现从尾到头打印单链表操作示例

    2021-12-20 00:09:32
  • Django Admin后台模型列表页面如何添加自定义操作按钮

    2021-02-24 18:57:15
  • Python3 Post登录并且保存cookie登录其他页面的方法

    2023-08-18 22:45:52
  • 手把手教你pip配置国内镜像源(最新详尽版)

    2023-05-30 10:19:03
  • 深刻理解Oracle数据库的启动和关闭

    2010-07-26 13:08:00
  • Python实现批量识别图片文字并存为Excel

    2021-07-28 06:34:23
  • python3.6实现学生信息管理系统

    2021-02-09 20:54:48
  • 详解webpack编译多页面vue项目的配置问题

    2024-06-15 00:50:22
  • 用Pycharm实现鼠标滚轮控制字体大小的方法

    2023-02-08 15:34:32
  • 油猴脚本编写教程详解

    2023-05-26 12:29:51
  • linux环境下安装mysql数据库的详细教程

    2024-01-15 02:12:31
  • Python的Flask框架中@app.route的用法教程

    2022-05-14 07:25:19
  • 如何修改vue-treeSelect的高度

    2024-05-08 09:33:55
  • Python实现获取乱序列表排序后的新下标的示例

    2021-04-25 10:36:42
  • python中pygame安装过程(超级详细)

    2022-08-05 04:26:15
  • 谈谈如何管理门户级网站的CSS/IMG/JS文件

    2009-09-03 11:48:00
  • python实现简单神经网络算法

    2021-03-22 07:03:16
  • MySQL插入时间差八小时问题的解决方法

    2024-01-28 22:00:09
  • 解决MybatisPlus SqlServer OFFSET 分页问题

    2024-01-12 16:26:24
  • asp之家 网络编程 m.aspxhome.com