tensorflow注册自己实现的Op

目的是云端算法中执行LSTM部分计算过程的加速,即用cu文件编译出so,用此so中的LSTM类或函数替代tf.LSTMCell进行运算。
整个项目见Github,流程见博客,博主也刚入门cuda,欢迎留言探讨~

1. 源代码编译tensorflow

因为我们要对tf库进行修改,所以需要用源码编译方式重新安装tensorflow,官方步骤写的很清楚,就不自己瞎写了。

2. 注册OP流程:

  1. 定义 Op 的接口,即按规则写好cc文件

  2. 为 Op 实现 kernel,即你自己的.cu文件

  3. 编译出so,即(BUILD.sh)文件,上述三个文件如下,同样先看官方网站,再来看例子会豁然开朗

3. 例子

  1. 按上图准备好cc文件

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
#include <stdio.h>
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/kernels/bounds_check.h"
#include "tensorflow/core/framework/allocator.h"
#include "fsmn_forward.h"
#include <cstddef>
#include <iostream>
#include <algorithm>
using namespace tensorflow;
void LSTMTest(int miniBatch, int seqLength, int inputSize, int hiddenSize, int outSize,
float* input, float* h_data_in, float *c_data_in, float* weight_i, float* weight_h, float*bias_data_in, float* w_i_diag_in,
float* w_f_diag_in, float* w_o_diag_in, float* proj_kernel_in, float* h_data_out, float* c_data_out, float* output,
bool use_peepholes = true, float cell_clip = 0.0, float proj_clip = 0.0);


REGISTER_OP("CudaLstmForward")
.Input("input: float32")
.Input("cdata_in: float32")
.Input("hdata_in: float32")
.Input("weight_i: float32")
.Input("weight_h: float32")
.Input("bias: float32")
.Input("w_i_diag: float32")
.Input("w_f_diag: float32")
.Input("w_o_diag: float32")
.Input("proj_kernel: float32")
.Output("cdata_out: float32")
.Output("hdata_out: float32")
.Output("output: float32")
.SetShapeFn([](::tensorflow::shape_inference::InferenceContext *c) {
c->set_output(0, c->Matrix(c->Dim(c->input(1), 0), c->Dim(c->input(1), 1)));
c->set_output(1, c->Matrix(c->Dim(c->input(1), 0), c->Dim(c->input(2), 1)));
c->set_output(2, c->Matrix(c->Dim(c->input(0), 0), c->Dim(c->input(2), 1)));
return Status::OK();
});


