|
|
|
@@ -19,20 +19,7 @@ |
|
|
|
#include <float.h> |
|
|
|
#include "nnacl/fp32/activation_fp32.h" |
|
|
|
#include "nnacl/fp32/arithmetic_fp32.h" |
|
|
|
#include "nnacl/fp32/mul_fp32.h" |
|
|
|
|
|
|
|
void InitGate(float *gate_buffer, const float *bias, const LstmParameter *lstm_parm) { |
|
|
|
int gate_offest = 0; |
|
|
|
for (int l = 0; l < 4; l++) { |
|
|
|
int batch_offest = gate_offest; |
|
|
|
int bias_offest = l * lstm_parm->hidden_size_; |
|
|
|
for (int b = 0; b < lstm_parm->batch_; b++) { |
|
|
|
memcpy(gate_buffer + batch_offest, bias + bias_offest, lstm_parm->hidden_size_ * sizeof(float)); |
|
|
|
batch_offest += lstm_parm->hidden_size_; |
|
|
|
} |
|
|
|
gate_offest += lstm_parm->batch_ * lstm_parm->hidden_size_; |
|
|
|
} |
|
|
|
} |
|
|
|
#include "nnacl/fp32/matmul_fp32.h" |
|
|
|
|
|
|
|
// input: [row, inner_size]; weight: [col, inner_size]; output: [row, col] |
|
|
|
void MatMulAcc(float *output, const float *input, const float *weight, int rows, int cols, int inner_size) { |
|
|
|
@@ -134,106 +121,131 @@ void UpdataOutput(const float *cell_state, const float *output_gate, float *hidd |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void LstmStepUnit(float *output, const float *input, const float *input_input_weight, const float *input_forget_weight, |
|
|
|
const float *input_cell_weight, const float *input_output_weight, const float *state_input_weight, |
|
|
|
const float *state_forget_weight, const float *state_cell_weight, const float *state_output_weight, |
|
|
|
const float *bias, float *hidden_state, float *cell_state, float *gate_buffer, float *state_buffer, |
|
|
|
const LstmParameter *lstm_parm) { |
|
|
|
InitGate(gate_buffer, bias, lstm_parm); |
|
|
|
void LstmMatmul(float *c, const float *a, const float *b, const float *bias, int row, int deep, int col, bool is_vec) { |
|
|
|
if (is_vec) { |
|
|
|
memcpy(c, bias, col * sizeof(float)); |
|
|
|
MatMulAcc(c, a, b, row, col, deep); |
|
|
|
} else { |
|
|
|
MatMulOpt(a, b, c, bias, ActType_No, deep, row, col, col, OutType_Nhwc); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void PackLstmInput(float *dst, const float *src, int row, int deep) { |
|
|
|
#ifdef ENABLE_AVX |
|
|
|
RowMajor2Col6Major(src, dst, row, deep); |
|
|
|
#elif defined(ENABLE_SSE) |
|
|
|
RowMajor2Col4Major(src, dst, row, deep); |
|
|
|
#else |
|
|
|
RowMajor2Col12Major(src, dst, row, deep); |
|
|
|
#endif |
|
|
|
} |
|
|
|
|
|
|
|
void UpdateGate(float *gate_buffer, const float *input, const float *weight, const float *bias, int row, int deep, |
|
|
|
int col, int col_align, bool is_vec) { |
|
|
|
const float *input_weight = weight; |
|
|
|
const float *forget_weight = weight + deep * col * 2; |
|
|
|
const float *cell_weight = weight + deep * col * 3; |
|
|
|
const float *output_weight = weight + deep * col; |
|
|
|
|
|
|
|
const float *input_bias = bias; |
|
|
|
const float *forget_bias = bias + col_align * 2; |
|
|
|
const float *cell_bias = bias + col_align * 3; |
|
|
|
const float *output_bias = bias + col_align; |
|
|
|
|
|
|
|
float *input_gate = gate_buffer; |
|
|
|
float *forget_gate = gate_buffer + lstm_parm->batch_ * lstm_parm->hidden_size_ * 2; |
|
|
|
float *cell_gate = gate_buffer + lstm_parm->batch_ * lstm_parm->hidden_size_ * 3; |
|
|
|
float *output_gate = gate_buffer + lstm_parm->batch_ * lstm_parm->hidden_size_ * 1; |
|
|
|
float *forget_gate = gate_buffer + row * col * 2; |
|
|
|
float *cell_gate = gate_buffer + row * col * 3; |
|
|
|
float *output_gate = gate_buffer + row * col; |
|
|
|
|
|
|
|
LstmMatmul(input_gate, input, input_weight, input_bias, row, deep, col, is_vec); |
|
|
|
LstmMatmul(forget_gate, input, forget_weight, forget_bias, row, deep, col, is_vec); |
|
|
|
LstmMatmul(cell_gate, input, cell_weight, cell_bias, row, deep, col, is_vec); |
|
|
|
LstmMatmul(output_gate, input, output_weight, output_bias, row, deep, col, is_vec); |
|
|
|
} |
|
|
|
|
|
|
|
void LstmStepUnit(float *output, const float *input, const float *input_weight, const float *state_weight, |
|
|
|
const float *bias, float *hidden_state, float *cell_state, float *gate_buffer, float *state_buffer, |
|
|
|
float *matmul_buffer[2], const LstmParameter *lstm_param) { |
|
|
|
bool is_vec = lstm_param->batch_ == 1; |
|
|
|
// input * weight |
|
|
|
MatMulAcc(input_gate, input, input_input_weight, lstm_parm->batch_, lstm_parm->hidden_size_, lstm_parm->input_size_); |
|
|
|
MatMulAcc(forget_gate, input, input_forget_weight, lstm_parm->batch_, lstm_parm->hidden_size_, |
|
|
|
lstm_parm->input_size_); |
|
|
|
MatMulAcc(cell_gate, input, input_cell_weight, lstm_parm->batch_, lstm_parm->hidden_size_, lstm_parm->input_size_); |
|
|
|
MatMulAcc(output_gate, input, input_output_weight, lstm_parm->batch_, lstm_parm->hidden_size_, |
|
|
|
lstm_parm->input_size_); |
|
|
|
if (is_vec) { |
|
|
|
UpdateGate(gate_buffer, input, input_weight, bias, lstm_param->batch_, lstm_param->input_size_, |
|
|
|
lstm_param->hidden_size_, lstm_param->col_align_, is_vec); |
|
|
|
} else { |
|
|
|
// pack input for matmul |
|
|
|
PackLstmInput(matmul_buffer[0], input, lstm_param->batch_, lstm_param->input_size_); |
|
|
|
UpdateGate(gate_buffer, matmul_buffer[0], input_weight, bias, lstm_param->batch_, lstm_param->input_size_, |
|
|
|
lstm_param->hidden_size_, lstm_param->col_align_, is_vec); |
|
|
|
} |
|
|
|
|
|
|
|
// state * weight |
|
|
|
MatMulAcc(input_gate, hidden_state, state_input_weight, lstm_parm->batch_, lstm_parm->hidden_size_, |
|
|
|
lstm_parm->hidden_size_); |
|
|
|
MatMulAcc(forget_gate, hidden_state, state_forget_weight, lstm_parm->batch_, lstm_parm->hidden_size_, |
|
|
|
lstm_parm->hidden_size_); |
|
|
|
MatMulAcc(cell_gate, hidden_state, state_cell_weight, lstm_parm->batch_, lstm_parm->hidden_size_, |
|
|
|
lstm_parm->hidden_size_); |
|
|
|
MatMulAcc(output_gate, hidden_state, state_output_weight, lstm_parm->batch_, lstm_parm->hidden_size_, |
|
|
|
lstm_parm->hidden_size_); |
|
|
|
float *state_gate = gate_buffer + lstm_param->batch_ * lstm_param->hidden_size_ * 4; |
|
|
|
const float *state_bias = bias + lstm_param->col_align_ * 4; |
|
|
|
if (is_vec) { |
|
|
|
UpdateGate(state_gate, hidden_state, state_weight, state_bias, lstm_param->batch_, lstm_param->hidden_size_, |
|
|
|
lstm_param->hidden_size_, lstm_param->col_align_, is_vec); |
|
|
|
} else { |
|
|
|
// pack state for matmul |
|
|
|
PackLstmInput(matmul_buffer[1], hidden_state, lstm_param->batch_, lstm_param->hidden_size_); |
|
|
|
UpdateGate(state_gate, matmul_buffer[1], state_weight, state_bias, lstm_param->batch_, lstm_param->hidden_size_, |
|
|
|
lstm_param->hidden_size_, lstm_param->col_align_, is_vec); |
|
|
|
} |
|
|
|
ElementAdd(gate_buffer, state_gate, gate_buffer, 4 * lstm_param->batch_ * lstm_param->hidden_size_); |
|
|
|
|
|
|
|
float *input_gate = gate_buffer; |
|
|
|
float *forget_gate = gate_buffer + lstm_param->batch_ * lstm_param->hidden_size_ * 2; |
|
|
|
float *cell_gate = gate_buffer + lstm_param->batch_ * lstm_param->hidden_size_ * 3; |
|
|
|
float *output_gate = gate_buffer + lstm_param->batch_ * lstm_param->hidden_size_; |
|
|
|
// update input_gate |
|
|
|
Sigmoid(input_gate, lstm_parm->batch_ * lstm_parm->hidden_size_, input_gate); |
|
|
|
Sigmoid(input_gate, lstm_param->batch_ * lstm_param->hidden_size_, input_gate); |
|
|
|
|
|
|
|
// update forget_gate |
|
|
|
Sigmoid(forget_gate, lstm_parm->batch_ * lstm_parm->hidden_size_, forget_gate); |
|
|
|
Sigmoid(forget_gate, lstm_param->batch_ * lstm_param->hidden_size_, forget_gate); |
|
|
|
|
|
|
|
// update cell_gate |
|
|
|
Tanh(cell_gate, lstm_parm->batch_ * lstm_parm->hidden_size_, cell_gate); |
|
|
|
Tanh(cell_gate, lstm_param->batch_ * lstm_param->hidden_size_, cell_gate); |
|
|
|
// update cell state |
|
|
|
UpdataState(cell_state, forget_gate, input_gate, cell_gate, state_buffer, lstm_parm->batch_, lstm_parm->hidden_size_, |
|
|
|
lstm_parm->smooth_); |
|
|
|
UpdataState(cell_state, forget_gate, input_gate, cell_gate, state_buffer, lstm_param->batch_, |
|
|
|
lstm_param->hidden_size_, lstm_param->smooth_); |
|
|
|
|
|
|
|
// update output_gate |
|
|
|
Sigmoid(output_gate, lstm_parm->batch_ * lstm_parm->hidden_size_, output_gate); |
|
|
|
Sigmoid(output_gate, lstm_param->batch_ * lstm_param->hidden_size_, output_gate); |
|
|
|
// update output |
|
|
|
UpdataOutput(cell_state, output_gate, hidden_state, state_buffer, lstm_parm->batch_, lstm_parm->hidden_size_, |
|
|
|
lstm_parm->smooth_); |
|
|
|
memcpy(output, hidden_state, lstm_parm->batch_ * lstm_parm->hidden_size_ * sizeof(float)); |
|
|
|
|
|
|
|
if (!(lstm_parm->smooth_ >= -FLT_EPSILON && lstm_parm->smooth_ <= FLT_EPSILON)) { |
|
|
|
memcpy(cell_state, state_buffer, lstm_parm->batch_ * lstm_parm->hidden_size_ * sizeof(float)); |
|
|
|
memcpy(hidden_state, state_buffer + lstm_parm->batch_ * lstm_parm->hidden_size_, |
|
|
|
lstm_parm->batch_ * lstm_parm->hidden_size_ * sizeof(float)); |
|
|
|
UpdataOutput(cell_state, output_gate, hidden_state, state_buffer, lstm_param->batch_, lstm_param->hidden_size_, |
|
|
|
lstm_param->smooth_); |
|
|
|
memcpy(output, hidden_state, lstm_param->batch_ * lstm_param->hidden_size_ * sizeof(float)); |
|
|
|
|
|
|
|
if (!(lstm_param->smooth_ >= -FLT_EPSILON && lstm_param->smooth_ <= FLT_EPSILON)) { |
|
|
|
memcpy(cell_state, state_buffer, lstm_param->batch_ * lstm_param->hidden_size_ * sizeof(float)); |
|
|
|
memcpy(hidden_state, state_buffer + lstm_param->batch_ * lstm_param->hidden_size_, |
|
|
|
lstm_param->batch_ * lstm_param->hidden_size_ * sizeof(float)); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void Lstm(float *output, const float *input, const float *weight_i, const float *weight_h, const float *bias, |
|
|
|
float *hidden_state, float *cell_state, float *gate_buffer, float *state_buffer, |
|
|
|
const LstmParameter *lstm_parm) { |
|
|
|
float *hidden_state, float *cell_state, float *gate_buffer, float *state_buffer, float *matmul_buffer[2], |
|
|
|
const LstmParameter *lstm_param) { |
|
|
|
// forward |
|
|
|
const float *input_input_weight = weight_i; |
|
|
|
const float *input_forget_weight = weight_i + lstm_parm->input_size_ * lstm_parm->hidden_size_ * 2; |
|
|
|
const float *input_cell_weight = weight_i + lstm_parm->input_size_ * lstm_parm->hidden_size_ * 3; |
|
|
|
const float *input_output_weight = weight_i + lstm_parm->input_size_ * lstm_parm->hidden_size_ * 1; |
|
|
|
|
|
|
|
const float *state_input_weight = weight_h; |
|
|
|
const float *state_forget_weight = weight_h + lstm_parm->hidden_size_ * lstm_parm->hidden_size_ * 2; |
|
|
|
const float *state_cell_weight = weight_h + lstm_parm->hidden_size_ * lstm_parm->hidden_size_ * 3; |
|
|
|
const float *state_output_weight = weight_h + lstm_parm->hidden_size_ * lstm_parm->hidden_size_ * 1; |
|
|
|
|
|
|
|
for (int t = 0; t < lstm_parm->seq_len_; t++) { |
|
|
|
const float *input_ptr = input + t * lstm_parm->input_step_; |
|
|
|
float *output_ptr = output + t * lstm_parm->output_step_; |
|
|
|
LstmStepUnit(output_ptr, input_ptr, input_input_weight, input_forget_weight, input_cell_weight, input_output_weight, |
|
|
|
state_input_weight, state_forget_weight, state_cell_weight, state_output_weight, bias, hidden_state, |
|
|
|
cell_state, gate_buffer, state_buffer, lstm_parm); |
|
|
|
for (int t = 0; t < lstm_param->seq_len_; t++) { |
|
|
|
const float *input_ptr = input + t * lstm_param->input_step_; |
|
|
|
float *output_ptr = output + t * lstm_param->output_step_; |
|
|
|
LstmStepUnit(output_ptr, input_ptr, weight_i, weight_h, bias, hidden_state, cell_state, gate_buffer, state_buffer, |
|
|
|
matmul_buffer, lstm_param); |
|
|
|
} |
|
|
|
|
|
|
|
// backward |
|
|
|
if (lstm_parm->bidirectional_) { |
|
|
|
input_input_weight = weight_i + lstm_parm->input_size_ * lstm_parm->hidden_size_ * 4; |
|
|
|
input_forget_weight = weight_i + lstm_parm->input_size_ * lstm_parm->hidden_size_ * 6; |
|
|
|
input_cell_weight = weight_i + lstm_parm->input_size_ * lstm_parm->hidden_size_ * 7; |
|
|
|
input_output_weight = weight_i + lstm_parm->input_size_ * lstm_parm->hidden_size_ * 5; |
|
|
|
|
|
|
|
state_input_weight = weight_h + lstm_parm->hidden_size_ * lstm_parm->hidden_size_ * 4; |
|
|
|
state_forget_weight = weight_h + lstm_parm->hidden_size_ * lstm_parm->hidden_size_ * 6; |
|
|
|
state_cell_weight = weight_h + lstm_parm->hidden_size_ * lstm_parm->hidden_size_ * 7; |
|
|
|
state_output_weight = weight_h + lstm_parm->hidden_size_ * lstm_parm->hidden_size_ * 5; |
|
|
|
|
|
|
|
float *backward_output = output + lstm_parm->batch_ * lstm_parm->hidden_size_; |
|
|
|
const float *backward_bias = bias + 4 * lstm_parm->hidden_size_; |
|
|
|
float *backward_cell_state = cell_state + lstm_parm->batch_ * lstm_parm->hidden_size_; |
|
|
|
float *backward_hidden_state = hidden_state + lstm_parm->batch_ * lstm_parm->hidden_size_; |
|
|
|
for (int t = lstm_parm->seq_len_ - 1; t >= 0; t--) { |
|
|
|
const float *input_ptr = input + t * lstm_parm->input_step_; |
|
|
|
float *output_ptr = backward_output + t * lstm_parm->output_step_; |
|
|
|
LstmStepUnit(output_ptr, input_ptr, input_input_weight, input_forget_weight, input_cell_weight, |
|
|
|
input_output_weight, state_input_weight, state_forget_weight, state_cell_weight, state_output_weight, |
|
|
|
backward_bias, backward_hidden_state, backward_cell_state, gate_buffer, state_buffer, lstm_parm); |
|
|
|
if (lstm_param->bidirectional_) { |
|
|
|
const float *backward_weight_i = weight_i + 4 * lstm_param->col_align_ * lstm_param->input_size_; |
|
|
|
const float *backward_weight_h = weight_h + 4 * lstm_param->col_align_ * lstm_param->hidden_size_; |
|
|
|
const float *backward_bias = bias + 8 * lstm_param->hidden_size_; |
|
|
|
float *backward_output = output + lstm_param->batch_ * lstm_param->hidden_size_; |
|
|
|
float *backward_cell_state = cell_state + lstm_param->batch_ * lstm_param->hidden_size_; |
|
|
|
float *backward_hidden_state = hidden_state + lstm_param->batch_ * lstm_param->hidden_size_; |
|
|
|
for (int t = lstm_param->seq_len_ - 1; t >= 0; t--) { |
|
|
|
const float *input_ptr = input + t * lstm_param->input_step_; |
|
|
|
float *output_ptr = backward_output + t * lstm_param->output_step_; |
|
|
|
LstmStepUnit(output_ptr, input_ptr, backward_weight_i, backward_weight_h, backward_bias, backward_hidden_state, |
|
|
|
backward_cell_state, gate_buffer, state_buffer, matmul_buffer, lstm_param); |
|
|
|
} |
|
|
|
} |
|
|
|
} |