LSTM的cuda加速

TensorFlow LSTM benchmark

TensorFlow提供5种LSTM变体:(1)BasicLSTMCell,(2)LSTMCell,(3)LSTMBlockCell,(4)LSTMBlockFusedCell和(5)cuDNNLSTM

测试环境GTX1080Ti:seqLength 100, numLayers 1, hiddenSize 512, miniBatch 64 执行1000次取均值

Cell CPU GPU 分析(CPU and GPU)
BasicLSTMCell 195.74ms 46.89ms 时间步上用tf.dynamic_rnn循环,由于实现简单,速度相对较快
LSTMCell 202.13ms 50.83ms 标准LSTM,有更多的参数选择,速度相差不多
LSTMBlockCell 254.76ms 48.72ms 用于单运行RNN交互场景,时间步上和tf.while_loop结合使用,CPU上速度明显慢,理论在GPU上应该比Basic快,实际速度相差不多
LSTMBlockFusedCell 190.48ms 23.06ms CPU上速度轻微加快,GPU如预期速度快很多
cuDNNLSTM / 18.63ms 在CPU上默认执行LSTMBlockFused,在GPU上是5种变体速度最快的
cudaLSTM / 11.36ms 自己写的源码如下,比tf自带版本都要快

具体流程

定义注册op的接口

即规划输入和输出个数,张量尺寸等,在此之前,因为我们要对TensorFlow库进行修改,需要用源码编译方式重新安装TensorFlow,这部分参考官方例子完成,而定义注册算子接口,需要实现一个c文件,继承Opkernal(TensorFlow对外提供的类接口)类后,重写Compute(Opkernal类下的接口函数)函数,会默认传入OpKernelContext(TensorFlow中的一种定义好的类结构)指针,根据这个指针利用OP_REQUIRES(TensorFlow提供的接口函数)申请张量,输入输出均定义后调用重写的MyLSTM核函数,这个核函数在cu文件中定义,最后使用REGISTER_KERNEL_BUILDER(一种默认注册函数)指定可运行的设备和核函数环境。

为op实现Kernal

即编写cu文件,整个LSTM核算子实现经历了四次迭代,分别为v0,v1,v2,v3四个版本,v2是四层分开的通用版本,v3是特性针对顺滑四层网络结构一体化版本,主要优化有

  1. 每次启用核函数十分耗时,因此将多个block(显卡单元的一种单位)尽可能融合进一个核进行计算,减少<<< >>>(核启动的默认写法,这里代表一次核启动操作)的操作,如把向量加法和向量乘法融合进元素集合操作中,但是要特别注意在核函数中用blockIdx.x(线程块索引) * blockDim.x(线程块维度索引) + threadIdx.x(线程索引)来控制总的索引,进而控制某个门的运算或隐层单元激活等复杂的计算。

  2. 四层网络需要尽可能的并行,LSTM可分为两种流向进行,横向时间步流向使用CUDA流控制,纵向层流向使用CUDA事件进行控制,这两部分流向用同步和等待函数控制,并行的同时又严格遵循计算步骤。v3版内置了每一层的隐层个数并再初始化时直接初始化四层所需要的所有参数,因此是在不改变模型的前提下最快的实现版本。

  3. LSTM有多个矩阵运算,四个门的计算可以进行矩阵拼接。如图5-7中遗忘门的计算ft=(Wixxt+Wimmt-1+bi),Wixxi和Wimmt-1在每个时间步进行循环计算,xi和时间状态无关,即可把每个时间步的输入xi拼接后再和参数矩阵相乘,而对Wimmt-1需要使用for循环实现,这两个矩阵乘法需要合作来计算遗忘门,但相互却并不影响,因此可以并行计算两个矩阵乘,然后得到两个矩阵后再传入元素计算函数,元素计算函数即使用了优化1把激活步骤和加和步骤原本两个启动核变为一个启动核,在整个函数里进行门控单元的计算。

  4. 最后,由于之前进行了大量的矩阵拼接,算子中便包含很多矩阵运算,测试表明时间占比也多数花在这些运算上,因此矩阵优化也是重中之重,矩阵运算可以使用cuBLAS(英伟达开发的矩阵运算库),MKL(Intel Math Kernel Library,英特尔开发的矩阵运算库),Engine(C++自带的开源几何运算库)库,详细速度对比测试结果在第六章系统测试中阐述,最终结论是cuBLAS是综合表现最快的,因此选用cuBLAS进行矩阵运算,这里还有一个可大幅提速的点是cuBLAS默认是列读入,而我们数据是行排列,因此可以使用数学转置的小技巧进行优化,比如要计算C=AB,使用BT*AT计算得到CT,最后再转置会按列大大加速计算。

