| @@ -227,7 +227,7 @@ jobs: | |||
| uses: actions/cache@v3 | |||
| with: | |||
| path: lavapipe-install | |||
| key: lavapipe-linux-install-20211127-2 | |||
| key: lavapipe-linux-install-20211127-3 | |||
| - name: checkout-lavapipe | |||
| if: steps.cache-lavapipe.outputs.cache-hit != 'true' | |||
| uses: actions/checkout@v3 | |||
| @@ -1026,15 +1026,17 @@ y0, hidden y1, cell y2 = lstm(x0, hidden x1, cell x2) | |||
| | param id | name | type | default | description | | |||
| | --------- | ------------- | ----- | --------- | ----------------- | | |||
| | 0 | num_output | int | 0 | hidden size of output | | |||
| | 0 | num_output | int | 0 | output size of output | | |||
| | 1 | weight_data_size| int | 0 | total size of IFOG weight matrix | | |||
| | 2 | direction | int | 0 | 0=forward, 1=reverse, 2=bidirectional | | |||
| | 3 | hidden_size | int | num_output| hidden size | | |||
| | weight | type | shape | | |||
| | ------------- | ----- | --------------------- | | |||
| | weight_xc_data| float/fp16/int8 | [input_size, num_output * 4, num_directions] | | |||
| | bias_c_data | float/fp16/int8 | [num_output, 4, num_directions] | | |||
| | weight_hc_data| float/fp16/int8 | [num_output, num_output * 4, num_directions] | | |||
| | weight_xc_data| float/fp16/int8 | [input_size, hidden_size * 4, num_directions] | | |||
| | bias_c_data | float/fp16/int8 | [hidden_size, 4, num_directions] | | |||
| | weight_hc_data| float/fp16/int8 | [num_output, hidden_size * 4, num_directions] | | |||
| | weight_hr_data| float/fp16/int8 | [hidden_size, num_output, num_directions] | | |||
| Direction flag: | |||
| - 0 = forward only | |||
| @@ -58,11 +58,11 @@ int LSTM_arm::create_pipeline(const Option& opt) | |||
| // pack IFOG | |||
| int num_directions = direction == 2 ? 2 : 1; | |||
| int size = weight_data_size / num_directions / num_output / 4; | |||
| int size = weight_data_size / num_directions / hidden_size / 4; | |||
| weight_xc_data_packed.create(size, num_output, num_directions, 16u, 4); | |||
| bias_c_data_packed.create(num_output, 1, num_directions, 16u, 4); | |||
| weight_hc_data_packed.create(num_output, num_output, num_directions, 16u, 4); | |||
| weight_xc_data_packed.create(size, hidden_size, num_directions, 16u, 4); | |||
| bias_c_data_packed.create(hidden_size, 1, num_directions, 16u, 4); | |||
| weight_hc_data_packed.create(num_output, hidden_size, num_directions, 16u, 4); | |||
| #pragma omp parallel for num_threads(opt.num_threads) | |||
| for (int dr = 0; dr < num_directions; dr++) | |||
| @@ -82,7 +82,7 @@ int LSTM_arm::create_pipeline(const Option& opt) | |||
| float* bias_c_IFOG = bias_c_data_packed_dr.row(0); | |||
| for (int q = 0; q < num_output; q++) | |||
| for (int q = 0; q < hidden_size; q++) | |||
| { | |||
| bias_c_IFOG[0] = bias_c_I[q]; | |||
| bias_c_IFOG[1] = bias_c_F[q]; | |||
| @@ -91,15 +91,15 @@ int LSTM_arm::create_pipeline(const Option& opt) | |||
| bias_c_IFOG += 4; | |||
| const float* weight_xc_I = weight_xc.row(num_output * 0 + q); | |||
| const float* weight_xc_F = weight_xc.row(num_output * 1 + q); | |||
| const float* weight_xc_O = weight_xc.row(num_output * 2 + q); | |||
| const float* weight_xc_G = weight_xc.row(num_output * 3 + q); | |||
| const float* weight_xc_I = weight_xc.row(hidden_size * 0 + q); | |||
| const float* weight_xc_F = weight_xc.row(hidden_size * 1 + q); | |||
| const float* weight_xc_O = weight_xc.row(hidden_size * 2 + q); | |||
| const float* weight_xc_G = weight_xc.row(hidden_size * 3 + q); | |||
| const float* weight_hc_I = weight_hc.row(num_output * 0 + q); | |||
| const float* weight_hc_F = weight_hc.row(num_output * 1 + q); | |||
| const float* weight_hc_O = weight_hc.row(num_output * 2 + q); | |||
| const float* weight_hc_G = weight_hc.row(num_output * 3 + q); | |||
| const float* weight_hc_I = weight_hc.row(hidden_size * 0 + q); | |||
| const float* weight_hc_F = weight_hc.row(hidden_size * 1 + q); | |||
| const float* weight_hc_O = weight_hc.row(hidden_size * 2 + q); | |||
| const float* weight_hc_G = weight_hc.row(hidden_size * 3 + q); | |||
| float* weight_xc_IFOG = weight_xc_data_packed_dr.row(q); | |||
| float* weight_hc_IFOG = weight_hc_data_packed_dr.row(q); | |||
| @@ -126,21 +126,37 @@ int LSTM_arm::create_pipeline(const Option& opt) | |||
| } | |||
| } | |||
| if (opt.lightmode) | |||
| { | |||
| weight_xc_data.release(); | |||
| bias_c_data.release(); | |||
| weight_hc_data.release(); | |||
| } | |||
| return 0; | |||
| } | |||
| static int lstm(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& weight_xc, const Mat& bias_c, const Mat& weight_hc, Mat& hidden_state, Mat& cell_state, const Option& opt) | |||
| static int lstm(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& weight_xc, const Mat& bias_c, const Mat& weight_hc, const Mat& weight_hr, Mat& hidden_state, Mat& cell_state, const Option& opt) | |||
| { | |||
| int size = bottom_blob.w; | |||
| int T = bottom_blob.h; | |||
| int num_output = top_blob.w; | |||
| int hidden_size = cell_state.w; | |||
| // 4 x num_output | |||
| Mat gates(4, num_output, 4u, opt.workspace_allocator); | |||
| // 4 x hidden_size | |||
| Mat gates(4, hidden_size, 4u, opt.workspace_allocator); | |||
| if (gates.empty()) | |||
| return -100; | |||
| Mat tmp_hidden_state; | |||
| if (num_output != hidden_size) | |||
| { | |||
| tmp_hidden_state.create(hidden_size, 4u, opt.workspace_allocator); | |||
| if (tmp_hidden_state.empty()) | |||
| return -100; | |||
| } | |||
| // unroll | |||
| for (int t = 0; t < T; t++) | |||
| { | |||
| @@ -155,7 +171,7 @@ static int lstm(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& w | |||
| const float* x = bottom_blob.row(ti); | |||
| #pragma omp parallel for num_threads(opt.num_threads) | |||
| for (int q = 0; q < num_output; q++) | |||
| for (int q = 0; q < hidden_size; q++) | |||
| { | |||
| const float* bias_c_IFOG = (const float*)bias_c + q * 4; | |||
| @@ -291,14 +307,15 @@ static int lstm(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& w | |||
| float* cell_ptr = cell_state; | |||
| float* hidden_ptr = hidden_state; | |||
| float* tmp_hidden_ptr = tmp_hidden_state; | |||
| int remain_num_output_start = 0; | |||
| int remain_hidden_size_start = 0; | |||
| #if __ARM_NEON | |||
| int nn_num_output = num_output >> 2; | |||
| remain_num_output_start = nn_num_output << 2; | |||
| int nn_hidden_size = hidden_size >> 2; | |||
| remain_hidden_size_start = nn_hidden_size << 2; | |||
| #pragma omp parallel for num_threads(opt.num_threads) | |||
| for (int qq = 0; qq < nn_num_output; qq++) | |||
| for (int qq = 0; qq < nn_hidden_size; qq++) | |||
| { | |||
| int q = qq * 4; | |||
| @@ -315,12 +332,20 @@ static int lstm(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& w | |||
| float32x4_t _H = vmulq_f32(_O, tanh_ps(_cell2)); | |||
| vst1q_f32(cell_ptr + q, _cell2); | |||
| vst1q_f32(hidden_ptr + q, _H); | |||
| vst1q_f32(output_data + q, _H); | |||
| if (num_output == hidden_size) | |||
| { | |||
| vst1q_f32(hidden_ptr + q, _H); | |||
| vst1q_f32(output_data + q, _H); | |||
| } | |||
| else | |||
| { | |||
| vst1q_f32(tmp_hidden_ptr + q, _H); | |||
| } | |||
| } | |||
| #endif // __ARM_NEON | |||
| #pragma omp parallel for num_threads(opt.num_threads) | |||
| for (int q = remain_num_output_start; q < num_output; q++) | |||
| for (int q = remain_hidden_size_start; q < hidden_size; q++) | |||
| { | |||
| const float* gates_data = gates.row(q); | |||
| @@ -338,8 +363,43 @@ static int lstm(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& w | |||
| float H = O * tanh(cell2); | |||
| cell_ptr[q] = cell2; | |||
| hidden_ptr[q] = H; | |||
| output_data[q] = H; | |||
| if (num_output == hidden_size) | |||
| { | |||
| hidden_ptr[q] = H; | |||
| output_data[q] = H; | |||
| } | |||
| else | |||
| { | |||
| tmp_hidden_ptr[q] = H; | |||
| } | |||
| } | |||
| if (num_output != hidden_size) | |||
| { | |||
| // int nn_num_output = num_output >> 2; | |||
| // int remain_num_output_start = nn_num_output << 2; | |||
| // #pragma omp parallel for num_threads(opt.num_threads) | |||
| // for (int qq = 0; qq < nn_num_output; qq++) | |||
| // { | |||
| // int q = qq * 4; | |||
| // | |||
| // } | |||
| int remain_num_output_start = 0; | |||
| #pragma omp parallel for num_threads(opt.num_threads) | |||
| for (int q = remain_num_output_start; q < num_output; q++) | |||
| { | |||
| const float* hr = weight_hr.row(q); | |||
| const float* tmp_hidden_ptr = tmp_hidden_state; | |||
| float H = 0; | |||
| for (int i = 0; i < hidden_size; i++) | |||
| { | |||
| H += tmp_hidden_ptr[i] * hr[i]; | |||
| } | |||
| hidden_ptr[q] = H; | |||
| output_data[q] = H; | |||
| } | |||
| } | |||
| } | |||
| @@ -375,7 +435,7 @@ int LSTM_arm::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) | |||
| return -100; | |||
| hidden.fill(0.f); | |||
| Mat cell(num_output, 4u, opt.workspace_allocator); | |||
| Mat cell(hidden_size, 4u, opt.workspace_allocator); | |||
| if (cell.empty()) | |||
| return -100; | |||
| cell.fill(0.f); | |||
| @@ -387,7 +447,7 @@ int LSTM_arm::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) | |||
| // Uni directional | |||
| if (direction == 0 || direction == 1) | |||
| { | |||
| int ret = lstm(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden, cell, opt); | |||
| int ret = lstm(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt); | |||
| if (ret != 0) | |||
| return ret; | |||
| } | |||
| @@ -402,14 +462,14 @@ int LSTM_arm::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) | |||
| if (top_blob_reverse.empty()) | |||
| return -100; | |||
| int ret0 = lstm(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden, cell, opt); | |||
| int ret0 = lstm(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt); | |||
| if (ret0 != 0) | |||
| return ret0; | |||
| hidden.fill(0.0f); | |||
| cell.fill(0.0f); | |||
| int ret1 = lstm(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), hidden, cell, opt); | |||
| int ret1 = lstm(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), num_output == hidden_size ? Mat() : weight_hr_data.channel(1), hidden, cell, opt); | |||
| if (ret1 != 0) | |||
| return ret1; | |||
| @@ -466,7 +526,7 @@ int LSTM_arm::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& to | |||
| return -100; | |||
| hidden.fill(0.f); | |||
| cell.create(num_output, num_directions, 4u, hidden_cell_allocator); | |||
| cell.create(hidden_size, num_directions, 4u, hidden_cell_allocator); | |||
| if (cell.empty()) | |||
| return -100; | |||
| cell.fill(0.f); | |||
| @@ -480,7 +540,7 @@ int LSTM_arm::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& to | |||
| // Uni directional | |||
| if (direction == 0 || direction == 1) | |||
| { | |||
| int ret = lstm(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden, cell, opt); | |||
| int ret = lstm(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt); | |||
| if (ret != 0) | |||
| return ret; | |||
| } | |||
| @@ -497,13 +557,13 @@ int LSTM_arm::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& to | |||
| Mat hidden0 = hidden.row_range(0, 1); | |||
| Mat cell0 = cell.row_range(0, 1); | |||
| int ret0 = lstm(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden0, cell0, opt); | |||
| int ret0 = lstm(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden0, cell0, opt); | |||
| if (ret0 != 0) | |||
| return ret0; | |||
| Mat hidden1 = hidden.row_range(1, 1); | |||
| Mat cell1 = cell.row_range(1, 1); | |||
| int ret1 = lstm(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), hidden1, cell1, opt); | |||
| int ret1 = lstm(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), num_output == hidden_size ? Mat() : weight_hr_data.channel(1), hidden1, cell1, opt); | |||
| if (ret1 != 0) | |||
| return ret1; | |||
| @@ -529,18 +589,27 @@ int LSTM_arm::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& to | |||
| } | |||
| #if NCNN_BF16 | |||
| static int lstm_bf16s(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& weight_xc, const Mat& bias_c, const Mat& weight_hc, Mat& hidden_state, Mat& cell_state, const Option& opt) | |||
| static int lstm_bf16s(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& weight_xc, const Mat& bias_c, const Mat& weight_hc, const Mat& weight_hr, Mat& hidden_state, Mat& cell_state, const Option& opt) | |||
| { | |||
| int size = bottom_blob.w; | |||
| int T = bottom_blob.h; | |||
| int num_output = top_blob.w; | |||
| int hidden_size = cell_state.w; | |||
| // 4 x num_output | |||
| Mat gates(4, num_output, 4u, opt.workspace_allocator); | |||
| // 4 x hidden_size | |||
| Mat gates(4, hidden_size, 4u, opt.workspace_allocator); | |||
| if (gates.empty()) | |||
| return -100; | |||
| Mat tmp_hidden_state; | |||
| if (num_output != hidden_size) | |||
| { | |||
| tmp_hidden_state.create(hidden_size, 4u, opt.workspace_allocator); | |||
| if (tmp_hidden_state.empty()) | |||
| return -100; | |||
| } | |||
| // unroll | |||
| for (int t = 0; t < T; t++) | |||
| { | |||
| @@ -555,7 +624,7 @@ static int lstm_bf16s(const Mat& bottom_blob, Mat& top_blob, int reverse, const | |||
| const unsigned short* x = bottom_blob.row<const unsigned short>(ti); | |||
| #pragma omp parallel for num_threads(opt.num_threads) | |||
| for (int q = 0; q < num_output; q++) | |||
| for (int q = 0; q < hidden_size; q++) | |||
| { | |||
| const unsigned short* bias_c_IFOG = (const unsigned short*)bias_c + q * 4; | |||
| @@ -693,14 +762,15 @@ static int lstm_bf16s(const Mat& bottom_blob, Mat& top_blob, int reverse, const | |||
| float* cell_ptr = cell_state; | |||
| float* hidden_ptr = hidden_state; | |||
| float* tmp_hidden_ptr = tmp_hidden_state; | |||
| int remain_num_output_start = 0; | |||
| int remain_hidden_size_start = 0; | |||
| #if __ARM_NEON | |||
| int nn_num_output = num_output >> 2; | |||
| remain_num_output_start = nn_num_output << 2; | |||
| int nn_hidden_size = hidden_size >> 2; | |||
| remain_hidden_size_start = nn_hidden_size << 2; | |||
| #pragma omp parallel for num_threads(opt.num_threads) | |||
| for (int qq = 0; qq < nn_num_output; qq++) | |||
| for (int qq = 0; qq < nn_hidden_size; qq++) | |||
| { | |||
| int q = qq * 4; | |||
| @@ -717,12 +787,20 @@ static int lstm_bf16s(const Mat& bottom_blob, Mat& top_blob, int reverse, const | |||
| float32x4_t _H = vmulq_f32(_O, tanh_ps(_cell2)); | |||
| vst1q_f32(cell_ptr + q, _cell2); | |||
| vst1q_f32(hidden_ptr + q, _H); | |||
| vst1_u16(output_data + q, bfloat2float(_H)); | |||
| if (num_output == hidden_size) | |||
| { | |||
| vst1q_f32(hidden_ptr + q, _H); | |||
| vst1_u16(output_data + q, bfloat2float(_H)); | |||
| } | |||
| else | |||
| { | |||
| vst1q_f32(tmp_hidden_ptr + q, _H); | |||
| } | |||
| } | |||
| #endif // __ARM_NEON | |||
| #pragma omp parallel for num_threads(opt.num_threads) | |||
| for (int q = remain_num_output_start; q < num_output; q++) | |||
| for (int q = remain_hidden_size_start; q < hidden_size; q++) | |||
| { | |||
| const float* gates_data = gates.row(q); | |||
| @@ -740,8 +818,43 @@ static int lstm_bf16s(const Mat& bottom_blob, Mat& top_blob, int reverse, const | |||
| float H = O * tanh(cell2); | |||
| cell_ptr[q] = cell2; | |||
| hidden_ptr[q] = H; | |||
| output_data[q] = float32_to_bfloat16(H); | |||
| if (num_output == hidden_size) | |||
| { | |||
| hidden_ptr[q] = H; | |||
| output_data[q] = float32_to_bfloat16(H); | |||
| } | |||
| else | |||
| { | |||
| tmp_hidden_ptr[q] = H; | |||
| } | |||
| } | |||
| if (num_output != hidden_size) | |||
| { | |||
| // int nn_num_output = num_output >> 2; | |||
| // int remain_num_output_start = nn_num_output << 2; | |||
| // #pragma omp parallel for num_threads(opt.num_threads) | |||
| // for (int qq = 0; qq < nn_num_output; qq++) | |||
| // { | |||
| // int q = qq * 4; | |||
| // | |||
| // } | |||
| int remain_num_output_start = 0; | |||
| #pragma omp parallel for num_threads(opt.num_threads) | |||
| for (int q = remain_num_output_start; q < num_output; q++) | |||
| { | |||
| const float* hr = weight_hr.row(q); | |||
| const float* tmp_hidden_ptr = tmp_hidden_state; | |||
| float H = 0; | |||
| for (int i = 0; i < hidden_size; i++) | |||
| { | |||
| H += tmp_hidden_ptr[i] * hr[i]; | |||
| } | |||
| hidden_ptr[q] = H; | |||
| output_data[q] = float32_to_bfloat16(H); | |||
| } | |||
| } | |||
| } | |||
| @@ -752,11 +865,11 @@ int LSTM_arm::create_pipeline_bf16s(const Option& opt) | |||
| { | |||
| // pack IFOG | |||
| int num_directions = direction == 2 ? 2 : 1; | |||
| int size = weight_data_size / num_directions / num_output / 4; | |||
| int size = weight_data_size / num_directions / hidden_size / 4; | |||
| weight_xc_data_packed.create(size, num_output, num_directions, 8u, 4); | |||
| bias_c_data_packed.create(num_output, 1, num_directions, 8u, 4); | |||
| weight_hc_data_packed.create(num_output, num_output, num_directions, 8u, 4); | |||
| weight_xc_data_packed.create(size, hidden_size, num_directions, 8u, 4); | |||
| bias_c_data_packed.create(hidden_size, 1, num_directions, 8u, 4); | |||
| weight_hc_data_packed.create(num_output, hidden_size, num_directions, 8u, 4); | |||
| #pragma omp parallel for num_threads(opt.num_threads) | |||
| for (int dr = 0; dr < num_directions; dr++) | |||
| @@ -776,7 +889,7 @@ int LSTM_arm::create_pipeline_bf16s(const Option& opt) | |||
| unsigned short* bias_c_IFOG = bias_c_data_packed_dr.row<unsigned short>(0); | |||
| for (int q = 0; q < num_output; q++) | |||
| for (int q = 0; q < hidden_size; q++) | |||
| { | |||
| bias_c_IFOG[0] = float32_to_bfloat16(bias_c_I[q]); | |||
| bias_c_IFOG[1] = float32_to_bfloat16(bias_c_F[q]); | |||
| @@ -785,15 +898,15 @@ int LSTM_arm::create_pipeline_bf16s(const Option& opt) | |||
| bias_c_IFOG += 4; | |||
| const float* weight_xc_I = weight_xc.row(num_output * 0 + q); | |||
| const float* weight_xc_F = weight_xc.row(num_output * 1 + q); | |||
| const float* weight_xc_O = weight_xc.row(num_output * 2 + q); | |||
| const float* weight_xc_G = weight_xc.row(num_output * 3 + q); | |||
| const float* weight_xc_I = weight_xc.row(hidden_size * 0 + q); | |||
| const float* weight_xc_F = weight_xc.row(hidden_size * 1 + q); | |||
| const float* weight_xc_O = weight_xc.row(hidden_size * 2 + q); | |||
| const float* weight_xc_G = weight_xc.row(hidden_size * 3 + q); | |||
| const float* weight_hc_I = weight_hc.row(num_output * 0 + q); | |||
| const float* weight_hc_F = weight_hc.row(num_output * 1 + q); | |||
| const float* weight_hc_O = weight_hc.row(num_output * 2 + q); | |||
| const float* weight_hc_G = weight_hc.row(num_output * 3 + q); | |||
| const float* weight_hc_I = weight_hc.row(hidden_size * 0 + q); | |||
| const float* weight_hc_F = weight_hc.row(hidden_size * 1 + q); | |||
| const float* weight_hc_O = weight_hc.row(hidden_size * 2 + q); | |||
| const float* weight_hc_G = weight_hc.row(hidden_size * 3 + q); | |||
| unsigned short* weight_xc_IFOG = weight_xc_data_packed_dr.row<unsigned short>(q); | |||
| unsigned short* weight_hc_IFOG = weight_hc_data_packed_dr.row<unsigned short>(q); | |||
| @@ -820,6 +933,13 @@ int LSTM_arm::create_pipeline_bf16s(const Option& opt) | |||
| } | |||
| } | |||
| if (opt.lightmode) | |||
| { | |||
| weight_xc_data.release(); | |||
| bias_c_data.release(); | |||
| weight_hc_data.release(); | |||
| } | |||
| return 0; | |||
| } | |||
| @@ -835,7 +955,7 @@ int LSTM_arm::forward_bf16s(const Mat& bottom_blob, Mat& top_blob, const Option& | |||
| return -100; | |||
| hidden.fill(0.f); | |||
| Mat cell(num_output, 4u, opt.workspace_allocator); | |||
| Mat cell(hidden_size, 4u, opt.workspace_allocator); | |||
| if (cell.empty()) | |||
| return -100; | |||
| cell.fill(0.f); | |||
| @@ -847,7 +967,7 @@ int LSTM_arm::forward_bf16s(const Mat& bottom_blob, Mat& top_blob, const Option& | |||
| // Uni directional | |||
| if (direction == 0 || direction == 1) | |||
| { | |||
| int ret = lstm_bf16s(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden, cell, opt); | |||
| int ret = lstm_bf16s(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt); | |||
| if (ret != 0) | |||
| return ret; | |||
| } | |||
| @@ -862,14 +982,14 @@ int LSTM_arm::forward_bf16s(const Mat& bottom_blob, Mat& top_blob, const Option& | |||
| if (top_blob_reverse.empty()) | |||
| return -100; | |||
| int ret0 = lstm_bf16s(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden, cell, opt); | |||
| int ret0 = lstm_bf16s(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt); | |||
| if (ret0 != 0) | |||
| return ret0; | |||
| hidden.fill(0.f); | |||
| cell.fill(0.f); | |||
| int ret1 = lstm_bf16s(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), hidden, cell, opt); | |||
| int ret1 = lstm_bf16s(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), num_output == hidden_size ? Mat() : weight_hr_data.channel(1), hidden, cell, opt); | |||
| if (ret1 != 0) | |||
| return ret1; | |||
| @@ -911,7 +1031,7 @@ int LSTM_arm::forward_bf16s(const std::vector<Mat>& bottom_blobs, std::vector<Ma | |||
| return -100; | |||
| hidden.fill(0.f); | |||
| cell.create(num_output, num_directions, 4u, hidden_cell_allocator); | |||
| cell.create(hidden_size, num_directions, 4u, hidden_cell_allocator); | |||
| if (cell.empty()) | |||
| return -100; | |||
| cell.fill(0.f); | |||
| @@ -925,7 +1045,7 @@ int LSTM_arm::forward_bf16s(const std::vector<Mat>& bottom_blobs, std::vector<Ma | |||
| // Uni directional | |||
| if (direction == 0 || direction == 1) | |||
| { | |||
| int ret = lstm_bf16s(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden, cell, opt); | |||
| int ret = lstm_bf16s(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt); | |||
| if (ret != 0) | |||
| return ret; | |||
| } | |||
| @@ -942,13 +1062,13 @@ int LSTM_arm::forward_bf16s(const std::vector<Mat>& bottom_blobs, std::vector<Ma | |||
| Mat hidden0 = hidden.row_range(0, 1); | |||
| Mat cell0 = cell.row_range(0, 1); | |||
| int ret0 = lstm_bf16s(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden0, cell0, opt); | |||
| int ret0 = lstm_bf16s(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden0, cell0, opt); | |||
| if (ret0 != 0) | |||
| return ret0; | |||
| Mat hidden1 = hidden.row_range(1, 1); | |||
| Mat cell1 = cell.row_range(1, 1); | |||
| int ret1 = lstm_bf16s(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), hidden1, cell1, opt); | |||
| int ret1 = lstm_bf16s(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), num_output == hidden_size ? Mat() : weight_hr_data.channel(1), hidden1, cell1, opt); | |||
| if (ret1 != 0) | |||
| return ret1; | |||
| @@ -25,18 +25,27 @@ | |||
| namespace ncnn { | |||
| #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
| static int lstm_fp16s(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& weight_xc, const Mat& bias_c, const Mat& weight_hc, Mat& hidden_state, Mat& cell_state, const Option& opt) | |||
| static int lstm_fp16s(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& weight_xc, const Mat& bias_c, const Mat& weight_hc, const Mat& weight_hr, Mat& hidden_state, Mat& cell_state, const Option& opt) | |||
| { | |||
| int size = bottom_blob.w; | |||
| int T = bottom_blob.h; | |||
| int num_output = top_blob.w; | |||
| int hidden_size = cell_state.w; | |||
| // 4 x num_output | |||
| Mat gates(4, num_output, 4u, opt.workspace_allocator); | |||
| // 4 x hidden_size | |||
| Mat gates(4, hidden_size, 4u, opt.workspace_allocator); | |||
| if (gates.empty()) | |||
| return -100; | |||
| Mat tmp_hidden_state; | |||
| if (num_output != hidden_size) | |||
| { | |||
| tmp_hidden_state.create(hidden_size, 4u, opt.workspace_allocator); | |||
| if (tmp_hidden_state.empty()) | |||
| return -100; | |||
| } | |||
| // unroll | |||
| for (int t = 0; t < T; t++) | |||
| { | |||
| @@ -51,7 +60,7 @@ static int lstm_fp16s(const Mat& bottom_blob, Mat& top_blob, int reverse, const | |||
| const __fp16* x = bottom_blob.row<const __fp16>(ti); | |||
| #pragma omp parallel for num_threads(opt.num_threads) | |||
| for (int q = 0; q < num_output; q++) | |||
| for (int q = 0; q < hidden_size; q++) | |||
| { | |||
| const __fp16* bias_c_IFOG = (const __fp16*)bias_c + q * 4; | |||
| @@ -141,11 +150,12 @@ static int lstm_fp16s(const Mat& bottom_blob, Mat& top_blob, int reverse, const | |||
| float* cell_ptr = cell_state; | |||
| float* hidden_ptr = hidden_state; | |||
| float* tmp_hidden_ptr = tmp_hidden_state; | |||
| int nn_num_output = num_output >> 2; | |||
| int remain_num_output_start = nn_num_output << 2; | |||
| int nn_hidden_size = hidden_size >> 2; | |||
| int remain_hidden_size_start = nn_hidden_size << 2; | |||
| #pragma omp parallel for num_threads(opt.num_threads) | |||
| for (int qq = 0; qq < nn_num_output; qq++) | |||
| for (int qq = 0; qq < nn_hidden_size; qq++) | |||
| { | |||
| int q = qq * 4; | |||
| @@ -162,11 +172,19 @@ static int lstm_fp16s(const Mat& bottom_blob, Mat& top_blob, int reverse, const | |||
| float32x4_t _H = vmulq_f32(_O, tanh_ps(_cell2)); | |||
| vst1q_f32(cell_ptr + q, _cell2); | |||
| vst1q_f32(hidden_ptr + q, _H); | |||
| vst1_f16(output_data + q, vcvt_f16_f32(_H)); | |||
| if (num_output == hidden_size) | |||
| { | |||
| vst1q_f32(hidden_ptr + q, _H); | |||
| vst1_f16(output_data + q, vcvt_f16_f32(_H)); | |||
| } | |||
| else | |||
| { | |||
| vst1q_f32(tmp_hidden_ptr + q, _H); | |||
| } | |||
| } | |||
| #pragma omp parallel for num_threads(opt.num_threads) | |||
| for (int q = remain_num_output_start; q < num_output; q++) | |||
| for (int q = remain_hidden_size_start; q < hidden_size; q++) | |||
| { | |||
| const float* gates_data = gates.row(q); | |||
| @@ -184,26 +202,70 @@ static int lstm_fp16s(const Mat& bottom_blob, Mat& top_blob, int reverse, const | |||
| float H = O * tanh(cell2); | |||
| cell_ptr[q] = cell2; | |||
| hidden_ptr[q] = H; | |||
| output_data[q] = (__fp16)(H); | |||
| if (num_output == hidden_size) | |||
| { | |||
| hidden_ptr[q] = H; | |||
| output_data[q] = (__fp16)H; | |||
| } | |||
| else | |||
| { | |||
| tmp_hidden_ptr[q] = H; | |||
| } | |||
| } | |||
| if (num_output != hidden_size) | |||
| { | |||
| // int nn_num_output = num_output >> 2; | |||
| // int remain_num_output_start = nn_num_output << 2; | |||
| // #pragma omp parallel for num_threads(opt.num_threads) | |||
| // for (int qq = 0; qq < nn_num_output; qq++) | |||
| // { | |||
| // int q = qq * 4; | |||
| // | |||
| // } | |||
| int remain_num_output_start = 0; | |||
| #pragma omp parallel for num_threads(opt.num_threads) | |||
| for (int q = remain_num_output_start; q < num_output; q++) | |||
| { | |||
| const float* hr = weight_hr.row(q); | |||
| const float* tmp_hidden_ptr = tmp_hidden_state; | |||
| float H = 0; | |||
| for (int i = 0; i < hidden_size; i++) | |||
| { | |||
| H += tmp_hidden_ptr[i] * hr[i]; | |||
| } | |||
| hidden_ptr[q] = H; | |||
| output_data[q] = (__fp16)H; | |||
| } | |||
| } | |||
| } | |||
| return 0; | |||
| } | |||
| static int lstm_fp16sa(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& weight_xc, const Mat& bias_c, const Mat& weight_hc, Mat& hidden_state, Mat& cell_state, const Option& opt) | |||
| static int lstm_fp16sa(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& weight_xc, const Mat& bias_c, const Mat& weight_hc, const Mat& weight_hr, Mat& hidden_state, Mat& cell_state, const Option& opt) | |||
| { | |||
| int size = bottom_blob.w; | |||
| int T = bottom_blob.h; | |||
| int num_output = top_blob.w; | |||
| int hidden_size = cell_state.w; | |||
| // 4 x num_output | |||
| Mat gates(4, num_output, 2u, opt.workspace_allocator); | |||
| // 4 x hidden_size | |||
| Mat gates(4, hidden_size, 2u, opt.workspace_allocator); | |||
| if (gates.empty()) | |||
| return -100; | |||
| Mat tmp_hidden_state; | |||
| if (num_output != hidden_size) | |||
| { | |||
| tmp_hidden_state.create(hidden_size, 4u, opt.workspace_allocator); | |||
| if (tmp_hidden_state.empty()) | |||
| return -100; | |||
| } | |||
| // unroll | |||
| for (int t = 0; t < T; t++) | |||
| { | |||
| @@ -216,10 +278,10 @@ static int lstm_fp16sa(const Mat& bottom_blob, Mat& top_blob, int reverse, const | |||
| int ti = reverse ? T - 1 - t : t; | |||
| int nn_num_output = num_output >> 1; | |||
| int remain_num_output_start = nn_num_output << 1; | |||
| int nn_hidden_size = hidden_size >> 1; | |||
| int remain_hidden_size_start = nn_hidden_size << 1; | |||
| #pragma omp parallel for num_threads(opt.num_threads) | |||
| for (int qq = 0; qq < nn_num_output; qq++) | |||
| for (int qq = 0; qq < nn_hidden_size; qq++) | |||
| { | |||
| int q = qq * 2; | |||
| @@ -319,7 +381,7 @@ static int lstm_fp16sa(const Mat& bottom_blob, Mat& top_blob, int reverse, const | |||
| vst1q_f16(gates_data, _IFOG); | |||
| } | |||
| #pragma omp parallel for num_threads(opt.num_threads) | |||
| for (int q = remain_num_output_start; q < num_output; q++) | |||
| for (int q = remain_hidden_size_start; q < hidden_size; q++) | |||
| { | |||
| const __fp16* bias_c_IFOG = (const __fp16*)bias_c + q * 4; | |||
| @@ -428,11 +490,12 @@ static int lstm_fp16sa(const Mat& bottom_blob, Mat& top_blob, int reverse, const | |||
| float* cell_ptr = cell_state; | |||
| float* hidden_ptr = hidden_state; | |||
| float* tmp_hidden_ptr = tmp_hidden_state; | |||
| nn_num_output = num_output >> 2; | |||
| remain_num_output_start = nn_num_output << 2; | |||
| nn_hidden_size = hidden_size >> 2; | |||
| remain_hidden_size_start = nn_hidden_size << 2; | |||
| #pragma omp parallel for num_threads(opt.num_threads) | |||
| for (int qq = 0; qq < nn_num_output; qq++) | |||
| for (int qq = 0; qq < nn_hidden_size; qq++) | |||
| { | |||
| int q = qq * 4; | |||
| @@ -449,11 +512,19 @@ static int lstm_fp16sa(const Mat& bottom_blob, Mat& top_blob, int reverse, const | |||
| float32x4_t _H = vmulq_f32(_O, tanh_ps(_cell2)); | |||
| vst1q_f32(cell_ptr + q, _cell2); | |||
| vst1q_f32(hidden_ptr + q, _H); | |||
| vst1_f16(output_data + q, vcvt_f16_f32(_H)); | |||
| if (num_output == hidden_size) | |||
| { | |||
| vst1q_f32(hidden_ptr + q, _H); | |||
| vst1_f16(output_data + q, vcvt_f16_f32(_H)); | |||
| } | |||
| else | |||
| { | |||
| vst1q_f32(tmp_hidden_ptr + q, _H); | |||
| } | |||
| } | |||
| #pragma omp parallel for num_threads(opt.num_threads) | |||
| for (int q = remain_num_output_start; q < num_output; q++) | |||
| for (int q = remain_hidden_size_start; q < hidden_size; q++) | |||
| { | |||
| const __fp16* gates_data = gates.row<const __fp16>(q); | |||
| @@ -471,8 +542,43 @@ static int lstm_fp16sa(const Mat& bottom_blob, Mat& top_blob, int reverse, const | |||
| float H = O * tanh(cell2); | |||
| cell_ptr[q] = cell2; | |||
| hidden_ptr[q] = H; | |||
| output_data[q] = (__fp16)H; | |||
| if (num_output == hidden_size) | |||
| { | |||
| hidden_ptr[q] = H; | |||
| output_data[q] = (__fp16)H; | |||
| } | |||
| else | |||
| { | |||
| tmp_hidden_ptr[q] = H; | |||
| } | |||
| } | |||
| if (num_output != hidden_size) | |||
| { | |||
| // int nn_num_output = num_output >> 2; | |||
| // int remain_num_output_start = nn_num_output << 2; | |||
| // #pragma omp parallel for num_threads(opt.num_threads) | |||
| // for (int qq = 0; qq < nn_num_output; qq++) | |||
| // { | |||
| // int q = qq * 4; | |||
| // | |||
| // } | |||
| int remain_num_output_start = 0; | |||
| #pragma omp parallel for num_threads(opt.num_threads) | |||
| for (int q = remain_num_output_start; q < num_output; q++) | |||
| { | |||
| const float* hr = weight_hr.row(q); | |||
| const float* tmp_hidden_ptr = tmp_hidden_state; | |||
| float H = 0; | |||
| for (int i = 0; i < hidden_size; i++) | |||
| { | |||
| H += tmp_hidden_ptr[i] * hr[i]; | |||
| } | |||
| hidden_ptr[q] = H; | |||
| output_data[q] = (__fp16)H; | |||
| } | |||
| } | |||
| } | |||
| @@ -483,19 +589,19 @@ int LSTM_arm::create_pipeline_fp16s(const Option& opt) | |||
| { | |||
| // pack IFOG | |||
| int num_directions = direction == 2 ? 2 : 1; | |||
| int size = weight_data_size / num_directions / num_output / 4; | |||
| int size = weight_data_size / num_directions / hidden_size / 4; | |||
| if (opt.use_fp16_arithmetic) | |||
| { | |||
| weight_xc_data_packed.create(size, num_output / 2 + num_output % 2, num_directions, 16u, 8); | |||
| bias_c_data_packed.create(num_output, 1, num_directions, 8u, 4); | |||
| weight_hc_data_packed.create(num_output, num_output / 2 + num_output % 2, num_directions, 16u, 8); | |||
| weight_xc_data_packed.create(size, hidden_size / 2 + hidden_size % 2, num_directions, 16u, 8); | |||
| bias_c_data_packed.create(hidden_size, 1, num_directions, 8u, 4); | |||
| weight_hc_data_packed.create(num_output, hidden_size / 2 + hidden_size % 2, num_directions, 16u, 8); | |||
| } | |||
| else | |||
| { | |||
| weight_xc_data_packed.create(size, num_output, num_directions, 8u, 4); | |||
| bias_c_data_packed.create(num_output, 1, num_directions, 8u, 4); | |||
| weight_hc_data_packed.create(num_output, num_output, num_directions, 8u, 4); | |||
| weight_xc_data_packed.create(size, hidden_size, num_directions, 8u, 4); | |||
| bias_c_data_packed.create(hidden_size, 1, num_directions, 8u, 4); | |||
| weight_hc_data_packed.create(num_output, hidden_size, num_directions, 8u, 4); | |||
| } | |||
| #pragma omp parallel for num_threads(opt.num_threads) | |||
| @@ -519,7 +625,7 @@ int LSTM_arm::create_pipeline_fp16s(const Option& opt) | |||
| if (opt.use_fp16_arithmetic) | |||
| { | |||
| int q = 0; | |||
| for (; q + 1 < num_output; q += 2) | |||
| for (; q + 1 < hidden_size; q += 2) | |||
| { | |||
| bias_c_IFOG[0] = (__fp16)bias_c_I[q]; | |||
| bias_c_IFOG[1] = (__fp16)bias_c_F[q]; | |||
| @@ -532,23 +638,23 @@ int LSTM_arm::create_pipeline_fp16s(const Option& opt) | |||
| bias_c_IFOG += 8; | |||
| const float* weight_xc_I = weight_xc.row(num_output * 0 + q); | |||
| const float* weight_xc_F = weight_xc.row(num_output * 1 + q); | |||
| const float* weight_xc_O = weight_xc.row(num_output * 2 + q); | |||
| const float* weight_xc_G = weight_xc.row(num_output * 3 + q); | |||
| const float* weight_xc_I_1 = weight_xc.row(num_output * 0 + q + 1); | |||
| const float* weight_xc_F_1 = weight_xc.row(num_output * 1 + q + 1); | |||
| const float* weight_xc_O_1 = weight_xc.row(num_output * 2 + q + 1); | |||
| const float* weight_xc_G_1 = weight_xc.row(num_output * 3 + q + 1); | |||
| const float* weight_hc_I = weight_hc.row(num_output * 0 + q); | |||
| const float* weight_hc_F = weight_hc.row(num_output * 1 + q); | |||
| const float* weight_hc_O = weight_hc.row(num_output * 2 + q); | |||
| const float* weight_hc_G = weight_hc.row(num_output * 3 + q); | |||
| const float* weight_hc_I_1 = weight_hc.row(num_output * 0 + q + 1); | |||
| const float* weight_hc_F_1 = weight_hc.row(num_output * 1 + q + 1); | |||
| const float* weight_hc_O_1 = weight_hc.row(num_output * 2 + q + 1); | |||
| const float* weight_hc_G_1 = weight_hc.row(num_output * 3 + q + 1); | |||
| const float* weight_xc_I = weight_xc.row(hidden_size * 0 + q); | |||
| const float* weight_xc_F = weight_xc.row(hidden_size * 1 + q); | |||
| const float* weight_xc_O = weight_xc.row(hidden_size * 2 + q); | |||
| const float* weight_xc_G = weight_xc.row(hidden_size * 3 + q); | |||
| const float* weight_xc_I_1 = weight_xc.row(hidden_size * 0 + q + 1); | |||
| const float* weight_xc_F_1 = weight_xc.row(hidden_size * 1 + q + 1); | |||
| const float* weight_xc_O_1 = weight_xc.row(hidden_size * 2 + q + 1); | |||
| const float* weight_xc_G_1 = weight_xc.row(hidden_size * 3 + q + 1); | |||
| const float* weight_hc_I = weight_hc.row(hidden_size * 0 + q); | |||
| const float* weight_hc_F = weight_hc.row(hidden_size * 1 + q); | |||
| const float* weight_hc_O = weight_hc.row(hidden_size * 2 + q); | |||
| const float* weight_hc_G = weight_hc.row(hidden_size * 3 + q); | |||
| const float* weight_hc_I_1 = weight_hc.row(hidden_size * 0 + q + 1); | |||
| const float* weight_hc_F_1 = weight_hc.row(hidden_size * 1 + q + 1); | |||
| const float* weight_hc_O_1 = weight_hc.row(hidden_size * 2 + q + 1); | |||
| const float* weight_hc_G_1 = weight_hc.row(hidden_size * 3 + q + 1); | |||
| __fp16* weight_xc_IFOG = weight_xc_data_packed_dr.row<__fp16>(q / 2); | |||
| __fp16* weight_hc_IFOG = weight_hc_data_packed_dr.row<__fp16>(q / 2); | |||
| @@ -581,7 +687,7 @@ int LSTM_arm::create_pipeline_fp16s(const Option& opt) | |||
| weight_hc_IFOG += 8; | |||
| } | |||
| } | |||
| for (; q < num_output; q++) | |||
| for (; q < hidden_size; q++) | |||
| { | |||
| bias_c_IFOG[0] = (__fp16)bias_c_I[q]; | |||
| bias_c_IFOG[1] = (__fp16)bias_c_F[q]; | |||
| @@ -590,15 +696,15 @@ int LSTM_arm::create_pipeline_fp16s(const Option& opt) | |||
| bias_c_IFOG += 4; | |||
| const float* weight_xc_I = weight_xc.row(num_output * 0 + q); | |||
| const float* weight_xc_F = weight_xc.row(num_output * 1 + q); | |||
| const float* weight_xc_O = weight_xc.row(num_output * 2 + q); | |||
| const float* weight_xc_G = weight_xc.row(num_output * 3 + q); | |||
| const float* weight_xc_I = weight_xc.row(hidden_size * 0 + q); | |||
| const float* weight_xc_F = weight_xc.row(hidden_size * 1 + q); | |||
| const float* weight_xc_O = weight_xc.row(hidden_size * 2 + q); | |||
| const float* weight_xc_G = weight_xc.row(hidden_size * 3 + q); | |||
| const float* weight_hc_I = weight_hc.row(num_output * 0 + q); | |||
| const float* weight_hc_F = weight_hc.row(num_output * 1 + q); | |||
| const float* weight_hc_O = weight_hc.row(num_output * 2 + q); | |||
| const float* weight_hc_G = weight_hc.row(num_output * 3 + q); | |||
| const float* weight_hc_I = weight_hc.row(hidden_size * 0 + q); | |||
| const float* weight_hc_F = weight_hc.row(hidden_size * 1 + q); | |||
| const float* weight_hc_O = weight_hc.row(hidden_size * 2 + q); | |||
| const float* weight_hc_G = weight_hc.row(hidden_size * 3 + q); | |||
| __fp16* weight_xc_IFOG = weight_xc_data_packed_dr.row<__fp16>(q / 2 + q % 2); | |||
| __fp16* weight_hc_IFOG = weight_hc_data_packed_dr.row<__fp16>(q / 2 + q % 2); | |||
| @@ -626,7 +732,7 @@ int LSTM_arm::create_pipeline_fp16s(const Option& opt) | |||
| } | |||
| else | |||
| { | |||
| for (int q = 0; q < num_output; q++) | |||
| for (int q = 0; q < hidden_size; q++) | |||
| { | |||
| bias_c_IFOG[0] = (__fp16)bias_c_I[q]; | |||
| bias_c_IFOG[1] = (__fp16)bias_c_F[q]; | |||
| @@ -635,15 +741,15 @@ int LSTM_arm::create_pipeline_fp16s(const Option& opt) | |||
| bias_c_IFOG += 4; | |||
| const float* weight_xc_I = weight_xc.row(num_output * 0 + q); | |||
| const float* weight_xc_F = weight_xc.row(num_output * 1 + q); | |||
| const float* weight_xc_O = weight_xc.row(num_output * 2 + q); | |||
| const float* weight_xc_G = weight_xc.row(num_output * 3 + q); | |||
| const float* weight_xc_I = weight_xc.row(hidden_size * 0 + q); | |||
| const float* weight_xc_F = weight_xc.row(hidden_size * 1 + q); | |||
| const float* weight_xc_O = weight_xc.row(hidden_size * 2 + q); | |||
| const float* weight_xc_G = weight_xc.row(hidden_size * 3 + q); | |||
| const float* weight_hc_I = weight_hc.row(num_output * 0 + q); | |||
| const float* weight_hc_F = weight_hc.row(num_output * 1 + q); | |||
| const float* weight_hc_O = weight_hc.row(num_output * 2 + q); | |||
| const float* weight_hc_G = weight_hc.row(num_output * 3 + q); | |||
| const float* weight_hc_I = weight_hc.row(hidden_size * 0 + q); | |||
| const float* weight_hc_F = weight_hc.row(hidden_size * 1 + q); | |||
| const float* weight_hc_O = weight_hc.row(hidden_size * 2 + q); | |||
| const float* weight_hc_G = weight_hc.row(hidden_size * 3 + q); | |||
| __fp16* weight_xc_IFOG = weight_xc_data_packed_dr.row<__fp16>(q); | |||
| __fp16* weight_hc_IFOG = weight_hc_data_packed_dr.row<__fp16>(q); | |||
| @@ -671,6 +777,13 @@ int LSTM_arm::create_pipeline_fp16s(const Option& opt) | |||
| } | |||
| } | |||
| if (opt.lightmode) | |||
| { | |||
| weight_xc_data.release(); | |||
| bias_c_data.release(); | |||
| weight_hc_data.release(); | |||
| } | |||
| return 0; | |||
| } | |||
| @@ -686,7 +799,7 @@ int LSTM_arm::forward_fp16s(const Mat& bottom_blob, Mat& top_blob, const Option& | |||
| return -100; | |||
| hidden.fill(0.f); | |||
| Mat cell(num_output, 4u, opt.workspace_allocator); | |||
| Mat cell(hidden_size, 4u, opt.workspace_allocator); | |||
| if (cell.empty()) | |||
| return -100; | |||
| cell.fill(0.f); | |||
| @@ -698,7 +811,7 @@ int LSTM_arm::forward_fp16s(const Mat& bottom_blob, Mat& top_blob, const Option& | |||
| // Uni directional | |||
| if (direction == 0 || direction == 1) | |||
| { | |||
| int ret = lstm_fp16s(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden, cell, opt); | |||
| int ret = lstm_fp16s(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt); | |||
| if (ret != 0) | |||
| return ret; | |||
| } | |||
| @@ -713,14 +826,14 @@ int LSTM_arm::forward_fp16s(const Mat& bottom_blob, Mat& top_blob, const Option& | |||
| if (top_blob_reverse.empty()) | |||
| return -100; | |||
| int ret0 = lstm_fp16s(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden, cell, opt); | |||
| int ret0 = lstm_fp16s(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt); | |||
| if (ret0 != 0) | |||
| return ret0; | |||
| hidden.fill(0.f); | |||
| cell.fill(0.f); | |||
| int ret1 = lstm_fp16s(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), hidden, cell, opt); | |||
| int ret1 = lstm_fp16s(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), num_output == hidden_size ? Mat() : weight_hr_data.channel(1), hidden, cell, opt); | |||
| if (ret1 != 0) | |||
| return ret1; | |||
| @@ -762,7 +875,7 @@ int LSTM_arm::forward_fp16s(const std::vector<Mat>& bottom_blobs, std::vector<Ma | |||
| return -100; | |||
| hidden.fill(0.f); | |||
| cell.create(num_output, num_directions, 4u, hidden_cell_allocator); | |||
| cell.create(hidden_size, num_directions, 4u, hidden_cell_allocator); | |||
| if (cell.empty()) | |||
| return -100; | |||
| cell.fill(0.f); | |||
| @@ -776,7 +889,7 @@ int LSTM_arm::forward_fp16s(const std::vector<Mat>& bottom_blobs, std::vector<Ma | |||
| // Uni directional | |||
| if (direction == 0 || direction == 1) | |||
| { | |||
| int ret = lstm_fp16s(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden, cell, opt); | |||
| int ret = lstm_fp16s(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt); | |||
| if (ret != 0) | |||
| return ret; | |||
| } | |||
| @@ -793,13 +906,13 @@ int LSTM_arm::forward_fp16s(const std::vector<Mat>& bottom_blobs, std::vector<Ma | |||
| Mat hidden0 = hidden.row_range(0, 1); | |||
| Mat cell0 = cell.row_range(0, 1); | |||
| int ret0 = lstm_fp16s(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden0, cell0, opt); | |||
| int ret0 = lstm_fp16s(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden0, cell0, opt); | |||
| if (ret0 != 0) | |||
| return ret0; | |||
| Mat hidden1 = hidden.row_range(1, 1); | |||
| Mat cell1 = cell.row_range(1, 1); | |||
| int ret1 = lstm_fp16s(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), hidden1, cell1, opt); | |||
| int ret1 = lstm_fp16s(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), num_output == hidden_size ? Mat() : weight_hr_data.channel(1), hidden1, cell1, opt); | |||
| if (ret1 != 0) | |||
| return ret1; | |||
| @@ -836,7 +949,7 @@ int LSTM_arm::forward_fp16sa(const Mat& bottom_blob, Mat& top_blob, const Option | |||
| return -100; | |||
| hidden.fill(0.f); | |||
| Mat cell(num_output, 4u, opt.workspace_allocator); | |||
| Mat cell(hidden_size, 4u, opt.workspace_allocator); | |||
| if (cell.empty()) | |||
| return -100; | |||
| cell.fill(0.f); | |||
| @@ -848,7 +961,7 @@ int LSTM_arm::forward_fp16sa(const Mat& bottom_blob, Mat& top_blob, const Option | |||
| // Uni directional | |||
| if (direction == 0 || direction == 1) | |||
| { | |||
| int ret = lstm_fp16sa(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden, cell, opt); | |||
| int ret = lstm_fp16sa(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt); | |||
| if (ret != 0) | |||
| return ret; | |||
| } | |||
| @@ -863,14 +976,14 @@ int LSTM_arm::forward_fp16sa(const Mat& bottom_blob, Mat& top_blob, const Option | |||
| if (top_blob_reverse.empty()) | |||
| return -100; | |||
| int ret0 = lstm_fp16sa(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden, cell, opt); | |||
| int ret0 = lstm_fp16sa(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt); | |||
| if (ret0 != 0) | |||
| return ret0; | |||
| hidden.fill(0.f); | |||
| cell.fill(0.f); | |||
| int ret1 = lstm_fp16sa(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), hidden, cell, opt); | |||
| int ret1 = lstm_fp16sa(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), num_output == hidden_size ? Mat() : weight_hr_data.channel(1), hidden, cell, opt); | |||
| if (ret1 != 0) | |||
| return ret1; | |||
| @@ -912,7 +1025,7 @@ int LSTM_arm::forward_fp16sa(const std::vector<Mat>& bottom_blobs, std::vector<M | |||
| return -100; | |||
| hidden.fill(0.f); | |||
| cell.create(num_output, num_directions, 4u, hidden_cell_allocator); | |||
| cell.create(hidden_size, num_directions, 4u, hidden_cell_allocator); | |||
| if (cell.empty()) | |||
| return -100; | |||
| cell.fill(0.f); | |||
| @@ -926,7 +1039,7 @@ int LSTM_arm::forward_fp16sa(const std::vector<Mat>& bottom_blobs, std::vector<M | |||
| // Uni directional | |||
| if (direction == 0 || direction == 1) | |||
| { | |||
| int ret = lstm_fp16sa(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden, cell, opt); | |||
| int ret = lstm_fp16sa(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt); | |||
| if (ret != 0) | |||
| return ret; | |||
| } | |||
| @@ -943,13 +1056,13 @@ int LSTM_arm::forward_fp16sa(const std::vector<Mat>& bottom_blobs, std::vector<M | |||
| Mat hidden0 = hidden.row_range(0, 1); | |||
| Mat cell0 = cell.row_range(0, 1); | |||
| int ret0 = lstm_fp16sa(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden0, cell0, opt); | |||
| int ret0 = lstm_fp16sa(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden0, cell0, opt); | |||
| if (ret0 != 0) | |||
| return ret0; | |||
| Mat hidden1 = hidden.row_range(1, 1); | |||
| Mat cell1 = cell.row_range(1, 1); | |||
| int ret1 = lstm_fp16sa(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), hidden1, cell1, opt); | |||
| int ret1 = lstm_fp16sa(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), num_output == hidden_size ? Mat() : weight_hr_data.channel(1), hidden1, cell1, opt); | |||
| if (ret1 != 0) | |||
| return ret1; | |||
| @@ -29,6 +29,7 @@ int LSTM::load_param(const ParamDict& pd) | |||
| num_output = pd.get(0, 0); | |||
| weight_data_size = pd.get(1, 0); | |||
| direction = pd.get(2, 0); | |||
| hidden_size = pd.get(3, num_output); | |||
| return 0; | |||
| } | |||
| @@ -36,36 +37,52 @@ int LSTM::load_model(const ModelBin& mb) | |||
| { | |||
| int num_directions = direction == 2 ? 2 : 1; | |||
| int size = weight_data_size / num_directions / num_output / 4; | |||
| int size = weight_data_size / num_directions / hidden_size / 4; | |||
| // raw weight data | |||
| weight_xc_data = mb.load(size, num_output * 4, num_directions, 0); | |||
| weight_xc_data = mb.load(size, hidden_size * 4, num_directions, 0); | |||
| if (weight_xc_data.empty()) | |||
| return -100; | |||
| bias_c_data = mb.load(num_output, 4, num_directions, 0); | |||
| bias_c_data = mb.load(hidden_size, 4, num_directions, 0); | |||
| if (bias_c_data.empty()) | |||
| return -100; | |||
| weight_hc_data = mb.load(num_output, num_output * 4, num_directions, 0); | |||
| weight_hc_data = mb.load(num_output, hidden_size * 4, num_directions, 0); | |||
| if (weight_hc_data.empty()) | |||
| return -100; | |||
| if (num_output != hidden_size) | |||
| { | |||
| weight_hr_data = mb.load(hidden_size, num_output, num_directions, 0); | |||
| if (weight_hr_data.empty()) | |||
| return -100; | |||
| } | |||
| return 0; | |||
| } | |||
| static int lstm(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& weight_xc, const Mat& bias_c, const Mat& weight_hc, Mat& hidden_state, Mat& cell_state, const Option& opt) | |||
| static int lstm(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& weight_xc, const Mat& bias_c, const Mat& weight_hc, const Mat& weight_hr, Mat& hidden_state, Mat& cell_state, const Option& opt) | |||
| { | |||
| int size = bottom_blob.w; | |||
| int T = bottom_blob.h; | |||
| int num_output = top_blob.w; | |||
| int hidden_size = cell_state.w; | |||
| // 4 x num_output | |||
| Mat gates(4, num_output, 4u, opt.workspace_allocator); | |||
| // 4 x hidden_size | |||
| Mat gates(4, hidden_size, 4u, opt.workspace_allocator); | |||
| if (gates.empty()) | |||
| return -100; | |||
| Mat tmp_hidden_state; | |||
| if (num_output != hidden_size) | |||
| { | |||
| tmp_hidden_state.create(hidden_size, 4u, opt.workspace_allocator); | |||
| if (tmp_hidden_state.empty()) | |||
| return -100; | |||
| } | |||
| // unroll | |||
| for (int t = 0; t < T; t++) | |||
| { | |||
| @@ -80,7 +97,7 @@ static int lstm(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& w | |||
| const float* x = bottom_blob.row(ti); | |||
| #pragma omp parallel for num_threads(opt.num_threads) | |||
| for (int q = 0; q < num_output; q++) | |||
| for (int q = 0; q < hidden_size; q++) | |||
| { | |||
| const float* bias_c_I = bias_c.row(0); | |||
| const float* bias_c_F = bias_c.row(1); | |||
| @@ -90,15 +107,15 @@ static int lstm(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& w | |||
| float* gates_data = gates.row(q); | |||
| // gate I F O G | |||
| const float* weight_xc_I = weight_xc.row(num_output * 0 + q); | |||
| const float* weight_xc_F = weight_xc.row(num_output * 1 + q); | |||
| const float* weight_xc_O = weight_xc.row(num_output * 2 + q); | |||
| const float* weight_xc_G = weight_xc.row(num_output * 3 + q); | |||
| const float* weight_xc_I = weight_xc.row(hidden_size * 0 + q); | |||
| const float* weight_xc_F = weight_xc.row(hidden_size * 1 + q); | |||
| const float* weight_xc_O = weight_xc.row(hidden_size * 2 + q); | |||
| const float* weight_xc_G = weight_xc.row(hidden_size * 3 + q); | |||
| const float* weight_hc_I = weight_hc.row(num_output * 0 + q); | |||
| const float* weight_hc_F = weight_hc.row(num_output * 1 + q); | |||
| const float* weight_hc_O = weight_hc.row(num_output * 2 + q); | |||
| const float* weight_hc_G = weight_hc.row(num_output * 3 + q); | |||
| const float* weight_hc_I = weight_hc.row(hidden_size * 0 + q); | |||
| const float* weight_hc_F = weight_hc.row(hidden_size * 1 + q); | |||
| const float* weight_hc_O = weight_hc.row(hidden_size * 2 + q); | |||
| const float* weight_hc_G = weight_hc.row(hidden_size * 3 + q); | |||
| float I = bias_c_I[q]; | |||
| float F = bias_c_F[q]; | |||
| @@ -140,7 +157,7 @@ static int lstm(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& w | |||
| // h_t := o_t .* tanh[c_t] | |||
| float* output_data = top_blob.row(ti); | |||
| #pragma omp parallel for num_threads(opt.num_threads) | |||
| for (int q = 0; q < num_output; q++) | |||
| for (int q = 0; q < hidden_size; q++) | |||
| { | |||
| const float* gates_data = gates.row(q); | |||
| @@ -157,8 +174,34 @@ static int lstm(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& w | |||
| float cell2 = F * cell_state[q] + I * G; | |||
| float H = O * tanh(cell2); | |||
| cell_state[q] = cell2; | |||
| hidden_state[q] = H; | |||
| output_data[q] = H; | |||
| if (num_output == hidden_size) | |||
| { | |||
| hidden_state[q] = H; | |||
| output_data[q] = H; | |||
| } | |||
| else | |||
| { | |||
| tmp_hidden_state[q] = H; | |||
| } | |||
| } | |||
| if (num_output != hidden_size) | |||
| { | |||
| #pragma omp parallel for num_threads(opt.num_threads) | |||
| for (int q = 0; q < num_output; q++) | |||
| { | |||
| const float* hr = weight_hr.row(q); | |||
| float H = 0; | |||
| for (int i = 0; i < hidden_size; i++) | |||
| { | |||
| H += tmp_hidden_state[i] * hr[i]; | |||
| } | |||
| hidden_state[q] = H; | |||
| output_data[q] = H; | |||
| } | |||
| } | |||
| } | |||
| @@ -177,7 +220,7 @@ int LSTM::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) cons | |||
| return -100; | |||
| hidden.fill(0.f); | |||
| Mat cell(num_output, 4u, opt.workspace_allocator); | |||
| Mat cell(hidden_size, 4u, opt.workspace_allocator); | |||
| if (cell.empty()) | |||
| return -100; | |||
| cell.fill(0.f); | |||
| @@ -189,7 +232,7 @@ int LSTM::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) cons | |||
| // Uni directional | |||
| if (direction == 0 || direction == 1) | |||
| { | |||
| int ret = lstm(bottom_blob, top_blob, direction, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), hidden, cell, opt); | |||
| int ret = lstm(bottom_blob, top_blob, direction, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt); | |||
| if (ret != 0) | |||
| return ret; | |||
| } | |||
| @@ -204,14 +247,14 @@ int LSTM::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) cons | |||
| if (top_blob_reverse.empty()) | |||
| return -100; | |||
| int ret0 = lstm(bottom_blob, top_blob_forward, 0, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), hidden, cell, opt); | |||
| int ret0 = lstm(bottom_blob, top_blob_forward, 0, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt); | |||
| if (ret0 != 0) | |||
| return ret0; | |||
| hidden.fill(0.0f); | |||
| cell.fill(0.0f); | |||
| int ret1 = lstm(bottom_blob, top_blob_reverse, 1, weight_xc_data.channel(1), bias_c_data.channel(1), weight_hc_data.channel(1), hidden, cell, opt); | |||
| int ret1 = lstm(bottom_blob, top_blob_reverse, 1, weight_xc_data.channel(1), bias_c_data.channel(1), weight_hc_data.channel(1), num_output == hidden_size ? Mat() : weight_hr_data.channel(1), hidden, cell, opt); | |||
| if (ret1 != 0) | |||
| return ret1; | |||
| @@ -251,7 +294,7 @@ int LSTM::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top_bl | |||
| return -100; | |||
| hidden.fill(0.f); | |||
| cell.create(num_output, num_directions, 4u, hidden_cell_allocator); | |||
| cell.create(hidden_size, num_directions, 4u, hidden_cell_allocator); | |||
| if (cell.empty()) | |||
| return -100; | |||
| cell.fill(0.f); | |||
| @@ -265,7 +308,7 @@ int LSTM::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top_bl | |||
| // Uni directional | |||
| if (direction == 0 || direction == 1) | |||
| { | |||
| int ret = lstm(bottom_blob, top_blob, direction, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), hidden, cell, opt); | |||
| int ret = lstm(bottom_blob, top_blob, direction, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt); | |||
| if (ret != 0) | |||
| return ret; | |||
| } | |||
| @@ -282,13 +325,13 @@ int LSTM::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top_bl | |||
| Mat hidden0 = hidden.row_range(0, 1); | |||
| Mat cell0 = cell.row_range(0, 1); | |||
| int ret0 = lstm(bottom_blob, top_blob_forward, 0, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), hidden0, cell0, opt); | |||
| int ret0 = lstm(bottom_blob, top_blob_forward, 0, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden0, cell0, opt); | |||
| if (ret0 != 0) | |||
| return ret0; | |||
| Mat hidden1 = hidden.row_range(1, 1); | |||
| Mat cell1 = cell.row_range(1, 1); | |||
| int ret1 = lstm(bottom_blob, top_blob_reverse, 1, weight_xc_data.channel(1), bias_c_data.channel(1), weight_hc_data.channel(1), hidden1, cell1, opt); | |||
| int ret1 = lstm(bottom_blob, top_blob_reverse, 1, weight_xc_data.channel(1), bias_c_data.channel(1), weight_hc_data.channel(1), num_output == hidden_size ? Mat() : weight_hr_data.channel(1), hidden1, cell1, opt); | |||
| if (ret1 != 0) | |||
| return ret1; | |||
| @@ -36,10 +36,12 @@ public: | |||
| int num_output; | |||
| int weight_data_size; | |||
| int direction; // 0=forward 1=reverse 2=bidirectional | |||
| int hidden_size; | |||
| Mat weight_hc_data; | |||
| Mat weight_xc_data; | |||
| Mat bias_c_data; | |||
| Mat weight_hr_data; | |||
| }; | |||
| } // namespace ncnn | |||
| @@ -14,6 +14,13 @@ | |||
| #include "lstm_x86.h" | |||
| #if __SSE2__ | |||
| #include <emmintrin.h> | |||
| #if __AVX__ | |||
| #include <immintrin.h> | |||
| #endif | |||
| #endif // __SSE2__ | |||
| #include "x86_activation.h" | |||
| #include "x86_usability.h" | |||
| @@ -30,23 +37,183 @@ LSTM_x86::LSTM_x86() | |||
| int LSTM_x86::create_pipeline(const Option& opt) | |||
| { | |||
| (void)(opt); | |||
| // pack IFOG | |||
| int num_directions = direction == 2 ? 2 : 1; | |||
| int size = weight_data_size / num_directions / hidden_size / 4; | |||
| #if __AVX__ | |||
| weight_xc_data_packed.create(size, hidden_size / 2 + hidden_size % 2, num_directions, 32u, 8); | |||
| bias_c_data_packed.create(hidden_size, 1, num_directions, 16u, 4); | |||
| weight_hc_data_packed.create(num_output, hidden_size / 2 + hidden_size % 2, num_directions, 32u, 8); | |||
| #else | |||
| weight_xc_data_packed.create(size, hidden_size, num_directions, 16u, 4); | |||
| bias_c_data_packed.create(hidden_size, 1, num_directions, 16u, 4); | |||
| weight_hc_data_packed.create(num_output, hidden_size, num_directions, 16u, 4); | |||
| #endif | |||
| #pragma omp parallel for num_threads(opt.num_threads) | |||
| for (int dr = 0; dr < num_directions; dr++) | |||
| { | |||
| const Mat weight_xc = weight_xc_data.channel(dr); | |||
| const Mat bias_c = bias_c_data.channel(dr); | |||
| const Mat weight_hc = weight_hc_data.channel(dr); | |||
| Mat weight_xc_data_packed_dr = weight_xc_data_packed.channel(dr); | |||
| Mat bias_c_data_packed_dr = bias_c_data_packed.channel(dr); | |||
| Mat weight_hc_data_packed_dr = weight_hc_data_packed.channel(dr); | |||
| const float* bias_c_I = bias_c.row(0); | |||
| const float* bias_c_F = bias_c.row(1); | |||
| const float* bias_c_O = bias_c.row(2); | |||
| const float* bias_c_G = bias_c.row(3); | |||
| float* bias_c_IFOG = bias_c_data_packed_dr.row(0); | |||
| int q = 0; | |||
| #if __AVX__ | |||
| for (; q + 1 < hidden_size; q += 2) | |||
| { | |||
| bias_c_IFOG[0] = bias_c_I[q]; | |||
| bias_c_IFOG[1] = bias_c_F[q]; | |||
| bias_c_IFOG[2] = bias_c_O[q]; | |||
| bias_c_IFOG[3] = bias_c_G[q]; | |||
| bias_c_IFOG[4] = bias_c_I[q + 1]; | |||
| bias_c_IFOG[5] = bias_c_F[q + 1]; | |||
| bias_c_IFOG[6] = bias_c_O[q + 1]; | |||
| bias_c_IFOG[7] = bias_c_G[q + 1]; | |||
| bias_c_IFOG += 8; | |||
| const float* weight_xc_I = weight_xc.row(hidden_size * 0 + q); | |||
| const float* weight_xc_F = weight_xc.row(hidden_size * 1 + q); | |||
| const float* weight_xc_O = weight_xc.row(hidden_size * 2 + q); | |||
| const float* weight_xc_G = weight_xc.row(hidden_size * 3 + q); | |||
| const float* weight_xc_I_1 = weight_xc.row(hidden_size * 0 + q + 1); | |||
| const float* weight_xc_F_1 = weight_xc.row(hidden_size * 1 + q + 1); | |||
| const float* weight_xc_O_1 = weight_xc.row(hidden_size * 2 + q + 1); | |||
| const float* weight_xc_G_1 = weight_xc.row(hidden_size * 3 + q + 1); | |||
| const float* weight_hc_I = weight_hc.row(hidden_size * 0 + q); | |||
| const float* weight_hc_F = weight_hc.row(hidden_size * 1 + q); | |||
| const float* weight_hc_O = weight_hc.row(hidden_size * 2 + q); | |||
| const float* weight_hc_G = weight_hc.row(hidden_size * 3 + q); | |||
| const float* weight_hc_I_1 = weight_hc.row(hidden_size * 0 + q + 1); | |||
| const float* weight_hc_F_1 = weight_hc.row(hidden_size * 1 + q + 1); | |||
| const float* weight_hc_O_1 = weight_hc.row(hidden_size * 2 + q + 1); | |||
| const float* weight_hc_G_1 = weight_hc.row(hidden_size * 3 + q + 1); | |||
| float* weight_xc_IFOG = weight_xc_data_packed_dr.row(q / 2); | |||
| float* weight_hc_IFOG = weight_hc_data_packed_dr.row(q / 2); | |||
| for (int i = 0; i < size; i++) | |||
| { | |||
| weight_xc_IFOG[0] = weight_xc_I[i]; | |||
| weight_xc_IFOG[1] = weight_xc_F[i]; | |||
| weight_xc_IFOG[2] = weight_xc_O[i]; | |||
| weight_xc_IFOG[3] = weight_xc_G[i]; | |||
| weight_xc_IFOG[4] = weight_xc_I_1[i]; | |||
| weight_xc_IFOG[5] = weight_xc_F_1[i]; | |||
| weight_xc_IFOG[6] = weight_xc_O_1[i]; | |||
| weight_xc_IFOG[7] = weight_xc_G_1[i]; | |||
| weight_xc_IFOG += 8; | |||
| } | |||
| for (int i = 0; i < num_output; i++) | |||
| { | |||
| weight_hc_IFOG[0] = weight_hc_I[i]; | |||
| weight_hc_IFOG[1] = weight_hc_F[i]; | |||
| weight_hc_IFOG[2] = weight_hc_O[i]; | |||
| weight_hc_IFOG[3] = weight_hc_G[i]; | |||
| weight_hc_IFOG[4] = weight_hc_I_1[i]; | |||
| weight_hc_IFOG[5] = weight_hc_F_1[i]; | |||
| weight_hc_IFOG[6] = weight_hc_O_1[i]; | |||
| weight_hc_IFOG[7] = weight_hc_G_1[i]; | |||
| weight_hc_IFOG += 8; | |||
| } | |||
| } | |||
| #endif // __AVX__ | |||
| for (; q < hidden_size; q++) | |||
| { | |||
| bias_c_IFOG[0] = bias_c_I[q]; | |||
| bias_c_IFOG[1] = bias_c_F[q]; | |||
| bias_c_IFOG[2] = bias_c_O[q]; | |||
| bias_c_IFOG[3] = bias_c_G[q]; | |||
| bias_c_IFOG += 4; | |||
| const float* weight_xc_I = weight_xc.row(hidden_size * 0 + q); | |||
| const float* weight_xc_F = weight_xc.row(hidden_size * 1 + q); | |||
| const float* weight_xc_O = weight_xc.row(hidden_size * 2 + q); | |||
| const float* weight_xc_G = weight_xc.row(hidden_size * 3 + q); | |||
| const float* weight_hc_I = weight_hc.row(hidden_size * 0 + q); | |||
| const float* weight_hc_F = weight_hc.row(hidden_size * 1 + q); | |||
| const float* weight_hc_O = weight_hc.row(hidden_size * 2 + q); | |||
| const float* weight_hc_G = weight_hc.row(hidden_size * 3 + q); | |||
| #if __AVX__ | |||
| float* weight_xc_IFOG = weight_xc_data_packed_dr.row(q / 2 + q % 2); | |||
| float* weight_hc_IFOG = weight_hc_data_packed_dr.row(q / 2 + q % 2); | |||
| #else | |||
| float* weight_xc_IFOG = weight_xc_data_packed_dr.row(q); | |||
| float* weight_hc_IFOG = weight_hc_data_packed_dr.row(q); | |||
| #endif | |||
| for (int i = 0; i < size; i++) | |||
| { | |||
| weight_xc_IFOG[0] = weight_xc_I[i]; | |||
| weight_xc_IFOG[1] = weight_xc_F[i]; | |||
| weight_xc_IFOG[2] = weight_xc_O[i]; | |||
| weight_xc_IFOG[3] = weight_xc_G[i]; | |||
| weight_xc_IFOG += 4; | |||
| } | |||
| for (int i = 0; i < num_output; i++) | |||
| { | |||
| weight_hc_IFOG[0] = weight_hc_I[i]; | |||
| weight_hc_IFOG[1] = weight_hc_F[i]; | |||
| weight_hc_IFOG[2] = weight_hc_O[i]; | |||
| weight_hc_IFOG[3] = weight_hc_G[i]; | |||
| weight_hc_IFOG += 4; | |||
| } | |||
| } | |||
| } | |||
| if (opt.lightmode) | |||
| { | |||
| weight_xc_data.release(); | |||
| bias_c_data.release(); | |||
| weight_hc_data.release(); | |||
| } | |||
| return 0; | |||
| } | |||
| #ifdef __AVX__ | |||
| static int lstm(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& weight_xc, const Mat& bias_c, const Mat& weight_hc, Mat& hidden_state, Mat& cell_state, const Option& opt) | |||
| static int lstm(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& weight_xc, const Mat& bias_c, const Mat& weight_hc, const Mat& weight_hr, Mat& hidden_state, Mat& cell_state, const Option& opt) | |||
| { | |||
| int size = bottom_blob.w; | |||
| int T = bottom_blob.h; | |||
| int num_output = top_blob.w; | |||
| int hidden_size = cell_state.w; | |||
| // 4 x num_output | |||
| Mat gates(num_output, 4, 4u, opt.workspace_allocator); | |||
| // 4 x hidden_size | |||
| Mat gates(4, hidden_size, 4u, opt.workspace_allocator); | |||
| if (gates.empty()) | |||
| return -100; | |||
| Mat tmp_hidden_state; | |||
| if (num_output != hidden_size) | |||
| { | |||
| tmp_hidden_state.create(hidden_size, 4u, opt.workspace_allocator); | |||
| if (tmp_hidden_state.empty()) | |||
| return -100; | |||
| } | |||
| // unroll | |||
| for (int t = 0; t < T; t++) | |||
| { | |||
| @@ -59,267 +226,222 @@ static int lstm(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& w | |||
| int ti = reverse ? T - 1 - t : t; | |||
| int nn_num_output = num_output >> 1; | |||
| int remain_num_output_start = nn_num_output << 1; | |||
| #if __AVX__ | |||
| int nn_hidden_size = hidden_size >> 1; | |||
| int remain_hidden_size_start = nn_hidden_size << 1; | |||
| #pragma omp parallel for num_threads(opt.num_threads) | |||
| for (int qq = 0; qq < nn_num_output; qq++) | |||
| for (int qq = 0; qq < nn_hidden_size; qq++) | |||
| { | |||
| int q = qq * 2; | |||
| const float* x = bottom_blob.row(ti); | |||
| const float* hidden_ptr_r = hidden_state; | |||
| const float* bias_c_I = bias_c.row(0); | |||
| const float* bias_c_F = bias_c.row(1); | |||
| const float* bias_c_O = bias_c.row(2); | |||
| const float* bias_c_G = bias_c.row(3); | |||
| float* gates_data_I = gates.row(0); | |||
| float* gates_data_F = gates.row(1); | |||
| float* gates_data_O = gates.row(2); | |||
| float* gates_data_G = gates.row(3); | |||
| const float* bias_c_IFOG = (const float*)bias_c + q * 4; | |||
| // gate I F O G | |||
| const float* weight_xc_I_0 = weight_xc.row(num_output * 0 + q); | |||
| const float* weight_xc_F_0 = weight_xc.row(num_output * 1 + q); | |||
| const float* weight_xc_O_0 = weight_xc.row(num_output * 2 + q); | |||
| const float* weight_xc_G_0 = weight_xc.row(num_output * 3 + q); | |||
| const float* weight_xc_I_1 = weight_xc.row(num_output * 0 + (q + 1)); | |||
| const float* weight_xc_F_1 = weight_xc.row(num_output * 1 + (q + 1)); | |||
| const float* weight_xc_O_1 = weight_xc.row(num_output * 2 + (q + 1)); | |||
| const float* weight_xc_G_1 = weight_xc.row(num_output * 3 + (q + 1)); | |||
| const float* weight_hc_I_0 = weight_hc.row(num_output * 0 + q); | |||
| const float* weight_hc_F_0 = weight_hc.row(num_output * 1 + q); | |||
| const float* weight_hc_O_0 = weight_hc.row(num_output * 2 + q); | |||
| const float* weight_hc_G_0 = weight_hc.row(num_output * 3 + q); | |||
| const float* weight_hc_I_1 = weight_hc.row(num_output * 0 + (q + 1)); | |||
| const float* weight_hc_F_1 = weight_hc.row(num_output * 1 + (q + 1)); | |||
| const float* weight_hc_O_1 = weight_hc.row(num_output * 2 + (q + 1)); | |||
| const float* weight_hc_G_1 = weight_hc.row(num_output * 3 + (q + 1)); | |||
| // float I = bias_c_I[q]; | |||
| // float F = bias_c_F[q]; | |||
| // float O = bias_c_O[q]; | |||
| // float G = bias_c_G[q]; | |||
| __m256 _sumI_0 = _mm256_setzero_ps(); | |||
| __m256 _sumF_0 = _mm256_setzero_ps(); | |||
| __m256 _sumO_0 = _mm256_setzero_ps(); | |||
| __m256 _sumG_0 = _mm256_setzero_ps(); | |||
| __m256 _sumI_1 = _mm256_setzero_ps(); | |||
| __m256 _sumF_1 = _mm256_setzero_ps(); | |||
| __m256 _sumO_1 = _mm256_setzero_ps(); | |||
| __m256 _sumG_1 = _mm256_setzero_ps(); | |||
| int nn_num_size = size >> 3; | |||
| int remain_size = size & 7; | |||
| for (; nn_num_size > 0; nn_num_size--) | |||
| const float* weight_xc_IFOG = weight_xc.row(q / 2); | |||
| const float* weight_hc_IFOG = weight_hc.row(q / 2); | |||
| __m256 _IFOG = _mm256_loadu_ps(bias_c_IFOG); | |||
| __m256 _sum1 = _mm256_setzero_ps(); | |||
| __m256 _sum2 = _mm256_setzero_ps(); | |||
| __m256 _sum3 = _mm256_setzero_ps(); | |||
| const float* x = bottom_blob.row(ti); | |||
| int i = 0; | |||
| for (; i + 3 < size; i += 4) | |||
| { | |||
| __m256 xi = _mm256_loadu_ps(x); | |||
| _sumI_0 = _mm256_comp_fmadd_ps(_mm256_loadu_ps(weight_xc_I_0), xi, _sumI_0); | |||
| _sumF_0 = _mm256_comp_fmadd_ps(_mm256_loadu_ps(weight_xc_F_0), xi, _sumF_0); | |||
| _sumO_0 = _mm256_comp_fmadd_ps(_mm256_loadu_ps(weight_xc_O_0), xi, _sumO_0); | |||
| _sumG_0 = _mm256_comp_fmadd_ps(_mm256_loadu_ps(weight_xc_G_0), xi, _sumG_0); | |||
| _sumI_1 = _mm256_comp_fmadd_ps(_mm256_loadu_ps(weight_xc_I_1), xi, _sumI_1); | |||
| _sumF_1 = _mm256_comp_fmadd_ps(_mm256_loadu_ps(weight_xc_F_1), xi, _sumF_1); | |||
| _sumO_1 = _mm256_comp_fmadd_ps(_mm256_loadu_ps(weight_xc_O_1), xi, _sumO_1); | |||
| _sumG_1 = _mm256_comp_fmadd_ps(_mm256_loadu_ps(weight_xc_G_1), xi, _sumG_1); | |||
| x += 8; | |||
| weight_xc_I_0 += 8; | |||
| weight_xc_F_0 += 8; | |||
| weight_xc_O_0 += 8; | |||
| weight_xc_G_0 += 8; | |||
| weight_xc_I_1 += 8; | |||
| weight_xc_F_1 += 8; | |||
| weight_xc_O_1 += 8; | |||
| weight_xc_G_1 += 8; | |||
| __m256 _xi0 = _mm256_broadcast_ss(x); | |||
| __m256 _xi1 = _mm256_broadcast_ss(x + 1); | |||
| __m256 _xi2 = _mm256_broadcast_ss(x + 2); | |||
| __m256 _xi3 = _mm256_broadcast_ss(x + 3); | |||
| __m256 _weight_xc_IFOG0 = _mm256_loadu_ps(weight_xc_IFOG); | |||
| __m256 _weight_xc_IFOG1 = _mm256_loadu_ps(weight_xc_IFOG + 8); | |||
| __m256 _weight_xc_IFOG2 = _mm256_loadu_ps(weight_xc_IFOG + 16); | |||
| __m256 _weight_xc_IFOG3 = _mm256_loadu_ps(weight_xc_IFOG + 24); | |||
| _IFOG = _mm256_comp_fmadd_ps(_weight_xc_IFOG0, _xi0, _IFOG); | |||
| _sum1 = _mm256_comp_fmadd_ps(_weight_xc_IFOG1, _xi1, _sum1); | |||
| _sum2 = _mm256_comp_fmadd_ps(_weight_xc_IFOG2, _xi2, _sum2); | |||
| _sum3 = _mm256_comp_fmadd_ps(_weight_xc_IFOG3, _xi3, _sum3); | |||
| x += 4; | |||
| weight_xc_IFOG += 32; | |||
| } | |||
| int nn_num_output = num_output >> 3; | |||
| int remain_num_output = num_output & 7; | |||
| for (; nn_num_output > 0; nn_num_output--) | |||
| for (; i < size; i++) | |||
| { | |||
| __m256 h_cont = _mm256_loadu_ps(hidden_ptr_r); | |||
| _sumI_0 = _mm256_comp_fmadd_ps(_mm256_loadu_ps(weight_hc_I_0), h_cont, _sumI_0); | |||
| _sumF_0 = _mm256_comp_fmadd_ps(_mm256_loadu_ps(weight_hc_F_0), h_cont, _sumF_0); | |||
| _sumO_0 = _mm256_comp_fmadd_ps(_mm256_loadu_ps(weight_hc_O_0), h_cont, _sumO_0); | |||
| _sumG_0 = _mm256_comp_fmadd_ps(_mm256_loadu_ps(weight_hc_G_0), h_cont, _sumG_0); | |||
| _sumI_1 = _mm256_comp_fmadd_ps(_mm256_loadu_ps(weight_hc_I_1), h_cont, _sumI_1); | |||
| _sumF_1 = _mm256_comp_fmadd_ps(_mm256_loadu_ps(weight_hc_F_1), h_cont, _sumF_1); | |||
| _sumO_1 = _mm256_comp_fmadd_ps(_mm256_loadu_ps(weight_hc_O_1), h_cont, _sumO_1); | |||
| _sumG_1 = _mm256_comp_fmadd_ps(_mm256_loadu_ps(weight_hc_G_1), h_cont, _sumG_1); | |||
| hidden_ptr_r += 8; | |||
| weight_hc_I_0 += 8; | |||
| weight_hc_F_0 += 8; | |||
| weight_hc_O_0 += 8; | |||
| weight_hc_G_0 += 8; | |||
| weight_hc_I_1 += 8; | |||
| weight_hc_F_1 += 8; | |||
| weight_hc_O_1 += 8; | |||
| weight_hc_G_1 += 8; | |||
| __m256 _xi = _mm256_broadcast_ss(x); | |||
| __m256 _weight_xc_IFOG = _mm256_loadu_ps(weight_xc_IFOG); | |||
| _IFOG = _mm256_comp_fmadd_ps(_weight_xc_IFOG, _xi, _IFOG); | |||
| x += 1; | |||
| weight_xc_IFOG += 8; | |||
| } | |||
| float sums[8]; | |||
| _mm256_storeu_ps(sums, HorizontalSums(_sumI_0, _sumF_0, _sumO_0, _sumG_0, _sumI_1, _sumF_1, _sumO_1, _sumG_1)); | |||
| sums[0] += bias_c_I[q]; | |||
| sums[1] += bias_c_F[q]; | |||
| sums[2] += bias_c_O[q]; | |||
| sums[3] += bias_c_G[q]; | |||
| sums[4] += bias_c_I[q + 1]; | |||
| sums[5] += bias_c_F[q + 1]; | |||
| sums[6] += bias_c_O[q + 1]; | |||
| sums[7] += bias_c_G[q + 1]; | |||
| for (; remain_size > 0; remain_size--) | |||
| const float* hidden_ptr = hidden_state; | |||
| i = 0; | |||
| for (; i + 3 < num_output; i += 4) | |||
| { | |||
| float xi = *x; | |||
| sums[0] += *weight_xc_I_0 * xi; | |||
| sums[1] += *weight_xc_F_0 * xi; | |||
| sums[2] += *weight_xc_O_0 * xi; | |||
| sums[3] += *weight_xc_G_0 * xi; | |||
| sums[4] += *weight_xc_I_1 * xi; | |||
| sums[5] += *weight_xc_F_1 * xi; | |||
| sums[6] += *weight_xc_O_1 * xi; | |||
| sums[7] += *weight_xc_G_1 * xi; | |||
| x++; | |||
| weight_xc_I_0++; | |||
| weight_xc_F_0++; | |||
| weight_xc_O_0++; | |||
| weight_xc_G_0++; | |||
| weight_xc_I_1++; | |||
| weight_xc_F_1++; | |||
| weight_xc_O_1++; | |||
| weight_xc_G_1++; | |||
| __m256 _h_cont0 = _mm256_broadcast_ss(hidden_ptr); | |||
| __m256 _h_cont1 = _mm256_broadcast_ss(hidden_ptr + 1); | |||
| __m256 _h_cont2 = _mm256_broadcast_ss(hidden_ptr + 2); | |||
| __m256 _h_cont3 = _mm256_broadcast_ss(hidden_ptr + 3); | |||
| __m256 _weight_hc_IFOG0 = _mm256_loadu_ps(weight_hc_IFOG); | |||
| __m256 _weight_hc_IFOG1 = _mm256_loadu_ps(weight_hc_IFOG + 8); | |||
| __m256 _weight_hc_IFOG2 = _mm256_loadu_ps(weight_hc_IFOG + 16); | |||
| __m256 _weight_hc_IFOG3 = _mm256_loadu_ps(weight_hc_IFOG + 24); | |||
| _IFOG = _mm256_comp_fmadd_ps(_weight_hc_IFOG0, _h_cont0, _IFOG); | |||
| _sum1 = _mm256_comp_fmadd_ps(_weight_hc_IFOG1, _h_cont1, _sum1); | |||
| _sum2 = _mm256_comp_fmadd_ps(_weight_hc_IFOG2, _h_cont2, _sum2); | |||
| _sum3 = _mm256_comp_fmadd_ps(_weight_hc_IFOG3, _h_cont3, _sum3); | |||
| hidden_ptr += 4; | |||
| weight_hc_IFOG += 32; | |||
| } | |||
| for (; remain_num_output > 0; remain_num_output--) | |||
| for (; i < num_output; i++) | |||
| { | |||
| float h_cont = *hidden_ptr_r; | |||
| sums[0] += *weight_hc_I_0 * h_cont; | |||
| sums[1] += *weight_hc_F_0 * h_cont; | |||
| sums[2] += *weight_hc_O_0 * h_cont; | |||
| sums[3] += *weight_hc_G_0 * h_cont; | |||
| sums[4] += *weight_hc_I_1 * h_cont; | |||
| sums[5] += *weight_hc_F_1 * h_cont; | |||
| sums[6] += *weight_hc_O_1 * h_cont; | |||
| sums[7] += *weight_hc_G_1 * h_cont; | |||
| hidden_ptr_r++; | |||
| weight_hc_I_0++; | |||
| weight_hc_F_0++; | |||
| weight_hc_O_0++; | |||
| weight_hc_G_0++; | |||
| weight_hc_I_1++; | |||
| weight_hc_F_1++; | |||
| weight_hc_O_1++; | |||
| weight_hc_G_1++; | |||
| __m256 _h_cont = _mm256_broadcast_ss(hidden_ptr); | |||
| __m256 _weight_hc_IFOG = _mm256_loadu_ps(weight_hc_IFOG); | |||
| _IFOG = _mm256_comp_fmadd_ps(_weight_hc_IFOG, _h_cont, _IFOG); | |||
| hidden_ptr += 1; | |||
| weight_hc_IFOG += 8; | |||
| } | |||
| gates_data_I[q] = sums[0]; | |||
| gates_data_F[q] = sums[1]; | |||
| gates_data_O[q] = sums[2]; | |||
| gates_data_G[q] = sums[3]; | |||
| gates_data_I[q + 1] = sums[4]; | |||
| gates_data_F[q + 1] = sums[5]; | |||
| gates_data_O[q + 1] = sums[6]; | |||
| gates_data_G[q + 1] = sums[7]; | |||
| float* gates_data = gates.row(q); | |||
| _IFOG = _mm256_add_ps(_IFOG, _sum1); | |||
| _sum2 = _mm256_add_ps(_sum2, _sum3); | |||
| _IFOG = _mm256_add_ps(_IFOG, _sum2); | |||
| _mm256_storeu_ps(gates_data, _IFOG); | |||
| } | |||
| #else | |||
| int nn_hidden_size = 0; | |||
| int remain_hidden_size_start = 0; | |||
| #endif // __AVX__ | |||
| #pragma omp parallel for num_threads(opt.num_threads) | |||
| for (int q = remain_num_output_start; q < num_output; q++) | |||
| for (int q = remain_hidden_size_start; q < hidden_size; q++) | |||
| { | |||
| const float* x = bottom_blob.row(ti); | |||
| const float* hidden_ptr_r = hidden_state; | |||
| const float* bias_c_I = bias_c.row(0); | |||
| const float* bias_c_F = bias_c.row(1); | |||
| const float* bias_c_O = bias_c.row(2); | |||
| const float* bias_c_G = bias_c.row(3); | |||
| float* gates_data_I = gates.row(0); | |||
| float* gates_data_F = gates.row(1); | |||
| float* gates_data_O = gates.row(2); | |||
| float* gates_data_G = gates.row(3); | |||
| const float* bias_c_IFOG = (const float*)bias_c + q * 4; | |||
| // gate I F O G | |||
| const float* weight_xc_I = weight_xc.row(num_output * 0 + q); | |||
| const float* weight_xc_F = weight_xc.row(num_output * 1 + q); | |||
| const float* weight_xc_O = weight_xc.row(num_output * 2 + q); | |||
| const float* weight_xc_G = weight_xc.row(num_output * 3 + q); | |||
| const float* weight_hc_I = weight_hc.row(num_output * 0 + q); | |||
| const float* weight_hc_F = weight_hc.row(num_output * 1 + q); | |||
| const float* weight_hc_O = weight_hc.row(num_output * 2 + q); | |||
| const float* weight_hc_G = weight_hc.row(num_output * 3 + q); | |||
| // float I = bias_c_I[q]; | |||
| // float F = bias_c_F[q]; | |||
| // float O = bias_c_O[q]; | |||
| // float G = bias_c_G[q]; | |||
| __m256 _sumI = _mm256_setzero_ps(); | |||
| __m256 _sumF = _mm256_setzero_ps(); | |||
| __m256 _sumO = _mm256_setzero_ps(); | |||
| __m256 _sumG = _mm256_setzero_ps(); | |||
| int nn_num_size = size >> 3; | |||
| int remain_size = size & 7; | |||
| for (; nn_num_size > 0; nn_num_size--) | |||
| #if __AVX__ | |||
| const float* weight_xc_IFOG = weight_xc.row(q / 2 + q % 2); | |||
| const float* weight_hc_IFOG = weight_hc.row(q / 2 + q % 2); | |||
| #else | |||
| const float* weight_xc_IFOG = weight_xc.row(q); | |||
| const float* weight_hc_IFOG = weight_hc.row(q); | |||
| #endif | |||
| #if __SSE2__ | |||
| __m128 _IFOG = _mm_loadu_ps(bias_c_IFOG); | |||
| __m128 _sum1 = _mm_setzero_ps(); | |||
| __m128 _sum2 = _mm_setzero_ps(); | |||
| __m128 _sum3 = _mm_setzero_ps(); | |||
| #else // __SSE2__ | |||
| float I = bias_c_IFOG[0]; | |||
| float F = bias_c_IFOG[1]; | |||
| float O = bias_c_IFOG[2]; | |||
| float G = bias_c_IFOG[3]; | |||
| #endif // __SSE2__ | |||
| const float* x = bottom_blob.row(ti); | |||
| int i = 0; | |||
| #if __SSE2__ | |||
| for (; i + 3 < size; i += 4) | |||
| { | |||
| __m256 xi = _mm256_loadu_ps(x); | |||
| _sumI = _mm256_comp_fmadd_ps(_mm256_loadu_ps(weight_xc_I), xi, _sumI); | |||
| _sumF = _mm256_comp_fmadd_ps(_mm256_loadu_ps(weight_xc_F), xi, _sumF); | |||
| _sumO = _mm256_comp_fmadd_ps(_mm256_loadu_ps(weight_xc_O), xi, _sumO); | |||
| _sumG = _mm256_comp_fmadd_ps(_mm256_loadu_ps(weight_xc_G), xi, _sumG); | |||
| x += 8; | |||
| weight_xc_I += 8; | |||
| weight_xc_F += 8; | |||
| weight_xc_O += 8; | |||
| weight_xc_G += 8; | |||
| __m128 _xi0 = _mm_load1_ps(x); | |||
| __m128 _xi1 = _mm_load1_ps(x + 1); | |||
| __m128 _xi2 = _mm_load1_ps(x + 2); | |||
| __m128 _xi3 = _mm_load1_ps(x + 3); | |||
| __m128 _weight_xc_IFOG0 = _mm_loadu_ps(weight_xc_IFOG); | |||
| __m128 _weight_xc_IFOG1 = _mm_loadu_ps(weight_xc_IFOG + 4); | |||
| __m128 _weight_xc_IFOG2 = _mm_loadu_ps(weight_xc_IFOG + 8); | |||
| __m128 _weight_xc_IFOG3 = _mm_loadu_ps(weight_xc_IFOG + 12); | |||
| _IFOG = _mm_comp_fmadd_ps(_weight_xc_IFOG0, _xi0, _IFOG); | |||
| _sum1 = _mm_comp_fmadd_ps(_weight_xc_IFOG1, _xi1, _sum1); | |||
| _sum2 = _mm_comp_fmadd_ps(_weight_xc_IFOG2, _xi2, _sum2); | |||
| _sum3 = _mm_comp_fmadd_ps(_weight_xc_IFOG3, _xi3, _sum3); | |||
| x += 4; | |||
| weight_xc_IFOG += 16; | |||
| } | |||
| int nn_num_output = num_output >> 3; | |||
| int remain_num_output = num_output & 7; | |||
| for (; nn_num_output > 0; nn_num_output--) | |||
| #endif // __SSE2__ | |||
| for (; i < size; i++) | |||
| { | |||
| __m256 h_cont = _mm256_loadu_ps(hidden_ptr_r); | |||
| _sumI = _mm256_comp_fmadd_ps(_mm256_loadu_ps(weight_hc_I), h_cont, _sumI); | |||
| _sumF = _mm256_comp_fmadd_ps(_mm256_loadu_ps(weight_hc_F), h_cont, _sumF); | |||
| _sumO = _mm256_comp_fmadd_ps(_mm256_loadu_ps(weight_hc_O), h_cont, _sumO); | |||
| _sumG = _mm256_comp_fmadd_ps(_mm256_loadu_ps(weight_hc_G), h_cont, _sumG); | |||
| hidden_ptr_r += 8; | |||
| weight_hc_I += 8; | |||
| weight_hc_F += 8; | |||
| weight_hc_O += 8; | |||
| weight_hc_G += 8; | |||
| #if __SSE2__ | |||
| __m128 _xi = _mm_load1_ps(x); | |||
| __m128 _weight_xc_IFOG = _mm_loadu_ps(weight_xc_IFOG); | |||
| _IFOG = _mm_comp_fmadd_ps(_weight_xc_IFOG, _xi, _IFOG); | |||
| #else // __SSE2__ | |||
| float xi = x[0]; | |||
| I += xi * weight_xc_IFOG[0]; | |||
| F += xi * weight_xc_IFOG[1]; | |||
| O += xi * weight_xc_IFOG[2]; | |||
| G += xi * weight_xc_IFOG[3]; | |||
| #endif // __SSE2__ | |||
| x += 1; | |||
| weight_xc_IFOG += 4; | |||
| } | |||
| float sums[4]; | |||
| _mm_storeu_ps(sums, HorizontalSums(_sumI, _sumF, _sumO, _sumG)); | |||
| sums[0] += bias_c_I[q]; | |||
| sums[1] += bias_c_F[q]; | |||
| sums[2] += bias_c_O[q]; | |||
| sums[3] += bias_c_G[q]; | |||
| for (; remain_size > 0; remain_size--) | |||
| const float* hidden_ptr = hidden_state; | |||
| i = 0; | |||
| #if __SSE2__ | |||
| for (; i + 3 < num_output; i += 4) | |||
| { | |||
| float xi = *x; | |||
| sums[0] += *weight_xc_I * xi; | |||
| sums[1] += *weight_xc_F * xi; | |||
| sums[2] += *weight_xc_O * xi; | |||
| sums[3] += *weight_xc_G * xi; | |||
| x++; | |||
| weight_xc_I++; | |||
| weight_xc_F++; | |||
| weight_xc_O++; | |||
| weight_xc_G++; | |||
| __m128 _h_cont0 = _mm_load1_ps(hidden_ptr); | |||
| __m128 _h_cont1 = _mm_load1_ps(hidden_ptr + 1); | |||
| __m128 _h_cont2 = _mm_load1_ps(hidden_ptr + 2); | |||
| __m128 _h_cont3 = _mm_load1_ps(hidden_ptr + 3); | |||
| __m128 _weight_hc_IFOG0 = _mm_loadu_ps(weight_hc_IFOG); | |||
| __m128 _weight_hc_IFOG1 = _mm_loadu_ps(weight_hc_IFOG + 4); | |||
| __m128 _weight_hc_IFOG2 = _mm_loadu_ps(weight_hc_IFOG + 8); | |||
| __m128 _weight_hc_IFOG3 = _mm_loadu_ps(weight_hc_IFOG + 12); | |||
| _IFOG = _mm_comp_fmadd_ps(_weight_hc_IFOG0, _h_cont0, _IFOG); | |||
| _sum1 = _mm_comp_fmadd_ps(_weight_hc_IFOG1, _h_cont1, _sum1); | |||
| _sum2 = _mm_comp_fmadd_ps(_weight_hc_IFOG2, _h_cont2, _sum2); | |||
| _sum3 = _mm_comp_fmadd_ps(_weight_hc_IFOG3, _h_cont3, _sum3); | |||
| hidden_ptr += 4; | |||
| weight_hc_IFOG += 16; | |||
| } | |||
| for (; remain_num_output > 0; remain_num_output--) | |||
| #endif // __SSE2__ | |||
| for (; i < num_output; i++) | |||
| { | |||
| float h_cont = *hidden_ptr_r; | |||
| sums[0] += *weight_hc_I * h_cont; | |||
| sums[1] += *weight_hc_F * h_cont; | |||
| sums[2] += *weight_hc_O * h_cont; | |||
| sums[3] += *weight_hc_G * h_cont; | |||
| hidden_ptr_r++; | |||
| weight_hc_I++; | |||
| weight_hc_F++; | |||
| weight_hc_O++; | |||
| weight_hc_G++; | |||
| #if __SSE2__ | |||
| __m128 _h_cont = _mm_load1_ps(hidden_ptr); | |||
| __m128 _weight_hc_IFOG = _mm_loadu_ps(weight_hc_IFOG); | |||
| _IFOG = _mm_comp_fmadd_ps(_weight_hc_IFOG, _h_cont, _IFOG); | |||
| #else // __SSE2__ | |||
| float h_cont = hidden_ptr[0]; | |||
| I += h_cont * weight_hc_IFOG[0]; | |||
| F += h_cont * weight_hc_IFOG[1]; | |||
| O += h_cont * weight_hc_IFOG[2]; | |||
| G += h_cont * weight_hc_IFOG[3]; | |||
| #endif // __SSE2__ | |||
| hidden_ptr += 1; | |||
| weight_hc_IFOG += 4; | |||
| } | |||
| gates_data_I[q] = sums[0]; | |||
| gates_data_F[q] = sums[1]; | |||
| gates_data_O[q] = sums[2]; | |||
| gates_data_G[q] = sums[3]; | |||
| float* gates_data = gates.row(q); | |||
| #if __SSE2__ | |||
| _IFOG = _mm_add_ps(_IFOG, _sum1); | |||
| _sum2 = _mm_add_ps(_sum2, _sum3); | |||
| _IFOG = _mm_add_ps(_IFOG, _sum2); | |||
| _mm_storeu_ps(gates_data, _IFOG); | |||
| #else // __SSE2__ | |||
| gates_data[0] = I; | |||
| gates_data[1] = F; | |||
| gates_data[2] = O; | |||
| gates_data[3] = G; | |||
| #endif // __SSE2__ | |||
| } | |||
| // lstm unit | |||
| @@ -330,69 +452,117 @@ static int lstm(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& w | |||
| // c_t := f_t .* c_{t-1} + i_t .* g_t | |||
| // h_t := o_t .* tanh[c_t] | |||
| float* output_data = top_blob.row(ti); | |||
| float* cell_ptr = cell_state; | |||
| float* hidden_ptr = hidden_state; | |||
| const float* gates_data_I = gates.row(0); | |||
| const float* gates_data_F = gates.row(1); | |||
| const float* gates_data_O = gates.row(2); | |||
| const float* gates_data_G = gates.row(3); | |||
| int nn_activation = num_output >> 3; | |||
| int remain_activations = num_output & 7; | |||
| for (; nn_activation > 0; nn_activation--) | |||
| float* tmp_hidden_ptr = tmp_hidden_state; | |||
| #if __SSE2__ | |||
| nn_hidden_size = hidden_size >> 2; | |||
| remain_hidden_size_start = nn_hidden_size << 2; | |||
| #pragma omp parallel for num_threads(opt.num_threads) | |||
| for (int qq = 0; qq < nn_hidden_size; qq++) | |||
| { | |||
| __m256 I = sigmoid_avx(_mm256_loadu_ps(gates_data_I)); | |||
| __m256 F = sigmoid_avx(_mm256_loadu_ps(gates_data_F)); | |||
| __m256 O = sigmoid_avx(_mm256_loadu_ps(gates_data_O)); | |||
| __m256 G = tanh_avx(_mm256_loadu_ps(gates_data_G)); | |||
| __m256 cell2 = _mm256_add_ps(_mm256_mul_ps(F, _mm256_loadu_ps(cell_ptr)), _mm256_mul_ps(I, G)); | |||
| __m256 H = _mm256_mul_ps(O, tanh_avx(cell2)); | |||
| _mm256_storeu_ps(cell_ptr, cell2); | |||
| _mm256_storeu_ps(hidden_ptr, H); | |||
| _mm256_storeu_ps(output_data, H); | |||
| cell_ptr += 8; | |||
| output_data += 8; | |||
| hidden_ptr += 8; | |||
| gates_data_I += 8; | |||
| gates_data_F += 8; | |||
| gates_data_O += 8; | |||
| gates_data_G += 8; | |||
| int q = qq * 4; | |||
| const float* gates_data = gates.row(q); | |||
| __m128 _IFOG_4x4_0 = _mm_loadu_ps(gates_data); | |||
| __m128 _IFOG_4x4_1 = _mm_loadu_ps(gates_data + 4); | |||
| __m128 _IFOG_4x4_2 = _mm_loadu_ps(gates_data + 8); | |||
| __m128 _IFOG_4x4_3 = _mm_loadu_ps(gates_data + 12); | |||
| _MM_TRANSPOSE4_PS(_IFOG_4x4_0, _IFOG_4x4_1, _IFOG_4x4_2, _IFOG_4x4_3); | |||
| __m128 _I = sigmoid_sse(_IFOG_4x4_0); | |||
| __m128 _F = sigmoid_sse(_IFOG_4x4_1); | |||
| __m128 _O = sigmoid_sse(_IFOG_4x4_2); | |||
| __m128 _G = tanh_sse(_IFOG_4x4_3); | |||
| __m128 _cell2 = _mm_add_ps(_mm_mul_ps(_F, _mm_loadu_ps(cell_ptr + q)), _mm_mul_ps(_I, _G)); | |||
| __m128 _H = _mm_mul_ps(_O, tanh_sse(_cell2)); | |||
| _mm_storeu_ps(cell_ptr + q, _cell2); | |||
| if (num_output == hidden_size) | |||
| { | |||
| _mm_storeu_ps(hidden_ptr + q, _H); | |||
| _mm_storeu_ps(output_data + q, _H); | |||
| } | |||
| else | |||
| { | |||
| _mm_storeu_ps(tmp_hidden_ptr + q, _H); | |||
| } | |||
| } | |||
| for (; remain_activations > 0; remain_activations--) | |||
| #else // __SSE2__ | |||
| remain_hidden_size_start = 0; | |||
| #endif // __SSE2__ | |||
| #pragma omp parallel for num_threads(opt.num_threads) | |||
| for (int q = remain_hidden_size_start; q < hidden_size; q++) | |||
| { | |||
| float I = *gates_data_I; | |||
| float F = *gates_data_F; | |||
| float O = *gates_data_O; | |||
| float G = *gates_data_G; | |||
| const float* gates_data = gates.row(q); | |||
| float I = gates_data[0]; | |||
| float F = gates_data[1]; | |||
| float O = gates_data[2]; | |||
| float G = gates_data[3]; | |||
| I = 1.f / (1.f + exp(-I)); | |||
| F = 1.f / (1.f + exp(-F)); | |||
| O = 1.f / (1.f + exp(-O)); | |||
| G = tanh(G); | |||
| float cell2 = F * *cell_ptr + I * G; | |||
| float cell2 = F * cell_ptr[q] + I * G; | |||
| float H = O * tanh(cell2); | |||
| *cell_ptr = cell2; | |||
| *hidden_ptr = H; | |||
| *output_data = H; | |||
| cell_ptr++; | |||
| output_data++; | |||
| hidden_ptr++; | |||
| gates_data_I++; | |||
| gates_data_F++; | |||
| gates_data_O++; | |||
| gates_data_G++; | |||
| cell_ptr[q] = cell2; | |||
| if (num_output == hidden_size) | |||
| { | |||
| hidden_ptr[q] = H; | |||
| output_data[q] = H; | |||
| } | |||
| else | |||
| { | |||
| tmp_hidden_ptr[q] = H; | |||
| } | |||
| } | |||
| // no cell output here | |||
| if (num_output != hidden_size) | |||
| { | |||
| // int nn_num_output = num_output >> 2; | |||
| // int remain_num_output_start = nn_num_output << 2; | |||
| // #pragma omp parallel for num_threads(opt.num_threads) | |||
| // for (int qq = 0; qq < nn_num_output; qq++) | |||
| // { | |||
| // int q = qq * 4; | |||
| // | |||
| // } | |||
| int remain_num_output_start = 0; | |||
| #pragma omp parallel for num_threads(opt.num_threads) | |||
| for (int q = remain_num_output_start; q < num_output; q++) | |||
| { | |||
| const float* hr = weight_hr.row(q); | |||
| const float* tmp_hidden_ptr = tmp_hidden_state; | |||
| float H = 0; | |||
| for (int i = 0; i < hidden_size; i++) | |||
| { | |||
| H += tmp_hidden_ptr[i] * hr[i]; | |||
| } | |||
| output_data[q] = H; | |||
| hidden_ptr[q] = H; | |||
| } | |||
| } | |||
| } | |||
| return 0; | |||
| } | |||
| #endif | |||
| int LSTM_x86::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const | |||
| { | |||
| #if __AVX__ | |||
| int T = bottom_blob.h; | |||
| int num_directions = direction == 2 ? 2 : 1; | |||
| // initial hidden state | |||
| @@ -400,8 +570,8 @@ int LSTM_x86::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) | |||
| if (hidden.empty()) | |||
| return -100; | |||
| hidden.fill(0.f); | |||
| // internal cell state | |||
| Mat cell(num_output, 4u, opt.workspace_allocator); | |||
| Mat cell(hidden_size, 4u, opt.workspace_allocator); | |||
| if (cell.empty()) | |||
| return -100; | |||
| cell.fill(0.f); | |||
| @@ -413,7 +583,7 @@ int LSTM_x86::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) | |||
| // Uni directional | |||
| if (direction == 0 || direction == 1) | |||
| { | |||
| int ret = lstm(bottom_blob, top_blob, direction, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), hidden, cell, opt); | |||
| int ret = lstm(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt); | |||
| if (ret != 0) | |||
| return ret; | |||
| } | |||
| @@ -428,14 +598,14 @@ int LSTM_x86::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) | |||
| if (top_blob_reverse.empty()) | |||
| return -100; | |||
| int ret0 = lstm(bottom_blob, top_blob_forward, 0, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), hidden, cell, opt); | |||
| int ret0 = lstm(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt); | |||
| if (ret0 != 0) | |||
| return ret0; | |||
| hidden.fill(0.0f); | |||
| cell.fill(0.0f); | |||
| int ret1 = lstm(bottom_blob, top_blob_reverse, 1, weight_xc_data.channel(1), bias_c_data.channel(1), weight_hc_data.channel(1), hidden, cell, opt); | |||
| int ret1 = lstm(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), num_output == hidden_size ? Mat() : weight_hr_data.channel(1), hidden, cell, opt); | |||
| if (ret1 != 0) | |||
| return ret1; | |||
| @@ -452,14 +622,10 @@ int LSTM_x86::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) | |||
| } | |||
| return 0; | |||
| #else | |||
| return LSTM::forward(bottom_blob, top_blob, opt); | |||
| #endif | |||
| } | |||
| int LSTM_x86::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top_blobs, const Option& opt) const | |||
| { | |||
| #if __AVX__ | |||
| const Mat& bottom_blob = bottom_blobs[0]; | |||
| int T = bottom_blob.h; | |||
| int num_directions = direction == 2 ? 2 : 1; | |||
| @@ -479,7 +645,7 @@ int LSTM_x86::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& to | |||
| return -100; | |||
| hidden.fill(0.f); | |||
| cell.create(num_output, num_directions, 4u, hidden_cell_allocator); | |||
| cell.create(hidden_size, num_directions, 4u, hidden_cell_allocator); | |||
| if (cell.empty()) | |||
| return -100; | |||
| cell.fill(0.f); | |||
| @@ -493,7 +659,7 @@ int LSTM_x86::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& to | |||
| // Uni directional | |||
| if (direction == 0 || direction == 1) | |||
| { | |||
| int ret = lstm(bottom_blob, top_blob, direction, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), hidden, cell, opt); | |||
| int ret = lstm(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt); | |||
| if (ret != 0) | |||
| return ret; | |||
| } | |||
| @@ -510,15 +676,13 @@ int LSTM_x86::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& to | |||
| Mat hidden0 = hidden.row_range(0, 1); | |||
| Mat cell0 = cell.row_range(0, 1); | |||
| int ret0 = lstm(bottom_blob, top_blob_forward, 0, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), hidden0, cell0, opt); | |||
| int ret0 = lstm(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden0, cell0, opt); | |||
| if (ret0 != 0) | |||
| return ret0; | |||
| Mat hidden1 = hidden.row_range(1, 1); | |||
| Mat cell1 = cell.row_range(1, 1); | |||
| int ret1 = lstm(bottom_blob, top_blob_reverse, 1, weight_xc_data.channel(1), bias_c_data.channel(1), weight_hc_data.channel(1), hidden1, cell1, opt); | |||
| int ret1 = lstm(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), num_output == hidden_size ? Mat() : weight_hr_data.channel(1), hidden1, cell1, opt); | |||
| if (ret1 != 0) | |||
| return ret1; | |||
| @@ -541,9 +705,6 @@ int LSTM_x86::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& to | |||
| } | |||
| return 0; | |||
| #else | |||
| return LSTM::forward(bottom_blobs, top_blobs, opt); | |||
| #endif | |||
| } | |||
| } // namespace ncnn | |||
| @@ -31,6 +31,9 @@ public: | |||
| virtual int forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top_blobs, const Option& opt) const; | |||
| public: | |||
| Mat weight_xc_data_packed; | |||
| Mat bias_c_data_packed; | |||
| Mat weight_hc_data_packed; | |||
| }; | |||
| } // namespace ncnn | |||
| @@ -15,50 +15,64 @@ | |||
| #include "layer/lstm.h" | |||
| #include "testutil.h" | |||
| static int test_lstm(const ncnn::Mat& a, int outch, int direction) | |||
| static int test_lstm(const ncnn::Mat& a, int outch, int direction, int hidden_size = 0) | |||
| { | |||
| int input_size = a.w; | |||
| int num_directions = direction == 2 ? 2 : 1; | |||
| if (hidden_size == 0) | |||
| hidden_size = outch; | |||
| ncnn::ParamDict pd; | |||
| pd.set(0, outch); | |||
| pd.set(1, outch * input_size * 4 * num_directions); | |||
| pd.set(1, hidden_size * input_size * 4 * num_directions); | |||
| pd.set(2, direction); | |||
| pd.set(3, hidden_size); | |||
| std::vector<ncnn::Mat> weights(3); | |||
| weights[0] = RandomMat(outch * input_size * 4 * num_directions); | |||
| weights[1] = RandomMat(outch * 4 * num_directions); | |||
| weights[2] = RandomMat(outch * outch * 4 * num_directions); | |||
| std::vector<ncnn::Mat> weights(hidden_size == 0 ? 3 : 4); | |||
| weights[0] = RandomMat(hidden_size * input_size * 4 * num_directions); | |||
| weights[1] = RandomMat(hidden_size * 4 * num_directions); | |||
| weights[2] = RandomMat(outch * hidden_size * 4 * num_directions); | |||
| if (hidden_size) | |||
| { | |||
| weights[3] = RandomMat(hidden_size * outch * num_directions); | |||
| } | |||
| int ret = test_layer<ncnn::LSTM>("LSTM", pd, weights, a); | |||
| if (ret != 0) | |||
| { | |||
| fprintf(stderr, "test_lstm failed a.dims=%d a=(%d %d %d) outch=%d, direction = %d \n", a.dims, a.w, a.h, a.c, outch, direction); | |||
| fprintf(stderr, "test_lstm failed a.dims=%d a=(%d %d %d) outch=%d direction=%d hidden_size=%d\n", a.dims, a.w, a.h, a.c, outch, direction, hidden_size); | |||
| } | |||
| return ret; | |||
| } | |||
| int test_lstm_layer_with_hidden(const ncnn::Mat& a, int outch, int direction) | |||
| int test_lstm_layer_with_hidden(const ncnn::Mat& a, int outch, int direction, int hidden_size = 0) | |||
| { | |||
| int input_size = a.w; | |||
| int num_directions = direction == 2 ? 2 : 1; | |||
| if (hidden_size == 0) | |||
| hidden_size = outch; | |||
| ncnn::ParamDict pd; | |||
| pd.set(0, outch); | |||
| pd.set(1, outch * input_size * 4 * num_directions); | |||
| pd.set(1, hidden_size * input_size * 4 * num_directions); | |||
| pd.set(2, direction); | |||
| pd.set(3, hidden_size); | |||
| std::vector<ncnn::Mat> weights(3); | |||
| weights[0] = RandomMat(outch * input_size * 4 * num_directions); | |||
| weights[1] = RandomMat(outch * 4 * num_directions); | |||
| weights[2] = RandomMat(outch * outch * 4 * num_directions); | |||
| std::vector<ncnn::Mat> weights(hidden_size == 0 ? 3 : 4); | |||
| weights[0] = RandomMat(hidden_size * input_size * 4 * num_directions); | |||
| weights[1] = RandomMat(hidden_size * 4 * num_directions); | |||
| weights[2] = RandomMat(outch * hidden_size * 4 * num_directions); | |||
| if (hidden_size) | |||
| { | |||
| weights[3] = RandomMat(hidden_size * outch * num_directions); | |||
| } | |||
| // initial hidden state | |||
| ncnn::Mat hidden = RandomMat(outch, num_directions); | |||
| // initial cell state | |||
| ncnn::Mat cell = RandomMat(outch, num_directions); | |||
| ncnn::Mat cell = RandomMat(hidden_size, num_directions); | |||
| std::vector<ncnn::Mat> as(3); | |||
| as[0] = a; | |||
| @@ -68,32 +82,39 @@ int test_lstm_layer_with_hidden(const ncnn::Mat& a, int outch, int direction) | |||
| int ret = test_layer<ncnn::LSTM>("LSTM", pd, weights, as, 3); | |||
| if (ret != 0) | |||
| { | |||
| fprintf(stderr, "test_lstm_layer_with_hidden failed a.dims=%d a=(%d %d %d) outch=%d, direction = %d \n", a.dims, a.w, a.h, a.c, outch, direction); | |||
| fprintf(stderr, "test_lstm_layer_with_hidden failed a.dims=%d a=(%d %d %d) outch=%d direction=%d hidden_size=%d\n", a.dims, a.w, a.h, a.c, outch, direction, hidden_size); | |||
| } | |||
| return ret; | |||
| } | |||
| int test_lstm_layer_with_hidden_input(const ncnn::Mat& a, int outch, int direction) | |||
| int test_lstm_layer_with_hidden_input(const ncnn::Mat& a, int outch, int direction, int hidden_size = 0) | |||
| { | |||
| int input_size = a.w; | |||
| int num_directions = direction == 2 ? 2 : 1; | |||
| if (hidden_size == 0) | |||
| hidden_size = outch; | |||
| ncnn::ParamDict pd; | |||
| pd.set(0, outch); | |||
| pd.set(1, outch * input_size * 4 * num_directions); | |||
| pd.set(1, hidden_size * input_size * 4 * num_directions); | |||
| pd.set(2, direction); | |||
| pd.set(3, hidden_size); | |||
| std::vector<ncnn::Mat> weights(3); | |||
| weights[0] = RandomMat(outch * input_size * 4 * num_directions); | |||
| weights[1] = RandomMat(outch * 4 * num_directions); | |||
| weights[2] = RandomMat(outch * outch * 4 * num_directions); | |||
| std::vector<ncnn::Mat> weights(hidden_size == 0 ? 3 : 4); | |||
| weights[0] = RandomMat(hidden_size * input_size * 4 * num_directions); | |||
| weights[1] = RandomMat(hidden_size * 4 * num_directions); | |||
| weights[2] = RandomMat(outch * hidden_size * 4 * num_directions); | |||
| if (hidden_size) | |||
| { | |||
| weights[3] = RandomMat(hidden_size * outch * num_directions); | |||
| } | |||
| // initial hidden state | |||
| ncnn::Mat hidden = RandomMat(outch, num_directions); | |||
| // initial cell state | |||
| ncnn::Mat cell = RandomMat(outch, num_directions); | |||
| ncnn::Mat cell = RandomMat(hidden_size, num_directions); | |||
| std::vector<ncnn::Mat> as(3); | |||
| as[0] = a; | |||
| @@ -103,26 +124,33 @@ int test_lstm_layer_with_hidden_input(const ncnn::Mat& a, int outch, int directi | |||
| int ret = test_layer<ncnn::LSTM>("LSTM", pd, weights, as, 1); | |||
| if (ret != 0) | |||
| { | |||
| fprintf(stderr, "test_lstm_layer_with_hidden_input failed a.dims=%d a=(%d %d %d) outch=%d, direction = %d \n", a.dims, a.w, a.h, a.c, outch, direction); | |||
| fprintf(stderr, "test_lstm_layer_with_hidden_input failed a.dims=%d a=(%d %d %d) outch=%d direction=%d hidden_size=%d\n", a.dims, a.w, a.h, a.c, outch, direction, hidden_size); | |||
| } | |||
| return ret; | |||
| } | |||
| int test_lstm_layer_with_hidden_output(const ncnn::Mat& a, int outch, int direction) | |||
| int test_lstm_layer_with_hidden_output(const ncnn::Mat& a, int outch, int direction, int hidden_size = 0) | |||
| { | |||
| int input_size = a.w; | |||
| int num_directions = direction == 2 ? 2 : 1; | |||
| if (hidden_size == 0) | |||
| hidden_size = outch; | |||
| ncnn::ParamDict pd; | |||
| pd.set(0, outch); | |||
| pd.set(1, outch * input_size * 4 * num_directions); | |||
| pd.set(1, hidden_size * input_size * 4 * num_directions); | |||
| pd.set(2, direction); | |||
| pd.set(3, hidden_size); | |||
| std::vector<ncnn::Mat> weights(3); | |||
| weights[0] = RandomMat(outch * input_size * 4 * num_directions); | |||
| weights[1] = RandomMat(outch * 4 * num_directions); | |||
| weights[2] = RandomMat(outch * outch * 4 * num_directions); | |||
| std::vector<ncnn::Mat> weights(hidden_size == 0 ? 3 : 4); | |||
| weights[0] = RandomMat(hidden_size * input_size * 4 * num_directions); | |||
| weights[1] = RandomMat(hidden_size * 4 * num_directions); | |||
| weights[2] = RandomMat(outch * hidden_size * 4 * num_directions); | |||
| if (hidden_size) | |||
| { | |||
| weights[3] = RandomMat(hidden_size * outch * num_directions); | |||
| } | |||
| std::vector<ncnn::Mat> as(1); | |||
| as[0] = a; | |||
| @@ -130,7 +158,7 @@ int test_lstm_layer_with_hidden_output(const ncnn::Mat& a, int outch, int direct | |||
| int ret = test_layer<ncnn::LSTM>("LSTM", pd, weights, as, 3); | |||
| if (ret != 0) | |||
| { | |||
| fprintf(stderr, "test_lstm_layer_with_hidden_output failed a.dims=%d a=(%d %d %d) outch=%d, direction = %d \n", a.dims, a.w, a.h, a.c, outch, direction); | |||
| fprintf(stderr, "test_lstm_layer_with_hidden_output failed a.dims=%d a=(%d %d %d) outch=%d direction=%d hidden_size=%d\n", a.dims, a.w, a.h, a.c, outch, direction, hidden_size); | |||
| } | |||
| return ret; | |||
| @@ -147,7 +175,7 @@ static int test_lstm_0() | |||
| || test_lstm(RandomMat(5, 16), 16, 2) | |||
| || test_lstm(RandomMat(3, 16), 8, 2) | |||
| || test_lstm(RandomMat(8, 16), 16, 2) | |||
| || test_lstm(RandomMat(2, 5), 17, 2); | |||
| || test_lstm(RandomMat(2, 5), 17, 2, 15); | |||
| } | |||
| static int test_lstm_1() | |||
| @@ -160,7 +188,7 @@ static int test_lstm_1() | |||
| || test_lstm_layer_with_hidden(RandomMat(19, 15), 8, 2) | |||
| || test_lstm_layer_with_hidden(RandomMat(5, 16), 16, 2) | |||
| || test_lstm_layer_with_hidden(RandomMat(3, 16), 8, 2) | |||
| || test_lstm_layer_with_hidden(RandomMat(2, 5), 99, 2) | |||
| || test_lstm_layer_with_hidden(RandomMat(2, 5), 99, 2, 33) | |||
| || test_lstm_layer_with_hidden(RandomMat(4, 4), 1, 1) | |||
| || test_lstm_layer_with_hidden(RandomMat(8, 2), 2, 1) | |||
| || test_lstm_layer_with_hidden(RandomMat(16, 8), 7, 1) | |||
| @@ -168,7 +196,7 @@ static int test_lstm_1() | |||
| || test_lstm_layer_with_hidden(RandomMat(19, 15), 8, 1) | |||
| || test_lstm_layer_with_hidden(RandomMat(5, 16), 16, 1) | |||
| || test_lstm_layer_with_hidden(RandomMat(3, 16), 8, 1) | |||
| || test_lstm_layer_with_hidden(RandomMat(2, 5), 99, 1) | |||
| || test_lstm_layer_with_hidden(RandomMat(2, 5), 99, 1, 33) | |||
| || test_lstm_layer_with_hidden(RandomMat(4, 2), 1, 0) | |||
| || test_lstm_layer_with_hidden(RandomMat(8, 2), 2, 0) | |||
| || test_lstm_layer_with_hidden(RandomMat(16, 8), 7, 0) | |||
| @@ -176,7 +204,7 @@ static int test_lstm_1() | |||
| || test_lstm_layer_with_hidden(RandomMat(19, 15), 8, 0) | |||
| || test_lstm_layer_with_hidden(RandomMat(5, 16), 16, 0) | |||
| || test_lstm_layer_with_hidden(RandomMat(3, 16), 8, 0) | |||
| || test_lstm_layer_with_hidden(RandomMat(2, 5), 17, 0) | |||
| || test_lstm_layer_with_hidden(RandomMat(2, 5), 17, 0, 15) | |||
| || test_lstm_layer_with_hidden_input(RandomMat(4, 4), 1, 2) | |||
| || test_lstm_layer_with_hidden_input(RandomMat(8, 2), 2, 2) | |||
| @@ -185,7 +213,7 @@ static int test_lstm_1() | |||
| || test_lstm_layer_with_hidden_input(RandomMat(19, 15), 8, 2) | |||
| || test_lstm_layer_with_hidden_input(RandomMat(5, 16), 16, 2) | |||
| || test_lstm_layer_with_hidden_input(RandomMat(3, 16), 8, 2) | |||
| || test_lstm_layer_with_hidden_input(RandomMat(2, 5), 99, 2) | |||
| || test_lstm_layer_with_hidden_input(RandomMat(2, 5), 99, 2, 33) | |||
| || test_lstm_layer_with_hidden_input(RandomMat(4, 4), 1, 1) | |||
| || test_lstm_layer_with_hidden_input(RandomMat(8, 2), 2, 1) | |||
| || test_lstm_layer_with_hidden_input(RandomMat(16, 8), 7, 1) | |||
| @@ -193,7 +221,7 @@ static int test_lstm_1() | |||
| || test_lstm_layer_with_hidden_input(RandomMat(19, 15), 8, 1) | |||
| || test_lstm_layer_with_hidden_input(RandomMat(5, 16), 16, 1) | |||
| || test_lstm_layer_with_hidden_input(RandomMat(3, 16), 8, 1) | |||
| || test_lstm_layer_with_hidden_input(RandomMat(2, 5), 99, 1) | |||
| || test_lstm_layer_with_hidden_input(RandomMat(2, 5), 99, 1, 33) | |||
| || test_lstm_layer_with_hidden_input(RandomMat(4, 2), 1, 0) | |||
| || test_lstm_layer_with_hidden_input(RandomMat(8, 2), 2, 0) | |||
| || test_lstm_layer_with_hidden_input(RandomMat(16, 8), 7, 0) | |||
| @@ -201,7 +229,7 @@ static int test_lstm_1() | |||
| || test_lstm_layer_with_hidden_input(RandomMat(19, 15), 8, 0) | |||
| || test_lstm_layer_with_hidden_input(RandomMat(5, 16), 16, 0) | |||
| || test_lstm_layer_with_hidden_input(RandomMat(3, 16), 8, 0) | |||
| || test_lstm_layer_with_hidden_input(RandomMat(2, 5), 17, 0) | |||
| || test_lstm_layer_with_hidden_input(RandomMat(2, 5), 17, 0, 15) | |||
| || test_lstm_layer_with_hidden_output(RandomMat(4, 4), 1, 2) | |||
| || test_lstm_layer_with_hidden_output(RandomMat(8, 2), 2, 2) | |||
| @@ -210,7 +238,7 @@ static int test_lstm_1() | |||
| || test_lstm_layer_with_hidden_output(RandomMat(19, 15), 8, 2) | |||
| || test_lstm_layer_with_hidden_output(RandomMat(5, 16), 16, 2) | |||
| || test_lstm_layer_with_hidden_output(RandomMat(3, 16), 8, 2) | |||
| || test_lstm_layer_with_hidden_output(RandomMat(2, 5), 99, 2) | |||
| || test_lstm_layer_with_hidden_output(RandomMat(2, 5), 99, 2, 33) | |||
| || test_lstm_layer_with_hidden_output(RandomMat(4, 4), 1, 1) | |||
| || test_lstm_layer_with_hidden_output(RandomMat(8, 2), 2, 1) | |||
| || test_lstm_layer_with_hidden_output(RandomMat(16, 8), 7, 1) | |||
| @@ -218,7 +246,7 @@ static int test_lstm_1() | |||
| || test_lstm_layer_with_hidden_output(RandomMat(19, 15), 8, 1) | |||
| || test_lstm_layer_with_hidden_output(RandomMat(5, 16), 16, 1) | |||
| || test_lstm_layer_with_hidden_output(RandomMat(3, 16), 8, 1) | |||
| || test_lstm_layer_with_hidden_output(RandomMat(2, 5), 99, 1) | |||
| || test_lstm_layer_with_hidden_output(RandomMat(2, 5), 99, 1, 33) | |||
| || test_lstm_layer_with_hidden_output(RandomMat(4, 2), 1, 0) | |||
| || test_lstm_layer_with_hidden_output(RandomMat(8, 2), 2, 0) | |||
| || test_lstm_layer_with_hidden_output(RandomMat(16, 8), 7, 0) | |||
| @@ -226,7 +254,7 @@ static int test_lstm_1() | |||
| || test_lstm_layer_with_hidden_output(RandomMat(19, 15), 8, 0) | |||
| || test_lstm_layer_with_hidden_output(RandomMat(5, 16), 16, 0) | |||
| || test_lstm_layer_with_hidden_output(RandomMat(3, 16), 8, 0) | |||
| || test_lstm_layer_with_hidden_output(RandomMat(2, 5), 17, 0); | |||
| || test_lstm_layer_with_hidden_output(RandomMat(2, 5), 17, 0, 15); | |||
| } | |||
| static int test_lstm_2() | |||
| @@ -240,7 +268,7 @@ static int test_lstm_2() | |||
| || test_lstm(RandomMat(5, 16), 16, 0) | |||
| || test_lstm(RandomMat(3, 16), 8, 0) | |||
| || test_lstm(RandomMat(8, 16), 16, 0) | |||
| || test_lstm(RandomMat(2, 5), 17, 0); | |||
| || test_lstm(RandomMat(2, 5), 17, 0, 15); | |||
| } | |||
| static int test_lstm_3() | |||
| { | |||
| @@ -253,7 +281,7 @@ static int test_lstm_3() | |||
| || test_lstm(RandomMat(5, 16), 16, 1) | |||
| || test_lstm(RandomMat(3, 16), 8, 1) | |||
| || test_lstm(RandomMat(8, 16), 16, 1) | |||
| || test_lstm(RandomMat(2, 5), 17, 1); | |||
| || test_lstm(RandomMat(2, 5), 17, 1, 15); | |||
| } | |||
| int main() | |||
| @@ -33,9 +33,9 @@ public: | |||
| void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph, const torch::jit::Module& mod) const | |||
| { | |||
| // mod.dump(true, true, true); | |||
| // graph->dump(); | |||
| // mod.dump(true, true, true); | |||
| // | |||
| // graph->dump(); | |||
| const torch::jit::Node* lstm = find_node_by_kind(graph, "aten::lstm"); | |||
| @@ -49,12 +49,13 @@ public: | |||
| op->params["pnnx_rnn_output_swapped"] = 1; | |||
| } | |||
| // for (auto aa : lstm->schema().arguments()) | |||
| // { | |||
| // fprintf(stderr, "arg %s\n", aa.name().c_str()); | |||
| // } | |||
| // for (auto aa : lstm->schema().arguments()) | |||
| // { | |||
| // fprintf(stderr, "arg %s\n", aa.name().c_str()); | |||
| // } | |||
| const auto& weight_ih_l0 = mod.attr("weight_ih_l0").toTensor(); | |||
| const auto& weight_hh_l0 = mod.attr("weight_hh_l0").toTensor(); | |||
| op->params["input_size"] = weight_ih_l0.size(1); | |||
| op->params["hidden_size"] = weight_ih_l0.size(0) / 4; | |||
| @@ -62,10 +63,12 @@ public: | |||
| op->params["bias"] = lstm->namedInput("has_biases"); | |||
| op->params["batch_first"] = lstm->namedInput("batch_first"); | |||
| op->params["bidirectional"] = lstm->namedInput("bidirectional"); | |||
| op->params["proj_size"] = weight_ih_l0.size(0) / 4 == weight_hh_l0.size(1) ? 0 : weight_hh_l0.size(1); | |||
| const int num_layers = op->params["num_layers"].i; | |||
| const bool bias = op->params["bias"].b; | |||
| const bool bidirectional = op->params["bidirectional"].b; | |||
| const int proj_size = op->params["proj_size"].i; | |||
| for (int k = 0; k < num_layers; k++) | |||
| { | |||
| @@ -84,6 +87,13 @@ public: | |||
| op->attrs[bias_hh_lk_key] = mod.attr(bias_hh_lk_key).toTensor(); | |||
| } | |||
| if (proj_size > 0) | |||
| { | |||
| std::string weight_hr_lk_key = std::string("weight_hr_l") + std::to_string(k); | |||
| op->attrs[weight_hr_lk_key] = mod.attr(weight_hr_lk_key).toTensor(); | |||
| } | |||
| if (bidirectional) | |||
| { | |||
| std::string weight_ih_lk_reverse_key = std::string("weight_ih_l") + std::to_string(k) + "_reverse"; | |||
| @@ -100,6 +110,13 @@ public: | |||
| op->attrs[bias_ih_lk_reverse_key] = mod.attr(bias_ih_lk_reverse_key).toTensor(); | |||
| op->attrs[bias_hh_lk_reverse_key] = mod.attr(bias_hh_lk_reverse_key).toTensor(); | |||
| } | |||
| if (proj_size > 0) | |||
| { | |||
| std::string weight_hr_lk_reverse_key = std::string("weight_hr_l") + std::to_string(k) + "_reverse"; | |||
| op->attrs[weight_hr_lk_reverse_key] = mod.attr(weight_hr_lk_reverse_key).toTensor(); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -42,6 +42,7 @@ void unroll_rnn_op(Graph& graph) | |||
| bool has_output_hidden = op->outputs.size() >= 2; | |||
| bool has_output_cell = op->outputs.size() == 3; | |||
| const int hidden_size = op->params["hidden_size"].i; | |||
| const int proj_size = (op->type == "nn.LSTM") ? op->params["proj_size"].i : 0; | |||
| bool has_bias = op->params["bias"].b; | |||
| bool is_bidirectional = op->params["bidirectional"].b; | |||
| @@ -116,7 +117,14 @@ void unroll_rnn_op(Graph& graph) | |||
| } | |||
| else | |||
| { | |||
| op1->params["input_size"] = is_bidirectional ? hidden_size * 2 : hidden_size; | |||
| if (proj_size) | |||
| { | |||
| op1->params["input_size"] = is_bidirectional ? proj_size * 2 : proj_size; | |||
| } | |||
| else | |||
| { | |||
| op1->params["input_size"] = is_bidirectional ? hidden_size * 2 : hidden_size; | |||
| } | |||
| op1->inputs.push_back(unrolled_ops[j - 1]->outputs[0]); | |||
| op1->inputs[0]->consumers.push_back(op1); | |||
| @@ -171,6 +179,11 @@ void unroll_rnn_op(Graph& graph) | |||
| op1->attrs["bias_ih_l0"] = op->attrs["bias_ih_l" + std::to_string(j)]; | |||
| } | |||
| if (proj_size) | |||
| { | |||
| op1->attrs["weight_hr_l0"] = op->attrs["weight_hr_l" + std::to_string(j)]; | |||
| } | |||
| if (is_bidirectional) | |||
| { | |||
| op1->attrs["weight_hh_l0_reverse"] = op->attrs["weight_hh_l" + std::to_string(j) + "_reverse"]; | |||
| @@ -181,6 +194,11 @@ void unroll_rnn_op(Graph& graph) | |||
| op1->attrs["bias_hh_l0_reverse"] = op->attrs["bias_hh_l" + std::to_string(j) + "_reverse"]; | |||
| op1->attrs["bias_ih_l0_reverse"] = op->attrs["bias_ih_l" + std::to_string(j) + "_reverse"]; | |||
| } | |||
| if (proj_size) | |||
| { | |||
| op1->attrs["weight_hr_l0_reverse"] = op->attrs["weight_hr_l" + std::to_string(j) + "_reverse"]; | |||
| } | |||
| } | |||
| unrolled_ops[j] = op1; | |||
| @@ -27,7 +27,7 @@ public: | |||
| return R"PNNXIR(7767517 | |||
| 3 4 | |||
| pnnx.Input input 0 1 input | |||
| nn.LSTM op_0 1 3 input out out_hidden out_cell input_size=%input_size hidden_size=%hidden_size num_layers=1 bias=%bias batch_first=%batch_first bidirectional=%bidirectional @weight_ih_l0 @weight_hh_l0 @bias_ih_l0 @bias_hh_l0 @weight_ih_l0_reverse @weight_hh_l0_reverse @bias_ih_l0_reverse @bias_hh_l0_reverse | |||
| nn.LSTM op_0 1 3 input out out_hidden out_cell input_size=%input_size hidden_size=%hidden_size num_layers=1 bias=%bias batch_first=%batch_first bidirectional=%bidirectional proj_size=%proj_size @weight_ih_l0 @weight_hh_l0 @bias_ih_l0 @bias_hh_l0 @weight_hr_l0 @weight_ih_l0_reverse @weight_hh_l0_reverse @bias_ih_l0_reverse @bias_hh_l0_reverse @weight_hr_l0_reverse | |||
| pnnx.Output output 3 0 out out_hidden out_cell | |||
| )PNNXIR"; | |||
| } | |||
| @@ -46,14 +46,19 @@ pnnx.Output output 3 0 out out_hidden out_cell | |||
| { | |||
| const bool bidirectional = captured_params.at("bidirectional").b; | |||
| const int num_directions = bidirectional ? 2 : 1; | |||
| const int num_output = captured_params.at("hidden_size").i; | |||
| const int hidden_size = captured_params.at("hidden_size").i; | |||
| const int input_size = captured_params.at("input_size").i; | |||
| int weight_data_size = num_directions * num_output * input_size * 4; | |||
| int proj_size = captured_params.at("proj_size").i; | |||
| if (proj_size == 0) | |||
| proj_size = hidden_size; | |||
| op->params["0"] = num_output; | |||
| int weight_data_size = num_directions * hidden_size * input_size * 4; | |||
| op->params["0"] = proj_size; | |||
| op->params["1"] = weight_data_size; | |||
| op->params["2"] = bidirectional ? 2 : 0; | |||
| op->params["3"] = hidden_size; | |||
| op->attrs["0"] = Attribute(); | |||
| op->attrs["0"].data = {0, 0, 0, 0}; | |||
| @@ -62,7 +67,7 @@ pnnx.Output output 3 0 out out_hidden out_cell | |||
| { | |||
| std::vector<float> new_weight_ih; | |||
| { | |||
| const int weight_data_size_g = num_output * input_size; | |||
| const int weight_data_size_g = hidden_size * input_size; | |||
| const float* weight_ih = (const float*)captured_attrs.at("op_0.weight_ih_l0").data.data(); | |||
| const float* iptr = weight_ih; | |||
| @@ -70,7 +75,7 @@ pnnx.Output output 3 0 out out_hidden out_cell | |||
| const float* gptr = weight_ih + weight_data_size_g * 2; | |||
| const float* optr = weight_ih + weight_data_size_g * 3; | |||
| new_weight_ih.resize(4 * num_output * input_size); | |||
| new_weight_ih.resize(4 * hidden_size * input_size); | |||
| float* weight = (float*)new_weight_ih.data(); | |||
| float* w_iptr = weight; | |||
| float* w_fptr = weight + weight_data_size_g; | |||
| @@ -86,7 +91,7 @@ pnnx.Output output 3 0 out out_hidden out_cell | |||
| { | |||
| std::vector<float> new_weight_ih_reverse; | |||
| { | |||
| const int weight_data_size_g = num_output * input_size; | |||
| const int weight_data_size_g = hidden_size * input_size; | |||
| const float* weight_ih = (const float*)captured_attrs.at("op_0.weight_ih_l0_reverse").data.data(); | |||
| const float* iptr = weight_ih; | |||
| @@ -94,7 +99,7 @@ pnnx.Output output 3 0 out out_hidden out_cell | |||
| const float* gptr = weight_ih + weight_data_size_g * 2; | |||
| const float* optr = weight_ih + weight_data_size_g * 3; | |||
| new_weight_ih_reverse.resize(4 * num_output * input_size); | |||
| new_weight_ih_reverse.resize(4 * hidden_size * input_size); | |||
| float* weight = (float*)new_weight_ih_reverse.data(); | |||
| float* w_iptr = weight; | |||
| float* w_fptr = weight + weight_data_size_g; | |||
| @@ -105,11 +110,11 @@ pnnx.Output output 3 0 out out_hidden out_cell | |||
| memcpy(w_optr, optr, weight_data_size_g * sizeof(float)); | |||
| memcpy(w_gptr, gptr, weight_data_size_g * sizeof(float)); | |||
| } | |||
| op->attrs["1"] = Attribute({4, num_output, input_size}, new_weight_ih) + Attribute({4, num_output, input_size}, new_weight_ih_reverse); | |||
| op->attrs["1"] = Attribute({4, hidden_size, input_size}, new_weight_ih) + Attribute({4, hidden_size, input_size}, new_weight_ih_reverse); | |||
| } | |||
| else | |||
| { | |||
| op->attrs["1"] = Attribute({4, num_output, input_size}, new_weight_ih); | |||
| op->attrs["1"] = Attribute({4, hidden_size, input_size}, new_weight_ih); | |||
| } | |||
| } | |||
| @@ -124,33 +129,33 @@ pnnx.Output output 3 0 out out_hidden out_cell | |||
| const float* bias_ih = (const float*)captured_attrs.at("op_0.bias_ih_l0").data.data(); | |||
| const float* bias_hh = (const float*)captured_attrs.at("op_0.bias_hh_l0").data.data(); | |||
| const float* bias_ih_iptr = bias_ih; | |||
| const float* bias_ih_fptr = bias_ih + num_output; | |||
| const float* bias_ih_gptr = bias_ih + num_output * 2; | |||
| const float* bias_ih_optr = bias_ih + num_output * 3; | |||
| const float* bias_ih_fptr = bias_ih + hidden_size; | |||
| const float* bias_ih_gptr = bias_ih + hidden_size * 2; | |||
| const float* bias_ih_optr = bias_ih + hidden_size * 3; | |||
| const float* bias_hh_iptr = bias_hh; | |||
| const float* bias_hh_fptr = bias_hh + num_output; | |||
| const float* bias_hh_gptr = bias_hh + num_output * 2; | |||
| const float* bias_hh_optr = bias_hh + num_output * 3; | |||
| const float* bias_hh_fptr = bias_hh + hidden_size; | |||
| const float* bias_hh_gptr = bias_hh + hidden_size * 2; | |||
| const float* bias_hh_optr = bias_hh + hidden_size * 3; | |||
| new_bias.resize(4 * num_output); | |||
| new_bias.resize(4 * hidden_size); | |||
| float* bias = (float*)new_bias.data(); | |||
| float* b_iptr = bias; | |||
| float* b_fptr = bias + num_output; | |||
| float* b_optr = bias + num_output * 2; | |||
| float* b_gptr = bias + num_output * 3; | |||
| for (int i = 0; i < num_output; i++) | |||
| float* b_fptr = bias + hidden_size; | |||
| float* b_optr = bias + hidden_size * 2; | |||
| float* b_gptr = bias + hidden_size * 3; | |||
| for (int i = 0; i < hidden_size; i++) | |||
| { | |||
| b_iptr[i] = bias_ih_iptr[i] + bias_hh_iptr[i]; | |||
| } | |||
| for (int i = 0; i < num_output; i++) | |||
| for (int i = 0; i < hidden_size; i++) | |||
| { | |||
| b_fptr[i] = bias_ih_fptr[i] + bias_hh_fptr[i]; | |||
| } | |||
| for (int i = 0; i < num_output; i++) | |||
| for (int i = 0; i < hidden_size; i++) | |||
| { | |||
| b_optr[i] = bias_ih_optr[i] + bias_hh_optr[i]; | |||
| } | |||
| for (int i = 0; i < num_output; i++) | |||
| for (int i = 0; i < hidden_size; i++) | |||
| { | |||
| b_gptr[i] = bias_ih_gptr[i] + bias_hh_gptr[i]; | |||
| } | |||
| @@ -163,63 +168,63 @@ pnnx.Output output 3 0 out out_hidden out_cell | |||
| const float* bias_ih = (const float*)captured_attrs.at("op_0.bias_ih_l0_reverse").data.data(); | |||
| const float* bias_hh = (const float*)captured_attrs.at("op_0.bias_hh_l0_reverse").data.data(); | |||
| const float* bias_ih_iptr = bias_ih; | |||
| const float* bias_ih_fptr = bias_ih + num_output; | |||
| const float* bias_ih_gptr = bias_ih + num_output * 2; | |||
| const float* bias_ih_optr = bias_ih + num_output * 3; | |||
| const float* bias_ih_fptr = bias_ih + hidden_size; | |||
| const float* bias_ih_gptr = bias_ih + hidden_size * 2; | |||
| const float* bias_ih_optr = bias_ih + hidden_size * 3; | |||
| const float* bias_hh_iptr = bias_hh; | |||
| const float* bias_hh_fptr = bias_hh + num_output; | |||
| const float* bias_hh_gptr = bias_hh + num_output * 2; | |||
| const float* bias_hh_optr = bias_hh + num_output * 3; | |||
| const float* bias_hh_fptr = bias_hh + hidden_size; | |||
| const float* bias_hh_gptr = bias_hh + hidden_size * 2; | |||
| const float* bias_hh_optr = bias_hh + hidden_size * 3; | |||
| new_bias_reverse.resize(4 * num_output); | |||
| new_bias_reverse.resize(4 * hidden_size); | |||
| float* bias = (float*)new_bias_reverse.data(); | |||
| float* b_iptr = bias; | |||
| float* b_fptr = bias + num_output; | |||
| float* b_optr = bias + num_output * 2; | |||
| float* b_gptr = bias + num_output * 3; | |||
| for (int i = 0; i < num_output; i++) | |||
| float* b_fptr = bias + hidden_size; | |||
| float* b_optr = bias + hidden_size * 2; | |||
| float* b_gptr = bias + hidden_size * 3; | |||
| for (int i = 0; i < hidden_size; i++) | |||
| { | |||
| b_iptr[i] = bias_ih_iptr[i] + bias_hh_iptr[i]; | |||
| } | |||
| for (int i = 0; i < num_output; i++) | |||
| for (int i = 0; i < hidden_size; i++) | |||
| { | |||
| b_fptr[i] = bias_ih_fptr[i] + bias_hh_fptr[i]; | |||
| } | |||
| for (int i = 0; i < num_output; i++) | |||
| for (int i = 0; i < hidden_size; i++) | |||
| { | |||
| b_optr[i] = bias_ih_optr[i] + bias_hh_optr[i]; | |||
| } | |||
| for (int i = 0; i < num_output; i++) | |||
| for (int i = 0; i < hidden_size; i++) | |||
| { | |||
| b_gptr[i] = bias_ih_gptr[i] + bias_hh_gptr[i]; | |||
| } | |||
| } | |||
| op->attrs["3"] = Attribute({4, num_output}, new_bias) + Attribute({4, num_output}, new_bias_reverse); | |||
| op->attrs["3"] = Attribute({4, hidden_size}, new_bias) + Attribute({4, hidden_size}, new_bias_reverse); | |||
| } | |||
| else | |||
| { | |||
| op->attrs["3"] = Attribute({4, num_output}, new_bias); | |||
| op->attrs["3"] = Attribute({4, hidden_size}, new_bias); | |||
| } | |||
| } | |||
| else | |||
| { | |||
| std::vector<float> bias(4 * num_output, 0.f); | |||
| std::vector<float> bias(4 * hidden_size, 0.f); | |||
| if (bidirectional) | |||
| op->attrs["3"] = Attribute({4, num_output}, bias) + Attribute({4, num_output}, bias); | |||
| op->attrs["3"] = Attribute({4, hidden_size}, bias) + Attribute({4, hidden_size}, bias); | |||
| else | |||
| op->attrs["3"] = Attribute({4, num_output}, bias); | |||
| op->attrs["3"] = Attribute({4, hidden_size}, bias); | |||
| } | |||
| op->attrs["4"] = Attribute(); | |||
| op->attrs["4"].data = {0, 0, 0, 0}; | |||
| // reorder IFGO-hidden-hidden to IFOG-hidden-hidden | |||
| // reorder IFGO-hidden-proj to IFOG-hidden-proj | |||
| { | |||
| std::vector<float> new_weight_hh; | |||
| { | |||
| const int weight_data_size_g = num_output * num_output; | |||
| const int weight_data_size_g = hidden_size * proj_size; | |||
| const float* weight_hh = (const float*)captured_attrs.at("op_0.weight_hh_l0").data.data(); | |||
| const float* iptr = weight_hh; | |||
| @@ -227,7 +232,7 @@ pnnx.Output output 3 0 out out_hidden out_cell | |||
| const float* gptr = weight_hh + weight_data_size_g * 2; | |||
| const float* optr = weight_hh + weight_data_size_g * 3; | |||
| new_weight_hh.resize(4 * num_output * num_output); | |||
| new_weight_hh.resize(4 * hidden_size * proj_size); | |||
| float* weight = (float*)new_weight_hh.data(); | |||
| float* w_iptr = weight; | |||
| float* w_fptr = weight + weight_data_size_g; | |||
| @@ -243,7 +248,7 @@ pnnx.Output output 3 0 out out_hidden out_cell | |||
| { | |||
| std::vector<float> new_weight_hh_reverse; | |||
| { | |||
| const int weight_data_size_g = num_output * num_output; | |||
| const int weight_data_size_g = hidden_size * proj_size; | |||
| const float* weight_hh = (const float*)captured_attrs.at("op_0.weight_hh_l0_reverse").data.data(); | |||
| const float* iptr = weight_hh; | |||
| @@ -251,7 +256,7 @@ pnnx.Output output 3 0 out out_hidden out_cell | |||
| const float* gptr = weight_hh + weight_data_size_g * 2; | |||
| const float* optr = weight_hh + weight_data_size_g * 3; | |||
| new_weight_hh_reverse.resize(4 * num_output * num_output); | |||
| new_weight_hh_reverse.resize(4 * hidden_size * proj_size); | |||
| float* weight = (float*)new_weight_hh_reverse.data(); | |||
| float* w_iptr = weight; | |||
| float* w_fptr = weight + weight_data_size_g; | |||
| @@ -262,11 +267,26 @@ pnnx.Output output 3 0 out out_hidden out_cell | |||
| memcpy(w_optr, optr, weight_data_size_g * sizeof(float)); | |||
| memcpy(w_gptr, gptr, weight_data_size_g * sizeof(float)); | |||
| } | |||
| op->attrs["5"] = Attribute({4, num_output, num_output}, new_weight_hh) + Attribute({4, num_output, num_output}, new_weight_hh_reverse); | |||
| op->attrs["5"] = Attribute({4, hidden_size, proj_size}, new_weight_hh) + Attribute({4, hidden_size, proj_size}, new_weight_hh_reverse); | |||
| } | |||
| else | |||
| { | |||
| op->attrs["5"] = Attribute({4, hidden_size, proj_size}, new_weight_hh); | |||
| } | |||
| } | |||
| if (proj_size != hidden_size) | |||
| { | |||
| op->attrs["6"] = Attribute(); | |||
| op->attrs["6"].data = {0, 0, 0, 0}; | |||
| if (bidirectional) | |||
| { | |||
| op->attrs["7"] = captured_attrs.at("op_0.weight_hr_l0") + captured_attrs.at("op_0.weight_hr_l0_reverse"); | |||
| } | |||
| else | |||
| { | |||
| op->attrs["5"] = Attribute({4, num_output, num_output}, new_weight_hh); | |||
| op->attrs["7"] = captured_attrs.at("op_0.weight_hr_l0"); | |||
| } | |||
| } | |||
| } | |||
| @@ -284,7 +304,7 @@ public: | |||
| pnnx.Input input 0 1 input | |||
| pnnx.Input in_hidden 0 1 in_hidden | |||
| pnnx.Input in_hidden 0 1 in_cell | |||
| nn.LSTM op_0 3 3 input in_hidden in_cell out out_hidden out_cell input_size=%input_size hidden_size=%hidden_size num_layers=1 bias=%bias batch_first=%batch_first bidirectional=%bidirectional @weight_ih_l0 @weight_hh_l0 @bias_ih_l0 @bias_hh_l0 @weight_ih_l0_reverse @weight_hh_l0_reverse @bias_ih_l0_reverse @bias_hh_l0_reverse | |||
| nn.LSTM op_0 3 3 input in_hidden in_cell out out_hidden out_cell input_size=%input_size hidden_size=%hidden_size num_layers=1 bias=%bias batch_first=%batch_first bidirectional=%bidirectional proj_size=%proj_size @weight_ih_l0 @weight_hh_l0 @bias_ih_l0 @bias_hh_l0 @weight_hr_l0 @weight_ih_l0_reverse @weight_hh_l0_reverse @bias_ih_l0_reverse @bias_hh_l0_reverse @weight_hr_l0_reverse | |||
| pnnx.Output output 3 0 out out_hidden out_cell | |||
| )PNNXIR"; | |||
| } | |||
| @@ -300,7 +320,7 @@ public: | |||
| return R"PNNXIR(7767517 | |||
| 3 2 | |||
| pnnx.Input input 0 1 input | |||
| nn.LSTM op_0 1 1 input out input_size=%input_size hidden_size=%hidden_size num_layers=1 bias=%bias batch_first=%batch_first bidirectional=%bidirectional @weight_ih_l0 @weight_hh_l0 @bias_ih_l0 @bias_hh_l0 @weight_ih_l0_reverse @weight_hh_l0_reverse @bias_ih_l0_reverse @bias_hh_l0_reverse | |||
| nn.LSTM op_0 1 1 input out input_size=%input_size hidden_size=%hidden_size num_layers=1 bias=%bias batch_first=%batch_first bidirectional=%bidirectional proj_size=%proj_size @weight_ih_l0 @weight_hh_l0 @bias_ih_l0 @bias_hh_l0 @weight_hr_l0 @weight_ih_l0_reverse @weight_hh_l0_reverse @bias_ih_l0_reverse @bias_hh_l0_reverse @weight_hr_l0_reverse | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| @@ -318,7 +338,7 @@ public: | |||
| pnnx.Input input 0 1 input | |||
| pnnx.Input in_hidden 0 1 in_hidden | |||
| pnnx.Input in_hidden 0 1 in_cell | |||
| nn.LSTM op_0 3 1 input in_hidden in_cell out input_size=%input_size hidden_size=%hidden_size num_layers=1 bias=%bias batch_first=%batch_first bidirectional=%bidirectional @weight_ih_l0 @weight_hh_l0 @bias_ih_l0 @bias_hh_l0 @weight_ih_l0_reverse @weight_hh_l0_reverse @bias_ih_l0_reverse @bias_hh_l0_reverse | |||
| nn.LSTM op_0 3 1 input in_hidden in_cell out input_size=%input_size hidden_size=%hidden_size num_layers=1 bias=%bias batch_first=%batch_first bidirectional=%bidirectional proj_size=%proj_size @weight_ih_l0 @weight_hh_l0 @bias_ih_l0 @bias_hh_l0 @weight_hr_l0 @weight_ih_l0_reverse @weight_hh_l0_reverse @bias_ih_l0_reverse @bias_hh_l0_reverse @weight_hr_l0_reverse | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| @@ -22,15 +22,15 @@ class Model(nn.Module): | |||
| self.lstm_0_0 = nn.LSTM(input_size=32, hidden_size=16) | |||
| self.lstm_0_1 = nn.LSTM(input_size=16, hidden_size=16, num_layers=3, bias=False) | |||
| self.lstm_0_2 = nn.LSTM(input_size=16, hidden_size=16, num_layers=4, bias=True, bidirectional=True) | |||
| self.lstm_0_3 = nn.LSTM(input_size=16, hidden_size=16, num_layers=4, bias=True, bidirectional=True) | |||
| self.lstm_0_4 = nn.LSTM(input_size=16, hidden_size=16, num_layers=4, bias=True, bidirectional=True) | |||
| self.lstm_0_2 = nn.LSTM(input_size=16, hidden_size=16, num_layers=4, bias=True, bidirectional=True, proj_size=10) | |||
| self.lstm_0_3 = nn.LSTM(input_size=20, hidden_size=16, num_layers=4, bias=True, bidirectional=True, proj_size=10) | |||
| self.lstm_0_4 = nn.LSTM(input_size=20, hidden_size=16, num_layers=4, bias=True, bidirectional=True, proj_size=10) | |||
| self.lstm_1_0 = nn.LSTM(input_size=25, hidden_size=16, batch_first=True) | |||
| self.lstm_1_1 = nn.LSTM(input_size=16, hidden_size=16, num_layers=3, bias=False, batch_first=True) | |||
| self.lstm_1_2 = nn.LSTM(input_size=16, hidden_size=16, num_layers=4, bias=True, batch_first=True, bidirectional=True) | |||
| self.lstm_1_3 = nn.LSTM(input_size=16, hidden_size=16, num_layers=4, bias=True, batch_first=True, bidirectional=True) | |||
| self.lstm_1_4 = nn.LSTM(input_size=16, hidden_size=16, num_layers=4, bias=True, batch_first=True, bidirectional=True) | |||
| self.lstm_1_2 = nn.LSTM(input_size=16, hidden_size=16, num_layers=4, bias=True, batch_first=True, bidirectional=True, proj_size=10) | |||
| self.lstm_1_3 = nn.LSTM(input_size=20, hidden_size=16, num_layers=4, bias=True, batch_first=True, bidirectional=True, proj_size=10) | |||
| self.lstm_1_4 = nn.LSTM(input_size=20, hidden_size=16, num_layers=4, bias=True, batch_first=True, bidirectional=True, proj_size=10) | |||
| def forward(self, x, y): | |||
| x = x.permute(1, 0, 2) | |||
| @@ -38,14 +38,14 @@ class Model(nn.Module): | |||
| x0, _ = self.lstm_0_0(x) | |||
| x1, _ = self.lstm_0_1(x0) | |||
| x2, (h0, c0) = self.lstm_0_2(x1) | |||
| x3, (h1, c1) = self.lstm_0_3(x1, (h0, c0)) | |||
| x4, _ = self.lstm_0_4(x1, (h1, c1)) | |||
| x3, (h1, c1) = self.lstm_0_3(x2, (h0, c0)) | |||
| x4, _ = self.lstm_0_4(x3, (h1, c1)) | |||
| y0, _ = self.lstm_1_0(y) | |||
| y1, _ = self.lstm_1_1(y0) | |||
| y2, (h2, c2) = self.lstm_1_2(y1) | |||
| y3, (h3, c3) = self.lstm_1_3(y1, (h2, c2)) | |||
| y4, _ = self.lstm_1_4(y1, (h3, c3)) | |||
| y3, (h3, c3) = self.lstm_1_3(y2, (h2, c2)) | |||
| y4, _ = self.lstm_1_4(y3, (h3, c3)) | |||
| x2 = x2.permute(1, 0, 2) | |||
| x3 = x3.permute(1, 0, 2) | |||
| @@ -22,24 +22,24 @@ class Model(nn.Module): | |||
| self.lstm_0_0 = nn.LSTM(input_size=32, hidden_size=16) | |||
| self.lstm_0_1 = nn.LSTM(input_size=16, hidden_size=16, num_layers=3, bias=False) | |||
| self.lstm_0_2 = nn.LSTM(input_size=16, hidden_size=16, num_layers=4, bias=True, bidirectional=True) | |||
| self.lstm_0_3 = nn.LSTM(input_size=16, hidden_size=16, num_layers=4, bias=True, bidirectional=True) | |||
| self.lstm_0_2 = nn.LSTM(input_size=16, hidden_size=16, num_layers=4, bias=True, bidirectional=True, proj_size=10) | |||
| self.lstm_0_3 = nn.LSTM(input_size=20, hidden_size=16, num_layers=4, bias=True, bidirectional=True, proj_size=10) | |||
| self.lstm_1_0 = nn.LSTM(input_size=25, hidden_size=16, batch_first=True) | |||
| self.lstm_1_1 = nn.LSTM(input_size=16, hidden_size=16, num_layers=3, bias=False, batch_first=True) | |||
| self.lstm_1_2 = nn.LSTM(input_size=16, hidden_size=16, num_layers=4, bias=True, batch_first=True, bidirectional=True) | |||
| self.lstm_1_3 = nn.LSTM(input_size=16, hidden_size=16, num_layers=4, bias=True, batch_first=True, bidirectional=True) | |||
| self.lstm_1_2 = nn.LSTM(input_size=16, hidden_size=16, num_layers=4, bias=True, batch_first=True, bidirectional=True, proj_size=10) | |||
| self.lstm_1_3 = nn.LSTM(input_size=20, hidden_size=16, num_layers=4, bias=True, batch_first=True, bidirectional=True, proj_size=10) | |||
| def forward(self, x, y): | |||
| x0, (h0, c0) = self.lstm_0_0(x) | |||
| x1, (h1, c1) = self.lstm_0_1(x0) | |||
| x2, (h2, c2) = self.lstm_0_2(x1) | |||
| x3, (h3, c3) = self.lstm_0_3(x1, (h2, c2)) | |||
| x3, (h3, c3) = self.lstm_0_3(x2, (h2, c2)) | |||
| y0, (h4, c4) = self.lstm_1_0(y) | |||
| y1, (h5, c5) = self.lstm_1_1(y0) | |||
| y2, (h6, c6) = self.lstm_1_2(y1) | |||
| y3, (h7, c7) = self.lstm_1_3(y1, (h6, c6)) | |||
| y3, (h7, c7) = self.lstm_1_3(y2, (h6, c6)) | |||
| return x2, x3, h0, h1, h2, h3, c0, c1, c2, c3, y2, y3, h4, h5, h6, h7, c4, c5, c6, c7 | |||
| def test(): | |||