Python tensorflow与pytorch的浮点运算数如何计算

作者:浩哥依然 时间:2023-06-28 14:13:15 

1. 引言

FLOPs 是 floating point operations 的缩写,指浮点运算数,可以用来衡量模型/算法的计算复杂度。本文主要讨论如何在 tensorflow 1.x, tensorflow 2.x 以及 pytorch 中利用相关工具计算对应模型的 FLOPs。

2. 模型结构

为了说明方便,先搭建一个简单的神经网络模型,其模型结构以及主要参数如表1 所示。

表 1 模型结构及主要参数

LayerschannelsKernelsStridesUnitsActivation
Conv2D32(4,4)(1,2)\relu
GRU\\\96\
Dense\\\256sigmoid

用 tensorflow(实际使用 tensorflow 中的 keras 模块)实现该模型的代码为:

from tensorflow.keras.layers import *
from tensorflow.keras.models import load_model, Model
def test_model_tf(Input_shape):
   # shape: [B, C, T, F]
   main_input = Input(batch_shape=Input_shape, name='main_inputs')
   conv = Conv2D(32, kernel_size=(4, 4), strides=(1, 2), activation='relu', data_format='channels_first', name='conv')(main_input)
   # shape: [B, T, FC]
   gru = Reshape((conv.shape[2], conv.shape[1] * conv.shape[3]))(conv)
   gru = GRU(units=96, reset_after=True, return_sequences=True, name='gru')(gru)
   output = Dense(256, activation='sigmoid', name='output')(gru)
   model = Model(inputs=[main_input], outputs=[output])
   return model

用 pytorch 实现该模型的代码为:

import torch
import torch.nn as nn
class test_model_torch(nn.Module):
   def __init__(self):
       super(test_model_torch, self).__init__()
       self.conv2d = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=(4,4), stride=(1,2))
       self.relu = nn.ReLU()
       self.gru = nn.GRU(input_size=4064, hidden_size=96)
       self.fc = nn.Linear(96, 256)
       self.sigmoid = nn.Sigmoid()
   def forward(self, inputs):
       # shape: [B, C, T, F]
       out = self.conv2d(inputs)
       out = self.relu(out)
       # shape: [B, T, FC]
       batch, channel, frame, freq = out.size()
       out = torch.reshape(out, (batch, frame, freq*channel))
       out, _ = self.gru(out)
       out = self.fc(out)
       out = self.sigmoid(out)
       return out

3. 计算模型的 FLOPs

本节讨论的版本具体为:tensorflow 1.12.0, tensorflow 2.3.1 以及 pytorch 1.10.1+cu102。

3.1. tensorflow 1.12.0

在 tensorflow 1.12.0 环境中,可以使用以下代码计算模型的 FLOPs:

import tensorflow as tf
import tensorflow.keras.backend as K
def get_flops(model):
   run_meta = tf.RunMetadata()
   opts = tf.profiler.ProfileOptionBuilder.float_operation()
   flops = tf.profiler.profile(graph=K.get_session().graph,
                               run_meta=run_meta, cmd='op', options=opts)
   return flops.total_float_ops
if __name__ == "__main__":
   x = K.random_normal(shape=(1, 1, 100, 256))
   model = test_model_tf(x.shape)
   print('FLOPs of tensorflow 1.12.0:', get_flops(model))

3.2. tensorflow 2.3.1

在 tensorflow 2.3.1 环境中,可以使用以下代码计算模型的 FLOPs :

import tensorflow.compat.v1 as tf
import tensorflow.compat.v1.keras.backend as K
tf.disable_eager_execution()
def get_flops(model):
   run_meta = tf.RunMetadata()
   opts = tf.profiler.ProfileOptionBuilder.float_operation()
   flops = tf.profiler.profile(graph=K.get_session().graph,
                               run_meta=run_meta, cmd='op', options=opts)
   return flops.total_float_ops
if __name__ == "__main__":
   x = K.random_normal(shape=(1, 1, 100, 256))
   model = test_model_tf(x.shape)
   print('FLOPs of tensorflow 2.3.1:', get_flops(model))

3.3. pytorch 1.10.1+cu102

在 pytorch 1.10.1+cu102 环境中,可以使用以下代码计算模型的 FLOPs(需要安装 thop):

import thop
x = torch.randn(1, 1, 100, 256)
model = test_model_torch()
flops, _ = thop.profile(model, inputs=(x,))
print('FLOPs of pytorch 1.10.1:', flops * 2)

需要注意的是,thop 返回的是 MACs (Multiply–Accumulate Operations),其等于 2 2 2 倍的 FLOPs,所以上述代码有乘 2 2 2 操作。

3.4. 结果对比

三者计算出的 FLOPs 分别为:

tensorflow 1.12.0:

Python tensorflow与pytorch的浮点运算数如何计算

tensorflow 2.3.1:

Python tensorflow与pytorch的浮点运算数如何计算

pytorch 1.10.1:

Python tensorflow与pytorch的浮点运算数如何计算

可以看到 tensorflow 1.12.0 和 tensorflow 2.3.1 的结果基本在同一个量级,而与 pytorch 1.10.1 计算出来的相差甚远。但如果将上述模型结构改为只包含第一层 Conv2D,三者计算出来的 FLOPs 却又是一致的。所以推断差异主要来自于 GRU 的 FLOPs。如读者知道其中详情,还请不吝赐教。

4. 总结

本文给出了在 tensorflow 1.x, tensorflow 2.x 以及 pytorch 中利用相关工具计算模型 FLOPs 的方法,但从本文所使用的测试模型来看, tensorflow 与 pytorch 统计出的结果相差甚远。当然,也可以根据网络层的类型及其对应的参数,推导计算出每个网络层所需的 FLOPs。

来源:https://blog.csdn.net/wjrenxinlei/article/details/127973081

标签:Python,tensorflow,pytorch,浮点运算数
0
投稿

猜你喜欢

  • Matlab中plot基本用法的具体使用

    2022-08-14 10:28:24
  • Go如何优雅的使用字节池示例详解

    2024-02-10 21:10:17
  • z-index在IE中的迷惑

    2007-05-11 16:50:00
  • Python偏函数实现原理及应用

    2022-12-13 17:12:03
  • python 构造三维全零数组的方法

    2022-05-11 06:01:20
  • Python操作串口的方法

    2021-11-24 07:09:10
  • server.mappath方法详解

    2023-07-05 08:07:48
  • Java中使用正则表达式的一个简单例子及常用正则分享

    2023-05-06 09:03:16
  • flask session组件的使用示例

    2022-03-17 08:13:07
  • Python第三方库face_recognition在windows上的安装过程

    2023-07-27 02:51:29
  • 如何恢复MySQL主从数据一致性

    2024-01-26 23:34:33
  • ASP 写的判断 Money 各个位值的函数

    2008-04-13 06:36:00
  • Golang 按行读取文件的三种方法小结

    2024-02-20 18:45:29
  • selenium+python自动化测试之使用webdriver操作浏览器的方法

    2023-06-28 03:04:47
  • ThinkPHP 3.2.3实现页面静态化功能的方法详解

    2023-11-23 13:12:53
  • 详解Vuex管理登录状态

    2024-04-26 17:38:02
  • 给zblog加上运行代码功能

    2007-12-19 13:07:00
  • Python的randrange()方法使用教程

    2021-02-08 10:22:22
  • MySQL Workbench操作图文详解(史上最细)

    2024-01-14 01:44:14
  • python的迭代器,生成器和装饰器你了解吗

    2024-01-02 12:45:12
  • asp之家 网络编程 m.aspxhome.com