Browse Source

!15147 solve the precision bug of fp16 lstm op

From: @wangyanling10
Reviewed-by: @hangangqiang,@zhanghaibo5
Signed-off-by: @hangangqiang
pull/15147/MERGE
mindspore-ci-bot Gitee 4 years ago
parent
commit
8bf983c58d
3 changed files with 3 additions and 3 deletions
  1. +1
    -1
      mindspore/lite/src/runtime/kernel/arm/fp16/gru_fp16.cc
  2. +1
    -1
      mindspore/lite/src/runtime/kernel/arm/fp16/lstm_fp16.cc
  3. +1
    -1
      mindspore/lite/test/models_with_multiple_inputs_fp16.cfg

+ 1
- 1
mindspore/lite/src/runtime/kernel/arm/fp16/gru_fp16.cc View File

@@ -157,7 +157,7 @@ int GruFp16CPUKernel::InitStateWeightBias() {
if (weight_r->data_type() == kNumberTypeFloat32) {
Float32ToFloat16(reinterpret_cast<float *>(weight_r->data_c()), weight_r_ptr_, weight_r->ElementsNum());
} else if (weight_r->data_type() == kNumberTypeFloat16) {
memcpy(weight_r_ptr_, reinterpret_cast<float16_t *>(weight_r->data_c()), weight_r->ElementsNum());
memcpy(weight_r_ptr_, reinterpret_cast<float16_t *>(weight_r->data_c()), weight_r->Size());
} else {
MS_LOG(ERROR) << "Unsupported data type of weight_r tensor for gru.";
return RET_ERROR;


+ 1
- 1
mindspore/lite/src/runtime/kernel/arm/fp16/lstm_fp16.cc View File

@@ -165,7 +165,7 @@ int LstmFp16CPUKernel::InitStateWeightBias() {
if (weight_h->data_type() == kNumberTypeFloat32) {
Float32ToFloat16(reinterpret_cast<float *>(weight_h->data_c()), weight_h_ptr_, weight_h->ElementsNum());
} else if (weight_h->data_type() == kNumberTypeFloat16) {
memcpy(weight_h_ptr_, reinterpret_cast<float16_t *>(weight_h->data_c()), weight_h->ElementsNum());
memcpy(weight_h_ptr_, reinterpret_cast<float16_t *>(weight_h->data_c()), weight_h->Size());
} else {
MS_LOG(ERROR) << "Unsupported data type of weight_h tensor for lstm.";
return RET_ERROR;


+ 1
- 1
mindspore/lite/test/models_with_multiple_inputs_fp16.cfg View File

@@ -14,7 +14,7 @@ ml_female_model_step6_noiseout.pb;66 2
ml_male_model_step6_noiseout.pb;66 2.5
ml_tts_encoder_control_flow.pb;4;1:1,22:1:1 1.5
ml_tts_decoder_control_flow.pb;5 1
ml_tts_decoder.pb;5 117
ml_tts_decoder.pb;5 2.5
# The input of hiai_cv_labelDetectorModel_v3.tflite is between 0-255.
hiai_cv_labelDetectorModel_v3.tflite;2 2
ml_tts_vocoder.pb;66 53


Loading…
Cancel
Save