混合编译cu文件和c文件

注意要在源码编译下的TensorFlow文件夹内编译,即~/TensorFlow/TensorFlow/core/user_ops文件夹下,使用NVCC(类似GCC编译器,英伟达针对CUDA语言的一种编译器脚本,nvcc -o lib00_lstm.so -shared -Xcompiler -fPIC 00_lstm.cu -lcublas -lcurand -L /usr/local/cuda/lib64/ -std=c++11)语句把需要的cuBLAS和cuRAND函数库链接,编译出LSTM的CUDA库,然后GCC编译注册文件把核函数链接进来成为总库,或写CMake文件编译,得到最后的CUDA_LSTM_forward.so动态库文件。

python网络调用

使用TensorFlow自带的函数load_op_library找到我们上一步编译好的动态库文件载入,如图5-8所示,使用的v3版本需要传入4层的状态列表,返回4层的状态输出,虽然最终只需要最后一层的结果,但是四层结果全部返回方便进行一致性比对和调试,可在工程代码中再对返回进行优化。

收尾工作

代码部分已完成,需要测试自己实现的算子,并保证计算过程的正确性和计算结果的一致性,即所谓的一致性比对。使用同样的输入,同样的模型参数配置,对比TensorFlow自带的LSTMCell和自己实现的算子MyLSTM的输出结果,从最终结果往上一层一层的比对,发现不一致的错误点,优化迭代,删除冗余代码,最后对各版本进行整理,整合到内核代码。工程项目文件夹如图5-9所示,paper文件夹是所有参考到的论文,benchmark文件夹是测试TensorFlow5个运算算子运算的时间,test_consistent文件夹进行一致性的比对,README整理了各版本的使用方法和使用环境,是一个很完整的顺滑优化项目,总工程大概历时三个月完成,但效果十分显著,见测试章节。

