/** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "nnacl/fp32/lstm_fp32.h" #include #include #include "nnacl/fp32/activation_fp32.h" #include "nnacl/fp32/arithmetic_fp32.h" #include "nnacl/fp32/matmul_fp32.h" // input: [row, inner_size]; weight: [col, inner_size]; output: [row, col] void MatMulAcc(float *output, const float *input, const float *weight, int rows, int cols, int inner_size) { for (int r = 0; r < rows; r++) { for (int c = 0; c < cols; c++) { float res = 0; const float *input_col = input + r * inner_size; const float *weight_col = weight + c * inner_size; int index = 0; #ifdef ENABLE_ARM float32x4_t out = vdupq_n_f32(0.0f); for (; index <= inner_size - 4; index += 4) { float32x4_t in_0 = vld1q_f32(input_col + index); float32x4_t in_1 = vld1q_f32(weight_col + index); out = vmlaq_f32(out, in_1, in_0); } #ifdef ENABLE_ARM64 res += vaddvq_f32(out); #else float32x2_t add2 = vadd_f32(vget_low_f32(out), vget_high_f32(out)); float32x2_t add4 = vpadd_f32(add2, add2); res += vget_lane_f32(add4, 0); #endif #endif for (; index < inner_size; index++) { res += input_col[index] * weight_col[index]; } output[r * cols + c] += res; } } } void ElementMulAcc(const float *input0, const float *input1, float *output, int element_size) { int index = 0; #ifdef ENABLE_ARM for (; index <= element_size - 4; index += 4) { float32x4_t in_0 = vld1q_f32(input0 + index); float32x4_t in_1 = vld1q_f32(input1 + index); float32x4_t out = vld1q_f32(output + index); out = vmlaq_f32(out, in_1, in_0); vst1q_f32(output + index, out); } #endif for (; index < element_size; index++) { output[index] += input0[index] * input1[index]; } } int ElementOptMulAcc(const float *input0, const float input1, float *output, const int element_size) { int index = 0; #ifdef ENABLE_NEON for (; index <= element_size - 4; index += C4NUM) { float32x4_t vin0 = vld1q_f32(input0 + index); float32x4_t vout = vld1q_f32(output + index); vout = vmlaq_n_f32(vout, vin0, input1); vst1q_f32(output + index, vout); } #endif for (; index < element_size; index++) { output[index] += input0[index] * input1; } return NNACL_OK; } void UpdataState(float *cell_state, const float *forget_gate, const float *input_gate, const float *cell_gate, float *state_buffer, int batch, int hidden_size, const float smooth) { if (!(smooth >= -FLT_EPSILON && smooth <= FLT_EPSILON)) { // smooth * old_cell_state memcpy(state_buffer, cell_state, batch * hidden_size * sizeof(float)); ArithmeticParameter parameter; parameter.in_elements_num0_ = batch * hidden_size; parameter.in_elements_num1_ = 1; ElementOptMul(state_buffer, &smooth, state_buffer, batch * hidden_size, ¶meter); } ElementMul(forget_gate, cell_state, cell_state, batch * hidden_size); ElementMulAcc(input_gate, cell_gate, cell_state, batch * hidden_size); if (!(smooth >= -FLT_EPSILON && smooth <= FLT_EPSILON)) { // (1 - smooth) * new_cell_state ElementOptMulAcc(cell_state, 1 - smooth, state_buffer, batch * hidden_size); } } void UpdataOutput(const float *cell_state, const float *output_gate, float *hidden_state, float *state_buffer_in, int batch, int hidden_size, const float smooth) { float *state_buffer = state_buffer_in + batch * hidden_size; if (!(smooth >= -FLT_EPSILON && smooth <= FLT_EPSILON)) { memcpy(state_buffer, hidden_state, batch * hidden_size * sizeof(float)); ArithmeticParameter parameter; parameter.in_elements_num0_ = batch * hidden_size; parameter.in_elements_num1_ = 1; ElementOptMul(state_buffer, &smooth, state_buffer, batch * hidden_size, ¶meter); } Tanh(cell_state, batch * hidden_size, hidden_state); ElementMul(hidden_state, output_gate, hidden_state, batch * hidden_size); if (!(smooth >= -FLT_EPSILON && smooth <= FLT_EPSILON)) { ElementOptMulAcc(hidden_state, 1 - smooth, state_buffer, batch * hidden_size); } } void LstmMatmul(float *c, const float *a, const float *b, const float *bias, int row, int deep, int col, bool is_vec) { if (is_vec) { memcpy(c, bias, col * sizeof(float)); MatMulAcc(c, a, b, row, col, deep); } else { MatMulOpt(a, b, c, bias, ActType_No, deep, row, col, col, OutType_Nhwc); } } void PackLstmInput(float *dst, const float *src, int row, int deep) { #ifdef ENABLE_AVX RowMajor2Col6Major(src, dst, row, deep); #elif defined(ENABLE_SSE) RowMajor2Col4Major(src, dst, row, deep); #else RowMajor2Col12Major(src, dst, row, deep); #endif } void UpdateGate(float *gate_buffer, const float *input, const float *weight, const float *bias, int row, int deep, int col, int col_align, bool is_vec) { const float *input_weight = weight; const float *forget_weight = weight + deep * col * 2; const float *cell_weight = weight + deep * col * 3; const float *output_weight = weight + deep * col; const float *input_bias = bias; const float *forget_bias = bias + col_align * 2; const float *cell_bias = bias + col_align * 3; const float *output_bias = bias + col_align; float *input_gate = gate_buffer; float *forget_gate = gate_buffer + row * col * 2; float *cell_gate = gate_buffer + row * col * 3; float *output_gate = gate_buffer + row * col; LstmMatmul(input_gate, input, input_weight, input_bias, row, deep, col, is_vec); LstmMatmul(forget_gate, input, forget_weight, forget_bias, row, deep, col, is_vec); LstmMatmul(cell_gate, input, cell_weight, cell_bias, row, deep, col, is_vec); LstmMatmul(output_gate, input, output_weight, output_bias, row, deep, col, is_vec); } void LstmStepUnit(float *output, const float *input, const float *input_weight, const float *state_weight, const float *bias, float *hidden_state, float *cell_state, float *gate_buffer, float *state_buffer, float *matmul_buffer[2], const LstmParameter *lstm_param) { bool is_vec = lstm_param->batch_ == 1; // input * weight if (is_vec) { UpdateGate(gate_buffer, input, input_weight, bias, lstm_param->batch_, lstm_param->input_size_, lstm_param->hidden_size_, lstm_param->col_align_, is_vec); } else { // pack input for matmul PackLstmInput(matmul_buffer[0], input, lstm_param->batch_, lstm_param->input_size_); UpdateGate(gate_buffer, matmul_buffer[0], input_weight, bias, lstm_param->batch_, lstm_param->input_size_, lstm_param->hidden_size_, lstm_param->col_align_, is_vec); } // state * weight float *state_gate = gate_buffer + lstm_param->batch_ * lstm_param->hidden_size_ * 4; const float *state_bias = bias + lstm_param->col_align_ * 4; if (is_vec) { UpdateGate(state_gate, hidden_state, state_weight, state_bias, lstm_param->batch_, lstm_param->hidden_size_, lstm_param->hidden_size_, lstm_param->col_align_, is_vec); } else { // pack state for matmul PackLstmInput(matmul_buffer[1], hidden_state, lstm_param->batch_, lstm_param->hidden_size_); UpdateGate(state_gate, matmul_buffer[1], state_weight, state_bias, lstm_param->batch_, lstm_param->hidden_size_, lstm_param->hidden_size_, lstm_param->col_align_, is_vec); } ElementAdd(gate_buffer, state_gate, gate_buffer, 4 * lstm_param->batch_ * lstm_param->hidden_size_); float *input_gate = gate_buffer; float *forget_gate = gate_buffer + lstm_param->batch_ * lstm_param->hidden_size_ * 2; float *cell_gate = gate_buffer + lstm_param->batch_ * lstm_param->hidden_size_ * 3; float *output_gate = gate_buffer + lstm_param->batch_ * lstm_param->hidden_size_; // update input_gate Sigmoid(input_gate, lstm_param->batch_ * lstm_param->hidden_size_, input_gate); // update forget_gate Sigmoid(forget_gate, lstm_param->batch_ * lstm_param->hidden_size_, forget_gate); // update cell_gate Tanh(cell_gate, lstm_param->batch_ * lstm_param->hidden_size_, cell_gate); // update cell state UpdataState(cell_state, forget_gate, input_gate, cell_gate, state_buffer, lstm_param->batch_, lstm_param->hidden_size_, lstm_param->smooth_); // update output_gate Sigmoid(output_gate, lstm_param->batch_ * lstm_param->hidden_size_, output_gate); // update output UpdataOutput(cell_state, output_gate, hidden_state, state_buffer, lstm_param->batch_, lstm_param->hidden_size_, lstm_param->smooth_); memcpy(output, hidden_state, lstm_param->batch_ * lstm_param->hidden_size_ * sizeof(float)); if (!(lstm_param->smooth_ >= -FLT_EPSILON && lstm_param->smooth_ <= FLT_EPSILON)) { memcpy(cell_state, state_buffer, lstm_param->batch_ * lstm_param->hidden_size_ * sizeof(float)); memcpy(hidden_state, state_buffer + lstm_param->batch_ * lstm_param->hidden_size_, lstm_param->batch_ * lstm_param->hidden_size_ * sizeof(float)); } } void Lstm(float *output, const float *input, const float *weight_i, const float *weight_h, const float *bias, float *hidden_state, float *cell_state, float *gate_buffer, float *state_buffer, float *matmul_buffer[2], const LstmParameter *lstm_param) { // forward for (int t = 0; t < lstm_param->seq_len_; t++) { const float *input_ptr = input + t * lstm_param->input_step_; float *output_ptr = output + t * lstm_param->output_step_; LstmStepUnit(output_ptr, input_ptr, weight_i, weight_h, bias, hidden_state, cell_state, gate_buffer, state_buffer, matmul_buffer, lstm_param); } // backward if (lstm_param->bidirectional_) { const float *backward_weight_i = weight_i + 4 * lstm_param->col_align_ * lstm_param->input_size_; const float *backward_weight_h = weight_h + 4 * lstm_param->col_align_ * lstm_param->hidden_size_; const float *backward_bias = bias + 8 * lstm_param->hidden_size_; float *backward_output = output + lstm_param->batch_ * lstm_param->hidden_size_; float *backward_cell_state = cell_state + lstm_param->batch_ * lstm_param->hidden_size_; float *backward_hidden_state = hidden_state + lstm_param->batch_ * lstm_param->hidden_size_; for (int t = lstm_param->seq_len_ - 1; t >= 0; t--) { const float *input_ptr = input + t * lstm_param->input_step_; float *output_ptr = backward_output + t * lstm_param->output_step_; LstmStepUnit(output_ptr, input_ptr, backward_weight_i, backward_weight_h, backward_bias, backward_hidden_state, backward_cell_state, gate_buffer, state_buffer, matmul_buffer, lstm_param); } } }