diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/gather_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/gather_fp32.cc index ab8d4b9418..1510aa5063 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/gather_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/gather_fp32.cc @@ -59,6 +59,11 @@ int GatherCPUKernel::AssignIndicesData(bool isIndicesInt32) { indices_data_[i] = static_cast(reinterpret_cast(indices_tensor->MutableData())[i]); } break; + case kNumberTypeBool: + for (int i = 0; i < indices_num; i++) { + indices_data_[i] = static_cast(reinterpret_cast(indices_tensor->MutableData())[i]); + } + break; default: MS_LOG(ERROR) << "Does not support data type: " << indices_tensor->data_type(); return RET_ERROR; @@ -71,4 +76,5 @@ int GatherCPUKernel::AssignIndicesData(bool isIndicesInt32) { REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Gather, LiteKernelCreator) REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_Gather, LiteKernelCreator) +REG_KERNEL(kCPU, kNumberTypeBool, PrimitiveType_Gather, LiteKernelCreator) } // namespace mindspore::kernel diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_cast_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_cast_parser.cc index c642bb8133..57978b0c15 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_cast_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_cast_parser.cc @@ -32,6 +32,9 @@ ops::PrimitiveC *OnnxCastParser::Parse(const onnx::GraphProto &onnx_graph, const if (dst_type == kNumberTypeInt64) { dst_type = kNumberTypeInt32; } + if (dst_type == kNumberTypeFloat64) { + dst_type = kNumberTypeFloat32; + } prim->AddAttr("to", MakeValue(static_cast(dst_type))); } } diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_inputs_adjust.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_inputs_adjust.cc index cc7a70e08f..75504f4dab 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_inputs_adjust.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_inputs_adjust.cc @@ -71,7 +71,8 @@ STATUS AddAttrToInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, int return lite::RET_OK; } -STATUS ReplaceInt64ParameterNode(const FuncGraphPtr &func_graph, const ParameterPtr ¶m_node) { +STATUS ReplaceTypeParameterNode(const FuncGraphPtr &func_graph, const ParameterPtr ¶m_node, TypeId input, + TypeId output) { MS_ASSERT(func_graph != nullptr); MS_ASSERT(param_node != nullptr); if (param_node->abstract() == nullptr) { @@ -87,8 +88,8 @@ STATUS ReplaceInt64ParameterNode(const FuncGraphPtr &func_graph, const Parameter MS_LOG(ERROR) << "get typePtr failed."; return lite::RET_NULL_PTR; } - if (abstract_tensor->element()->GetTypeTrack()->type_id() != kNumberTypeInt64) { - MS_LOG(DEBUG) << "don't need to convert to int32."; + if (abstract_tensor->element()->GetTypeTrack()->type_id() != input) { + MS_LOG(DEBUG) << "The actual type is not the input type, don't need to convert."; return lite::RET_OK; } auto manager = func_graph->manager(); @@ -113,7 +114,9 @@ STATUS ReplaceInt64ParameterNode(const FuncGraphPtr &func_graph, const Parameter func_graph->DropNode(param_node); } else { // set graph input - param_node->abstract()->set_type(TypeIdToType(kNumberTypeInt32)); + if (abstract_tensor->element()->GetTypeTrack()->type_id() == input) { + param_node->abstract()->set_type(TypeIdToType(output)); + } } return lite::RET_OK; } @@ -390,7 +393,12 @@ bool OnnxInputAdjust::Adjust(const FuncGraphPtr &func_graph) { for (auto &node : node_list) { if (utils::isa(node)) { auto param_node = node->cast(); - status = ReplaceInt64ParameterNode(func_graph, param_node); + status = ReplaceTypeParameterNode(func_graph, param_node, kNumberTypeFloat64, kNumberTypeFloat32); + if (status != lite::RET_OK) { + MS_LOG(ERROR) << "replace fp64 param node failed."; + return status; + } + status = ReplaceTypeParameterNode(func_graph, param_node, kNumberTypeInt64, kNumberTypeInt32); if (status != lite::RET_OK) { MS_LOG(ERROR) << "replace int64 param node failed."; return status; diff --git a/mindspore/lite/tools/optimizer/common/gllo_utils.cc b/mindspore/lite/tools/optimizer/common/gllo_utils.cc index 478f63eb9d..31bf4a8260 100644 --- a/mindspore/lite/tools/optimizer/common/gllo_utils.cc +++ b/mindspore/lite/tools/optimizer/common/gllo_utils.cc @@ -212,6 +212,21 @@ int CopyTensorDataFromTensorInfo(const tensor::TensorPtr &tensor_info, tensor_data[i] = static_cast(origin_data[i]); } } + } else if (tensor_info->data_type() == kNumberTypeFloat64) { + auto *tensor_data = reinterpret_cast(tensor_info_dst->data_c()); + if (tensor_data == nullptr) { + MS_LOG(ERROR) << "new data failed"; + return RET_ERROR; + } + auto *origin_data = reinterpret_cast(tensor_info->data_c()); + for (size_t i = 0; i < data_count; ++i) { + if (origin_data[i] > static_cast(FLT_MAX) || origin_data[i] < static_cast(FLT_MIN)) { + MS_LOG(WARNING) << "float64 data " << origin_data[i] << " cannot fit into float32"; + tensor_data[i] = origin_data[i] > 0 ? FLT_MAX : FLT_MIN; + } else { + tensor_data[i] = static_cast(origin_data[i]); + } + } } else { tensor_info_dst->set_data_type(tensor_info->data_type()); auto *tensor_data = reinterpret_cast(tensor_info_dst->data_c()); @@ -724,7 +739,12 @@ ParameterPtr BuildParameterNode(const FuncGraphPtr &func_graph, const AnfNodePtr std::vector shape_vector; std::transform(shape.begin(), shape.end(), std::back_inserter(shape_vector), [](const int &val) { return static_cast(val); }); - auto data_type = tensor_info->data_type() == kNumberTypeInt64 ? kNumberTypeInt32 : tensor_info->data_type(); + auto data_type = tensor_info->data_type(); + if (tensor_info->data_type() == kNumberTypeInt64) { + data_type = kNumberTypeInt32; + } else if (tensor_info->data_type() == kNumberTypeFloat64) { + data_type = kNumberTypeFloat32; + } param_node->set_name(node->fullname_with_scope()); auto tensor_info_new = std::make_shared(data_type, shape_vector); if (tensor_info_new == nullptr) {