class CudaLstmForwardOp : public OpKernel {
public:
explicit CudaLstmForwardOp(OpKernelConstruction* ctx) : OpKernel(ctx){
}
void Compute(OpKernelContext* ctx) override {
//printf("begin....");

Tensor input_data= ctx->mutable_input(0, true);
Tensor cdata_in= ctx->mutable_input(1, true);
Tensor hdata_in= ctx->mutable_input(2, true);

//printf("input....");
OP_REQUIRES(ctx, input_data.shape().dims() == 2,
errors::InvalidArgument("input data is not a 2-Tensor"));
OP_REQUIRES(ctx, hdata_in.shape().dims() == 2,
errors::InvalidArgument("hdata in is not a 2-Tensor"));
OP_REQUIRES(ctx, cdata_in.shape().dims() == 2,
errors::InvalidArgument("cdata in is not a 2-Tensor"));
//printf("weight....");
Tensor weight_i= ctx->mutable_input(3, true);
Tensor weight_h= ctx->mutable_input(4, true);
Tensor bias= ctx->mutable_input(5, true);
Tensor wi_diag= ctx->mutable_input(6, true);
Tensor wf_diaf= ctx->mutable_input(7, true);
Tensor wo_diag= ctx->mutable_input(8, true);
Tensor proj= ctx->mutable_input(9, true);

auto inputdata_t = input_data.tensor<float, 2>();
auto cdatain_t = cdata_in.tensor<float, 2>();
auto hdatain_t = hdata_in.tensor<float, 2>();

auto weighti_t = weight_i.tensor<float, 2>();
auto weighth_t = weight_h.tensor<float, 2>();
auto bias_t = bias.tensor<float, 1>();

auto widiag_t = wi_diag.tensor<float, 1>();
auto wfdiag_t = wf_diaf.tensor<float, 1>();
auto wodiag_t = wo_diag.tensor<float, 1>();
auto proj_t = proj.tensor<float, 2>();

const auto &acti_shape = input_data.shape();
int seq_batch = acti_shape.dim_size(0);
int inputsize = acti_shape.dim_size(1);

const auto &acth_shape = cdata_in.shape();
int batch = acth_shape.dim_size(0);
int hiddensize = acth_shape.dim_size(1);

const auto &actc_shape = hdata_in.shape();
int outputsize = actc_shape.dim_size(1);

int length = seq_batch/batch;

// Create an state out tensor
Tensor *state_outc = nullptr;
TensorShape indice_shape({batch, hiddensize});
OP_REQUIRES_OK(ctx, ctx->allocate_output("cdata_out", indice_shape, &state_outc));
auto statec_t = state_outc->tensor<float, 2>();

// Create an state out tensor
Tensor *state_outh = nullptr;
TensorShape indice_shape1({batch, outputsize});
OP_REQUIRES_OK(ctx, ctx->allocate_output("hdata_out", indice_shape1, &state_outh));
auto stateh_t = state_outh->tensor<float, 2>();

// Create an output tensor
Tensor *out_put = nullptr;
TensorShape indice_shape2({seq_batch, outputsize});
OP_REQUIRES_OK(ctx, ctx->allocate_output("output", indice_shape2, &out_put));
auto out_t = out_put->tensor<float, 2>();


// 执行计算操作
LSTMTest(batch, length, inputsize, hiddensize, outputsize,
inputdata_t.data(), cdatain_t.data(), hdatain_t.data(),
weighti_t.data(), weighth_t.data(), bias_t.data(),
widiag_t.data(), wfdiag_t.data(), wodiag_t.data(), proj_t.data(),
statec_t.data(), stateh_t.data(), out_t.data(),
true, 0.0, 50.0);


}

private:

}; //class CudaLstmForward end
REGISTER_KERNEL_BUILDER(Name("CudaLstmForward").Device(::tensorflow::DEVICE_CPU), CudaLstmForwardOp);
REGISTER_KERNEL_BUILDER(Name("CudaLstmForward").Device(DEVICE_GPU), CudaLstmForwardOp);

