From 77eda4c19ff99353faca9ef7509af0255773a8dd Mon Sep 17 00:00:00 2001 From: nihui Date: Fri, 14 Oct 2022 10:54:03 +0800 Subject: [PATCH] implement lstm proj_size (#4263) --- .github/workflows/test-coverage.yml | 2 +- docs/developer-guide/operators.md | 10 +- src/layer/arm/lstm_arm.cpp | 256 +++++-- src/layer/arm/lstm_arm_asimdhp.cpp | 285 ++++--- src/layer/lstm.cpp | 97 ++- src/layer/lstm.h | 2 + src/layer/x86/lstm_x86.cpp | 763 +++++++++++-------- src/layer/x86/lstm_x86.h | 3 + tests/test_lstm.cpp | 112 ++- tools/pnnx/src/pass_level1/nn_LSTM.cpp | 31 +- tools/pnnx/src/pass_level5/unroll_rnn_op.cpp | 20 +- tools/pnnx/src/pass_ncnn/nn_LSTM.cpp | 126 +-- tools/pnnx/tests/ncnn/test_nn_LSTM.py | 20 +- tools/pnnx/tests/test_nn_LSTM.py | 12 +- 14 files changed, 1133 insertions(+), 606 deletions(-) diff --git a/.github/workflows/test-coverage.yml b/.github/workflows/test-coverage.yml index f84eeb252..bc6b0097b 100644 --- a/.github/workflows/test-coverage.yml +++ b/.github/workflows/test-coverage.yml @@ -227,7 +227,7 @@ jobs: uses: actions/cache@v3 with: path: lavapipe-install - key: lavapipe-linux-install-20211127-2 + key: lavapipe-linux-install-20211127-3 - name: checkout-lavapipe if: steps.cache-lavapipe.outputs.cache-hit != 'true' uses: actions/checkout@v3 diff --git a/docs/developer-guide/operators.md b/docs/developer-guide/operators.md index 5366da1e1..1cb387ceb 100644 --- a/docs/developer-guide/operators.md +++ b/docs/developer-guide/operators.md @@ -1026,15 +1026,17 @@ y0, hidden y1, cell y2 = lstm(x0, hidden x1, cell x2) | param id | name | type | default | description | | --------- | ------------- | ----- | --------- | ----------------- | -| 0 | num_output | int | 0 | hidden size of output | +| 0 | num_output | int | 0 | output size of output | | 1 | weight_data_size| int | 0 | total size of IFOG weight matrix | | 2 | direction | int | 0 | 0=forward, 1=reverse, 2=bidirectional | +| 3 | hidden_size | int | num_output| hidden size | | weight | type | shape | | ------------- | ----- | --------------------- | -| weight_xc_data| float/fp16/int8 | [input_size, num_output * 4, num_directions] | -| bias_c_data | float/fp16/int8 | [num_output, 4, num_directions] | -| weight_hc_data| float/fp16/int8 | [num_output, num_output * 4, num_directions] | +| weight_xc_data| float/fp16/int8 | [input_size, hidden_size * 4, num_directions] | +| bias_c_data | float/fp16/int8 | [hidden_size, 4, num_directions] | +| weight_hc_data| float/fp16/int8 | [num_output, hidden_size * 4, num_directions] | +| weight_hr_data| float/fp16/int8 | [hidden_size, num_output, num_directions] | Direction flag: - 0 = forward only diff --git a/src/layer/arm/lstm_arm.cpp b/src/layer/arm/lstm_arm.cpp index 440c7bc8c..075da57af 100644 --- a/src/layer/arm/lstm_arm.cpp +++ b/src/layer/arm/lstm_arm.cpp @@ -58,11 +58,11 @@ int LSTM_arm::create_pipeline(const Option& opt) // pack IFOG int num_directions = direction == 2 ? 2 : 1; - int size = weight_data_size / num_directions / num_output / 4; + int size = weight_data_size / num_directions / hidden_size / 4; - weight_xc_data_packed.create(size, num_output, num_directions, 16u, 4); - bias_c_data_packed.create(num_output, 1, num_directions, 16u, 4); - weight_hc_data_packed.create(num_output, num_output, num_directions, 16u, 4); + weight_xc_data_packed.create(size, hidden_size, num_directions, 16u, 4); + bias_c_data_packed.create(hidden_size, 1, num_directions, 16u, 4); + weight_hc_data_packed.create(num_output, hidden_size, num_directions, 16u, 4); #pragma omp parallel for num_threads(opt.num_threads) for (int dr = 0; dr < num_directions; dr++) @@ -82,7 +82,7 @@ int LSTM_arm::create_pipeline(const Option& opt) float* bias_c_IFOG = bias_c_data_packed_dr.row(0); - for (int q = 0; q < num_output; q++) + for (int q = 0; q < hidden_size; q++) { bias_c_IFOG[0] = bias_c_I[q]; bias_c_IFOG[1] = bias_c_F[q]; @@ -91,15 +91,15 @@ int LSTM_arm::create_pipeline(const Option& opt) bias_c_IFOG += 4; - const float* weight_xc_I = weight_xc.row(num_output * 0 + q); - const float* weight_xc_F = weight_xc.row(num_output * 1 + q); - const float* weight_xc_O = weight_xc.row(num_output * 2 + q); - const float* weight_xc_G = weight_xc.row(num_output * 3 + q); + const float* weight_xc_I = weight_xc.row(hidden_size * 0 + q); + const float* weight_xc_F = weight_xc.row(hidden_size * 1 + q); + const float* weight_xc_O = weight_xc.row(hidden_size * 2 + q); + const float* weight_xc_G = weight_xc.row(hidden_size * 3 + q); - const float* weight_hc_I = weight_hc.row(num_output * 0 + q); - const float* weight_hc_F = weight_hc.row(num_output * 1 + q); - const float* weight_hc_O = weight_hc.row(num_output * 2 + q); - const float* weight_hc_G = weight_hc.row(num_output * 3 + q); + const float* weight_hc_I = weight_hc.row(hidden_size * 0 + q); + const float* weight_hc_F = weight_hc.row(hidden_size * 1 + q); + const float* weight_hc_O = weight_hc.row(hidden_size * 2 + q); + const float* weight_hc_G = weight_hc.row(hidden_size * 3 + q); float* weight_xc_IFOG = weight_xc_data_packed_dr.row(q); float* weight_hc_IFOG = weight_hc_data_packed_dr.row(q); @@ -126,21 +126,37 @@ int LSTM_arm::create_pipeline(const Option& opt) } } + if (opt.lightmode) + { + weight_xc_data.release(); + bias_c_data.release(); + weight_hc_data.release(); + } + return 0; } -static int lstm(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& weight_xc, const Mat& bias_c, const Mat& weight_hc, Mat& hidden_state, Mat& cell_state, const Option& opt) +static int lstm(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& weight_xc, const Mat& bias_c, const Mat& weight_hc, const Mat& weight_hr, Mat& hidden_state, Mat& cell_state, const Option& opt) { int size = bottom_blob.w; int T = bottom_blob.h; int num_output = top_blob.w; + int hidden_size = cell_state.w; - // 4 x num_output - Mat gates(4, num_output, 4u, opt.workspace_allocator); + // 4 x hidden_size + Mat gates(4, hidden_size, 4u, opt.workspace_allocator); if (gates.empty()) return -100; + Mat tmp_hidden_state; + if (num_output != hidden_size) + { + tmp_hidden_state.create(hidden_size, 4u, opt.workspace_allocator); + if (tmp_hidden_state.empty()) + return -100; + } + // unroll for (int t = 0; t < T; t++) { @@ -155,7 +171,7 @@ static int lstm(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& w const float* x = bottom_blob.row(ti); #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < num_output; q++) + for (int q = 0; q < hidden_size; q++) { const float* bias_c_IFOG = (const float*)bias_c + q * 4; @@ -291,14 +307,15 @@ static int lstm(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& w float* cell_ptr = cell_state; float* hidden_ptr = hidden_state; + float* tmp_hidden_ptr = tmp_hidden_state; - int remain_num_output_start = 0; + int remain_hidden_size_start = 0; #if __ARM_NEON - int nn_num_output = num_output >> 2; - remain_num_output_start = nn_num_output << 2; + int nn_hidden_size = hidden_size >> 2; + remain_hidden_size_start = nn_hidden_size << 2; #pragma omp parallel for num_threads(opt.num_threads) - for (int qq = 0; qq < nn_num_output; qq++) + for (int qq = 0; qq < nn_hidden_size; qq++) { int q = qq * 4; @@ -315,12 +332,20 @@ static int lstm(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& w float32x4_t _H = vmulq_f32(_O, tanh_ps(_cell2)); vst1q_f32(cell_ptr + q, _cell2); - vst1q_f32(hidden_ptr + q, _H); - vst1q_f32(output_data + q, _H); + + if (num_output == hidden_size) + { + vst1q_f32(hidden_ptr + q, _H); + vst1q_f32(output_data + q, _H); + } + else + { + vst1q_f32(tmp_hidden_ptr + q, _H); + } } #endif // __ARM_NEON #pragma omp parallel for num_threads(opt.num_threads) - for (int q = remain_num_output_start; q < num_output; q++) + for (int q = remain_hidden_size_start; q < hidden_size; q++) { const float* gates_data = gates.row(q); @@ -338,8 +363,43 @@ static int lstm(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& w float H = O * tanh(cell2); cell_ptr[q] = cell2; - hidden_ptr[q] = H; - output_data[q] = H; + if (num_output == hidden_size) + { + hidden_ptr[q] = H; + output_data[q] = H; + } + else + { + tmp_hidden_ptr[q] = H; + } + } + + if (num_output != hidden_size) + { + // int nn_num_output = num_output >> 2; + // int remain_num_output_start = nn_num_output << 2; + // #pragma omp parallel for num_threads(opt.num_threads) + // for (int qq = 0; qq < nn_num_output; qq++) + // { + // int q = qq * 4; + // + // } + int remain_num_output_start = 0; + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = remain_num_output_start; q < num_output; q++) + { + const float* hr = weight_hr.row(q); + const float* tmp_hidden_ptr = tmp_hidden_state; + + float H = 0; + for (int i = 0; i < hidden_size; i++) + { + H += tmp_hidden_ptr[i] * hr[i]; + } + + hidden_ptr[q] = H; + output_data[q] = H; + } } } @@ -375,7 +435,7 @@ int LSTM_arm::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) return -100; hidden.fill(0.f); - Mat cell(num_output, 4u, opt.workspace_allocator); + Mat cell(hidden_size, 4u, opt.workspace_allocator); if (cell.empty()) return -100; cell.fill(0.f); @@ -387,7 +447,7 @@ int LSTM_arm::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) // Uni directional if (direction == 0 || direction == 1) { - int ret = lstm(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden, cell, opt); + int ret = lstm(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt); if (ret != 0) return ret; } @@ -402,14 +462,14 @@ int LSTM_arm::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) if (top_blob_reverse.empty()) return -100; - int ret0 = lstm(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden, cell, opt); + int ret0 = lstm(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt); if (ret0 != 0) return ret0; hidden.fill(0.0f); cell.fill(0.0f); - int ret1 = lstm(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), hidden, cell, opt); + int ret1 = lstm(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), num_output == hidden_size ? Mat() : weight_hr_data.channel(1), hidden, cell, opt); if (ret1 != 0) return ret1; @@ -466,7 +526,7 @@ int LSTM_arm::forward(const std::vector& bottom_blobs, std::vector& to return -100; hidden.fill(0.f); - cell.create(num_output, num_directions, 4u, hidden_cell_allocator); + cell.create(hidden_size, num_directions, 4u, hidden_cell_allocator); if (cell.empty()) return -100; cell.fill(0.f); @@ -480,7 +540,7 @@ int LSTM_arm::forward(const std::vector& bottom_blobs, std::vector& to // Uni directional if (direction == 0 || direction == 1) { - int ret = lstm(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden, cell, opt); + int ret = lstm(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt); if (ret != 0) return ret; } @@ -497,13 +557,13 @@ int LSTM_arm::forward(const std::vector& bottom_blobs, std::vector& to Mat hidden0 = hidden.row_range(0, 1); Mat cell0 = cell.row_range(0, 1); - int ret0 = lstm(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden0, cell0, opt); + int ret0 = lstm(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden0, cell0, opt); if (ret0 != 0) return ret0; Mat hidden1 = hidden.row_range(1, 1); Mat cell1 = cell.row_range(1, 1); - int ret1 = lstm(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), hidden1, cell1, opt); + int ret1 = lstm(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), num_output == hidden_size ? Mat() : weight_hr_data.channel(1), hidden1, cell1, opt); if (ret1 != 0) return ret1; @@ -529,18 +589,27 @@ int LSTM_arm::forward(const std::vector& bottom_blobs, std::vector& to } #if NCNN_BF16 -static int lstm_bf16s(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& weight_xc, const Mat& bias_c, const Mat& weight_hc, Mat& hidden_state, Mat& cell_state, const Option& opt) +static int lstm_bf16s(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& weight_xc, const Mat& bias_c, const Mat& weight_hc, const Mat& weight_hr, Mat& hidden_state, Mat& cell_state, const Option& opt) { int size = bottom_blob.w; int T = bottom_blob.h; int num_output = top_blob.w; + int hidden_size = cell_state.w; - // 4 x num_output - Mat gates(4, num_output, 4u, opt.workspace_allocator); + // 4 x hidden_size + Mat gates(4, hidden_size, 4u, opt.workspace_allocator); if (gates.empty()) return -100; + Mat tmp_hidden_state; + if (num_output != hidden_size) + { + tmp_hidden_state.create(hidden_size, 4u, opt.workspace_allocator); + if (tmp_hidden_state.empty()) + return -100; + } + // unroll for (int t = 0; t < T; t++) { @@ -555,7 +624,7 @@ static int lstm_bf16s(const Mat& bottom_blob, Mat& top_blob, int reverse, const const unsigned short* x = bottom_blob.row(ti); #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < num_output; q++) + for (int q = 0; q < hidden_size; q++) { const unsigned short* bias_c_IFOG = (const unsigned short*)bias_c + q * 4; @@ -693,14 +762,15 @@ static int lstm_bf16s(const Mat& bottom_blob, Mat& top_blob, int reverse, const float* cell_ptr = cell_state; float* hidden_ptr = hidden_state; + float* tmp_hidden_ptr = tmp_hidden_state; - int remain_num_output_start = 0; + int remain_hidden_size_start = 0; #if __ARM_NEON - int nn_num_output = num_output >> 2; - remain_num_output_start = nn_num_output << 2; + int nn_hidden_size = hidden_size >> 2; + remain_hidden_size_start = nn_hidden_size << 2; #pragma omp parallel for num_threads(opt.num_threads) - for (int qq = 0; qq < nn_num_output; qq++) + for (int qq = 0; qq < nn_hidden_size; qq++) { int q = qq * 4; @@ -717,12 +787,20 @@ static int lstm_bf16s(const Mat& bottom_blob, Mat& top_blob, int reverse, const float32x4_t _H = vmulq_f32(_O, tanh_ps(_cell2)); vst1q_f32(cell_ptr + q, _cell2); - vst1q_f32(hidden_ptr + q, _H); - vst1_u16(output_data + q, bfloat2float(_H)); + + if (num_output == hidden_size) + { + vst1q_f32(hidden_ptr + q, _H); + vst1_u16(output_data + q, bfloat2float(_H)); + } + else + { + vst1q_f32(tmp_hidden_ptr + q, _H); + } } #endif // __ARM_NEON #pragma omp parallel for num_threads(opt.num_threads) - for (int q = remain_num_output_start; q < num_output; q++) + for (int q = remain_hidden_size_start; q < hidden_size; q++) { const float* gates_data = gates.row(q); @@ -740,8 +818,43 @@ static int lstm_bf16s(const Mat& bottom_blob, Mat& top_blob, int reverse, const float H = O * tanh(cell2); cell_ptr[q] = cell2; - hidden_ptr[q] = H; - output_data[q] = float32_to_bfloat16(H); + if (num_output == hidden_size) + { + hidden_ptr[q] = H; + output_data[q] = float32_to_bfloat16(H); + } + else + { + tmp_hidden_ptr[q] = H; + } + } + + if (num_output != hidden_size) + { + // int nn_num_output = num_output >> 2; + // int remain_num_output_start = nn_num_output << 2; + // #pragma omp parallel for num_threads(opt.num_threads) + // for (int qq = 0; qq < nn_num_output; qq++) + // { + // int q = qq * 4; + // + // } + int remain_num_output_start = 0; + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = remain_num_output_start; q < num_output; q++) + { + const float* hr = weight_hr.row(q); + const float* tmp_hidden_ptr = tmp_hidden_state; + + float H = 0; + for (int i = 0; i < hidden_size; i++) + { + H += tmp_hidden_ptr[i] * hr[i]; + } + + hidden_ptr[q] = H; + output_data[q] = float32_to_bfloat16(H); + } } } @@ -752,11 +865,11 @@ int LSTM_arm::create_pipeline_bf16s(const Option& opt) { // pack IFOG int num_directions = direction == 2 ? 2 : 1; - int size = weight_data_size / num_directions / num_output / 4; + int size = weight_data_size / num_directions / hidden_size / 4; - weight_xc_data_packed.create(size, num_output, num_directions, 8u, 4); - bias_c_data_packed.create(num_output, 1, num_directions, 8u, 4); - weight_hc_data_packed.create(num_output, num_output, num_directions, 8u, 4); + weight_xc_data_packed.create(size, hidden_size, num_directions, 8u, 4); + bias_c_data_packed.create(hidden_size, 1, num_directions, 8u, 4); + weight_hc_data_packed.create(num_output, hidden_size, num_directions, 8u, 4); #pragma omp parallel for num_threads(opt.num_threads) for (int dr = 0; dr < num_directions; dr++) @@ -776,7 +889,7 @@ int LSTM_arm::create_pipeline_bf16s(const Option& opt) unsigned short* bias_c_IFOG = bias_c_data_packed_dr.row(0); - for (int q = 0; q < num_output; q++) + for (int q = 0; q < hidden_size; q++) { bias_c_IFOG[0] = float32_to_bfloat16(bias_c_I[q]); bias_c_IFOG[1] = float32_to_bfloat16(bias_c_F[q]); @@ -785,15 +898,15 @@ int LSTM_arm::create_pipeline_bf16s(const Option& opt) bias_c_IFOG += 4; - const float* weight_xc_I = weight_xc.row(num_output * 0 + q); - const float* weight_xc_F = weight_xc.row(num_output * 1 + q); - const float* weight_xc_O = weight_xc.row(num_output * 2 + q); - const float* weight_xc_G = weight_xc.row(num_output * 3 + q); + const float* weight_xc_I = weight_xc.row(hidden_size * 0 + q); + const float* weight_xc_F = weight_xc.row(hidden_size * 1 + q); + const float* weight_xc_O = weight_xc.row(hidden_size * 2 + q); + const float* weight_xc_G = weight_xc.row(hidden_size * 3 + q); - const float* weight_hc_I = weight_hc.row(num_output * 0 + q); - const float* weight_hc_F = weight_hc.row(num_output * 1 + q); - const float* weight_hc_O = weight_hc.row(num_output * 2 + q); - const float* weight_hc_G = weight_hc.row(num_output * 3 + q); + const float* weight_hc_I = weight_hc.row(hidden_size * 0 + q); + const float* weight_hc_F = weight_hc.row(hidden_size * 1 + q); + const float* weight_hc_O = weight_hc.row(hidden_size * 2 + q); + const float* weight_hc_G = weight_hc.row(hidden_size * 3 + q); unsigned short* weight_xc_IFOG = weight_xc_data_packed_dr.row(q); unsigned short* weight_hc_IFOG = weight_hc_data_packed_dr.row(q); @@ -820,6 +933,13 @@ int LSTM_arm::create_pipeline_bf16s(const Option& opt) } } + if (opt.lightmode) + { + weight_xc_data.release(); + bias_c_data.release(); + weight_hc_data.release(); + } + return 0; } @@ -835,7 +955,7 @@ int LSTM_arm::forward_bf16s(const Mat& bottom_blob, Mat& top_blob, const Option& return -100; hidden.fill(0.f); - Mat cell(num_output, 4u, opt.workspace_allocator); + Mat cell(hidden_size, 4u, opt.workspace_allocator); if (cell.empty()) return -100; cell.fill(0.f); @@ -847,7 +967,7 @@ int LSTM_arm::forward_bf16s(const Mat& bottom_blob, Mat& top_blob, const Option& // Uni directional if (direction == 0 || direction == 1) { - int ret = lstm_bf16s(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden, cell, opt); + int ret = lstm_bf16s(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt); if (ret != 0) return ret; } @@ -862,14 +982,14 @@ int LSTM_arm::forward_bf16s(const Mat& bottom_blob, Mat& top_blob, const Option& if (top_blob_reverse.empty()) return -100; - int ret0 = lstm_bf16s(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden, cell, opt); + int ret0 = lstm_bf16s(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt); if (ret0 != 0) return ret0; hidden.fill(0.f); cell.fill(0.f); - int ret1 = lstm_bf16s(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), hidden, cell, opt); + int ret1 = lstm_bf16s(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), num_output == hidden_size ? Mat() : weight_hr_data.channel(1), hidden, cell, opt); if (ret1 != 0) return ret1; @@ -911,7 +1031,7 @@ int LSTM_arm::forward_bf16s(const std::vector& bottom_blobs, std::vector& bottom_blobs, std::vector& bottom_blobs, std::vector(ti); #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < num_output; q++) + for (int q = 0; q < hidden_size; q++) { const __fp16* bias_c_IFOG = (const __fp16*)bias_c + q * 4; @@ -141,11 +150,12 @@ static int lstm_fp16s(const Mat& bottom_blob, Mat& top_blob, int reverse, const float* cell_ptr = cell_state; float* hidden_ptr = hidden_state; + float* tmp_hidden_ptr = tmp_hidden_state; - int nn_num_output = num_output >> 2; - int remain_num_output_start = nn_num_output << 2; + int nn_hidden_size = hidden_size >> 2; + int remain_hidden_size_start = nn_hidden_size << 2; #pragma omp parallel for num_threads(opt.num_threads) - for (int qq = 0; qq < nn_num_output; qq++) + for (int qq = 0; qq < nn_hidden_size; qq++) { int q = qq * 4; @@ -162,11 +172,19 @@ static int lstm_fp16s(const Mat& bottom_blob, Mat& top_blob, int reverse, const float32x4_t _H = vmulq_f32(_O, tanh_ps(_cell2)); vst1q_f32(cell_ptr + q, _cell2); - vst1q_f32(hidden_ptr + q, _H); - vst1_f16(output_data + q, vcvt_f16_f32(_H)); + + if (num_output == hidden_size) + { + vst1q_f32(hidden_ptr + q, _H); + vst1_f16(output_data + q, vcvt_f16_f32(_H)); + } + else + { + vst1q_f32(tmp_hidden_ptr + q, _H); + } } #pragma omp parallel for num_threads(opt.num_threads) - for (int q = remain_num_output_start; q < num_output; q++) + for (int q = remain_hidden_size_start; q < hidden_size; q++) { const float* gates_data = gates.row(q); @@ -184,26 +202,70 @@ static int lstm_fp16s(const Mat& bottom_blob, Mat& top_blob, int reverse, const float H = O * tanh(cell2); cell_ptr[q] = cell2; - hidden_ptr[q] = H; - output_data[q] = (__fp16)(H); + if (num_output == hidden_size) + { + hidden_ptr[q] = H; + output_data[q] = (__fp16)H; + } + else + { + tmp_hidden_ptr[q] = H; + } + } + + if (num_output != hidden_size) + { + // int nn_num_output = num_output >> 2; + // int remain_num_output_start = nn_num_output << 2; + // #pragma omp parallel for num_threads(opt.num_threads) + // for (int qq = 0; qq < nn_num_output; qq++) + // { + // int q = qq * 4; + // + // } + int remain_num_output_start = 0; + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = remain_num_output_start; q < num_output; q++) + { + const float* hr = weight_hr.row(q); + const float* tmp_hidden_ptr = tmp_hidden_state; + + float H = 0; + for (int i = 0; i < hidden_size; i++) + { + H += tmp_hidden_ptr[i] * hr[i]; + } + + hidden_ptr[q] = H; + output_data[q] = (__fp16)H; + } } } return 0; } -static int lstm_fp16sa(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& weight_xc, const Mat& bias_c, const Mat& weight_hc, Mat& hidden_state, Mat& cell_state, const Option& opt) +static int lstm_fp16sa(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& weight_xc, const Mat& bias_c, const Mat& weight_hc, const Mat& weight_hr, Mat& hidden_state, Mat& cell_state, const Option& opt) { int size = bottom_blob.w; int T = bottom_blob.h; int num_output = top_blob.w; + int hidden_size = cell_state.w; - // 4 x num_output - Mat gates(4, num_output, 2u, opt.workspace_allocator); + // 4 x hidden_size + Mat gates(4, hidden_size, 2u, opt.workspace_allocator); if (gates.empty()) return -100; + Mat tmp_hidden_state; + if (num_output != hidden_size) + { + tmp_hidden_state.create(hidden_size, 4u, opt.workspace_allocator); + if (tmp_hidden_state.empty()) + return -100; + } + // unroll for (int t = 0; t < T; t++) { @@ -216,10 +278,10 @@ static int lstm_fp16sa(const Mat& bottom_blob, Mat& top_blob, int reverse, const int ti = reverse ? T - 1 - t : t; - int nn_num_output = num_output >> 1; - int remain_num_output_start = nn_num_output << 1; + int nn_hidden_size = hidden_size >> 1; + int remain_hidden_size_start = nn_hidden_size << 1; #pragma omp parallel for num_threads(opt.num_threads) - for (int qq = 0; qq < nn_num_output; qq++) + for (int qq = 0; qq < nn_hidden_size; qq++) { int q = qq * 2; @@ -319,7 +381,7 @@ static int lstm_fp16sa(const Mat& bottom_blob, Mat& top_blob, int reverse, const vst1q_f16(gates_data, _IFOG); } #pragma omp parallel for num_threads(opt.num_threads) - for (int q = remain_num_output_start; q < num_output; q++) + for (int q = remain_hidden_size_start; q < hidden_size; q++) { const __fp16* bias_c_IFOG = (const __fp16*)bias_c + q * 4; @@ -428,11 +490,12 @@ static int lstm_fp16sa(const Mat& bottom_blob, Mat& top_blob, int reverse, const float* cell_ptr = cell_state; float* hidden_ptr = hidden_state; + float* tmp_hidden_ptr = tmp_hidden_state; - nn_num_output = num_output >> 2; - remain_num_output_start = nn_num_output << 2; + nn_hidden_size = hidden_size >> 2; + remain_hidden_size_start = nn_hidden_size << 2; #pragma omp parallel for num_threads(opt.num_threads) - for (int qq = 0; qq < nn_num_output; qq++) + for (int qq = 0; qq < nn_hidden_size; qq++) { int q = qq * 4; @@ -449,11 +512,19 @@ static int lstm_fp16sa(const Mat& bottom_blob, Mat& top_blob, int reverse, const float32x4_t _H = vmulq_f32(_O, tanh_ps(_cell2)); vst1q_f32(cell_ptr + q, _cell2); - vst1q_f32(hidden_ptr + q, _H); - vst1_f16(output_data + q, vcvt_f16_f32(_H)); + + if (num_output == hidden_size) + { + vst1q_f32(hidden_ptr + q, _H); + vst1_f16(output_data + q, vcvt_f16_f32(_H)); + } + else + { + vst1q_f32(tmp_hidden_ptr + q, _H); + } } #pragma omp parallel for num_threads(opt.num_threads) - for (int q = remain_num_output_start; q < num_output; q++) + for (int q = remain_hidden_size_start; q < hidden_size; q++) { const __fp16* gates_data = gates.row(q); @@ -471,8 +542,43 @@ static int lstm_fp16sa(const Mat& bottom_blob, Mat& top_blob, int reverse, const float H = O * tanh(cell2); cell_ptr[q] = cell2; - hidden_ptr[q] = H; - output_data[q] = (__fp16)H; + if (num_output == hidden_size) + { + hidden_ptr[q] = H; + output_data[q] = (__fp16)H; + } + else + { + tmp_hidden_ptr[q] = H; + } + } + + if (num_output != hidden_size) + { + // int nn_num_output = num_output >> 2; + // int remain_num_output_start = nn_num_output << 2; + // #pragma omp parallel for num_threads(opt.num_threads) + // for (int qq = 0; qq < nn_num_output; qq++) + // { + // int q = qq * 4; + // + // } + int remain_num_output_start = 0; + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = remain_num_output_start; q < num_output; q++) + { + const float* hr = weight_hr.row(q); + const float* tmp_hidden_ptr = tmp_hidden_state; + + float H = 0; + for (int i = 0; i < hidden_size; i++) + { + H += tmp_hidden_ptr[i] * hr[i]; + } + + hidden_ptr[q] = H; + output_data[q] = (__fp16)H; + } } } @@ -483,19 +589,19 @@ int LSTM_arm::create_pipeline_fp16s(const Option& opt) { // pack IFOG int num_directions = direction == 2 ? 2 : 1; - int size = weight_data_size / num_directions / num_output / 4; + int size = weight_data_size / num_directions / hidden_size / 4; if (opt.use_fp16_arithmetic) { - weight_xc_data_packed.create(size, num_output / 2 + num_output % 2, num_directions, 16u, 8); - bias_c_data_packed.create(num_output, 1, num_directions, 8u, 4); - weight_hc_data_packed.create(num_output, num_output / 2 + num_output % 2, num_directions, 16u, 8); + weight_xc_data_packed.create(size, hidden_size / 2 + hidden_size % 2, num_directions, 16u, 8); + bias_c_data_packed.create(hidden_size, 1, num_directions, 8u, 4); + weight_hc_data_packed.create(num_output, hidden_size / 2 + hidden_size % 2, num_directions, 16u, 8); } else { - weight_xc_data_packed.create(size, num_output, num_directions, 8u, 4); - bias_c_data_packed.create(num_output, 1, num_directions, 8u, 4); - weight_hc_data_packed.create(num_output, num_output, num_directions, 8u, 4); + weight_xc_data_packed.create(size, hidden_size, num_directions, 8u, 4); + bias_c_data_packed.create(hidden_size, 1, num_directions, 8u, 4); + weight_hc_data_packed.create(num_output, hidden_size, num_directions, 8u, 4); } #pragma omp parallel for num_threads(opt.num_threads) @@ -519,7 +625,7 @@ int LSTM_arm::create_pipeline_fp16s(const Option& opt) if (opt.use_fp16_arithmetic) { int q = 0; - for (; q + 1 < num_output; q += 2) + for (; q + 1 < hidden_size; q += 2) { bias_c_IFOG[0] = (__fp16)bias_c_I[q]; bias_c_IFOG[1] = (__fp16)bias_c_F[q]; @@ -532,23 +638,23 @@ int LSTM_arm::create_pipeline_fp16s(const Option& opt) bias_c_IFOG += 8; - const float* weight_xc_I = weight_xc.row(num_output * 0 + q); - const float* weight_xc_F = weight_xc.row(num_output * 1 + q); - const float* weight_xc_O = weight_xc.row(num_output * 2 + q); - const float* weight_xc_G = weight_xc.row(num_output * 3 + q); - const float* weight_xc_I_1 = weight_xc.row(num_output * 0 + q + 1); - const float* weight_xc_F_1 = weight_xc.row(num_output * 1 + q + 1); - const float* weight_xc_O_1 = weight_xc.row(num_output * 2 + q + 1); - const float* weight_xc_G_1 = weight_xc.row(num_output * 3 + q + 1); - - const float* weight_hc_I = weight_hc.row(num_output * 0 + q); - const float* weight_hc_F = weight_hc.row(num_output * 1 + q); - const float* weight_hc_O = weight_hc.row(num_output * 2 + q); - const float* weight_hc_G = weight_hc.row(num_output * 3 + q); - const float* weight_hc_I_1 = weight_hc.row(num_output * 0 + q + 1); - const float* weight_hc_F_1 = weight_hc.row(num_output * 1 + q + 1); - const float* weight_hc_O_1 = weight_hc.row(num_output * 2 + q + 1); - const float* weight_hc_G_1 = weight_hc.row(num_output * 3 + q + 1); + const float* weight_xc_I = weight_xc.row(hidden_size * 0 + q); + const float* weight_xc_F = weight_xc.row(hidden_size * 1 + q); + const float* weight_xc_O = weight_xc.row(hidden_size * 2 + q); + const float* weight_xc_G = weight_xc.row(hidden_size * 3 + q); + const float* weight_xc_I_1 = weight_xc.row(hidden_size * 0 + q + 1); + const float* weight_xc_F_1 = weight_xc.row(hidden_size * 1 + q + 1); + const float* weight_xc_O_1 = weight_xc.row(hidden_size * 2 + q + 1); + const float* weight_xc_G_1 = weight_xc.row(hidden_size * 3 + q + 1); + + const float* weight_hc_I = weight_hc.row(hidden_size * 0 + q); + const float* weight_hc_F = weight_hc.row(hidden_size * 1 + q); + const float* weight_hc_O = weight_hc.row(hidden_size * 2 + q); + const float* weight_hc_G = weight_hc.row(hidden_size * 3 + q); + const float* weight_hc_I_1 = weight_hc.row(hidden_size * 0 + q + 1); + const float* weight_hc_F_1 = weight_hc.row(hidden_size * 1 + q + 1); + const float* weight_hc_O_1 = weight_hc.row(hidden_size * 2 + q + 1); + const float* weight_hc_G_1 = weight_hc.row(hidden_size * 3 + q + 1); __fp16* weight_xc_IFOG = weight_xc_data_packed_dr.row<__fp16>(q / 2); __fp16* weight_hc_IFOG = weight_hc_data_packed_dr.row<__fp16>(q / 2); @@ -581,7 +687,7 @@ int LSTM_arm::create_pipeline_fp16s(const Option& opt) weight_hc_IFOG += 8; } } - for (; q < num_output; q++) + for (; q < hidden_size; q++) { bias_c_IFOG[0] = (__fp16)bias_c_I[q]; bias_c_IFOG[1] = (__fp16)bias_c_F[q]; @@ -590,15 +696,15 @@ int LSTM_arm::create_pipeline_fp16s(const Option& opt) bias_c_IFOG += 4; - const float* weight_xc_I = weight_xc.row(num_output * 0 + q); - const float* weight_xc_F = weight_xc.row(num_output * 1 + q); - const float* weight_xc_O = weight_xc.row(num_output * 2 + q); - const float* weight_xc_G = weight_xc.row(num_output * 3 + q); + const float* weight_xc_I = weight_xc.row(hidden_size * 0 + q); + const float* weight_xc_F = weight_xc.row(hidden_size * 1 + q); + const float* weight_xc_O = weight_xc.row(hidden_size * 2 + q); + const float* weight_xc_G = weight_xc.row(hidden_size * 3 + q); - const float* weight_hc_I = weight_hc.row(num_output * 0 + q); - const float* weight_hc_F = weight_hc.row(num_output * 1 + q); - const float* weight_hc_O = weight_hc.row(num_output * 2 + q); - const float* weight_hc_G = weight_hc.row(num_output * 3 + q); + const float* weight_hc_I = weight_hc.row(hidden_size * 0 + q); + const float* weight_hc_F = weight_hc.row(hidden_size * 1 + q); + const float* weight_hc_O = weight_hc.row(hidden_size * 2 + q); + const float* weight_hc_G = weight_hc.row(hidden_size * 3 + q); __fp16* weight_xc_IFOG = weight_xc_data_packed_dr.row<__fp16>(q / 2 + q % 2); __fp16* weight_hc_IFOG = weight_hc_data_packed_dr.row<__fp16>(q / 2 + q % 2); @@ -626,7 +732,7 @@ int LSTM_arm::create_pipeline_fp16s(const Option& opt) } else { - for (int q = 0; q < num_output; q++) + for (int q = 0; q < hidden_size; q++) { bias_c_IFOG[0] = (__fp16)bias_c_I[q]; bias_c_IFOG[1] = (__fp16)bias_c_F[q]; @@ -635,15 +741,15 @@ int LSTM_arm::create_pipeline_fp16s(const Option& opt) bias_c_IFOG += 4; - const float* weight_xc_I = weight_xc.row(num_output * 0 + q); - const float* weight_xc_F = weight_xc.row(num_output * 1 + q); - const float* weight_xc_O = weight_xc.row(num_output * 2 + q); - const float* weight_xc_G = weight_xc.row(num_output * 3 + q); + const float* weight_xc_I = weight_xc.row(hidden_size * 0 + q); + const float* weight_xc_F = weight_xc.row(hidden_size * 1 + q); + const float* weight_xc_O = weight_xc.row(hidden_size * 2 + q); + const float* weight_xc_G = weight_xc.row(hidden_size * 3 + q); - const float* weight_hc_I = weight_hc.row(num_output * 0 + q); - const float* weight_hc_F = weight_hc.row(num_output * 1 + q); - const float* weight_hc_O = weight_hc.row(num_output * 2 + q); - const float* weight_hc_G = weight_hc.row(num_output * 3 + q); + const float* weight_hc_I = weight_hc.row(hidden_size * 0 + q); + const float* weight_hc_F = weight_hc.row(hidden_size * 1 + q); + const float* weight_hc_O = weight_hc.row(hidden_size * 2 + q); + const float* weight_hc_G = weight_hc.row(hidden_size * 3 + q); __fp16* weight_xc_IFOG = weight_xc_data_packed_dr.row<__fp16>(q); __fp16* weight_hc_IFOG = weight_hc_data_packed_dr.row<__fp16>(q); @@ -671,6 +777,13 @@ int LSTM_arm::create_pipeline_fp16s(const Option& opt) } } + if (opt.lightmode) + { + weight_xc_data.release(); + bias_c_data.release(); + weight_hc_data.release(); + } + return 0; } @@ -686,7 +799,7 @@ int LSTM_arm::forward_fp16s(const Mat& bottom_blob, Mat& top_blob, const Option& return -100; hidden.fill(0.f); - Mat cell(num_output, 4u, opt.workspace_allocator); + Mat cell(hidden_size, 4u, opt.workspace_allocator); if (cell.empty()) return -100; cell.fill(0.f); @@ -698,7 +811,7 @@ int LSTM_arm::forward_fp16s(const Mat& bottom_blob, Mat& top_blob, const Option& // Uni directional if (direction == 0 || direction == 1) { - int ret = lstm_fp16s(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden, cell, opt); + int ret = lstm_fp16s(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt); if (ret != 0) return ret; } @@ -713,14 +826,14 @@ int LSTM_arm::forward_fp16s(const Mat& bottom_blob, Mat& top_blob, const Option& if (top_blob_reverse.empty()) return -100; - int ret0 = lstm_fp16s(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden, cell, opt); + int ret0 = lstm_fp16s(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt); if (ret0 != 0) return ret0; hidden.fill(0.f); cell.fill(0.f); - int ret1 = lstm_fp16s(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), hidden, cell, opt); + int ret1 = lstm_fp16s(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), num_output == hidden_size ? Mat() : weight_hr_data.channel(1), hidden, cell, opt); if (ret1 != 0) return ret1; @@ -762,7 +875,7 @@ int LSTM_arm::forward_fp16s(const std::vector& bottom_blobs, std::vector& bottom_blobs, std::vector& bottom_blobs, std::vector& bottom_blobs, std::vector& bottom_blobs, std::vector& bottom_blobs, std::vector& bottom_blobs, std::vector& top_bl return -100; hidden.fill(0.f); - cell.create(num_output, num_directions, 4u, hidden_cell_allocator); + cell.create(hidden_size, num_directions, 4u, hidden_cell_allocator); if (cell.empty()) return -100; cell.fill(0.f); @@ -265,7 +308,7 @@ int LSTM::forward(const std::vector& bottom_blobs, std::vector& top_bl // Uni directional if (direction == 0 || direction == 1) { - int ret = lstm(bottom_blob, top_blob, direction, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), hidden, cell, opt); + int ret = lstm(bottom_blob, top_blob, direction, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt); if (ret != 0) return ret; } @@ -282,13 +325,13 @@ int LSTM::forward(const std::vector& bottom_blobs, std::vector& top_bl Mat hidden0 = hidden.row_range(0, 1); Mat cell0 = cell.row_range(0, 1); - int ret0 = lstm(bottom_blob, top_blob_forward, 0, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), hidden0, cell0, opt); + int ret0 = lstm(bottom_blob, top_blob_forward, 0, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden0, cell0, opt); if (ret0 != 0) return ret0; Mat hidden1 = hidden.row_range(1, 1); Mat cell1 = cell.row_range(1, 1); - int ret1 = lstm(bottom_blob, top_blob_reverse, 1, weight_xc_data.channel(1), bias_c_data.channel(1), weight_hc_data.channel(1), hidden1, cell1, opt); + int ret1 = lstm(bottom_blob, top_blob_reverse, 1, weight_xc_data.channel(1), bias_c_data.channel(1), weight_hc_data.channel(1), num_output == hidden_size ? Mat() : weight_hr_data.channel(1), hidden1, cell1, opt); if (ret1 != 0) return ret1; diff --git a/src/layer/lstm.h b/src/layer/lstm.h index 78d8366a0..58bd67f98 100644 --- a/src/layer/lstm.h +++ b/src/layer/lstm.h @@ -36,10 +36,12 @@ public: int num_output; int weight_data_size; int direction; // 0=forward 1=reverse 2=bidirectional + int hidden_size; Mat weight_hc_data; Mat weight_xc_data; Mat bias_c_data; + Mat weight_hr_data; }; } // namespace ncnn diff --git a/src/layer/x86/lstm_x86.cpp b/src/layer/x86/lstm_x86.cpp index 59124f790..53c8bfe90 100644 --- a/src/layer/x86/lstm_x86.cpp +++ b/src/layer/x86/lstm_x86.cpp @@ -14,6 +14,13 @@ #include "lstm_x86.h" +#if __SSE2__ +#include +#if __AVX__ +#include +#endif +#endif // __SSE2__ + #include "x86_activation.h" #include "x86_usability.h" @@ -30,23 +37,183 @@ LSTM_x86::LSTM_x86() int LSTM_x86::create_pipeline(const Option& opt) { - (void)(opt); + // pack IFOG + int num_directions = direction == 2 ? 2 : 1; + int size = weight_data_size / num_directions / hidden_size / 4; + +#if __AVX__ + weight_xc_data_packed.create(size, hidden_size / 2 + hidden_size % 2, num_directions, 32u, 8); + bias_c_data_packed.create(hidden_size, 1, num_directions, 16u, 4); + weight_hc_data_packed.create(num_output, hidden_size / 2 + hidden_size % 2, num_directions, 32u, 8); +#else + weight_xc_data_packed.create(size, hidden_size, num_directions, 16u, 4); + bias_c_data_packed.create(hidden_size, 1, num_directions, 16u, 4); + weight_hc_data_packed.create(num_output, hidden_size, num_directions, 16u, 4); +#endif + + #pragma omp parallel for num_threads(opt.num_threads) + for (int dr = 0; dr < num_directions; dr++) + { + const Mat weight_xc = weight_xc_data.channel(dr); + const Mat bias_c = bias_c_data.channel(dr); + const Mat weight_hc = weight_hc_data.channel(dr); + + Mat weight_xc_data_packed_dr = weight_xc_data_packed.channel(dr); + Mat bias_c_data_packed_dr = bias_c_data_packed.channel(dr); + Mat weight_hc_data_packed_dr = weight_hc_data_packed.channel(dr); + + const float* bias_c_I = bias_c.row(0); + const float* bias_c_F = bias_c.row(1); + const float* bias_c_O = bias_c.row(2); + const float* bias_c_G = bias_c.row(3); + + float* bias_c_IFOG = bias_c_data_packed_dr.row(0); + + int q = 0; +#if __AVX__ + for (; q + 1 < hidden_size; q += 2) + { + bias_c_IFOG[0] = bias_c_I[q]; + bias_c_IFOG[1] = bias_c_F[q]; + bias_c_IFOG[2] = bias_c_O[q]; + bias_c_IFOG[3] = bias_c_G[q]; + bias_c_IFOG[4] = bias_c_I[q + 1]; + bias_c_IFOG[5] = bias_c_F[q + 1]; + bias_c_IFOG[6] = bias_c_O[q + 1]; + bias_c_IFOG[7] = bias_c_G[q + 1]; + + bias_c_IFOG += 8; + + const float* weight_xc_I = weight_xc.row(hidden_size * 0 + q); + const float* weight_xc_F = weight_xc.row(hidden_size * 1 + q); + const float* weight_xc_O = weight_xc.row(hidden_size * 2 + q); + const float* weight_xc_G = weight_xc.row(hidden_size * 3 + q); + const float* weight_xc_I_1 = weight_xc.row(hidden_size * 0 + q + 1); + const float* weight_xc_F_1 = weight_xc.row(hidden_size * 1 + q + 1); + const float* weight_xc_O_1 = weight_xc.row(hidden_size * 2 + q + 1); + const float* weight_xc_G_1 = weight_xc.row(hidden_size * 3 + q + 1); + + const float* weight_hc_I = weight_hc.row(hidden_size * 0 + q); + const float* weight_hc_F = weight_hc.row(hidden_size * 1 + q); + const float* weight_hc_O = weight_hc.row(hidden_size * 2 + q); + const float* weight_hc_G = weight_hc.row(hidden_size * 3 + q); + const float* weight_hc_I_1 = weight_hc.row(hidden_size * 0 + q + 1); + const float* weight_hc_F_1 = weight_hc.row(hidden_size * 1 + q + 1); + const float* weight_hc_O_1 = weight_hc.row(hidden_size * 2 + q + 1); + const float* weight_hc_G_1 = weight_hc.row(hidden_size * 3 + q + 1); + + float* weight_xc_IFOG = weight_xc_data_packed_dr.row(q / 2); + float* weight_hc_IFOG = weight_hc_data_packed_dr.row(q / 2); + + for (int i = 0; i < size; i++) + { + weight_xc_IFOG[0] = weight_xc_I[i]; + weight_xc_IFOG[1] = weight_xc_F[i]; + weight_xc_IFOG[2] = weight_xc_O[i]; + weight_xc_IFOG[3] = weight_xc_G[i]; + weight_xc_IFOG[4] = weight_xc_I_1[i]; + weight_xc_IFOG[5] = weight_xc_F_1[i]; + weight_xc_IFOG[6] = weight_xc_O_1[i]; + weight_xc_IFOG[7] = weight_xc_G_1[i]; + + weight_xc_IFOG += 8; + } + + for (int i = 0; i < num_output; i++) + { + weight_hc_IFOG[0] = weight_hc_I[i]; + weight_hc_IFOG[1] = weight_hc_F[i]; + weight_hc_IFOG[2] = weight_hc_O[i]; + weight_hc_IFOG[3] = weight_hc_G[i]; + weight_hc_IFOG[4] = weight_hc_I_1[i]; + weight_hc_IFOG[5] = weight_hc_F_1[i]; + weight_hc_IFOG[6] = weight_hc_O_1[i]; + weight_hc_IFOG[7] = weight_hc_G_1[i]; + + weight_hc_IFOG += 8; + } + } +#endif // __AVX__ + for (; q < hidden_size; q++) + { + bias_c_IFOG[0] = bias_c_I[q]; + bias_c_IFOG[1] = bias_c_F[q]; + bias_c_IFOG[2] = bias_c_O[q]; + bias_c_IFOG[3] = bias_c_G[q]; + + bias_c_IFOG += 4; + + const float* weight_xc_I = weight_xc.row(hidden_size * 0 + q); + const float* weight_xc_F = weight_xc.row(hidden_size * 1 + q); + const float* weight_xc_O = weight_xc.row(hidden_size * 2 + q); + const float* weight_xc_G = weight_xc.row(hidden_size * 3 + q); + + const float* weight_hc_I = weight_hc.row(hidden_size * 0 + q); + const float* weight_hc_F = weight_hc.row(hidden_size * 1 + q); + const float* weight_hc_O = weight_hc.row(hidden_size * 2 + q); + const float* weight_hc_G = weight_hc.row(hidden_size * 3 + q); + +#if __AVX__ + float* weight_xc_IFOG = weight_xc_data_packed_dr.row(q / 2 + q % 2); + float* weight_hc_IFOG = weight_hc_data_packed_dr.row(q / 2 + q % 2); +#else + float* weight_xc_IFOG = weight_xc_data_packed_dr.row(q); + float* weight_hc_IFOG = weight_hc_data_packed_dr.row(q); +#endif + + for (int i = 0; i < size; i++) + { + weight_xc_IFOG[0] = weight_xc_I[i]; + weight_xc_IFOG[1] = weight_xc_F[i]; + weight_xc_IFOG[2] = weight_xc_O[i]; + weight_xc_IFOG[3] = weight_xc_G[i]; + + weight_xc_IFOG += 4; + } + + for (int i = 0; i < num_output; i++) + { + weight_hc_IFOG[0] = weight_hc_I[i]; + weight_hc_IFOG[1] = weight_hc_F[i]; + weight_hc_IFOG[2] = weight_hc_O[i]; + weight_hc_IFOG[3] = weight_hc_G[i]; + + weight_hc_IFOG += 4; + } + } + } + + if (opt.lightmode) + { + weight_xc_data.release(); + bias_c_data.release(); + weight_hc_data.release(); + } return 0; } -#ifdef __AVX__ -static int lstm(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& weight_xc, const Mat& bias_c, const Mat& weight_hc, Mat& hidden_state, Mat& cell_state, const Option& opt) + +static int lstm(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& weight_xc, const Mat& bias_c, const Mat& weight_hc, const Mat& weight_hr, Mat& hidden_state, Mat& cell_state, const Option& opt) { int size = bottom_blob.w; int T = bottom_blob.h; int num_output = top_blob.w; + int hidden_size = cell_state.w; - // 4 x num_output - Mat gates(num_output, 4, 4u, opt.workspace_allocator); + // 4 x hidden_size + Mat gates(4, hidden_size, 4u, opt.workspace_allocator); if (gates.empty()) return -100; + Mat tmp_hidden_state; + if (num_output != hidden_size) + { + tmp_hidden_state.create(hidden_size, 4u, opt.workspace_allocator); + if (tmp_hidden_state.empty()) + return -100; + } + // unroll for (int t = 0; t < T; t++) { @@ -59,267 +226,222 @@ static int lstm(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& w int ti = reverse ? T - 1 - t : t; - int nn_num_output = num_output >> 1; - int remain_num_output_start = nn_num_output << 1; +#if __AVX__ + int nn_hidden_size = hidden_size >> 1; + int remain_hidden_size_start = nn_hidden_size << 1; #pragma omp parallel for num_threads(opt.num_threads) - for (int qq = 0; qq < nn_num_output; qq++) + for (int qq = 0; qq < nn_hidden_size; qq++) { int q = qq * 2; - const float* x = bottom_blob.row(ti); - const float* hidden_ptr_r = hidden_state; - const float* bias_c_I = bias_c.row(0); - const float* bias_c_F = bias_c.row(1); - const float* bias_c_O = bias_c.row(2); - const float* bias_c_G = bias_c.row(3); - - float* gates_data_I = gates.row(0); - float* gates_data_F = gates.row(1); - float* gates_data_O = gates.row(2); - float* gates_data_G = gates.row(3); + const float* bias_c_IFOG = (const float*)bias_c + q * 4; + // gate I F O G - const float* weight_xc_I_0 = weight_xc.row(num_output * 0 + q); - const float* weight_xc_F_0 = weight_xc.row(num_output * 1 + q); - const float* weight_xc_O_0 = weight_xc.row(num_output * 2 + q); - const float* weight_xc_G_0 = weight_xc.row(num_output * 3 + q); - const float* weight_xc_I_1 = weight_xc.row(num_output * 0 + (q + 1)); - const float* weight_xc_F_1 = weight_xc.row(num_output * 1 + (q + 1)); - const float* weight_xc_O_1 = weight_xc.row(num_output * 2 + (q + 1)); - const float* weight_xc_G_1 = weight_xc.row(num_output * 3 + (q + 1)); - - const float* weight_hc_I_0 = weight_hc.row(num_output * 0 + q); - const float* weight_hc_F_0 = weight_hc.row(num_output * 1 + q); - const float* weight_hc_O_0 = weight_hc.row(num_output * 2 + q); - const float* weight_hc_G_0 = weight_hc.row(num_output * 3 + q); - const float* weight_hc_I_1 = weight_hc.row(num_output * 0 + (q + 1)); - const float* weight_hc_F_1 = weight_hc.row(num_output * 1 + (q + 1)); - const float* weight_hc_O_1 = weight_hc.row(num_output * 2 + (q + 1)); - const float* weight_hc_G_1 = weight_hc.row(num_output * 3 + (q + 1)); - - // float I = bias_c_I[q]; - // float F = bias_c_F[q]; - // float O = bias_c_O[q]; - // float G = bias_c_G[q]; - __m256 _sumI_0 = _mm256_setzero_ps(); - __m256 _sumF_0 = _mm256_setzero_ps(); - __m256 _sumO_0 = _mm256_setzero_ps(); - __m256 _sumG_0 = _mm256_setzero_ps(); - __m256 _sumI_1 = _mm256_setzero_ps(); - __m256 _sumF_1 = _mm256_setzero_ps(); - __m256 _sumO_1 = _mm256_setzero_ps(); - __m256 _sumG_1 = _mm256_setzero_ps(); - int nn_num_size = size >> 3; - int remain_size = size & 7; - for (; nn_num_size > 0; nn_num_size--) + const float* weight_xc_IFOG = weight_xc.row(q / 2); + const float* weight_hc_IFOG = weight_hc.row(q / 2); + + __m256 _IFOG = _mm256_loadu_ps(bias_c_IFOG); + __m256 _sum1 = _mm256_setzero_ps(); + __m256 _sum2 = _mm256_setzero_ps(); + __m256 _sum3 = _mm256_setzero_ps(); + + const float* x = bottom_blob.row(ti); + + int i = 0; + for (; i + 3 < size; i += 4) { - __m256 xi = _mm256_loadu_ps(x); - _sumI_0 = _mm256_comp_fmadd_ps(_mm256_loadu_ps(weight_xc_I_0), xi, _sumI_0); - _sumF_0 = _mm256_comp_fmadd_ps(_mm256_loadu_ps(weight_xc_F_0), xi, _sumF_0); - _sumO_0 = _mm256_comp_fmadd_ps(_mm256_loadu_ps(weight_xc_O_0), xi, _sumO_0); - _sumG_0 = _mm256_comp_fmadd_ps(_mm256_loadu_ps(weight_xc_G_0), xi, _sumG_0); - _sumI_1 = _mm256_comp_fmadd_ps(_mm256_loadu_ps(weight_xc_I_1), xi, _sumI_1); - _sumF_1 = _mm256_comp_fmadd_ps(_mm256_loadu_ps(weight_xc_F_1), xi, _sumF_1); - _sumO_1 = _mm256_comp_fmadd_ps(_mm256_loadu_ps(weight_xc_O_1), xi, _sumO_1); - _sumG_1 = _mm256_comp_fmadd_ps(_mm256_loadu_ps(weight_xc_G_1), xi, _sumG_1); - x += 8; - weight_xc_I_0 += 8; - weight_xc_F_0 += 8; - weight_xc_O_0 += 8; - weight_xc_G_0 += 8; - weight_xc_I_1 += 8; - weight_xc_F_1 += 8; - weight_xc_O_1 += 8; - weight_xc_G_1 += 8; + __m256 _xi0 = _mm256_broadcast_ss(x); + __m256 _xi1 = _mm256_broadcast_ss(x + 1); + __m256 _xi2 = _mm256_broadcast_ss(x + 2); + __m256 _xi3 = _mm256_broadcast_ss(x + 3); + __m256 _weight_xc_IFOG0 = _mm256_loadu_ps(weight_xc_IFOG); + __m256 _weight_xc_IFOG1 = _mm256_loadu_ps(weight_xc_IFOG + 8); + __m256 _weight_xc_IFOG2 = _mm256_loadu_ps(weight_xc_IFOG + 16); + __m256 _weight_xc_IFOG3 = _mm256_loadu_ps(weight_xc_IFOG + 24); + _IFOG = _mm256_comp_fmadd_ps(_weight_xc_IFOG0, _xi0, _IFOG); + _sum1 = _mm256_comp_fmadd_ps(_weight_xc_IFOG1, _xi1, _sum1); + _sum2 = _mm256_comp_fmadd_ps(_weight_xc_IFOG2, _xi2, _sum2); + _sum3 = _mm256_comp_fmadd_ps(_weight_xc_IFOG3, _xi3, _sum3); + + x += 4; + weight_xc_IFOG += 32; } - int nn_num_output = num_output >> 3; - int remain_num_output = num_output & 7; - for (; nn_num_output > 0; nn_num_output--) + for (; i < size; i++) { - __m256 h_cont = _mm256_loadu_ps(hidden_ptr_r); - - _sumI_0 = _mm256_comp_fmadd_ps(_mm256_loadu_ps(weight_hc_I_0), h_cont, _sumI_0); - _sumF_0 = _mm256_comp_fmadd_ps(_mm256_loadu_ps(weight_hc_F_0), h_cont, _sumF_0); - _sumO_0 = _mm256_comp_fmadd_ps(_mm256_loadu_ps(weight_hc_O_0), h_cont, _sumO_0); - _sumG_0 = _mm256_comp_fmadd_ps(_mm256_loadu_ps(weight_hc_G_0), h_cont, _sumG_0); - _sumI_1 = _mm256_comp_fmadd_ps(_mm256_loadu_ps(weight_hc_I_1), h_cont, _sumI_1); - _sumF_1 = _mm256_comp_fmadd_ps(_mm256_loadu_ps(weight_hc_F_1), h_cont, _sumF_1); - _sumO_1 = _mm256_comp_fmadd_ps(_mm256_loadu_ps(weight_hc_O_1), h_cont, _sumO_1); - _sumG_1 = _mm256_comp_fmadd_ps(_mm256_loadu_ps(weight_hc_G_1), h_cont, _sumG_1); - hidden_ptr_r += 8; - weight_hc_I_0 += 8; - weight_hc_F_0 += 8; - weight_hc_O_0 += 8; - weight_hc_G_0 += 8; - weight_hc_I_1 += 8; - weight_hc_F_1 += 8; - weight_hc_O_1 += 8; - weight_hc_G_1 += 8; + __m256 _xi = _mm256_broadcast_ss(x); + __m256 _weight_xc_IFOG = _mm256_loadu_ps(weight_xc_IFOG); + _IFOG = _mm256_comp_fmadd_ps(_weight_xc_IFOG, _xi, _IFOG); + + x += 1; + weight_xc_IFOG += 8; } - float sums[8]; - _mm256_storeu_ps(sums, HorizontalSums(_sumI_0, _sumF_0, _sumO_0, _sumG_0, _sumI_1, _sumF_1, _sumO_1, _sumG_1)); - sums[0] += bias_c_I[q]; - sums[1] += bias_c_F[q]; - sums[2] += bias_c_O[q]; - sums[3] += bias_c_G[q]; - sums[4] += bias_c_I[q + 1]; - sums[5] += bias_c_F[q + 1]; - sums[6] += bias_c_O[q + 1]; - sums[7] += bias_c_G[q + 1]; - - for (; remain_size > 0; remain_size--) + + const float* hidden_ptr = hidden_state; + + i = 0; + for (; i + 3 < num_output; i += 4) { - float xi = *x; - sums[0] += *weight_xc_I_0 * xi; - sums[1] += *weight_xc_F_0 * xi; - sums[2] += *weight_xc_O_0 * xi; - sums[3] += *weight_xc_G_0 * xi; - sums[4] += *weight_xc_I_1 * xi; - sums[5] += *weight_xc_F_1 * xi; - sums[6] += *weight_xc_O_1 * xi; - sums[7] += *weight_xc_G_1 * xi; - x++; - weight_xc_I_0++; - weight_xc_F_0++; - weight_xc_O_0++; - weight_xc_G_0++; - weight_xc_I_1++; - weight_xc_F_1++; - weight_xc_O_1++; - weight_xc_G_1++; + __m256 _h_cont0 = _mm256_broadcast_ss(hidden_ptr); + __m256 _h_cont1 = _mm256_broadcast_ss(hidden_ptr + 1); + __m256 _h_cont2 = _mm256_broadcast_ss(hidden_ptr + 2); + __m256 _h_cont3 = _mm256_broadcast_ss(hidden_ptr + 3); + __m256 _weight_hc_IFOG0 = _mm256_loadu_ps(weight_hc_IFOG); + __m256 _weight_hc_IFOG1 = _mm256_loadu_ps(weight_hc_IFOG + 8); + __m256 _weight_hc_IFOG2 = _mm256_loadu_ps(weight_hc_IFOG + 16); + __m256 _weight_hc_IFOG3 = _mm256_loadu_ps(weight_hc_IFOG + 24); + _IFOG = _mm256_comp_fmadd_ps(_weight_hc_IFOG0, _h_cont0, _IFOG); + _sum1 = _mm256_comp_fmadd_ps(_weight_hc_IFOG1, _h_cont1, _sum1); + _sum2 = _mm256_comp_fmadd_ps(_weight_hc_IFOG2, _h_cont2, _sum2); + _sum3 = _mm256_comp_fmadd_ps(_weight_hc_IFOG3, _h_cont3, _sum3); + + hidden_ptr += 4; + weight_hc_IFOG += 32; } - - for (; remain_num_output > 0; remain_num_output--) + for (; i < num_output; i++) { - float h_cont = *hidden_ptr_r; - sums[0] += *weight_hc_I_0 * h_cont; - sums[1] += *weight_hc_F_0 * h_cont; - sums[2] += *weight_hc_O_0 * h_cont; - sums[3] += *weight_hc_G_0 * h_cont; - sums[4] += *weight_hc_I_1 * h_cont; - sums[5] += *weight_hc_F_1 * h_cont; - sums[6] += *weight_hc_O_1 * h_cont; - sums[7] += *weight_hc_G_1 * h_cont; - hidden_ptr_r++; - weight_hc_I_0++; - weight_hc_F_0++; - weight_hc_O_0++; - weight_hc_G_0++; - weight_hc_I_1++; - weight_hc_F_1++; - weight_hc_O_1++; - weight_hc_G_1++; + __m256 _h_cont = _mm256_broadcast_ss(hidden_ptr); + __m256 _weight_hc_IFOG = _mm256_loadu_ps(weight_hc_IFOG); + _IFOG = _mm256_comp_fmadd_ps(_weight_hc_IFOG, _h_cont, _IFOG); + + hidden_ptr += 1; + weight_hc_IFOG += 8; } - gates_data_I[q] = sums[0]; - gates_data_F[q] = sums[1]; - gates_data_O[q] = sums[2]; - gates_data_G[q] = sums[3]; - gates_data_I[q + 1] = sums[4]; - gates_data_F[q + 1] = sums[5]; - gates_data_O[q + 1] = sums[6]; - gates_data_G[q + 1] = sums[7]; + + float* gates_data = gates.row(q); + + _IFOG = _mm256_add_ps(_IFOG, _sum1); + _sum2 = _mm256_add_ps(_sum2, _sum3); + _IFOG = _mm256_add_ps(_IFOG, _sum2); + + _mm256_storeu_ps(gates_data, _IFOG); } +#else + int nn_hidden_size = 0; + int remain_hidden_size_start = 0; +#endif // __AVX__ + #pragma omp parallel for num_threads(opt.num_threads) - for (int q = remain_num_output_start; q < num_output; q++) + for (int q = remain_hidden_size_start; q < hidden_size; q++) { - const float* x = bottom_blob.row(ti); - const float* hidden_ptr_r = hidden_state; - const float* bias_c_I = bias_c.row(0); - const float* bias_c_F = bias_c.row(1); - const float* bias_c_O = bias_c.row(2); - const float* bias_c_G = bias_c.row(3); - - float* gates_data_I = gates.row(0); - float* gates_data_F = gates.row(1); - float* gates_data_O = gates.row(2); - float* gates_data_G = gates.row(3); + const float* bias_c_IFOG = (const float*)bias_c + q * 4; + // gate I F O G - const float* weight_xc_I = weight_xc.row(num_output * 0 + q); - const float* weight_xc_F = weight_xc.row(num_output * 1 + q); - const float* weight_xc_O = weight_xc.row(num_output * 2 + q); - const float* weight_xc_G = weight_xc.row(num_output * 3 + q); - - const float* weight_hc_I = weight_hc.row(num_output * 0 + q); - const float* weight_hc_F = weight_hc.row(num_output * 1 + q); - const float* weight_hc_O = weight_hc.row(num_output * 2 + q); - const float* weight_hc_G = weight_hc.row(num_output * 3 + q); - - // float I = bias_c_I[q]; - // float F = bias_c_F[q]; - // float O = bias_c_O[q]; - // float G = bias_c_G[q]; - __m256 _sumI = _mm256_setzero_ps(); - __m256 _sumF = _mm256_setzero_ps(); - __m256 _sumO = _mm256_setzero_ps(); - __m256 _sumG = _mm256_setzero_ps(); - int nn_num_size = size >> 3; - int remain_size = size & 7; - for (; nn_num_size > 0; nn_num_size--) +#if __AVX__ + const float* weight_xc_IFOG = weight_xc.row(q / 2 + q % 2); + const float* weight_hc_IFOG = weight_hc.row(q / 2 + q % 2); +#else + const float* weight_xc_IFOG = weight_xc.row(q); + const float* weight_hc_IFOG = weight_hc.row(q); +#endif + +#if __SSE2__ + __m128 _IFOG = _mm_loadu_ps(bias_c_IFOG); + __m128 _sum1 = _mm_setzero_ps(); + __m128 _sum2 = _mm_setzero_ps(); + __m128 _sum3 = _mm_setzero_ps(); +#else // __SSE2__ + float I = bias_c_IFOG[0]; + float F = bias_c_IFOG[1]; + float O = bias_c_IFOG[2]; + float G = bias_c_IFOG[3]; +#endif // __SSE2__ + + const float* x = bottom_blob.row(ti); + + int i = 0; +#if __SSE2__ + for (; i + 3 < size; i += 4) { - __m256 xi = _mm256_loadu_ps(x); - _sumI = _mm256_comp_fmadd_ps(_mm256_loadu_ps(weight_xc_I), xi, _sumI); - _sumF = _mm256_comp_fmadd_ps(_mm256_loadu_ps(weight_xc_F), xi, _sumF); - _sumO = _mm256_comp_fmadd_ps(_mm256_loadu_ps(weight_xc_O), xi, _sumO); - _sumG = _mm256_comp_fmadd_ps(_mm256_loadu_ps(weight_xc_G), xi, _sumG); - x += 8; - weight_xc_I += 8; - weight_xc_F += 8; - weight_xc_O += 8; - weight_xc_G += 8; + __m128 _xi0 = _mm_load1_ps(x); + __m128 _xi1 = _mm_load1_ps(x + 1); + __m128 _xi2 = _mm_load1_ps(x + 2); + __m128 _xi3 = _mm_load1_ps(x + 3); + __m128 _weight_xc_IFOG0 = _mm_loadu_ps(weight_xc_IFOG); + __m128 _weight_xc_IFOG1 = _mm_loadu_ps(weight_xc_IFOG + 4); + __m128 _weight_xc_IFOG2 = _mm_loadu_ps(weight_xc_IFOG + 8); + __m128 _weight_xc_IFOG3 = _mm_loadu_ps(weight_xc_IFOG + 12); + _IFOG = _mm_comp_fmadd_ps(_weight_xc_IFOG0, _xi0, _IFOG); + _sum1 = _mm_comp_fmadd_ps(_weight_xc_IFOG1, _xi1, _sum1); + _sum2 = _mm_comp_fmadd_ps(_weight_xc_IFOG2, _xi2, _sum2); + _sum3 = _mm_comp_fmadd_ps(_weight_xc_IFOG3, _xi3, _sum3); + + x += 4; + weight_xc_IFOG += 16; } - int nn_num_output = num_output >> 3; - int remain_num_output = num_output & 7; - for (; nn_num_output > 0; nn_num_output--) +#endif // __SSE2__ + for (; i < size; i++) { - __m256 h_cont = _mm256_loadu_ps(hidden_ptr_r); - - _sumI = _mm256_comp_fmadd_ps(_mm256_loadu_ps(weight_hc_I), h_cont, _sumI); - _sumF = _mm256_comp_fmadd_ps(_mm256_loadu_ps(weight_hc_F), h_cont, _sumF); - _sumO = _mm256_comp_fmadd_ps(_mm256_loadu_ps(weight_hc_O), h_cont, _sumO); - _sumG = _mm256_comp_fmadd_ps(_mm256_loadu_ps(weight_hc_G), h_cont, _sumG); - hidden_ptr_r += 8; - weight_hc_I += 8; - weight_hc_F += 8; - weight_hc_O += 8; - weight_hc_G += 8; +#if __SSE2__ + __m128 _xi = _mm_load1_ps(x); + __m128 _weight_xc_IFOG = _mm_loadu_ps(weight_xc_IFOG); + _IFOG = _mm_comp_fmadd_ps(_weight_xc_IFOG, _xi, _IFOG); +#else // __SSE2__ + float xi = x[0]; + I += xi * weight_xc_IFOG[0]; + F += xi * weight_xc_IFOG[1]; + O += xi * weight_xc_IFOG[2]; + G += xi * weight_xc_IFOG[3]; +#endif // __SSE2__ + + x += 1; + weight_xc_IFOG += 4; } - float sums[4]; - _mm_storeu_ps(sums, HorizontalSums(_sumI, _sumF, _sumO, _sumG)); - sums[0] += bias_c_I[q]; - sums[1] += bias_c_F[q]; - sums[2] += bias_c_O[q]; - sums[3] += bias_c_G[q]; - - for (; remain_size > 0; remain_size--) + + const float* hidden_ptr = hidden_state; + + i = 0; +#if __SSE2__ + for (; i + 3 < num_output; i += 4) { - float xi = *x; - sums[0] += *weight_xc_I * xi; - sums[1] += *weight_xc_F * xi; - sums[2] += *weight_xc_O * xi; - sums[3] += *weight_xc_G * xi; - x++; - weight_xc_I++; - weight_xc_F++; - weight_xc_O++; - weight_xc_G++; + __m128 _h_cont0 = _mm_load1_ps(hidden_ptr); + __m128 _h_cont1 = _mm_load1_ps(hidden_ptr + 1); + __m128 _h_cont2 = _mm_load1_ps(hidden_ptr + 2); + __m128 _h_cont3 = _mm_load1_ps(hidden_ptr + 3); + __m128 _weight_hc_IFOG0 = _mm_loadu_ps(weight_hc_IFOG); + __m128 _weight_hc_IFOG1 = _mm_loadu_ps(weight_hc_IFOG + 4); + __m128 _weight_hc_IFOG2 = _mm_loadu_ps(weight_hc_IFOG + 8); + __m128 _weight_hc_IFOG3 = _mm_loadu_ps(weight_hc_IFOG + 12); + _IFOG = _mm_comp_fmadd_ps(_weight_hc_IFOG0, _h_cont0, _IFOG); + _sum1 = _mm_comp_fmadd_ps(_weight_hc_IFOG1, _h_cont1, _sum1); + _sum2 = _mm_comp_fmadd_ps(_weight_hc_IFOG2, _h_cont2, _sum2); + _sum3 = _mm_comp_fmadd_ps(_weight_hc_IFOG3, _h_cont3, _sum3); + + hidden_ptr += 4; + weight_hc_IFOG += 16; } - - for (; remain_num_output > 0; remain_num_output--) +#endif // __SSE2__ + for (; i < num_output; i++) { - float h_cont = *hidden_ptr_r; - sums[0] += *weight_hc_I * h_cont; - sums[1] += *weight_hc_F * h_cont; - sums[2] += *weight_hc_O * h_cont; - sums[3] += *weight_hc_G * h_cont; - hidden_ptr_r++; - weight_hc_I++; - weight_hc_F++; - weight_hc_O++; - weight_hc_G++; +#if __SSE2__ + __m128 _h_cont = _mm_load1_ps(hidden_ptr); + __m128 _weight_hc_IFOG = _mm_loadu_ps(weight_hc_IFOG); + _IFOG = _mm_comp_fmadd_ps(_weight_hc_IFOG, _h_cont, _IFOG); +#else // __SSE2__ + float h_cont = hidden_ptr[0]; + I += h_cont * weight_hc_IFOG[0]; + F += h_cont * weight_hc_IFOG[1]; + O += h_cont * weight_hc_IFOG[2]; + G += h_cont * weight_hc_IFOG[3]; +#endif // __SSE2__ + + hidden_ptr += 1; + weight_hc_IFOG += 4; } - gates_data_I[q] = sums[0]; - gates_data_F[q] = sums[1]; - gates_data_O[q] = sums[2]; - gates_data_G[q] = sums[3]; + + float* gates_data = gates.row(q); + +#if __SSE2__ + _IFOG = _mm_add_ps(_IFOG, _sum1); + _sum2 = _mm_add_ps(_sum2, _sum3); + _IFOG = _mm_add_ps(_IFOG, _sum2); + + _mm_storeu_ps(gates_data, _IFOG); +#else // __SSE2__ + gates_data[0] = I; + gates_data[1] = F; + gates_data[2] = O; + gates_data[3] = G; +#endif // __SSE2__ } // lstm unit @@ -330,69 +452,117 @@ static int lstm(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& w // c_t := f_t .* c_{t-1} + i_t .* g_t // h_t := o_t .* tanh[c_t] float* output_data = top_blob.row(ti); + float* cell_ptr = cell_state; float* hidden_ptr = hidden_state; - const float* gates_data_I = gates.row(0); - const float* gates_data_F = gates.row(1); - const float* gates_data_O = gates.row(2); - const float* gates_data_G = gates.row(3); - int nn_activation = num_output >> 3; - int remain_activations = num_output & 7; - for (; nn_activation > 0; nn_activation--) + float* tmp_hidden_ptr = tmp_hidden_state; + +#if __SSE2__ + nn_hidden_size = hidden_size >> 2; + remain_hidden_size_start = nn_hidden_size << 2; + #pragma omp parallel for num_threads(opt.num_threads) + for (int qq = 0; qq < nn_hidden_size; qq++) { - __m256 I = sigmoid_avx(_mm256_loadu_ps(gates_data_I)); - __m256 F = sigmoid_avx(_mm256_loadu_ps(gates_data_F)); - __m256 O = sigmoid_avx(_mm256_loadu_ps(gates_data_O)); - __m256 G = tanh_avx(_mm256_loadu_ps(gates_data_G)); - __m256 cell2 = _mm256_add_ps(_mm256_mul_ps(F, _mm256_loadu_ps(cell_ptr)), _mm256_mul_ps(I, G)); - __m256 H = _mm256_mul_ps(O, tanh_avx(cell2)); - _mm256_storeu_ps(cell_ptr, cell2); - _mm256_storeu_ps(hidden_ptr, H); - _mm256_storeu_ps(output_data, H); - cell_ptr += 8; - output_data += 8; - hidden_ptr += 8; - gates_data_I += 8; - gates_data_F += 8; - gates_data_O += 8; - gates_data_G += 8; + int q = qq * 4; + + const float* gates_data = gates.row(q); + + __m128 _IFOG_4x4_0 = _mm_loadu_ps(gates_data); + __m128 _IFOG_4x4_1 = _mm_loadu_ps(gates_data + 4); + __m128 _IFOG_4x4_2 = _mm_loadu_ps(gates_data + 8); + __m128 _IFOG_4x4_3 = _mm_loadu_ps(gates_data + 12); + + _MM_TRANSPOSE4_PS(_IFOG_4x4_0, _IFOG_4x4_1, _IFOG_4x4_2, _IFOG_4x4_3); + + __m128 _I = sigmoid_sse(_IFOG_4x4_0); + __m128 _F = sigmoid_sse(_IFOG_4x4_1); + __m128 _O = sigmoid_sse(_IFOG_4x4_2); + __m128 _G = tanh_sse(_IFOG_4x4_3); + + __m128 _cell2 = _mm_add_ps(_mm_mul_ps(_F, _mm_loadu_ps(cell_ptr + q)), _mm_mul_ps(_I, _G)); + __m128 _H = _mm_mul_ps(_O, tanh_sse(_cell2)); + + _mm_storeu_ps(cell_ptr + q, _cell2); + + if (num_output == hidden_size) + { + _mm_storeu_ps(hidden_ptr + q, _H); + _mm_storeu_ps(output_data + q, _H); + } + else + { + _mm_storeu_ps(tmp_hidden_ptr + q, _H); + } } - for (; remain_activations > 0; remain_activations--) +#else // __SSE2__ + remain_hidden_size_start = 0; +#endif // __SSE2__ + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = remain_hidden_size_start; q < hidden_size; q++) { - float I = *gates_data_I; - float F = *gates_data_F; - float O = *gates_data_O; - float G = *gates_data_G; + const float* gates_data = gates.row(q); + + float I = gates_data[0]; + float F = gates_data[1]; + float O = gates_data[2]; + float G = gates_data[3]; I = 1.f / (1.f + exp(-I)); F = 1.f / (1.f + exp(-F)); O = 1.f / (1.f + exp(-O)); G = tanh(G); - float cell2 = F * *cell_ptr + I * G; + + float cell2 = F * cell_ptr[q] + I * G; float H = O * tanh(cell2); - *cell_ptr = cell2; - *hidden_ptr = H; - *output_data = H; - cell_ptr++; - output_data++; - hidden_ptr++; - gates_data_I++; - gates_data_F++; - gates_data_O++; - gates_data_G++; + + cell_ptr[q] = cell2; + if (num_output == hidden_size) + { + hidden_ptr[q] = H; + output_data[q] = H; + } + else + { + tmp_hidden_ptr[q] = H; + } } - // no cell output here + if (num_output != hidden_size) + { + // int nn_num_output = num_output >> 2; + // int remain_num_output_start = nn_num_output << 2; + // #pragma omp parallel for num_threads(opt.num_threads) + // for (int qq = 0; qq < nn_num_output; qq++) + // { + // int q = qq * 4; + // + // } + int remain_num_output_start = 0; + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = remain_num_output_start; q < num_output; q++) + { + const float* hr = weight_hr.row(q); + const float* tmp_hidden_ptr = tmp_hidden_state; + + float H = 0; + for (int i = 0; i < hidden_size; i++) + { + H += tmp_hidden_ptr[i] * hr[i]; + } + + output_data[q] = H; + hidden_ptr[q] = H; + } + } } return 0; } -#endif int LSTM_x86::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const { -#if __AVX__ int T = bottom_blob.h; + int num_directions = direction == 2 ? 2 : 1; // initial hidden state @@ -400,8 +570,8 @@ int LSTM_x86::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) if (hidden.empty()) return -100; hidden.fill(0.f); - // internal cell state - Mat cell(num_output, 4u, opt.workspace_allocator); + + Mat cell(hidden_size, 4u, opt.workspace_allocator); if (cell.empty()) return -100; cell.fill(0.f); @@ -413,7 +583,7 @@ int LSTM_x86::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) // Uni directional if (direction == 0 || direction == 1) { - int ret = lstm(bottom_blob, top_blob, direction, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), hidden, cell, opt); + int ret = lstm(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt); if (ret != 0) return ret; } @@ -428,14 +598,14 @@ int LSTM_x86::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) if (top_blob_reverse.empty()) return -100; - int ret0 = lstm(bottom_blob, top_blob_forward, 0, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), hidden, cell, opt); + int ret0 = lstm(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt); if (ret0 != 0) return ret0; hidden.fill(0.0f); cell.fill(0.0f); - int ret1 = lstm(bottom_blob, top_blob_reverse, 1, weight_xc_data.channel(1), bias_c_data.channel(1), weight_hc_data.channel(1), hidden, cell, opt); + int ret1 = lstm(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), num_output == hidden_size ? Mat() : weight_hr_data.channel(1), hidden, cell, opt); if (ret1 != 0) return ret1; @@ -452,14 +622,10 @@ int LSTM_x86::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) } return 0; -#else - return LSTM::forward(bottom_blob, top_blob, opt); -#endif } int LSTM_x86::forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const { -#if __AVX__ const Mat& bottom_blob = bottom_blobs[0]; int T = bottom_blob.h; int num_directions = direction == 2 ? 2 : 1; @@ -479,7 +645,7 @@ int LSTM_x86::forward(const std::vector& bottom_blobs, std::vector& to return -100; hidden.fill(0.f); - cell.create(num_output, num_directions, 4u, hidden_cell_allocator); + cell.create(hidden_size, num_directions, 4u, hidden_cell_allocator); if (cell.empty()) return -100; cell.fill(0.f); @@ -493,7 +659,7 @@ int LSTM_x86::forward(const std::vector& bottom_blobs, std::vector& to // Uni directional if (direction == 0 || direction == 1) { - int ret = lstm(bottom_blob, top_blob, direction, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), hidden, cell, opt); + int ret = lstm(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt); if (ret != 0) return ret; } @@ -510,15 +676,13 @@ int LSTM_x86::forward(const std::vector& bottom_blobs, std::vector& to Mat hidden0 = hidden.row_range(0, 1); Mat cell0 = cell.row_range(0, 1); - - int ret0 = lstm(bottom_blob, top_blob_forward, 0, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), hidden0, cell0, opt); + int ret0 = lstm(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden0, cell0, opt); if (ret0 != 0) return ret0; Mat hidden1 = hidden.row_range(1, 1); Mat cell1 = cell.row_range(1, 1); - - int ret1 = lstm(bottom_blob, top_blob_reverse, 1, weight_xc_data.channel(1), bias_c_data.channel(1), weight_hc_data.channel(1), hidden1, cell1, opt); + int ret1 = lstm(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), num_output == hidden_size ? Mat() : weight_hr_data.channel(1), hidden1, cell1, opt); if (ret1 != 0) return ret1; @@ -541,9 +705,6 @@ int LSTM_x86::forward(const std::vector& bottom_blobs, std::vector& to } return 0; -#else - return LSTM::forward(bottom_blobs, top_blobs, opt); -#endif } } // namespace ncnn diff --git a/src/layer/x86/lstm_x86.h b/src/layer/x86/lstm_x86.h index 51ffb4139..cab7d7e32 100644 --- a/src/layer/x86/lstm_x86.h +++ b/src/layer/x86/lstm_x86.h @@ -31,6 +31,9 @@ public: virtual int forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const; public: + Mat weight_xc_data_packed; + Mat bias_c_data_packed; + Mat weight_hc_data_packed; }; } // namespace ncnn diff --git a/tests/test_lstm.cpp b/tests/test_lstm.cpp index f002a1aec..fb76ad0fb 100644 --- a/tests/test_lstm.cpp +++ b/tests/test_lstm.cpp @@ -15,50 +15,64 @@ #include "layer/lstm.h" #include "testutil.h" -static int test_lstm(const ncnn::Mat& a, int outch, int direction) +static int test_lstm(const ncnn::Mat& a, int outch, int direction, int hidden_size = 0) { int input_size = a.w; int num_directions = direction == 2 ? 2 : 1; + if (hidden_size == 0) + hidden_size = outch; ncnn::ParamDict pd; pd.set(0, outch); - pd.set(1, outch * input_size * 4 * num_directions); + pd.set(1, hidden_size * input_size * 4 * num_directions); pd.set(2, direction); + pd.set(3, hidden_size); - std::vector weights(3); - weights[0] = RandomMat(outch * input_size * 4 * num_directions); - weights[1] = RandomMat(outch * 4 * num_directions); - weights[2] = RandomMat(outch * outch * 4 * num_directions); + std::vector weights(hidden_size == 0 ? 3 : 4); + weights[0] = RandomMat(hidden_size * input_size * 4 * num_directions); + weights[1] = RandomMat(hidden_size * 4 * num_directions); + weights[2] = RandomMat(outch * hidden_size * 4 * num_directions); + if (hidden_size) + { + weights[3] = RandomMat(hidden_size * outch * num_directions); + } int ret = test_layer("LSTM", pd, weights, a); if (ret != 0) { - fprintf(stderr, "test_lstm failed a.dims=%d a=(%d %d %d) outch=%d, direction = %d \n", a.dims, a.w, a.h, a.c, outch, direction); + fprintf(stderr, "test_lstm failed a.dims=%d a=(%d %d %d) outch=%d direction=%d hidden_size=%d\n", a.dims, a.w, a.h, a.c, outch, direction, hidden_size); } return ret; } -int test_lstm_layer_with_hidden(const ncnn::Mat& a, int outch, int direction) +int test_lstm_layer_with_hidden(const ncnn::Mat& a, int outch, int direction, int hidden_size = 0) { int input_size = a.w; int num_directions = direction == 2 ? 2 : 1; + if (hidden_size == 0) + hidden_size = outch; ncnn::ParamDict pd; pd.set(0, outch); - pd.set(1, outch * input_size * 4 * num_directions); + pd.set(1, hidden_size * input_size * 4 * num_directions); pd.set(2, direction); + pd.set(3, hidden_size); - std::vector weights(3); - weights[0] = RandomMat(outch * input_size * 4 * num_directions); - weights[1] = RandomMat(outch * 4 * num_directions); - weights[2] = RandomMat(outch * outch * 4 * num_directions); + std::vector weights(hidden_size == 0 ? 3 : 4); + weights[0] = RandomMat(hidden_size * input_size * 4 * num_directions); + weights[1] = RandomMat(hidden_size * 4 * num_directions); + weights[2] = RandomMat(outch * hidden_size * 4 * num_directions); + if (hidden_size) + { + weights[3] = RandomMat(hidden_size * outch * num_directions); + } // initial hidden state ncnn::Mat hidden = RandomMat(outch, num_directions); // initial cell state - ncnn::Mat cell = RandomMat(outch, num_directions); + ncnn::Mat cell = RandomMat(hidden_size, num_directions); std::vector as(3); as[0] = a; @@ -68,32 +82,39 @@ int test_lstm_layer_with_hidden(const ncnn::Mat& a, int outch, int direction) int ret = test_layer("LSTM", pd, weights, as, 3); if (ret != 0) { - fprintf(stderr, "test_lstm_layer_with_hidden failed a.dims=%d a=(%d %d %d) outch=%d, direction = %d \n", a.dims, a.w, a.h, a.c, outch, direction); + fprintf(stderr, "test_lstm_layer_with_hidden failed a.dims=%d a=(%d %d %d) outch=%d direction=%d hidden_size=%d\n", a.dims, a.w, a.h, a.c, outch, direction, hidden_size); } return ret; } -int test_lstm_layer_with_hidden_input(const ncnn::Mat& a, int outch, int direction) +int test_lstm_layer_with_hidden_input(const ncnn::Mat& a, int outch, int direction, int hidden_size = 0) { int input_size = a.w; int num_directions = direction == 2 ? 2 : 1; + if (hidden_size == 0) + hidden_size = outch; ncnn::ParamDict pd; pd.set(0, outch); - pd.set(1, outch * input_size * 4 * num_directions); + pd.set(1, hidden_size * input_size * 4 * num_directions); pd.set(2, direction); + pd.set(3, hidden_size); - std::vector weights(3); - weights[0] = RandomMat(outch * input_size * 4 * num_directions); - weights[1] = RandomMat(outch * 4 * num_directions); - weights[2] = RandomMat(outch * outch * 4 * num_directions); + std::vector weights(hidden_size == 0 ? 3 : 4); + weights[0] = RandomMat(hidden_size * input_size * 4 * num_directions); + weights[1] = RandomMat(hidden_size * 4 * num_directions); + weights[2] = RandomMat(outch * hidden_size * 4 * num_directions); + if (hidden_size) + { + weights[3] = RandomMat(hidden_size * outch * num_directions); + } // initial hidden state ncnn::Mat hidden = RandomMat(outch, num_directions); // initial cell state - ncnn::Mat cell = RandomMat(outch, num_directions); + ncnn::Mat cell = RandomMat(hidden_size, num_directions); std::vector as(3); as[0] = a; @@ -103,26 +124,33 @@ int test_lstm_layer_with_hidden_input(const ncnn::Mat& a, int outch, int directi int ret = test_layer("LSTM", pd, weights, as, 1); if (ret != 0) { - fprintf(stderr, "test_lstm_layer_with_hidden_input failed a.dims=%d a=(%d %d %d) outch=%d, direction = %d \n", a.dims, a.w, a.h, a.c, outch, direction); + fprintf(stderr, "test_lstm_layer_with_hidden_input failed a.dims=%d a=(%d %d %d) outch=%d direction=%d hidden_size=%d\n", a.dims, a.w, a.h, a.c, outch, direction, hidden_size); } return ret; } -int test_lstm_layer_with_hidden_output(const ncnn::Mat& a, int outch, int direction) +int test_lstm_layer_with_hidden_output(const ncnn::Mat& a, int outch, int direction, int hidden_size = 0) { int input_size = a.w; int num_directions = direction == 2 ? 2 : 1; + if (hidden_size == 0) + hidden_size = outch; ncnn::ParamDict pd; pd.set(0, outch); - pd.set(1, outch * input_size * 4 * num_directions); + pd.set(1, hidden_size * input_size * 4 * num_directions); pd.set(2, direction); + pd.set(3, hidden_size); - std::vector weights(3); - weights[0] = RandomMat(outch * input_size * 4 * num_directions); - weights[1] = RandomMat(outch * 4 * num_directions); - weights[2] = RandomMat(outch * outch * 4 * num_directions); + std::vector weights(hidden_size == 0 ? 3 : 4); + weights[0] = RandomMat(hidden_size * input_size * 4 * num_directions); + weights[1] = RandomMat(hidden_size * 4 * num_directions); + weights[2] = RandomMat(outch * hidden_size * 4 * num_directions); + if (hidden_size) + { + weights[3] = RandomMat(hidden_size * outch * num_directions); + } std::vector as(1); as[0] = a; @@ -130,7 +158,7 @@ int test_lstm_layer_with_hidden_output(const ncnn::Mat& a, int outch, int direct int ret = test_layer("LSTM", pd, weights, as, 3); if (ret != 0) { - fprintf(stderr, "test_lstm_layer_with_hidden_output failed a.dims=%d a=(%d %d %d) outch=%d, direction = %d \n", a.dims, a.w, a.h, a.c, outch, direction); + fprintf(stderr, "test_lstm_layer_with_hidden_output failed a.dims=%d a=(%d %d %d) outch=%d direction=%d hidden_size=%d\n", a.dims, a.w, a.h, a.c, outch, direction, hidden_size); } return ret; @@ -147,7 +175,7 @@ static int test_lstm_0() || test_lstm(RandomMat(5, 16), 16, 2) || test_lstm(RandomMat(3, 16), 8, 2) || test_lstm(RandomMat(8, 16), 16, 2) - || test_lstm(RandomMat(2, 5), 17, 2); + || test_lstm(RandomMat(2, 5), 17, 2, 15); } static int test_lstm_1() @@ -160,7 +188,7 @@ static int test_lstm_1() || test_lstm_layer_with_hidden(RandomMat(19, 15), 8, 2) || test_lstm_layer_with_hidden(RandomMat(5, 16), 16, 2) || test_lstm_layer_with_hidden(RandomMat(3, 16), 8, 2) - || test_lstm_layer_with_hidden(RandomMat(2, 5), 99, 2) + || test_lstm_layer_with_hidden(RandomMat(2, 5), 99, 2, 33) || test_lstm_layer_with_hidden(RandomMat(4, 4), 1, 1) || test_lstm_layer_with_hidden(RandomMat(8, 2), 2, 1) || test_lstm_layer_with_hidden(RandomMat(16, 8), 7, 1) @@ -168,7 +196,7 @@ static int test_lstm_1() || test_lstm_layer_with_hidden(RandomMat(19, 15), 8, 1) || test_lstm_layer_with_hidden(RandomMat(5, 16), 16, 1) || test_lstm_layer_with_hidden(RandomMat(3, 16), 8, 1) - || test_lstm_layer_with_hidden(RandomMat(2, 5), 99, 1) + || test_lstm_layer_with_hidden(RandomMat(2, 5), 99, 1, 33) || test_lstm_layer_with_hidden(RandomMat(4, 2), 1, 0) || test_lstm_layer_with_hidden(RandomMat(8, 2), 2, 0) || test_lstm_layer_with_hidden(RandomMat(16, 8), 7, 0) @@ -176,7 +204,7 @@ static int test_lstm_1() || test_lstm_layer_with_hidden(RandomMat(19, 15), 8, 0) || test_lstm_layer_with_hidden(RandomMat(5, 16), 16, 0) || test_lstm_layer_with_hidden(RandomMat(3, 16), 8, 0) - || test_lstm_layer_with_hidden(RandomMat(2, 5), 17, 0) + || test_lstm_layer_with_hidden(RandomMat(2, 5), 17, 0, 15) || test_lstm_layer_with_hidden_input(RandomMat(4, 4), 1, 2) || test_lstm_layer_with_hidden_input(RandomMat(8, 2), 2, 2) @@ -185,7 +213,7 @@ static int test_lstm_1() || test_lstm_layer_with_hidden_input(RandomMat(19, 15), 8, 2) || test_lstm_layer_with_hidden_input(RandomMat(5, 16), 16, 2) || test_lstm_layer_with_hidden_input(RandomMat(3, 16), 8, 2) - || test_lstm_layer_with_hidden_input(RandomMat(2, 5), 99, 2) + || test_lstm_layer_with_hidden_input(RandomMat(2, 5), 99, 2, 33) || test_lstm_layer_with_hidden_input(RandomMat(4, 4), 1, 1) || test_lstm_layer_with_hidden_input(RandomMat(8, 2), 2, 1) || test_lstm_layer_with_hidden_input(RandomMat(16, 8), 7, 1) @@ -193,7 +221,7 @@ static int test_lstm_1() || test_lstm_layer_with_hidden_input(RandomMat(19, 15), 8, 1) || test_lstm_layer_with_hidden_input(RandomMat(5, 16), 16, 1) || test_lstm_layer_with_hidden_input(RandomMat(3, 16), 8, 1) - || test_lstm_layer_with_hidden_input(RandomMat(2, 5), 99, 1) + || test_lstm_layer_with_hidden_input(RandomMat(2, 5), 99, 1, 33) || test_lstm_layer_with_hidden_input(RandomMat(4, 2), 1, 0) || test_lstm_layer_with_hidden_input(RandomMat(8, 2), 2, 0) || test_lstm_layer_with_hidden_input(RandomMat(16, 8), 7, 0) @@ -201,7 +229,7 @@ static int test_lstm_1() || test_lstm_layer_with_hidden_input(RandomMat(19, 15), 8, 0) || test_lstm_layer_with_hidden_input(RandomMat(5, 16), 16, 0) || test_lstm_layer_with_hidden_input(RandomMat(3, 16), 8, 0) - || test_lstm_layer_with_hidden_input(RandomMat(2, 5), 17, 0) + || test_lstm_layer_with_hidden_input(RandomMat(2, 5), 17, 0, 15) || test_lstm_layer_with_hidden_output(RandomMat(4, 4), 1, 2) || test_lstm_layer_with_hidden_output(RandomMat(8, 2), 2, 2) @@ -210,7 +238,7 @@ static int test_lstm_1() || test_lstm_layer_with_hidden_output(RandomMat(19, 15), 8, 2) || test_lstm_layer_with_hidden_output(RandomMat(5, 16), 16, 2) || test_lstm_layer_with_hidden_output(RandomMat(3, 16), 8, 2) - || test_lstm_layer_with_hidden_output(RandomMat(2, 5), 99, 2) + || test_lstm_layer_with_hidden_output(RandomMat(2, 5), 99, 2, 33) || test_lstm_layer_with_hidden_output(RandomMat(4, 4), 1, 1) || test_lstm_layer_with_hidden_output(RandomMat(8, 2), 2, 1) || test_lstm_layer_with_hidden_output(RandomMat(16, 8), 7, 1) @@ -218,7 +246,7 @@ static int test_lstm_1() || test_lstm_layer_with_hidden_output(RandomMat(19, 15), 8, 1) || test_lstm_layer_with_hidden_output(RandomMat(5, 16), 16, 1) || test_lstm_layer_with_hidden_output(RandomMat(3, 16), 8, 1) - || test_lstm_layer_with_hidden_output(RandomMat(2, 5), 99, 1) + || test_lstm_layer_with_hidden_output(RandomMat(2, 5), 99, 1, 33) || test_lstm_layer_with_hidden_output(RandomMat(4, 2), 1, 0) || test_lstm_layer_with_hidden_output(RandomMat(8, 2), 2, 0) || test_lstm_layer_with_hidden_output(RandomMat(16, 8), 7, 0) @@ -226,7 +254,7 @@ static int test_lstm_1() || test_lstm_layer_with_hidden_output(RandomMat(19, 15), 8, 0) || test_lstm_layer_with_hidden_output(RandomMat(5, 16), 16, 0) || test_lstm_layer_with_hidden_output(RandomMat(3, 16), 8, 0) - || test_lstm_layer_with_hidden_output(RandomMat(2, 5), 17, 0); + || test_lstm_layer_with_hidden_output(RandomMat(2, 5), 17, 0, 15); } static int test_lstm_2() @@ -240,7 +268,7 @@ static int test_lstm_2() || test_lstm(RandomMat(5, 16), 16, 0) || test_lstm(RandomMat(3, 16), 8, 0) || test_lstm(RandomMat(8, 16), 16, 0) - || test_lstm(RandomMat(2, 5), 17, 0); + || test_lstm(RandomMat(2, 5), 17, 0, 15); } static int test_lstm_3() { @@ -253,7 +281,7 @@ static int test_lstm_3() || test_lstm(RandomMat(5, 16), 16, 1) || test_lstm(RandomMat(3, 16), 8, 1) || test_lstm(RandomMat(8, 16), 16, 1) - || test_lstm(RandomMat(2, 5), 17, 1); + || test_lstm(RandomMat(2, 5), 17, 1, 15); } int main() diff --git a/tools/pnnx/src/pass_level1/nn_LSTM.cpp b/tools/pnnx/src/pass_level1/nn_LSTM.cpp index ba82232ce..a2354dfad 100644 --- a/tools/pnnx/src/pass_level1/nn_LSTM.cpp +++ b/tools/pnnx/src/pass_level1/nn_LSTM.cpp @@ -33,9 +33,9 @@ public: void write(Operator* op, const std::shared_ptr& graph, const torch::jit::Module& mod) const { - // mod.dump(true, true, true); - - // graph->dump(); + // mod.dump(true, true, true); + // + // graph->dump(); const torch::jit::Node* lstm = find_node_by_kind(graph, "aten::lstm"); @@ -49,12 +49,13 @@ public: op->params["pnnx_rnn_output_swapped"] = 1; } - // for (auto aa : lstm->schema().arguments()) - // { - // fprintf(stderr, "arg %s\n", aa.name().c_str()); - // } + // for (auto aa : lstm->schema().arguments()) + // { + // fprintf(stderr, "arg %s\n", aa.name().c_str()); + // } const auto& weight_ih_l0 = mod.attr("weight_ih_l0").toTensor(); + const auto& weight_hh_l0 = mod.attr("weight_hh_l0").toTensor(); op->params["input_size"] = weight_ih_l0.size(1); op->params["hidden_size"] = weight_ih_l0.size(0) / 4; @@ -62,10 +63,12 @@ public: op->params["bias"] = lstm->namedInput("has_biases"); op->params["batch_first"] = lstm->namedInput("batch_first"); op->params["bidirectional"] = lstm->namedInput("bidirectional"); + op->params["proj_size"] = weight_ih_l0.size(0) / 4 == weight_hh_l0.size(1) ? 0 : weight_hh_l0.size(1); const int num_layers = op->params["num_layers"].i; const bool bias = op->params["bias"].b; const bool bidirectional = op->params["bidirectional"].b; + const int proj_size = op->params["proj_size"].i; for (int k = 0; k < num_layers; k++) { @@ -84,6 +87,13 @@ public: op->attrs[bias_hh_lk_key] = mod.attr(bias_hh_lk_key).toTensor(); } + if (proj_size > 0) + { + std::string weight_hr_lk_key = std::string("weight_hr_l") + std::to_string(k); + + op->attrs[weight_hr_lk_key] = mod.attr(weight_hr_lk_key).toTensor(); + } + if (bidirectional) { std::string weight_ih_lk_reverse_key = std::string("weight_ih_l") + std::to_string(k) + "_reverse"; @@ -100,6 +110,13 @@ public: op->attrs[bias_ih_lk_reverse_key] = mod.attr(bias_ih_lk_reverse_key).toTensor(); op->attrs[bias_hh_lk_reverse_key] = mod.attr(bias_hh_lk_reverse_key).toTensor(); } + + if (proj_size > 0) + { + std::string weight_hr_lk_reverse_key = std::string("weight_hr_l") + std::to_string(k) + "_reverse"; + + op->attrs[weight_hr_lk_reverse_key] = mod.attr(weight_hr_lk_reverse_key).toTensor(); + } } } } diff --git a/tools/pnnx/src/pass_level5/unroll_rnn_op.cpp b/tools/pnnx/src/pass_level5/unroll_rnn_op.cpp index c832353be..2fda02423 100644 --- a/tools/pnnx/src/pass_level5/unroll_rnn_op.cpp +++ b/tools/pnnx/src/pass_level5/unroll_rnn_op.cpp @@ -42,6 +42,7 @@ void unroll_rnn_op(Graph& graph) bool has_output_hidden = op->outputs.size() >= 2; bool has_output_cell = op->outputs.size() == 3; const int hidden_size = op->params["hidden_size"].i; + const int proj_size = (op->type == "nn.LSTM") ? op->params["proj_size"].i : 0; bool has_bias = op->params["bias"].b; bool is_bidirectional = op->params["bidirectional"].b; @@ -116,7 +117,14 @@ void unroll_rnn_op(Graph& graph) } else { - op1->params["input_size"] = is_bidirectional ? hidden_size * 2 : hidden_size; + if (proj_size) + { + op1->params["input_size"] = is_bidirectional ? proj_size * 2 : proj_size; + } + else + { + op1->params["input_size"] = is_bidirectional ? hidden_size * 2 : hidden_size; + } op1->inputs.push_back(unrolled_ops[j - 1]->outputs[0]); op1->inputs[0]->consumers.push_back(op1); @@ -171,6 +179,11 @@ void unroll_rnn_op(Graph& graph) op1->attrs["bias_ih_l0"] = op->attrs["bias_ih_l" + std::to_string(j)]; } + if (proj_size) + { + op1->attrs["weight_hr_l0"] = op->attrs["weight_hr_l" + std::to_string(j)]; + } + if (is_bidirectional) { op1->attrs["weight_hh_l0_reverse"] = op->attrs["weight_hh_l" + std::to_string(j) + "_reverse"]; @@ -181,6 +194,11 @@ void unroll_rnn_op(Graph& graph) op1->attrs["bias_hh_l0_reverse"] = op->attrs["bias_hh_l" + std::to_string(j) + "_reverse"]; op1->attrs["bias_ih_l0_reverse"] = op->attrs["bias_ih_l" + std::to_string(j) + "_reverse"]; } + + if (proj_size) + { + op1->attrs["weight_hr_l0_reverse"] = op->attrs["weight_hr_l" + std::to_string(j) + "_reverse"]; + } } unrolled_ops[j] = op1; diff --git a/tools/pnnx/src/pass_ncnn/nn_LSTM.cpp b/tools/pnnx/src/pass_ncnn/nn_LSTM.cpp index ba62b5271..1a1511680 100644 --- a/tools/pnnx/src/pass_ncnn/nn_LSTM.cpp +++ b/tools/pnnx/src/pass_ncnn/nn_LSTM.cpp @@ -27,7 +27,7 @@ public: return R"PNNXIR(7767517 3 4 pnnx.Input input 0 1 input -nn.LSTM op_0 1 3 input out out_hidden out_cell input_size=%input_size hidden_size=%hidden_size num_layers=1 bias=%bias batch_first=%batch_first bidirectional=%bidirectional @weight_ih_l0 @weight_hh_l0 @bias_ih_l0 @bias_hh_l0 @weight_ih_l0_reverse @weight_hh_l0_reverse @bias_ih_l0_reverse @bias_hh_l0_reverse +nn.LSTM op_0 1 3 input out out_hidden out_cell input_size=%input_size hidden_size=%hidden_size num_layers=1 bias=%bias batch_first=%batch_first bidirectional=%bidirectional proj_size=%proj_size @weight_ih_l0 @weight_hh_l0 @bias_ih_l0 @bias_hh_l0 @weight_hr_l0 @weight_ih_l0_reverse @weight_hh_l0_reverse @bias_ih_l0_reverse @bias_hh_l0_reverse @weight_hr_l0_reverse pnnx.Output output 3 0 out out_hidden out_cell )PNNXIR"; } @@ -46,14 +46,19 @@ pnnx.Output output 3 0 out out_hidden out_cell { const bool bidirectional = captured_params.at("bidirectional").b; const int num_directions = bidirectional ? 2 : 1; - const int num_output = captured_params.at("hidden_size").i; + const int hidden_size = captured_params.at("hidden_size").i; const int input_size = captured_params.at("input_size").i; - int weight_data_size = num_directions * num_output * input_size * 4; + int proj_size = captured_params.at("proj_size").i; + if (proj_size == 0) + proj_size = hidden_size; - op->params["0"] = num_output; + int weight_data_size = num_directions * hidden_size * input_size * 4; + + op->params["0"] = proj_size; op->params["1"] = weight_data_size; op->params["2"] = bidirectional ? 2 : 0; + op->params["3"] = hidden_size; op->attrs["0"] = Attribute(); op->attrs["0"].data = {0, 0, 0, 0}; @@ -62,7 +67,7 @@ pnnx.Output output 3 0 out out_hidden out_cell { std::vector new_weight_ih; { - const int weight_data_size_g = num_output * input_size; + const int weight_data_size_g = hidden_size * input_size; const float* weight_ih = (const float*)captured_attrs.at("op_0.weight_ih_l0").data.data(); const float* iptr = weight_ih; @@ -70,7 +75,7 @@ pnnx.Output output 3 0 out out_hidden out_cell const float* gptr = weight_ih + weight_data_size_g * 2; const float* optr = weight_ih + weight_data_size_g * 3; - new_weight_ih.resize(4 * num_output * input_size); + new_weight_ih.resize(4 * hidden_size * input_size); float* weight = (float*)new_weight_ih.data(); float* w_iptr = weight; float* w_fptr = weight + weight_data_size_g; @@ -86,7 +91,7 @@ pnnx.Output output 3 0 out out_hidden out_cell { std::vector new_weight_ih_reverse; { - const int weight_data_size_g = num_output * input_size; + const int weight_data_size_g = hidden_size * input_size; const float* weight_ih = (const float*)captured_attrs.at("op_0.weight_ih_l0_reverse").data.data(); const float* iptr = weight_ih; @@ -94,7 +99,7 @@ pnnx.Output output 3 0 out out_hidden out_cell const float* gptr = weight_ih + weight_data_size_g * 2; const float* optr = weight_ih + weight_data_size_g * 3; - new_weight_ih_reverse.resize(4 * num_output * input_size); + new_weight_ih_reverse.resize(4 * hidden_size * input_size); float* weight = (float*)new_weight_ih_reverse.data(); float* w_iptr = weight; float* w_fptr = weight + weight_data_size_g; @@ -105,11 +110,11 @@ pnnx.Output output 3 0 out out_hidden out_cell memcpy(w_optr, optr, weight_data_size_g * sizeof(float)); memcpy(w_gptr, gptr, weight_data_size_g * sizeof(float)); } - op->attrs["1"] = Attribute({4, num_output, input_size}, new_weight_ih) + Attribute({4, num_output, input_size}, new_weight_ih_reverse); + op->attrs["1"] = Attribute({4, hidden_size, input_size}, new_weight_ih) + Attribute({4, hidden_size, input_size}, new_weight_ih_reverse); } else { - op->attrs["1"] = Attribute({4, num_output, input_size}, new_weight_ih); + op->attrs["1"] = Attribute({4, hidden_size, input_size}, new_weight_ih); } } @@ -124,33 +129,33 @@ pnnx.Output output 3 0 out out_hidden out_cell const float* bias_ih = (const float*)captured_attrs.at("op_0.bias_ih_l0").data.data(); const float* bias_hh = (const float*)captured_attrs.at("op_0.bias_hh_l0").data.data(); const float* bias_ih_iptr = bias_ih; - const float* bias_ih_fptr = bias_ih + num_output; - const float* bias_ih_gptr = bias_ih + num_output * 2; - const float* bias_ih_optr = bias_ih + num_output * 3; + const float* bias_ih_fptr = bias_ih + hidden_size; + const float* bias_ih_gptr = bias_ih + hidden_size * 2; + const float* bias_ih_optr = bias_ih + hidden_size * 3; const float* bias_hh_iptr = bias_hh; - const float* bias_hh_fptr = bias_hh + num_output; - const float* bias_hh_gptr = bias_hh + num_output * 2; - const float* bias_hh_optr = bias_hh + num_output * 3; + const float* bias_hh_fptr = bias_hh + hidden_size; + const float* bias_hh_gptr = bias_hh + hidden_size * 2; + const float* bias_hh_optr = bias_hh + hidden_size * 3; - new_bias.resize(4 * num_output); + new_bias.resize(4 * hidden_size); float* bias = (float*)new_bias.data(); float* b_iptr = bias; - float* b_fptr = bias + num_output; - float* b_optr = bias + num_output * 2; - float* b_gptr = bias + num_output * 3; - for (int i = 0; i < num_output; i++) + float* b_fptr = bias + hidden_size; + float* b_optr = bias + hidden_size * 2; + float* b_gptr = bias + hidden_size * 3; + for (int i = 0; i < hidden_size; i++) { b_iptr[i] = bias_ih_iptr[i] + bias_hh_iptr[i]; } - for (int i = 0; i < num_output; i++) + for (int i = 0; i < hidden_size; i++) { b_fptr[i] = bias_ih_fptr[i] + bias_hh_fptr[i]; } - for (int i = 0; i < num_output; i++) + for (int i = 0; i < hidden_size; i++) { b_optr[i] = bias_ih_optr[i] + bias_hh_optr[i]; } - for (int i = 0; i < num_output; i++) + for (int i = 0; i < hidden_size; i++) { b_gptr[i] = bias_ih_gptr[i] + bias_hh_gptr[i]; } @@ -163,63 +168,63 @@ pnnx.Output output 3 0 out out_hidden out_cell const float* bias_ih = (const float*)captured_attrs.at("op_0.bias_ih_l0_reverse").data.data(); const float* bias_hh = (const float*)captured_attrs.at("op_0.bias_hh_l0_reverse").data.data(); const float* bias_ih_iptr = bias_ih; - const float* bias_ih_fptr = bias_ih + num_output; - const float* bias_ih_gptr = bias_ih + num_output * 2; - const float* bias_ih_optr = bias_ih + num_output * 3; + const float* bias_ih_fptr = bias_ih + hidden_size; + const float* bias_ih_gptr = bias_ih + hidden_size * 2; + const float* bias_ih_optr = bias_ih + hidden_size * 3; const float* bias_hh_iptr = bias_hh; - const float* bias_hh_fptr = bias_hh + num_output; - const float* bias_hh_gptr = bias_hh + num_output * 2; - const float* bias_hh_optr = bias_hh + num_output * 3; + const float* bias_hh_fptr = bias_hh + hidden_size; + const float* bias_hh_gptr = bias_hh + hidden_size * 2; + const float* bias_hh_optr = bias_hh + hidden_size * 3; - new_bias_reverse.resize(4 * num_output); + new_bias_reverse.resize(4 * hidden_size); float* bias = (float*)new_bias_reverse.data(); float* b_iptr = bias; - float* b_fptr = bias + num_output; - float* b_optr = bias + num_output * 2; - float* b_gptr = bias + num_output * 3; - for (int i = 0; i < num_output; i++) + float* b_fptr = bias + hidden_size; + float* b_optr = bias + hidden_size * 2; + float* b_gptr = bias + hidden_size * 3; + for (int i = 0; i < hidden_size; i++) { b_iptr[i] = bias_ih_iptr[i] + bias_hh_iptr[i]; } - for (int i = 0; i < num_output; i++) + for (int i = 0; i < hidden_size; i++) { b_fptr[i] = bias_ih_fptr[i] + bias_hh_fptr[i]; } - for (int i = 0; i < num_output; i++) + for (int i = 0; i < hidden_size; i++) { b_optr[i] = bias_ih_optr[i] + bias_hh_optr[i]; } - for (int i = 0; i < num_output; i++) + for (int i = 0; i < hidden_size; i++) { b_gptr[i] = bias_ih_gptr[i] + bias_hh_gptr[i]; } } - op->attrs["3"] = Attribute({4, num_output}, new_bias) + Attribute({4, num_output}, new_bias_reverse); + op->attrs["3"] = Attribute({4, hidden_size}, new_bias) + Attribute({4, hidden_size}, new_bias_reverse); } else { - op->attrs["3"] = Attribute({4, num_output}, new_bias); + op->attrs["3"] = Attribute({4, hidden_size}, new_bias); } } else { - std::vector bias(4 * num_output, 0.f); + std::vector bias(4 * hidden_size, 0.f); if (bidirectional) - op->attrs["3"] = Attribute({4, num_output}, bias) + Attribute({4, num_output}, bias); + op->attrs["3"] = Attribute({4, hidden_size}, bias) + Attribute({4, hidden_size}, bias); else - op->attrs["3"] = Attribute({4, num_output}, bias); + op->attrs["3"] = Attribute({4, hidden_size}, bias); } op->attrs["4"] = Attribute(); op->attrs["4"].data = {0, 0, 0, 0}; - // reorder IFGO-hidden-hidden to IFOG-hidden-hidden + // reorder IFGO-hidden-proj to IFOG-hidden-proj { std::vector new_weight_hh; { - const int weight_data_size_g = num_output * num_output; + const int weight_data_size_g = hidden_size * proj_size; const float* weight_hh = (const float*)captured_attrs.at("op_0.weight_hh_l0").data.data(); const float* iptr = weight_hh; @@ -227,7 +232,7 @@ pnnx.Output output 3 0 out out_hidden out_cell const float* gptr = weight_hh + weight_data_size_g * 2; const float* optr = weight_hh + weight_data_size_g * 3; - new_weight_hh.resize(4 * num_output * num_output); + new_weight_hh.resize(4 * hidden_size * proj_size); float* weight = (float*)new_weight_hh.data(); float* w_iptr = weight; float* w_fptr = weight + weight_data_size_g; @@ -243,7 +248,7 @@ pnnx.Output output 3 0 out out_hidden out_cell { std::vector new_weight_hh_reverse; { - const int weight_data_size_g = num_output * num_output; + const int weight_data_size_g = hidden_size * proj_size; const float* weight_hh = (const float*)captured_attrs.at("op_0.weight_hh_l0_reverse").data.data(); const float* iptr = weight_hh; @@ -251,7 +256,7 @@ pnnx.Output output 3 0 out out_hidden out_cell const float* gptr = weight_hh + weight_data_size_g * 2; const float* optr = weight_hh + weight_data_size_g * 3; - new_weight_hh_reverse.resize(4 * num_output * num_output); + new_weight_hh_reverse.resize(4 * hidden_size * proj_size); float* weight = (float*)new_weight_hh_reverse.data(); float* w_iptr = weight; float* w_fptr = weight + weight_data_size_g; @@ -262,11 +267,26 @@ pnnx.Output output 3 0 out out_hidden out_cell memcpy(w_optr, optr, weight_data_size_g * sizeof(float)); memcpy(w_gptr, gptr, weight_data_size_g * sizeof(float)); } - op->attrs["5"] = Attribute({4, num_output, num_output}, new_weight_hh) + Attribute({4, num_output, num_output}, new_weight_hh_reverse); + op->attrs["5"] = Attribute({4, hidden_size, proj_size}, new_weight_hh) + Attribute({4, hidden_size, proj_size}, new_weight_hh_reverse); + } + else + { + op->attrs["5"] = Attribute({4, hidden_size, proj_size}, new_weight_hh); + } + } + + if (proj_size != hidden_size) + { + op->attrs["6"] = Attribute(); + op->attrs["6"].data = {0, 0, 0, 0}; + + if (bidirectional) + { + op->attrs["7"] = captured_attrs.at("op_0.weight_hr_l0") + captured_attrs.at("op_0.weight_hr_l0_reverse"); } else { - op->attrs["5"] = Attribute({4, num_output, num_output}, new_weight_hh); + op->attrs["7"] = captured_attrs.at("op_0.weight_hr_l0"); } } } @@ -284,7 +304,7 @@ public: pnnx.Input input 0 1 input pnnx.Input in_hidden 0 1 in_hidden pnnx.Input in_hidden 0 1 in_cell -nn.LSTM op_0 3 3 input in_hidden in_cell out out_hidden out_cell input_size=%input_size hidden_size=%hidden_size num_layers=1 bias=%bias batch_first=%batch_first bidirectional=%bidirectional @weight_ih_l0 @weight_hh_l0 @bias_ih_l0 @bias_hh_l0 @weight_ih_l0_reverse @weight_hh_l0_reverse @bias_ih_l0_reverse @bias_hh_l0_reverse +nn.LSTM op_0 3 3 input in_hidden in_cell out out_hidden out_cell input_size=%input_size hidden_size=%hidden_size num_layers=1 bias=%bias batch_first=%batch_first bidirectional=%bidirectional proj_size=%proj_size @weight_ih_l0 @weight_hh_l0 @bias_ih_l0 @bias_hh_l0 @weight_hr_l0 @weight_ih_l0_reverse @weight_hh_l0_reverse @bias_ih_l0_reverse @bias_hh_l0_reverse @weight_hr_l0_reverse pnnx.Output output 3 0 out out_hidden out_cell )PNNXIR"; } @@ -300,7 +320,7 @@ public: return R"PNNXIR(7767517 3 2 pnnx.Input input 0 1 input -nn.LSTM op_0 1 1 input out input_size=%input_size hidden_size=%hidden_size num_layers=1 bias=%bias batch_first=%batch_first bidirectional=%bidirectional @weight_ih_l0 @weight_hh_l0 @bias_ih_l0 @bias_hh_l0 @weight_ih_l0_reverse @weight_hh_l0_reverse @bias_ih_l0_reverse @bias_hh_l0_reverse +nn.LSTM op_0 1 1 input out input_size=%input_size hidden_size=%hidden_size num_layers=1 bias=%bias batch_first=%batch_first bidirectional=%bidirectional proj_size=%proj_size @weight_ih_l0 @weight_hh_l0 @bias_ih_l0 @bias_hh_l0 @weight_hr_l0 @weight_ih_l0_reverse @weight_hh_l0_reverse @bias_ih_l0_reverse @bias_hh_l0_reverse @weight_hr_l0_reverse pnnx.Output output 1 0 out )PNNXIR"; } @@ -318,7 +338,7 @@ public: pnnx.Input input 0 1 input pnnx.Input in_hidden 0 1 in_hidden pnnx.Input in_hidden 0 1 in_cell -nn.LSTM op_0 3 1 input in_hidden in_cell out input_size=%input_size hidden_size=%hidden_size num_layers=1 bias=%bias batch_first=%batch_first bidirectional=%bidirectional @weight_ih_l0 @weight_hh_l0 @bias_ih_l0 @bias_hh_l0 @weight_ih_l0_reverse @weight_hh_l0_reverse @bias_ih_l0_reverse @bias_hh_l0_reverse +nn.LSTM op_0 3 1 input in_hidden in_cell out input_size=%input_size hidden_size=%hidden_size num_layers=1 bias=%bias batch_first=%batch_first bidirectional=%bidirectional proj_size=%proj_size @weight_ih_l0 @weight_hh_l0 @bias_ih_l0 @bias_hh_l0 @weight_hr_l0 @weight_ih_l0_reverse @weight_hh_l0_reverse @bias_ih_l0_reverse @bias_hh_l0_reverse @weight_hr_l0_reverse pnnx.Output output 1 0 out )PNNXIR"; } diff --git a/tools/pnnx/tests/ncnn/test_nn_LSTM.py b/tools/pnnx/tests/ncnn/test_nn_LSTM.py index 575d44aac..a51f5e940 100644 --- a/tools/pnnx/tests/ncnn/test_nn_LSTM.py +++ b/tools/pnnx/tests/ncnn/test_nn_LSTM.py @@ -22,15 +22,15 @@ class Model(nn.Module): self.lstm_0_0 = nn.LSTM(input_size=32, hidden_size=16) self.lstm_0_1 = nn.LSTM(input_size=16, hidden_size=16, num_layers=3, bias=False) - self.lstm_0_2 = nn.LSTM(input_size=16, hidden_size=16, num_layers=4, bias=True, bidirectional=True) - self.lstm_0_3 = nn.LSTM(input_size=16, hidden_size=16, num_layers=4, bias=True, bidirectional=True) - self.lstm_0_4 = nn.LSTM(input_size=16, hidden_size=16, num_layers=4, bias=True, bidirectional=True) + self.lstm_0_2 = nn.LSTM(input_size=16, hidden_size=16, num_layers=4, bias=True, bidirectional=True, proj_size=10) + self.lstm_0_3 = nn.LSTM(input_size=20, hidden_size=16, num_layers=4, bias=True, bidirectional=True, proj_size=10) + self.lstm_0_4 = nn.LSTM(input_size=20, hidden_size=16, num_layers=4, bias=True, bidirectional=True, proj_size=10) self.lstm_1_0 = nn.LSTM(input_size=25, hidden_size=16, batch_first=True) self.lstm_1_1 = nn.LSTM(input_size=16, hidden_size=16, num_layers=3, bias=False, batch_first=True) - self.lstm_1_2 = nn.LSTM(input_size=16, hidden_size=16, num_layers=4, bias=True, batch_first=True, bidirectional=True) - self.lstm_1_3 = nn.LSTM(input_size=16, hidden_size=16, num_layers=4, bias=True, batch_first=True, bidirectional=True) - self.lstm_1_4 = nn.LSTM(input_size=16, hidden_size=16, num_layers=4, bias=True, batch_first=True, bidirectional=True) + self.lstm_1_2 = nn.LSTM(input_size=16, hidden_size=16, num_layers=4, bias=True, batch_first=True, bidirectional=True, proj_size=10) + self.lstm_1_3 = nn.LSTM(input_size=20, hidden_size=16, num_layers=4, bias=True, batch_first=True, bidirectional=True, proj_size=10) + self.lstm_1_4 = nn.LSTM(input_size=20, hidden_size=16, num_layers=4, bias=True, batch_first=True, bidirectional=True, proj_size=10) def forward(self, x, y): x = x.permute(1, 0, 2) @@ -38,14 +38,14 @@ class Model(nn.Module): x0, _ = self.lstm_0_0(x) x1, _ = self.lstm_0_1(x0) x2, (h0, c0) = self.lstm_0_2(x1) - x3, (h1, c1) = self.lstm_0_3(x1, (h0, c0)) - x4, _ = self.lstm_0_4(x1, (h1, c1)) + x3, (h1, c1) = self.lstm_0_3(x2, (h0, c0)) + x4, _ = self.lstm_0_4(x3, (h1, c1)) y0, _ = self.lstm_1_0(y) y1, _ = self.lstm_1_1(y0) y2, (h2, c2) = self.lstm_1_2(y1) - y3, (h3, c3) = self.lstm_1_3(y1, (h2, c2)) - y4, _ = self.lstm_1_4(y1, (h3, c3)) + y3, (h3, c3) = self.lstm_1_3(y2, (h2, c2)) + y4, _ = self.lstm_1_4(y3, (h3, c3)) x2 = x2.permute(1, 0, 2) x3 = x3.permute(1, 0, 2) diff --git a/tools/pnnx/tests/test_nn_LSTM.py b/tools/pnnx/tests/test_nn_LSTM.py index 36274c8fc..33c542190 100644 --- a/tools/pnnx/tests/test_nn_LSTM.py +++ b/tools/pnnx/tests/test_nn_LSTM.py @@ -22,24 +22,24 @@ class Model(nn.Module): self.lstm_0_0 = nn.LSTM(input_size=32, hidden_size=16) self.lstm_0_1 = nn.LSTM(input_size=16, hidden_size=16, num_layers=3, bias=False) - self.lstm_0_2 = nn.LSTM(input_size=16, hidden_size=16, num_layers=4, bias=True, bidirectional=True) - self.lstm_0_3 = nn.LSTM(input_size=16, hidden_size=16, num_layers=4, bias=True, bidirectional=True) + self.lstm_0_2 = nn.LSTM(input_size=16, hidden_size=16, num_layers=4, bias=True, bidirectional=True, proj_size=10) + self.lstm_0_3 = nn.LSTM(input_size=20, hidden_size=16, num_layers=4, bias=True, bidirectional=True, proj_size=10) self.lstm_1_0 = nn.LSTM(input_size=25, hidden_size=16, batch_first=True) self.lstm_1_1 = nn.LSTM(input_size=16, hidden_size=16, num_layers=3, bias=False, batch_first=True) - self.lstm_1_2 = nn.LSTM(input_size=16, hidden_size=16, num_layers=4, bias=True, batch_first=True, bidirectional=True) - self.lstm_1_3 = nn.LSTM(input_size=16, hidden_size=16, num_layers=4, bias=True, batch_first=True, bidirectional=True) + self.lstm_1_2 = nn.LSTM(input_size=16, hidden_size=16, num_layers=4, bias=True, batch_first=True, bidirectional=True, proj_size=10) + self.lstm_1_3 = nn.LSTM(input_size=20, hidden_size=16, num_layers=4, bias=True, batch_first=True, bidirectional=True, proj_size=10) def forward(self, x, y): x0, (h0, c0) = self.lstm_0_0(x) x1, (h1, c1) = self.lstm_0_1(x0) x2, (h2, c2) = self.lstm_0_2(x1) - x3, (h3, c3) = self.lstm_0_3(x1, (h2, c2)) + x3, (h3, c3) = self.lstm_0_3(x2, (h2, c2)) y0, (h4, c4) = self.lstm_1_0(y) y1, (h5, c5) = self.lstm_1_1(y0) y2, (h6, c6) = self.lstm_1_2(y1) - y3, (h7, c7) = self.lstm_1_3(y1, (h6, c6)) + y3, (h7, c7) = self.lstm_1_3(y2, (h6, c6)) return x2, x3, h0, h1, h2, h3, c0, c1, c2, c3, y2, y3, h4, h5, h6, h7, c4, c5, c6, c7 def test():