Browse Source

!30723 [MS][LITE]support onnx fp64 cast node and parameter node

Merge pull request !30723 from luoyuan/support-fp64-greater
feature/build-system-rewrite
i-robot Gitee 4 years ago
parent
commit
5041fa9a6b
No known key found for this signature in database GPG Key ID: 173E9B9CA92EEF8F
4 changed files with 43 additions and 6 deletions
  1. +6
    -0
      mindspore/lite/src/runtime/kernel/arm/fp32/gather_fp32.cc
  2. +3
    -0
      mindspore/lite/tools/converter/parser/onnx/onnx_cast_parser.cc
  3. +13
    -5
      mindspore/lite/tools/converter/parser/onnx/onnx_inputs_adjust.cc
  4. +21
    -1
      mindspore/lite/tools/optimizer/common/gllo_utils.cc

+ 6
- 0
mindspore/lite/src/runtime/kernel/arm/fp32/gather_fp32.cc View File

@@ -59,6 +59,11 @@ int GatherCPUKernel::AssignIndicesData(bool isIndicesInt32) {
indices_data_[i] = static_cast<int>(reinterpret_cast<float *>(indices_tensor->MutableData())[i]);
}
break;
case kNumberTypeBool:
for (int i = 0; i < indices_num; i++) {
indices_data_[i] = static_cast<int>(reinterpret_cast<bool *>(indices_tensor->MutableData())[i]);
}
break;
default:
MS_LOG(ERROR) << "Does not support data type: " << indices_tensor->data_type();
return RET_ERROR;
@@ -71,4 +76,5 @@ int GatherCPUKernel::AssignIndicesData(bool isIndicesInt32) {

REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Gather, LiteKernelCreator<GatherCPUKernel>)
REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_Gather, LiteKernelCreator<GatherCPUKernel>)
REG_KERNEL(kCPU, kNumberTypeBool, PrimitiveType_Gather, LiteKernelCreator<GatherCPUKernel>)
} // namespace mindspore::kernel

+ 3
- 0
mindspore/lite/tools/converter/parser/onnx/onnx_cast_parser.cc View File

@@ -32,6 +32,9 @@ ops::PrimitiveC *OnnxCastParser::Parse(const onnx::GraphProto &onnx_graph, const
if (dst_type == kNumberTypeInt64) {
dst_type = kNumberTypeInt32;
}
if (dst_type == kNumberTypeFloat64) {
dst_type = kNumberTypeFloat32;
}
prim->AddAttr("to", MakeValue(static_cast<int32_t>(dst_type)));
}
}


+ 13
- 5
mindspore/lite/tools/converter/parser/onnx/onnx_inputs_adjust.cc View File

@@ -71,7 +71,8 @@ STATUS AddAttrToInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, int
return lite::RET_OK;
}

STATUS ReplaceInt64ParameterNode(const FuncGraphPtr &func_graph, const ParameterPtr &param_node) {
STATUS ReplaceTypeParameterNode(const FuncGraphPtr &func_graph, const ParameterPtr &param_node, TypeId input,
TypeId output) {
MS_ASSERT(func_graph != nullptr);
MS_ASSERT(param_node != nullptr);
if (param_node->abstract() == nullptr) {
@@ -87,8 +88,8 @@ STATUS ReplaceInt64ParameterNode(const FuncGraphPtr &func_graph, const Parameter
MS_LOG(ERROR) << "get typePtr failed.";
return lite::RET_NULL_PTR;
}
if (abstract_tensor->element()->GetTypeTrack()->type_id() != kNumberTypeInt64) {
MS_LOG(DEBUG) << "don't need to convert to int32.";
if (abstract_tensor->element()->GetTypeTrack()->type_id() != input) {
MS_LOG(DEBUG) << "The actual type is not the input type, don't need to convert.";
return lite::RET_OK;
}
auto manager = func_graph->manager();
@@ -113,7 +114,9 @@ STATUS ReplaceInt64ParameterNode(const FuncGraphPtr &func_graph, const Parameter
func_graph->DropNode(param_node);
} else {
// set graph input
param_node->abstract()->set_type(TypeIdToType(kNumberTypeInt32));
if (abstract_tensor->element()->GetTypeTrack()->type_id() == input) {
param_node->abstract()->set_type(TypeIdToType(output));
}
}
return lite::RET_OK;
}
@@ -390,7 +393,12 @@ bool OnnxInputAdjust::Adjust(const FuncGraphPtr &func_graph) {
for (auto &node : node_list) {
if (utils::isa<ParameterPtr>(node)) {
auto param_node = node->cast<ParameterPtr>();
status = ReplaceInt64ParameterNode(func_graph, param_node);
status = ReplaceTypeParameterNode(func_graph, param_node, kNumberTypeFloat64, kNumberTypeFloat32);
if (status != lite::RET_OK) {
MS_LOG(ERROR) << "replace fp64 param node failed.";
return status;
}
status = ReplaceTypeParameterNode(func_graph, param_node, kNumberTypeInt64, kNumberTypeInt32);
if (status != lite::RET_OK) {
MS_LOG(ERROR) << "replace int64 param node failed.";
return status;


+ 21
- 1
mindspore/lite/tools/optimizer/common/gllo_utils.cc View File

@@ -212,6 +212,21 @@ int CopyTensorDataFromTensorInfo(const tensor::TensorPtr &tensor_info,
tensor_data[i] = static_cast<int>(origin_data[i]);
}
}
} else if (tensor_info->data_type() == kNumberTypeFloat64) {
auto *tensor_data = reinterpret_cast<float *>(tensor_info_dst->data_c());
if (tensor_data == nullptr) {
MS_LOG(ERROR) << "new data failed";
return RET_ERROR;
}
auto *origin_data = reinterpret_cast<double_t *>(tensor_info->data_c());
for (size_t i = 0; i < data_count; ++i) {
if (origin_data[i] > static_cast<double_t>(FLT_MAX) || origin_data[i] < static_cast<double_t>(FLT_MIN)) {
MS_LOG(WARNING) << "float64 data " << origin_data[i] << " cannot fit into float32";
tensor_data[i] = origin_data[i] > 0 ? FLT_MAX : FLT_MIN;
} else {
tensor_data[i] = static_cast<float>(origin_data[i]);
}
}
} else {
tensor_info_dst->set_data_type(tensor_info->data_type());
auto *tensor_data = reinterpret_cast<int8_t *>(tensor_info_dst->data_c());
@@ -724,7 +739,12 @@ ParameterPtr BuildParameterNode(const FuncGraphPtr &func_graph, const AnfNodePtr
std::vector<int64_t> shape_vector;
std::transform(shape.begin(), shape.end(), std::back_inserter(shape_vector),
[](const int &val) { return static_cast<int64_t>(val); });
auto data_type = tensor_info->data_type() == kNumberTypeInt64 ? kNumberTypeInt32 : tensor_info->data_type();
auto data_type = tensor_info->data_type();
if (tensor_info->data_type() == kNumberTypeInt64) {
data_type = kNumberTypeInt32;
} else if (tensor_info->data_type() == kNumberTypeFloat64) {
data_type = kNumberTypeFloat32;
}
param_node->set_name(node->fullname_with_scope());
auto tensor_info_new = std::make_shared<tensor::Tensor>(data_type, shape_vector);
if (tensor_info_new == nullptr) {


Loading…
Cancel
Save