Browse Source

implement lstm proj_size (#4263)

tags/20221128
nihui GitHub 3 years ago
parent
commit
77eda4c19f
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 1133 additions and 606 deletions
  1. +1
    -1
      .github/workflows/test-coverage.yml
  2. +6
    -4
      docs/developer-guide/operators.md
  3. +188
    -68
      src/layer/arm/lstm_arm.cpp
  4. +199
    -86
      src/layer/arm/lstm_arm_asimdhp.cpp
  5. +70
    -27
      src/layer/lstm.cpp
  6. +2
    -0
      src/layer/lstm.h
  7. +462
    -301
      src/layer/x86/lstm_x86.cpp
  8. +3
    -0
      src/layer/x86/lstm_x86.h
  9. +70
    -42
      tests/test_lstm.cpp
  10. +24
    -7
      tools/pnnx/src/pass_level1/nn_LSTM.cpp
  11. +19
    -1
      tools/pnnx/src/pass_level5/unroll_rnn_op.cpp
  12. +73
    -53
      tools/pnnx/src/pass_ncnn/nn_LSTM.cpp
  13. +10
    -10
      tools/pnnx/tests/ncnn/test_nn_LSTM.py
  14. +6
    -6
      tools/pnnx/tests/test_nn_LSTM.py

+ 1
- 1
.github/workflows/test-coverage.yml View File

@@ -227,7 +227,7 @@ jobs:
uses: actions/cache@v3
with:
path: lavapipe-install
key: lavapipe-linux-install-20211127-2
key: lavapipe-linux-install-20211127-3
- name: checkout-lavapipe
if: steps.cache-lavapipe.outputs.cache-hit != 'true'
uses: actions/checkout@v3


+ 6
- 4
docs/developer-guide/operators.md View File

@@ -1026,15 +1026,17 @@ y0, hidden y1, cell y2 = lstm(x0, hidden x1, cell x2)

| param id | name | type | default | description |
| --------- | ------------- | ----- | --------- | ----------------- |
| 0 | num_output | int | 0 | hidden size of output |
| 0 | num_output | int | 0 | output size of output |
| 1 | weight_data_size| int | 0 | total size of IFOG weight matrix |
| 2 | direction | int | 0 | 0=forward, 1=reverse, 2=bidirectional |
| 3 | hidden_size | int | num_output| hidden size |

| weight | type | shape |
| ------------- | ----- | --------------------- |
| weight_xc_data| float/fp16/int8 | [input_size, num_output * 4, num_directions] |
| bias_c_data | float/fp16/int8 | [num_output, 4, num_directions] |
| weight_hc_data| float/fp16/int8 | [num_output, num_output * 4, num_directions] |
| weight_xc_data| float/fp16/int8 | [input_size, hidden_size * 4, num_directions] |
| bias_c_data | float/fp16/int8 | [hidden_size, 4, num_directions] |
| weight_hc_data| float/fp16/int8 | [num_output, hidden_size * 4, num_directions] |
| weight_hr_data| float/fp16/int8 | [hidden_size, num_output, num_directions] |

Direction flag:
- 0 = forward only


+ 188
- 68
src/layer/arm/lstm_arm.cpp View File

@@ -58,11 +58,11 @@ int LSTM_arm::create_pipeline(const Option& opt)

// pack IFOG
int num_directions = direction == 2 ? 2 : 1;
int size = weight_data_size / num_directions / num_output / 4;
int size = weight_data_size / num_directions / hidden_size / 4;

weight_xc_data_packed.create(size, num_output, num_directions, 16u, 4);
bias_c_data_packed.create(num_output, 1, num_directions, 16u, 4);
weight_hc_data_packed.create(num_output, num_output, num_directions, 16u, 4);
weight_xc_data_packed.create(size, hidden_size, num_directions, 16u, 4);
bias_c_data_packed.create(hidden_size, 1, num_directions, 16u, 4);
weight_hc_data_packed.create(num_output, hidden_size, num_directions, 16u, 4);

#pragma omp parallel for num_threads(opt.num_threads)
for (int dr = 0; dr < num_directions; dr++)
@@ -82,7 +82,7 @@ int LSTM_arm::create_pipeline(const Option& opt)

float* bias_c_IFOG = bias_c_data_packed_dr.row(0);

for (int q = 0; q < num_output; q++)
for (int q = 0; q < hidden_size; q++)
{
bias_c_IFOG[0] = bias_c_I[q];
bias_c_IFOG[1] = bias_c_F[q];
@@ -91,15 +91,15 @@ int LSTM_arm::create_pipeline(const Option& opt)

bias_c_IFOG += 4;

const float* weight_xc_I = weight_xc.row(num_output * 0 + q);
const float* weight_xc_F = weight_xc.row(num_output * 1 + q);
const float* weight_xc_O = weight_xc.row(num_output * 2 + q);
const float* weight_xc_G = weight_xc.row(num_output * 3 + q);
const float* weight_xc_I = weight_xc.row(hidden_size * 0 + q);
const float* weight_xc_F = weight_xc.row(hidden_size * 1 + q);
const float* weight_xc_O = weight_xc.row(hidden_size * 2 + q);
const float* weight_xc_G = weight_xc.row(hidden_size * 3 + q);

const float* weight_hc_I = weight_hc.row(num_output * 0 + q);
const float* weight_hc_F = weight_hc.row(num_output * 1 + q);
const float* weight_hc_O = weight_hc.row(num_output * 2 + q);
const float* weight_hc_G = weight_hc.row(num_output * 3 + q);
const float* weight_hc_I = weight_hc.row(hidden_size * 0 + q);
const float* weight_hc_F = weight_hc.row(hidden_size * 1 + q);
const float* weight_hc_O = weight_hc.row(hidden_size * 2 + q);
const float* weight_hc_G = weight_hc.row(hidden_size * 3 + q);

float* weight_xc_IFOG = weight_xc_data_packed_dr.row(q);
float* weight_hc_IFOG = weight_hc_data_packed_dr.row(q);
@@ -126,21 +126,37 @@ int LSTM_arm::create_pipeline(const Option& opt)
}
}

if (opt.lightmode)
{
weight_xc_data.release();
bias_c_data.release();
weight_hc_data.release();
}

return 0;
}

static int lstm(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& weight_xc, const Mat& bias_c, const Mat& weight_hc, Mat& hidden_state, Mat& cell_state, const Option& opt)
static int lstm(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& weight_xc, const Mat& bias_c, const Mat& weight_hc, const Mat& weight_hr, Mat& hidden_state, Mat& cell_state, const Option& opt)
{
int size = bottom_blob.w;
int T = bottom_blob.h;

int num_output = top_blob.w;
int hidden_size = cell_state.w;

// 4 x num_output
Mat gates(4, num_output, 4u, opt.workspace_allocator);
// 4 x hidden_size
Mat gates(4, hidden_size, 4u, opt.workspace_allocator);
if (gates.empty())
return -100;

Mat tmp_hidden_state;
if (num_output != hidden_size)
{
tmp_hidden_state.create(hidden_size, 4u, opt.workspace_allocator);
if (tmp_hidden_state.empty())
return -100;
}

// unroll
for (int t = 0; t < T; t++)
{
@@ -155,7 +171,7 @@ static int lstm(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& w

const float* x = bottom_blob.row(ti);
#pragma omp parallel for num_threads(opt.num_threads)
for (int q = 0; q < num_output; q++)
for (int q = 0; q < hidden_size; q++)
{
const float* bias_c_IFOG = (const float*)bias_c + q * 4;

@@ -291,14 +307,15 @@ static int lstm(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& w

float* cell_ptr = cell_state;
float* hidden_ptr = hidden_state;
float* tmp_hidden_ptr = tmp_hidden_state;

int remain_num_output_start = 0;
int remain_hidden_size_start = 0;
#if __ARM_NEON
int nn_num_output = num_output >> 2;
remain_num_output_start = nn_num_output << 2;
int nn_hidden_size = hidden_size >> 2;
remain_hidden_size_start = nn_hidden_size << 2;

#pragma omp parallel for num_threads(opt.num_threads)
for (int qq = 0; qq < nn_num_output; qq++)
for (int qq = 0; qq < nn_hidden_size; qq++)
{
int q = qq * 4;

@@ -315,12 +332,20 @@ static int lstm(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& w
float32x4_t _H = vmulq_f32(_O, tanh_ps(_cell2));

vst1q_f32(cell_ptr + q, _cell2);
vst1q_f32(hidden_ptr + q, _H);
vst1q_f32(output_data + q, _H);

if (num_output == hidden_size)
{
vst1q_f32(hidden_ptr + q, _H);
vst1q_f32(output_data + q, _H);
}
else
{
vst1q_f32(tmp_hidden_ptr + q, _H);
}
}
#endif // __ARM_NEON
#pragma omp parallel for num_threads(opt.num_threads)
for (int q = remain_num_output_start; q < num_output; q++)
for (int q = remain_hidden_size_start; q < hidden_size; q++)
{
const float* gates_data = gates.row(q);

@@ -338,8 +363,43 @@ static int lstm(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& w
float H = O * tanh(cell2);

cell_ptr[q] = cell2;
hidden_ptr[q] = H;
output_data[q] = H;
if (num_output == hidden_size)
{
hidden_ptr[q] = H;
output_data[q] = H;
}
else
{
tmp_hidden_ptr[q] = H;
}
}

if (num_output != hidden_size)
{
// int nn_num_output = num_output >> 2;
// int remain_num_output_start = nn_num_output << 2;
// #pragma omp parallel for num_threads(opt.num_threads)
// for (int qq = 0; qq < nn_num_output; qq++)
// {
// int q = qq * 4;
//
// }
int remain_num_output_start = 0;
#pragma omp parallel for num_threads(opt.num_threads)
for (int q = remain_num_output_start; q < num_output; q++)
{
const float* hr = weight_hr.row(q);
const float* tmp_hidden_ptr = tmp_hidden_state;

float H = 0;
for (int i = 0; i < hidden_size; i++)
{
H += tmp_hidden_ptr[i] * hr[i];
}

hidden_ptr[q] = H;
output_data[q] = H;
}
}
}

@@ -375,7 +435,7 @@ int LSTM_arm::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt)
return -100;
hidden.fill(0.f);

Mat cell(num_output, 4u, opt.workspace_allocator);
Mat cell(hidden_size, 4u, opt.workspace_allocator);
if (cell.empty())
return -100;
cell.fill(0.f);
@@ -387,7 +447,7 @@ int LSTM_arm::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt)
// Uni directional
if (direction == 0 || direction == 1)
{
int ret = lstm(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden, cell, opt);
int ret = lstm(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt);
if (ret != 0)
return ret;
}
@@ -402,14 +462,14 @@ int LSTM_arm::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt)
if (top_blob_reverse.empty())
return -100;

int ret0 = lstm(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden, cell, opt);
int ret0 = lstm(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt);
if (ret0 != 0)
return ret0;

hidden.fill(0.0f);
cell.fill(0.0f);

int ret1 = lstm(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), hidden, cell, opt);
int ret1 = lstm(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), num_output == hidden_size ? Mat() : weight_hr_data.channel(1), hidden, cell, opt);
if (ret1 != 0)
return ret1;

@@ -466,7 +526,7 @@ int LSTM_arm::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& to
return -100;
hidden.fill(0.f);

cell.create(num_output, num_directions, 4u, hidden_cell_allocator);
cell.create(hidden_size, num_directions, 4u, hidden_cell_allocator);
if (cell.empty())
return -100;
cell.fill(0.f);
@@ -480,7 +540,7 @@ int LSTM_arm::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& to
// Uni directional
if (direction == 0 || direction == 1)
{
int ret = lstm(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden, cell, opt);
int ret = lstm(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt);
if (ret != 0)
return ret;
}
@@ -497,13 +557,13 @@ int LSTM_arm::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& to

Mat hidden0 = hidden.row_range(0, 1);
Mat cell0 = cell.row_range(0, 1);
int ret0 = lstm(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden0, cell0, opt);
int ret0 = lstm(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden0, cell0, opt);
if (ret0 != 0)
return ret0;

Mat hidden1 = hidden.row_range(1, 1);
Mat cell1 = cell.row_range(1, 1);
int ret1 = lstm(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), hidden1, cell1, opt);
int ret1 = lstm(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), num_output == hidden_size ? Mat() : weight_hr_data.channel(1), hidden1, cell1, opt);
if (ret1 != 0)
return ret1;

@@ -529,18 +589,27 @@ int LSTM_arm::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& to
}

