TensorFlow实现自定义Op方式

作者:Wilbur529 时间:2021-11-01 15:56:38 

『写在前面』

以CTC Beam search decoder为例,简单整理一下TensorFlow实现自定义Op的操作流程。

基本的流程

1. 定义Op接口


#include "tensorflow/core/framework/op.h"

REGISTER_OP("Custom")  
 .Input("custom_input: int32")
 .Output("custom_output: int32");

2. 为Op实现Compute操作(CPU)或实现kernel(GPU)


#include "tensorflow/core/framework/op_kernel.h"

using namespace tensorflow;

class CustomOp : public OpKernel{
 public:
 explicit CustomOp(OpKernelConstruction* context) : OpKernel(context) {}
 void Compute(OpKernelContext* context) override {
 // 获取输入 tensor.
 const Tensor& input_tensor = context->input(0);
 auto input = input_tensor.flat<int32>();
 // 创建一个输出 tensor.
 Tensor* output_tensor = NULL;
 OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(),
                          &output_tensor));
 auto output = output_tensor->template flat<int32>();
 //进行具体的运算,操作input和output
 //……
}
};

3. 将实现的kernel注册到TensorFlow系统中

REGISTER_KERNEL_BUILDER(Name("Custom").Device(DEVICE_CPU), CustomOp);

CTCBeamSearchDecoder自定义

该Op对应TensorFlow中的源码部分

Op接口的定义:

tensorflow-master/tensorflow/core/ops/ctc_ops.cc

CTCBeamSearchDecoder本身的定义:

tensorflow-master/tensorflow/core/util/ctc/ctc_beam_search.cc

Op-Class的封装与Op注册:

tensorflow-master/tensorflow/core/kernels/ctc_decoder_ops.cc

基于源码修改的Op


#include <algorithm>
#include <vector>
#include <cmath>

#include "tensorflow/core/util/ctc/ctc_beam_search.h"

#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/kernels/bounds_check.h"

namespace tf = tensorflow;
using tf::shape_inference::DimensionHandle;
using tf::shape_inference::InferenceContext;
using tf::shape_inference::ShapeHandle;

using namespace tensorflow;

REGISTER_OP("CTCBeamSearchDecoderWithParam")
 .Input("inputs: float")
 .Input("sequence_length: int32")
 .Attr("beam_width: int >= 1")
 .Attr("top_paths: int >= 1")
 .Attr("merge_repeated: bool = true")
 //新添加了两个参数
 .Attr("label_selection_size: int >= 0 = 0")
 .Attr("label_selection_margin: float")
 .Output("decoded_indices: top_paths * int64")
 .Output("decoded_values: top_paths * int64")
 .Output("decoded_shape: top_paths * int64")
 .Output("log_probability: float")
 .SetShapeFn([](InferenceContext* c) {
  ShapeHandle inputs;
  ShapeHandle sequence_length;

TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 3, &inputs));
  TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &sequence_length));

// Get batch size from inputs and sequence_length.
  DimensionHandle batch_size;
  TF_RETURN_IF_ERROR(
    c->Merge(c->Dim(inputs, 1), c->Dim(sequence_length, 0), &batch_size));

int32 top_paths;
  TF_RETURN_IF_ERROR(c->GetAttr("top_paths", &top_paths));

// Outputs.
  int out_idx = 0;
  for (int i = 0; i < top_paths; ++i) { // decoded_indices
   c->set_output(out_idx++, c->Matrix(InferenceContext::kUnknownDim, 2));
  }
  for (int i = 0; i < top_paths; ++i) { // decoded_values
   c->set_output(out_idx++, c->Vector(InferenceContext::kUnknownDim));
  }
  ShapeHandle shape_v = c->Vector(2);
  for (int i = 0; i < top_paths; ++i) { // decoded_shape
   c->set_output(out_idx++, shape_v);
  }
  c->set_output(out_idx++, c->Matrix(batch_size, top_paths));
  return Status::OK();
 });

typedef Eigen::ThreadPoolDevice CPUDevice;

inline float RowMax(const TTypes<float>::UnalignedConstMatrix& m, int r,
         int* c) {
*c = 0;
CHECK_LT(0, m.dimension(1));
float p = m(r, 0);
for (int i = 1; i < m.dimension(1); ++i) {
 if (m(r, i) > p) {
  p = m(r, i);
  *c = i;
 }
}
return p;
}

