| @@ -124,6 +124,18 @@ void CheckIfValidType(const TypePtr &type) { | |||
| } | |||
| } | |||
| void SetTensorType(const TypePtr &type, const BaseShapePtr &shape, irpb::TypeProto *type_proto) { | |||
| TypePtr elem_type = dyn_cast<TensorType>(type)->element(); | |||
| type_proto->mutable_tensor_type()->set_elem_type(GetNumberDataType(elem_type)); | |||
| type_proto->set_data_type(irpb::DT_TENSOR); | |||
| if (shape != nullptr && shape->isa<abstract::Shape>()) { | |||
| abstract::ShapePtr shape_info = dyn_cast<abstract::Shape>(shape); | |||
| for (const auto &elem : shape_info->shape()) { | |||
| type_proto->mutable_tensor_type()->mutable_shape()->add_dim()->set_size(elem); | |||
| } | |||
| } | |||
| } | |||
| void ProtoExporter::SetNodeOutputType(const TypePtr &type, const BaseShapePtr &shape, irpb::TypeProto *type_proto) { | |||
| if (type_proto == nullptr) { | |||
| return; | |||
| @@ -136,15 +148,7 @@ void ProtoExporter::SetNodeOutputType(const TypePtr &type, const BaseShapePtr &s | |||
| if (type->isa<Number>()) { | |||
| type_proto->set_data_type(GetNumberDataType(type)); | |||
| } else if (type->isa<TensorType>()) { | |||
| TypePtr elem_type = dyn_cast<TensorType>(type)->element(); | |||
| type_proto->mutable_tensor_type()->set_elem_type(GetNumberDataType(elem_type)); | |||
| type_proto->set_data_type(irpb::DT_TENSOR); | |||
| if (shape != nullptr && shape->isa<abstract::Shape>()) { | |||
| abstract::ShapePtr shape_info = dyn_cast<abstract::Shape>(shape); | |||
| for (const auto &elem : shape_info->shape()) { | |||
| type_proto->mutable_tensor_type()->mutable_shape()->add_dim()->set_size(elem); | |||
| } | |||
| } | |||
| SetTensorType(type, shape, type_proto); | |||
| } else if (type->isa<Tuple>()) { | |||
| TuplePtr tuple_type = dyn_cast<Tuple>(type); | |||
| type_proto->set_data_type(irpb::DT_TUPLE); | |||
| @@ -91,6 +91,28 @@ bool GetTensorOrScalarTypeInfo(const TypePtr &arg_type_origin, TypeId *arg_type_ | |||
| return false; | |||
| } | |||
| TypeId GetMaxTypeIdForNumber(TypeId max_type_id, bool has_int8, bool has_scalar_int64, bool has_scalar_float32) { | |||
| if (max_type_id == kNumberTypeUInt8 && has_int8) { | |||
| max_type_id = kNumberTypeInt16; | |||
| } | |||
| // if bool is the max type, see if there is scalar input | |||
| // if so, it means that max is bool tensor, use scalar type instead. | |||
| // for example: Tensor([True, True]) * 2, expect result is Tensor([2, 2]) | |||
| if (max_type_id == kNumberTypeBool) { | |||
| if (has_scalar_int64) { | |||
| max_type_id = kNumberTypeInt64; | |||
| } | |||
| if (has_scalar_float32) { | |||
| max_type_id = kNumberTypeFloat32; | |||
| } | |||
| } | |||
| if (max_type_id != kNumberTypeFloat16 && max_type_id != kNumberTypeFloat32 && max_type_id != kNumberTypeFloat64 && | |||
| max_type_id != kTypeUnknown && has_scalar_float32) { | |||
| max_type_id = kNumberTypeFloat32; | |||
| } | |||
| return max_type_id; | |||
| } | |||
| TypeId GetMaxTypeId(const std::vector<TypePtr> &input_types, const std::vector<size_t> &indices) { | |||
| TypeId max_type_id = kTypeUnknown; | |||
| size_t max_type_number = 0; | |||
| @@ -126,26 +148,7 @@ TypeId GetMaxTypeId(const std::vector<TypePtr> &input_types, const std::vector<s | |||
| SetMaxType(&max_type_id, &max_type_number, arg_type_id, it->second); | |||
| } | |||
| } | |||
| if (max_type_id == kNumberTypeUInt8 && has_int8) { | |||
| max_type_id = kNumberTypeInt16; | |||
| } | |||
| // if bool is the max type, see if there is scalar input | |||
| // if so, it means that max is bool tensor, use scalar type instead. | |||
| // for example: Tensor([True, True]) * 2, expect result is Tensor([2, 2]) | |||
| if (max_type_id == kNumberTypeBool) { | |||
| if (has_scalar_int64) { | |||
| max_type_id = kNumberTypeInt64; | |||
| } | |||
| if (has_scalar_float32) { | |||
| max_type_id = kNumberTypeFloat32; | |||
| } | |||
| } | |||
| if (max_type_id != kNumberTypeFloat16 && max_type_id != kNumberTypeFloat32 && max_type_id != kNumberTypeFloat64 && | |||
| max_type_id != kTypeUnknown && has_scalar_float32) { | |||
| max_type_id = kNumberTypeFloat32; | |||
| } | |||
| return max_type_id; | |||
| return GetMaxTypeIdForNumber(max_type_id, has_int8, has_scalar_int64, has_scalar_float32); | |||
| } | |||
| // Get the largest type of index in the same SignatureEnumDType of arguments. | |||
| @@ -257,6 +260,18 @@ void CheckSigSize(const size_t &sig_size, const bool &has_var, const AbstractBas | |||
| } | |||
| } | |||
| SignatureEnumRW GetSignatureEnumRW(size_t index, const std::vector<Signature> &signature, bool has_var) { | |||
| SignatureEnumRW sig = SignatureEnumRW::kRWDefault; | |||
| // If sig_size is 0 use default. | |||
| std::size_t sig_size = signature.size(); | |||
| if (index < sig_size) { | |||
| sig = signature[index].rw; | |||
| } else if (has_var && index >= sig_size) { | |||
| sig = signature[sig_size - 1].rw; | |||
| } | |||
| return sig; | |||
| } | |||
| AnfNodePtr BuildNewCNode(const FuncGraphPtr &func_graph, const std::string &func_name, const ValuePtr &function, | |||
| const AbstractBasePtrList &args_spec_list, const std::vector<AnfNodePtr> ¶ms_list) { | |||
| // args: original inputs | |||
| @@ -277,14 +292,8 @@ AnfNodePtr BuildNewCNode(const FuncGraphPtr &func_graph, const std::string &func | |||
| op_inputs.push_back(param); | |||
| continue; | |||
| } | |||
| SignatureEnumRW sig = SignatureEnumRW::kRWDefault; | |||
| // If sig_size is 0 use default. | |||
| if (sig_size > 0 && i < sig_size) { | |||
| sig = signature[i].rw; | |||
| } else if (has_var && i >= sig_size) { | |||
| sig = signature[sig_size - 1].rw; | |||
| } | |||
| SignatureEnumRW sig = GetSignatureEnumRW(i, signature, has_var); | |||
| TypePtr type = args_spec_list[i]->BuildType(); | |||
| if (type && type->isa<RefType>()) { | |||
| if (sig == SignatureEnumRW::kRWRead) { | |||
| @@ -31,7 +31,6 @@ | |||
| namespace mindspore { | |||
| namespace opt { | |||
| namespace { | |||
| using ParamUserMap = std::unordered_map<std::string, std::vector<size_t>>; | |||
| using LoadGraphMap = OrderedMap<std::string, std::vector<size_t>>; | |||
| @@ -146,9 +146,8 @@ class InlinerBase : public AnfVisitor { | |||
| if (IsForceInline(this, fg, node)) { | |||
| if (IsUniqueUse(nullptr, fg, nullptr)) { | |||
| return InlineMove(node, fg, args, inputs); | |||
| } else { | |||
| return InlineClone(fg, node->func_graph(), args, inputs[0]->scope()); | |||
| } | |||
| return InlineClone(fg, node->func_graph(), args, inputs[0]->scope()); | |||
| } | |||
| if (IsUniqueUse(nullptr, fg, nullptr)) { | |||
| @@ -960,6 +960,27 @@ AnfNodePtr MSANFModelParser::BuildOperatorNode(const mind_ir::NodeProto &node_pr | |||
| return std::make_shared<ValueNode>(prim); | |||
| } | |||
| bool MSANFModelParser::CheckCNodePrim(CNodePtr cnode_ptr) { | |||
| // Handle control flow operator. | |||
| auto operatorPtr = cnode_ptr->input(0); | |||
| // Set abstract of switch(c,f,t),switchLayer(c,tup) and | |||
| // partial(func,args) to null | |||
| auto prim = GetValueNode<PrimitivePtr>(operatorPtr); | |||
| if (IsPrimitiveEquals(prim::kPrimSwitch, prim) || IsPrimitiveEquals(prim::kPrimSwitchLayer, prim) || | |||
| IsPrimitiveEquals(prim::kPrimPartial, prim)) { | |||
| cnode_ptr->set_abstract(nullptr); | |||
| return true; | |||
| } | |||
| // If the operator is not a primitive, the abstract will been set to null. | |||
| // Because there are not some operators in front end, the abstract of primitive should be reserved. | |||
| if (prim == nullptr) { | |||
| cnode_ptr->set_abstract(nullptr); | |||
| return true; | |||
| } | |||
| return false; | |||
| } | |||
| void MSANFModelParser::SetEmptyTensorProtoCNodeAbstract(CNodePtr cnode_ptr, const std::string &node_type) { | |||
| if (node_type == "UpdateState") { | |||
| cnode_ptr->set_abstract(kUMonad->ToAbstract()); | |||
| @@ -984,22 +1005,7 @@ void MSANFModelParser::SetEmptyTensorProtoCNodeAbstract(CNodePtr cnode_ptr, cons | |||
| // Set CNode abstract. | |||
| void MSANFModelParser::SetCNodeAbstract(const mind_ir::NodeProto &node_proto, CNodePtr cnode_ptr) { | |||
| const std::string &node_type = node_proto.op_type(); | |||
| // Handle control flow operator. | |||
| auto operatorPtr = cnode_ptr->input(0); | |||
| // Set abstract of switch(c,f,t),switchLayer(c,tup) and | |||
| // partial(func,args) to null | |||
| auto prim = GetValueNode<PrimitivePtr>(operatorPtr); | |||
| if (IsPrimitiveEquals(prim::kPrimSwitch, prim) || IsPrimitiveEquals(prim::kPrimSwitchLayer, prim) || | |||
| IsPrimitiveEquals(prim::kPrimPartial, prim)) { | |||
| cnode_ptr->set_abstract(nullptr); | |||
| return; | |||
| } | |||
| // If the operator is not a primitive, the abstract will been set to null. | |||
| // Because there are not some operators in front end, the abstract of primitive should be reserved. | |||
| if (prim == nullptr && need_renormalize()) { | |||
| cnode_ptr->set_abstract(nullptr); | |||
| if (CheckCNodePrim(cnode_ptr)) { | |||
| return; | |||
| } | |||
| @@ -1020,6 +1026,7 @@ void MSANFModelParser::SetCNodeAbstract(const mind_ir::NodeProto &node_proto, CN | |||
| // Because there is not context in unit test, | |||
| // abstract->broaden() is replaced by abstract->set_value(kAnyValue). | |||
| const std::string &node_type = node_proto.op_type(); | |||
| if (kv.size() == 0) { | |||
| SetEmptyTensorProtoCNodeAbstract(cnode_ptr, node_type); | |||
| } else if (kv.size() == 1 && !is_tuple_or_list) { | |||
| @@ -71,6 +71,7 @@ class MSANFModelParser { | |||
| bool ObtainCNodeAttrInTensorForm(const PrimitivePtr &prim, const mind_ir::AttributeProto &attr_proto); | |||
| bool BuildValueNodeForFuncGraph(const mind_ir::NodeProto &node_proto); | |||
| AnfNodePtr BuildOperatorNode(const mind_ir::NodeProto &node_proto); | |||
| bool CheckCNodePrim(CNodePtr cnode_ptr); | |||
| void SetEmptyTensorProtoCNodeAbstract(CNodePtr cnode_ptr, const std::string &node_type); | |||
| void SetCNodeAbstract(const mind_ir::NodeProto &node_proto, CNodePtr cnode_ptr); | |||
| bool ObtainValueNodeInTensorForm(const string &value_node_name, const mind_ir::TensorProto &attr_tensor); | |||
| @@ -571,7 +571,8 @@ ShapeVector CheckAndConvertUtils::CheckTensorIntValue(const std::string &type_na | |||
| } | |||
| TypePtr CheckAndConvertUtils::CheckTensorSubClass(const string &type_name, const TypePtr &type, | |||
| const std::set<TypePtr> &template_types, const string &prim_name) { | |||
| const std::set<TypePtr> &template_types, const string &prim_name, | |||
| bool is_mix) { | |||
| if (CheckType(type, template_types)) { | |||
| return type; | |||
| } | |||
| @@ -584,6 +585,11 @@ TypePtr CheckAndConvertUtils::CheckTensorSubClass(const string &type_name, const | |||
| } | |||
| buffer << " Tensor[" << item->ToString() << "],"; | |||
| } | |||
| if (is_mix) { | |||
| for (const auto &item : template_types) { | |||
| buffer << " " << item->ToString() << "],"; | |||
| } | |||
| } | |||
| buffer << "}, but got " << type->ToString(); | |||
| buffer << "."; | |||
| MS_EXCEPTION(TypeError) << buffer.str(); | |||
| @@ -613,7 +619,7 @@ TypePtr CheckAndConvertUtils::CheckScalarOrTensorTypesSame(const std::map<std::s | |||
| (void)input_names.append(item.first); | |||
| (void)input_names.append(", "); | |||
| } | |||
| return CheckMixSubClass(input_names, arg_, valid_values, prim_name); | |||
| return CheckTensorSubClass(input_names, arg_, valid_values, prim_name, true); | |||
| } | |||
| TypePtr CheckAndConvertUtils::_CheckTypeSame(const std::map<std::string, TypePtr> &args, const std::string &prim_name, | |||
| @@ -809,26 +815,4 @@ bool CheckAndConvertUtils::HasDynamicShapeInput(const AbstractBasePtrList &abs_l | |||
| } | |||
| return false; | |||
| } | |||
| TypePtr CheckAndConvertUtils::CheckMixSubClass(const string &type_name, const TypePtr &type, | |||
| const std::set<TypePtr> &template_types, const string &prim_name) { | |||
| if (CheckType(type, template_types)) { | |||
| return type; | |||
| } | |||
| std::ostringstream buffer; | |||
| buffer << "Primitive[" << prim_name << "]'s input argument[" << type_name << "] must be a type of {"; | |||
| for (const auto &item : template_types) { | |||
| if (item->isa<TensorType>()) { | |||
| buffer << item->ToString(); | |||
| continue; | |||
| } | |||
| buffer << " Tensor[" << item->ToString() << "],"; | |||
| } | |||
| for (const auto &item : template_types) { | |||
| buffer << " " << item->ToString() << "],"; | |||
| } | |||
| buffer << "}, but got " << type->ToString(); | |||
| buffer << "."; | |||
| MS_EXCEPTION(TypeError) << buffer.str(); | |||
| } | |||
| } // namespace mindspore | |||
| @@ -322,9 +322,8 @@ class CheckAndConvertUtils { | |||
| static TypePtr _CheckTypeSame(const std::map<std::string, TypePtr> &args, const std::string &prim_name, | |||
| const bool allow_mix); | |||
| static TypePtr CheckTensorSubClass(const std::string &type_name, const TypePtr &type, | |||
| const std::set<TypePtr> &template_types, const std::string &prim_name); | |||
| static TypePtr CheckMixSubClass(const std::string &type_name, const TypePtr &type, | |||
| const std::set<TypePtr> &template_types, const std::string &prim_name); | |||
| const std::set<TypePtr> &template_types, const std::string &prim_name, | |||
| bool is_mix = false); | |||
| }; | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CORE_UTILS_CHECK_CONVERT_UTILS_H_ | |||