pytorch中使用cuda扩展的实现示例
作者:outthinker 时间:2021-02-17 23:46:55
以下面这个例子作为教程,实现功能是element-wise add;
(pytorch中想调用cuda模块,还是用另外使用C编写接口脚本)
第一步:cuda编程的源文件和头文件
// mathutil_cuda_kernel.cu
// 头文件,最后一个是cuda特有的
#include <curand.h>
#include <stdio.h>
#include <math.h>
#include <float.h>
#include "mathutil_cuda_kernel.h"
// 获取GPU线程通道信息
dim3 cuda_gridsize(int n)
{
int k = (n - 1) / BLOCK + 1;
int x = k;
int y = 1;
if(x > 65535) {
x = ceil(sqrt(k));
y = (n - 1) / (x * BLOCK) + 1;
}
dim3 d(x, y, 1);
return d;
}
// 这个函数是cuda执行函数,可以看到细化到了每一个元素
__global__ void broadcast_sum_kernel(float *a, float *b, int x, int y, int size)
{
int i = (blockIdx.x + blockIdx.y * gridDim.x) * blockDim.x + threadIdx.x;
if(i >= size) return;
int j = i % x; i = i / x;
int k = i % y;
a[IDX2D(j, k, y)] += b[k];
}
// 这个函数是与c语言函数链接的接口函数
void broadcast_sum_cuda(float *a, float *b, int x, int y, cudaStream_t stream)
{
int size = x * y;
cudaError_t err;
// 上面定义的函数
broadcast_sum_kernel<<<cuda_gridsize(size), BLOCK, 0, stream>>>(a, b, x, y, size);
err = cudaGetLastError();
if (cudaSuccess != err)
{
fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
exit(-1);
}
}
#ifndef _MATHUTIL_CUDA_KERNEL
#define _MATHUTIL_CUDA_KERNEL
#define IDX2D(i, j, dj) (dj * i + j)
#define IDX3D(i, j, k, dj, dk) (IDX2D(IDX2D(i, j, dj), k, dk))
#define BLOCK 512
#define MAX_STREAMS 512
#ifdef __cplusplus
extern "C" {
#endif
void broadcast_sum_cuda(float *a, float *b, int x, int y, cudaStream_t stream);
#ifdef __cplusplus
}
#endif
#endif
第二步:C编程的源文件和头文件(接口函数)
// mathutil_cuda.c
// THC是pytorch底层GPU库
#include <THC/THC.h>
#include "mathutil_cuda_kernel.h"
extern THCState *state;
int broadcast_sum(THCudaTensor *a_tensor, THCudaTensor *b_tensor, int x, int y)
{
float *a = THCudaTensor_data(state, a_tensor);
float *b = THCudaTensor_data(state, b_tensor);
cudaStream_t stream = THCState_getCurrentStream(state);
// 这里调用之前在cuda中编写的接口函数
broadcast_sum_cuda(a, b, x, y, stream);
return 1;
}
int broadcast_sum(THCudaTensor *a_tensor, THCudaTensor *b_tensor, int x, int y);
第三步:编译,先编译cuda模块,再编译接口函数模块(不能放在一起同时编译)
nvcc -c -o mathutil_cuda_kernel.cu.o mathutil_cuda_kernel.cu -x cu -Xcompiler -fPIC -arch=sm_52
import os
import torch
from torch.utils.ffi import create_extension
this_file = os.path.dirname(__file__)
sources = []
headers = []
defines = []
with_cuda = False
if torch.cuda.is_available():
print('Including CUDA code.')
sources += ['src/mathutil_cuda.c']
headers += ['src/mathutil_cuda.h']
defines += [('WITH_CUDA', None)]
with_cuda = True
this_file = os.path.dirname(os.path.realpath(__file__))
extra_objects = ['src/mathutil_cuda_kernel.cu.o'] # 这里是编译好后的.o文件位置
extra_objects = [os.path.join(this_file, fname) for fname in extra_objects]
ffi = create_extension(
'_ext.cuda_util',
headers=headers,
sources=sources,
define_macros=defines,
relative_to=__file__,
with_cuda=with_cuda,
extra_objects=extra_objects
)
if __name__ == '__main__':
ffi.build()
第四步:调用cuda模块
from _ext import cuda_util #从对应路径中调用编译好的模块
a = torch.randn(3, 5).cuda()
b = torch.randn(3, 1).cuda()
mathutil.broadcast_sum(a, b, *map(int, a.size()))
# 上面等价于下面的效果:
a = torch.randn(3, 5)
b = torch.randn(3, 1)
a += b
来源:https://www.cnblogs.com/zf-blog/p/11883166.html
标签:pytorch,cuda
0
投稿
猜你喜欢
python 爬虫如何实现百度翻译
2023-02-20 18:33:51
python3使用python-redis-lock解决并发计算问题
2021-05-09 16:04:18
python名片管理系统开发
2022-06-25 13:43:47
MySQL中replace into语句的用法详解
2024-01-20 10:45:53
Python中optparse模块使用浅析
2023-10-21 06:55:38
python 自动监控最新邮件并读取的操作
2023-02-04 12:58:51
javascript自启动函数的问题探讨
2024-04-30 08:55:57
asp如何随机显示网站链接?
2010-06-07 20:40:00
python执行CMD指令,并获取返回的方法
2021-10-19 02:52:40
asp如何实现无组件上传二进制文件?
2010-06-03 10:09:00
python用socket实现协议TCP长连接框架
2022-05-08 00:22:47
利用Python Matlab绘制曲线图的简单实例
2021-05-16 07:21:38
在SQL查询中使用LIKE来代替IN查询的方法
2011-09-30 11:10:18
Python request post上传文件常见要点
2022-11-05 09:27:14
对vue.js中this.$emit的深入理解
2024-04-26 17:40:12
轻松处理Dreamweaver段落缩进
2007-11-17 07:53:00
HTML 5 正在改变 Web
2008-09-15 08:20:00
从两个方面讲解SQL Server口令的脆弱性
2009-01-08 13:40:00
Linux下编译安装Mysql 5.5的简单步骤
2024-01-27 13:33:03
500行Python代码打造刷脸考勤系统
2022-01-21 12:54:10