源码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
void LSTMTest(int miniBatch, int seqLength, int inputSize, int hiddenSize, int outSize,
float* input, float* c_data_in, float *h_data_in, float* weight_i, float* weight_h, float*bias_data_in, float* w_i_diag_in,
float* w_f_diag_in, float* w_o_diag_in, float* proj_kernel_in, float* c_data_out, float* h_data_out, float* output,
bool use_peepholes = true, float cell_clip = 0.0, float proj_clip = 0.0){

// cudaErrCheck(cudaSetDeviceFlags(cudaDeviceScheduleBlockingSync));
// float* tmp_i, *tmp_h, *c_o_data + 2 * numElements, *w_i_diag, *w_f_diag, *w_o_diag, *proj_kernel, *c_i_data + 2 * miniBatch * h_depth, *input_T, *input_T + miniBatch * input_depth * seqLength + input_depth * gateSize, *input_T + miniBatch * input_depth * seqLength,
// *h_op_data, *h_op_data + miniBatch * h_depth, *c_data_in_cuda, *c_o_data + numElements, *bias;
float *w_diag_bias_proj;



float alpha = 1.f;
float beta = 0.f;
int input_depth = inputSize;
int gateSize = hiddenSize * 4;
int h_depth;
int numElements = miniBatch * hiddenSize;

if(use_peepholes == true)
h_depth = outSize;
else
h_depth = hiddenSize;

int w_diag_bias_proj_size = (7 + h_depth) * hiddenSize;
int h_op_data_size = 2 * miniBatch * h_depth + miniBatch * seqLength * h_depth;
int input_T_size = miniBatch * seqLength * input_depth + (input_depth + h_depth) * gateSize;
int c_o_data_size = 2 * miniBatch * hiddenSize + miniBatch * seqLength * hiddenSize;
int tmp_i_size = miniBatch * seqLength * gateSize;
int tmp_h_size = miniBatch * gateSize;

printf("the seqLength is: %d, inputSize: %d, input_depth: %d, hiddenSize: %d, outSize: %d\n", seqLength, inputSize, input_depth, hiddenSize, outSize);

cudaErrCheck(cudaGetLastError());

//six to one
// cudaErrCheck(cudaMalloc((void**)&tmp_i, miniBatch * seqLength * gateSize * sizeof(float)));
// cudaErrCheck(cudaMalloc((void**)&tmp_h, miniBatch * gateSize * sizeof(float)));
// cudaErrCheck(cudaMalloc((void**)&input_T, (miniBatch * input_depth * seqLength + input_depth * gateSize + h_depth * gateSize ) * sizeof(float)));
// cudaErrCheck(cudaMalloc((void**)&h_op_data, miniBatch * (seqLength+2) * h_depth * sizeof(float)));
// cudaErrCheck(cudaMalloc((void**)&c_o_data, miniBatch * (seqLength+2) * hiddenSize * sizeof(float)));
cudaErrCheck(cudaMalloc((void**)&w_diag_bias_proj, (w_diag_bias_proj_size + input_T_size + h_op_data_size + c_o_data_size + tmp_i_size + tmp_h_size) * sizeof(float)));

//b = a + size(a);
float *input_T = w_diag_bias_proj + w_diag_bias_proj_size;
//c = b + size(b);
float *h_op_data = input_T + input_T_size;
float *c_o_data = h_op_data + h_op_data_size;
float *tmp_i = c_o_data + c_o_data_size;
float *tmp_h = tmp_i + tmp_i_size;

cudaStream_t stream_i, stream_h;
cudaErrCheck(cudaStreamCreate(&stream_i));
cudaErrCheck(cudaStreamCreate(&stream_h));
bool stream_i_flag = true;

//pivot
// cudaErrCheck(cudaMemcpy(input_T, input, miniBatch * input_depth * seqLength * sizeof(float), cudaMemcpyHostToDevice));
// cudaErrCheck(cudaMemcpy(input_T + miniBatch * input_depth * seqLength, weight_i, input_depth * gateSize * sizeof(float), cudaMemcpyHostToDevice));
// cudaErrCheck(cudaMemcpy(input_T + miniBatch * input_depth * seqLength + input_depth * gateSize, weight_h, h_depth * gateSize * sizeof(float), cudaMemcpyHostToDevice));
// // printf("*************************%d\n", hiddenSize);
// cudaErrCheck(cudaMemcpy(h_op_data, h_data_in, h_depth * miniBatch * sizeof(float), cudaMemcpyHostToDevice));
// cudaErrCheck(cudaMemcpy(c_o_data, c_data_in, numElements * sizeof(float), cudaMemcpyHostToDevice));


// // printf("i_data up and i_data_beforeProj down and the seqLength is%d\n", seqLength);

// cudaErrCheck(cudaMemcpy(w_diag_bias_proj, w_i_diag_in, hiddenSize * sizeof(float), cudaMemcpyHostToDevice));
// cudaErrCheck(cudaMemcpy(w_diag_bias_proj + hiddenSize, w_f_diag_in, hiddenSize * sizeof(float), cudaMemcpyHostToDevice));
// cudaErrCheck(cudaMemcpy(w_diag_bias_proj + 2 * hiddenSize, w_o_diag_in, hiddenSize * sizeof(float), cudaMemcpyHostToDevice));
// cudaErrCheck(cudaMemcpy(w_diag_bias_proj + 3 * hiddenSize , bias_data_in, gateSize * sizeof(float), cudaMemcpyHostToDevice));
// cudaErrCheck(cudaMemcpy(w_diag_bias_proj + 7 * hiddenSize, proj_kernel_in, h_depth * hiddenSize * sizeof(float), cudaMemcpyHostToDevice));


cudaErrCheck(cudaMemcpyAsync(input_T, input, miniBatch * input_depth * seqLength * sizeof(float), cudaMemcpyHostToDevice, stream_i));
cudaErrCheck(cudaMemcpyAsync(input_T + miniBatch * input_depth * seqLength, weight_i, input_depth * gateSize * sizeof(float), cudaMemcpyHostToDevice, stream_i));
cudaErrCheck(cudaMemcpyAsync(input_T + miniBatch * input_depth * seqLength + input_depth * gateSize, weight_h, h_depth * gateSize * sizeof(float), cudaMemcpyHostToDevice, stream_h));
// printf("*************************%d\n", hiddenSize);
cudaErrCheck(cudaMemcpyAsync(h_op_data, h_data_in, h_depth * miniBatch * sizeof(float), cudaMemcpyHostToDevice, stream_h));
cudaErrCheck(cudaMemcpyAsync(c_o_data, c_data_in, numElements * sizeof(float), cudaMemcpyHostToDevice, stream_h));


// printf("i_data up and i_data_beforeProj down and the seqLength is%d\n", seqLength);

cudaErrCheck(cudaMemcpyAsync(w_diag_bias_proj, w_i_diag_in, hiddenSize * sizeof(float), cudaMemcpyHostToDevice, stream_h));
cudaErrCheck(cudaMemcpyAsync(w_diag_bias_proj + hiddenSize, w_f_diag_in, hiddenSize * sizeof(float), cudaMemcpyHostToDevice, stream_h));
cudaErrCheck(cudaMemcpyAsync(w_diag_bias_proj + 2 * hiddenSize, w_o_diag_in, hiddenSize * sizeof(float), cudaMemcpyHostToDevice, stream_h));
cudaErrCheck(cudaMemcpyAsync(w_diag_bias_proj + 3 * hiddenSize , bias_data_in, gateSize * sizeof(float), cudaMemcpyHostToDevice, stream_h));
cudaErrCheck(cudaMemcpyAsync(w_diag_bias_proj + 7 * hiddenSize, proj_kernel_in, h_depth * hiddenSize * sizeof(float), cudaMemcpyHostToDevice, stream_h));




cudaErrCheck(cudaGetLastError());

// cudaDeviceSynchronize();
// Need a cuBLAS handle.
cublasHandle_t handle;
cublasErrCheck(cublasCreate(&handle));


cublasErrCheck(cublasSetStream(handle, stream_i));
cublasErrCheck(cublasSgemm(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
gateSize, miniBatch * seqLength, input_depth,
&alpha,
input_T + miniBatch * input_depth * seqLength,
gateSize,
input_T,
input_depth,
&beta,
tmp_i,
gateSize));



cudaErrCheck(cudaGetLastError());

// cudaEvent_t event1, event2;
// cudaEventCreate(&event1);
// cudaEventCreate(&event2);

// printf("######tmp_i: \n");
// float* tmp_i_cpu = init_Matrix_zeros(miniBatch * gateSize);
// cudaMemcpy(tmp_i_cpu, tmp_i, miniBatch * gateSize * sizeof(float), cudaMemcpyDeviceToHost);
// cudaDeviceSynchronize();
// printMatix(tmp_i_cpu, miniBatch * gateSize);

// cudaEventCreateWithFlags(&event1, cudaEventBlockingSync);
// cudaEventCreateWithFlags(&event2, cudaEventBlockingSync);


for(int i = 0; i < seqLength; ++i){
// cudaEventRecord(event1, 0);
cublasErrCheck(cublasSetStream(handle, stream_h));
cublasErrCheck(cublasSgemm(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
gateSize, miniBatch, h_depth,
&alpha,
input_T + miniBatch * input_depth * seqLength + input_depth * gateSize,
gateSize,
h_op_data,
h_depth ,
&beta,
tmp_h,
gateSize));

dim3 blockDim;
dim3 gridDim;

blockDim.x = 256;
gridDim.x = (miniBatch * hiddenSize + blockDim.x - 1) / blockDim.x;

if(stream_i_flag == true)
cudaErrCheck(cudaStreamSynchronize(stream_i));
elementWise_fp <<< gridDim, blockDim, 0 >>>
(hiddenSize, miniBatch,
tmp_h,
tmp_i + i * miniBatch * gateSize,
w_diag_bias_proj + 3 * hiddenSize,
NULL,
h_op_data + miniBatch * h_depth,
c_o_data + 2 * numElements + i * miniBatch * hiddenSize,
c_o_data,
c_o_data + numElements,
false,
w_diag_bias_proj,
w_diag_bias_proj + hiddenSize,
w_diag_bias_proj + 2 * hiddenSize,
use_peepholes,
h_depth,
cell_clip);

if(stream_i_flag == true){
cudaErrCheck(cudaStreamDestroy(stream_i));
stream_i_flag = false;
}

cudaErrCheck(cudaGetLastError());
if(use_peepholes != 0){
cublasErrCheck(cublasSgemm(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
h_depth, miniBatch, hiddenSize,
&alpha,
w_diag_bias_proj + 7 * hiddenSize,
h_depth,
c_o_data + 2 * numElements + i * miniBatch * hiddenSize,
hiddenSize,
&beta,
h_op_data + 2 * miniBatch * h_depth + i * miniBatch * h_depth,
h_depth));
if(proj_clip != 0){
// printf("in proj_clip\n");
dim3 blockDim;
dim3 gridDim;
blockDim.x = 256;
gridDim.x = (h_depth * miniBatch + blockDim.x - 1) / blockDim.x;
clip_by_value <<< gridDim, blockDim, 0 >>>
(h_op_data + 2 * miniBatch * h_depth + i * h_depth * miniBatch, proj_clip, miniBatch * h_depth);
}
//h_data和i_data保持同步
}
// cudaEventRecord(event2, 0);

// cudaEventSynchronize(event1);
// cudaEventSynchronize(event2);
// cudaDeviceSynchronize();

cudaErrCheck(cudaMemcpy(h_op_data + miniBatch * h_depth, h_op_data + 2 * miniBatch * h_depth + i * miniBatch * h_depth, miniBatch * h_depth * sizeof(float), cudaMemcpyDeviceToDevice));
cudaErrCheck(cudaMemcpy(h_op_data, h_op_data + miniBatch * h_depth, miniBatch * h_depth * sizeof(float), cudaMemcpyDeviceToDevice));
cudaErrCheck(cudaMemcpy(c_o_data, c_o_data + numElements, miniBatch * hiddenSize * sizeof(float), cudaMemcpyDeviceToDevice));
cudaErrCheck(cudaGetLastError());


// cudaErrCheck(cudaMemcpy(out_cpu1, h_op_data + miniBatch * h_depth, miniBatch * h_depth * sizeof(float), cudaMemcpyDeviceToHost));
// cudaDeviceSynchronize();
// printMatix(out_cpu1, miniBatch * h_depth);

// float *c = (float*)malloc(sizeof(float)*miniBatch * h_depth);
// float* out_cpu = new float[miniBatch * h_depth];
// cudaErrCheck(cudaMemcpy(c, h_op_data + miniBatch * h_depth, miniBatch * h_depth * sizeof(float), cudaMemcpyDeviceToHost));
// printMatix(c, miniBatch * h_depth);
// printf("end for*************************%d\n", hiddenSize);
cudaErrCheck(cudaGetLastError());

}
// cudaDeviceSynchronize();
// printf("end 1 *************************%d\n", hiddenSize);

// float* out_cpu = new float[seqLength * miniBatch * h_depth];
// cudaErrCheck(cudaMemcpy(out_cpu, c_i_data + 2 * miniBatch * h_depth, seqLength * miniBatch * h_depth * sizeof(float), cudaMemcpyDeviceToHost));
// cudaDeviceSynchronize();
// printMatix(out_cpu, seqLength * miniBatch * h_depth);

cudaErrCheck(cudaMemcpy(h_data_out, h_op_data + miniBatch * h_depth, miniBatch * h_depth * sizeof(float), cudaMemcpyDeviceToHost));
cudaErrCheck(cudaMemcpy(c_data_out, c_o_data + numElements, miniBatch * hiddenSize * sizeof(float), cudaMemcpyDeviceToHost));
cudaErrCheck(cudaMemcpy(output, h_op_data + 2 * miniBatch * h_depth, seqLength * miniBatch * h_depth * sizeof(float), cudaMemcpyDeviceToHost));


// printf("free *************************%d\n", hiddenSize);
//six to one
// cudaErrCheck(cudaFree(tmp_i));
// cudaErrCheck(cudaFree(tmp_h));
// cudaErrCheck(cudaFree(c_o_data));
// cudaErrCheck(cudaFree(input_T));
// cudaErrCheck(cudaFree(h_op_data));
cudaErrCheck(cudaFree(w_diag_bias_proj));

cudaErrCheck(cudaStreamDestroy(stream_h));


// printf("end 2 *************************%d\n", hiddenSize);
}