#if NCNN_BF16
static int lstm_bf16s(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& weight_xc, const Mat& bias_c, const Mat& weight_hc, Mat& hidden_state, Mat& cell_state, const Option& opt)
static int lstm_bf16s(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& weight_xc, const Mat& bias_c, const Mat& weight_hc, const Mat& weight_hr, Mat& hidden_state, Mat& cell_state, const Option& opt)
{
int size = bottom_blob.w;
int T = bottom_blob.h;

int num_output = top_blob.w;
int hidden_size = cell_state.w;

// 4 x num_output
Mat gates(4, num_output, 4u, opt.workspace_allocator);
// 4 x hidden_size
Mat gates(4, hidden_size, 4u, opt.workspace_allocator);
if (gates.empty())
return -100;

Mat tmp_hidden_state;
if (num_output != hidden_size)
{
tmp_hidden_state.create(hidden_size, 4u, opt.workspace_allocator);
if (tmp_hidden_state.empty())
return -100;
}

// unroll
for (int t = 0; t < T; t++)
{
@@ -555,7 +624,7 @@ static int lstm_bf16s(const Mat& bottom_blob, Mat& top_blob, int reverse, const

const unsigned short* x = bottom_blob.row<const unsigned short>(ti);
#pragma omp parallel for num_threads(opt.num_threads)
for (int q = 0; q < num_output; q++)
for (int q = 0; q < hidden_size; q++)
{
const unsigned short* bias_c_IFOG = (const unsigned short*)bias_c + q * 4;

@@ -693,14 +762,15 @@ static int lstm_bf16s(const Mat& bottom_blob, Mat& top_blob, int reverse, const

float* cell_ptr = cell_state;
float* hidden_ptr = hidden_state;
float* tmp_hidden_ptr = tmp_hidden_state;

int remain_num_output_start = 0;
int remain_hidden_size_start = 0;
#if __ARM_NEON
int nn_num_output = num_output >> 2;
remain_num_output_start = nn_num_output << 2;
int nn_hidden_size = hidden_size >> 2;
remain_hidden_size_start = nn_hidden_size << 2;

#pragma omp parallel for num_threads(opt.num_threads)
for (int qq = 0; qq < nn_num_output; qq++)
for (int qq = 0; qq < nn_hidden_size; qq++)
{
int q = qq * 4;

@@ -717,12 +787,20 @@ static int lstm_bf16s(const Mat& bottom_blob, Mat& top_blob, int reverse, const
float32x4_t _H = vmulq_f32(_O, tanh_ps(_cell2));

vst1q_f32(cell_ptr + q, _cell2);
vst1q_f32(hidden_ptr + q, _H);
vst1_u16(output_data + q, bfloat2float(_H));

if (num_output == hidden_size)
{
vst1q_f32(hidden_ptr + q, _H);
vst1_u16(output_data + q, bfloat2float(_H));
}
else
{
vst1q_f32(tmp_hidden_ptr + q, _H);
}
}
#endif // __ARM_NEON
#pragma omp parallel for num_threads(opt.num_threads)
for (int q = remain_num_output_start; q < num_output; q++)
for (int q = remain_hidden_size_start; q < hidden_size; q++)
{
const float* gates_data = gates.row(q);

@@ -740,8 +818,43 @@ static int lstm_bf16s(const Mat& bottom_blob, Mat& top_blob, int reverse, const
float H = O * tanh(cell2);

cell_ptr[q] = cell2;
hidden_ptr[q] = H;
output_data[q] = float32_to_bfloat16(H);
if (num_output == hidden_size)
{
hidden_ptr[q] = H;
output_data[q] = float32_to_bfloat16(H);
}
else
{
tmp_hidden_ptr[q] = H;
}
}

if (num_output != hidden_size)
{
// int nn_num_output = num_output >> 2;
// int remain_num_output_start = nn_num_output << 2;
// #pragma omp parallel for num_threads(opt.num_threads)
// for (int qq = 0; qq < nn_num_output; qq++)
// {
// int q = qq * 4;
//
// }
int remain_num_output_start = 0;
#pragma omp parallel for num_threads(opt.num_threads)
for (int q = remain_num_output_start; q < num_output; q++)
{
const float* hr = weight_hr.row(q);
const float* tmp_hidden_ptr = tmp_hidden_state;

float H = 0;
for (int i = 0; i < hidden_size; i++)
{
H += tmp_hidden_ptr[i] * hr[i];
}

hidden_ptr[q] = H;
output_data[q] = float32_to_bfloat16(H);
}
}
}

@@ -752,11 +865,11 @@ int LSTM_arm::create_pipeline_bf16s(const Option& opt)
{
// pack IFOG
int num_directions = direction == 2 ? 2 : 1;
int size = weight_data_size / num_directions / num_output / 4;
int size = weight_data_size / num_directions / hidden_size / 4;

weight_xc_data_packed.create(size, num_output, num_directions, 8u, 4);
bias_c_data_packed.create(num_output, 1, num_directions, 8u, 4);
weight_hc_data_packed.create(num_output, num_output, num_directions, 8u, 4);
weight_xc_data_packed.create(size, hidden_size, num_directions, 8u, 4);
bias_c_data_packed.create(hidden_size, 1, num_directions, 8u, 4);
weight_hc_data_packed.create(num_output, hidden_size, num_directions, 8u, 4);

#pragma omp parallel for num_threads(opt.num_threads)
for (int dr = 0; dr < num_directions; dr++)
@@ -776,7 +889,7 @@ int LSTM_arm::create_pipeline_bf16s(const Option& opt)

unsigned short* bias_c_IFOG = bias_c_data_packed_dr.row<unsigned short>(0);

for (int q = 0; q < num_output; q++)
for (int q = 0; q < hidden_size; q++)
{
bias_c_IFOG[0] = float32_to_bfloat16(bias_c_I[q]);
bias_c_IFOG[1] = float32_to_bfloat16(bias_c_F[q]);
@@ -785,15 +898,15 @@ int LSTM_arm::create_pipeline_bf16s(const Option& opt)

bias_c_IFOG += 4;

const float* weight_xc_I = weight_xc.row(num_output * 0 + q);
const float* weight_xc_F = weight_xc.row(num_output * 1 + q);
const float* weight_xc_O = weight_xc.row(num_output * 2 + q);
const float* weight_xc_G = weight_xc.row(num_output * 3 + q);
const float* weight_xc_I = weight_xc.row(hidden_size * 0 + q);
const float* weight_xc_F = weight_xc.row(hidden_size * 1 + q);
const float* weight_xc_O = weight_xc.row(hidden_size * 2 + q);
const float* weight_xc_G = weight_xc.row(hidden_size * 3 + q);

const float* weight_hc_I = weight_hc.row(num_output * 0 + q);
const float* weight_hc_F = weight_hc.row(num_output * 1 + q);
const float* weight_hc_O = weight_hc.row(num_output * 2 + q);
const float* weight_hc_G = weight_hc.row(num_output * 3 + q);
const float* weight_hc_I = weight_hc.row(hidden_size * 0 + q);
const float* weight_hc_F = weight_hc.row(hidden_size * 1 + q);
const float* weight_hc_O = weight_hc.row(hidden_size * 2 + q);
const float* weight_hc_G = weight_hc.row(hidden_size * 3 + q);

unsigned short* weight_xc_IFOG = weight_xc_data_packed_dr.row<unsigned short>(q);
unsigned short* weight_hc_IFOG = weight_hc_data_packed_dr.row<unsigned short>(q);
@@ -820,6 +933,13 @@ int LSTM_arm::create_pipeline_bf16s(const Option& opt)
}
}

if (opt.lightmode)
{
weight_xc_data.release();
bias_c_data.release();
weight_hc_data.release();
}

return 0;
}

@@ -835,7 +955,7 @@ int LSTM_arm::forward_bf16s(const Mat& bottom_blob, Mat& top_blob, const Option&
return -100;
hidden.fill(0.f);

Mat cell(num_output, 4u, opt.workspace_allocator);
Mat cell(hidden_size, 4u, opt.workspace_allocator);
if (cell.empty())
return -100;
cell.fill(0.f);
@@ -847,7 +967,7 @@ int LSTM_arm::forward_bf16s(const Mat& bottom_blob, Mat& top_blob, const Option&
// Uni directional
if (direction == 0 || direction == 1)
{
int ret = lstm_bf16s(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden, cell, opt);
int ret = lstm_bf16s(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt);
if (ret != 0)
return ret;
}
@@ -862,14 +982,14 @@ int LSTM_arm::forward_bf16s(const Mat& bottom_blob, Mat& top_blob, const Option&
if (top_blob_reverse.empty())
return -100;

int ret0 = lstm_bf16s(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden, cell, opt);
int ret0 = lstm_bf16s(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt);
if (ret0 != 0)
return ret0;

hidden.fill(0.f);
cell.fill(0.f);

int ret1 = lstm_bf16s(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), hidden, cell, opt);
int ret1 = lstm_bf16s(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), num_output == hidden_size ? Mat() : weight_hr_data.channel(1), hidden, cell, opt);
if (ret1 != 0)
return ret1;

@@ -911,7 +1031,7 @@ int LSTM_arm::forward_bf16s(const std::vector<Mat>& bottom_blobs, std::vector<Ma
return -100;
hidden.fill(0.f);

cell.create(num_output, num_directions, 4u, hidden_cell_allocator);
cell.create(hidden_size, num_directions, 4u, hidden_cell_allocator);
if (cell.empty())
return -100;
cell.fill(0.f);
@@ -925,7 +1045,7 @@ int LSTM_arm::forward_bf16s(const std::vector<Mat>& bottom_blobs, std::vector<Ma
// Uni directional
if (direction == 0 || direction == 1)
{
int ret = lstm_bf16s(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden, cell, opt);
int ret = lstm_bf16s(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt);
if (ret != 0)
return ret;
}
@@ -942,13 +1062,13 @@ int LSTM_arm::forward_bf16s(const std::vector<Mat>& bottom_blobs, std::vector<Ma

Mat hidden0 = hidden.row_range(0, 1);
Mat cell0 = cell.row_range(0, 1);
int ret0 = lstm_bf16s(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden0, cell0, opt);
int ret0 = lstm_bf16s(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden0, cell0, opt);
if (ret0 != 0)
return ret0;

Mat hidden1 = hidden.row_range(1, 1);
Mat cell1 = cell.row_range(1, 1);
int ret1 = lstm_bf16s(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), hidden1, cell1, opt);
int ret1 = lstm_bf16s(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), num_output == hidden_size ? Mat() : weight_hr_data.channel(1), hidden1, cell1, opt);
if (ret1 != 0)
return ret1;



+ 199
- 86
src/layer/arm/lstm_arm_asimdhp.cpp View File

@@ -25,18 +25,27 @@
namespace ncnn {

#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
static int lstm_fp16s(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& weight_xc, const Mat& bias_c, const Mat& weight_hc, Mat& hidden_state, Mat& cell_state, const Option& opt)
static int lstm_fp16s(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& weight_xc, const Mat& bias_c, const Mat& weight_hc, const Mat& weight_hr, Mat& hidden_state, Mat& cell_state, const Option& opt)
{
int size = bottom_blob.w;
int T = bottom_blob.h;

int num_output = top_blob.w;
int hidden_size = cell_state.w;

// 4 x num_output
Mat gates(4, num_output, 4u, opt.workspace_allocator);
// 4 x hidden_size
Mat gates(4, hidden_size, 4u, opt.workspace_allocator);
if (gates.empty())
return -100;

Mat tmp_hidden_state;
if (num_output != hidden_size)
{
tmp_hidden_state.create(hidden_size, 4u, opt.workspace_allocator);
if (tmp_hidden_state.empty())
return -100;
}

// unroll
for (int t = 0; t < T; t++)
{
@@ -51,7 +60,7 @@ static int lstm_fp16s(const Mat& bottom_blob, Mat& top_blob, int reverse, const

const __fp16* x = bottom_blob.row<const __fp16>(ti);
#pragma omp parallel for num_threads(opt.num_threads)
for (int q = 0; q < num_output; q++)
for (int q = 0; q < hidden_size; q++)
{
const __fp16* bias_c_IFOG = (const __fp16*)bias_c + q * 4;

@@ -141,11 +150,12 @@ static int lstm_fp16s(const Mat& bottom_blob, Mat& top_blob, int reverse, const

float* cell_ptr = cell_state;
float* hidden_ptr = hidden_state;
float* tmp_hidden_ptr = tmp_hidden_state;

int nn_num_output = num_output >> 2;
int remain_num_output_start = nn_num_output << 2;
int nn_hidden_size = hidden_size >> 2;
int remain_hidden_size_start = nn_hidden_size << 2;
#pragma omp parallel for num_threads(opt.num_threads)
for (int qq = 0; qq < nn_num_output; qq++)
for (int qq = 0; qq < nn_hidden_size; qq++)
{
int q = qq * 4;

@@ -162,11 +172,19 @@ static int lstm_fp16s(const Mat& bottom_blob, Mat& top_blob, int reverse, const
float32x4_t _H = vmulq_f32(_O, tanh_ps(_cell2));

vst1q_f32(cell_ptr + q, _cell2);
vst1q_f32(hidden_ptr + q, _H);
vst1_f16(output_data + q, vcvt_f16_f32(_H));

if (num_output == hidden_size)
{
vst1q_f32(hidden_ptr + q, _H);
vst1_f16(output_data + q, vcvt_f16_f32(_H));
}
else
{
vst1q_f32(tmp_hidden_ptr + q, _H);
}
}
#pragma omp parallel for num_threads(opt.num_threads)
for (int q = remain_num_output_start; q < num_output; q++)
for (int q = remain_hidden_size_start; q < hidden_size; q++)
{
const float* gates_data = gates.row(q);

@@ -184,26 +202,70 @@ static int lstm_fp16s(const Mat& bottom_blob, Mat& top_blob, int reverse, const
float H = O * tanh(cell2);

cell_ptr[q] = cell2;
hidden_ptr[q] = H;
output_data[q] = (__fp16)(H);
if (num_output == hidden_size)
{
hidden_ptr[q] = H;
output_data[q] = (__fp16)H;
}
else
{
tmp_hidden_ptr[q] = H;
}
}

if (num_output != hidden_size)
{
// int nn_num_output = num_output >> 2;
// int remain_num_output_start = nn_num_output << 2;
// #pragma omp parallel for num_threads(opt.num_threads)
// for (int qq = 0; qq < nn_num_output; qq++)
// {
// int q = qq * 4;
//
// }
int remain_num_output_start = 0;
#pragma omp parallel for num_threads(opt.num_threads)
for (int q = remain_num_output_start; q < num_output; q++)
{
const float* hr = weight_hr.row(q);
const float* tmp_hidden_ptr = tmp_hidden_state;

float H = 0;
for (int i = 0; i < hidden_size; i++)
{
H += tmp_hidden_ptr[i] * hr[i];
}

hidden_ptr[q] = H;
output_data[q] = (__fp16)H;
}
}
}

return 0;
}

static int lstm_fp16sa(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& weight_xc, const Mat& bias_c, const Mat& weight_hc, Mat& hidden_state, Mat& cell_state, const Option& opt)
static int lstm_fp16sa(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& weight_xc, const Mat& bias_c, const Mat& weight_hc, const Mat& weight_hr, Mat& hidden_state, Mat& cell_state, const Option& opt)
{
int size = bottom_blob.w;
int T = bottom_blob.h;

int num_output = top_blob.w;
int hidden_size = cell_state.w;

// 4 x num_output
Mat gates(4, num_output, 2u, opt.workspace_allocator);
// 4 x hidden_size
Mat gates(4, hidden_size, 2u, opt.workspace_allocator);
if (gates.empty())
return -100;

Mat tmp_hidden_state;
if (num_output != hidden_size)
{
tmp_hidden_state.create(hidden_size, 4u, opt.workspace_allocator);
if (tmp_hidden_state.empty())
return -100;
}

// unroll
for (int t = 0; t < T; t++)
{
@@ -216,10 +278,10 @@ static int lstm_fp16sa(const Mat& bottom_blob, Mat& top_blob, int reverse, const

int ti = reverse ? T - 1 - t : t;

int nn_num_output = num_output >> 1;
int remain_num_output_start = nn_num_output << 1;
int nn_hidden_size = hidden_size >> 1;
int remain_hidden_size_start = nn_hidden_size << 1;
#pragma omp parallel for num_threads(opt.num_threads)
for (int qq = 0; qq < nn_num_output; qq++)
for (int qq = 0; qq < nn_hidden_size; qq++)
{
int q = qq * 2;

@@ -319,7 +381,7 @@ static int lstm_fp16sa(const Mat& bottom_blob, Mat& top_blob, int reverse, const
vst1q_f16(gates_data, _IFOG);
}
#pragma omp parallel for num_threads(opt.num_threads)
for (int q = remain_num_output_start; q < num_output; q++)
for (int q = remain_hidden_size_start; q < hidden_size; q++)
{
const __fp16* bias_c_IFOG = (const __fp16*)bias_c + q * 4;

@@ -428,11 +490,12 @@ static int lstm_fp16sa(const Mat& bottom_blob, Mat& top_blob, int reverse, const

float* cell_ptr = cell_state;
float* hidden_ptr = hidden_state;
float* tmp_hidden_ptr = tmp_hidden_state;

nn_num_output = num_output >> 2;
remain_num_output_start = nn_num_output << 2;
nn_hidden_size = hidden_size >> 2;
remain_hidden_size_start = nn_hidden_size << 2;
#pragma omp parallel for num_threads(opt.num_threads)
for (int qq = 0; qq < nn_num_output; qq++)
for (int qq = 0; qq < nn_hidden_size; qq++)
{
int q = qq * 4;

@@ -449,11 +512,19 @@ static int lstm_fp16sa(const Mat& bottom_blob, Mat& top_blob, int reverse, const
float32x4_t _H = vmulq_f32(_O, tanh_ps(_cell2));

vst1q_f32(cell_ptr + q, _cell2);
vst1q_f32(hidden_ptr + q, _H);
vst1_f16(output_data + q, vcvt_f16_f32(_H));

if (num_output == hidden_size)
{
vst1q_f32(hidden_ptr + q, _H);
vst1_f16(output_data + q, vcvt_f16_f32(_H));
}
else
{
vst1q_f32(tmp_hidden_ptr + q, _H);
}
}
#pragma omp parallel for num_threads(opt.num_threads)
for (int q = remain_num_output_start; q < num_output; q++)
for (int q = remain_hidden_size_start; q < hidden_size; q++)
{
const __fp16* gates_data = gates.row<const __fp16>(q);

@@ -471,8 +542,43 @@ static int lstm_fp16sa(const Mat& bottom_blob, Mat& top_blob, int reverse, const
float H = O * tanh(cell2);

cell_ptr[q] = cell2;
hidden_ptr[q] = H;
output_data[q] = (__fp16)H;
if (num_output == hidden_size)
{
hidden_ptr[q] = H;
output_data[q] = (__fp16)H;
}
else
{
tmp_hidden_ptr[q] = H;
}
}

if (num_output != hidden_size)
{
// int nn_num_output = num_output >> 2;
// int remain_num_output_start = nn_num_output << 2;
// #pragma omp parallel for num_threads(opt.num_threads)
// for (int qq = 0; qq < nn_num_output; qq++)
// {
// int q = qq * 4;
//
// }
int remain_num_output_start = 0;
#pragma omp parallel for num_threads(opt.num_threads)
for (int q = remain_num_output_start; q < num_output; q++)
{
const float* hr = weight_hr.row(q);
const float* tmp_hidden_ptr = tmp_hidden_state;

float H = 0;
for (int i = 0; i < hidden_size; i++)
{
H += tmp_hidden_ptr[i] * hr[i];
}

hidden_ptr[q] = H;
output_data[q] = (__fp16)H;
}
}
}

@@ -483,19 +589,19 @@ int LSTM_arm::create_pipeline_fp16s(const Option& opt)
{
// pack IFOG
int num_directions = direction == 2 ? 2 : 1;
int size = weight_data_size / num_directions / num_output / 4;
int size = weight_data_size / num_directions / hidden_size / 4;

if (opt.use_fp16_arithmetic)
{
weight_xc_data_packed.create(size, num_output / 2 + num_output % 2, num_directions, 16u, 8);
bias_c_data_packed.create(num_output, 1, num_directions, 8u, 4);
weight_hc_data_packed.create(num_output, num_output / 2 + num_output % 2, num_directions, 16u, 8);
weight_xc_data_packed.create(size, hidden_size / 2 + hidden_size % 2, num_directions, 16u, 8);
bias_c_data_packed.create(hidden_size, 1, num_directions, 8u, 4);
weight_hc_data_packed.create(num_output, hidden_size / 2 + hidden_size % 2, num_directions, 16u, 8);
}
else
{
weight_xc_data_packed.create(size, num_output, num_directions, 8u, 4);
bias_c_data_packed.create(num_output, 1, num_directions, 8u, 4);
weight_hc_data_packed.create(num_output, num_output, num_directions, 8u, 4);
weight_xc_data_packed.create(size, hidden_size, num_directions, 8u, 4);
bias_c_data_packed.create(hidden_size, 1, num_directions, 8u, 4);
weight_hc_data_packed.create(num_output, hidden_size, num_directions, 8u, 4);
}

#pragma omp parallel for num_threads(opt.num_threads)
@@ -519,7 +625,7 @@ int LSTM_arm::create_pipeline_fp16s(const Option& opt)
if (opt.use_fp16_arithmetic)
{
int q = 0;
for (; q + 1 < num_output; q += 2)
for (; q + 1 < hidden_size; q += 2)
{
bias_c_IFOG[0] = (__fp16)bias_c_I[q];
bias_c_IFOG[1] = (__fp16)bias_c_F[q];
@@ -532,23 +638,23 @@ int LSTM_arm::create_pipeline_fp16s(const Option& opt)

bias_c_IFOG += 8;

const float* weight_xc_I = weight_xc.row(num_output * 0 + q);
const float* weight_xc_F = weight_xc.row(num_output * 1 + q);
const float* weight_xc_O = weight_xc.row(num_output * 2 + q);
const float* weight_xc_G = weight_xc.row(num_output * 3 + q);
const float* weight_xc_I_1 = weight_xc.row(num_output * 0 + q + 1);
const float* weight_xc_F_1 = weight_xc.row(num_output * 1 + q + 1);
const float* weight_xc_O_1 = weight_xc.row(num_output * 2 + q + 1);
const float* weight_xc_G_1 = weight_xc.row(num_output * 3 + q + 1);
const float* weight_hc_I = weight_hc.row(num_output * 0 + q);
const float* weight_hc_F = weight_hc.row(num_output * 1 + q);
const float* weight_hc_O = weight_hc.row(num_output * 2 + q);
const float* weight_hc_G = weight_hc.row(num_output * 3 + q);
const float* weight_hc_I_1 = weight_hc.row(num_output * 0 + q + 1);
const float* weight_hc_F_1 = weight_hc.row(num_output * 1 + q + 1);
const float* weight_hc_O_1 = weight_hc.row(num_output * 2 + q + 1);
const float* weight_hc_G_1 = weight_hc.row(num_output * 3 + q + 1);
const float* weight_xc_I = weight_xc.row(hidden_size * 0 + q);
const float* weight_xc_F = weight_xc.row(hidden_size * 1 + q);
const float* weight_xc_O = weight_xc.row(hidden_size * 2 + q);
const float* weight_xc_G = weight_xc.row(hidden_size * 3 + q);
const float* weight_xc_I_1 = weight_xc.row(hidden_size * 0 + q + 1);
const float* weight_xc_F_1 = weight_xc.row(hidden_size * 1 + q + 1);
const float* weight_xc_O_1 = weight_xc.row(hidden_size * 2 + q + 1);
const float* weight_xc_G_1 = weight_xc.row(hidden_size * 3 + q + 1);
const float* weight_hc_I = weight_hc.row(hidden_size * 0 + q);
const float* weight_hc_F = weight_hc.row(hidden_size * 1 + q);
const float* weight_hc_O = weight_hc.row(hidden_size * 2 + q);
const float* weight_hc_G = weight_hc.row(hidden_size * 3 + q);
const float* weight_hc_I_1 = weight_hc.row(hidden_size * 0 + q + 1);
const float* weight_hc_F_1 = weight_hc.row(hidden_size * 1 + q + 1);
const float* weight_hc_O_1 = weight_hc.row(hidden_size * 2 + q + 1);
const float* weight_hc_G_1 = weight_hc.row(hidden_size * 3 + q + 1);

__fp16* weight_xc_IFOG = weight_xc_data_packed_dr.row<__fp16>(q / 2);
__fp16* weight_hc_IFOG = weight_hc_data_packed_dr.row<__fp16>(q / 2);
@@ -581,7 +687,7 @@ int LSTM_arm::create_pipeline_fp16s(const Option& opt)
weight_hc_IFOG += 8;
}
}
for (; q < num_output; q++)
for (; q < hidden_size; q++)
{
bias_c_IFOG[0] = (__fp16)bias_c_I[q];
bias_c_IFOG[1] = (__fp16)bias_c_F[q];
@@ -590,15 +696,15 @@ int LSTM_arm::create_pipeline_fp16s(const Option& opt)

bias_c_IFOG += 4;

const float* weight_xc_I = weight_xc.row(num_output * 0 + q);
const float* weight_xc_F = weight_xc.row(num_output * 1 + q);
const float* weight_xc_O = weight_xc.row(num_output * 2 + q);
const float* weight_xc_G = weight_xc.row(num_output * 3 + q);
const float* weight_xc_I = weight_xc.row(hidden_size * 0 + q);
const float* weight_xc_F = weight_xc.row(hidden_size * 1 + q);
const float* weight_xc_O = weight_xc.row(hidden_size * 2 + q);
const float* weight_xc_G = weight_xc.row(hidden_size * 3 + q);

const float* weight_hc_I = weight_hc.row(num_output * 0 + q);
const float* weight_hc_F = weight_hc.row(num_output * 1 + q);
const float* weight_hc_O = weight_hc.row(num_output * 2 + q);
const float* weight_hc_G = weight_hc.row(num_output * 3 + q);
const float* weight_hc_I = weight_hc.row(hidden_size * 0 + q);
const float* weight_hc_F = weight_hc.row(hidden_size * 1 + q);
const float* weight_hc_O = weight_hc.row(hidden_size * 2 + q);
const float* weight_hc_G = weight_hc.row(hidden_size * 3 + q);

__fp16* weight_xc_IFOG = weight_xc_data_packed_dr.row<__fp16>(q / 2 + q % 2);
__fp16* weight_hc_IFOG = weight_hc_data_packed_dr.row<__fp16>(q / 2 + q % 2);
@@ -626,7 +732,7 @@ int LSTM_arm::create_pipeline_fp16s(const Option& opt)
}
else
{
for (int q = 0; q < num_output; q++)
for (int q = 0; q < hidden_size; q++)
{
bias_c_IFOG[0] = (__fp16)bias_c_I[q];
bias_c_IFOG[1] = (__fp16)bias_c_F[q];
@@ -635,15 +741,15 @@ int LSTM_arm::create_pipeline_fp16s(const Option& opt)

bias_c_IFOG += 4;

const float* weight_xc_I = weight_xc.row(num_output * 0 + q);
const float* weight_xc_F = weight_xc.row(num_output * 1 + q);
const float* weight_xc_O = weight_xc.row(num_output * 2 + q);
const float* weight_xc_G = weight_xc.row(num_output * 3 + q);
const float* weight_xc_I = weight_xc.row(hidden_size * 0 + q);
const float* weight_xc_F = weight_xc.row(hidden_size * 1 + q);
const float* weight_xc_O = weight_xc.row(hidden_size * 2 + q);
const float* weight_xc_G = weight_xc.row(hidden_size * 3 + q);

const float* weight_hc_I = weight_hc.row(num_output * 0 + q);
const float* weight_hc_F = weight_hc.row(num_output * 1 + q);
const float* weight_hc_O = weight_hc.row(num_output * 2 + q);
const float* weight_hc_G = weight_hc.row(num_output * 3 + q);
const float* weight_hc_I = weight_hc.row(hidden_size * 0 + q);
const float* weight_hc_F = weight_hc.row(hidden_size * 1 + q);
const float* weight_hc_O = weight_hc.row(hidden_size * 2 + q);
const float* weight_hc_G = weight_hc.row(hidden_size * 3 + q);

__fp16* weight_xc_IFOG = weight_xc_data_packed_dr.row<__fp16>(q);
__fp16* weight_hc_IFOG = weight_hc_data_packed_dr.row<__fp16>(q);
@@ -671,6 +777,13 @@ int LSTM_arm::create_pipeline_fp16s(const Option& opt)
}
}

if (opt.lightmode)
{
weight_xc_data.release();
bias_c_data.release();
weight_hc_data.release();
}

return 0;
}

@@ -686,7 +799,7 @@ int LSTM_arm::forward_fp16s(const Mat& bottom_blob, Mat& top_blob, const Option&
return -100;
hidden.fill(0.f);

Mat cell(num_output, 4u, opt.workspace_allocator);
Mat cell(hidden_size, 4u, opt.workspace_allocator);
if (cell.empty())
return -100;
cell.fill(0.f);
@@ -698,7 +811,7 @@ int LSTM_arm::forward_fp16s(const Mat& bottom_blob, Mat& top_blob, const Option&
// Uni directional
if (direction == 0 || direction == 1)
{
int ret = lstm_fp16s(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden, cell, opt);
int ret = lstm_fp16s(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt);
if (ret != 0)
return ret;
}
@@ -713,14 +826,14 @@ int LSTM_arm::forward_fp16s(const Mat& bottom_blob, Mat& top_blob, const Option&
if (top_blob_reverse.empty())
return -100;

int ret0 = lstm_fp16s(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden, cell, opt);
int ret0 = lstm_fp16s(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt);
if (ret0 != 0)
return ret0;

hidden.fill(0.f);
cell.fill(0.f);

int ret1 = lstm_fp16s(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), hidden, cell, opt);
int ret1 = lstm_fp16s(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), num_output == hidden_size ? Mat() : weight_hr_data.channel(1), hidden, cell, opt);
if (ret1 != 0)
return ret1;

@@ -762,7 +875,7 @@ int LSTM_arm::forward_fp16s(const std::vector<Mat>& bottom_blobs, std::vector<Ma
return -100;
hidden.fill(0.f);

cell.create(num_output, num_directions, 4u, hidden_cell_allocator);
cell.create(hidden_size, num_directions, 4u, hidden_cell_allocator);
if (cell.empty())
return -100;
cell.fill(0.f);
@@ -776,7 +889,7 @@ int LSTM_arm::forward_fp16s(const std::vector<Mat>& bottom_blobs, std::vector<Ma
// Uni directional
if (direction == 0 || direction == 1)
{
int ret = lstm_fp16s(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden, cell, opt);
int ret = lstm_fp16s(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt);
if (ret != 0)
return ret;
}
@@ -793,13 +906,13 @@ int LSTM_arm::forward_fp16s(const std::vector<Mat>& bottom_blobs, std::vector<Ma

Mat hidden0 = hidden.row_range(0, 1);
Mat cell0 = cell.row_range(0, 1);
int ret0 = lstm_fp16s(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden0, cell0, opt);
int ret0 = lstm_fp16s(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden0, cell0, opt);
if (ret0 != 0)
return ret0;

Mat hidden1 = hidden.row_range(1, 1);
Mat cell1 = cell.row_range(1, 1);
int ret1 = lstm_fp16s(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), hidden1, cell1, opt);
int ret1 = lstm_fp16s(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), num_output == hidden_size ? Mat() : weight_hr_data.channel(1), hidden1, cell1, opt);
if (ret1 != 0)
return ret1;

@@ -836,7 +949,7 @@ int LSTM_arm::forward_fp16sa(const Mat& bottom_blob, Mat& top_blob, const Option
return -100;
hidden.fill(0.f);

Mat cell(num_output, 4u, opt.workspace_allocator);
Mat cell(hidden_size, 4u, opt.workspace_allocator);
if (cell.empty())
return -100;
cell.fill(0.f);
@@ -848,7 +961,7 @@ int LSTM_arm::forward_fp16sa(const Mat& bottom_blob, Mat& top_blob, const Option
// Uni directional
if (direction == 0 || direction == 1)
{
int ret = lstm_fp16sa(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden, cell, opt);
int ret = lstm_fp16sa(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt);
if (ret != 0)
return ret;
}
@@ -863,14 +976,14 @@ int LSTM_arm::forward_fp16sa(const Mat& bottom_blob, Mat& top_blob, const Option
if (top_blob_reverse.empty())
return -100;

int ret0 = lstm_fp16sa(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden, cell, opt);
int ret0 = lstm_fp16sa(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt);
if (ret0 != 0)
return ret0;

hidden.fill(0.f);
cell.fill(0.f);

int ret1 = lstm_fp16sa(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), hidden, cell, opt);
int ret1 = lstm_fp16sa(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), num_output == hidden_size ? Mat() : weight_hr_data.channel(1), hidden, cell, opt);
if (ret1 != 0)
return ret1;

@@ -912,7 +1025,7 @@ int LSTM_arm::forward_fp16sa(const std::vector<Mat>& bottom_blobs, std::vector<M
return -100;
hidden.fill(0.f);

cell.create(num_output, num_directions, 4u, hidden_cell_allocator);
cell.create(hidden_size, num_directions, 4u, hidden_cell_allocator);
if (cell.empty())
return -100;
cell.fill(0.f);
@@ -926,7 +1039,7 @@ int LSTM_arm::forward_fp16sa(const std::vector<Mat>& bottom_blobs, std::vector<M
// Uni directional
if (direction == 0 || direction == 1)
{
int ret = lstm_fp16sa(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden, cell, opt);
int ret = lstm_fp16sa(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt);
if (ret != 0)
return ret;
}
@@ -943,13 +1056,13 @@ int LSTM_arm::forward_fp16sa(const std::vector<Mat>& bottom_blobs, std::vector<M

Mat hidden0 = hidden.row_range(0, 1);
Mat cell0 = cell.row_range(0, 1);
int ret0 = lstm_fp16sa(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden0, cell0, opt);
int ret0 = lstm_fp16sa(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden0, cell0, opt);
if (ret0 != 0)
return ret0;

Mat hidden1 = hidden.row_range(1, 1);
Mat cell1 = cell.row_range(1, 1);
int ret1 = lstm_fp16sa(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), hidden1, cell1, opt);
int ret1 = lstm_fp16sa(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), num_output == hidden_size ? Mat() : weight_hr_data.channel(1), hidden1, cell1, opt);
if (ret1 != 0)
return ret1;



+ 70
- 27
src/layer/lstm.cpp View File

@@ -29,6 +29,7 @@ int LSTM::load_param(const ParamDict& pd)
num_output = pd.get(0, 0);
weight_data_size = pd.get(1, 0);
direction = pd.get(2, 0);
hidden_size = pd.get(3, num_output);
return 0;
}

@@ -36,36 +37,52 @@ int LSTM::load_model(const ModelBin& mb)
{
int num_directions = direction == 2 ? 2 : 1;

int size = weight_data_size / num_directions / num_output / 4;
int size = weight_data_size / num_directions / hidden_size / 4;

// raw weight data
weight_xc_data = mb.load(size, num_output * 4, num_directions, 0);
weight_xc_data = mb.load(size, hidden_size * 4, num_directions, 0);
if (weight_xc_data.empty())
return -100;

bias_c_data = mb.load(num_output, 4, num_directions, 0);
bias_c_data = mb.load(hidden_size, 4, num_directions, 0);
if (bias_c_data.empty())
return -100;

weight_hc_data = mb.load(num_output, num_output * 4, num_directions, 0);
weight_hc_data = mb.load(num_output, hidden_size * 4, num_directions, 0);
if (weight_hc_data.empty())
return -100;

if (num_output != hidden_size)
{
weight_hr_data = mb.load(hidden_size, num_output, num_directions, 0);
if (weight_hr_data.empty())
return -100;
}

return 0;
}

static int lstm(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& weight_xc, const Mat& bias_c, const Mat& weight_hc, Mat& hidden_state, Mat& cell_state, const Option& opt)
static int lstm(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& weight_xc, const Mat& bias_c, const Mat& weight_hc, const Mat& weight_hr, Mat& hidden_state, Mat& cell_state, const Option& opt)
{
int size = bottom_blob.w;
int T = bottom_blob.h;

int num_output = top_blob.w;
int hidden_size = cell_state.w;

// 4 x num_output
Mat gates(4, num_output, 4u, opt.workspace_allocator);
// 4 x hidden_size
Mat gates(4, hidden_size, 4u, opt.workspace_allocator);
if (gates.empty())
return -100;

Mat tmp_hidden_state;
if (num_output != hidden_size)
{
tmp_hidden_state.create(hidden_size, 4u, opt.workspace_allocator);
if (tmp_hidden_state.empty())
return -100;
}

// unroll
for (int t = 0; t < T; t++)
{
@@ -80,7 +97,7 @@ static int lstm(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& w

const float* x = bottom_blob.row(ti);
#pragma omp parallel for num_threads(opt.num_threads)
for (int q = 0; q < num_output; q++)
for (int q = 0; q < hidden_size; q++)
{
const float* bias_c_I = bias_c.row(0);
const float* bias_c_F = bias_c.row(1);
@@ -90,15 +107,15 @@ static int lstm(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& w
float* gates_data = gates.row(q);

// gate I F O G
const float* weight_xc_I = weight_xc.row(num_output * 0 + q);
const float* weight_xc_F = weight_xc.row(num_output * 1 + q);
const float* weight_xc_O = weight_xc.row(num_output * 2 + q);
const float* weight_xc_G = weight_xc.row(num_output * 3 + q);
const float* weight_xc_I = weight_xc.row(hidden_size * 0 + q);
const float* weight_xc_F = weight_xc.row(hidden_size * 1 + q);
const float* weight_xc_O = weight_xc.row(hidden_size * 2 + q);
const float* weight_xc_G = weight_xc.row(hidden_size * 3 + q);

const float* weight_hc_I = weight_hc.row(num_output * 0 + q);
const float* weight_hc_F = weight_hc.row(num_output * 1 + q);
const float* weight_hc_O = weight_hc.row(num_output * 2 + q);
const float* weight_hc_G = weight_hc.row(num_output * 3 + q);
const float* weight_hc_I = weight_hc.row(hidden_size * 0 + q);
const float* weight_hc_F = weight_hc.row(hidden_size * 1 + q);
const float* weight_hc_O = weight_hc.row(hidden_size * 2 + q);
const float* weight_hc_G = weight_hc.row(hidden_size * 3 + q);

float I = bias_c_I[q];
float F = bias_c_F[q];
@@ -140,7 +157,7 @@ static int lstm(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& w
// h_t := o_t .* tanh[c_t]
float* output_data = top_blob.row(ti);
#pragma omp parallel for num_threads(opt.num_threads)
for (int q = 0; q < num_output; q++)
for (int q = 0; q < hidden_size; q++)
{
const float* gates_data = gates.row(q);

@@ -157,8 +174,34 @@ static int lstm(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& w
float cell2 = F * cell_state[q] + I * G;
float H = O * tanh(cell2);
cell_state[q] = cell2;
hidden_state[q] = H;
output_data[q] = H;

if (num_output == hidden_size)
{
hidden_state[q] = H;
output_data[q] = H;
}
else
{
tmp_hidden_state[q] = H;
}
}

if (num_output != hidden_size)
{
#pragma omp parallel for num_threads(opt.num_threads)
for (int q = 0; q < num_output; q++)
{
const float* hr = weight_hr.row(q);

float H = 0;
for (int i = 0; i < hidden_size; i++)
{
H += tmp_hidden_state[i] * hr[i];
}

hidden_state[q] = H;
output_data[q] = H;
}
}
}

@@ -177,7 +220,7 @@ int LSTM::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) cons
return -100;
hidden.fill(0.f);

Mat cell(num_output, 4u, opt.workspace_allocator);
Mat cell(hidden_size, 4u, opt.workspace_allocator);
if (cell.empty())
return -100;
cell.fill(0.f);
@@ -189,7 +232,7 @@ int LSTM::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) cons
// Uni directional
if (direction == 0 || direction == 1)
{
int ret = lstm(bottom_blob, top_blob, direction, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), hidden, cell, opt);
int ret = lstm(bottom_blob, top_blob, direction, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt);
if (ret != 0)
return ret;
}
@@ -204,14 +247,14 @@ int LSTM::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) cons
if (top_blob_reverse.empty())
return -100;

int ret0 = lstm(bottom_blob, top_blob_forward, 0, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), hidden, cell, opt);
int ret0 = lstm(bottom_blob, top_blob_forward, 0, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt);
if (ret0 != 0)
return ret0;

hidden.fill(0.0f);
cell.fill(0.0f);

int ret1 = lstm(bottom_blob, top_blob_reverse, 1, weight_xc_data.channel(1), bias_c_data.channel(1), weight_hc_data.channel(1), hidden, cell, opt);
int ret1 = lstm(bottom_blob, top_blob_reverse, 1, weight_xc_data.channel(1), bias_c_data.channel(1), weight_hc_data.channel(1), num_output == hidden_size ? Mat() : weight_hr_data.channel(1), hidden, cell, opt);
if (ret1 != 0)
return ret1;

@@ -251,7 +294,7 @@ int LSTM::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top_bl
return -100;
hidden.fill(0.f);

cell.create(num_output, num_directions, 4u, hidden_cell_allocator);
cell.create(hidden_size, num_directions, 4u, hidden_cell_allocator);
if (cell.empty())
return -100;
cell.fill(0.f);
@@ -265,7 +308,7 @@ int LSTM::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top_bl
// Uni directional
if (direction == 0 || direction == 1)
{
int ret = lstm(bottom_blob, top_blob, direction, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), hidden, cell, opt);
int ret = lstm(bottom_blob, top_blob, direction, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt);
if (ret != 0)
return ret;
}
@@ -282,13 +325,13 @@ int LSTM::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top_bl

Mat hidden0 = hidden.row_range(0, 1);
Mat cell0 = cell.row_range(0, 1);
int ret0 = lstm(bottom_blob, top_blob_forward, 0, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), hidden0, cell0, opt);
int ret0 = lstm(bottom_blob, top_blob_forward, 0, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden0, cell0, opt);
if (ret0 != 0)
return ret0;

Mat hidden1 = hidden.row_range(1, 1);
Mat cell1 = cell.row_range(1, 1);
int ret1 = lstm(bottom_blob, top_blob_reverse, 1, weight_xc_data.channel(1), bias_c_data.channel(1), weight_hc_data.channel(1), hidden1, cell1, opt);
int ret1 = lstm(bottom_blob, top_blob_reverse, 1, weight_xc_data.channel(1), bias_c_data.channel(1), weight_hc_data.channel(1), num_output == hidden_size ? Mat() : weight_hr_data.channel(1), hidden1, cell1, opt);
if (ret1 != 0)
return ret1;



+ 2
- 0
src/layer/lstm.h View File

@@ -36,10 +36,12 @@ public:
int num_output;
int weight_data_size;
int direction; // 0=forward 1=reverse 2=bidirectional
int hidden_size;

Mat weight_hc_data;
Mat weight_xc_data;
Mat bias_c_data;
Mat weight_hr_data;
};

} // namespace ncnn


+ 462
- 301
src/layer/x86/lstm_x86.cpp View File

@@ -14,6 +14,13 @@

#include "lstm_x86.h"

#if __SSE2__
#include <emmintrin.h>
#if __AVX__
#include <immintrin.h>
#endif
#endif // __SSE2__

#include "x86_activation.h"
#include "x86_usability.h"

@@ -30,23 +37,183 @@ LSTM_x86::LSTM_x86()

int LSTM_x86::create_pipeline(const Option& opt)
{
(void)(opt);
// pack IFOG
int num_directions = direction == 2 ? 2 : 1;
int size = weight_data_size / num_directions / hidden_size / 4;

#if __AVX__
weight_xc_data_packed.create(size, hidden_size / 2 + hidden_size % 2, num_directions, 32u, 8);
bias_c_data_packed.create(hidden_size, 1, num_directions, 16u, 4);
weight_hc_data_packed.create(num_output, hidden_size / 2 + hidden_size % 2, num_directions, 32u, 8);
#else
weight_xc_data_packed.create(size, hidden_size, num_directions, 16u, 4);
bias_c_data_packed.create(hidden_size, 1, num_directions, 16u, 4);
weight_hc_data_packed.create(num_output, hidden_size, num_directions, 16u, 4);
#endif

#pragma omp parallel for num_threads(opt.num_threads)
for (int dr = 0; dr < num_directions; dr++)
{
const Mat weight_xc = weight_xc_data.channel(dr);
const Mat bias_c = bias_c_data.channel(dr);
const Mat weight_hc = weight_hc_data.channel(dr);

Mat weight_xc_data_packed_dr = weight_xc_data_packed.channel(dr);
Mat bias_c_data_packed_dr = bias_c_data_packed.channel(dr);
Mat weight_hc_data_packed_dr = weight_hc_data_packed.channel(dr);

const float* bias_c_I = bias_c.row(0);
const float* bias_c_F = bias_c.row(1);
const float* bias_c_O = bias_c.row(2);
const float* bias_c_G = bias_c.row(3);

float* bias_c_IFOG = bias_c_data_packed_dr.row(0);

int q = 0;
#if __AVX__
for (; q + 1 < hidden_size; q += 2)
{
bias_c_IFOG[0] = bias_c_I[q];
bias_c_IFOG[1] = bias_c_F[q];
bias_c_IFOG[2] = bias_c_O[q];
bias_c_IFOG[3] = bias_c_G[q];
bias_c_IFOG[4] = bias_c_I[q + 1];
bias_c_IFOG[5] = bias_c_F[q + 1];
bias_c_IFOG[6] = bias_c_O[q + 1];
bias_c_IFOG[7] = bias_c_G[q + 1];

bias_c_IFOG += 8;

const float* weight_xc_I = weight_xc.row(hidden_size * 0 + q);
const float* weight_xc_F = weight_xc.row(hidden_size * 1 + q);
const float* weight_xc_O = weight_xc.row(hidden_size * 2 + q);
const float* weight_xc_G = weight_xc.row(hidden_size * 3 + q);
const float* weight_xc_I_1 = weight_xc.row(hidden_size * 0 + q + 1);
const float* weight_xc_F_1 = weight_xc.row(hidden_size * 1 + q + 1);
const float* weight_xc_O_1 = weight_xc.row(hidden_size * 2 + q + 1);
const float* weight_xc_G_1 = weight_xc.row(hidden_size * 3 + q + 1);

const float* weight_hc_I = weight_hc.row(hidden_size * 0 + q);
const float* weight_hc_F = weight_hc.row(hidden_size * 1 + q);
const float* weight_hc_O = weight_hc.row(hidden_size * 2 + q);
const float* weight_hc_G = weight_hc.row(hidden_size * 3 + q);
const float* weight_hc_I_1 = weight_hc.row(hidden_size * 0 + q + 1);
const float* weight_hc_F_1 = weight_hc.row(hidden_size * 1 + q + 1);
const float* weight_hc_O_1 = weight_hc.row(hidden_size * 2 + q + 1);
const float* weight_hc_G_1 = weight_hc.row(hidden_size * 3 + q + 1);

float* weight_xc_IFOG = weight_xc_data_packed_dr.row(q / 2);
float* weight_hc_IFOG = weight_hc_data_packed_dr.row(q / 2);

for (int i = 0; i < size; i++)
{
weight_xc_IFOG[0] = weight_xc_I[i];
weight_xc_IFOG[1] = weight_xc_F[i];
weight_xc_IFOG[2] = weight_xc_O[i];
weight_xc_IFOG[3] = weight_xc_G[i];
weight_xc_IFOG[4] = weight_xc_I_1[i];
weight_xc_IFOG[5] = weight_xc_F_1[i];
weight_xc_IFOG[6] = weight_xc_O_1[i];
weight_xc_IFOG[7] = weight_xc_G_1[i];

weight_xc_IFOG += 8;
}

for (int i = 0; i < num_output; i++)
{
weight_hc_IFOG[0] = weight_hc_I[i];
weight_hc_IFOG[1] = weight_hc_F[i];
weight_hc_IFOG[2] = weight_hc_O[i];
weight_hc_IFOG[3] = weight_hc_G[i];
weight_hc_IFOG[4] = weight_hc_I_1[i];
weight_hc_IFOG[5] = weight_hc_F_1[i];
weight_hc_IFOG[6] = weight_hc_O_1[i];
weight_hc_IFOG[7] = weight_hc_G_1[i];

weight_hc_IFOG += 8;
}
}
#endif // __AVX__
for (; q < hidden_size; q++)
{
bias_c_IFOG[0] = bias_c_I[q];
bias_c_IFOG[1] = bias_c_F[q];
bias_c_IFOG[2] = bias_c_O[q];
bias_c_IFOG[3] = bias_c_G[q];

bias_c_IFOG += 4;

const float* weight_xc_I = weight_xc.row(hidden_size * 0 + q);
const float* weight_xc_F = weight_xc.row(hidden_size * 1 + q);
const float* weight_xc_O = weight_xc.row(hidden_size * 2 + q);
const float* weight_xc_G = weight_xc.row(hidden_size * 3 + q);

const float* weight_hc_I = weight_hc.row(hidden_size * 0 + q);
const float* weight_hc_F = weight_hc.row(hidden_size * 1 + q);
const float* weight_hc_O = weight_hc.row(hidden_size * 2 + q);
const float* weight_hc_G = weight_hc.row(hidden_size * 3 + q);

#if __AVX__
float* weight_xc_IFOG = weight_xc_data_packed_dr.row(q / 2 + q % 2);
float* weight_hc_IFOG = weight_hc_data_packed_dr.row(q / 2 + q % 2);
#else
float* weight_xc_IFOG = weight_xc_data_packed_dr.row(q);
float* weight_hc_IFOG = weight_hc_data_packed_dr.row(q);
#endif

for (int i = 0; i < size; i++)
{
weight_xc_IFOG[0] = weight_xc_I[i];
weight_xc_IFOG[1] = weight_xc_F[i];
weight_xc_IFOG[2] = weight_xc_O[i];
weight_xc_IFOG[3] = weight_xc_G[i];

weight_xc_IFOG += 4;
}

for (int i = 0; i < num_output; i++)
{
weight_hc_IFOG[0] = weight_hc_I[i];
weight_hc_IFOG[1] = weight_hc_F[i];
weight_hc_IFOG[2] = weight_hc_O[i];
weight_hc_IFOG[3] = weight_hc_G[i];

weight_hc_IFOG += 4;
}
}
}

if (opt.lightmode)
{
weight_xc_data.release();
bias_c_data.release();
weight_hc_data.release();
}

return 0;
}
#ifdef __AVX__
static int lstm(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& weight_xc, const Mat& bias_c, const Mat& weight_hc, Mat& hidden_state, Mat& cell_state, const Option& opt)
static int lstm(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& weight_xc, const Mat& bias_c, const Mat& weight_hc, const Mat& weight_hr, Mat& hidden_state, Mat& cell_state, const Option& opt)
{
int size = bottom_blob.w;
int T = bottom_blob.h;

int num_output = top_blob.w;
int hidden_size = cell_state.w;

// 4 x num_output
Mat gates(num_output, 4, 4u, opt.workspace_allocator);
// 4 x hidden_size
Mat gates(4, hidden_size, 4u, opt.workspace_allocator);
if (gates.empty())
return -100;

Mat tmp_hidden_state;
if (num_output != hidden_size)
{
tmp_hidden_state.create(hidden_size, 4u, opt.workspace_allocator);
if (tmp_hidden_state.empty())
return -100;
}

// unroll
for (int t = 0; t < T; t++)
{
@@ -59,267 +226,222 @@ static int lstm(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& w

int ti = reverse ? T - 1 - t : t;

int nn_num_output = num_output >> 1;
int remain_num_output_start = nn_num_output << 1;
#if __AVX__
int nn_hidden_size = hidden_size >> 1;
int remain_hidden_size_start = nn_hidden_size << 1;
#pragma omp parallel for num_threads(opt.num_threads)
for (int qq = 0; qq < nn_num_output; qq++)
for (int qq = 0; qq < nn_hidden_size; qq++)
{
int q = qq * 2;

const float* x = bottom_blob.row(ti);
const float* hidden_ptr_r = hidden_state;
const float* bias_c_I = bias_c.row(0);
const float* bias_c_F = bias_c.row(1);
const float* bias_c_O = bias_c.row(2);
const float* bias_c_G = bias_c.row(3);

float* gates_data_I = gates.row(0);
float* gates_data_F = gates.row(1);
float* gates_data_O = gates.row(2);
float* gates_data_G = gates.row(3);
const float* bias_c_IFOG = (const float*)bias_c + q * 4;

// gate I F O G
const float* weight_xc_I_0 = weight_xc.row(num_output * 0 + q);
const float* weight_xc_F_0 = weight_xc.row(num_output * 1 + q);
const float* weight_xc_O_0 = weight_xc.row(num_output * 2 + q);
const float* weight_xc_G_0 = weight_xc.row(num_output * 3 + q);
const float* weight_xc_I_1 = weight_xc.row(num_output * 0 + (q + 1));
const float* weight_xc_F_1 = weight_xc.row(num_output * 1 + (q + 1));
const float* weight_xc_O_1 = weight_xc.row(num_output * 2 + (q + 1));
const float* weight_xc_G_1 = weight_xc.row(num_output * 3 + (q + 1));

const float* weight_hc_I_0 = weight_hc.row(num_output * 0 + q);
const float* weight_hc_F_0 = weight_hc.row(num_output * 1 + q);
const float* weight_hc_O_0 = weight_hc.row(num_output * 2 + q);
const float* weight_hc_G_0 = weight_hc.row(num_output * 3 + q);
const float* weight_hc_I_1 = weight_hc.row(num_output * 0 + (q + 1));
const float* weight_hc_F_1 = weight_hc.row(num_output * 1 + (q + 1));
const float* weight_hc_O_1 = weight_hc.row(num_output * 2 + (q + 1));
const float* weight_hc_G_1 = weight_hc.row(num_output * 3 + (q + 1));

// float I = bias_c_I[q];
// float F = bias_c_F[q];
// float O = bias_c_O[q];
// float G = bias_c_G[q];
__m256 _sumI_0 = _mm256_setzero_ps();
__m256 _sumF_0 = _mm256_setzero_ps();
__m256 _sumO_0 = _mm256_setzero_ps();
__m256 _sumG_0 = _mm256_setzero_ps();
__m256 _sumI_1 = _mm256_setzero_ps();
__m256 _sumF_1 = _mm256_setzero_ps();
__m256 _sumO_1 = _mm256_setzero_ps();
__m256 _sumG_1 = _mm256_setzero_ps();
int nn_num_size = size >> 3;
int remain_size = size & 7;
for (; nn_num_size > 0; nn_num_size--)
const float* weight_xc_IFOG = weight_xc.row(q / 2);
const float* weight_hc_IFOG = weight_hc.row(q / 2);

__m256 _IFOG = _mm256_loadu_ps(bias_c_IFOG);
__m256 _sum1 = _mm256_setzero_ps();
__m256 _sum2 = _mm256_setzero_ps();
__m256 _sum3 = _mm256_setzero_ps();

const float* x = bottom_blob.row(ti);

int i = 0;
for (; i + 3 < size; i += 4)
{
__m256 xi = _mm256_loadu_ps(x);
_sumI_0 = _mm256_comp_fmadd_ps(_mm256_loadu_ps(weight_xc_I_0), xi, _sumI_0);
_sumF_0 = _mm256_comp_fmadd_ps(_mm256_loadu_ps(weight_xc_F_0), xi, _sumF_0);
_sumO_0 = _mm256_comp_fmadd_ps(_mm256_loadu_ps(weight_xc_O_0), xi, _sumO_0);
_sumG_0 = _mm256_comp_fmadd_ps(_mm256_loadu_ps(weight_xc_G_0), xi, _sumG_0);
_sumI_1 = _mm256_comp_fmadd_ps(_mm256_loadu_ps(weight_xc_I_1), xi, _sumI_1);
_sumF_1 = _mm256_comp_fmadd_ps(_mm256_loadu_ps(weight_xc_F_1), xi, _sumF_1);
_sumO_1 = _mm256_comp_fmadd_ps(_mm256_loadu_ps(weight_xc_O_1), xi, _sumO_1);
_sumG_1 = _mm256_comp_fmadd_ps(_mm256_loadu_ps(weight_xc_G_1), xi, _sumG_1);
x += 8;
weight_xc_I_0 += 8;
weight_xc_F_0 += 8;
weight_xc_O_0 += 8;
weight_xc_G_0 += 8;
weight_xc_I_1 += 8;
weight_xc_F_1 += 8;
weight_xc_O_1 += 8;
weight_xc_G_1 += 8;
__m256 _xi0 = _mm256_broadcast_ss(x);
__m256 _xi1 = _mm256_broadcast_ss(x + 1);
__m256 _xi2 = _mm256_broadcast_ss(x + 2);
__m256 _xi3 = _mm256_broadcast_ss(x + 3);
__m256 _weight_xc_IFOG0 = _mm256_loadu_ps(weight_xc_IFOG);
__m256 _weight_xc_IFOG1 = _mm256_loadu_ps(weight_xc_IFOG + 8);
__m256 _weight_xc_IFOG2 = _mm256_loadu_ps(weight_xc_IFOG + 16);
__m256 _weight_xc_IFOG3 = _mm256_loadu_ps(weight_xc_IFOG + 24);
_IFOG = _mm256_comp_fmadd_ps(_weight_xc_IFOG0, _xi0, _IFOG);
_sum1 = _mm256_comp_fmadd_ps(_weight_xc_IFOG1, _xi1, _sum1);
_sum2 = _mm256_comp_fmadd_ps(_weight_xc_IFOG2, _xi2, _sum2);
_sum3 = _mm256_comp_fmadd_ps(_weight_xc_IFOG3, _xi3, _sum3);

x += 4;
weight_xc_IFOG += 32;
}
int nn_num_output = num_output >> 3;
int remain_num_output = num_output & 7;
for (; nn_num_output > 0; nn_num_output--)
for (; i < size; i++)
{
__m256 h_cont = _mm256_loadu_ps(hidden_ptr_r);

_sumI_0 = _mm256_comp_fmadd_ps(_mm256_loadu_ps(weight_hc_I_0), h_cont, _sumI_0);
_sumF_0 = _mm256_comp_fmadd_ps(_mm256_loadu_ps(weight_hc_F_0), h_cont, _sumF_0);
_sumO_0 = _mm256_comp_fmadd_ps(_mm256_loadu_ps(weight_hc_O_0), h_cont, _sumO_0);
_sumG_0 = _mm256_comp_fmadd_ps(_mm256_loadu_ps(weight_hc_G_0), h_cont, _sumG_0);
_sumI_1 = _mm256_comp_fmadd_ps(_mm256_loadu_ps(weight_hc_I_1), h_cont, _sumI_1);
_sumF_1 = _mm256_comp_fmadd_ps(_mm256_loadu_ps(weight_hc_F_1), h_cont, _sumF_1);
_sumO_1 = _mm256_comp_fmadd_ps(_mm256_loadu_ps(weight_hc_O_1), h_cont, _sumO_1);
_sumG_1 = _mm256_comp_fmadd_ps(_mm256_loadu_ps(weight_hc_G_1), h_cont, _sumG_1);
hidden_ptr_r += 8;
weight_hc_I_0 += 8;
weight_hc_F_0 += 8;
weight_hc_O_0 += 8;
weight_hc_G_0 += 8;
weight_hc_I_1 += 8;
weight_hc_F_1 += 8;
weight_hc_O_1 += 8;
weight_hc_G_1 += 8;
__m256 _xi = _mm256_broadcast_ss(x);
__m256 _weight_xc_IFOG = _mm256_loadu_ps(weight_xc_IFOG);
_IFOG = _mm256_comp_fmadd_ps(_weight_xc_IFOG, _xi, _IFOG);

x += 1;
weight_xc_IFOG += 8;
}
float sums[8];
_mm256_storeu_ps(sums, HorizontalSums(_sumI_0, _sumF_0, _sumO_0, _sumG_0, _sumI_1, _sumF_1, _sumO_1, _sumG_1));
sums[0] += bias_c_I[q];
sums[1] += bias_c_F[q];
sums[2] += bias_c_O[q];
sums[3] += bias_c_G[q];
sums[4] += bias_c_I[q + 1];
sums[5] += bias_c_F[q + 1];
sums[6] += bias_c_O[q + 1];
sums[7] += bias_c_G[q + 1];

for (; remain_size > 0; remain_size--)

const float* hidden_ptr = hidden_state;

i = 0;
for (; i + 3 < num_output; i += 4)
{
float xi = *x;
sums[0] += *weight_xc_I_0 * xi;
sums[1] += *weight_xc_F_0 * xi;
sums[2] += *weight_xc_O_0 * xi;
sums[3] += *weight_xc_G_0 * xi;
sums[4] += *weight_xc_I_1 * xi;
sums[5] += *weight_xc_F_1 * xi;
sums[6] += *weight_xc_O_1 * xi;
sums[7] += *weight_xc_G_1 * xi;
x++;
weight_xc_I_0++;
weight_xc_F_0++;
weight_xc_O_0++;
weight_xc_G_0++;
weight_xc_I_1++;
weight_xc_F_1++;
weight_xc_O_1++;
weight_xc_G_1++;
__m256 _h_cont0 = _mm256_broadcast_ss(hidden_ptr);
__m256 _h_cont1 = _mm256_broadcast_ss(hidden_ptr + 1);
__m256 _h_cont2 = _mm256_broadcast_ss(hidden_ptr + 2);
__m256 _h_cont3 = _mm256_broadcast_ss(hidden_ptr + 3);
__m256 _weight_hc_IFOG0 = _mm256_loadu_ps(weight_hc_IFOG);
__m256 _weight_hc_IFOG1 = _mm256_loadu_ps(weight_hc_IFOG + 8);
__m256 _weight_hc_IFOG2 = _mm256_loadu_ps(weight_hc_IFOG + 16);
__m256 _weight_hc_IFOG3 = _mm256_loadu_ps(weight_hc_IFOG + 24);
_IFOG = _mm256_comp_fmadd_ps(_weight_hc_IFOG0, _h_cont0, _IFOG);
_sum1 = _mm256_comp_fmadd_ps(_weight_hc_IFOG1, _h_cont1, _sum1);
_sum2 = _mm256_comp_fmadd_ps(_weight_hc_IFOG2, _h_cont2, _sum2);
_sum3 = _mm256_comp_fmadd_ps(_weight_hc_IFOG3, _h_cont3, _sum3);

hidden_ptr += 4;
weight_hc_IFOG += 32;
}

for (; remain_num_output > 0; remain_num_output--)
for (; i < num_output; i++)
{
float h_cont = *hidden_ptr_r;
sums[0] += *weight_hc_I_0 * h_cont;
sums[1] += *weight_hc_F_0 * h_cont;
sums[2] += *weight_hc_O_0 * h_cont;
sums[3] += *weight_hc_G_0 * h_cont;
sums[4] += *weight_hc_I_1 * h_cont;
sums[5] += *weight_hc_F_1 * h_cont;
sums[6] += *weight_hc_O_1 * h_cont;
sums[7] += *weight_hc_G_1 * h_cont;
hidden_ptr_r++;
weight_hc_I_0++;
weight_hc_F_0++;
weight_hc_O_0++;
weight_hc_G_0++;
weight_hc_I_1++;
weight_hc_F_1++;
weight_hc_O_1++;
weight_hc_G_1++;
__m256 _h_cont = _mm256_broadcast_ss(hidden_ptr);
__m256 _weight_hc_IFOG = _mm256_loadu_ps(weight_hc_IFOG);
_IFOG = _mm256_comp_fmadd_ps(_weight_hc_IFOG, _h_cont, _IFOG);

hidden_ptr += 1;
weight_hc_IFOG += 8;
}
gates_data_I[q] = sums[0];
gates_data_F[q] = sums[1];
gates_data_O[q] = sums[2];
gates_data_G[q] = sums[3];
gates_data_I[q + 1] = sums[4];
gates_data_F[q + 1] = sums[5];
gates_data_O[q + 1] = sums[6];
gates_data_G[q + 1] = sums[7];

float* gates_data = gates.row(q);

_IFOG = _mm256_add_ps(_IFOG, _sum1);
_sum2 = _mm256_add_ps(_sum2, _sum3);
_IFOG = _mm256_add_ps(_IFOG, _sum2);

_mm256_storeu_ps(gates_data, _IFOG);
}
#else
int nn_hidden_size = 0;
int remain_hidden_size_start = 0;
#endif // __AVX__

#pragma omp parallel for num_threads(opt.num_threads)
for (int q = remain_num_output_start; q < num_output; q++)
for (int q = remain_hidden_size_start; q < hidden_size; q++)
{
const float* x = bottom_blob.row(ti);
const float* hidden_ptr_r = hidden_state;
const float* bias_c_I = bias_c.row(0);
const float* bias_c_F = bias_c.row(1);
const float* bias_c_O = bias_c.row(2);
const float* bias_c_G = bias_c.row(3);

float* gates_data_I = gates.row(0);
float* gates_data_F = gates.row(1);
float* gates_data_O = gates.row(2);
float* gates_data_G = gates.row(3);
const float* bias_c_IFOG = (const float*)bias_c + q * 4;

// gate I F O G
const float* weight_xc_I = weight_xc.row(num_output * 0 + q);
const float* weight_xc_F = weight_xc.row(num_output * 1 + q);
const float* weight_xc_O = weight_xc.row(num_output * 2 + q);
const float* weight_xc_G = weight_xc.row(num_output * 3 + q);

const float* weight_hc_I = weight_hc.row(num_output * 0 + q);
const float* weight_hc_F = weight_hc.row(num_output * 1 + q);
const float* weight_hc_O = weight_hc.row(num_output * 2 + q);
const float* weight_hc_G = weight_hc.row(num_output * 3 + q);

// float I = bias_c_I[q];
// float F = bias_c_F[q];
// float O = bias_c_O[q];
// float G = bias_c_G[q];
__m256 _sumI = _mm256_setzero_ps();
__m256 _sumF = _mm256_setzero_ps();
__m256 _sumO = _mm256_setzero_ps();
__m256 _sumG = _mm256_setzero_ps();
int nn_num_size = size >> 3;
int remain_size = size & 7;
for (; nn_num_size > 0; nn_num_size--)
#if __AVX__
const float* weight_xc_IFOG = weight_xc.row(q / 2 + q % 2);
const float* weight_hc_IFOG = weight_hc.row(q / 2 + q % 2);
#else
const float* weight_xc_IFOG = weight_xc.row(q);
const float* weight_hc_IFOG = weight_hc.row(q);
#endif

#if __SSE2__
__m128 _IFOG = _mm_loadu_ps(bias_c_IFOG);
__m128 _sum1 = _mm_setzero_ps();
__m128 _sum2 = _mm_setzero_ps();
__m128 _sum3 = _mm_setzero_ps();
#else // __SSE2__
float I = bias_c_IFOG[0];
float F = bias_c_IFOG[1];
float O = bias_c_IFOG[2];
float G = bias_c_IFOG[3];
#endif // __SSE2__

const float* x = bottom_blob.row(ti);

int i = 0;
#if __SSE2__
for (; i + 3 < size; i += 4)
{
__m256 xi = _mm256_loadu_ps(x);
_sumI = _mm256_comp_fmadd_ps(_mm256_loadu_ps(weight_xc_I), xi, _sumI);
_sumF = _mm256_comp_fmadd_ps(_mm256_loadu_ps(weight_xc_F), xi, _sumF);
_sumO = _mm256_comp_fmadd_ps(_mm256_loadu_ps(weight_xc_O), xi, _sumO);
_sumG = _mm256_comp_fmadd_ps(_mm256_loadu_ps(weight_xc_G), xi, _sumG);
x += 8;
weight_xc_I += 8;
weight_xc_F += 8;
weight_xc_O += 8;
weight_xc_G += 8;
__m128 _xi0 = _mm_load1_ps(x);
__m128 _xi1 = _mm_load1_ps(x + 1);
__m128 _xi2 = _mm_load1_ps(x + 2);
__m128 _xi3 = _mm_load1_ps(x + 3);
__m128 _weight_xc_IFOG0 = _mm_loadu_ps(weight_xc_IFOG);
__m128 _weight_xc_IFOG1 = _mm_loadu_ps(weight_xc_IFOG + 4);
__m128 _weight_xc_IFOG2 = _mm_loadu_ps(weight_xc_IFOG + 8);
__m128 _weight_xc_IFOG3 = _mm_loadu_ps(weight_xc_IFOG + 12);
_IFOG = _mm_comp_fmadd_ps(_weight_xc_IFOG0, _xi0, _IFOG);
_sum1 = _mm_comp_fmadd_ps(_weight_xc_IFOG1, _xi1, _sum1);
_sum2 = _mm_comp_fmadd_ps(_weight_xc_IFOG2, _xi2, _sum2);
_sum3 = _mm_comp_fmadd_ps(_weight_xc_IFOG3, _xi3, _sum3);

x += 4;
weight_xc_IFOG += 16;
}
int nn_num_output = num_output >> 3;
int remain_num_output = num_output & 7;
for (; nn_num_output > 0; nn_num_output--)
#endif // __SSE2__
for (; i < size; i++)
{
__m256 h_cont = _mm256_loadu_ps(hidden_ptr_r);

_sumI = _mm256_comp_fmadd_ps(_mm256_loadu_ps(weight_hc_I), h_cont, _sumI);
_sumF = _mm256_comp_fmadd_ps(_mm256_loadu_ps(weight_hc_F), h_cont, _sumF);
_sumO = _mm256_comp_fmadd_ps(_mm256_loadu_ps(weight_hc_O), h_cont, _sumO);
_sumG = _mm256_comp_fmadd_ps(_mm256_loadu_ps(weight_hc_G), h_cont, _sumG);
hidden_ptr_r += 8;
weight_hc_I += 8;
weight_hc_F += 8;
weight_hc_O += 8;
weight_hc_G += 8;
#if __SSE2__
__m128 _xi = _mm_load1_ps(x);
__m128 _weight_xc_IFOG = _mm_loadu_ps(weight_xc_IFOG);
_IFOG = _mm_comp_fmadd_ps(_weight_xc_IFOG, _xi, _IFOG);
#else // __SSE2__
float xi = x[0];
I += xi * weight_xc_IFOG[0];
F += xi * weight_xc_IFOG[1];
O += xi * weight_xc_IFOG[2];
G += xi * weight_xc_IFOG[3];
#endif // __SSE2__

x += 1;
weight_xc_IFOG += 4;
}
float sums[4];
_mm_storeu_ps(sums, HorizontalSums(_sumI, _sumF, _sumO, _sumG));
sums[0] += bias_c_I[q];
sums[1] += bias_c_F[q];
sums[2] += bias_c_O[q];
sums[3] += bias_c_G[q];

for (; remain_size > 0; remain_size--)

const float* hidden_ptr = hidden_state;

i = 0;
#if __SSE2__
for (; i + 3 < num_output; i += 4)
{
float xi = *x;
sums[0] += *weight_xc_I * xi;
sums[1] += *weight_xc_F * xi;
sums[2] += *weight_xc_O * xi;
sums[3] += *weight_xc_G * xi;
x++;
weight_xc_I++;
weight_xc_F++;
weight_xc_O++;
weight_xc_G++;
__m128 _h_cont0 = _mm_load1_ps(hidden_ptr);
__m128 _h_cont1 = _mm_load1_ps(hidden_ptr + 1);
__m128 _h_cont2 = _mm_load1_ps(hidden_ptr + 2);
__m128 _h_cont3 = _mm_load1_ps(hidden_ptr + 3);
__m128 _weight_hc_IFOG0 = _mm_loadu_ps(weight_hc_IFOG);
__m128 _weight_hc_IFOG1 = _mm_loadu_ps(weight_hc_IFOG + 4);
__m128 _weight_hc_IFOG2 = _mm_loadu_ps(weight_hc_IFOG + 8);
__m128 _weight_hc_IFOG3 = _mm_loadu_ps(weight_hc_IFOG + 12);
_IFOG = _mm_comp_fmadd_ps(_weight_hc_IFOG0, _h_cont0, _IFOG);
_sum1 = _mm_comp_fmadd_ps(_weight_hc_IFOG1, _h_cont1, _sum1);
_sum2 = _mm_comp_fmadd_ps(_weight_hc_IFOG2, _h_cont2, _sum2);
_sum3 = _mm_comp_fmadd_ps(_weight_hc_IFOG3, _h_cont3, _sum3);

hidden_ptr += 4;
weight_hc_IFOG += 16;
}

for (; remain_num_output > 0; remain_num_output--)
#endif // __SSE2__
for (; i < num_output; i++)
{
float h_cont = *hidden_ptr_r;
sums[0] += *weight_hc_I * h_cont;
sums[1] += *weight_hc_F * h_cont;
sums[2] += *weight_hc_O * h_cont;
sums[3] += *weight_hc_G * h_cont;
hidden_ptr_r++;
weight_hc_I++;
weight_hc_F++;
weight_hc_O++;
weight_hc_G++;
#if __SSE2__
__m128 _h_cont = _mm_load1_ps(hidden_ptr);
__m128 _weight_hc_IFOG = _mm_loadu_ps(weight_hc_IFOG);
_IFOG = _mm_comp_fmadd_ps(_weight_hc_IFOG, _h_cont, _IFOG);
#else // __SSE2__
float h_cont = hidden_ptr[0];
I += h_cont * weight_hc_IFOG[0];
F += h_cont * weight_hc_IFOG[1];
O += h_cont * weight_hc_IFOG[2];
G += h_cont * weight_hc_IFOG[3];
#endif // __SSE2__

hidden_ptr += 1;
weight_hc_IFOG += 4;
}
gates_data_I[q] = sums[0];
gates_data_F[q] = sums[1];
gates_data_O[q] = sums[2];
gates_data_G[q] = sums[3];

float* gates_data = gates.row(q);

#if __SSE2__
_IFOG = _mm_add_ps(_IFOG, _sum1);
_sum2 = _mm_add_ps(_sum2, _sum3);
_IFOG = _mm_add_ps(_IFOG, _sum2);

_mm_storeu_ps(gates_data, _IFOG);
#else // __SSE2__
gates_data[0] = I;
gates_data[1] = F;
gates_data[2] = O;
gates_data[3] = G;
#endif // __SSE2__
}

// lstm unit
@@ -330,69 +452,117 @@ static int lstm(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& w
// c_t := f_t .* c_{t-1} + i_t .* g_t
// h_t := o_t .* tanh[c_t]
float* output_data = top_blob.row(ti);

float* cell_ptr = cell_state;
float* hidden_ptr = hidden_state;
const float* gates_data_I = gates.row(0);
const float* gates_data_F = gates.row(1);
const float* gates_data_O = gates.row(2);
const float* gates_data_G = gates.row(3);
int nn_activation = num_output >> 3;
int remain_activations = num_output & 7;
for (; nn_activation > 0; nn_activation--)
float* tmp_hidden_ptr = tmp_hidden_state;
#if __SSE2__
nn_hidden_size = hidden_size >> 2;
remain_hidden_size_start = nn_hidden_size << 2;
#pragma omp parallel for num_threads(opt.num_threads)
for (int qq = 0; qq < nn_hidden_size; qq++)
{
__m256 I = sigmoid_avx(_mm256_loadu_ps(gates_data_I));
__m256 F = sigmoid_avx(_mm256_loadu_ps(gates_data_F));
__m256 O = sigmoid_avx(_mm256_loadu_ps(gates_data_O));
__m256 G = tanh_avx(_mm256_loadu_ps(gates_data_G));
__m256 cell2 = _mm256_add_ps(_mm256_mul_ps(F, _mm256_loadu_ps(cell_ptr)), _mm256_mul_ps(I, G));
__m256 H = _mm256_mul_ps(O, tanh_avx(cell2));
_mm256_storeu_ps(cell_ptr, cell2);
_mm256_storeu_ps(hidden_ptr, H);
_mm256_storeu_ps(output_data, H);
cell_ptr += 8;
output_data += 8;
hidden_ptr += 8;
gates_data_I += 8;
gates_data_F += 8;
gates_data_O += 8;
gates_data_G += 8;
int q = qq * 4;

const float* gates_data = gates.row(q);

__m128 _IFOG_4x4_0 = _mm_loadu_ps(gates_data);
__m128 _IFOG_4x4_1 = _mm_loadu_ps(gates_data + 4);
__m128 _IFOG_4x4_2 = _mm_loadu_ps(gates_data + 8);
__m128 _IFOG_4x4_3 = _mm_loadu_ps(gates_data + 12);

_MM_TRANSPOSE4_PS(_IFOG_4x4_0, _IFOG_4x4_1, _IFOG_4x4_2, _IFOG_4x4_3);

__m128 _I = sigmoid_sse(_IFOG_4x4_0);
__m128 _F = sigmoid_sse(_IFOG_4x4_1);
__m128 _O = sigmoid_sse(_IFOG_4x4_2);
__m128 _G = tanh_sse(_IFOG_4x4_3);

__m128 _cell2 = _mm_add_ps(_mm_mul_ps(_F, _mm_loadu_ps(cell_ptr + q)), _mm_mul_ps(_I, _G));
__m128 _H = _mm_mul_ps(_O, tanh_sse(_cell2));

_mm_storeu_ps(cell_ptr + q, _cell2);

if (num_output == hidden_size)
{
_mm_storeu_ps(hidden_ptr + q, _H);
_mm_storeu_ps(output_data + q, _H);
}
else
{
_mm_storeu_ps(tmp_hidden_ptr + q, _H);
}
}
for (; remain_activations > 0; remain_activations--)
#else // __SSE2__
remain_hidden_size_start = 0;
#endif // __SSE2__
#pragma omp parallel for num_threads(opt.num_threads)
for (int q = remain_hidden_size_start; q < hidden_size; q++)
{
float I = *gates_data_I;
float F = *gates_data_F;
float O = *gates_data_O;
float G = *gates_data_G;
const float* gates_data = gates.row(q);

float I = gates_data[0];
float F = gates_data[1];
float O = gates_data[2];
float G = gates_data[3];

I = 1.f / (1.f + exp(-I));
F = 1.f / (1.f + exp(-F));
O = 1.f / (1.f + exp(-O));
G = tanh(G);
float cell2 = F * *cell_ptr + I * G;

float cell2 = F * cell_ptr[q] + I * G;
float H = O * tanh(cell2);
*cell_ptr = cell2;
*hidden_ptr = H;
*output_data = H;
cell_ptr++;
output_data++;
hidden_ptr++;
gates_data_I++;
gates_data_F++;
gates_data_O++;
gates_data_G++;

cell_ptr[q] = cell2;
if (num_output == hidden_size)
{
hidden_ptr[q] = H;
output_data[q] = H;
}
else
{
tmp_hidden_ptr[q] = H;
}
}

// no cell output here
if (num_output != hidden_size)
{
// int nn_num_output = num_output >> 2;
// int remain_num_output_start = nn_num_output << 2;
// #pragma omp parallel for num_threads(opt.num_threads)
// for (int qq = 0; qq < nn_num_output; qq++)
// {
// int q = qq * 4;
//
// }
int remain_num_output_start = 0;
#pragma omp parallel for num_threads(opt.num_threads)
for (int q = remain_num_output_start; q < num_output; q++)
{
const float* hr = weight_hr.row(q);
const float* tmp_hidden_ptr = tmp_hidden_state;

float H = 0;
for (int i = 0; i < hidden_size; i++)
{
H += tmp_hidden_ptr[i] * hr[i];
}

output_data[q] = H;
hidden_ptr[q] = H;
}
}
}

return 0;
}
#endif

int LSTM_x86::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const
{
#if __AVX__
int T = bottom_blob.h;

int num_directions = direction == 2 ? 2 : 1;

// initial hidden state
@@ -400,8 +570,8 @@ int LSTM_x86::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt)
if (hidden.empty())
return -100;
hidden.fill(0.f);
// internal cell state
Mat cell(num_output, 4u, opt.workspace_allocator);
Mat cell(hidden_size, 4u, opt.workspace_allocator);
if (cell.empty())
return -100;
cell.fill(0.f);
@@ -413,7 +583,7 @@ int LSTM_x86::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt)
// Uni directional
if (direction == 0 || direction == 1)
{
int ret = lstm(bottom_blob, top_blob, direction, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), hidden, cell, opt);
int ret = lstm(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt);
if (ret != 0)
return ret;
}
@@ -428,14 +598,14 @@ int LSTM_x86::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt)
if (top_blob_reverse.empty())
return -100;

int ret0 = lstm(bottom_blob, top_blob_forward, 0, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), hidden, cell, opt);
int ret0 = lstm(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt);
if (ret0 != 0)
return ret0;

hidden.fill(0.0f);
cell.fill(0.0f);

int ret1 = lstm(bottom_blob, top_blob_reverse, 1, weight_xc_data.channel(1), bias_c_data.channel(1), weight_hc_data.channel(1), hidden, cell, opt);
int ret1 = lstm(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), num_output == hidden_size ? Mat() : weight_hr_data.channel(1), hidden, cell, opt);
if (ret1 != 0)
return ret1;

@@ -452,14 +622,10 @@ int LSTM_x86::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt)
}

return 0;
#else
return LSTM::forward(bottom_blob, top_blob, opt);
#endif
}

int LSTM_x86::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top_blobs, const Option& opt) const
{
#if __AVX__
const Mat& bottom_blob = bottom_blobs[0];
int T = bottom_blob.h;
int num_directions = direction == 2 ? 2 : 1;
@@ -479,7 +645,7 @@ int LSTM_x86::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& to
return -100;
hidden.fill(0.f);

cell.create(num_output, num_directions, 4u, hidden_cell_allocator);
cell.create(hidden_size, num_directions, 4u, hidden_cell_allocator);
if (cell.empty())
return -100;
cell.fill(0.f);
@@ -493,7 +659,7 @@ int LSTM_x86::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& to
// Uni directional
if (direction == 0 || direction == 1)
{
int ret = lstm(bottom_blob, top_blob, direction, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), hidden, cell, opt);
int ret = lstm(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt);
if (ret != 0)
return ret;
}
@@ -510,15 +676,13 @@ int LSTM_x86::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& to

Mat hidden0 = hidden.row_range(0, 1);
Mat cell0 = cell.row_range(0, 1);

int ret0 = lstm(bottom_blob, top_blob_forward, 0, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), hidden0, cell0, opt);
int ret0 = lstm(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden0, cell0, opt);
if (ret0 != 0)
return ret0;

Mat hidden1 = hidden.row_range(1, 1);
Mat cell1 = cell.row_range(1, 1);

int ret1 = lstm(bottom_blob, top_blob_reverse, 1, weight_xc_data.channel(1), bias_c_data.channel(1), weight_hc_data.channel(1), hidden1, cell1, opt);
int ret1 = lstm(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), num_output == hidden_size ? Mat() : weight_hr_data.channel(1), hidden1, cell1, opt);
if (ret1 != 0)
return ret1;

@@ -541,9 +705,6 @@ int LSTM_x86::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& to
}

return 0;
#else
return LSTM::forward(bottom_blobs, top_blobs, opt);
#endif
}

} // namespace ncnn

+ 3
- 0
src/layer/x86/lstm_x86.h View File

@@ -31,6 +31,9 @@ public:
virtual int forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top_blobs, const Option& opt) const;

public:
Mat weight_xc_data_packed;
Mat bias_c_data_packed;
Mat weight_hc_data_packed;
};

} // namespace ncnn


+ 70
- 42
tests/test_lstm.cpp View File

@@ -15,50 +15,64 @@
#include "layer/lstm.h"
#include "testutil.h"

static int test_lstm(const ncnn::Mat& a, int outch, int direction)
static int test_lstm(const ncnn::Mat& a, int outch, int direction, int hidden_size = 0)
{
int input_size = a.w;
int num_directions = direction == 2 ? 2 : 1;
if (hidden_size == 0)
hidden_size = outch;

ncnn::ParamDict pd;
pd.set(0, outch);
pd.set(1, outch * input_size * 4 * num_directions);
pd.set(1, hidden_size * input_size * 4 * num_directions);
pd.set(2, direction);
pd.set(3, hidden_size);

std::vector<ncnn::Mat> weights(3);
weights[0] = RandomMat(outch * input_size * 4 * num_directions);
weights[1] = RandomMat(outch * 4 * num_directions);
weights[2] = RandomMat(outch * outch * 4 * num_directions);
std::vector<ncnn::Mat> weights(hidden_size == 0 ? 3 : 4);
weights[0] = RandomMat(hidden_size * input_size * 4 * num_directions);
weights[1] = RandomMat(hidden_size * 4 * num_directions);
weights[2] = RandomMat(outch * hidden_size * 4 * num_directions);
if (hidden_size)
{
weights[3] = RandomMat(hidden_size * outch * num_directions);
}

int ret = test_layer<ncnn::LSTM>("LSTM", pd, weights, a);
if (ret != 0)
{
fprintf(stderr, "test_lstm failed a.dims=%d a=(%d %d %d) outch=%d, direction = %d \n", a.dims, a.w, a.h, a.c, outch, direction);
fprintf(stderr, "test_lstm failed a.dims=%d a=(%d %d %d) outch=%d direction=%d hidden_size=%d\n", a.dims, a.w, a.h, a.c, outch, direction, hidden_size);
}

return ret;
}

int test_lstm_layer_with_hidden(const ncnn::Mat& a, int outch, int direction)
int test_lstm_layer_with_hidden(const ncnn::Mat& a, int outch, int direction, int hidden_size = 0)
{
int input_size = a.w;
int num_directions = direction == 2 ? 2 : 1;
if (hidden_size == 0)
hidden_size = outch;

ncnn::ParamDict pd;
pd.set(0, outch);
pd.set(1, outch * input_size * 4 * num_directions);
pd.set(1, hidden_size * input_size * 4 * num_directions);
pd.set(2, direction);
pd.set(3, hidden_size);

std::vector<ncnn::Mat> weights(3);
weights[0] = RandomMat(outch * input_size * 4 * num_directions);
weights[1] = RandomMat(outch * 4 * num_directions);
weights[2] = RandomMat(outch * outch * 4 * num_directions);
std::vector<ncnn::Mat> weights(hidden_size == 0 ? 3 : 4);
weights[0] = RandomMat(hidden_size * input_size * 4 * num_directions);
weights[1] = RandomMat(hidden_size * 4 * num_directions);
weights[2] = RandomMat(outch * hidden_size * 4 * num_directions);
if (hidden_size)
{
weights[3] = RandomMat(hidden_size * outch * num_directions);
}

// initial hidden state
ncnn::Mat hidden = RandomMat(outch, num_directions);

// initial cell state
ncnn::Mat cell = RandomMat(outch, num_directions);
ncnn::Mat cell = RandomMat(hidden_size, num_directions);

std::vector<ncnn::Mat> as(3);
as[0] = a;
@@ -68,32 +82,39 @@ int test_lstm_layer_with_hidden(const ncnn::Mat& a, int outch, int direction)
int ret = test_layer<ncnn::LSTM>("LSTM", pd, weights, as, 3);
if (ret != 0)
{
fprintf(stderr, "test_lstm_layer_with_hidden failed a.dims=%d a=(%d %d %d) outch=%d, direction = %d \n", a.dims, a.w, a.h, a.c, outch, direction);
fprintf(stderr, "test_lstm_layer_with_hidden failed a.dims=%d a=(%d %d %d) outch=%d direction=%d hidden_size=%d\n", a.dims, a.w, a.h, a.c, outch, direction, hidden_size);
}

return ret;
}

int test_lstm_layer_with_hidden_input(const ncnn::Mat& a, int outch, int direction)
int test_lstm_layer_with_hidden_input(const ncnn::Mat& a, int outch, int direction, int hidden_size = 0)
{
int input_size = a.w;
int num_directions = direction == 2 ? 2 : 1;
if (hidden_size == 0)
hidden_size = outch;

ncnn::ParamDict pd;
pd.set(0, outch);
pd.set(1, outch * input_size * 4 * num_directions);
pd.set(1, hidden_size * input_size * 4 * num_directions);
pd.set(2, direction);
pd.set(3, hidden_size);

std::vector<ncnn::Mat> weights(3);
weights[0] = RandomMat(outch * input_size * 4 * num_directions);
weights[1] = RandomMat(outch * 4 * num_directions);
weights[2] = RandomMat(outch * outch * 4 * num_directions);
std::vector<ncnn::Mat> weights(hidden_size == 0 ? 3 : 4);
weights[0] = RandomMat(hidden_size * input_size * 4 * num_directions);
weights[1] = RandomMat(hidden_size * 4 * num_directions);
weights[2] = RandomMat(outch * hidden_size * 4 * num_directions);
if (hidden_size)
{
weights[3] = RandomMat(hidden_size * outch * num_directions);
}

// initial hidden state
ncnn::Mat hidden = RandomMat(outch, num_directions);

// initial cell state
ncnn::Mat cell = RandomMat(outch, num_directions);
ncnn::Mat cell = RandomMat(hidden_size, num_directions);

std::vector<ncnn::Mat> as(3);
as[0] = a;
@@ -103,26 +124,33 @@ int test_lstm_layer_with_hidden_input(const ncnn::Mat& a, int outch, int directi
int ret = test_layer<ncnn::LSTM>("LSTM", pd, weights, as, 1);
if (ret != 0)
{
fprintf(stderr, "test_lstm_layer_with_hidden_input failed a.dims=%d a=(%d %d %d) outch=%d, direction = %d \n", a.dims, a.w, a.h, a.c, outch, direction);
fprintf(stderr, "test_lstm_layer_with_hidden_input failed a.dims=%d a=(%d %d %d) outch=%d direction=%d hidden_size=%d\n", a.dims, a.w, a.h, a.c, outch, direction, hidden_size);
}

return ret;
}

int test_lstm_layer_with_hidden_output(const ncnn::Mat& a, int outch, int direction)
int test_lstm_layer_with_hidden_output(const ncnn::Mat& a, int outch, int direction, int hidden_size = 0)
{
int input_size = a.w;
int num_directions = direction == 2 ? 2 : 1;
if (hidden_size == 0)
hidden_size = outch;

ncnn::ParamDict pd;
pd.set(0, outch);
pd.set(1, outch * input_size * 4 * num_directions);
pd.set(1, hidden_size * input_size * 4 * num_directions);
pd.set(2, direction);
pd.set(3, hidden_size);

std::vector<ncnn::Mat> weights(3);
weights[0] = RandomMat(outch * input_size * 4 * num_directions);
weights[1] = RandomMat(outch * 4 * num_directions);
weights[2] = RandomMat(outch * outch * 4 * num_directions);
std::vector<ncnn::Mat> weights(hidden_size == 0 ? 3 : 4);
weights[0] = RandomMat(hidden_size * input_size * 4 * num_directions);
weights[1] = RandomMat(hidden_size * 4 * num_directions);
weights[2] = RandomMat(outch * hidden_size * 4 * num_directions);
if (hidden_size)
{
weights[3] = RandomMat(hidden_size * outch * num_directions);
}

std::vector<ncnn::Mat> as(1);
as[0] = a;
@@ -130,7 +158,7 @@ int test_lstm_layer_with_hidden_output(const ncnn::Mat& a, int outch, int direct
int ret = test_layer<ncnn::LSTM>("LSTM", pd, weights, as, 3);
if (ret != 0)
{
fprintf(stderr, "test_lstm_layer_with_hidden_output failed a.dims=%d a=(%d %d %d) outch=%d, direction = %d \n", a.dims, a.w, a.h, a.c, outch, direction);
fprintf(stderr, "test_lstm_layer_with_hidden_output failed a.dims=%d a=(%d %d %d) outch=%d direction=%d hidden_size=%d\n", a.dims, a.w, a.h, a.c, outch, direction, hidden_size);
}

return ret;
@@ -147,7 +175,7 @@ static int test_lstm_0()
|| test_lstm(RandomMat(5, 16), 16, 2)
|| test_lstm(RandomMat(3, 16), 8, 2)
|| test_lstm(RandomMat(8, 16), 16, 2)
|| test_lstm(RandomMat(2, 5), 17, 2);
|| test_lstm(RandomMat(2, 5), 17, 2, 15);
}

static int test_lstm_1()
@@ -160,7 +188,7 @@ static int test_lstm_1()
|| test_lstm_layer_with_hidden(RandomMat(19, 15), 8, 2)
|| test_lstm_layer_with_hidden(RandomMat(5, 16), 16, 2)
|| test_lstm_layer_with_hidden(RandomMat(3, 16), 8, 2)
|| test_lstm_layer_with_hidden(RandomMat(2, 5), 99, 2)
|| test_lstm_layer_with_hidden(RandomMat(2, 5), 99, 2, 33)
|| test_lstm_layer_with_hidden(RandomMat(4, 4), 1, 1)
|| test_lstm_layer_with_hidden(RandomMat(8, 2), 2, 1)
|| test_lstm_layer_with_hidden(RandomMat(16, 8), 7, 1)
@@ -168,7 +196,7 @@ static int test_lstm_1()
|| test_lstm_layer_with_hidden(RandomMat(19, 15), 8, 1)
|| test_lstm_layer_with_hidden(RandomMat(5, 16), 16, 1)
|| test_lstm_layer_with_hidden(RandomMat(3, 16), 8, 1)
|| test_lstm_layer_with_hidden(RandomMat(2, 5), 99, 1)
|| test_lstm_layer_with_hidden(RandomMat(2, 5), 99, 1, 33)
|| test_lstm_layer_with_hidden(RandomMat(4, 2), 1, 0)
|| test_lstm_layer_with_hidden(RandomMat(8, 2), 2, 0)
|| test_lstm_layer_with_hidden(RandomMat(16, 8), 7, 0)
@@ -176,7 +204,7 @@ static int test_lstm_1()
|| test_lstm_layer_with_hidden(RandomMat(19, 15), 8, 0)
|| test_lstm_layer_with_hidden(RandomMat(5, 16), 16, 0)
|| test_lstm_layer_with_hidden(RandomMat(3, 16), 8, 0)
|| test_lstm_layer_with_hidden(RandomMat(2, 5), 17, 0)
|| test_lstm_layer_with_hidden(RandomMat(2, 5), 17, 0, 15)

|| test_lstm_layer_with_hidden_input(RandomMat(4, 4), 1, 2)
|| test_lstm_layer_with_hidden_input(RandomMat(8, 2), 2, 2)
@@ -185,7 +213,7 @@ static int test_lstm_1()
|| test_lstm_layer_with_hidden_input(RandomMat(19, 15), 8, 2)
|| test_lstm_layer_with_hidden_input(RandomMat(5, 16), 16, 2)
|| test_lstm_layer_with_hidden_input(RandomMat(3, 16), 8, 2)
|| test_lstm_layer_with_hidden_input(RandomMat(2, 5), 99, 2)
|| test_lstm_layer_with_hidden_input(RandomMat(2, 5), 99, 2, 33)
|| test_lstm_layer_with_hidden_input(RandomMat(4, 4), 1, 1)
|| test_lstm_layer_with_hidden_input(RandomMat(8, 2), 2, 1)
|| test_lstm_layer_with_hidden_input(RandomMat(16, 8), 7, 1)
@@ -193,7 +221,7 @@ static int test_lstm_1()
|| test_lstm_layer_with_hidden_input(RandomMat(19, 15), 8, 1)
|| test_lstm_layer_with_hidden_input(RandomMat(5, 16), 16, 1)
|| test_lstm_layer_with_hidden_input(RandomMat(3, 16), 8, 1)
|| test_lstm_layer_with_hidden_input(RandomMat(2, 5), 99, 1)
|| test_lstm_layer_with_hidden_input(RandomMat(2, 5), 99, 1, 33)
|| test_lstm_layer_with_hidden_input(RandomMat(4, 2), 1, 0)
|| test_lstm_layer_with_hidden_input(RandomMat(8, 2), 2, 0)
|| test_lstm_layer_with_hidden_input(RandomMat(16, 8), 7, 0)
@@ -201,7 +229,7 @@ static int test_lstm_1()
|| test_lstm_layer_with_hidden_input(RandomMat(19, 15), 8, 0)
|| test_lstm_layer_with_hidden_input(RandomMat(5, 16), 16, 0)
|| test_lstm_layer_with_hidden_input(RandomMat(3, 16), 8, 0)
|| test_lstm_layer_with_hidden_input(RandomMat(2, 5), 17, 0)
|| test_lstm_layer_with_hidden_input(RandomMat(2, 5), 17, 0, 15)

|| test_lstm_layer_with_hidden_output(RandomMat(4, 4), 1, 2)
|| test_lstm_layer_with_hidden_output(RandomMat(8, 2), 2, 2)
@@ -210,7 +238,7 @@ static int test_lstm_1()
|| test_lstm_layer_with_hidden_output(RandomMat(19, 15), 8, 2)
|| test_lstm_layer_with_hidden_output(RandomMat(5, 16), 16, 2)
|| test_lstm_layer_with_hidden_output(RandomMat(3, 16), 8, 2)
|| test_lstm_layer_with_hidden_output(RandomMat(2, 5), 99, 2)
|| test_lstm_layer_with_hidden_output(RandomMat(2, 5), 99, 2, 33)
|| test_lstm_layer_with_hidden_output(RandomMat(4, 4), 1, 1)
|| test_lstm_layer_with_hidden_output(RandomMat(8, 2), 2, 1)
|| test_lstm_layer_with_hidden_output(RandomMat(16, 8), 7, 1)
@@ -218,7 +246,7 @@ static int test_lstm_1()
|| test_lstm_layer_with_hidden_output(RandomMat(19, 15), 8, 1)
|| test_lstm_layer_with_hidden_output(RandomMat(5, 16), 16, 1)
|| test_lstm_layer_with_hidden_output(RandomMat(3, 16), 8, 1)
|| test_lstm_layer_with_hidden_output(RandomMat(2, 5), 99, 1)
|| test_lstm_layer_with_hidden_output(RandomMat(2, 5), 99, 1, 33)
|| test_lstm_layer_with_hidden_output(RandomMat(4, 2), 1, 0)
|| test_lstm_layer_with_hidden_output(RandomMat(8, 2), 2, 0)
|| test_lstm_layer_with_hidden_output(RandomMat(16, 8), 7, 0)
@@ -226,7 +254,7 @@ static int test_lstm_1()
|| test_lstm_layer_with_hidden_output(RandomMat(19, 15), 8, 0)
|| test_lstm_layer_with_hidden_output(RandomMat(5, 16), 16, 0)
|| test_lstm_layer_with_hidden_output(RandomMat(3, 16), 8, 0)
|| test_lstm_layer_with_hidden_output(RandomMat(2, 5), 17, 0);
|| test_lstm_layer_with_hidden_output(RandomMat(2, 5), 17, 0, 15);
}

static int test_lstm_2()
@@ -240,7 +268,7 @@ static int test_lstm_2()
|| test_lstm(RandomMat(5, 16), 16, 0)
|| test_lstm(RandomMat(3, 16), 8, 0)
|| test_lstm(RandomMat(8, 16), 16, 0)
|| test_lstm(RandomMat(2, 5), 17, 0);
|| test_lstm(RandomMat(2, 5), 17, 0, 15);
}
static int test_lstm_3()
{
@@ -253,7 +281,7 @@ static int test_lstm_3()
|| test_lstm(RandomMat(5, 16), 16, 1)
|| test_lstm(RandomMat(3, 16), 8, 1)
|| test_lstm(RandomMat(8, 16), 16, 1)
|| test_lstm(RandomMat(2, 5), 17, 1);
|| test_lstm(RandomMat(2, 5), 17, 1, 15);
}

int main()


+ 24
- 7
tools/pnnx/src/pass_level1/nn_LSTM.cpp View File

@@ -33,9 +33,9 @@ public:

void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph, const torch::jit::Module& mod) const
{
// mod.dump(true, true, true);
// graph->dump();
// mod.dump(true, true, true);
//
// graph->dump();

const torch::jit::Node* lstm = find_node_by_kind(graph, "aten::lstm");

@@ -49,12 +49,13 @@ public:
op->params["pnnx_rnn_output_swapped"] = 1;
}

// for (auto aa : lstm->schema().arguments())
// {
// fprintf(stderr, "arg %s\n", aa.name().c_str());
// }
// for (auto aa : lstm->schema().arguments())
// {
// fprintf(stderr, "arg %s\n", aa.name().c_str());
// }

const auto& weight_ih_l0 = mod.attr("weight_ih_l0").toTensor();
const auto& weight_hh_l0 = mod.attr("weight_hh_l0").toTensor();

op->params["input_size"] = weight_ih_l0.size(1);
op->params["hidden_size"] = weight_ih_l0.size(0) / 4;
@@ -62,10 +63,12 @@ public:
op->params["bias"] = lstm->namedInput("has_biases");
op->params["batch_first"] = lstm->namedInput("batch_first");
op->params["bidirectional"] = lstm->namedInput("bidirectional");
op->params["proj_size"] = weight_ih_l0.size(0) / 4 == weight_hh_l0.size(1) ? 0 : weight_hh_l0.size(1);

const int num_layers = op->params["num_layers"].i;
const bool bias = op->params["bias"].b;
const bool bidirectional = op->params["bidirectional"].b;
const int proj_size = op->params["proj_size"].i;

for (int k = 0; k < num_layers; k++)
{
@@ -84,6 +87,13 @@ public:
op->attrs[bias_hh_lk_key] = mod.attr(bias_hh_lk_key).toTensor();
}

if (proj_size > 0)
{
std::string weight_hr_lk_key = std::string("weight_hr_l") + std::to_string(k);

op->attrs[weight_hr_lk_key] = mod.attr(weight_hr_lk_key).toTensor();
}

if (bidirectional)
{
std::string weight_ih_lk_reverse_key = std::string("weight_ih_l") + std::to_string(k) + "_reverse";
@@ -100,6 +110,13 @@ public:
op->attrs[bias_ih_lk_reverse_key] = mod.attr(bias_ih_lk_reverse_key).toTensor();
op->attrs[bias_hh_lk_reverse_key] = mod.attr(bias_hh_lk_reverse_key).toTensor();
}

if (proj_size > 0)
{
std::string weight_hr_lk_reverse_key = std::string("weight_hr_l") + std::to_string(k) + "_reverse";

op->attrs[weight_hr_lk_reverse_key] = mod.attr(weight_hr_lk_reverse_key).toTensor();
}
}
}
}


+ 19
- 1
tools/pnnx/src/pass_level5/unroll_rnn_op.cpp View File

@@ -42,6 +42,7 @@ void unroll_rnn_op(Graph& graph)
bool has_output_hidden = op->outputs.size() >= 2;
bool has_output_cell = op->outputs.size() == 3;
const int hidden_size = op->params["hidden_size"].i;
const int proj_size = (op->type == "nn.LSTM") ? op->params["proj_size"].i : 0;
bool has_bias = op->params["bias"].b;
bool is_bidirectional = op->params["bidirectional"].b;

@@ -116,7 +117,14 @@ void unroll_rnn_op(Graph& graph)
}
else
{
op1->params["input_size"] = is_bidirectional ? hidden_size * 2 : hidden_size;
if (proj_size)
{
op1->params["input_size"] = is_bidirectional ? proj_size * 2 : proj_size;
}
else
{
op1->params["input_size"] = is_bidirectional ? hidden_size * 2 : hidden_size;
}

op1->inputs.push_back(unrolled_ops[j - 1]->outputs[0]);
op1->inputs[0]->consumers.push_back(op1);
@@ -171,6 +179,11 @@ void unroll_rnn_op(Graph& graph)
op1->attrs["bias_ih_l0"] = op->attrs["bias_ih_l" + std::to_string(j)];
}