class CTCDecodeHelper {
public:
CTCDecodeHelper() : top_paths_(1) {}

inline int GetTopPaths() const { return top_paths_; }
void SetTopPaths(int tp) { top_paths_ = tp; }

Status ValidateInputsGenerateOutputs(
  OpKernelContext* ctx, const Tensor** inputs, const Tensor** seq_len,
  Tensor** log_prob, OpOutputList* decoded_indices,
  OpOutputList* decoded_values, OpOutputList* decoded_shape) const {
 Status status = ctx->input("inputs", inputs);
 if (!status.ok()) return status;
 status = ctx->input("sequence_length", seq_len);
 if (!status.ok()) return status;

const TensorShape& inputs_shape = (*inputs)->shape();

if (inputs_shape.dims() != 3) {
  return errors::InvalidArgument("inputs is not a 3-Tensor");
 }

const int64 max_time = inputs_shape.dim_size(0);
 const int64 batch_size = inputs_shape.dim_size(1);

if (max_time == 0) {
  return errors::InvalidArgument("max_time is 0");
 }
 if (!TensorShapeUtils::IsVector((*seq_len)->shape())) {
  return errors::InvalidArgument("sequence_length is not a vector");
 }

if (!(batch_size == (*seq_len)->dim_size(0))) {
  return errors::FailedPrecondition(
    "len(sequence_length) != batch_size. ", "len(sequence_length): ",
    (*seq_len)->dim_size(0), " batch_size: ", batch_size);
 }

auto seq_len_t = (*seq_len)->vec<int32>();

for (int b = 0; b < batch_size; ++b) {
  if (!(seq_len_t(b) <= max_time)) {
   return errors::FailedPrecondition("sequence_length(", b, ") <= ",
                    max_time);
  }
 }

Status s = ctx->allocate_output(
   "log_probability", TensorShape({batch_size, top_paths_}), log_prob);
 if (!s.ok()) return s;

s = ctx->output_list("decoded_indices", decoded_indices);
 if (!s.ok()) return s;
 s = ctx->output_list("decoded_values", decoded_values);
 if (!s.ok()) return s;
 s = ctx->output_list("decoded_shape", decoded_shape);
 if (!s.ok()) return s;

return Status::OK();
}

// sequences[b][p][ix] stores decoded value "ix" of path "p" for batch "b".
Status StoreAllDecodedSequences(
  const std::vector<std::vector<std::vector<int> > >& sequences,
  OpOutputList* decoded_indices, OpOutputList* decoded_values,
  OpOutputList* decoded_shape) const {
 // Calculate the total number of entries for each path
 const int64 batch_size = sequences.size();
 std::vector<int64> num_entries(top_paths_, 0);

// Calculate num_entries per path
 for (const auto& batch_s : sequences) {
  CHECK_EQ(batch_s.size(), top_paths_);
  for (int p = 0; p < top_paths_; ++p) {
   num_entries[p] += batch_s[p].size();
  }
 }

for (int p = 0; p < top_paths_; ++p) {
  Tensor* p_indices = nullptr;
  Tensor* p_values = nullptr;
  Tensor* p_shape = nullptr;

const int64 p_num = num_entries[p];

Status s =
    decoded_indices->allocate(p, TensorShape({p_num, 2}), &p_indices);
  if (!s.ok()) return s;
  s = decoded_values->allocate(p, TensorShape({p_num}), &p_values);
  if (!s.ok()) return s;
  s = decoded_shape->allocate(p, TensorShape({2}), &p_shape);
  if (!s.ok()) return s;

auto indices_t = p_indices->matrix<int64>();
  auto values_t = p_values->vec<int64>();
  auto shape_t = p_shape->vec<int64>();

int64 max_decoded = 0;
  int64 offset = 0;

for (int64 b = 0; b < batch_size; ++b) {
   auto& p_batch = sequences[b][p];
   int64 num_decoded = p_batch.size();
   max_decoded = std::max(max_decoded, num_decoded);
   std::copy_n(p_batch.begin(), num_decoded, &values_t(offset));
   for (int64 t = 0; t < num_decoded; ++t, ++offset) {
    indices_t(offset, 0) = b;
    indices_t(offset, 1) = t;
   }
  }

shape_t(0) = batch_size;
  shape_t(1) = max_decoded;
 }
 return Status::OK();
}

private:
int top_paths_;
TF_DISALLOW_COPY_AND_ASSIGN(CTCDecodeHelper);
};