00_lstm.cu

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
extern "C" void LSTMTest(int miniBatch, int seqLength, int inputSize, int hiddenSize, int outSize,
float* input, float* c_data_in, float *h_data_in, float* weight_i, float* weight_h, float*bias_data_in, float* w_i_diag_in,
float* w_f_diag_in, float* w_o_diag_in, float* proj_kernel_in, float* c_data_out, float* h_data_out, float* output,
bool use_peepholes = true, float cell_clip = 0.0, float proj_clip = 0.0){


static int layer0_size = (7 + 320 + 2 * miniBatch + 4 * miniBatch) * 1536 + (2 * miniBatch + miniBatch * 4) * 320 + (miniBatch * 4) * 320 + (320 + 320 + miniBatch * 4 + miniBatch) * 4 * 1536;
static int layer1_size = (7 + 320 + 2 * miniBatch + 4 * miniBatch) * 1536 + (2 * miniBatch + miniBatch * 4) * 320 + (miniBatch * 4) * 320 + (320 + 320 + miniBatch * 4 + miniBatch) * 4 * 1536;
static int layer2_size = (7 + 448 + 2 * miniBatch + 4 * miniBatch) * 1536 + (2 * miniBatch + miniBatch * 4) * 448 + (miniBatch * 4) * 320 + (320 + 448 + miniBatch * 4 + miniBatch) * 4 * 1536;
static int layer3_size = (7 + 448 + 2 * miniBatch + 4 * miniBatch) * 1536 + (2 * miniBatch + miniBatch * 4) * 448 + (miniBatch * 4) * 448 + (448 + 448 + miniBatch * 4 + miniBatch) * 4 * 1536;
static float *w_diag_bias_proj;
static float *init_pointer;
static int all_layer_size = layer0_size + layer1_size + layer2_size + layer3_size;
static int flag = 0;



float alpha = 1.f;
float beta = 0.f;
int input_depth = inputSize;
int gateSize = hiddenSize * 4;
int h_depth;
int numElements = miniBatch * hiddenSize;

if(use_peepholes == true)
h_depth = outSize;
else
h_depth = hiddenSize;

int w_diag_bias_proj_size = (7 + h_depth) * hiddenSize;
int h_op_data_size = 2 * miniBatch * h_depth + miniBatch * seqLength * h_depth;
int input_T_size = miniBatch * seqLength * input_depth + (input_depth + h_depth) * gateSize;
int c_o_data_size = 2 * miniBatch * hiddenSize + miniBatch * seqLength * hiddenSize;
int tmp_i_size = miniBatch * seqLength * gateSize;
// int tmp_h_size = miniBatch * gateSize;

// printf("the seqLength is: %d, inputSize: %d, input_depth: %d, hiddenSize: %d, outSize: %d\n", seqLength, inputSize, input_depth, hiddenSize, outSize);

cudaErrCheck(cudaGetLastError());

if(flag == 0){
cudaErrCheck(cudaMalloc((void**)&w_diag_bias_proj, all_layer_size * sizeof(float)));
init_pointer = w_diag_bias_proj;
}

if(flag == 1)
w_diag_bias_proj = w_diag_bias_proj + layer0_size;
else if(flag == 2)
w_diag_bias_proj = w_diag_bias_proj + layer1_size;
else if(flag == 3)
w_diag_bias_proj = w_diag_bias_proj + layer2_size;
flag++;

printf("flag: %d\n", flag);

//b = a + size(a);
float *input_T = w_diag_bias_proj + w_diag_bias_proj_size;
//c = b + size(b);
float *h_op_data = input_T + input_T_size;
float *c_o_data = h_op_data + h_op_data_size;
float *tmp_i = c_o_data + c_o_data_size;
float *tmp_h = tmp_i + tmp_i_size;

cudaStream_t stream_i, stream_h;
cudaErrCheck(cudaStreamCreate(&stream_i));
cudaErrCheck(cudaStreamCreate(&stream_h));
bool stream_i_flag = true;



cudaErrCheck(cudaMemcpyAsync(input_T, input, miniBatch * input_depth * seqLength * sizeof(float), cudaMemcpyHostToDevice, stream_i));
cudaErrCheck(cudaMemcpyAsync(input_T + miniBatch * input_depth * seqLength, weight_i, input_depth * gateSize * sizeof(float), cudaMemcpyHostToDevice, stream_i));
cudaErrCheck(cudaMemcpyAsync(input_T + miniBatch * input_depth * seqLength + input_depth * gateSize, weight_h, h_depth * gateSize * sizeof(float), cudaMemcpyHostToDevice, stream_h));
// printf("*************************%d\n", hiddenSize);
cudaErrCheck(cudaMemcpyAsync(h_op_data, h_data_in, h_depth * miniBatch * sizeof(float), cudaMemcpyHostToDevice, stream_h));
cudaErrCheck(cudaMemcpyAsync(c_o_data, c_data_in, numElements * sizeof(float), cudaMemcpyHostToDevice, stream_h));


// printf("i_data up and i_data_beforeProj down and the seqLength is%d\n", seqLength);

cudaErrCheck(cudaMemcpyAsync(w_diag_bias_proj, w_i_diag_in, hiddenSize * sizeof(float), cudaMemcpyHostToDevice, stream_h));
cudaErrCheck(cudaMemcpyAsync(w_diag_bias_proj + hiddenSize, w_f_diag_in, hiddenSize * sizeof(float), cudaMemcpyHostToDevice, stream_h));
cudaErrCheck(cudaMemcpyAsync(w_diag_bias_proj + 2 * hiddenSize, w_o_diag_in, hiddenSize * sizeof(float), cudaMemcpyHostToDevice, stream_h));
cudaErrCheck(cudaMemcpyAsync(w_diag_bias_proj + 3 * hiddenSize , bias_data_in, gateSize * sizeof(float), cudaMemcpyHostToDevice, stream_h));
cudaErrCheck(cudaMemcpyAsync(w_diag_bias_proj + 7 * hiddenSize, proj_kernel_in, h_depth * hiddenSize * sizeof(float), cudaMemcpyHostToDevice, stream_h));




cudaErrCheck(cudaGetLastError());

// cudaDeviceSynchronize();
// Need a cuBLAS handle.
cublasHandle_t handle;
cublasErrCheck(cublasCreate(&handle));


cublasErrCheck(cublasSetStream(handle, stream_i));
cublasErrCheck(cublasSgemm(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
gateSize, miniBatch * seqLength, input_depth,
&alpha,
input_T + miniBatch * input_depth * seqLength,
gateSize,
input_T,
input_depth,
&beta,
tmp_i,
gateSize));



cudaErrCheck(cudaGetLastError());


for(int i = 0; i < seqLength; ++i){
// cudaEventRecord(event1, 0);
cublasErrCheck(cublasSetStream(handle, stream_h));
cublasErrCheck(cublasSgemm(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
gateSize, miniBatch, h_depth,
&alpha,
input_T + miniBatch * input_depth * seqLength + input_depth * gateSize,
gateSize,
h_op_data,
h_depth ,
&beta,
tmp_h,
gateSize));

dim3 blockDim;
dim3 gridDim;

blockDim.x = 256;
gridDim.x = (miniBatch * hiddenSize + blockDim.x - 1) / blockDim.x;

if(stream_i_flag == true)
cudaErrCheck(cudaStreamSynchronize(stream_i));
elementWise_fp <<< gridDim, blockDim, 0 , stream_h >>>
(hiddenSize, miniBatch,
tmp_h,
tmp_i + i * miniBatch * gateSize,
w_diag_bias_proj + 3 * hiddenSize,
NULL,
h_op_data + miniBatch * h_depth,
c_o_data + 2 * numElements + i * miniBatch * hiddenSize,
c_o_data,
c_o_data + numElements,
false,
w_diag_bias_proj,
w_diag_bias_proj + hiddenSize,
w_diag_bias_proj + 2 * hiddenSize,
use_peepholes,
h_depth,
cell_clip);

if(stream_i_flag == true){
cudaErrCheck(cudaStreamDestroy(stream_i));
stream_i_flag = false;
}

cudaErrCheck(cudaGetLastError());
if(use_peepholes != 0){
cublasErrCheck(cublasSgemm(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
h_depth, miniBatch, hiddenSize,
&alpha,
w_diag_bias_proj + 7 * hiddenSize,
h_depth,
c_o_data + 2 * numElements + i * miniBatch * hiddenSize,
hiddenSize,
&beta,
h_op_data + 2 * miniBatch * h_depth + i * miniBatch * h_depth,
h_depth));
if(proj_clip != 0){
// printf("in proj_clip\n");
dim3 blockDim;
dim3 gridDim;
blockDim.x = 256;
gridDim.x = (h_depth * miniBatch + blockDim.x - 1) / blockDim.x;
clip_by_value <<< gridDim, blockDim, 0, stream_h >>>
(h_op_data + 2 * miniBatch * h_depth + i * h_depth * miniBatch, proj_clip, miniBatch * h_depth);
}
//h_data和i_data保持同步
}


cudaErrCheck(cudaMemcpy(h_op_data + miniBatch * h_depth, h_op_data + 2 * miniBatch * h_depth + i * miniBatch * h_depth, miniBatch * h_depth * sizeof(float), cudaMemcpyDeviceToDevice));
cudaErrCheck(cudaMemcpy(h_op_data, h_op_data + miniBatch * h_depth, miniBatch * h_depth * sizeof(float), cudaMemcpyDeviceToDevice));
cudaErrCheck(cudaMemcpy(c_o_data, c_o_data + numElements, miniBatch * hiddenSize * sizeof(float), cudaMemcpyDeviceToDevice));
cudaErrCheck(cudaGetLastError());


cudaErrCheck(cudaGetLastError());

}



cudaErrCheck(cudaMemcpy(h_data_out, h_op_data + miniBatch * h_depth, miniBatch * h_depth * sizeof(float), cudaMemcpyDeviceToHost));
cudaErrCheck(cudaMemcpy(c_data_out, c_o_data + numElements, miniBatch * hiddenSize * sizeof(float), cudaMemcpyDeviceToHost));
cudaErrCheck(cudaMemcpy(output, h_op_data + 2 * miniBatch * h_depth, seqLength * miniBatch * h_depth * sizeof(float), cudaMemcpyDeviceToHost));

if(flag == 4){
cudaErrCheck(cudaFree(init_pointer));
flag = 0;
}

cudaErrCheck(cudaStreamDestroy(stream_h));

}