if (proj_size)
{
op1->attrs["weight_hr_l0"] = op->attrs["weight_hr_l" + std::to_string(j)];
}

if (is_bidirectional)
{
op1->attrs["weight_hh_l0_reverse"] = op->attrs["weight_hh_l" + std::to_string(j) + "_reverse"];
@@ -181,6 +194,11 @@ void unroll_rnn_op(Graph& graph)
op1->attrs["bias_hh_l0_reverse"] = op->attrs["bias_hh_l" + std::to_string(j) + "_reverse"];
op1->attrs["bias_ih_l0_reverse"] = op->attrs["bias_ih_l" + std::to_string(j) + "_reverse"];
}

if (proj_size)
{
op1->attrs["weight_hr_l0_reverse"] = op->attrs["weight_hr_l" + std::to_string(j) + "_reverse"];
}
}

unrolled_ops[j] = op1;


+ 73
- 53
tools/pnnx/src/pass_ncnn/nn_LSTM.cpp View File

@@ -27,7 +27,7 @@ public:
return R"PNNXIR(7767517
3 4
pnnx.Input input 0 1 input
nn.LSTM op_0 1 3 input out out_hidden out_cell input_size=%input_size hidden_size=%hidden_size num_layers=1 bias=%bias batch_first=%batch_first bidirectional=%bidirectional @weight_ih_l0 @weight_hh_l0 @bias_ih_l0 @bias_hh_l0 @weight_ih_l0_reverse @weight_hh_l0_reverse @bias_ih_l0_reverse @bias_hh_l0_reverse
nn.LSTM op_0 1 3 input out out_hidden out_cell input_size=%input_size hidden_size=%hidden_size num_layers=1 bias=%bias batch_first=%batch_first bidirectional=%bidirectional proj_size=%proj_size @weight_ih_l0 @weight_hh_l0 @bias_ih_l0 @bias_hh_l0 @weight_hr_l0 @weight_ih_l0_reverse @weight_hh_l0_reverse @bias_ih_l0_reverse @bias_hh_l0_reverse @weight_hr_l0_reverse
pnnx.Output output 3 0 out out_hidden out_cell
)PNNXIR";
}
@@ -46,14 +46,19 @@ pnnx.Output output 3 0 out out_hidden out_cell
{
const bool bidirectional = captured_params.at("bidirectional").b;
const int num_directions = bidirectional ? 2 : 1;
const int num_output = captured_params.at("hidden_size").i;
const int hidden_size = captured_params.at("hidden_size").i;
const int input_size = captured_params.at("input_size").i;

int weight_data_size = num_directions * num_output * input_size * 4;
int proj_size = captured_params.at("proj_size").i;
if (proj_size == 0)
proj_size = hidden_size;

op->params["0"] = num_output;
int weight_data_size = num_directions * hidden_size * input_size * 4;

op->params["0"] = proj_size;
op->params["1"] = weight_data_size;
op->params["2"] = bidirectional ? 2 : 0;
op->params["3"] = hidden_size;

op->attrs["0"] = Attribute();
op->attrs["0"].data = {0, 0, 0, 0};
@@ -62,7 +67,7 @@ pnnx.Output output 3 0 out out_hidden out_cell
{
std::vector<float> new_weight_ih;
{
const int weight_data_size_g = num_output * input_size;
const int weight_data_size_g = hidden_size * input_size;

const float* weight_ih = (const float*)captured_attrs.at("op_0.weight_ih_l0").data.data();
const float* iptr = weight_ih;
@@ -70,7 +75,7 @@ pnnx.Output output 3 0 out out_hidden out_cell
const float* gptr = weight_ih + weight_data_size_g * 2;
const float* optr = weight_ih + weight_data_size_g * 3;

new_weight_ih.resize(4 * num_output * input_size);
new_weight_ih.resize(4 * hidden_size * input_size);
float* weight = (float*)new_weight_ih.data();
float* w_iptr = weight;
float* w_fptr = weight + weight_data_size_g;
@@ -86,7 +91,7 @@ pnnx.Output output 3 0 out out_hidden out_cell
{
std::vector<float> new_weight_ih_reverse;
{
const int weight_data_size_g = num_output * input_size;
const int weight_data_size_g = hidden_size * input_size;

const float* weight_ih = (const float*)captured_attrs.at("op_0.weight_ih_l0_reverse").data.data();
const float* iptr = weight_ih;
@@ -94,7 +99,7 @@ pnnx.Output output 3 0 out out_hidden out_cell
const float* gptr = weight_ih + weight_data_size_g * 2;
const float* optr = weight_ih + weight_data_size_g * 3;

new_weight_ih_reverse.resize(4 * num_output * input_size);
new_weight_ih_reverse.resize(4 * hidden_size * input_size);
float* weight = (float*)new_weight_ih_reverse.data();
float* w_iptr = weight;
float* w_fptr = weight + weight_data_size_g;
@@ -105,11 +110,11 @@ pnnx.Output output 3 0 out out_hidden out_cell
memcpy(w_optr, optr, weight_data_size_g * sizeof(float));
memcpy(w_gptr, gptr, weight_data_size_g * sizeof(float));
}
op->attrs["1"] = Attribute({4, num_output, input_size}, new_weight_ih) + Attribute({4, num_output, input_size}, new_weight_ih_reverse);
op->attrs["1"] = Attribute({4, hidden_size, input_size}, new_weight_ih) + Attribute({4, hidden_size, input_size}, new_weight_ih_reverse);
}
else
{
op->attrs["1"] = Attribute({4, num_output, input_size}, new_weight_ih);
op->attrs["1"] = Attribute({4, hidden_size, input_size}, new_weight_ih);
}
}

@@ -124,33 +129,33 @@ pnnx.Output output 3 0 out out_hidden out_cell
const float* bias_ih = (const float*)captured_attrs.at("op_0.bias_ih_l0").data.data();
const float* bias_hh = (const float*)captured_attrs.at("op_0.bias_hh_l0").data.data();
const float* bias_ih_iptr = bias_ih;
const float* bias_ih_fptr = bias_ih + num_output;
const float* bias_ih_gptr = bias_ih + num_output * 2;
const float* bias_ih_optr = bias_ih + num_output * 3;
const float* bias_ih_fptr = bias_ih + hidden_size;
const float* bias_ih_gptr = bias_ih + hidden_size * 2;
const float* bias_ih_optr = bias_ih + hidden_size * 3;
const float* bias_hh_iptr = bias_hh;
const float* bias_hh_fptr = bias_hh + num_output;
const float* bias_hh_gptr = bias_hh + num_output * 2;
const float* bias_hh_optr = bias_hh + num_output * 3;
const float* bias_hh_fptr = bias_hh + hidden_size;
const float* bias_hh_gptr = bias_hh + hidden_size * 2;
const float* bias_hh_optr = bias_hh + hidden_size * 3;

new_bias.resize(4 * num_output);
new_bias.resize(4 * hidden_size);
float* bias = (float*)new_bias.data();
float* b_iptr = bias;
float* b_fptr = bias + num_output;
float* b_optr = bias + num_output * 2;
float* b_gptr = bias + num_output * 3;
for (int i = 0; i < num_output; i++)
float* b_fptr = bias + hidden_size;
float* b_optr = bias + hidden_size * 2;
float* b_gptr = bias + hidden_size * 3;
for (int i = 0; i < hidden_size; i++)
{
b_iptr[i] = bias_ih_iptr[i] + bias_hh_iptr[i];
}
for (int i = 0; i < num_output; i++)
for (int i = 0; i < hidden_size; i++)
{
b_fptr[i] = bias_ih_fptr[i] + bias_hh_fptr[i];
}
for (int i = 0; i < num_output; i++)
for (int i = 0; i < hidden_size; i++)
{
b_optr[i] = bias_ih_optr[i] + bias_hh_optr[i];
}
for (int i = 0; i < num_output; i++)
for (int i = 0; i < hidden_size; i++)
{
b_gptr[i] = bias_ih_gptr[i] + bias_hh_gptr[i];
}
@@ -163,63 +168,63 @@ pnnx.Output output 3 0 out out_hidden out_cell
const float* bias_ih = (const float*)captured_attrs.at("op_0.bias_ih_l0_reverse").data.data();
const float* bias_hh = (const float*)captured_attrs.at("op_0.bias_hh_l0_reverse").data.data();
const float* bias_ih_iptr = bias_ih;
const float* bias_ih_fptr = bias_ih + num_output;
const float* bias_ih_gptr = bias_ih + num_output * 2;
const float* bias_ih_optr = bias_ih + num_output * 3;
const float* bias_ih_fptr = bias_ih + hidden_size;
const float* bias_ih_gptr = bias_ih + hidden_size * 2;
const float* bias_ih_optr = bias_ih + hidden_size * 3;
const float* bias_hh_iptr = bias_hh;
const float* bias_hh_fptr = bias_hh + num_output;
const float* bias_hh_gptr = bias_hh + num_output * 2;
const float* bias_hh_optr = bias_hh + num_output * 3;
const float* bias_hh_fptr = bias_hh + hidden_size;
const float* bias_hh_gptr = bias_hh + hidden_size * 2;
const float* bias_hh_optr = bias_hh + hidden_size * 3;

new_bias_reverse.resize(4 * num_output);
new_bias_reverse.resize(4 * hidden_size);
float* bias = (float*)new_bias_reverse.data();
float* b_iptr = bias;
float* b_fptr = bias + num_output;
float* b_optr = bias + num_output * 2;
float* b_gptr = bias + num_output * 3;
for (int i = 0; i < num_output; i++)
float* b_fptr = bias + hidden_size;
float* b_optr = bias + hidden_size * 2;
float* b_gptr = bias + hidden_size * 3;
for (int i = 0; i < hidden_size; i++)
{
b_iptr[i] = bias_ih_iptr[i] + bias_hh_iptr[i];
}
for (int i = 0; i < num_output; i++)
for (int i = 0; i < hidden_size; i++)
{
b_fptr[i] = bias_ih_fptr[i] + bias_hh_fptr[i];
}
for (int i = 0; i < num_output; i++)
for (int i = 0; i < hidden_size; i++)
{
b_optr[i] = bias_ih_optr[i] + bias_hh_optr[i];
}
for (int i = 0; i < num_output; i++)
for (int i = 0; i < hidden_size; i++)
{
b_gptr[i] = bias_ih_gptr[i] + bias_hh_gptr[i];
}
}

op->attrs["3"] = Attribute({4, num_output}, new_bias) + Attribute({4, num_output}, new_bias_reverse);
op->attrs["3"] = Attribute({4, hidden_size}, new_bias) + Attribute({4, hidden_size}, new_bias_reverse);
}
else
{
op->attrs["3"] = Attribute({4, num_output}, new_bias);
op->attrs["3"] = Attribute({4, hidden_size}, new_bias);
}
}
else
{
std::vector<float> bias(4 * num_output, 0.f);
std::vector<float> bias(4 * hidden_size, 0.f);

if (bidirectional)
op->attrs["3"] = Attribute({4, num_output}, bias) + Attribute({4, num_output}, bias);
op->attrs["3"] = Attribute({4, hidden_size}, bias) + Attribute({4, hidden_size}, bias);
else
op->attrs["3"] = Attribute({4, num_output}, bias);
op->attrs["3"] = Attribute({4, hidden_size}, bias);
}