// CTC beam search
class CTCBeamSearchDecoderWithParamOp : public OpKernel {
public:
explicit CTCBeamSearchDecoderWithParamOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
 OP_REQUIRES_OK(ctx, ctx->GetAttr("merge_repeated", &merge_repeated_));
 OP_REQUIRES_OK(ctx, ctx->GetAttr("beam_width", &beam_width_));
 //从参数列表中读取新添的两个参数
 OP_REQUIRES_OK(ctx, ctx->GetAttr("label_selection_size", &label_selection_size));
 OP_REQUIRES_OK(ctx, ctx->GetAttr("label_selection_margin", &label_selection_margin));
 int top_paths;
 OP_REQUIRES_OK(ctx, ctx->GetAttr("top_paths", &top_paths));
 decode_helper_.SetTopPaths(top_paths);
}

void Compute(OpKernelContext* ctx) override {
 const Tensor* inputs;
 const Tensor* seq_len;
 Tensor* log_prob = nullptr;
 OpOutputList decoded_indices;
 OpOutputList decoded_values;
 OpOutputList decoded_shape;
 OP_REQUIRES_OK(ctx, decode_helper_.ValidateInputsGenerateOutputs(
             ctx, &inputs, &seq_len, &log_prob, &decoded_indices,
             &decoded_values, &decoded_shape));

auto inputs_t = inputs->tensor<float, 3>();
 auto seq_len_t = seq_len->vec<int32>();
 auto log_prob_t = log_prob->matrix<float>();

const TensorShape& inputs_shape = inputs->shape();

const int64 max_time = inputs_shape.dim_size(0);
 const int64 batch_size = inputs_shape.dim_size(1);
 const int64 num_classes_raw = inputs_shape.dim_size(2);
 OP_REQUIRES(
   ctx, FastBoundsCheck(num_classes_raw, std::numeric_limits<int>::max()),
   errors::InvalidArgument("num_classes cannot exceed max int"));
 const int num_classes = static_cast<const int>(num_classes_raw);

log_prob_t.setZero();

std::vector<TTypes<float>::UnalignedConstMatrix> input_list_t;

for (std::size_t t = 0; t < max_time; ++t) {
  input_list_t.emplace_back(inputs_t.data() + t * batch_size * num_classes,
               batch_size, num_classes);
 }

ctc::CTCBeamSearchDecoder<> beam_search(num_classes, beam_width_,
                     &beam_scorer_, 1 /* batch_size */,
                     merge_repeated_);
 //使用传入的两个参数进行Set
 beam_search.SetLabelSelectionParameters(label_selection_size, label_selection_margin);
 Tensor input_chip(DT_FLOAT, TensorShape({num_classes}));
 auto input_chip_t = input_chip.flat<float>();

std::vector<std::vector<std::vector<int> > > best_paths(batch_size);
 std::vector<float> log_probs;

// Assumption: the blank index is num_classes - 1
 for (int b = 0; b < batch_size; ++b) {
  auto& best_paths_b = best_paths[b];
  best_paths_b.resize(decode_helper_.GetTopPaths());
  for (int t = 0; t < seq_len_t(b); ++t) {
   input_chip_t = input_list_t[t].chip(b, 0);
   auto input_bi =
     Eigen::Map<const Eigen::ArrayXf>(input_chip_t.data(), num_classes);
   beam_search.Step(input_bi);
  }
  OP_REQUIRES_OK(
    ctx, beam_search.TopPaths(decode_helper_.GetTopPaths(), &best_paths_b,
                 &log_probs, merge_repeated_));

beam_search.Reset();

for (int bp = 0; bp < decode_helper_.GetTopPaths(); ++bp) {
   log_prob_t(b, bp) = log_probs[bp];
  }
 }

OP_REQUIRES_OK(ctx, decode_helper_.StoreAllDecodedSequences(
             best_paths, &decoded_indices, &decoded_values,
             &decoded_shape));
}

