diff --git a/src/layer/arm/gru_arm.cpp b/src/layer/arm/gru_arm.cpp index 97a2bb61e..a49147813 100644 --- a/src/layer/arm/gru_arm.cpp +++ b/src/layer/arm/gru_arm.cpp @@ -695,13 +695,7 @@ int GRU_arm::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) c int GRU_arm::forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const { - if (bottom_blobs.size() != 2 || top_blobs.size() != 2) - { - return forward(bottom_blobs[0], top_blobs[0], opt); - } - const Mat& bottom_blob = bottom_blobs[0]; - int elembits = bottom_blob.elembits(); #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC @@ -720,24 +714,72 @@ int GRU_arm::forward(const std::vector& bottom_blobs, std::vector& top #endif int T = bottom_blob.h; - Mat& top_blob = top_blobs[0]; - Mat& hidden_state = top_blobs[1]; + int num_directions = direction == 2 ? 2 : 1; - //Copy previous states - hidden_state = bottom_blobs[1].clone(opt.blob_allocator); + Mat hidden; + Allocator* hidden_allocator = top_blobs.size() == 2 ? opt.blob_allocator : opt.workspace_allocator; + if (bottom_blobs.size() == 2) + { + hidden = bottom_blobs[1].clone(hidden_allocator); + } + else + { + hidden.create(num_output, num_directions, 4u, hidden_allocator); + if (hidden.empty()) + return -100; + hidden.fill(0.f); + } - top_blob.create(num_output, T, 4u, opt.blob_allocator); + Mat& top_blob = top_blobs[0]; + top_blob.create(num_output * num_directions, T, 4u, opt.blob_allocator); if (top_blob.empty()) return -100; // Uni directional if (direction == 0 || direction == 1) { - int ret = gru(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden_state, opt); + int ret = gru(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden, opt); if (ret != 0) return ret; } + if (direction == 2) + { + Mat top_blob_forward(num_output, T, 4u, opt.workspace_allocator); + if (top_blob_forward.empty()) + return -100; + + Mat top_blob_reverse(num_output, T, 4u, opt.workspace_allocator); + if (top_blob_reverse.empty()) + return -100; + + Mat hidden0 = hidden.row_range(0, 1); + int ret0 = gru(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden0, opt); + if (ret0 != 0) + return ret0; + + Mat hidden1 = hidden.row_range(1, 1); + int ret1 = gru(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), hidden1, opt); + if (ret1 != 0) + return ret1; + + // concat w + for (int i = 0; i < T; i++) + { + const float* pf = top_blob_forward.row(i); + const float* pr = top_blob_reverse.row(i); + float* ptr = top_blob.row(i); + + memcpy(ptr, pf, num_output * sizeof(float)); + memcpy(ptr + num_output, pr, num_output * sizeof(float)); + } + } + + if (top_blobs.size() == 2) + { + top_blobs[1] = hidden; + } + return 0; } @@ -1625,16 +1667,29 @@ int GRU_arm::forward_fp16s(const std::vector& bottom_blobs, std::vector& bottom_blobs, std::vector(i); + const __fp16* pr = top_blob_reverse.row(i); + __fp16* ptr = top_blob.row<__fp16>(i); + + memcpy(ptr, pf, num_output * sizeof(__fp16)); + memcpy(ptr + num_output, pr, num_output * sizeof(__fp16)); + } + } + + if (top_blobs.size() == 2) + { + cast_float32_to_float16(hidden, top_blobs[1], opt); + } return 0; } @@ -1711,16 +1801,29 @@ int GRU_arm::forward_fp16sa(const std::vector& bottom_blobs, std::vector& bottom_blobs, std::vector(i); + const __fp16* pr = top_blob_reverse.row(i); + __fp16* ptr = top_blob.row<__fp16>(i); + + memcpy(ptr, pf, num_output * sizeof(__fp16)); + memcpy(ptr + num_output, pr, num_output * sizeof(__fp16)); + } + } + + if (top_blobs.size() == 2) + { + top_blobs[1] = hidden; + } return 0; } @@ -2365,16 +2503,29 @@ int GRU_arm::forward_bf16s(const std::vector& bottom_blobs, std::vector& bottom_blobs, std::vector(i); + const unsigned short* pr = top_blob_reverse.row(i); + unsigned short* ptr = top_blob.row(i); + + memcpy(ptr, pf, num_output * sizeof(unsigned short)); + memcpy(ptr + num_output, pr, num_output * sizeof(unsigned short)); + } + } + + if (top_blobs.size() == 2) + { + cast_float32_to_bfloat16(hidden, top_blobs[1], opt); + } return 0; } diff --git a/src/layer/arm/lstm_arm.cpp b/src/layer/arm/lstm_arm.cpp index 7b8511b8f..185b8ff54 100644 --- a/src/layer/arm/lstm_arm.cpp +++ b/src/layer/arm/lstm_arm.cpp @@ -423,13 +423,7 @@ int LSTM_arm::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) int LSTM_arm::forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const { - if (bottom_blobs.size() != 3 || top_blobs.size() != 3) - { - return forward(bottom_blobs[0], top_blobs[0], opt); - } - const Mat& bottom_blob = bottom_blobs[0]; - int elembits = bottom_blob.elembits(); #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC @@ -448,26 +442,82 @@ int LSTM_arm::forward(const std::vector& bottom_blobs, std::vector& to #endif int T = bottom_blob.h; - Mat& top_blob = top_blobs[0]; - Mat& hidden_state = top_blobs[1]; - Mat& cell_state = top_blobs[2]; + int num_directions = direction == 2 ? 2 : 1; + + Mat hidden; + Mat cell; + Allocator* hidden_cell_allocator = top_blobs.size() == 3 ? opt.blob_allocator : opt.workspace_allocator; + if (bottom_blobs.size() == 3) + { + hidden = bottom_blobs[1].clone(hidden_cell_allocator); + cell = bottom_blobs[2].clone(hidden_cell_allocator); + } + else + { + hidden.create(num_output, num_directions, 4u, hidden_cell_allocator); + if (hidden.empty()) + return -100; + hidden.fill(0.f); - //Copy previous states - hidden_state = bottom_blobs[1].clone(opt.blob_allocator); - cell_state = bottom_blobs[2].clone(opt.blob_allocator); + cell.create(num_output, num_directions, 4u, hidden_cell_allocator); + if (cell.empty()) + return -100; + cell.fill(0.f); + } - top_blob.create(num_output, T, 4u, opt.blob_allocator); + Mat& top_blob = top_blobs[0]; + top_blob.create(num_output * num_directions, T, 4u, opt.blob_allocator); if (top_blob.empty()) return -100; // Uni directional if (direction == 0 || direction == 1) { - int ret = lstm(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden_state, cell_state, opt); + int ret = lstm(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden, cell, opt); if (ret != 0) return ret; } + if (direction == 2) + { + Mat top_blob_forward(num_output, T, 4u, opt.workspace_allocator); + if (top_blob_forward.empty()) + return -100; + + Mat top_blob_reverse(num_output, T, 4u, opt.workspace_allocator); + if (top_blob_reverse.empty()) + return -100; + + Mat hidden0 = hidden.row_range(0, 1); + Mat cell0 = cell.row_range(0, 1); + int ret0 = lstm(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden0, cell0, opt); + if (ret0 != 0) + return ret0; + + Mat hidden1 = hidden.row_range(1, 1); + Mat cell1 = cell.row_range(1, 1); + int ret1 = lstm(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), hidden1, cell1, opt); + if (ret1 != 0) + return ret1; + + // concat w + for (int i = 0; i < T; i++) + { + const float* pf = top_blob_forward.row(i); + const float* pr = top_blob_reverse.row(i); + float* ptr = top_blob.row(i); + + memcpy(ptr, pf, num_output * sizeof(float)); + memcpy(ptr + num_output, pr, num_output * sizeof(float)); + } + } + + if (top_blobs.size() == 3) + { + top_blobs[1] = hidden; + top_blobs[2] = cell; + } + return 0; } @@ -1182,17 +1232,35 @@ int LSTM_arm::forward_fp16s(const std::vector& bottom_blobs, std::vector& bottom_blobs, std::vector(i); + const __fp16* pr = top_blob_reverse.row(i); + __fp16* ptr = top_blob.row<__fp16>(i); + + memcpy(ptr, pf, num_output * sizeof(__fp16)); + memcpy(ptr + num_output, pr, num_output * sizeof(__fp16)); + } + } + + if (top_blobs.size() == 3) + { + cast_float32_to_float16(hidden, top_blobs[1], opt); + cast_float32_to_float16(cell, top_blobs[2], opt); + } return 0; } @@ -1277,17 +1382,35 @@ int LSTM_arm::forward_fp16sa(const std::vector& bottom_blobs, std::vector& bottom_blobs, std::vector(i); + const __fp16* pr = top_blob_reverse.row(i); + __fp16* ptr = top_blob.row<__fp16>(i); + + memcpy(ptr, pf, num_output * sizeof(__fp16)); + memcpy(ptr + num_output, pr, num_output * sizeof(__fp16)); + } + } + + if (top_blobs.size() == 3) + { + cast_float32_to_float16(hidden, top_blobs[1], opt); + cast_float32_to_float16(cell, top_blobs[2], opt); + } return 0; } @@ -1664,17 +1824,35 @@ int LSTM_arm::forward_bf16s(const std::vector& bottom_blobs, std::vector& bottom_blobs, std::vector(i); + const unsigned short* pr = top_blob_reverse.row(i); + unsigned short* ptr = top_blob.row(i); + + memcpy(ptr, pf, num_output * sizeof(unsigned short)); + memcpy(ptr + num_output, pr, num_output * sizeof(unsigned short)); + } + } + + if (top_blobs.size() == 3) + { + cast_float32_to_bfloat16(hidden, top_blobs[1], opt); + cast_float32_to_bfloat16(cell, top_blobs[2], opt); + } return 0; } diff --git a/src/layer/arm/rnn_arm.cpp b/src/layer/arm/rnn_arm.cpp index 9bd5bcc29..c113eb282 100644 --- a/src/layer/arm/rnn_arm.cpp +++ b/src/layer/arm/rnn_arm.cpp @@ -377,13 +377,7 @@ int RNN_arm::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) c int RNN_arm::forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const { - if (bottom_blobs.size() != 2 || top_blobs.size() != 2) - { - return forward(bottom_blobs[0], top_blobs[0], opt); - } - const Mat& bottom_blob = bottom_blobs[0]; - int elembits = bottom_blob.elembits(); #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC @@ -402,24 +396,72 @@ int RNN_arm::forward(const std::vector& bottom_blobs, std::vector& top #endif int T = bottom_blob.h; - Mat& top_blob = top_blobs[0]; - Mat& hidden_state = top_blobs[1]; + int num_directions = direction == 2 ? 2 : 1; - //Copy previous states - hidden_state = bottom_blobs[1].clone(opt.blob_allocator); + Mat hidden; + Allocator* hidden_allocator = top_blobs.size() == 2 ? opt.blob_allocator : opt.workspace_allocator; + if (bottom_blobs.size() == 2) + { + hidden = bottom_blobs[1].clone(hidden_allocator); + } + else + { + hidden.create(num_output, num_directions, 4u, hidden_allocator); + if (hidden.empty()) + return -100; + hidden.fill(0.f); + } - top_blob.create(num_output, T, 4u, opt.blob_allocator); + Mat& top_blob = top_blobs[0]; + top_blob.create(num_output * num_directions, T, 4u, opt.blob_allocator); if (top_blob.empty()) return -100; // Uni directional if (direction == 0 || direction == 1) { - int ret = rnn(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden_state, opt); + int ret = rnn(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden, opt); if (ret != 0) return ret; } + if (direction == 2) + { + Mat top_blob_forward(num_output, T, 4u, opt.workspace_allocator); + if (top_blob_forward.empty()) + return -100; + + Mat top_blob_reverse(num_output, T, 4u, opt.workspace_allocator); + if (top_blob_reverse.empty()) + return -100; + + Mat hidden0 = hidden.row_range(0, 1); + int ret0 = rnn(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden0, opt); + if (ret0 != 0) + return ret0; + + Mat hidden1 = hidden.row_range(1, 1); + int ret1 = rnn(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), hidden1, opt); + if (ret1 != 0) + return ret1; + + // concat w + for (int i = 0; i < T; i++) + { + const float* pf = top_blob_forward.row(i); + const float* pr = top_blob_reverse.row(i); + float* ptr = top_blob.row(i); + + memcpy(ptr, pf, num_output * sizeof(float)); + memcpy(ptr + num_output, pr, num_output * sizeof(float)); + } + } + + if (top_blobs.size() == 2) + { + top_blobs[1] = hidden; + } + return 0; } @@ -965,16 +1007,29 @@ int RNN_arm::forward_fp16s(const std::vector& bottom_blobs, std::vector& bottom_blobs, std::vector(i); + const __fp16* pr = top_blob_reverse.row(i); + __fp16* ptr = top_blob.row<__fp16>(i); + + memcpy(ptr, pf, num_output * sizeof(__fp16)); + memcpy(ptr + num_output, pr, num_output * sizeof(__fp16)); + } + } + + if (top_blobs.size() == 2) + { + cast_float32_to_float16(hidden, top_blobs[1], opt); + } return 0; } @@ -1051,16 +1141,29 @@ int RNN_arm::forward_fp16sa(const std::vector& bottom_blobs, std::vector& bottom_blobs, std::vector(i); + const __fp16* pr = top_blob_reverse.row(i); + __fp16* ptr = top_blob.row<__fp16>(i); + + memcpy(ptr, pf, num_output * sizeof(__fp16)); + memcpy(ptr + num_output, pr, num_output * sizeof(__fp16)); + } + } + + if (top_blobs.size() == 2) + { + cast_float32_to_float16(hidden, top_blobs[1], opt); + } return 0; } @@ -1387,16 +1525,29 @@ int RNN_arm::forward_bf16s(const std::vector& bottom_blobs, std::vector& bottom_blobs, std::vector(i); + const unsigned short* pr = top_blob_reverse.row(i); + unsigned short* ptr = top_blob.row(i); + + memcpy(ptr, pf, num_output * sizeof(unsigned short)); + memcpy(ptr + num_output, pr, num_output * sizeof(unsigned short)); + } + } + + if (top_blobs.size() == 2) + { + cast_float32_to_bfloat16(hidden, top_blobs[1], opt); + } return 0; } diff --git a/src/layer/gru.cpp b/src/layer/gru.cpp index 6183f4f0a..5690252ad 100644 --- a/src/layer/gru.cpp +++ b/src/layer/gru.cpp @@ -29,8 +29,6 @@ int GRU::load_param(const ParamDict& pd) num_output = pd.get(0, 0); weight_data_size = pd.get(1, 0); direction = pd.get(2, 0); - if (direction == 2) - one_blob_only = true; return 0; } @@ -223,30 +221,74 @@ int GRU::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const int GRU::forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const { - if (bottom_blobs.size() != 2 || top_blobs.size() != 2) - { - return forward(bottom_blobs[0], top_blobs[0], opt); - } const Mat& bottom_blob = bottom_blobs[0]; int T = bottom_blob.h; - Mat& top_blob = top_blobs[0]; - Mat& hidden_state = top_blobs[1]; + int num_directions = direction == 2 ? 2 : 1; - //Copy previous states - hidden_state = bottom_blobs[1].clone(opt.blob_allocator); + Mat hidden; + Allocator* hidden_allocator = top_blobs.size() == 2 ? opt.blob_allocator : opt.workspace_allocator; + if (bottom_blobs.size() == 2) + { + hidden = bottom_blobs[1].clone(hidden_allocator); + } + else + { + hidden.create(num_output, num_directions, 4u, hidden_allocator); + if (hidden.empty()) + return -100; + hidden.fill(0.f); + } - top_blob.create(num_output, T, 4u, opt.blob_allocator); + Mat& top_blob = top_blobs[0]; + top_blob.create(num_output * num_directions, T, 4u, opt.blob_allocator); if (top_blob.empty()) return -100; // Uni directional if (direction == 0 || direction == 1) { - int ret = gru(bottom_blob, top_blob, direction, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), hidden_state, opt); + int ret = gru(bottom_blob, top_blob, direction, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), hidden, opt); if (ret != 0) return ret; } + if (direction == 2) + { + Mat top_blob_forward(num_output, T, 4u, opt.workspace_allocator); + if (top_blob_forward.empty()) + return -100; + + Mat top_blob_reverse(num_output, T, 4u, opt.workspace_allocator); + if (top_blob_reverse.empty()) + return -100; + + Mat hidden0 = hidden.row_range(0, 1); + int ret0 = gru(bottom_blob, top_blob_forward, 0, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), hidden0, opt); + if (ret0 != 0) + return ret0; + + Mat hidden1 = hidden.row_range(1, 1); + int ret1 = gru(bottom_blob, top_blob_reverse, 1, weight_xc_data.channel(1), bias_c_data.channel(1), weight_hc_data.channel(1), hidden1, opt); + if (ret1 != 0) + return ret1; + + // concat w + for (int i = 0; i < T; i++) + { + const float* pf = top_blob_forward.row(i); + const float* pr = top_blob_reverse.row(i); + float* ptr = top_blob.row(i); + + memcpy(ptr, pf, num_output * sizeof(float)); + memcpy(ptr + num_output, pr, num_output * sizeof(float)); + } + } + + if (top_blobs.size() == 2) + { + top_blobs[1] = hidden; + } + return 0; } diff --git a/src/layer/lstm.cpp b/src/layer/lstm.cpp index ad7720085..7907e1dc0 100644 --- a/src/layer/lstm.cpp +++ b/src/layer/lstm.cpp @@ -29,8 +29,6 @@ int LSTM::load_param(const ParamDict& pd) num_output = pd.get(0, 0); weight_data_size = pd.get(1, 0); direction = pd.get(2, 0); - if (direction == 2) - one_blob_only = true; return 0; } @@ -232,32 +230,84 @@ int LSTM::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) cons int LSTM::forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const { - if (bottom_blobs.size() != 3 || top_blobs.size() != 3) - { - return forward(bottom_blobs[0], top_blobs[0], opt); - } const Mat& bottom_blob = bottom_blobs[0]; int T = bottom_blob.h; - Mat& top_blob = top_blobs[0]; - Mat& hidden_state = top_blobs[1]; - Mat& cell_state = top_blobs[2]; + int num_directions = direction == 2 ? 2 : 1; + + Mat hidden; + Mat cell; + Allocator* hidden_cell_allocator = top_blobs.size() == 3 ? opt.blob_allocator : opt.workspace_allocator; + if (bottom_blobs.size() == 3) + { + hidden = bottom_blobs[1].clone(hidden_cell_allocator); + cell = bottom_blobs[2].clone(hidden_cell_allocator); + } + else + { + hidden.create(num_output, num_directions, 4u, hidden_cell_allocator); + if (hidden.empty()) + return -100; + hidden.fill(0.f); - //Copy previous states - hidden_state = bottom_blobs[1].clone(opt.blob_allocator); - cell_state = bottom_blobs[2].clone(opt.blob_allocator); + cell.create(num_output, num_directions, 4u, hidden_cell_allocator); + if (cell.empty()) + return -100; + cell.fill(0.f); + } - top_blob.create(num_output, T, 4u, opt.blob_allocator); + Mat& top_blob = top_blobs[0]; + top_blob.create(num_output * num_directions, T, 4u, opt.blob_allocator); if (top_blob.empty()) return -100; // Uni directional if (direction == 0 || direction == 1) { - int ret = lstm(bottom_blob, top_blob, direction, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), hidden_state, cell_state, opt); + int ret = lstm(bottom_blob, top_blob, direction, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), hidden, cell, opt); if (ret != 0) return ret; } + if (direction == 2) + { + Mat top_blob_forward(num_output, T, 4u, opt.workspace_allocator); + if (top_blob_forward.empty()) + return -100; + + Mat top_blob_reverse(num_output, T, 4u, opt.workspace_allocator); + if (top_blob_reverse.empty()) + return -100; + + Mat hidden0 = hidden.row_range(0, 1); + Mat cell0 = cell.row_range(0, 1); + int ret0 = lstm(bottom_blob, top_blob_forward, 0, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), hidden0, cell0, opt); + if (ret0 != 0) + return ret0; + + Mat hidden1 = hidden.row_range(1, 1); + Mat cell1 = cell.row_range(1, 1); + int ret1 = lstm(bottom_blob, top_blob_reverse, 1, weight_xc_data.channel(1), bias_c_data.channel(1), weight_hc_data.channel(1), hidden1, cell1, opt); + if (ret1 != 0) + return ret1; + + // concat w + for (int i = 0; i < T; i++) + { + const float* pf = top_blob_forward.row(i); + const float* pr = top_blob_reverse.row(i); + float* ptr = top_blob.row(i); + + memcpy(ptr, pf, num_output * sizeof(float)); + memcpy(ptr + num_output, pr, num_output * sizeof(float)); + } + } + + if (top_blobs.size() == 3) + { + top_blobs[1] = hidden; + top_blobs[2] = cell; + } + return 0; } diff --git a/src/layer/riscv/gru_riscv.cpp b/src/layer/riscv/gru_riscv.cpp index 0ebb2879b..20666f5ee 100644 --- a/src/layer/riscv/gru_riscv.cpp +++ b/src/layer/riscv/gru_riscv.cpp @@ -301,11 +301,6 @@ int GRU_riscv::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) int GRU_riscv::forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const { - if (bottom_blobs.size() != 2 || top_blobs.size() != 2) - { - return forward(bottom_blobs[0], top_blobs[0], opt); - } - const Mat& bottom_blob = bottom_blobs[0]; int elembits = bottom_blob.elembits(); @@ -321,24 +316,73 @@ int GRU_riscv::forward(const std::vector& bottom_blobs, std::vector& t #endif int T = bottom_blob.h; - Mat& top_blob = top_blobs[0]; - Mat& hidden_state = top_blobs[1]; + int num_directions = direction == 2 ? 2 : 1; - //Copy previous states - hidden_state = bottom_blobs[1].clone(opt.blob_allocator); + Mat hidden; + Allocator* hidden_allocator = top_blobs.size() == 2 ? opt.blob_allocator : opt.workspace_allocator; + if (bottom_blobs.size() == 2) + { + hidden = bottom_blobs[1].clone(hidden_allocator); + } + else + { + hidden.create(num_output, num_directions, 4u, hidden_allocator); + if (hidden.empty()) + return -100; + hidden.fill(0.f); + } - top_blob.create(num_output, T, 4u, opt.blob_allocator); + Mat& top_blob = top_blobs[0]; + top_blob.create(num_output * num_directions, T, 4u, opt.blob_allocator); if (top_blob.empty()) return -100; // Uni directional if (direction == 0 || direction == 1) { - int ret = gru(bottom_blob, top_blob, direction, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), hidden_state, opt); + Mat hidden0 = hidden.row_range(0, 1); + int ret = gru(bottom_blob, top_blob, direction, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), hidden0, opt); if (ret != 0) return ret; } + if (direction == 2) + { + Mat top_blob_forward(num_output, T, 4u, opt.workspace_allocator); + if (top_blob_forward.empty()) + return -100; + + Mat top_blob_reverse(num_output, T, 4u, opt.workspace_allocator); + if (top_blob_reverse.empty()) + return -100; + + Mat hidden0 = hidden.row_range(0, 1); + int ret0 = gru(bottom_blob, top_blob_forward, 0, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), hidden0, opt); + if (ret0 != 0) + return ret0; + + Mat hidden1 = hidden.row_range(1, 1); + int ret1 = gru(bottom_blob, top_blob_reverse, 1, weight_xc_data.channel(1), bias_c_data.channel(1), weight_hc_data.channel(1), hidden1, opt); + if (ret1 != 0) + return ret1; + + // concat w + for (int i = 0; i < T; i++) + { + const float* pf = top_blob_forward.row(i); + const float* pr = top_blob_reverse.row(i); + float* ptr = top_blob.row(i); + + memcpy(ptr, pf, num_output * sizeof(float)); + memcpy(ptr + num_output, pr, num_output * sizeof(float)); + } + } + + if (top_blobs.size() == 2) + { + top_blobs[1] = hidden; + } + return 0; #endif return GRU::forward(bottom_blobs, top_blobs, opt); @@ -587,24 +631,75 @@ int GRU_riscv::forward_fp16s(const std::vector& bottom_blobs, std::vector(i); + const __fp16* pr = top_blob_reverse.row(i); + __fp16* ptr = top_blob.row<__fp16>(i); + + memcpy(ptr, pf, num_output * sizeof(__fp16)); + memcpy(ptr + num_output, pr, num_output * sizeof(__fp16)); + } + } + + if (top_blobs.size() == 2) + { + cast_float32_to_float16(hidden, top_blobs[1], opt); + } + return 0; } @@ -853,15 +948,29 @@ int GRU_riscv::forward_fp16sa(const std::vector& bottom_blobs, std::vector< { const Mat& bottom_blob = bottom_blobs[0]; int T = bottom_blob.h; + int num_directions = direction == 2 ? 2 : 1; + + Mat hidden; + Allocator* hidden_allocator = top_blobs.size() == 2 ? opt.blob_allocator : opt.workspace_allocator; + if (bottom_blobs.size() == 2) + { + Option opt_cast = opt; + opt_cast.blob_allocator = hidden_allocator; + cast_float16_to_float32(bottom_blobs[1], hidden, opt_cast); + } + else + { + hidden.create(num_output, num_directions, 4u, hidden_allocator); + if (hidden.empty()) + return -100; + hidden.fill(0.f); + } + Mat& top_blob = top_blobs[0]; - top_blob.create(num_output, T, 2u, opt.blob_allocator); + top_blob.create(num_output * num_directions, T, 2u, opt.blob_allocator); if (top_blob.empty()) return -100; - //Copy previous states - Mat hidden; - cast_float16_to_float32(bottom_blobs[1], hidden, opt); - // Uni directional if (direction == 0 || direction == 1) { @@ -870,11 +979,46 @@ int GRU_riscv::forward_fp16sa(const std::vector& bottom_blobs, std::vector< return ret; } - cast_float32_to_float16(hidden, top_blobs[1], opt); + if (direction == 2) + { + Mat top_blob_forward(num_output, T, 2u, opt.workspace_allocator); + if (top_blob_forward.empty()) + return -100; + + Mat top_blob_reverse(num_output, T, 2u, opt.workspace_allocator); + if (top_blob_reverse.empty()) + return -100; + + Mat hidden0 = hidden.row_range(0, 1); + int ret0 = gru_fp16sa(bottom_blob, top_blob_forward, 0, weight_xc_data_fp16sa.channel(0), bias_c_data_fp16sa.channel(0), weight_hc_data_fp16sa.channel(0), hidden0, opt); + if (ret0 != 0) + return ret0; + + Mat hidden1 = hidden.row_range(1, 1); + int ret1 = gru_fp16sa(bottom_blob, top_blob_reverse, 1, weight_xc_data_fp16sa.channel(1), bias_c_data_fp16sa.channel(1), weight_hc_data_fp16sa.channel(1), hidden1, opt); + if (ret1 != 0) + return ret1; + + // concat w + for (int i = 0; i < T; i++) + { + const __fp16* pf = top_blob_forward.row(i); + const __fp16* pr = top_blob_reverse.row(i); + __fp16* ptr = top_blob.row<__fp16>(i); + + memcpy(ptr, pf, num_output * sizeof(__fp16)); + memcpy(ptr + num_output, pr, num_output * sizeof(__fp16)); + } + } + + if (top_blobs.size() == 2) + { + cast_float32_to_float16(hidden, top_blobs[1], opt); + } return 0; } #endif -} // namespace ncnn \ No newline at end of file +} // namespace ncnn diff --git a/src/layer/rnn.cpp b/src/layer/rnn.cpp index efe9b8461..54627d5c7 100644 --- a/src/layer/rnn.cpp +++ b/src/layer/rnn.cpp @@ -29,8 +29,6 @@ int RNN::load_param(const ParamDict& pd) num_output = pd.get(0, 0); weight_data_size = pd.get(1, 0); direction = pd.get(2, 0); - if (direction == 2) - one_blob_only = true; return 0; } @@ -172,30 +170,74 @@ int RNN::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const int RNN::forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const { - if (bottom_blobs.size() != 2 || top_blobs.size() != 2) - { - return forward(bottom_blobs[0], top_blobs[0], opt); - } const Mat& bottom_blob = bottom_blobs[0]; int T = bottom_blob.h; - Mat& top_blob = top_blobs[0]; - Mat& hidden_state = top_blobs[1]; + int num_directions = direction == 2 ? 2 : 1; - //Copy previous states - hidden_state = bottom_blobs[1].clone(opt.blob_allocator); + Mat hidden; + Allocator* hidden_allocator = top_blobs.size() == 2 ? opt.blob_allocator : opt.workspace_allocator; + if (bottom_blobs.size() == 2) + { + hidden = bottom_blobs[1].clone(hidden_allocator); + } + else + { + hidden.create(num_output, num_directions, 4u, hidden_allocator); + if (hidden.empty()) + return -100; + hidden.fill(0.f); + } - top_blob.create(num_output, T, 4u, opt.blob_allocator); + Mat& top_blob = top_blobs[0]; + top_blob.create(num_output * num_directions, T, 4u, opt.blob_allocator); if (top_blob.empty()) return -100; // Uni directional if (direction == 0 || direction == 1) { - int ret = rnn(bottom_blob, top_blob, direction, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), hidden_state, opt); + int ret = rnn(bottom_blob, top_blob, direction, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), hidden, opt); if (ret != 0) return ret; } + if (direction == 2) + { + Mat top_blob_forward(num_output, T, 4u, opt.workspace_allocator); + if (top_blob_forward.empty()) + return -100; + + Mat top_blob_reverse(num_output, T, 4u, opt.workspace_allocator); + if (top_blob_reverse.empty()) + return -100; + + Mat hidden0 = hidden.row_range(0, 1); + int ret0 = rnn(bottom_blob, top_blob_forward, 0, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), hidden0, opt); + if (ret0 != 0) + return ret0; + + Mat hidden1 = hidden.row_range(1, 1); + int ret1 = rnn(bottom_blob, top_blob_reverse, 1, weight_xc_data.channel(1), bias_c_data.channel(1), weight_hc_data.channel(1), hidden1, opt); + if (ret1 != 0) + return ret1; + + // concat w + for (int i = 0; i < T; i++) + { + const float* pf = top_blob_forward.row(i); + const float* pr = top_blob_reverse.row(i); + float* ptr = top_blob.row(i); + + memcpy(ptr, pf, num_output * sizeof(float)); + memcpy(ptr + num_output, pr, num_output * sizeof(float)); + } + } + + if (top_blobs.size() == 2) + { + top_blobs[1] = hidden; + } + return 0; } diff --git a/src/layer/x86/lstm_x86.cpp b/src/layer/x86/lstm_x86.cpp index 0c8d42d13..6e55fd700 100644 --- a/src/layer/x86/lstm_x86.cpp +++ b/src/layer/x86/lstm_x86.cpp @@ -910,42 +910,123 @@ int LSTM_x86::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) int LSTM_x86::forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const { #if __AVX__ - if (bottom_blobs.size() != 3 || top_blobs.size() != 3) - { - return forward(bottom_blobs[0], top_blobs[0], opt); - } const Mat& bottom_blob = bottom_blobs[0]; - int T = bottom_blob.h; - Mat& top_blob = top_blobs[0]; - Mat& hidden_state = top_blobs[1]; - Mat& cell_state = top_blobs[2]; + int num_directions = direction == 2 ? 2 : 1; - //Copy previous states - hidden_state = bottom_blobs[1].clone(opt.blob_allocator); - cell_state = bottom_blobs[2].clone(opt.blob_allocator); + Mat hidden; + Mat cell; + Allocator* hidden_cell_allocator = top_blobs.size() == 3 ? opt.blob_allocator : opt.workspace_allocator; + if (bottom_blobs.size() == 3) + { + hidden = bottom_blobs[1].clone(hidden_cell_allocator); + cell = bottom_blobs[2].clone(hidden_cell_allocator); + } + else + { + hidden.create(num_output, num_directions, 4u, hidden_cell_allocator); + if (hidden.empty()) + return -100; + hidden.fill(0.f); - top_blob.create(num_output, T, 4u, opt.blob_allocator); + cell.create(num_output, num_directions, 4u, hidden_cell_allocator); + if (cell.empty()) + return -100; + cell.fill(0.f); + } + + Mat& top_blob = top_blobs[0]; + top_blob.create(num_output * num_directions, T, 4u, opt.blob_allocator); if (top_blob.empty()) return -100; -#if __AVX2__ - if (opt.use_weight_fp16_storage) + + // Uni directional + if (direction == 0 || direction == 1) { - // Uni directional - int ret = lstm_fp16(bottom_blob, top_blob, direction, weight_xc_data_fp16.channel(0), bias_c_data.channel(0), weight_hc_data_fp16.channel(0), hidden_state, cell_state, opt); - if (ret != 0) - return ret; +#if __AVX2__ + if (opt.use_weight_fp16_storage) + { + int ret = lstm_fp16(bottom_blob, top_blob, direction, weight_xc_data_fp16.channel(0), bias_c_data.channel(0), weight_hc_data_fp16.channel(0), hidden, cell, opt); + if (ret != 0) + return ret; + } + else + { +#endif + int ret = lstm(bottom_blob, top_blob, direction, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), hidden, cell, opt); + if (ret != 0) + return ret; +#if __AVX2__ + } +#endif } - else + + if (direction == 2) { + Mat top_blob_forward(num_output, T, 4u, opt.workspace_allocator); + if (top_blob_forward.empty()) + return -100; + + Mat top_blob_reverse(num_output, T, 4u, opt.workspace_allocator); + if (top_blob_reverse.empty()) + return -100; + + Mat hidden0 = hidden.row_range(0, 1); + Mat cell0 = cell.row_range(0, 1); +#if __AVX2__ + if (opt.use_weight_fp16_storage) + { + int ret = lstm_fp16(bottom_blob, top_blob_forward, 0, weight_xc_data_fp16.channel(0), bias_c_data.channel(0), weight_hc_data_fp16.channel(0), hidden0, cell0, opt); + if (ret != 0) + return ret; + } + else + { #endif - // Uni directional - int ret = lstm(bottom_blob, top_blob, direction, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), hidden_state, cell_state, opt); - if (ret != 0) - return ret; + int ret0 = lstm(bottom_blob, top_blob_forward, 0, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), hidden0, cell0, opt); + if (ret0 != 0) + return ret0; #if __AVX2__ - } + } #endif + + Mat hidden1 = hidden.row_range(1, 1); + Mat cell1 = cell.row_range(1, 1); +#if __AVX2__ + if (opt.use_weight_fp16_storage) + { + int ret = lstm_fp16(bottom_blob, top_blob_reverse, 1, weight_xc_data_fp16.channel(1), bias_c_data.channel(1), weight_hc_data_fp16.channel(1), hidden1, cell1, opt); + if (ret != 0) + return ret; + } + else + { +#endif + int ret1 = lstm(bottom_blob, top_blob_reverse, 1, weight_xc_data.channel(1), bias_c_data.channel(1), weight_hc_data.channel(1), hidden1, cell1, opt); + if (ret1 != 0) + return ret1; +#if __AVX2__ + } +#endif + + // concat w + for (int i = 0; i < T; i++) + { + const float* pf = top_blob_forward.row(i); + const float* pr = top_blob_reverse.row(i); + float* ptr = top_blob.row(i); + + memcpy(ptr, pf, num_output * sizeof(float)); + memcpy(ptr + num_output, pr, num_output * sizeof(float)); + } + } + + if (top_blobs.size() == 3) + { + top_blobs[1] = hidden; + top_blobs[2] = cell; + } + return 0; #else return LSTM::forward(bottom_blobs, top_blobs, opt); diff --git a/tests/test_gru.cpp b/tests/test_gru.cpp index af3b178bb..006e54487 100644 --- a/tests/test_gru.cpp +++ b/tests/test_gru.cpp @@ -42,19 +42,20 @@ static int test_gru(const ncnn::Mat& a, int outch, int direction) int test_gru_layer_with_hidden(const ncnn::Mat& a, int outch, int direction) { int input_size = a.w; + int num_directions = direction == 2 ? 2 : 1; ncnn::ParamDict pd; pd.set(0, outch); - pd.set(1, outch * input_size * 3); + pd.set(1, outch * input_size * 3 * num_directions); pd.set(2, direction); std::vector weights(3); - weights[0] = RandomMat(outch * input_size * 3); - weights[1] = RandomMat(outch * 4); - weights[2] = RandomMat(outch * outch * 3); + weights[0] = RandomMat(outch * input_size * 3 * num_directions); + weights[1] = RandomMat(outch * 4 * num_directions); + weights[2] = RandomMat(outch * outch * 3 * num_directions); // initial hidden state - ncnn::Mat hidden = RandomMat(outch); + ncnn::Mat hidden = RandomMat(outch, num_directions); std::vector as(2); as[0] = a; @@ -69,6 +70,64 @@ int test_gru_layer_with_hidden(const ncnn::Mat& a, int outch, int direction) return ret; } +int test_gru_layer_with_hidden_input(const ncnn::Mat& a, int outch, int direction) +{ + int input_size = a.w; + int num_directions = direction == 2 ? 2 : 1; + + ncnn::ParamDict pd; + pd.set(0, outch); + pd.set(1, outch * input_size * 3 * num_directions); + pd.set(2, direction); + + std::vector weights(3); + weights[0] = RandomMat(outch * input_size * 3 * num_directions); + weights[1] = RandomMat(outch * 4 * num_directions); + weights[2] = RandomMat(outch * outch * 3 * num_directions); + + // initial hidden state + ncnn::Mat hidden = RandomMat(outch, num_directions); + + std::vector as(2); + as[0] = a; + as[1] = hidden; + + int ret = test_layer("GRU", pd, weights, as, 1); + if (ret != 0) + { + fprintf(stderr, "test_gru_layer_with_hidden_input failed a.dims=%d a=(%d %d %d) outch=%d, direction = %d \n", a.dims, a.w, a.h, a.c, outch, direction); + } + + return ret; +} + +int test_gru_layer_with_hidden_output(const ncnn::Mat& a, int outch, int direction) +{ + int input_size = a.w; + int num_directions = direction == 2 ? 2 : 1; + + ncnn::ParamDict pd; + pd.set(0, outch); + pd.set(1, outch * input_size * 3 * num_directions); + pd.set(2, direction); + + std::vector weights(3); + weights[0] = RandomMat(outch * input_size * 3 * num_directions); + weights[1] = RandomMat(outch * 4 * num_directions); + weights[2] = RandomMat(outch * outch * 3 * num_directions); + + std::vector as(1); + as[0] = a; + + int ret = test_layer("GRU", pd, weights, as, 2); + if (ret != 0) + { + fprintf(stderr, "test_gru_layer_with_hidden_output failed a.dims=%d a=(%d %d %d) outch=%d, direction = %d \n", a.dims, a.w, a.h, a.c, outch, direction); + } + + return ret; +} + static int test_gru_0() { return 0 @@ -86,6 +145,14 @@ static int test_gru_0() static int test_gru_1() { return 0 + || test_gru_layer_with_hidden(RandomMat(4, 4), 1, 2) + || test_gru_layer_with_hidden(RandomMat(8, 2), 2, 2) + || test_gru_layer_with_hidden(RandomMat(16, 8), 7, 2) + || test_gru_layer_with_hidden(RandomMat(17, 8), 8, 2) + || test_gru_layer_with_hidden(RandomMat(19, 15), 8, 2) + || test_gru_layer_with_hidden(RandomMat(5, 16), 16, 2) + || test_gru_layer_with_hidden(RandomMat(3, 16), 8, 2) + || test_gru_layer_with_hidden(RandomMat(2, 5), 99, 2) || test_gru_layer_with_hidden(RandomMat(4, 4), 1, 1) || test_gru_layer_with_hidden(RandomMat(8, 2), 2, 1) || test_gru_layer_with_hidden(RandomMat(16, 8), 7, 1) @@ -101,7 +168,57 @@ static int test_gru_1() || test_gru_layer_with_hidden(RandomMat(19, 15), 8, 0) || test_gru_layer_with_hidden(RandomMat(5, 16), 16, 0) || test_gru_layer_with_hidden(RandomMat(3, 16), 8, 0) - || test_gru_layer_with_hidden(RandomMat(2, 5), 17, 0); + || test_gru_layer_with_hidden(RandomMat(2, 5), 17, 0) + + || test_gru_layer_with_hidden_input(RandomMat(4, 4), 1, 2) + || test_gru_layer_with_hidden_input(RandomMat(8, 2), 2, 2) + || test_gru_layer_with_hidden_input(RandomMat(16, 8), 7, 2) + || test_gru_layer_with_hidden_input(RandomMat(17, 8), 8, 2) + || test_gru_layer_with_hidden_input(RandomMat(19, 15), 8, 2) + || test_gru_layer_with_hidden_input(RandomMat(5, 16), 16, 2) + || test_gru_layer_with_hidden_input(RandomMat(3, 16), 8, 2) + || test_gru_layer_with_hidden_input(RandomMat(2, 5), 99, 2) + || test_gru_layer_with_hidden_input(RandomMat(4, 4), 1, 1) + || test_gru_layer_with_hidden_input(RandomMat(8, 2), 2, 1) + || test_gru_layer_with_hidden_input(RandomMat(16, 8), 7, 1) + || test_gru_layer_with_hidden_input(RandomMat(17, 8), 8, 1) + || test_gru_layer_with_hidden_input(RandomMat(19, 15), 8, 1) + || test_gru_layer_with_hidden_input(RandomMat(5, 16), 16, 1) + || test_gru_layer_with_hidden_input(RandomMat(3, 16), 8, 1) + || test_gru_layer_with_hidden_input(RandomMat(2, 5), 99, 1) + || test_gru_layer_with_hidden_input(RandomMat(4, 2), 1, 0) + || test_gru_layer_with_hidden_input(RandomMat(8, 2), 2, 0) + || test_gru_layer_with_hidden_input(RandomMat(16, 8), 7, 0) + || test_gru_layer_with_hidden_input(RandomMat(17, 8), 8, 0) + || test_gru_layer_with_hidden_input(RandomMat(19, 15), 8, 0) + || test_gru_layer_with_hidden_input(RandomMat(5, 16), 16, 0) + || test_gru_layer_with_hidden_input(RandomMat(3, 16), 8, 0) + || test_gru_layer_with_hidden_input(RandomMat(2, 5), 17, 0) + + || test_gru_layer_with_hidden_output(RandomMat(4, 4), 1, 2) + || test_gru_layer_with_hidden_output(RandomMat(8, 2), 2, 2) + || test_gru_layer_with_hidden_output(RandomMat(16, 8), 7, 2) + || test_gru_layer_with_hidden_output(RandomMat(17, 8), 8, 2) + || test_gru_layer_with_hidden_output(RandomMat(19, 15), 8, 2) + || test_gru_layer_with_hidden_output(RandomMat(5, 16), 16, 2) + || test_gru_layer_with_hidden_output(RandomMat(3, 16), 8, 2) + || test_gru_layer_with_hidden_output(RandomMat(2, 5), 99, 2) + || test_gru_layer_with_hidden_output(RandomMat(4, 4), 1, 1) + || test_gru_layer_with_hidden_output(RandomMat(8, 2), 2, 1) + || test_gru_layer_with_hidden_output(RandomMat(16, 8), 7, 1) + || test_gru_layer_with_hidden_output(RandomMat(17, 8), 8, 1) + || test_gru_layer_with_hidden_output(RandomMat(19, 15), 8, 1) + || test_gru_layer_with_hidden_output(RandomMat(5, 16), 16, 1) + || test_gru_layer_with_hidden_output(RandomMat(3, 16), 8, 1) + || test_gru_layer_with_hidden_output(RandomMat(2, 5), 99, 1) + || test_gru_layer_with_hidden_output(RandomMat(4, 2), 1, 0) + || test_gru_layer_with_hidden_output(RandomMat(8, 2), 2, 0) + || test_gru_layer_with_hidden_output(RandomMat(16, 8), 7, 0) + || test_gru_layer_with_hidden_output(RandomMat(17, 8), 8, 0) + || test_gru_layer_with_hidden_output(RandomMat(19, 15), 8, 0) + || test_gru_layer_with_hidden_output(RandomMat(5, 16), 16, 0) + || test_gru_layer_with_hidden_output(RandomMat(3, 16), 8, 0) + || test_gru_layer_with_hidden_output(RandomMat(2, 5), 17, 0); } static int test_gru_2() diff --git a/tests/test_lstm.cpp b/tests/test_lstm.cpp index 9b337d685..f002a1aec 100644 --- a/tests/test_lstm.cpp +++ b/tests/test_lstm.cpp @@ -42,22 +42,23 @@ static int test_lstm(const ncnn::Mat& a, int outch, int direction) int test_lstm_layer_with_hidden(const ncnn::Mat& a, int outch, int direction) { int input_size = a.w; + int num_directions = direction == 2 ? 2 : 1; ncnn::ParamDict pd; pd.set(0, outch); - pd.set(1, outch * input_size * 4); + pd.set(1, outch * input_size * 4 * num_directions); pd.set(2, direction); std::vector weights(3); - weights[0] = RandomMat(outch * input_size * 4); - weights[1] = RandomMat(outch * 4); - weights[2] = RandomMat(outch * outch * 4); + weights[0] = RandomMat(outch * input_size * 4 * num_directions); + weights[1] = RandomMat(outch * 4 * num_directions); + weights[2] = RandomMat(outch * outch * 4 * num_directions); // initial hidden state - ncnn::Mat hidden = RandomMat(outch); + ncnn::Mat hidden = RandomMat(outch, num_directions); // initial cell state - ncnn::Mat cell = RandomMat(outch); + ncnn::Mat cell = RandomMat(outch, num_directions); std::vector as(3); as[0] = a; @@ -73,6 +74,68 @@ int test_lstm_layer_with_hidden(const ncnn::Mat& a, int outch, int direction) return ret; } +int test_lstm_layer_with_hidden_input(const ncnn::Mat& a, int outch, int direction) +{ + int input_size = a.w; + int num_directions = direction == 2 ? 2 : 1; + + ncnn::ParamDict pd; + pd.set(0, outch); + pd.set(1, outch * input_size * 4 * num_directions); + pd.set(2, direction); + + std::vector weights(3); + weights[0] = RandomMat(outch * input_size * 4 * num_directions); + weights[1] = RandomMat(outch * 4 * num_directions); + weights[2] = RandomMat(outch * outch * 4 * num_directions); + + // initial hidden state + ncnn::Mat hidden = RandomMat(outch, num_directions); + + // initial cell state + ncnn::Mat cell = RandomMat(outch, num_directions); + + std::vector as(3); + as[0] = a; + as[1] = hidden; + as[2] = cell; + + int ret = test_layer("LSTM", pd, weights, as, 1); + if (ret != 0) + { + fprintf(stderr, "test_lstm_layer_with_hidden_input failed a.dims=%d a=(%d %d %d) outch=%d, direction = %d \n", a.dims, a.w, a.h, a.c, outch, direction); + } + + return ret; +} + +int test_lstm_layer_with_hidden_output(const ncnn::Mat& a, int outch, int direction) +{ + int input_size = a.w; + int num_directions = direction == 2 ? 2 : 1; + + ncnn::ParamDict pd; + pd.set(0, outch); + pd.set(1, outch * input_size * 4 * num_directions); + pd.set(2, direction); + + std::vector weights(3); + weights[0] = RandomMat(outch * input_size * 4 * num_directions); + weights[1] = RandomMat(outch * 4 * num_directions); + weights[2] = RandomMat(outch * outch * 4 * num_directions); + + std::vector as(1); + as[0] = a; + + int ret = test_layer("LSTM", pd, weights, as, 3); + if (ret != 0) + { + fprintf(stderr, "test_lstm_layer_with_hidden_output failed a.dims=%d a=(%d %d %d) outch=%d, direction = %d \n", a.dims, a.w, a.h, a.c, outch, direction); + } + + return ret; +} + static int test_lstm_0() { return 0 @@ -90,6 +153,14 @@ static int test_lstm_0() static int test_lstm_1() { return 0 + || test_lstm_layer_with_hidden(RandomMat(4, 4), 1, 2) + || test_lstm_layer_with_hidden(RandomMat(8, 2), 2, 2) + || test_lstm_layer_with_hidden(RandomMat(16, 8), 7, 2) + || test_lstm_layer_with_hidden(RandomMat(17, 8), 8, 2) + || test_lstm_layer_with_hidden(RandomMat(19, 15), 8, 2) + || test_lstm_layer_with_hidden(RandomMat(5, 16), 16, 2) + || test_lstm_layer_with_hidden(RandomMat(3, 16), 8, 2) + || test_lstm_layer_with_hidden(RandomMat(2, 5), 99, 2) || test_lstm_layer_with_hidden(RandomMat(4, 4), 1, 1) || test_lstm_layer_with_hidden(RandomMat(8, 2), 2, 1) || test_lstm_layer_with_hidden(RandomMat(16, 8), 7, 1) @@ -105,7 +176,57 @@ static int test_lstm_1() || test_lstm_layer_with_hidden(RandomMat(19, 15), 8, 0) || test_lstm_layer_with_hidden(RandomMat(5, 16), 16, 0) || test_lstm_layer_with_hidden(RandomMat(3, 16), 8, 0) - || test_lstm_layer_with_hidden(RandomMat(2, 5), 17, 0); + || test_lstm_layer_with_hidden(RandomMat(2, 5), 17, 0) + + || test_lstm_layer_with_hidden_input(RandomMat(4, 4), 1, 2) + || test_lstm_layer_with_hidden_input(RandomMat(8, 2), 2, 2) + || test_lstm_layer_with_hidden_input(RandomMat(16, 8), 7, 2) + || test_lstm_layer_with_hidden_input(RandomMat(17, 8), 8, 2) + || test_lstm_layer_with_hidden_input(RandomMat(19, 15), 8, 2) + || test_lstm_layer_with_hidden_input(RandomMat(5, 16), 16, 2) + || test_lstm_layer_with_hidden_input(RandomMat(3, 16), 8, 2) + || test_lstm_layer_with_hidden_input(RandomMat(2, 5), 99, 2) + || test_lstm_layer_with_hidden_input(RandomMat(4, 4), 1, 1) + || test_lstm_layer_with_hidden_input(RandomMat(8, 2), 2, 1) + || test_lstm_layer_with_hidden_input(RandomMat(16, 8), 7, 1) + || test_lstm_layer_with_hidden_input(RandomMat(17, 8), 8, 1) + || test_lstm_layer_with_hidden_input(RandomMat(19, 15), 8, 1) + || test_lstm_layer_with_hidden_input(RandomMat(5, 16), 16, 1) + || test_lstm_layer_with_hidden_input(RandomMat(3, 16), 8, 1) + || test_lstm_layer_with_hidden_input(RandomMat(2, 5), 99, 1) + || test_lstm_layer_with_hidden_input(RandomMat(4, 2), 1, 0) + || test_lstm_layer_with_hidden_input(RandomMat(8, 2), 2, 0) + || test_lstm_layer_with_hidden_input(RandomMat(16, 8), 7, 0) + || test_lstm_layer_with_hidden_input(RandomMat(17, 8), 8, 0) + || test_lstm_layer_with_hidden_input(RandomMat(19, 15), 8, 0) + || test_lstm_layer_with_hidden_input(RandomMat(5, 16), 16, 0) + || test_lstm_layer_with_hidden_input(RandomMat(3, 16), 8, 0) + || test_lstm_layer_with_hidden_input(RandomMat(2, 5), 17, 0) + + || test_lstm_layer_with_hidden_output(RandomMat(4, 4), 1, 2) + || test_lstm_layer_with_hidden_output(RandomMat(8, 2), 2, 2) + || test_lstm_layer_with_hidden_output(RandomMat(16, 8), 7, 2) + || test_lstm_layer_with_hidden_output(RandomMat(17, 8), 8, 2) + || test_lstm_layer_with_hidden_output(RandomMat(19, 15), 8, 2) + || test_lstm_layer_with_hidden_output(RandomMat(5, 16), 16, 2) + || test_lstm_layer_with_hidden_output(RandomMat(3, 16), 8, 2) + || test_lstm_layer_with_hidden_output(RandomMat(2, 5), 99, 2) + || test_lstm_layer_with_hidden_output(RandomMat(4, 4), 1, 1) + || test_lstm_layer_with_hidden_output(RandomMat(8, 2), 2, 1) + || test_lstm_layer_with_hidden_output(RandomMat(16, 8), 7, 1) + || test_lstm_layer_with_hidden_output(RandomMat(17, 8), 8, 1) + || test_lstm_layer_with_hidden_output(RandomMat(19, 15), 8, 1) + || test_lstm_layer_with_hidden_output(RandomMat(5, 16), 16, 1) + || test_lstm_layer_with_hidden_output(RandomMat(3, 16), 8, 1) + || test_lstm_layer_with_hidden_output(RandomMat(2, 5), 99, 1) + || test_lstm_layer_with_hidden_output(RandomMat(4, 2), 1, 0) + || test_lstm_layer_with_hidden_output(RandomMat(8, 2), 2, 0) + || test_lstm_layer_with_hidden_output(RandomMat(16, 8), 7, 0) + || test_lstm_layer_with_hidden_output(RandomMat(17, 8), 8, 0) + || test_lstm_layer_with_hidden_output(RandomMat(19, 15), 8, 0) + || test_lstm_layer_with_hidden_output(RandomMat(5, 16), 16, 0) + || test_lstm_layer_with_hidden_output(RandomMat(3, 16), 8, 0) + || test_lstm_layer_with_hidden_output(RandomMat(2, 5), 17, 0); } static int test_lstm_2() diff --git a/tests/test_rnn.cpp b/tests/test_rnn.cpp index 0be24f9e6..31b89a22a 100644 --- a/tests/test_rnn.cpp +++ b/tests/test_rnn.cpp @@ -42,19 +42,20 @@ static int test_rnn(const ncnn::Mat& a, int outch, int direction) int test_rnn_layer_with_hidden(const ncnn::Mat& a, int outch, int direction) { int input_size = a.w; + int num_directions = direction == 2 ? 2 : 1; ncnn::ParamDict pd; pd.set(0, outch); - pd.set(1, outch * input_size); + pd.set(1, outch * input_size * num_directions); pd.set(2, direction); std::vector weights(3); - weights[0] = RandomMat(outch * input_size); - weights[1] = RandomMat(outch); - weights[2] = RandomMat(outch * outch); + weights[0] = RandomMat(outch * input_size * num_directions); + weights[1] = RandomMat(outch * num_directions); + weights[2] = RandomMat(outch * outch * num_directions); // initial hidden state - ncnn::Mat hidden = RandomMat(outch); + ncnn::Mat hidden = RandomMat(outch, num_directions); std::vector as(2); as[0] = a; @@ -69,6 +70,64 @@ int test_rnn_layer_with_hidden(const ncnn::Mat& a, int outch, int direction) return ret; } +int test_rnn_layer_with_hidden_input(const ncnn::Mat& a, int outch, int direction) +{ + int input_size = a.w; + int num_directions = direction == 2 ? 2 : 1; + + ncnn::ParamDict pd; + pd.set(0, outch); + pd.set(1, outch * input_size * num_directions); + pd.set(2, direction); + + std::vector weights(3); + weights[0] = RandomMat(outch * input_size * num_directions); + weights[1] = RandomMat(outch * num_directions); + weights[2] = RandomMat(outch * outch * num_directions); + + // initial hidden state + ncnn::Mat hidden = RandomMat(outch, num_directions); + + std::vector as(2); + as[0] = a; + as[1] = hidden; + + int ret = test_layer("RNN", pd, weights, as, 1); + if (ret != 0) + { + fprintf(stderr, "test_rnn_layer_with_hidden_input failed a.dims=%d a=(%d %d %d) outch=%d, direction = %d \n", a.dims, a.w, a.h, a.c, outch, direction); + } + + return ret; +} + +int test_rnn_layer_with_hidden_output(const ncnn::Mat& a, int outch, int direction) +{ + int input_size = a.w; + int num_directions = direction == 2 ? 2 : 1; + + ncnn::ParamDict pd; + pd.set(0, outch); + pd.set(1, outch * input_size * num_directions); + pd.set(2, direction); + + std::vector weights(3); + weights[0] = RandomMat(outch * input_size * num_directions); + weights[1] = RandomMat(outch * num_directions); + weights[2] = RandomMat(outch * outch * num_directions); + + std::vector as(1); + as[0] = a; + + int ret = test_layer("RNN", pd, weights, as, 2); + if (ret != 0) + { + fprintf(stderr, "test_rnn_layer_with_hidden_output failed a.dims=%d a=(%d %d %d) outch=%d, direction = %d \n", a.dims, a.w, a.h, a.c, outch, direction); + } + + return ret; +} + static int test_rnn_0() { return 0 @@ -86,6 +145,14 @@ static int test_rnn_0() static int test_rnn_1() { return 0 + || test_rnn_layer_with_hidden(RandomMat(4, 4), 1, 2) + || test_rnn_layer_with_hidden(RandomMat(8, 2), 2, 2) + || test_rnn_layer_with_hidden(RandomMat(16, 8), 7, 2) + || test_rnn_layer_with_hidden(RandomMat(17, 8), 8, 2) + || test_rnn_layer_with_hidden(RandomMat(19, 15), 8, 2) + || test_rnn_layer_with_hidden(RandomMat(5, 16), 16, 2) + || test_rnn_layer_with_hidden(RandomMat(3, 16), 8, 2) + || test_rnn_layer_with_hidden(RandomMat(2, 5), 99, 2) || test_rnn_layer_with_hidden(RandomMat(4, 4), 1, 1) || test_rnn_layer_with_hidden(RandomMat(8, 2), 2, 1) || test_rnn_layer_with_hidden(RandomMat(16, 8), 7, 1) @@ -101,7 +168,57 @@ static int test_rnn_1() || test_rnn_layer_with_hidden(RandomMat(19, 15), 8, 0) || test_rnn_layer_with_hidden(RandomMat(5, 16), 16, 0) || test_rnn_layer_with_hidden(RandomMat(3, 16), 8, 0) - || test_rnn_layer_with_hidden(RandomMat(2, 5), 17, 0); + || test_rnn_layer_with_hidden(RandomMat(2, 5), 17, 0) + + || test_rnn_layer_with_hidden_input(RandomMat(4, 4), 1, 2) + || test_rnn_layer_with_hidden_input(RandomMat(8, 2), 2, 2) + || test_rnn_layer_with_hidden_input(RandomMat(16, 8), 7, 2) + || test_rnn_layer_with_hidden_input(RandomMat(17, 8), 8, 2) + || test_rnn_layer_with_hidden_input(RandomMat(19, 15), 8, 2) + || test_rnn_layer_with_hidden_input(RandomMat(5, 16), 16, 2) + || test_rnn_layer_with_hidden_input(RandomMat(3, 16), 8, 2) + || test_rnn_layer_with_hidden_input(RandomMat(2, 5), 99, 2) + || test_rnn_layer_with_hidden_input(RandomMat(4, 4), 1, 1) + || test_rnn_layer_with_hidden_input(RandomMat(8, 2), 2, 1) + || test_rnn_layer_with_hidden_input(RandomMat(16, 8), 7, 1) + || test_rnn_layer_with_hidden_input(RandomMat(17, 8), 8, 1) + || test_rnn_layer_with_hidden_input(RandomMat(19, 15), 8, 1) + || test_rnn_layer_with_hidden_input(RandomMat(5, 16), 16, 1) + || test_rnn_layer_with_hidden_input(RandomMat(3, 16), 8, 1) + || test_rnn_layer_with_hidden_input(RandomMat(2, 5), 99, 1) + || test_rnn_layer_with_hidden_input(RandomMat(4, 2), 1, 0) + || test_rnn_layer_with_hidden_input(RandomMat(8, 2), 2, 0) + || test_rnn_layer_with_hidden_input(RandomMat(16, 8), 7, 0) + || test_rnn_layer_with_hidden_input(RandomMat(17, 8), 8, 0) + || test_rnn_layer_with_hidden_input(RandomMat(19, 15), 8, 0) + || test_rnn_layer_with_hidden_input(RandomMat(5, 16), 16, 0) + || test_rnn_layer_with_hidden_input(RandomMat(3, 16), 8, 0) + || test_rnn_layer_with_hidden_input(RandomMat(2, 5), 17, 0) + + || test_rnn_layer_with_hidden_output(RandomMat(4, 4), 1, 2) + || test_rnn_layer_with_hidden_output(RandomMat(8, 2), 2, 2) + || test_rnn_layer_with_hidden_output(RandomMat(16, 8), 7, 2) + || test_rnn_layer_with_hidden_output(RandomMat(17, 8), 8, 2) + || test_rnn_layer_with_hidden_output(RandomMat(19, 15), 8, 2) + || test_rnn_layer_with_hidden_output(RandomMat(5, 16), 16, 2) + || test_rnn_layer_with_hidden_output(RandomMat(3, 16), 8, 2) + || test_rnn_layer_with_hidden_output(RandomMat(2, 5), 99, 2) + || test_rnn_layer_with_hidden_output(RandomMat(4, 4), 1, 1) + || test_rnn_layer_with_hidden_output(RandomMat(8, 2), 2, 1) + || test_rnn_layer_with_hidden_output(RandomMat(16, 8), 7, 1) + || test_rnn_layer_with_hidden_output(RandomMat(17, 8), 8, 1) + || test_rnn_layer_with_hidden_output(RandomMat(19, 15), 8, 1) + || test_rnn_layer_with_hidden_output(RandomMat(5, 16), 16, 1) + || test_rnn_layer_with_hidden_output(RandomMat(3, 16), 8, 1) + || test_rnn_layer_with_hidden_output(RandomMat(2, 5), 99, 1) + || test_rnn_layer_with_hidden_output(RandomMat(4, 2), 1, 0) + || test_rnn_layer_with_hidden_output(RandomMat(8, 2), 2, 0) + || test_rnn_layer_with_hidden_output(RandomMat(16, 8), 7, 0) + || test_rnn_layer_with_hidden_output(RandomMat(17, 8), 8, 0) + || test_rnn_layer_with_hidden_output(RandomMat(19, 15), 8, 0) + || test_rnn_layer_with_hidden_output(RandomMat(5, 16), 16, 0) + || test_rnn_layer_with_hidden_output(RandomMat(3, 16), 8, 0) + || test_rnn_layer_with_hidden_output(RandomMat(2, 5), 17, 0); } static int test_rnn_2()