op->attrs["4"] = Attribute();
op->attrs["4"].data = {0, 0, 0, 0};

// reorder IFGO-hidden-hidden to IFOG-hidden-hidden
// reorder IFGO-hidden-proj to IFOG-hidden-proj
{
std::vector<float> new_weight_hh;
{
const int weight_data_size_g = num_output * num_output;
const int weight_data_size_g = hidden_size * proj_size;

const float* weight_hh = (const float*)captured_attrs.at("op_0.weight_hh_l0").data.data();
const float* iptr = weight_hh;
@@ -227,7 +232,7 @@ pnnx.Output output 3 0 out out_hidden out_cell
const float* gptr = weight_hh + weight_data_size_g * 2;
const float* optr = weight_hh + weight_data_size_g * 3;

new_weight_hh.resize(4 * num_output * num_output);
new_weight_hh.resize(4 * hidden_size * proj_size);
float* weight = (float*)new_weight_hh.data();
float* w_iptr = weight;
float* w_fptr = weight + weight_data_size_g;
@@ -243,7 +248,7 @@ pnnx.Output output 3 0 out out_hidden out_cell
{
std::vector<float> new_weight_hh_reverse;
{
const int weight_data_size_g = num_output * num_output;
const int weight_data_size_g = hidden_size * proj_size;

const float* weight_hh = (const float*)captured_attrs.at("op_0.weight_hh_l0_reverse").data.data();
const float* iptr = weight_hh;
@@ -251,7 +256,7 @@ pnnx.Output output 3 0 out out_hidden out_cell
const float* gptr = weight_hh + weight_data_size_g * 2;
const float* optr = weight_hh + weight_data_size_g * 3;

new_weight_hh_reverse.resize(4 * num_output * num_output);
new_weight_hh_reverse.resize(4 * hidden_size * proj_size);
float* weight = (float*)new_weight_hh_reverse.data();
float* w_iptr = weight;
float* w_fptr = weight + weight_data_size_g;
@@ -262,11 +267,26 @@ pnnx.Output output 3 0 out out_hidden out_cell
memcpy(w_optr, optr, weight_data_size_g * sizeof(float));
memcpy(w_gptr, gptr, weight_data_size_g * sizeof(float));
}
op->attrs["5"] = Attribute({4, num_output, num_output}, new_weight_hh) + Attribute({4, num_output, num_output}, new_weight_hh_reverse);
op->attrs["5"] = Attribute({4, hidden_size, proj_size}, new_weight_hh) + Attribute({4, hidden_size, proj_size}, new_weight_hh_reverse);
}
else
{
op->attrs["5"] = Attribute({4, hidden_size, proj_size}, new_weight_hh);
}
}

if (proj_size != hidden_size)
{
op->attrs["6"] = Attribute();
op->attrs["6"].data = {0, 0, 0, 0};

if (bidirectional)
{
op->attrs["7"] = captured_attrs.at("op_0.weight_hr_l0") + captured_attrs.at("op_0.weight_hr_l0_reverse");
}
else
{
op->attrs["5"] = Attribute({4, num_output, num_output}, new_weight_hh);
op->attrs["7"] = captured_attrs.at("op_0.weight_hr_l0");
}
}
}
@@ -284,7 +304,7 @@ public:
pnnx.Input input 0 1 input
pnnx.Input in_hidden 0 1 in_hidden
pnnx.Input in_hidden 0 1 in_cell
nn.LSTM op_0 3 3 input in_hidden in_cell out out_hidden out_cell input_size=%input_size hidden_size=%hidden_size num_layers=1 bias=%bias batch_first=%batch_first bidirectional=%bidirectional @weight_ih_l0 @weight_hh_l0 @bias_ih_l0 @bias_hh_l0 @weight_ih_l0_reverse @weight_hh_l0_reverse @bias_ih_l0_reverse @bias_hh_l0_reverse
nn.LSTM op_0 3 3 input in_hidden in_cell out out_hidden out_cell input_size=%input_size hidden_size=%hidden_size num_layers=1 bias=%bias batch_first=%batch_first bidirectional=%bidirectional proj_size=%proj_size @weight_ih_l0 @weight_hh_l0 @bias_ih_l0 @bias_hh_l0 @weight_hr_l0 @weight_ih_l0_reverse @weight_hh_l0_reverse @bias_ih_l0_reverse @bias_hh_l0_reverse @weight_hr_l0_reverse
pnnx.Output output 3 0 out out_hidden out_cell
)PNNXIR";
}
@@ -300,7 +320,7 @@ public:
return R"PNNXIR(7767517
3 2
pnnx.Input input 0 1 input
nn.LSTM op_0 1 1 input out input_size=%input_size hidden_size=%hidden_size num_layers=1 bias=%bias batch_first=%batch_first bidirectional=%bidirectional @weight_ih_l0 @weight_hh_l0 @bias_ih_l0 @bias_hh_l0 @weight_ih_l0_reverse @weight_hh_l0_reverse @bias_ih_l0_reverse @bias_hh_l0_reverse
nn.LSTM op_0 1 1 input out input_size=%input_size hidden_size=%hidden_size num_layers=1 bias=%bias batch_first=%batch_first bidirectional=%bidirectional proj_size=%proj_size @weight_ih_l0 @weight_hh_l0 @bias_ih_l0 @bias_hh_l0 @weight_hr_l0 @weight_ih_l0_reverse @weight_hh_l0_reverse @bias_ih_l0_reverse @bias_hh_l0_reverse @weight_hr_l0_reverse
pnnx.Output output 1 0 out
)PNNXIR";
}
@@ -318,7 +338,7 @@ public:
pnnx.Input input 0 1 input
pnnx.Input in_hidden 0 1 in_hidden
pnnx.Input in_hidden 0 1 in_cell
nn.LSTM op_0 3 1 input in_hidden in_cell out input_size=%input_size hidden_size=%hidden_size num_layers=1 bias=%bias batch_first=%batch_first bidirectional=%bidirectional @weight_ih_l0 @weight_hh_l0 @bias_ih_l0 @bias_hh_l0 @weight_ih_l0_reverse @weight_hh_l0_reverse @bias_ih_l0_reverse @bias_hh_l0_reverse
nn.LSTM op_0 3 1 input in_hidden in_cell out input_size=%input_size hidden_size=%hidden_size num_layers=1 bias=%bias batch_first=%batch_first bidirectional=%bidirectional proj_size=%proj_size @weight_ih_l0 @weight_hh_l0 @bias_ih_l0 @bias_hh_l0 @weight_hr_l0 @weight_ih_l0_reverse @weight_hh_l0_reverse @bias_ih_l0_reverse @bias_hh_l0_reverse @weight_hr_l0_reverse
pnnx.Output output 1 0 out
)PNNXIR";
}