int main(int argc, char* argv[]) {
int seqLength;
int numLayers;
int hiddenSize;
int miniBatch;
bool use_peepholes;
int num_proj;
float cell_clip = 0.0;
float proj_clip = 0.0;

if (argc == 5) {
seqLength = atoi(argv[1]);
numLayers = atoi(argv[2]);
hiddenSize = atoi(argv[3]);
miniBatch = atoi(argv[4]);
}
else if (argc == 1) {
printf("Running with default settings\n");
seqLength = 2;
numLayers = 1;
hiddenSize = 1536;
miniBatch = 100;
use_peepholes = true;
num_proj = 320;
cell_clip = 0.00;
proj_clip = 50.00;
}
else {
printf("Usage: ./LSTM <seqLength> <numLayers> <hiddenSize> <miniBatch>\n");
return 1;
}

printf("seqLength %d, numLayers %d, num_proj %d, miniBatch %d\n", seqLength, numLayers, num_proj, miniBatch);

int outSize = num_proj;
int numRuns = 4;
float totalTime = 0.f;
int input_depth = 320;
float* input = init_Matrix(miniBatch * input_depth * seqLength);
float* h_data_in = init_Matrix_zeros(miniBatch * num_proj);
float* c_data_in = init_Matrix_zeros(miniBatch * hiddenSize);
float* weight_i = init_Matrix(input_depth* hiddenSize * 4);
float* weight_h = init_Matrix(num_proj* hiddenSize * 4);
float* bias_data_in = init_Matrix_zeros(hiddenSize * 4);
float* w_i_diag_in = init_Matrix(hiddenSize);
float* w_f_diag_in = init_Matrix(hiddenSize);
float* w_o_diag_in = init_Matrix(hiddenSize);
float* proj_kernel_in = init_Matrix(hiddenSize * num_proj);
float* h_data_out = init_Matrix_zeros(miniBatch * num_proj);
float* c_data_out = init_Matrix_zeros(miniBatch * hiddenSize);
float* output = init_Matrix_zeros(miniBatch * seqLength * num_proj);

for (int run = 0; run < numRuns; run++) {

LSTMTest(miniBatch, seqLength, input_depth, hiddenSize, outSize,
input, c_data_in, h_data_in, weight_i, weight_h, bias_data_in, w_i_diag_in,
w_f_diag_in, w_o_diag_in, proj_kernel_in, c_data_out, h_data_out, output, use_peepholes, cell_clip, proj_clip);
}

printf("Runtime %fms\n", totalTime / numRuns);

return time < 0;
}
  1. 混合编译.c/.cpp与.cu文件

  • 即在cpp里使用cu文件,编译cpp时将编译好的cuda库链接进来

