pytorch中使用cuda扩展的实现示例

作者:outthinker 时间:2021-02-17 23:46:55 

以下面这个例子作为教程,实现功能是element-wise add;

(pytorch中想调用cuda模块,还是用另外使用C编写接口脚本)

第一步:cuda编程的源文件和头文件


// mathutil_cuda_kernel.cu
// 头文件,最后一个是cuda特有的
#include <curand.h>
#include <stdio.h>
#include <math.h>
#include <float.h>
#include "mathutil_cuda_kernel.h"

// 获取GPU线程通道信息
dim3 cuda_gridsize(int n)
{
 int k = (n - 1) / BLOCK + 1;
 int x = k;
 int y = 1;
 if(x > 65535) {
   x = ceil(sqrt(k));
   y = (n - 1) / (x * BLOCK) + 1;
 }
 dim3 d(x, y, 1);
 return d;
}
// 这个函数是cuda执行函数,可以看到细化到了每一个元素
__global__ void broadcast_sum_kernel(float *a, float *b, int x, int y, int size)
{
 int i = (blockIdx.x + blockIdx.y * gridDim.x) * blockDim.x + threadIdx.x;
 if(i >= size) return;
 int j = i % x; i = i / x;
 int k = i % y;
 a[IDX2D(j, k, y)] += b[k];
}

// 这个函数是与c语言函数链接的接口函数
void broadcast_sum_cuda(float *a, float *b, int x, int y, cudaStream_t stream)
{
 int size = x * y;
 cudaError_t err;

// 上面定义的函数
 broadcast_sum_kernel<<<cuda_gridsize(size), BLOCK, 0, stream>>>(a, b, x, y, size);

err = cudaGetLastError();
 if (cudaSuccess != err)
 {
   fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
   exit(-1);
 }
}

#ifndef _MATHUTIL_CUDA_KERNEL
#define _MATHUTIL_CUDA_KERNEL

#define IDX2D(i, j, dj) (dj * i + j)
#define IDX3D(i, j, k, dj, dk) (IDX2D(IDX2D(i, j, dj), k, dk))

#define BLOCK 512
#define MAX_STREAMS 512

#ifdef __cplusplus
extern "C" {
#endif

void broadcast_sum_cuda(float *a, float *b, int x, int y, cudaStream_t stream);

#ifdef __cplusplus
}
#endif

#endif

第二步:C编程的源文件和头文件(接口函数)


// mathutil_cuda.c
// THC是pytorch底层GPU库
#include <THC/THC.h>
#include "mathutil_cuda_kernel.h"

extern THCState *state;

int broadcast_sum(THCudaTensor *a_tensor, THCudaTensor *b_tensor, int x, int y)
{
 float *a = THCudaTensor_data(state, a_tensor);
 float *b = THCudaTensor_data(state, b_tensor);
 cudaStream_t stream = THCState_getCurrentStream(state);

// 这里调用之前在cuda中编写的接口函数
 broadcast_sum_cuda(a, b, x, y, stream);

return 1;
}


int broadcast_sum(THCudaTensor *a_tensor, THCudaTensor *b_tensor, int x, int y);

第三步:编译,先编译cuda模块,再编译接口函数模块(不能放在一起同时编译)


nvcc -c -o mathutil_cuda_kernel.cu.o mathutil_cuda_kernel.cu -x cu -Xcompiler -fPIC -arch=sm_52

import os
import torch
from torch.utils.ffi import create_extension

this_file = os.path.dirname(__file__)

sources = []
headers = []
defines = []
with_cuda = False

if torch.cuda.is_available():
 print('Including CUDA code.')
 sources += ['src/mathutil_cuda.c']
 headers += ['src/mathutil_cuda.h']
 defines += [('WITH_CUDA', None)]
 with_cuda = True

this_file = os.path.dirname(os.path.realpath(__file__))

extra_objects = ['src/mathutil_cuda_kernel.cu.o']  # 这里是编译好后的.o文件位置
extra_objects = [os.path.join(this_file, fname) for fname in extra_objects]

ffi = create_extension(
 '_ext.cuda_util',
 headers=headers,
 sources=sources,
 define_macros=defines,
 relative_to=__file__,
 with_cuda=with_cuda,
 extra_objects=extra_objects
)

if __name__ == '__main__':
 ffi.build()

第四步:调用cuda模块


from _ext import cuda_util #从对应路径中调用编译好的模块

a = torch.randn(3, 5).cuda()
b = torch.randn(3, 1).cuda()
mathutil.broadcast_sum(a, b, *map(int, a.size()))

# 上面等价于下面的效果:

a = torch.randn(3, 5)
b = torch.randn(3, 1)
a += b

来源:https://www.cnblogs.com/zf-blog/p/11883166.html

标签:pytorch,cuda
0
投稿

猜你喜欢

  • python使用正则表达式(Regular Expression)方法超详细

    2022-09-07 14:30:53
  • Python中struct模块对字节流/二进制流的操作教程

    2021-05-10 19:26:56
  • python中protobuf和json互相转换应用处理方法

    2023-03-15 11:50:55
  • Python 有可能删除 GIL 吗?

    2023-02-12 15:52:21
  • javascript拼音搜索引擎

    2011-08-29 15:42:14
  • Python 中random 库的详细使用

    2022-01-19 05:35:14
  • Python使用tkinter库实现文本显示用户输入功能示例

    2023-09-21 11:00:17
  • 用css3-tranistions实现平滑过渡

    2009-12-23 19:24:00
  • django 在原有表格添加或删除字段的实例

    2023-11-25 04:21:08
  • python基础面试题整理

    2023-11-03 02:09:45
  • python基本语法练习实例

    2021-02-25 06:50:07
  • python 实现网易邮箱邮件阅读和删除的辅助小脚本

    2022-03-17 10:55:34
  • asp用正则解析远程图片地址,用XMLHTTP将其保存

    2007-10-26 12:34:00
  • Python pandas对excel的操作实现示例

    2023-09-25 18:24:13
  • 关于Python可视化Dash工具之plotly基本图形示例详解

    2023-08-13 15:51:57
  • 对python中的pop函数和append函数详解

    2021-10-09 09:11:32
  • 如何实现My SQL中的用户的管理问题

    2008-12-03 13:56:00
  • 朋友去一家游戏公司的机试题,被难住了

    2009-11-29 15:23:00
  • Response.Flush的使用心得

    2010-04-08 12:57:00
  • pycharm 使用anaconda为默认环境的操作

    2023-10-08 12:37:25
  • asp之家 网络编程 m.aspxhome.com