|
|
|
@@ -165,7 +165,7 @@ int LstmFp16CPUKernel::InitStateWeightBias() { |
|
|
|
if (weight_h->data_type() == kNumberTypeFloat32) { |
|
|
|
Float32ToFloat16(reinterpret_cast<float *>(weight_h->data_c()), weight_h_ptr_, weight_h->ElementsNum()); |
|
|
|
} else if (weight_h->data_type() == kNumberTypeFloat16) { |
|
|
|
memcpy(weight_h_ptr_, reinterpret_cast<float16_t *>(weight_h->data_c()), weight_h->ElementsNum()); |
|
|
|
memcpy(weight_h_ptr_, reinterpret_cast<float16_t *>(weight_h->data_c()), weight_h->Size()); |
|
|
|
} else { |
|
|
|
MS_LOG(ERROR) << "Unsupported data type of weight_h tensor for lstm."; |
|
|
|
return RET_ERROR; |
|
|
|
|