Merge pull request !7187 from zhangbuxue/improve_the_implicit_conversion_rule_when_there_are_int_tensor_and_float_numbertags/v1.1.0
| @@ -68,8 +68,7 @@ void SetMaxType(TypeId *max_type_id, size_t *max_type_number, const TypeId type_ | |||||
| *max_type_number = type_number; | *max_type_number = type_number; | ||||
| } | } | ||||
| bool GetTensorOrScalarTypeInfo(TypePtr arg_type_origin, bool is_write, TypeId *arg_type_id, | |||||
| TypeId *arg_type = nullptr) { | |||||
| bool GetTensorOrScalarTypeInfo(TypePtr arg_type_origin, TypeId *arg_type_id, TypeId *arg_type = nullptr) { | |||||
| if (arg_type_origin->isa<TensorType>()) { | if (arg_type_origin->isa<TensorType>()) { | ||||
| auto tensor = arg_type_origin->cast<TensorTypePtr>(); | auto tensor = arg_type_origin->cast<TensorTypePtr>(); | ||||
| auto tensor_type = tensor->element(); | auto tensor_type = tensor->element(); | ||||
| @@ -102,8 +101,7 @@ TypeId GetMaxTypeId(const std::vector<TypePtr> &input_types, std::vector<size_t> | |||||
| for (const auto &index : indices) { | for (const auto &index : indices) { | ||||
| TypeId arg_type_id = kTypeUnknown; | TypeId arg_type_id = kTypeUnknown; | ||||
| TypeId arg_type = kTypeUnknown; | TypeId arg_type = kTypeUnknown; | ||||
| auto is_write = (write_indices.find(index) != write_indices.end()); | |||||
| if (!GetTensorOrScalarTypeInfo(input_types[index], is_write, &arg_type_id, &arg_type)) { | |||||
| if (!GetTensorOrScalarTypeInfo(input_types[index], &arg_type_id, &arg_type)) { | |||||
| continue; | continue; | ||||
| } | } | ||||
| if (arg_type != kObjectTypeTensorType) { | if (arg_type != kObjectTypeTensorType) { | ||||
| @@ -144,6 +142,10 @@ TypeId GetMaxTypeId(const std::vector<TypePtr> &input_types, std::vector<size_t> | |||||
| max_type_id = kNumberTypeFloat32; | max_type_id = kNumberTypeFloat32; | ||||
| } | } | ||||
| } | } | ||||
| if (max_type_id != kNumberTypeFloat16 && max_type_id != kNumberTypeFloat32 && max_type_id != kNumberTypeFloat64 && | |||||
| has_scalar_float32) { | |||||
| max_type_id = kNumberTypeFloat32; | |||||
| } | |||||
| return max_type_id; | return max_type_id; | ||||
| } | } | ||||
| @@ -218,7 +220,7 @@ void DoAutoCast(const std::string &func_name, const std::vector<Signature> &sign | |||||
| TypeId arg_type_id = kTypeUnknown; | TypeId arg_type_id = kTypeUnknown; | ||||
| auto arg_value = input_types[i]; | auto arg_value = input_types[i]; | ||||
| (void)GetTensorOrScalarTypeInfo(arg_value, is_write, &arg_type_id); | |||||
| (void)GetTensorOrScalarTypeInfo(arg_value, &arg_type_id); | |||||
| auto it_map = type_name_map.find(arg_type_id); | auto it_map = type_name_map.find(arg_type_id); | ||||
| if (it_map == type_name_map.end()) { | if (it_map == type_name_map.end()) { | ||||
| continue; | continue; | ||||
| @@ -223,6 +223,10 @@ std::map<SignatureEnumDType, TypeId> GetDstType(const py::tuple &py_args, | |||||
| max_type = TypeId::kNumberTypeFloat32; | max_type = TypeId::kNumberTypeFloat32; | ||||
| } | } | ||||
| } | } | ||||
| if (max_type != TypeId::kNumberTypeFloat16 && max_type != TypeId::kNumberTypeFloat32 && | |||||
| max_type != TypeId::kNumberTypeFloat64 && has_float) { | |||||
| max_type = TypeId::kNumberTypeFloat32; | |||||
| } | |||||
| if (max_type == TypeId::kNumberTypeUInt8 && has_int8) { | if (max_type == TypeId::kNumberTypeUInt8 && has_int8) { | ||||
| max_type = TypeId::kNumberTypeInt16; | max_type = TypeId::kNumberTypeInt16; | ||||
| } | } | ||||
| @@ -126,7 +126,7 @@ void FuncGraph::GenerateKwParams(const FuncGraphPtr &specialized_graph, | |||||
| std::vector<AnfNodePtr> kwarg_keys_tuple_nodes = {NewValueNode(prim::kPrimMakeTuple)}; | std::vector<AnfNodePtr> kwarg_keys_tuple_nodes = {NewValueNode(prim::kPrimMakeTuple)}; | ||||
| std::vector<AnfNodePtr> kwarg_values_tuple_nodes = {NewValueNode(prim::kPrimMakeTuple)}; | std::vector<AnfNodePtr> kwarg_values_tuple_nodes = {NewValueNode(prim::kPrimMakeTuple)}; | ||||
| std::set<AnfNodePtr> key_ward_para_nodes; | |||||
| std::set<AnfNodePtr> kwarg_nodes; | |||||
| for (const auto &kwarg : kwarg_list) { | for (const auto &kwarg : kwarg_list) { | ||||
| MS_EXCEPTION_IF_NULL(kwarg); | MS_EXCEPTION_IF_NULL(kwarg); | ||||
| std::string kw_param_name = kwarg->get_key(); | std::string kw_param_name = kwarg->get_key(); | ||||
| @@ -160,14 +160,13 @@ void FuncGraph::GenerateKwParams(const FuncGraphPtr &specialized_graph, | |||||
| } else { | } else { | ||||
| auto node_itr = std::find(specialized_parameter_list->begin(), specialized_parameter_list->end(), param_node); | auto node_itr = std::find(specialized_parameter_list->begin(), specialized_parameter_list->end(), param_node); | ||||
| // multiply values found given for parameter | // multiply values found given for parameter | ||||
| if (node_itr != specialized_parameter_list->end() && | |||||
| key_ward_para_nodes.find(param_node) == key_ward_para_nodes.end()) { | |||||
| if (node_itr != specialized_parameter_list->end() && kwarg_nodes.find(param_node) == kwarg_nodes.end()) { | |||||
| MS_EXCEPTION(TypeError) << "Multiply values for specific argument: " << kw_param_name; | MS_EXCEPTION(TypeError) << "Multiply values for specific argument: " << kw_param_name; | ||||
| } else { | } else { | ||||
| specialized_parameter_list->push_back(param_node); | specialized_parameter_list->push_back(param_node); | ||||
| auto extract_node = specialized_graph->NewCNode( | auto extract_node = specialized_graph->NewCNode( | ||||
| {NewValueNode(prim::kPrimExtractKeywordArg), NewValueNode(kw_param_name), param_node}); | {NewValueNode(prim::kPrimExtractKeywordArg), NewValueNode(kw_param_name), param_node}); | ||||
| key_ward_para_nodes.insert(param_node); | |||||
| kwarg_nodes.insert(param_node); | |||||
| (void)repl_nodes->emplace(param_node, extract_node); | (void)repl_nodes->emplace(param_node, extract_node); | ||||
| } | } | ||||
| } | } | ||||