| @@ -0,0 +1,139 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "nnacl/fp16/gru_fp16.h" | |||
| #include <string.h> | |||
| #include "nnacl/fp16/lstm_fp16.h" | |||
| #include "nnacl/fp16/activation_fp16.h" | |||
| #include "nnacl/fp16/arithmetic_fp16.h" | |||
| void InitGruGateFp16(float16_t *gate_buffer, const float16_t *bias, const GruParameter *gru_parm) { | |||
| int gate_offest = 0; | |||
| for (int l = 0; l < 3; l++) { | |||
| int batch_offest = gate_offest; | |||
| int bias_offest = l * gru_parm->hidden_size_; | |||
| for (int b = 0; b < gru_parm->batch_; b++) { | |||
| memcpy(gate_buffer + batch_offest, bias + bias_offest, gru_parm->hidden_size_ * sizeof(float16_t)); | |||
| batch_offest += gru_parm->hidden_size_; | |||
| } | |||
| gate_offest += gru_parm->batch_ * gru_parm->hidden_size_; | |||
| } | |||
| } | |||
| void GruStepUnitFp16(float16_t *output, const float16_t *input, const float16_t *input_reset_weight, | |||
| const float16_t *input_update_weight, const float16_t *input_hidden_weight, | |||
| const float16_t *state_reset_weight, const float16_t *state_update_weight, | |||
| const float16_t *state_hidden_weight, const float16_t *bias, float16_t *hidden_state, | |||
| float16_t *gate_buffer, const GruParameter *gru_parm) { | |||
| InitGruGateFp16(gate_buffer, bias, gru_parm); | |||
| float16_t *update_gate = gate_buffer; | |||
| float16_t *reset_gate = gate_buffer + gru_parm->batch_ * gru_parm->hidden_size_; | |||
| float16_t *hidden_buffer = gate_buffer + gru_parm->batch_ * gru_parm->hidden_size_ * 2; | |||
| // input * weight | |||
| MatMulAccFp16(reset_gate, input, input_reset_weight, gru_parm->batch_, gru_parm->hidden_size_, gru_parm->input_size_); | |||
| MatMulAccFp16(update_gate, input, input_update_weight, gru_parm->batch_, gru_parm->hidden_size_, | |||
| gru_parm->input_size_); | |||
| MatMulAccFp16(hidden_buffer, input, input_hidden_weight, gru_parm->batch_, gru_parm->hidden_size_, | |||
| gru_parm->input_size_); | |||
| // state * weight | |||
| MatMulAccFp16(reset_gate, hidden_state, state_reset_weight, gru_parm->batch_, gru_parm->hidden_size_, | |||
| gru_parm->hidden_size_); | |||
| MatMulAccFp16(update_gate, hidden_state, state_update_weight, gru_parm->batch_, gru_parm->hidden_size_, | |||
| gru_parm->hidden_size_); | |||
| // update reset_gate | |||
| SigmoidFp16(reset_gate, reset_gate, gru_parm->batch_ * gru_parm->hidden_size_); | |||
| // update update_gate | |||
| SigmoidFp16(update_gate, update_gate, gru_parm->batch_ * gru_parm->hidden_size_); | |||
| ElementMulFp16(hidden_state, reset_gate, reset_gate, gru_parm->batch_ * gru_parm->hidden_size_); | |||
| MatMulAccFp16(hidden_buffer, reset_gate, state_hidden_weight, gru_parm->batch_, gru_parm->hidden_size_, | |||
| gru_parm->hidden_size_); | |||
| TanhFp16(hidden_buffer, hidden_buffer, gru_parm->batch_ * gru_parm->hidden_size_); | |||
| ElementMulFp16(update_gate, hidden_state, hidden_state, gru_parm->batch_ * gru_parm->hidden_size_); | |||
| ArithmeticParameter parameter; | |||
| parameter.in_elements_num0_ = 1; | |||
| parameter.in_elements_num1_ = gru_parm->batch_ * gru_parm->hidden_size_; | |||
| float16_t one = 1.0f; | |||
| ElementOptSubFp16(&one, update_gate, update_gate, gru_parm->batch_ * gru_parm->hidden_size_, ¶meter); | |||
| ElementMulAccFp16(update_gate, hidden_buffer, hidden_state, gru_parm->batch_ * gru_parm->hidden_size_); | |||
| memcpy(output, hidden_state, gru_parm->batch_ * gru_parm->hidden_size_ * sizeof(float16_t)); | |||
| } | |||
| void GruFp16(float16_t *output, const float16_t *input, const float16_t *weight_g, const float16_t *weight_r, | |||
| const float16_t *bias, float16_t *hidden_state, float16_t *gate_buffer, int check_seq_len, | |||
| const GruParameter *gru_parm) { | |||
| // forward | |||
| const float16_t *input_update_weight = weight_g; | |||
| const float16_t *input_reset_weight = weight_g + gru_parm->input_size_ * gru_parm->hidden_size_; | |||
| const float16_t *input_hidden_weight = weight_g + gru_parm->input_size_ * gru_parm->hidden_size_ * 2; | |||
| const float16_t *state_update_weight = weight_r; | |||
| const float16_t *state_reset_weight = weight_r + gru_parm->hidden_size_ * gru_parm->hidden_size_; | |||
| const float16_t *state_hidden_weight = weight_r + gru_parm->hidden_size_ * gru_parm->hidden_size_ * 2; | |||
| for (int t = 0; t < check_seq_len; t++) { | |||
| const float16_t *input_ptr = input + t * gru_parm->input_step_; | |||
| float16_t *output_ptr = output + t * gru_parm->output_step_; | |||
| GruStepUnitFp16(output_ptr, input_ptr, input_reset_weight, input_update_weight, input_hidden_weight, | |||
| state_reset_weight, state_update_weight, state_hidden_weight, bias, hidden_state, gate_buffer, | |||
| gru_parm); | |||
| } | |||
| // zero out extra fw outputs | |||
| for (int t = check_seq_len; t < gru_parm->seq_len_; t++) { | |||
| float16_t *output_ptr = output + t * gru_parm->output_step_; | |||
| for (int i = 0; i < gru_parm->batch_ * gru_parm->hidden_size_; i++) { | |||
| output_ptr[i] = 0.0f; | |||
| } | |||
| } | |||
| // backward | |||
| if (gru_parm->bidirectional_) { | |||
| input_update_weight = weight_g + gru_parm->input_size_ * gru_parm->hidden_size_ * 3; | |||
| input_reset_weight = weight_g + gru_parm->input_size_ * gru_parm->hidden_size_ * 4; | |||
| input_hidden_weight = weight_g + gru_parm->input_size_ * gru_parm->hidden_size_ * 5; | |||
| state_update_weight = weight_r + gru_parm->hidden_size_ * gru_parm->hidden_size_ * 3; | |||
| state_reset_weight = weight_r + gru_parm->hidden_size_ * gru_parm->hidden_size_ * 4; | |||
| state_hidden_weight = weight_r + gru_parm->hidden_size_ * gru_parm->hidden_size_ * 5; | |||
| float16_t *backward_output = output + gru_parm->batch_ * gru_parm->hidden_size_; | |||
| const float16_t *backward_bias = bias + 3 * gru_parm->hidden_size_; | |||
| float16_t *backward_hidden_state = hidden_state + gru_parm->batch_ * gru_parm->hidden_size_; | |||
| for (int t = check_seq_len - 1; t >= 0; t--) { | |||
| const float16_t *input_ptr = input + t * gru_parm->input_step_; | |||
| float16_t *output_ptr = backward_output + t * gru_parm->output_step_; | |||
| GruStepUnitFp16(output_ptr, input_ptr, input_reset_weight, input_update_weight, input_hidden_weight, | |||
| state_reset_weight, state_update_weight, state_hidden_weight, backward_bias, | |||
| backward_hidden_state, gate_buffer, gru_parm); | |||
| } | |||
| // zero out extra bw outputs | |||
| for (int t = gru_parm->seq_len_ - 1; t >= check_seq_len; t--) { | |||
| float16_t *output_ptr = backward_output + t * gru_parm->output_step_; | |||
| for (int i = 0; i < gru_parm->batch_ * gru_parm->hidden_size_; i++) { | |||
| output_ptr[i] = 0.0f; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,30 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_NNACL_FP16_GRU_H_ | |||
| #define MINDSPORE_LITE_NNACL_FP16_GRU_H_ | |||
| #include "nnacl/gru_parameter.h" | |||
| #ifdef __cplusplus | |||
| extern "C" { | |||
| #endif | |||
| void GruFp16(float16_t *output, const float16_t *input, const float16_t *weight_g, const float16_t *weight_r, | |||
| const float16_t *bias, float16_t *hidden_state, float16_t *gate_buffer, int check_seq_len, | |||
| const GruParameter *gru_parm); | |||
| #ifdef __cplusplus | |||
| } | |||
| #endif | |||
| #endif // MINDSPORE_LITE_NNACL_FP16_GRU_H_ | |||
| @@ -15,21 +15,7 @@ | |||
| */ | |||
| #ifndef MINDSPORE_LITE_NNACL_FP32_GRU_FP32_H_ | |||
| #define MINDSPORE_LITE_NNACL_FP32_GRU_FP32_H_ | |||
| #include "nnacl/op_base.h" | |||
| typedef struct GruParameter { | |||
| // Primitive parameter | |||
| OpParameter op_parameter_; | |||
| // shape correlative | |||
| int input_size_; | |||
| int hidden_size_; // output_size | |||
| int seq_len_; | |||
| int batch_; | |||
| // other parameter | |||
| int input_step_; | |||
| int output_step_; | |||
| bool bidirectional_; | |||
| } GruParameter; | |||
| #include "nnacl/gru_parameter.h" | |||
| #ifdef __cplusplus | |||
| extern "C" { | |||
| @@ -0,0 +1,35 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_NNACL_GRU_PARAMETER_H_ | |||
| #define MINDSPORE_LITE_NNACL_GRU_PARAMETER_H_ | |||
| #include "nnacl/op_base.h" | |||
| typedef struct GruParameter { | |||
| // Primitive parameter | |||
| OpParameter op_parameter_; | |||
| // shape correlative | |||
| int input_size_; | |||
| int hidden_size_; // output_size | |||
| int seq_len_; | |||
| int batch_; | |||
| // other parameter | |||
| int input_step_; | |||
| int output_step_; | |||
| bool bidirectional_; | |||
| } GruParameter; | |||
| #endif // MINDSPORE_LITE_NNACL_GRU_PARAMETER_H_ | |||
| @@ -0,0 +1,189 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "src/runtime/kernel/arm/fp16/gru_fp16.h" | |||
| #include <vector> | |||
| #include "schema/model_generated.h" | |||
| #include "src/kernel_registry.h" | |||
| #include "include/errorcode.h" | |||
| #include "nnacl/fp16/gru_fp16.h" | |||
| using mindspore::kernel::KERNEL_ARCH::kCPU; | |||
| using mindspore::lite::KernelRegistrar; | |||
| using mindspore::lite::RET_ERROR; | |||
| using mindspore::lite::RET_OK; | |||
| using mindspore::schema::PrimitiveType_Gru; | |||
| namespace mindspore::kernel { | |||
| void GruFp16CPUKernel::FreeTmpBuffer() { | |||
| if (gate_buffer_ != nullptr) { | |||
| free(gate_buffer_); | |||
| gate_buffer_ = nullptr; | |||
| } | |||
| if (bias_ptr_ != nullptr) { | |||
| free(bias_ptr_); | |||
| bias_ptr_ = nullptr; | |||
| } | |||
| if (weight_g_ptr_ != nullptr) { | |||
| free(weight_g_ptr_); | |||
| weight_g_ptr_ = nullptr; | |||
| } | |||
| if (weight_r_ptr_ != nullptr) { | |||
| free(weight_r_ptr_); | |||
| weight_r_ptr_ = nullptr; | |||
| } | |||
| } | |||
| int GruFp16CPUKernel::InitParam() { | |||
| auto input = in_tensors_.front(); | |||
| MS_ASSERT(input != nullptr); | |||
| std::vector<int> in_shape = input->shape(); | |||
| gru_parm_->seq_len_ = in_shape.at(0); | |||
| gru_parm_->batch_ = in_shape.at(1); | |||
| gru_parm_->input_size_ = in_shape.at(2); | |||
| auto weight_g = in_tensors_.at(1); | |||
| MS_ASSERT(weight_g != nullptr); | |||
| std::vector<int> w_shape = weight_g->shape(); | |||
| gru_parm_->hidden_size_ = w_shape.at(1) / 3; | |||
| gru_parm_->input_step_ = gru_parm_->batch_ * gru_parm_->input_size_; | |||
| gru_parm_->output_step_ = gru_parm_->bidirectional_ ? 2 * gru_parm_->batch_ * gru_parm_->hidden_size_ | |||
| : gru_parm_->batch_ * gru_parm_->hidden_size_; | |||
| return RET_OK; | |||
| } | |||
| int GruFp16CPUKernel::InitBuffer() { | |||
| gate_buffer_ = | |||
| reinterpret_cast<float16_t *>(malloc(3 * gru_parm_->batch_ * gru_parm_->hidden_size_ * sizeof(float16_t))); | |||
| if (gate_buffer_ == nullptr) { | |||
| MS_LOG(ERROR) << "GruFp16CPUKernel malloc gate_buffer error."; | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| int GruFp16CPUKernel::InitWeightBias() { | |||
| auto weight_gate = in_tensors_.at(1); | |||
| MS_ASSERT(weight_gate != nullptr); | |||
| weight_g_ptr_ = reinterpret_cast<float16_t *>(malloc(weight_gate->ElementsNum() * sizeof(float16_t))); | |||
| if (weight_g_ptr_ == nullptr) { | |||
| MS_LOG(ERROR) << "GruFp16CPUKernel malloc weight_g_ptr_ error."; | |||
| return RET_ERROR; | |||
| } | |||
| auto weight_g_data = reinterpret_cast<float *>(weight_gate->data_c()); | |||
| for (size_t i = 0; i < weight_gate->ElementsNum(); i++) { | |||
| weight_g_ptr_[i] = (float16_t)weight_g_data[i]; | |||
| } | |||
| auto weight_recu = in_tensors_.at(2); | |||
| MS_ASSERT(weight_recu != nullptr); | |||
| weight_r_ptr_ = reinterpret_cast<float16_t *>(malloc(weight_recu->ElementsNum() * sizeof(float16_t))); | |||
| if (weight_r_ptr_ == nullptr) { | |||
| MS_LOG(ERROR) << "GruFp16CPUKernel malloc weight_r_ptr_ error."; | |||
| return RET_ERROR; | |||
| } | |||
| auto weight_r_data = reinterpret_cast<float *>(weight_recu->data_c()); | |||
| for (size_t i = 0; i < weight_recu->ElementsNum(); i++) { | |||
| weight_r_ptr_[i] = (float16_t)weight_r_data[i]; | |||
| } | |||
| int bias_num = gru_parm_->bidirectional_ ? 2 * 3 * gru_parm_->hidden_size_ : 3 * gru_parm_->hidden_size_; | |||
| bias_ptr_ = reinterpret_cast<float16_t *>(malloc(bias_num * sizeof(float16_t))); | |||
| if (bias_ptr_ == nullptr) { | |||
| MS_LOG(ERROR) << "GruFp16CPUKernel malloc bias_ptr_ error."; | |||
| return RET_ERROR; | |||
| } | |||
| auto bias_data = reinterpret_cast<float *>(in_tensors_.at(3)->data_c()); | |||
| const int state_bias_offset = 3 * gru_parm_->hidden_size_; | |||
| for (int i = 0; i < state_bias_offset; i++) { | |||
| bias_ptr_[i] = (float16_t)(bias_data[i] + bias_data[i + state_bias_offset]); | |||
| } | |||
| if (gru_parm_->bidirectional_) { | |||
| bias_data += 3 * gru_parm_->hidden_size_ * 2; | |||
| auto backward_bias = bias_ptr_ + 3 * gru_parm_->hidden_size_; | |||
| for (int i = 0; i < state_bias_offset; i++) { | |||
| backward_bias[i] = (float16_t)(bias_data[i] + bias_data[i + state_bias_offset]); | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| int GruFp16CPUKernel::Init() { | |||
| if (!InferShapeDone()) { | |||
| return RET_OK; | |||
| } | |||
| return ReSize(); | |||
| } | |||
| int GruFp16CPUKernel::ReSize() { | |||
| FreeTmpBuffer(); | |||
| auto ret = InitParam(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "GruFp16CPUKernel InitParam error."; | |||
| return RET_ERROR; | |||
| } | |||
| ret = InitWeightBias(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "GruFp16CPUKernel InitWeightBias error."; | |||
| FreeTmpBuffer(); | |||
| return RET_ERROR; | |||
| } | |||
| ret = InitBuffer(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "GruFp16CPUKernel InitBuffer error."; | |||
| FreeTmpBuffer(); | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| int GruFp16CPUKernel::Run() { | |||
| auto input = in_tensors_.at(kInputIndex); | |||
| MS_ASSERT(input != nullptr); | |||
| auto hidden_state = in_tensors_.at(4); | |||
| MS_ASSERT(hidden_state != nullptr); | |||
| auto output = out_tensors_.at(0); | |||
| MS_ASSERT(output != nullptr); | |||
| auto input_ptr = reinterpret_cast<float16_t *>(input->data_c()); | |||
| MS_ASSERT(input_ptr); | |||
| auto output_ptr = reinterpret_cast<float16_t *>(output->data_c()); | |||
| MS_ASSERT(output_ptr); | |||
| auto output_hidden_state = out_tensors_[1]; | |||
| memcpy(output_hidden_state->data_c(), hidden_state->data_c(), hidden_state->ElementsNum() * sizeof(float16_t)); | |||
| int check_seq_len = gru_parm_->seq_len_; | |||
| if (in_tensors_.size() == 6) { | |||
| auto seq_len = reinterpret_cast<int *>(in_tensors_.at(5)->data_c()); | |||
| if (!std::equal(seq_len + 1, seq_len + gru_parm_->batch_, seq_len)) { | |||
| MS_LOG(ERROR) << "different batch seq_len is currently not supported"; | |||
| return RET_ERROR; | |||
| } | |||
| check_seq_len = MSMIN(check_seq_len, MSMAX(0, seq_len[0])); | |||
| } | |||
| MS_ASSERT(weight_g_ptr_ != nullptr); | |||
| MS_ASSERT(weight_r_ptr_ != nullptr); | |||
| MS_ASSERT(bias_ptr_ != nullptr); | |||
| MS_ASSERT(gate_buffer_ != nullptr); | |||
| GruFp16(output_ptr, input_ptr, weight_g_ptr_, weight_r_ptr_, bias_ptr_, | |||
| reinterpret_cast<float16_t *>(output_hidden_state->data_c()), gate_buffer_, check_seq_len, gru_parm_); | |||
| return RET_OK; | |||
| } | |||
| REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Gru, LiteKernelCreator<GruFp16CPUKernel>) | |||
| } // namespace mindspore::kernel | |||
| @@ -0,0 +1,52 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_GRU_H_ | |||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_GRU_H_ | |||
| #include <vector> | |||
| #include "src/lite_kernel.h" | |||
| #include "nnacl/gru_parameter.h" | |||
| namespace mindspore::kernel { | |||
| class GruFp16CPUKernel : public LiteKernel { | |||
| public: | |||
| GruFp16CPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | |||
| const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx, | |||
| const mindspore::lite::PrimitiveC *primitive) | |||
| : LiteKernel(parameter, inputs, outputs, ctx, primitive) { | |||
| gru_parm_ = reinterpret_cast<GruParameter *>(op_parameter_); | |||
| } | |||
| ~GruFp16CPUKernel() override { FreeTmpBuffer(); } | |||
| int Init() override; | |||
| int ReSize() override; | |||
| int Run() override; | |||
| private: | |||
| void FreeTmpBuffer(); | |||
| int InitParam(); | |||
| int InitBuffer(); | |||
| int InitWeightBias(); | |||
| float16_t *gate_buffer_ = nullptr; | |||
| float16_t *weight_g_ptr_ = nullptr; | |||
| float16_t *weight_r_ptr_ = nullptr; | |||
| float16_t *bias_ptr_ = nullptr; | |||
| GruParameter *gru_parm_ = nullptr; | |||
| }; | |||
| } // namespace mindspore::kernel | |||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_GRU_H_ | |||
| @@ -18,6 +18,7 @@ | |||
| #include "schema/model_generated.h" | |||
| #include "src/kernel_registry.h" | |||
| #include "include/errorcode.h" | |||
| #include "nnacl/fp32/gru_fp32.h" | |||
| using mindspore::kernel::KERNEL_ARCH::kCPU; | |||
| using mindspore::lite::KernelRegistrar; | |||
| @@ -70,11 +71,21 @@ int GruCPUKernel::InitBuffer() { | |||
| int GruCPUKernel::InitWeightBias() { | |||
| auto weight_gate = in_tensors_.at(1); | |||
| MS_ASSERT(weight_gate != nullptr); | |||
| weight_g_ptr_ = reinterpret_cast<float *>(weight_gate->data_c()); | |||
| weight_g_ptr_ = reinterpret_cast<float *>(malloc(weight_gate->ElementsNum() * sizeof(float))); | |||
| if (weight_g_ptr_ == nullptr) { | |||
| MS_LOG(ERROR) << "GruCPUKernel malloc weight_g_ptr_ error."; | |||
| return RET_ERROR; | |||
| } | |||
| memcpy(weight_g_ptr_, weight_gate->data_c(), weight_gate->ElementsNum() * sizeof(float)); | |||
| auto weight_recu = in_tensors_.at(2); | |||
| MS_ASSERT(weight_recu != nullptr); | |||
| weight_r_ptr_ = reinterpret_cast<float *>(weight_recu->data_c()); | |||
| weight_r_ptr_ = reinterpret_cast<float *>(malloc(weight_recu->ElementsNum() * sizeof(float))); | |||
| if (weight_r_ptr_ == nullptr) { | |||
| MS_LOG(ERROR) << "GruCPUKernel malloc weight_r_ptr_ error."; | |||
| return RET_ERROR; | |||
| } | |||
| memcpy(weight_r_ptr_, weight_recu->data_c(), weight_recu->ElementsNum() * sizeof(float)); | |||
| int bias_num = gru_parm_->bidirectional_ ? 2 * 3 * gru_parm_->hidden_size_ : 3 * gru_parm_->hidden_size_; | |||
| bias_ptr_ = reinterpret_cast<float *>(malloc(bias_num * sizeof(float))); | |||
| @@ -17,7 +17,7 @@ | |||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRU_FP32_H_ | |||
| #include <vector> | |||
| #include "src/lite_kernel.h" | |||
| #include "nnacl/fp32/gru_fp32.h" | |||
| #include "nnacl/gru_parameter.h" | |||
| namespace mindspore::kernel { | |||
| class GruCPUKernel : public LiteKernel { | |||
| @@ -42,8 +42,8 @@ class GruCPUKernel : public LiteKernel { | |||
| int InitWeightBias(); | |||
| float *gate_buffer_ = nullptr; | |||
| const float *weight_g_ptr_ = nullptr; | |||
| const float *weight_r_ptr_ = nullptr; | |||
| float *weight_g_ptr_ = nullptr; | |||
| float *weight_r_ptr_ = nullptr; | |||
| float *bias_ptr_ = nullptr; | |||
| GruParameter *gru_parm_ = nullptr; | |||
| }; | |||