+ 10
- 10
tools/pnnx/tests/ncnn/test_nn_LSTM.py View File

@@ -22,15 +22,15 @@ class Model(nn.Module):

self.lstm_0_0 = nn.LSTM(input_size=32, hidden_size=16)
self.lstm_0_1 = nn.LSTM(input_size=16, hidden_size=16, num_layers=3, bias=False)
self.lstm_0_2 = nn.LSTM(input_size=16, hidden_size=16, num_layers=4, bias=True, bidirectional=True)
self.lstm_0_3 = nn.LSTM(input_size=16, hidden_size=16, num_layers=4, bias=True, bidirectional=True)
self.lstm_0_4 = nn.LSTM(input_size=16, hidden_size=16, num_layers=4, bias=True, bidirectional=True)
self.lstm_0_2 = nn.LSTM(input_size=16, hidden_size=16, num_layers=4, bias=True, bidirectional=True, proj_size=10)
self.lstm_0_3 = nn.LSTM(input_size=20, hidden_size=16, num_layers=4, bias=True, bidirectional=True, proj_size=10)
self.lstm_0_4 = nn.LSTM(input_size=20, hidden_size=16, num_layers=4, bias=True, bidirectional=True, proj_size=10)

self.lstm_1_0 = nn.LSTM(input_size=25, hidden_size=16, batch_first=True)
self.lstm_1_1 = nn.LSTM(input_size=16, hidden_size=16, num_layers=3, bias=False, batch_first=True)
self.lstm_1_2 = nn.LSTM(input_size=16, hidden_size=16, num_layers=4, bias=True, batch_first=True, bidirectional=True)
self.lstm_1_3 = nn.LSTM(input_size=16, hidden_size=16, num_layers=4, bias=True, batch_first=True, bidirectional=True)
self.lstm_1_4 = nn.LSTM(input_size=16, hidden_size=16, num_layers=4, bias=True, batch_first=True, bidirectional=True)
self.lstm_1_2 = nn.LSTM(input_size=16, hidden_size=16, num_layers=4, bias=True, batch_first=True, bidirectional=True, proj_size=10)
self.lstm_1_3 = nn.LSTM(input_size=20, hidden_size=16, num_layers=4, bias=True, batch_first=True, bidirectional=True, proj_size=10)
self.lstm_1_4 = nn.LSTM(input_size=20, hidden_size=16, num_layers=4, bias=True, batch_first=True, bidirectional=True, proj_size=10)