优化前后各计算模块速度对比

优化前

优化后

GTX1080Ti显卡下使用TFS CudaLstm测速对比,四层LSTM每层的网络结构:

seqLength: 4, inputSize: 320, hiddenSize: 1536, outSize: 320

seqLength : 4, inputSize: 320, hiddenSize: 1536, outSize: 320

seqLength : 4, inputSize: 320, hiddenSize: 1536, outSize: 448

seqLength : 4, inputSize: 448, hiddenSize: 1536, outSize: 448

V1版测试结果
并发数 1 2 4 8 16 32
LSTM 25.94ms 24.50ms 24.20ms 26.25ms 26.44ms 29.22ms
MyLSTM 21.20ms 20.70ms 21 .04ms 23.01ms 23.88ms 27.32ms
V2版测试结果
并发数 1 2 4 8 16 32 64 128
LSTM 25.21 ms 24.96ms 25.39ms 27.26ms 28.18ms 33.08ms 41.25ms 78.67ms
MyLSTM 23.47ms 23.49ms 23.37ms 23.89ms 24.7 1ms 30.58ms 38.09ms 76.22ms
V3版测试结果
并发数 1 2 4 8 16 32
LSTM 25.73ms 24.73ms 24. .80s 26.27ms 26.65ms 29.70ms
MyLSTM 19.62ms 19.82ms 19.94s 21.38ms 21.24ms 24.34ms

​ V1版未对异步计算进行优化,且内存是在kernal函数中分配,V2版对异步计算进行优化,内存调用TF接口分配在cc注册函数中,传入LSTMkernal就不再分配内存,V1V2均是通用版LSTM,接口保持了和TF.LSTM一致的接口,V3版针对我们网络模型特定优化,把四层结构的内存和计算放入一个函数中,即四层从时间步和层数上通过stream和event进行控制,cc注册接口暴露出输入,返回的输出直接是四层计算后的结果,模型固定参数见测试参数,因此V3版可以理解为四层一块计算,速度更快,但是后三层的隐层由op固定,可以在op kernal内修改,但不能在接口处修改。