分别编译:g++ -o test 00_lstm.o 01_cpptest_cuda_lstm.o -lcudart -L/usr/local/cuda/lib64 -lcublas -lcurand -L/home/resources/yxwang/cuda-10.0/lib64/

静态库: nvcc -lib 00_lstm.cu -o lib00_lstm.a

g++ -o test 00_lstm.o 01_cpptest_cuda_lstm.o -L/usr/local/cuda/lib64

动态库(BUILD.sh):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
#注意要在源码编译后的tensorflow文件夹编译,pwd=~/tensorflow/tensorflow/core/user_ops
TF_CFLAGS=( $(python -c 'import tensorflow as tf; print(" ".join(tf.sysconfig.get_compile_flags()))') );TF_LFLAGS=( $(python -c 'import tensorflow as tf; print(" ".join(tf.sysconfig.get_link_flags()))') )

#g++ -std=c++11 -shared -o cuda_lstm_forward.so -c cuda_lstm_forward.cc -ltestcu -fPIC ${TF_CFLAGS[@]} ${TF_LFLAGS[@]} -O2 -ltensorflow_framework -L /home/resources/yxwang/cuda-10.0/lib64/ -lcublas -lcurand

##把00_lstm.cu -o 成lib00_lstm.so
#nvcc -o lib00_lstm.so -shared -Xcompiler -fPIC 00_lstm.cu -lcublas -lcurand -L /home/resources/yxwang/cuda-10.0/lib64/ -std=c++11
#
##把cuda_lstm_forward.cc -o成cuda_lstm_forward.so,用到-l00_lstm -L.
#g++ -std=c++11 -shared cuda_lstm_forward.cc -o cuda_lstm_forward.so -fPIC ${TF_CFLAGS[@]} ${TF_LFLAGS[@]} -O2 -ltensorflow_framework -l00_lstm -L. -lcublas -lcurand -L/home/resources/yxwang/cuda-10.0/lib64/

