|
|
|
@@ -71,7 +71,8 @@ STATUS AddAttrToInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, int |
|
|
|
return lite::RET_OK; |
|
|
|
} |
|
|
|
|
|
|
|
STATUS ReplaceInt64ParameterNode(const FuncGraphPtr &func_graph, const ParameterPtr ¶m_node) { |
|
|
|
STATUS ReplaceTypeParameterNode(const FuncGraphPtr &func_graph, const ParameterPtr ¶m_node, TypeId input, |
|
|
|
TypeId output) { |
|
|
|
MS_ASSERT(func_graph != nullptr); |
|
|
|
MS_ASSERT(param_node != nullptr); |
|
|
|
if (param_node->abstract() == nullptr) { |
|
|
|
@@ -87,8 +88,8 @@ STATUS ReplaceInt64ParameterNode(const FuncGraphPtr &func_graph, const Parameter |
|
|
|
MS_LOG(ERROR) << "get typePtr failed."; |
|
|
|
return lite::RET_NULL_PTR; |
|
|
|
} |
|
|
|
if (abstract_tensor->element()->GetTypeTrack()->type_id() != kNumberTypeInt64) { |
|
|
|
MS_LOG(DEBUG) << "don't need to convert to int32."; |
|
|
|
if (abstract_tensor->element()->GetTypeTrack()->type_id() != input) { |
|
|
|
MS_LOG(DEBUG) << "The actual type is not the input type, don't need to convert."; |
|
|
|
return lite::RET_OK; |
|
|
|
} |
|
|
|
auto manager = func_graph->manager(); |
|
|
|
@@ -113,7 +114,9 @@ STATUS ReplaceInt64ParameterNode(const FuncGraphPtr &func_graph, const Parameter |
|
|
|
func_graph->DropNode(param_node); |
|
|
|
} else { |
|
|
|
// set graph input |
|
|
|
param_node->abstract()->set_type(TypeIdToType(kNumberTypeInt32)); |
|
|
|
if (abstract_tensor->element()->GetTypeTrack()->type_id() == input) { |
|
|
|
param_node->abstract()->set_type(TypeIdToType(output)); |
|
|
|
} |
|
|
|
} |
|
|
|
return lite::RET_OK; |
|
|
|
} |
|
|
|
@@ -390,7 +393,12 @@ bool OnnxInputAdjust::Adjust(const FuncGraphPtr &func_graph) { |
|
|
|
for (auto &node : node_list) { |
|
|
|
if (utils::isa<ParameterPtr>(node)) { |
|
|
|
auto param_node = node->cast<ParameterPtr>(); |
|
|
|
status = ReplaceInt64ParameterNode(func_graph, param_node); |
|
|
|
status = ReplaceTypeParameterNode(func_graph, param_node, kNumberTypeFloat64, kNumberTypeFloat32); |
|
|
|
if (status != lite::RET_OK) { |
|
|
|
MS_LOG(ERROR) << "replace fp64 param node failed."; |
|
|
|
return status; |
|
|
|
} |
|
|
|
status = ReplaceTypeParameterNode(func_graph, param_node, kNumberTypeInt64, kNumberTypeInt32); |
|
|
|
if (status != lite::RET_OK) { |
|
|
|
MS_LOG(ERROR) << "replace int64 param node failed."; |
|
|
|
return status; |
|
|
|
|