def forward(self, x, y):
x = x.permute(1, 0, 2)
@@ -38,14 +38,14 @@ class Model(nn.Module):
x0, _ = self.lstm_0_0(x)
x1, _ = self.lstm_0_1(x0)
x2, (h0, c0) = self.lstm_0_2(x1)
x3, (h1, c1) = self.lstm_0_3(x1, (h0, c0))
x4, _ = self.lstm_0_4(x1, (h1, c1))
x3, (h1, c1) = self.lstm_0_3(x2, (h0, c0))
x4, _ = self.lstm_0_4(x3, (h1, c1))

y0, _ = self.lstm_1_0(y)
y1, _ = self.lstm_1_1(y0)
y2, (h2, c2) = self.lstm_1_2(y1)
y3, (h3, c3) = self.lstm_1_3(y1, (h2, c2))
y4, _ = self.lstm_1_4(y1, (h3, c3))
y3, (h3, c3) = self.lstm_1_3(y2, (h2, c2))
y4, _ = self.lstm_1_4(y3, (h3, c3))

x2 = x2.permute(1, 0, 2)
x3 = x3.permute(1, 0, 2)


+ 6
- 6
tools/pnnx/tests/test_nn_LSTM.py View File

@@ -22,24 +22,24 @@ class Model(nn.Module):