nvcc -o lib00_lstm.so -shared -Xcompiler -fPIC 00_lstm.cu -lcublas -lcurand -L /usr/local/cuda/lib64/ -std=c++11

g++ -std=c++11 -shared cuda_lstm_forward.cc -o cuda_lstm_forward.so -fPIC ${TF_CFLAGS[@]} ${TF_LFLAGS[@]} -O2 -ltensorflow_framework -l00_lstm -L. -lcublas -lcurand -L/usr/local/cuda/lib64/

#-l链接库名 -库地址
#单纯编译测试:nvcc -g -G 00_lstm.cu -o 00_lstm -L -arch=sm_52 -DPERFOPTS=31 -lcublas -lcurand

或写Makefile, : 后为依赖项,从下往上看

1
2
3
4
5
6
7
all : cpp

cpp : lib00_lstm.so
g++ 01_cpptest_cuda_lstm.cpp -o 01_cpptest_cuda_lstm /home/resources/yxwang/cuda-10.0/lib64/libcublas.so -l00_lstm -L.

lib00_lstm.so : 00_lstm.cu
nvcc -o lib00_lstm.so -shared -Xcompiler -fPIC 00_lstm.cu -lcublas -lcurand -L /home/resources/yxwang/cuda-10.0/lib64/
1
2
3
4
5
6
7
8
9
10
11
ShapeHandle x, cs_prev;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &x));
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &cs_prev));

DimensionHandle batch_size = c->Dim(x, 0);
DimensionHandle cell_size = c->Dim(cs_prev, 1);
ShapeHandle output = c->Matrix(batch_size, cell_size);
for (int i = 0; i < 7; ++i) {
c->set_output(i, output);
}
return tensorflow::Status::OK();