private:
CTCDecodeHelper decode_helper_;
ctc::CTCBeamSearchDecoder<>::DefaultBeamScorer beam_scorer_;
bool merge_repeated_;
int beam_width_;
//新添两个数据成员,用于存储新加的参数
int label_selection_size;
float label_selection_margin;
TF_DISALLOW_COPY_AND_ASSIGN(CTCBeamSearchDecoderWithParamOp);
};

REGISTER_KERNEL_BUILDER(Name("CTCBeamSearchDecoderWithParam").Device(DEVICE_CPU),
           CTCBeamSearchDecoderWithParamOp);

将自定义的Op编译成.so文件

在tensorflow-master目录下新建一个文件夹custom_op

cd custom_op

新建一个BUILD文件,并在其中添加如下代码:


cc_library(
 name = "ctc_decoder_with_param",
 srcs = [
     "new_beamsearch.cc"
     ] +
     glob(["boost_locale/**/*.hpp"]),
 includes = ["boost_locale"],
 copts = ["-std=c++11"],
 deps = ["//tensorflow/core:core",
     "//tensorflow/core/util/ctc",
     "//third_party/eigen3",
 ],
)

编译过程:

1. cd 到 tensorflow-master 目录下

2. bazel build -c opt --copt=-O3 //tensorflow:libtensorflow_cc.so //custom_op:ctc_decoder_with_param

3. bazel-bin/custom_op 目录下生成 libctc_decoder_with_param.so

在训练(预测)程序中使用自定义的Op

在程序中定义如下的方法:


decode_param_op_module = tf.load_op_library('libctc_decoder_with_param.so')
def decode_with_param(inputs, sequence_length, beam_width=100,
         top_paths=1, merge_repeated=True):
 decoded_ixs, decoded_vals, decoded_shapes, log_probabilities = (
   decode_param_op_module.ctc_beam_search_decoder_with_param(
     inputs, sequence_length, beam_width=beam_width,
     top_paths=top_paths, merge_repeated=merge_repeated,
     label_selection_size=40, label_selection_margin=0.99))
 return (
   [tf.SparseTensor(ix, val, shape) for (ix, val, shape)
    in zip(decoded_ixs, decoded_vals, decoded_shapes)],
   log_probabilities)

然后就可以像使用tf.nn.ctc_beam_search_decoder一样使用该Op了。

来源:https://blog.csdn.net/sinat_37532065/article/details/92085177

标签:TensorFlow,自定义,Op
0
投稿

猜你喜欢

  • asp如何在聊天室实现趣味答题并计分功能?

    2010-06-18 20:00:00
  • php中Ctype函数用法详解

    2023-06-19 01:43:13
  • python归并排序算法过程实例讲解

    2023-12-02 23:28:06
  • 在ASP.NET 2.0中操作数据之六十四:GridView批量添加数据

    2024-06-05 09:27:17
  • Python基于pygame实现单机版五子棋对战

    2021-02-26 05:53:54
  • 实例讲解SQL Server加密功能

    2024-01-24 04:05:15
  • js 禁用只读文本框获得焦点时的退格键

    2024-04-19 10:25:41
  • 读取MySQL的log方法

    2011-07-01 12:04:57
  • 解决golang 反射interface{}做零值判断的一个重大坑

    2024-05-21 10:24:27
  • pytorch中F.avg_pool1d()和F.avg_pool2d()的使用操作

    2023-11-30 01:20:41
  • caffe的python接口绘制loss和accuracy曲线

    2021-06-15 23:54:21
  • JS 循环li添加点击事件 (闭包的应用)

    2024-04-10 10:48:45
  • python实现生成字符串大小写字母和数字的各种组合

    2021-02-09 07:33:00
  • Python中生成Epoch的方法

    2021-06-27 15:21:15
  • Git Submodule使用完整教程(小结)

    2022-06-29 21:01:06
  • 详解Golang中interface接口的原理和使用技巧

    2024-04-26 17:29:10
  • vue.js 动态组件详解

    2024-04-29 13:09:58
  • 使用Python进行QQ批量登录的实例代码

    2021-07-14 16:58:04
  • 跟老齐学Python之dict()的操作方法

    2022-05-12 16:54:43
  • mysql_connect(): Connection using old (pre-4.1.1) authentication protocol refused

    2024-01-23 08:22:11
  • asp之家 网络编程 m.aspxhome.com