self.lstm_0_0 = nn.LSTM(input_size=32, hidden_size=16)
self.lstm_0_1 = nn.LSTM(input_size=16, hidden_size=16, num_layers=3, bias=False)
self.lstm_0_2 = nn.LSTM(input_size=16, hidden_size=16, num_layers=4, bias=True, bidirectional=True)
self.lstm_0_3 = nn.LSTM(input_size=16, hidden_size=16, num_layers=4, bias=True, bidirectional=True)
self.lstm_0_2 = nn.LSTM(input_size=16, hidden_size=16, num_layers=4, bias=True, bidirectional=True, proj_size=10)
self.lstm_0_3 = nn.LSTM(input_size=20, hidden_size=16, num_layers=4, bias=True, bidirectional=True, proj_size=10)

self.lstm_1_0 = nn.LSTM(input_size=25, hidden_size=16, batch_first=True)
self.lstm_1_1 = nn.LSTM(input_size=16, hidden_size=16, num_layers=3, bias=False, batch_first=True)
self.lstm_1_2 = nn.LSTM(input_size=16, hidden_size=16, num_layers=4, bias=True, batch_first=True, bidirectional=True)
self.lstm_1_3 = nn.LSTM(input_size=16, hidden_size=16, num_layers=4, bias=True, batch_first=True, bidirectional=True)
self.lstm_1_2 = nn.LSTM(input_size=16, hidden_size=16, num_layers=4, bias=True, batch_first=True, bidirectional=True, proj_size=10)
self.lstm_1_3 = nn.LSTM(input_size=20, hidden_size=16, num_layers=4, bias=True, batch_first=True, bidirectional=True, proj_size=10)

def forward(self, x, y):
x0, (h0, c0) = self.lstm_0_0(x)
x1, (h1, c1) = self.lstm_0_1(x0)
x2, (h2, c2) = self.lstm_0_2(x1)
x3, (h3, c3) = self.lstm_0_3(x1, (h2, c2))
x3, (h3, c3) = self.lstm_0_3(x2, (h2, c2))

y0, (h4, c4) = self.lstm_1_0(y)
y1, (h5, c5) = self.lstm_1_1(y0)
y2, (h6, c6) = self.lstm_1_2(y1)
y3, (h7, c7) = self.lstm_1_3(y1, (h6, c6))
y3, (h7, c7) = self.lstm_1_3(y2, (h6, c6))
return x2, x3, h0, h1, h2, h3, c0, c1, c2, c3, y2, y3, h4, h5, h6, h7, c4, c5, c6, c7

def test():


Loading…
Cancel
Save