diff --git a/mindspore/lite/src/ops/lstm.cc b/mindspore/lite/src/ops/lstm.cc new file mode 100644 index 0000000000..22b52ad8a7 --- /dev/null +++ b/mindspore/lite/src/ops/lstm.cc @@ -0,0 +1,69 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/ops.h" +#include "include/errorcode.h" +#include "utils/log_adapter.h" +#include "src/ir/tensor.h" + +namespace mindspore::lite { +const int kLstmInputNum = 6; +const int kLstmOutputNum = 3; +int Lstm::InferShape(std::vector inputs_, std::vector outputs_) { + MS_ASSERT(this->primitive != nullptr); + if (inputs_.size() != kLstmInputNum || outputs_.size() != kLstmOutputNum) { + MS_LOG(ERROR) << "OpLstm inputs or outputs size error."; + return RET_INPUT_TENSOR_ERROR; + } + auto input = inputs_.front(); + MS_ASSERT(input != nullptr); + auto weight_i = inputs_.front(); + MS_ASSERT(input0 != nullptr); + auto output = outputs_.front(); + MS_ASSERT(output != nullptr); + + std::vector in_shape = input->shape(); + std::vector w_shape = weight_i->shape(); // layer, hidden_size * 4, input_size + if (in_shape.size() != 3 || w_shape.size() != 3) { + MS_LOG(ERROR) << "OpLstm input dims should be 3."; + return RET_ERROR; + } + + auto lstm_prim = this->primitive->value_as_Lstm(); + int hidden_size = w_shape[1] / 4; + + // set output + std::vector out_shape(in_shape); + out_shape[2] = hidden_size; + if (lstm_prim->bidirection()) { + out_shape.insert(out_shape.begin() + 1, 2); + } + output->set_shape(out_shape); + + // set hidden state, cell state + std::vector state_shape(in_shape); + state_shape[0] = lstm_prim->bidirection() ? 2 : 1; + state_shape[2] = hidden_size; + outputs_[1]->set_shape(state_shape); + outputs_[2]->set_shape(state_shape); + + for (int i = 0; i < kLstmOutputNum; i++) { + outputs_[i]->set_data_type(input->data_type()); + outputs_[i]->SetFormat(input->GetFormat()); + } + return RET_OK; +} +} // namespace mindspore::lite diff --git a/mindspore/lite/src/ops/ops.h b/mindspore/lite/src/ops/ops.h index e1dd289804..28be317f21 100644 --- a/mindspore/lite/src/ops/ops.h +++ b/mindspore/lite/src/ops/ops.h @@ -770,6 +770,13 @@ class QuantDTypeCast : public Primitive { const schema::QuantDTypeCast *GetAttribute() const { return this->primitive->value_as_QuantDTypeCast(); } int InferShape(std::vector inputs, std::vector outputs) override; }; + +class Lstm : public Primitive { + public: + explicit Lstm(schema::Primitive *primitive) : Primitive(primitive) {} + const schema::Lstm *GetAttribute() const { return this->primitive->value_as_Lstm(); } + int InferShape(std::vector inputs, std::vector outputs) override; +}; } // namespace lite } // namespace mindspore #endif // MINDSPORE_LITE_SRC_OPS_OPS_H_ diff --git a/mindspore/lite/src/populate_parameter.cc b/mindspore/lite/src/populate_parameter.cc index 162a9d3275..e3f6f22efd 100644 --- a/mindspore/lite/src/populate_parameter.cc +++ b/mindspore/lite/src/populate_parameter.cc @@ -67,6 +67,7 @@ #include "src/runtime/kernel/arm/opclib/fp32/space_to_depth.h" #include "src/runtime/kernel/arm/opclib/fp32/space_to_batch.h" #include "src/runtime/kernel/arm/opclib/int8/quant_dtype_cast.h" +#include "src/runtime/kernel/arm/opclib/fp32/lstm.h" namespace mindspore::kernel { OpParameter *PopulateFillParameter(const lite::Primitive *primitive) { @@ -1169,6 +1170,23 @@ OpParameter *PopulatePriorBoxParameter(const lite::Primitive *primitive) { return reinterpret_cast(prior_box_param); } +OpParameter *PopulateLstmParameter(const lite::Primitive *primitive) { + LstmParameter *lstm_param = new (std::nothrow) LstmParameter(); + if (lstm_param == nullptr) { + MS_LOG(ERROR) << "new LstmParameter fail!"; + return nullptr; + } + lstm_param->op_parameter_.type_ = primitive->Type(); + auto param = primitive->Value()->value_as_Lstm(); + if (param == nullptr) { + delete (lstm_param); + MS_LOG(ERROR) << "get Lstm param nullptr."; + return nullptr; + } + lstm_param->bidirectional_ = param->bidirection(); + return reinterpret_cast(lstm_param); +} + PopulateParameterRegistry::PopulateParameterRegistry() { populate_parameter_funcs_[schema::PrimitiveType_SoftMax] = PopulateSoftmaxParameter; populate_parameter_funcs_[schema::PrimitiveType_Activation] = PopulateActivationParameter; @@ -1244,6 +1262,7 @@ PopulateParameterRegistry::PopulateParameterRegistry() { populate_parameter_funcs_[schema::PrimitiveType_Split] = PopulateSplitParameter; populate_parameter_funcs_[schema::PrimitiveType_PriorBox] = PopulatePriorBoxParameter; populate_parameter_funcs_[schema::PrimitiveType_QuantDTypeCast] = PopulateQuantDTypeCastParameter; + populate_parameter_funcs_[schema::PrimitiveType_Lstm] = PopulateLstmParameter; } PopulateParameterRegistry *PopulateParameterRegistry::GetInstance() { diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/lstm.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/lstm.cc new file mode 100644 index 0000000000..cc779de6e5 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/lstm.cc @@ -0,0 +1,185 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/fp32/lstm.h" +#include +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "include/errorcode.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_Lstm; + +namespace mindspore::kernel { +int LstmCPUKernel::InitParam() { + auto input = inputs_.front(); + MS_ASSERT(input != nullptr); + std::vector in_shape = input->shape(); + lstm_parm_->seq_len_ = in_shape[0]; + lstm_parm_->batch_ = in_shape[1]; + lstm_parm_->input_size_ = in_shape[2]; + + auto weight_i = inputs_[1]; + MS_ASSERT(weight_i != nullptr); + std::vector w_shape = weight_i->shape(); + lstm_parm_->hidden_size_ = w_shape[1] / 4; + + lstm_parm_->input_step_ = lstm_parm_->batch_ * lstm_parm_->input_size_; + lstm_parm_->output_step_ = lstm_parm_->bidirectional_ ? 2 * lstm_parm_->batch_ * lstm_parm_->hidden_size_ + : lstm_parm_->batch_ * lstm_parm_->hidden_size_; + return RET_OK; +} + +int LstmCPUKernel::InitBuffer() { + gate_buffer_ = reinterpret_cast(malloc(4 * lstm_parm_->batch_ * lstm_parm_->hidden_size_ * sizeof(float))); + if (gate_buffer_ == nullptr) { + MS_LOG(ERROR) << "LstmCPUKernel malloc gate_buffer error."; + return RET_ERROR; + } + return RET_OK; +} + +int LstmCPUKernel::InitWeightBias() { + // copy weight_i and weight_h + auto weight_i = inputs_.at(1); + MS_ASSERT(weight_i != nullptr); + weight_i_ptr_ = reinterpret_cast(malloc(weight_i->ElementsNum() * sizeof(float))); + if (weight_i_ptr_ == nullptr) { + MS_LOG(ERROR) << "LstmCPUKernel malloc weight_i_ptr_ error."; + return RET_ERROR; + } + memcpy(weight_i_ptr_, weight_i->Data(), weight_i->ElementsNum() * sizeof(float)); + + auto weight_h = inputs_.at(2); + MS_ASSERT(weight_h != nullptr); + weight_h_ptr_ = reinterpret_cast(malloc(weight_h->ElementsNum() * sizeof(float))); + if (weight_h_ptr_ == nullptr) { + MS_LOG(ERROR) << "LstmCPUKernel malloc weight_h_ error."; + return RET_ERROR; + } + memcpy(weight_h_ptr_, weight_h->Data(), weight_h->ElementsNum() * sizeof(float)); + + // init bias + int bias_num = lstm_parm_->bidirectional_ ? 2 * 4 * lstm_parm_->hidden_size_ : 4 * lstm_parm_->hidden_size_; + bias_ptr_ = reinterpret_cast(malloc(bias_num * sizeof(float))); + if (bias_ptr_ == nullptr) { + MS_LOG(ERROR) << "LstmCPUKernel malloc bias_ptr_ error."; + return RET_ERROR; + } + + auto bias_data = reinterpret_cast(inputs_.at(3)->Data()); + int state_bias_offset = 4 * lstm_parm_->hidden_size_; + for (int i = 0; i < state_bias_offset; i++) { + bias_ptr_[i] = bias_data[i] + bias_data[i + state_bias_offset]; + } + if (lstm_parm_->bidirectional_) { + bias_data += 4 * lstm_parm_->hidden_size_ * 2; + auto backward_bias = bias_ptr_ + 4 * lstm_parm_->hidden_size_; + for (int i = 0; i < state_bias_offset; i++) { + backward_bias[i] = bias_data[i] + bias_data[i + state_bias_offset]; + } + } + return RET_OK; +} + +int LstmCPUKernel::Init() { + auto ret = InitParam(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "LstmCPUKernel InitParam error."; + return RET_ERROR; + } + + ret = InitWeightBias(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "LstmCPUKernel InitWeightBias error."; + return RET_ERROR; + } + + ret = InitBuffer(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "LstmCPUKernel InitBuffer error."; + return RET_ERROR; + } + return RET_OK; +} + +int LstmCPUKernel::ReSize() { + free(gate_buffer_); + + auto ret = InitParam(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "LstmCPUKernel InitParam error."; + return RET_ERROR; + } + + ret = InitBuffer(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "LstmCPUKernel InitBuffer error."; + return RET_ERROR; + } + return RET_OK; +} + +int LstmCPUKernel::Run() { + auto input = inputs_.at(kInputIndex); + MS_ASSERT(input != nullptr); + auto hidden_state = inputs_.at(4); + MS_ASSERT(hidden_state != nullptr); + auto cell_state = inputs_.at(5); + MS_ASSERT(cell_state != nullptr); + auto output = outputs_.at(0); + MS_ASSERT(output != nullptr); + + auto input_ptr = reinterpret_cast(input->Data()); + auto output_ptr = reinterpret_cast(output->Data()); + + auto output_hidden_state = outputs_[1]; + memcpy(output_hidden_state->Data(), hidden_state->Data(), hidden_state->ElementsNum() * sizeof(float)); + auto output_cell_state = outputs_[2]; + memcpy(output_cell_state->Data(), cell_state->Data(), cell_state->ElementsNum() * sizeof(float)); + + Lstm(output_ptr, input_ptr, weight_i_ptr_, weight_h_ptr_, bias_ptr_, + reinterpret_cast(output_hidden_state->Data()), reinterpret_cast(output_cell_state->Data()), + gate_buffer_, lstm_parm_); + return RET_OK; +} + +kernel::LiteKernel *CpuLstmKernelCreator(const std::vector &inputs, + const std::vector &outputs, OpParameter *opParameter, + const lite::Context *ctx, const kernel::KernelKey &desc) { + MS_ASSERT(opParameter != nullptr); + MS_ASSERT(desc.type == schema::PrimitiveType_Lstm); + + auto *kernel = new (std::nothrow) LstmCPUKernel(opParameter, inputs, outputs, ctx); + if (kernel == nullptr) { + MS_LOG(ERROR) << "kernel is nullptr."; + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + delete kernel; + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Lstm, CpuLstmKernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/lstm.h b/mindspore/lite/src/runtime/kernel/arm/fp32/lstm.h new file mode 100644 index 0000000000..35ec4a115d --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/lstm.h @@ -0,0 +1,57 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_LSTM_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_LSTM_H_ + +#include +#include "src/lite_kernel.h" +#include "src/runtime/kernel/arm/opclib/fp32/lstm.h" + +namespace mindspore::kernel { +class LstmCPUKernel : public LiteKernel { + public: + LstmCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const lite::Context *ctx) + : LiteKernel(parameter, inputs, outputs) { + lstm_parm_ = reinterpret_cast(opParameter); + } + + ~LstmCPUKernel() override { + free(gate_buffer_); + free(weight_i_ptr_); + free(weight_h_ptr_); + free(bias_ptr_); + } + + int Init() override; + int ReSize() override; + int Run() override; + + int InitParam(); + int InitBuffer(); + int InitWeightBias(); + + private: + float *gate_buffer_; + float *weight_i_ptr_; + float *weight_h_ptr_; + float *bias_ptr_; + LstmParameter *lstm_parm_; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_LSTM_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/lstm.cc b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/lstm.cc new file mode 100644 index 0000000000..3504000572 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/lstm.cc @@ -0,0 +1,158 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/opclib/fp32/lstm.h" +#include +#include "src/runtime/kernel/arm/opclib/fp32/activation.h" +#include "src/runtime/kernel/arm/opclib/fp32/arithmetic.h" + +void InitGate(float *gate_buffer, const float *bias, LstmParameter *lstm_parm) { + int gate_offest = 0; + for (int l = 0; l < 4; l++) { + int batch_offest = gate_offest; + int bias_offest = l * lstm_parm->hidden_size_; + for (int b = 0; b < lstm_parm->batch_; b++) { + memcpy(gate_buffer + batch_offest, bias + bias_offest, lstm_parm->hidden_size_ * sizeof(float)); + batch_offest += lstm_parm->hidden_size_; + } + gate_offest += lstm_parm->batch_ * lstm_parm->hidden_size_; + } +} + +// input: [row, inner_size]; weight: [col, inner_size]; output: [row, col] +void MatMulAcc(float *output, const float *input, const float *weight, int rows, int cols, int inner_size) { + for (int r = 0; r < rows; r++) { + for (int c = 0; c < cols; c++) { + float res = 0; + for (int i = 0; i < inner_size; i++) { + res += input[r * inner_size + i] * weight[c * inner_size + i]; + } + output[r * cols + c] += res; + } + } +} + +void ElementMulAcc(float *input0, float *input1, float *output, int element_size) { + for (int index = 0; index < element_size; index++) { + output[index] += input0[index] * input1[index]; + } +} + +void UpdataState(float *cell_state, float *forget_gate, float *input_gate, float *cell_gate, int batch, + int hidden_size) { + ElementMul(forget_gate, cell_state, cell_state, batch * hidden_size); + ElementMulAcc(input_gate, cell_gate, cell_state, batch * hidden_size); +} + +void UpdataOutput(float *cell_state, float *output_gate, float *hidden_state, int batch, int hidden_size) { + Tanh(cell_state, batch * hidden_size, hidden_state); + ElementMul(hidden_state, output_gate, hidden_state, batch * hidden_size); +} + +void LstmStepUnit(float *output, const float *input, const float *input_input_weight, const float *input_forget_weight, + const float *input_cell_weight, const float *input_output_weight, const float *state_input_weight, + const float *state_forget_weight, const float *state_cell_weight, const float *state_output_weight, + const float *bias, float *hidden_state, float *cell_state, float *gate_buffer, + LstmParameter *lstm_parm) { + InitGate(gate_buffer, bias, lstm_parm); + + float *input_gate = gate_buffer; + float *forget_gate = gate_buffer + lstm_parm->batch_ * lstm_parm->hidden_size_ * 2; + float *cell_gate = gate_buffer + lstm_parm->batch_ * lstm_parm->hidden_size_ * 3; + float *output_gate = gate_buffer + lstm_parm->batch_ * lstm_parm->hidden_size_ * 1; + + // input * weight + MatMulAcc(input_gate, input, input_input_weight, lstm_parm->batch_, lstm_parm->hidden_size_, lstm_parm->input_size_); + MatMulAcc(forget_gate, input, input_forget_weight, lstm_parm->batch_, lstm_parm->hidden_size_, + lstm_parm->input_size_); + MatMulAcc(cell_gate, input, input_cell_weight, lstm_parm->batch_, lstm_parm->hidden_size_, lstm_parm->input_size_); + MatMulAcc(output_gate, input, input_output_weight, lstm_parm->batch_, lstm_parm->hidden_size_, + lstm_parm->input_size_); + + // state * weight + MatMulAcc(input_gate, hidden_state, state_input_weight, lstm_parm->batch_, lstm_parm->hidden_size_, + lstm_parm->hidden_size_); + MatMulAcc(forget_gate, hidden_state, state_forget_weight, lstm_parm->batch_, lstm_parm->hidden_size_, + lstm_parm->hidden_size_); + MatMulAcc(cell_gate, hidden_state, state_cell_weight, lstm_parm->batch_, lstm_parm->hidden_size_, + lstm_parm->hidden_size_); + MatMulAcc(output_gate, hidden_state, state_output_weight, lstm_parm->batch_, lstm_parm->hidden_size_, + lstm_parm->hidden_size_); + + // update input_gate + Sigmoid(input_gate, lstm_parm->batch_ * lstm_parm->hidden_size_, input_gate); + + // update forget_gate + Sigmoid(forget_gate, lstm_parm->batch_ * lstm_parm->hidden_size_, forget_gate); + + // update cell_gate + Tanh(cell_gate, lstm_parm->batch_ * lstm_parm->hidden_size_, cell_gate); + // update cell state + UpdataState(cell_state, forget_gate, input_gate, cell_gate, lstm_parm->batch_, lstm_parm->hidden_size_); + + // update output_gate + Sigmoid(output_gate, lstm_parm->batch_ * lstm_parm->hidden_size_, output_gate); + // update output + UpdataOutput(cell_state, output_gate, hidden_state, lstm_parm->batch_, lstm_parm->hidden_size_); + memcpy(output, hidden_state, lstm_parm->batch_ * lstm_parm->hidden_size_ * sizeof(float)); +} + +void Lstm(float *output, const float *input, const float *weight_i, const float *weight_h, const float *bias, + float *hidden_state, float *cell_state, float *gate_buffer, LstmParameter *lstm_parm) { + // forward + const float *input_input_weight = weight_i; + const float *input_forget_weight = weight_i + lstm_parm->input_size_ * lstm_parm->hidden_size_ * 2; + const float *input_cell_weight = weight_i + lstm_parm->input_size_ * lstm_parm->hidden_size_ * 3; + const float *input_output_weight = weight_i + lstm_parm->input_size_ * lstm_parm->hidden_size_ * 1; + + const float *state_input_weight = weight_h; + const float *state_forget_weight = weight_h + lstm_parm->hidden_size_ * lstm_parm->hidden_size_ * 2; + const float *state_cell_weight = weight_h + lstm_parm->hidden_size_ * lstm_parm->hidden_size_ * 3; + const float *state_output_weight = weight_h + lstm_parm->hidden_size_ * lstm_parm->hidden_size_ * 1; + + for (int t = 0; t < lstm_parm->seq_len_; t++) { + const float *input_ptr = input + t * lstm_parm->input_step_; + float *output_ptr = output + t * lstm_parm->output_step_; + LstmStepUnit(output_ptr, input_ptr, input_input_weight, input_forget_weight, input_cell_weight, input_output_weight, + state_input_weight, state_forget_weight, state_cell_weight, state_output_weight, bias, hidden_state, + cell_state, gate_buffer, lstm_parm); + } + + // backward + if (lstm_parm->bidirectional_) { + input_input_weight = weight_i + lstm_parm->input_size_ * lstm_parm->hidden_size_ * 4; + input_forget_weight = weight_i + lstm_parm->input_size_ * lstm_parm->hidden_size_ * 6; + input_cell_weight = weight_i + lstm_parm->input_size_ * lstm_parm->hidden_size_ * 7; + input_output_weight = weight_i + lstm_parm->input_size_ * lstm_parm->hidden_size_ * 5; + + state_input_weight = weight_h + lstm_parm->hidden_size_ * lstm_parm->hidden_size_ * 4; + state_forget_weight = weight_h + lstm_parm->hidden_size_ * lstm_parm->hidden_size_ * 6; + state_cell_weight = weight_h + lstm_parm->hidden_size_ * lstm_parm->hidden_size_ * 7; + state_output_weight = weight_h + lstm_parm->hidden_size_ * lstm_parm->hidden_size_ * 5; + + float *backward_output = output + lstm_parm->batch_ * lstm_parm->hidden_size_; + const float *backward_bias = bias + 4 * lstm_parm->hidden_size_; + float *backward_cell_state = cell_state + lstm_parm->batch_ * lstm_parm->hidden_size_; + float *backward_hidden_state = hidden_state + lstm_parm->batch_ * lstm_parm->hidden_size_; + for (int t = lstm_parm->seq_len_ - 1; t >= 0; t--) { + const float *input_ptr = input + t * lstm_parm->input_step_; + float *output_ptr = backward_output + t * lstm_parm->output_step_; + LstmStepUnit(output_ptr, input_ptr, input_input_weight, input_forget_weight, input_cell_weight, + input_output_weight, state_input_weight, state_forget_weight, state_cell_weight, state_output_weight, + backward_bias, backward_hidden_state, backward_cell_state, gate_buffer, lstm_parm); + } + } +} diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/lstm.h b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/lstm.h new file mode 100644 index 0000000000..6775ac6314 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/lstm.h @@ -0,0 +1,36 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FP32_LSTM_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FP32_LSTM_H_ + +#include "src/runtime/kernel/arm/opclib/op_base.h" + +struct LstmParameter { + OpParameter op_parameter_; + int input_size_; + int hidden_size_; // output_size + int seq_len_; + int batch_; + int input_step_; + int output_step_; + bool bidirectional_; +}; + +void Lstm(float *output, const float *input, const float *weight_i, const float *weight_h, const float *bias, + float *hidden_state, float *cell_state, float *gate_buffer, LstmParameter *lstm_parm); + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FP32_LSTM_H_ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/lstm_fp32_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/lstm_fp32_tests.cc new file mode 100644 index 0000000000..5af2f639ff --- /dev/null +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/lstm_fp32_tests.cc @@ -0,0 +1,330 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include "utils/log_adapter.h" +#include "common/common_test.h" +#include "mindspore/lite/src/runtime/kernel/arm/opclib/fp32/lstm.h" +#include "mindspore/lite/src/kernel_registry.h" +#include "mindspore/lite/src/ops/ops.h" + +namespace mindspore { +class LstmFp32 : public mindspore::Common { + public: + LstmFp32() {} +}; + +void InitLstmParam(LstmParameter *lstm_param) { + lstm_param->seq_len_ = 4; + lstm_param->batch_ = 1; + lstm_param->input_size_ = 2; + lstm_param->hidden_size_ = 3; + lstm_param->bidirectional_ = false; +} + +void InitLstmForwardCreator(std::vector *inputs, std::vector *outputs, + const LstmParameter *lstm_param) { + // prepare input + std::vector input_data = {1.3889, -0.3006, -0.1787, 2.1504, -0.3181, 0.4945, -0.4758, -0.8187}; + auto *input = new lite::tensor::Tensor; + input->set_data_type(kNumberTypeFloat32); + input->set_shape({lstm_param->seq_len_, lstm_param->batch_, lstm_param->input_size_}); + input->MallocData(); + memcpy(input->Data(), input_data.data(), input_data.size() * sizeof(float)); + + // prepare weight_i + std::vector weight_i_data = {0.21368974, -0.3778776, 0.05025542, 0.09011161, 0.18355745, 0.5491228, + -0.14186832, -0.4655916, 0.49541366, -0.44039622, 0.5625571, 0.23325664, + 0.3449825, -0.42750397, 0.01911497, -0.4125802, -0.56690466, 0.50593233, + -0.29129684, -0.27841482, 0.01964372, -0.42543447, 0.41720617, -0.30054367}; + auto *weight_i = new lite::tensor::Tensor; + weight_i->set_data_type(kNumberTypeFloat32); + weight_i->SetFormat(schema::Format_NHWC); + weight_i->set_shape({1, lstm_param->hidden_size_ * 4, lstm_param->input_size_}); + weight_i->MallocData(); + memcpy(weight_i->Data(), weight_i_data.data(), weight_i_data.size() * sizeof(float)); + + // prepare weight_r + std::vector weight_h_data = { + -0.03424168, 0.00643545, 0.36867607, -0.08598137, 0.19804275, -0.11319417, -0.0244593, -0.16440144, -0.07268238, + 0.09828371, 0.33358777, 0.53381383, -0.39431244, -0.06005383, -0.3520246, 0.42687547, 0.5772828, 0.5380008, + -0.16130409, -0.24737108, 0.42409766, -0.50648475, 0.48223662, -0.5221103, -0.49216837, -0.29084128, 0.3408438, + 0.34080023, 0.49467337, 0.23473483, 0.01759732, 0.04691631, 0.45574808, -0.29481018, 0.29442167, -0.36718}; + auto *weight_h = new lite::tensor::Tensor; + weight_h->set_data_type(kNumberTypeFloat32); + weight_h->SetFormat(schema::Format_NHWC); + weight_h->set_shape({1, lstm_param->hidden_size_ * 4, lstm_param->hidden_size_}); + weight_h->MallocData(); + memcpy(weight_h->Data(), weight_h_data.data(), weight_h_data.size() * sizeof(float)); + + // prepare bias + std::vector bias_data = {-0.00207639, 0.16391152, -0.00069344, -0.32945693, -0.367423, 0.28301108, + -0.17930457, 0.5278388, 0.12598747, -0.53130764, 0.1479364, 0.16695255, + -0.00708795, -0.46417096, -0.23966661, -0.17496741, -0.19166365, -0.50466555, + -0.23593256, -0.3911457, 0.51128435, 0.5128727, 0.253451, -0.51891875}; + auto *bias = new lite::tensor::Tensor; + bias->set_data_type(kNumberTypeFloat32); + bias->SetFormat(schema::Format_NHWC); + bias->set_shape({1, lstm_param->hidden_size_ * 4 * 2}); + bias->MallocData(); + memcpy(bias->Data(), bias_data.data(), bias_data.size() * sizeof(float)); + + // prepare state + std::vector state_data = {0, 0, 0}; + auto *state = new lite::tensor::Tensor; + state->set_data_type(kNumberTypeFloat32); + state->SetFormat(schema::Format_NHWC); + state->set_shape({1, lstm_param->batch_, lstm_param->hidden_size_}); + state->MallocData(); + memcpy(state->Data(), state_data.data(), state_data.size() * sizeof(float)); + + inputs->push_back(input); + inputs->push_back(weight_i); + inputs->push_back(weight_h); + inputs->push_back(bias); + inputs->push_back(state); + inputs->push_back(state); + + // malloc output buffer, for arm cpu, format: N C4 H W 4 + auto *output = new lite::tensor::Tensor; + output->set_data_type(kNumberTypeFloat32); + output->set_shape({lstm_param->seq_len_, lstm_param->batch_, lstm_param->hidden_size_}); + output->SetFormat(schema::Format_NHWC); + output->MallocData(); + memset(output->Data(), 0, output->ElementsNum() * sizeof(float)); + + auto *cell_state = new lite::tensor::Tensor; + cell_state->set_data_type(kNumberTypeFloat32); + cell_state->set_shape({1, lstm_param->batch_, lstm_param->hidden_size_}); + cell_state->SetFormat(schema::Format_NHWC); + cell_state->MallocData(); + memset(cell_state->Data(), 0, cell_state->ElementsNum() * sizeof(float)); + + auto *hidden_state = new lite::tensor::Tensor; + hidden_state->set_data_type(kNumberTypeFloat32); + hidden_state->set_shape({1, lstm_param->batch_, lstm_param->hidden_size_}); + hidden_state->SetFormat(schema::Format_NHWC); + hidden_state->MallocData(); + memset(hidden_state->Data(), 0, hidden_state->ElementsNum() * sizeof(float)); + + outputs->push_back(output); + outputs->push_back(cell_state); + outputs->push_back(hidden_state); +} + +void CompareOutput(lite::tensor::Tensor *output, std::vector data) { + for (int i = 0; i < output->ElementsNum(); i++) { + std::cout << reinterpret_cast(output->Data())[i] << ", "; + } + std::cout << std::endl; + + Common::CompareOutputData(reinterpret_cast(output->Data()), data.data(), output->ElementsNum(), 0.0001); +} + +TEST_F(LstmFp32, LstmForwardFp32Accuracy) { + // prepare stage + auto lstm_param = new LstmParameter(); + InitLstmParam(lstm_param); + + // init ctx + auto ctx = new lite::Context(); + ctx->thread_num_ = 1; + + // init tensor + std::vector inputs; + std::vector outputs; + InitLstmForwardCreator(&inputs, &outputs, lstm_param); + + // register op + kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeFloat32, mindspore::schema::PrimitiveType_Lstm}; + auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); + ASSERT_NE(creator, nullptr); + kernel::LiteKernel *kernel = creator(inputs, outputs, reinterpret_cast(lstm_param), ctx, desc); + ASSERT_NE(kernel, nullptr); + // op run + kernel->Run(); + + std::cout << "==================output data=================" << std::endl; + std::vector output0_data = {-0.0702, 0.1225, 0.0876, -0.0357, -0.0227, -0.2294, + -0.0345, -0.0108, -0.2002, 0.0451, 0.0853, -0.1205}; + CompareOutput(outputs[0], output0_data); + + std::vector output1_data = {0.0451, 0.0853, -0.1205}; + CompareOutput(outputs[1], output1_data); + + std::vector output2_data = {0.0989, 0.2094, -0.4132}; + CompareOutput(outputs[2], output2_data); + + delete lstm_param; + for (int i = 0; i < inputs.size() - 1; i++) { + delete inputs[i]; + } + for (int i = 0; i < outputs.size(); i++) { + delete outputs[i]; + } + delete kernel; + MS_LOG(INFO) << "LstmFp32 forward accuracy passed"; +} + +void InitLstmBackwardCreator(std::vector *inputs, std::vector *outputs, + const LstmParameter *lstm_param) { + // prepare input + std::vector input_data = {1.4305, 0.5342, -0.9221, 0.0527, 2.3770, -0.3697, -0.2833, -2.1285}; + auto *input = new lite::tensor::Tensor; + input->set_data_type(kNumberTypeFloat32); + input->set_shape({lstm_param->seq_len_, lstm_param->batch_, lstm_param->input_size_}); + input->MallocData(); + memcpy(input->Data(), input_data.data(), input_data.size() * sizeof(float)); + + // prepare weight_i + std::vector weight_i_data = { + -0.19253477, -0.007966279, -0.06039094, 0.27697134, -0.5071223, 0.18996351, 0.20472168, -0.1007814, + 0.04282999, 0.20836472, -0.4654655, 0.050321221, -0.3431457, 0.22256428, 0.29294532, 0.45042896, + 0.20468240, 0.13078391, -0.20987969, -0.3173505, -0.3813517, 0.10205835, 0.21858131, -0.0386473, + 0.5512280, -0.2763766, -0.3593936, -0.5181975, 0.3469863, -0.38533931, 0.010202527, -0.46598294, + -0.5740513, 0.06127524, -0.03960543, 0.2478809, -0.17296993, 0.19159525, -0.4976995, 0.05985528, + 0.3653409, 0.386924, 0.3170289, -0.08830952, -0.31105759, 0.3110240, 0.15174299, 0.287579894}; + auto *weight_i = new lite::tensor::Tensor; + weight_i->set_data_type(kNumberTypeFloat32); + weight_i->SetFormat(schema::Format_NHWC); + weight_i->set_shape({2, lstm_param->hidden_size_ * 4, lstm_param->input_size_}); + weight_i->MallocData(); + memcpy(weight_i->Data(), weight_i_data.data(), weight_i_data.size() * sizeof(float)); + + // prepare weight_r + std::vector weight_h_data = { + 0.106934666, -0.50430017, 0.33296257, -0.288117021, -0.38019785, -0.147071093, 0.422707557, 0.41497004, + -0.5329730, -0.430150926, -0.032713949, 0.35401260, 0.179495036, -0.14158579, 0.380428612, -0.175597071, + 0.54088723, -0.403292059, -0.287720531, -0.51250511, -0.15405902, -0.440592586, 0.16726928, -0.0163397789, + 0.51673841, 0.5094323, -0.137105107, -0.181070089, -0.47221425, -0.38046866, -0.206725060, 0.248537719, + -0.23961094, -0.117781728, 0.426800847, 0.0266208052, -0.197408229, 0.54831492, -0.280048757, -0.125062286, + -0.29929456, 0.42354834, -0.401066303, 0.356340110, 0.54629492, -0.15852552, 0.131406366, -0.101815432, + 0.0121276974, -0.53553336, 0.121099889, 0.060554087, 0.46259057, -0.49666053, 0.090806663, 0.20542401, + -0.38674920, -0.23874849, -0.5222138, 0.57537007, 0.113343358, -0.35233467, -0.25532332, 0.159506142, + 0.35996592, -0.201961308, -0.16323345, 0.119177639, -0.12677872, -0.175229549, -0.160024613, -0.21058899}; + auto *weight_h = new lite::tensor::Tensor; + weight_h->set_data_type(kNumberTypeFloat32); + weight_h->SetFormat(schema::Format_NHWC); + weight_h->set_shape({2, lstm_param->hidden_size_ * 4, lstm_param->hidden_size_}); + weight_h->MallocData(); + memcpy(weight_h->Data(), weight_h_data.data(), weight_h_data.size() * sizeof(float)); + + // prepare bias + std::vector bias_data = { + 0.57061123, -0.25357073, -0.146834075, 0.412972748, -0.27809411, -0.0542128682, -0.45384609, -0.53261917, + 0.222133636, -0.18093895, -0.045559883, 0.09109061, 0.080319643, 0.455167174, 0.36235427, -0.00164419412, + -0.135566502, 0.41905909, -0.450117409, 0.50565385, -0.077815443, -0.47051778, -0.141349375, -0.338519752, + 0.48683023, 0.282384872, 0.13399660, -0.382526844, -0.23370727, -0.184681564, 0.45679104, -0.339453905, + 0.452010273, 0.0552094578, 0.328843057, 0.127738714, -0.127084732, -0.334061294, -0.46742400, -0.401568055, + 0.23712641, -0.052937567, 0.272351622, 0.42767739, 0.303884744, -0.46025499, -0.43985402, 0.256422877}; + auto *bias = new lite::tensor::Tensor; + bias->set_data_type(kNumberTypeFloat32); + bias->SetFormat(schema::Format_NHWC); + bias->set_shape({2, lstm_param->hidden_size_ * 4 * 2}); + bias->MallocData(); + memcpy(bias->Data(), bias_data.data(), bias_data.size() * sizeof(float)); + + // prepare state + std::vector state_data = {0, 0, 0, 0, 0, 0}; + auto *state = new lite::tensor::Tensor; + state->set_data_type(kNumberTypeFloat32); + state->SetFormat(schema::Format_NHWC); + state->set_shape({2, lstm_param->batch_, lstm_param->hidden_size_}); + state->MallocData(); + memcpy(state->Data(), state_data.data(), state_data.size() * sizeof(float)); + + inputs->push_back(input); + inputs->push_back(weight_i); + inputs->push_back(weight_h); + inputs->push_back(bias); + inputs->push_back(state); + inputs->push_back(state); + + // malloc output buffer, for arm cpu, format: N C4 H W 4 + auto *output = new lite::tensor::Tensor; + output->set_data_type(kNumberTypeFloat32); + output->set_shape({lstm_param->seq_len_, 2, lstm_param->batch_, lstm_param->hidden_size_}); + output->SetFormat(schema::Format_NHWC); + output->MallocData(); + memset(output->Data(), 0, output->ElementsNum() * sizeof(float)); + + auto *cell_state = new lite::tensor::Tensor; + cell_state->set_data_type(kNumberTypeFloat32); + cell_state->set_shape({2, lstm_param->batch_, lstm_param->hidden_size_}); + cell_state->SetFormat(schema::Format_NHWC); + cell_state->MallocData(); + memset(cell_state->Data(), 0, cell_state->ElementsNum() * sizeof(float)); + + auto *hidden_state = new lite::tensor::Tensor; + hidden_state->set_data_type(kNumberTypeFloat32); + hidden_state->set_shape({2, lstm_param->batch_, lstm_param->hidden_size_}); + hidden_state->SetFormat(schema::Format_NHWC); + hidden_state->MallocData(); + memset(hidden_state->Data(), 0, hidden_state->ElementsNum() * sizeof(float)); + + outputs->push_back(output); + outputs->push_back(cell_state); + outputs->push_back(hidden_state); +} + +TEST_F(LstmFp32, LstmBackwardFp32Accuracy) { + // prepare stage + auto lstm_param = new LstmParameter(); + InitLstmParam(lstm_param); + lstm_param->bidirectional_ = true; + + // init ctx + auto ctx = new lite::Context(); + ctx->thread_num_ = 1; + + // init tensor + std::vector inputs; + std::vector outputs; + InitLstmBackwardCreator(&inputs, &outputs, lstm_param); + + // register op + kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeFloat32, mindspore::schema::PrimitiveType_Lstm}; + auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); + ASSERT_NE(creator, nullptr); + kernel::LiteKernel *kernel = creator(inputs, outputs, reinterpret_cast(lstm_param), ctx, desc); + ASSERT_NE(kernel, nullptr); + // op run + kernel->Run(); + + std::cout << "==================output data=================" << std::endl; + std::vector output0_data = {-0.2922, -0.1416, 0.0077, -0.0422, -0.0585, 0.2061, -0.2385, -0.0146, + -0.1796, -0.0554, -0.0973, 0.1013, -0.3062, -0.1516, -0.0310, 0.0459, + -0.0784, 0.0949, 0.0249, -0.0653, -0.0869, -0.1113, -0.2155, -0.0500}; + CompareOutput(outputs[0], output0_data); + + std::vector output1_data = {0.0249, -0.0653, -0.0869, -0.0422, -0.0585, 0.2061}; + CompareOutput(outputs[1], output1_data); + + std::vector output2_data = {0.0373, -0.2322, -0.1477, -0.1621, -0.1808, 0.5146}; + CompareOutput(outputs[2], output2_data); + + delete lstm_param; + for (int i = 0; i < inputs.size() - 1; i++) { + delete inputs[i]; + } + for (int i = 0; i < outputs.size(); i++) { + delete outputs[i]; + } + delete kernel; + MS_LOG(INFO) << "LstmFp32 backward accuracy passed"; +} + +} // namespace mindspore