From: @lianliguang Reviewed-by: @ginfung,@chujinjin Signed-off-by: @chujinjinpull/14548/MERGE
| @@ -30,45 +30,6 @@ namespace prim { | |||||
| ValuePtr GetPythonOps(const std::string &op_name, | ValuePtr GetPythonOps(const std::string &op_name, | ||||
| const std::string &module_name = "mindspore._extends.parse.standard_method", | const std::string &module_name = "mindspore._extends.parse.standard_method", | ||||
| bool use_signature = false); | bool use_signature = false); | ||||
| // Primitives only used by frontend; | |||||
| // Type introspection | |||||
| inline const PrimitivePtr kPrimTypeOf = std::make_shared<Primitive>("typeof"); | |||||
| inline const PrimitivePtr kPrimHasType = std::make_shared<Primitive>("hastype"); | |||||
| inline const PrimitivePtr kPrimResolve = std::make_shared<Primitive>("resolve"); | |||||
| inline const PrimitivePtr kPrimEmbed = std::make_shared<Primitive>("embed"); | |||||
| inline const PrimitivePtr kPrimRefToEmbed = std::make_shared<Primitive>("RefToEmbed"); | |||||
| inline const PrimitivePtr kPrimCreateInstance = std::make_shared<Primitive>("create_instance"); | |||||
| // Other miscellaneous | |||||
| inline const PrimitivePtr kPrimGetRefOrigin = std::make_shared<Primitive>("get_ref_origin"); | |||||
| inline const PrimitivePtr kPrimInsertGradientOf = std::make_shared<Primitive>("InsertGradientOf"); | |||||
| inline const PrimitivePtr kPrimCheckBprop = std::make_shared<Primitive>("CheckBprop"); | |||||
| inline const PrimitivePtr kPrimMixedPrecisionCast = std::make_shared<Primitive>("mixed_precision_cast"); | |||||
| inline const PrimitivePtr kPrimMakeRecord = std::make_shared<Primitive>("make_record"); | |||||
| // Structures | |||||
| inline const PrimitivePtr kPrimListMap = std::make_shared<Primitive>("list_map"); | |||||
| inline const PrimitivePtr kPrimListReduce = std::make_shared<Primitive>("list_reduce"); | |||||
| inline const PrimitivePtr kPrimTupleReversed = std::make_shared<Primitive>("tuple_reversed"); | |||||
| inline const PrimitivePtr kPrimReducedShape = std::make_shared<Primitive>("reduced_shape"); | |||||
| inline const PrimitivePtr kPrimTupleDiv = std::make_shared<Primitive>("tuple_div"); | |||||
| inline const PrimitivePtr kPrimTupleToArray = std::make_shared<Primitive>("tuple_to_array"); | |||||
| inline const PrimitivePtr kPrimShapeMul = std::make_shared<Primitive>("shape_mul"); | |||||
| inline const PrimitivePtr kPrimTupleEqual = std::make_shared<Primitive>("tuple_equal"); | |||||
| inline const PrimitivePtr kPrimListEqual = std::make_shared<Primitive>("list_equal"); | |||||
| inline const PrimitivePtr kPrimMakeRange = std::make_shared<Primitive>("make_range"); | |||||
| inline const PrimitivePtr kPrimStopGradient = std::make_shared<Primitive>("stop_gradient"); | |||||
| inline const PrimitivePtr kPrimStringEqual = std::make_shared<Primitive>("string_equal"); | |||||
| inline const PrimitivePtr kPrimStringConcat = std::make_shared<Primitive>("string_concat"); | |||||
| inline const PrimitivePtr kPrimDictLen = std::make_shared<Primitive>("dict_len"); | |||||
| inline const PrimitivePtr kPrimFakeBprop = std::make_shared<Primitive>("fake_bprop"); | |||||
| inline const PrimitivePtr kPrimBroadcastGradientArgs = std::make_shared<Primitive>("BroadcastGradientArgs"); | |||||
| class UnpackGraphPrimitive : public Primitive { | class UnpackGraphPrimitive : public Primitive { | ||||
| public: | public: | ||||
| explicit UnpackGraphPrimitive(const std::string &name, const bool &with_sens, const bool &need_unpack_args) | explicit UnpackGraphPrimitive(const std::string &name, const bool &with_sens, const bool &need_unpack_args) | ||||
| @@ -639,55 +639,26 @@ AbstractBasePtr InferImplMakeRecord(const AnalysisEnginePtr &, const PrimitivePt | |||||
| return std::make_shared<AbstractClass>(cls->tag(), abs_attributes, cls->methods()); | return std::make_shared<AbstractClass>(cls->tag(), abs_attributes, cls->methods()); | ||||
| } | } | ||||
| AbstractBasePtr InferImplAssign(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list) { | |||||
| // Inputs: Ref, value, [universal] | |||||
| CheckRequiredArgsSize(primitive->name(), args_spec_list, 2); | |||||
| MS_LOG(DEBUG) << "InferImplAssign " << args_spec_list[0]; | |||||
| auto type = args_spec_list[0]->BuildType(); | |||||
| if (type->type_id() == kObjectTypeRefKey) { | |||||
| return args_spec_list[1]->Broaden(); | |||||
| } else { | |||||
| return args_spec_list[0]; | |||||
| } | |||||
| } | |||||
| AbstractBasePtr InferImplLoad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list) { | |||||
| // Inputs: Ref/Tensor, universal | |||||
| CheckArgsSize(primitive->name(), args_spec_list, 2); | |||||
| auto ref_abs = dyn_cast<abstract::AbstractRef>(args_spec_list[0]); | |||||
| if (ref_abs != nullptr) { | |||||
| // Return tensor value if input is Ref. | |||||
| return ref_abs->CloneAsTensor(); | |||||
| } | |||||
| return args_spec_list[0]->Broaden(); | |||||
| } | |||||
| REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(TypeOf, prim::kPrimTypeOf, InferImplTypeof); | |||||
| REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(HasType, prim::kPrimHasType, InferImplHasType); | |||||
| REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(MakeRecord, prim::kPrimMakeRecord, InferImplMakeRecord); | |||||
| REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(ListMap, prim::kPrimListMap, InferImplListMap); | |||||
| REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(ListReduce, prim::kPrimListReduce, InferImplListReduce); | |||||
| REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(TupleReversed, prim::kPrimTupleReversed, InferImplTupleReversed); | |||||
| REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(ReducedShape, prim::kPrimReducedShape, InferImplReduceShape); | |||||
| REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(TupleDiv, prim::kPrimTupleDiv, InferImplTupleDiv); | |||||
| REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(TupleToArray, prim::kPrimTupleToArray, InferImplTuple2Array); | |||||
| REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(ShapeMul, prim::kPrimShapeMul, InferImplShapeMul); | |||||
| REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(TupleEqual, prim::kPrimTupleEqual, InferImplTupleEqual); | |||||
| REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(ListEqual, prim::kPrimListEqual, InferImplListEqual); | |||||
| REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(MakeRange, prim::kPrimMakeRange, InferImplMakeRange); | |||||
| REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(StopGradient, prim::kPrimStopGradient, InferImplStopGradient); | |||||
| REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(StringEqual, prim::kPrimStringEqual, InferImplStringEqual); | |||||
| REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(StringConcat, prim::kPrimStringConcat, InferImplStringConcat); | |||||
| REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(DictLen, prim::kPrimDictLen, InferImplDictLen); | |||||
| REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(FakeBprop, prim::kPrimFakeBprop, InferImplFakeBprop); | |||||
| REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(J, prim::kPrimJ, InferImplJ); | |||||
| REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(BroadcastGradientArgs, prim::kPrimBroadcastGradientArgs, | |||||
| InferImplBroadcastGradientArgs); | |||||
| REGISTER_PRIMITIVE_EVAL_IMPL(Assign, prim::kPrimAssign, InferImplAssign); | |||||
| REGISTER_PRIMITIVE_EVAL_IMPL(Load, prim::kPrimLoad, InferImplLoad); | |||||
| REGISTER_PRIMITIVE_EVAL_IMPL(TypeOf, prim::kPrimTypeOf, InferImplTypeof, nullptr, false); | |||||
| REGISTER_PRIMITIVE_EVAL_IMPL(HasType, prim::kPrimHasType, InferImplHasType, nullptr, false); | |||||
| REGISTER_PRIMITIVE_EVAL_IMPL(MakeRecord, prim::kPrimMakeRecord, InferImplMakeRecord, nullptr, false); | |||||
| REGISTER_PRIMITIVE_EVAL_IMPL(ListMap, prim::kPrimListMap, InferImplListMap, nullptr, false); | |||||
| REGISTER_PRIMITIVE_EVAL_IMPL(ListReduce, prim::kPrimListReduce, InferImplListReduce, nullptr, false); | |||||
| REGISTER_PRIMITIVE_EVAL_IMPL(TupleReversed, prim::kPrimTupleReversed, InferImplTupleReversed, nullptr, false); | |||||
| REGISTER_PRIMITIVE_EVAL_IMPL(ReducedShape, prim::kPrimReducedShape, InferImplReduceShape, nullptr, false); | |||||
| REGISTER_PRIMITIVE_EVAL_IMPL(TupleDiv, prim::kPrimTupleDiv, InferImplTupleDiv, nullptr, false); | |||||
| REGISTER_PRIMITIVE_EVAL_IMPL(TupleToArray, prim::kPrimTupleToArray, InferImplTuple2Array, nullptr, false); | |||||
| REGISTER_PRIMITIVE_EVAL_IMPL(ShapeMul, prim::kPrimShapeMul, InferImplShapeMul, nullptr, false); | |||||
| REGISTER_PRIMITIVE_EVAL_IMPL(TupleEqual, prim::kPrimTupleEqual, InferImplTupleEqual, nullptr, false); | |||||
| REGISTER_PRIMITIVE_EVAL_IMPL(ListEqual, prim::kPrimListEqual, InferImplListEqual, nullptr, false); | |||||
| REGISTER_PRIMITIVE_EVAL_IMPL(MakeRange, prim::kPrimMakeRange, InferImplMakeRange, nullptr, false); | |||||
| REGISTER_PRIMITIVE_EVAL_IMPL(StopGradient, prim::kPrimStopGradient, InferImplStopGradient, nullptr, false); | |||||
| REGISTER_PRIMITIVE_EVAL_IMPL(StringEqual, prim::kPrimStringEqual, InferImplStringEqual, nullptr, false); | |||||
| REGISTER_PRIMITIVE_EVAL_IMPL(StringConcat, prim::kPrimStringConcat, InferImplStringConcat, nullptr, false); | |||||
| REGISTER_PRIMITIVE_EVAL_IMPL(DictLen, prim::kPrimDictLen, InferImplDictLen, nullptr, false); | |||||
| REGISTER_PRIMITIVE_EVAL_IMPL(FakeBprop, prim::kPrimFakeBprop, InferImplFakeBprop, nullptr, false); | |||||
| REGISTER_PRIMITIVE_EVAL_IMPL(J, prim::kPrimJ, InferImplJ, nullptr, false); | |||||
| REGISTER_PRIMITIVE_EVAL_IMPL(BroadcastGradientArgs, prim::kPrimBroadcastGradientArgs, InferImplBroadcastGradientArgs, | |||||
| nullptr, false); | |||||
| } // namespace abstract | } // namespace abstract | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -59,18 +59,6 @@ AbstractBasePtr InferImplFakeBprop(const AnalysisEnginePtr &, const PrimitivePtr | |||||
| const AbstractBasePtrList &args_spec_list); | const AbstractBasePtrList &args_spec_list); | ||||
| AbstractBasePtr InferImplMakeRecord(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr InferImplMakeRecord(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const AbstractBasePtrList &args_spec_list); | const AbstractBasePtrList &args_spec_list); | ||||
| class RegisterFrontendPrimitiveEvalHelper { | |||||
| public: | |||||
| RegisterFrontendPrimitiveEvalHelper(const PrimitivePtr &primitive, const StandardPrimitiveEvalImpl &impl) { | |||||
| const StandardPrimitiveImplReg impl_reg{impl, false}; | |||||
| RegisterStandardPrimitiveImpl(primitive, impl_reg); | |||||
| } | |||||
| ~RegisterFrontendPrimitiveEvalHelper() = default; | |||||
| }; | |||||
| #define REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(name, primitive, impl) \ | |||||
| static auto helper_##name = RegisterFrontendPrimitiveEvalHelper(primitive, impl) | |||||
| } // namespace abstract | } // namespace abstract | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -308,6 +308,10 @@ AbstractBasePtr InferImplSparseSoftmaxCrossEntropyWithLogits(const AnalysisEngin | |||||
| const AbstractBasePtrList &args_spec_list); | const AbstractBasePtrList &args_spec_list); | ||||
| AbstractBasePtr InferImplDType(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr InferImplDType(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const AbstractBasePtrList &args_spec_list); | const AbstractBasePtrList &args_spec_list); | ||||
| AbstractBasePtr InferImplLoad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list); | |||||
| AbstractBasePtr InferImplAssign(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list); | |||||
| template <typename T> | template <typename T> | ||||
| AbstractBasePtr InferTupleOrListOrDictLen(const std::string &op_name, const AbstractBasePtrList &args_spec_list) { | AbstractBasePtr InferTupleOrListOrDictLen(const std::string &op_name, const AbstractBasePtrList &args_spec_list) { | ||||
| // Inputs: a tuple or list or dict. | // Inputs: a tuple or list or dict. | ||||
| @@ -577,5 +577,31 @@ AbstractBasePtr InferImplDType(const AnalysisEnginePtr &, const PrimitivePtr &pr | |||||
| abstract->set_value(value); | abstract->set_value(value); | ||||
| return abstract; | return abstract; | ||||
| } | } | ||||
| AbstractBasePtr InferImplLoad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list) { | |||||
| // Inputs: Ref/Tensor, universal | |||||
| CheckArgsSize(primitive->name(), args_spec_list, 2); | |||||
| auto ref_abs = dyn_cast<abstract::AbstractRef>(args_spec_list[0]); | |||||
| if (ref_abs != nullptr) { | |||||
| // Return tensor value if input is Ref. | |||||
| return ref_abs->CloneAsTensor(); | |||||
| } | |||||
| return args_spec_list[0]->Broaden(); | |||||
| } | |||||
| AbstractBasePtr InferImplAssign(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list) { | |||||
| // Inputs: Ref, value, [universal] | |||||
| CheckRequiredArgsSize(primitive->name(), args_spec_list, 2); | |||||
| MS_LOG(DEBUG) << "InferImplAssign " << args_spec_list[0]; | |||||
| auto type = args_spec_list[0]->BuildType(); | |||||
| if (type->type_id() == kObjectTypeRefKey) { | |||||
| return args_spec_list[1]->Broaden(); | |||||
| } else { | |||||
| return args_spec_list[0]; | |||||
| } | |||||
| } | |||||
| } // namespace abstract | } // namespace abstract | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -55,158 +55,161 @@ std::vector<int64_t> GetDependsFormMap(const CNodePtr &cnode) { | |||||
| PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { | PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { | ||||
| static PrimitiveEvalImplMap prim_eval_implement_map = { | static PrimitiveEvalImplMap prim_eval_implement_map = { | ||||
| // Statements | // Statements | ||||
| {prim::kPrimReturn, {InferImplReturn, true}}, | |||||
| {prim::kPrimSwitch, {InferImplSwitch, true}}, | |||||
| {prim::kPrimSwitchLayer, {InferImplSwitchLayer, true}}, | |||||
| {prim::kPrimIs_, {InferImplIs_, true}}, | |||||
| {prim::kPrimIsNot, {InferImplIsNot, true}}, | |||||
| {prim::kPrimInDict, {InferImplInDict, true}}, | |||||
| {prim::kPrimNotInDict, {InferImplNotInDict, true}}, | |||||
| {prim::kPrimIsConsant, {InferImplIsConstant, true}}, | |||||
| {prim::kPrimReturn, {InferImplReturn, nullptr, true}}, | |||||
| {prim::kPrimSwitch, {InferImplSwitch, nullptr, true}}, | |||||
| {prim::kPrimSwitchLayer, {InferImplSwitchLayer, nullptr, true}}, | |||||
| {prim::kPrimIs_, {InferImplIs_, nullptr, true}}, | |||||
| {prim::kPrimIsNot, {InferImplIsNot, nullptr, true}}, | |||||
| {prim::kPrimInDict, {InferImplInDict, nullptr, true}}, | |||||
| {prim::kPrimNotInDict, {InferImplNotInDict, nullptr, true}}, | |||||
| {prim::kPrimIsConsant, {InferImplIsConstant, nullptr, true}}, | |||||
| // Maths | // Maths | ||||
| {prim::kPrimSquare, {InferImplSquare, true}}, | |||||
| {prim::kPrimMatMul, {InferImplMatMul, true}}, | |||||
| {prim::kPrimBatchMatMul, {InferImplBatchMatMul, true}}, | |||||
| {prim::kPrimMaximumGrad, {InferImplMinOrMaxGrad, true}}, | |||||
| {prim::kPrimMinimumGrad, {InferImplMinOrMaxGrad, true}}, | |||||
| {prim::kPrimSqrt, {InferImplSqrt, true}}, | |||||
| {prim::kPrimSquare, {InferImplSquare, nullptr, true}}, | |||||
| {prim::kPrimMatMul, {InferImplMatMul, nullptr, true}}, | |||||
| {prim::kPrimBatchMatMul, {InferImplBatchMatMul, nullptr, true}}, | |||||
| {prim::kPrimMaximumGrad, {InferImplMinOrMaxGrad, nullptr, true}}, | |||||
| {prim::kPrimMinimumGrad, {InferImplMinOrMaxGrad, nullptr, true}}, | |||||
| {prim::kPrimSqrt, {InferImplSqrt, nullptr, true}}, | |||||
| // Array | // Array | ||||
| {prim::kPrimRange, {InferImplRange, true}}, | |||||
| {prim::kPrimScalarToArray, {InferImplScalarToArray, true}}, | |||||
| {prim::kPrimArrayToScalar, {InferImplArrayToScalar, true}}, | |||||
| {prim::kPrimBroadcastShape, {InferImplBroadCastShape, true}}, | |||||
| {prim::kPrimUnique, {InferImplUnique, true}}, | |||||
| {prim::kPrimUniqueGrad, {InferImplUniqueGrad, true}}, | |||||
| {prim::kPrimGather, {InferImplGatherV2, true}}, | |||||
| {prim::kPrimEmbeddingLookup, {InferImplEmbeddingLookup, true}}, | |||||
| {prim::kPrimSparseGatherV2, {InferImplGatherV2, true}}, | |||||
| {prim::kPrimUnsortedSegmentMax, {InferImplUnsortedSegmentMax, true}}, | |||||
| {prim::kPrimUnsortedSegmentMin, {InferImplUnsortedSegmentMin, true}}, | |||||
| {prim::kPrimScatterAdd, {InferImplScatterAdd, true}}, | |||||
| {prim::kPrimSubAndFilter, {InferImplSubAndFilter, true}}, | |||||
| {prim::kPrimScatterUpdate, {InferImplScatterUpdate, true}}, | |||||
| {prim::kPrimMapCacheIdx, {InferImplMapCacheIdx, true}}, | |||||
| {prim::kPrimDynamicAssign, {InferImplDynamicAssign, true}}, | |||||
| {prim::kPrimCacheSwapTable, {InferImplCacheSwapTable, true}}, | |||||
| {prim::kPrimUpdateCache, {InferImplUpdateCache, true}}, | |||||
| {prim::kPrimComputeAccidentalHits, {InferImplComputeAccidentalHits, true}}, | |||||
| {prim::kPrimPadAndShift, {InferImplPadAndShift, true}}, | |||||
| {prim::kPrimDynamicShape, {InferImplDynamicShape, true}}, | |||||
| {prim::kPrimMapUniform, {InferImplMapUniform, true}}, | |||||
| {prim::kPrimSplit, {InferImplSplit, true}}, | |||||
| {prim::kPrimSequenceMask, {InferImplSequenceMask, true}}, | |||||
| {prim::kPrimRange, {InferImplRange, nullptr, true}}, | |||||
| {prim::kPrimScalarToArray, {InferImplScalarToArray, nullptr, true}}, | |||||
| {prim::kPrimArrayToScalar, {InferImplArrayToScalar, nullptr, true}}, | |||||
| {prim::kPrimBroadcastShape, {InferImplBroadCastShape, nullptr, true}}, | |||||
| {prim::kPrimUnique, {InferImplUnique, nullptr, true}}, | |||||
| {prim::kPrimUniqueGrad, {InferImplUniqueGrad, nullptr, true}}, | |||||
| {prim::kPrimGather, {InferImplGatherV2, nullptr, true}}, | |||||
| {prim::kPrimEmbeddingLookup, {InferImplEmbeddingLookup, nullptr, true}}, | |||||
| {prim::kPrimSparseGatherV2, {InferImplGatherV2, nullptr, true}}, | |||||
| {prim::kPrimUnsortedSegmentMax, {InferImplUnsortedSegmentMax, nullptr, true}}, | |||||
| {prim::kPrimUnsortedSegmentMin, {InferImplUnsortedSegmentMin, nullptr, true}}, | |||||
| {prim::kPrimScatterAdd, {InferImplScatterAdd, nullptr, true}}, | |||||
| {prim::kPrimSubAndFilter, {InferImplSubAndFilter, nullptr, true}}, | |||||
| {prim::kPrimScatterUpdate, {InferImplScatterUpdate, nullptr, true}}, | |||||
| {prim::kPrimMapCacheIdx, {InferImplMapCacheIdx, nullptr, true}}, | |||||
| {prim::kPrimDynamicAssign, {InferImplDynamicAssign, nullptr, true}}, | |||||
| {prim::kPrimCacheSwapTable, {InferImplCacheSwapTable, nullptr, true}}, | |||||
| {prim::kPrimUpdateCache, {InferImplUpdateCache, nullptr, true}}, | |||||
| {prim::kPrimComputeAccidentalHits, {InferImplComputeAccidentalHits, nullptr, true}}, | |||||
| {prim::kPrimPadAndShift, {InferImplPadAndShift, nullptr, true}}, | |||||
| {prim::kPrimDynamicShape, {InferImplDynamicShape, nullptr, true}}, | |||||
| {prim::kPrimMapUniform, {InferImplMapUniform, nullptr, true}}, | |||||
| {prim::kPrimSplit, {InferImplSplit, nullptr, true}}, | |||||
| {prim::kPrimSequenceMask, {InferImplSequenceMask, nullptr, true}}, | |||||
| // Structure | // Structure | ||||
| {prim::kPrimMakeTuple, {InferImplMakeTuple, true}}, | |||||
| {prim::kPrimMakeList, {InferImplMakeList, true}}, | |||||
| {prim::kPrimMakeDict, {InferImplMakeDict, true}}, | |||||
| {prim::kPrimMakeSlice, {InferImplMakeSlice, true}}, | |||||
| {prim::kPrimMakeKeywordArg, {InferImplMakeKwarg, true}}, | |||||
| {prim::kPrimExtractKeywordArg, {InferImplExtractKwarg, true}}, | |||||
| {prim::kPrimTupleGetItem, {InferImplTupleGetItem, true}}, | |||||
| {prim::kPrimListGetItem, {InferImplListGetItem, true}}, | |||||
| {prim::kPrimTupleSetItem, {InferImplTupleSetItem, true}}, | |||||
| {prim::kPrimListSetItem, {InferImplListSetItem, true}}, | |||||
| {prim::kPrimDictGetItem, {InferImplDictGetItem, true}}, | |||||
| {prim::kPrimDictSetItem, {InferImplDictSetItem, true}}, | |||||
| {prim::kPrimDictGetKeys, {InferImplDictGetKeys, true}}, | |||||
| {prim::kPrimDictGetValues, {InferImplDictGetValues, true}}, | |||||
| {prim::kPrimListAppend, {InferImplListAppend, true}}, | |||||
| {prim::kPrimTupleLen, {InferImplTupleLen, true}}, | |||||
| {prim::kPrimListLen, {InferImplListLen, true}}, | |||||
| {prim::kPrimArrayLen, {InferImplArrayLen, true}}, | |||||
| {prim::kPrimMakeTuple, {InferImplMakeTuple, nullptr, true}}, | |||||
| {prim::kPrimMakeList, {InferImplMakeList, nullptr, true}}, | |||||
| {prim::kPrimMakeDict, {InferImplMakeDict, nullptr, true}}, | |||||
| {prim::kPrimMakeSlice, {InferImplMakeSlice, nullptr, true}}, | |||||
| {prim::kPrimMakeKeywordArg, {InferImplMakeKwarg, nullptr, true}}, | |||||
| {prim::kPrimExtractKeywordArg, {InferImplExtractKwarg, nullptr, true}}, | |||||
| {prim::kPrimTupleGetItem, {InferImplTupleGetItem, nullptr, true}}, | |||||
| {prim::kPrimListGetItem, {InferImplListGetItem, nullptr, true}}, | |||||
| {prim::kPrimTupleSetItem, {InferImplTupleSetItem, nullptr, true}}, | |||||
| {prim::kPrimListSetItem, {InferImplListSetItem, nullptr, true}}, | |||||
| {prim::kPrimDictGetItem, {InferImplDictGetItem, nullptr, true}}, | |||||
| {prim::kPrimDictSetItem, {InferImplDictSetItem, nullptr, true}}, | |||||
| {prim::kPrimDictGetKeys, {InferImplDictGetKeys, nullptr, true}}, | |||||
| {prim::kPrimDictGetValues, {InferImplDictGetValues, nullptr, true}}, | |||||
| {prim::kPrimListAppend, {InferImplListAppend, nullptr, true}}, | |||||
| {prim::kPrimTupleLen, {InferImplTupleLen, nullptr, true}}, | |||||
| {prim::kPrimListLen, {InferImplListLen, nullptr, true}}, | |||||
| {prim::kPrimArrayLen, {InferImplArrayLen, nullptr, true}}, | |||||
| // NN | // NN | ||||
| {prim::kPrimPooling, {InferImplPooling, true}}, | |||||
| {prim::kPrimPoolingGrad, {InferImplPoolingGrad, true}}, | |||||
| {prim::kPrimBatchNorm, {InferImplBatchNorm, true}}, | |||||
| {prim::kPrimReluGrad, {InferImplReluGrad, true}}, | |||||
| {prim::kPrimConv2D, {InferImplConv2D, true}}, | |||||
| {prim::kPrimBiasAdd, {InferImplBiasAdd, true}}, | |||||
| {prim::kPrimRelu, {InferImplRelu, true}}, | |||||
| {prim::kPrimRelu6, {InferImplRelu, true}}, | |||||
| {prim::kPrimZerosLike, {InferImplZerosLike, true}}, | |||||
| {prim::kPrimBpropCut, {InferImplBpropCut, true}}, | |||||
| {prim::kPrimLayerNorm, {InferImplLayerNorm, true}}, | |||||
| {prim::kPrimLayerNormGrad, {InferImplLayerNormGrad, true}}, | |||||
| {prim::kPrimDropout, {InferImplDropout, true}}, | |||||
| {prim::kPrimDropoutGenMask, {InferImplDropoutGenMask, true}}, | |||||
| {prim::kPrimSparseApplyFtrl, {InferImplSparseApplyFtrl, true}}, | |||||
| {prim::kPrimSparseApplyProximalAdagrad, {InferImplSparseApplyProximalAdagrad, true}}, | |||||
| {prim::kPrimSGD, {InferImplSGD, true}}, | |||||
| {prim::kPrimCTCGreedyDecoder, {InferImplCTCGreedyDecoder, true}}, | |||||
| {prim::kPrimPooling, {InferImplPooling, nullptr, true}}, | |||||
| {prim::kPrimPoolingGrad, {InferImplPoolingGrad, nullptr, true}}, | |||||
| {prim::kPrimBatchNorm, {InferImplBatchNorm, nullptr, true}}, | |||||
| {prim::kPrimReluGrad, {InferImplReluGrad, nullptr, true}}, | |||||
| {prim::kPrimConv2D, {InferImplConv2D, nullptr, true}}, | |||||
| {prim::kPrimBiasAdd, {InferImplBiasAdd, nullptr, true}}, | |||||
| {prim::kPrimRelu, {InferImplRelu, nullptr, true}}, | |||||
| {prim::kPrimRelu6, {InferImplRelu, nullptr, true}}, | |||||
| {prim::kPrimZerosLike, {InferImplZerosLike, nullptr, true}}, | |||||
| {prim::kPrimBpropCut, {InferImplBpropCut, nullptr, true}}, | |||||
| {prim::kPrimLayerNorm, {InferImplLayerNorm, nullptr, true}}, | |||||
| {prim::kPrimLayerNormGrad, {InferImplLayerNormGrad, nullptr, true}}, | |||||
| {prim::kPrimDropout, {InferImplDropout, nullptr, true}}, | |||||
| {prim::kPrimDropoutGenMask, {InferImplDropoutGenMask, nullptr, true}}, | |||||
| {prim::kPrimSparseApplyFtrl, {InferImplSparseApplyFtrl, nullptr, true}}, | |||||
| {prim::kPrimSparseApplyProximalAdagrad, {InferImplSparseApplyProximalAdagrad, nullptr, true}}, | |||||
| {prim::kPrimSGD, {InferImplSGD, nullptr, true}}, | |||||
| {prim::kPrimCTCGreedyDecoder, {InferImplCTCGreedyDecoder, nullptr, true}}, | |||||
| // Others | // Others | ||||
| {prim::kPrimIdentity, {InferImplIdentity, true}}, | |||||
| {prim::kPrimIdentity, {InferImplIdentity, nullptr, true}}, | |||||
| {prim::kPrimLoad, {InferImplLoad, nullptr, true}}, | |||||
| {prim::kPrimAssign, {InferImplAssign, nullptr, true}}, | |||||
| // Set impl to null as it will use PartialEvaluator; | // Set impl to null as it will use PartialEvaluator; | ||||
| {prim::kPrimPartial, {nullptr, true}}, | |||||
| {prim::kPrimEnvGetItem, {InferImplEnvGetItem, true}}, | |||||
| {prim::kPrimEnvSetItem, {InferImplEnvSetItem, true}}, | |||||
| {prim::kPrimEnvAdd, {InferImplEnvAdd, true}}, | |||||
| {prim::kPrimMakeRefKey, {InferImplMakeRefKey, true}}, | |||||
| {prim::kPrimMakeRef, {InferImplMakeRef, true}}, | |||||
| {prim::kPrimGetRefKey, {InferImplGetRefKey, true}}, | |||||
| {prim::kPrimGetRefValue, {InferImplGetRefValue, true}}, | |||||
| {prim::kPrimStateSetItem, {InferImplStateSetItem, true}}, | |||||
| {prim::kPrimDepend, {InferImplDepend, true}}, | |||||
| {prim::kPrimUpdateState, {InferImplUpdateState, true}}, | |||||
| {prim::kPrimControlDepend, {InferImplControlDepend, true}}, | |||||
| {prim::kPrimPartial, {nullptr, nullptr, true}}, | |||||
| {prim::kPrimEnvGetItem, {InferImplEnvGetItem, nullptr, true}}, | |||||
| {prim::kPrimEnvSetItem, {InferImplEnvSetItem, nullptr, true}}, | |||||
| {prim::kPrimEnvAdd, {InferImplEnvAdd, nullptr, true}}, | |||||
| {prim::kPrimMakeRefKey, {InferImplMakeRefKey, nullptr, true}}, | |||||
| {prim::kPrimMakeRef, {InferImplMakeRef, nullptr, true}}, | |||||
| {prim::kPrimGetRefKey, {InferImplGetRefKey, nullptr, true}}, | |||||
| {prim::kPrimGetRefValue, {InferImplGetRefValue, nullptr, true}}, | |||||
| {prim::kPrimStateSetItem, {InferImplStateSetItem, nullptr, true}}, | |||||
| {prim::kPrimDepend, {InferImplDepend, nullptr, true}}, | |||||
| {prim::kPrimUpdateState, {InferImplUpdateState, nullptr, true}}, | |||||
| {prim::kPrimControlDepend, {InferImplControlDepend, nullptr, true}}, | |||||
| // Debug | // Debug | ||||
| {prim::kPrimDebug, {InferImplDebug, true}}, | |||||
| {prim::kPrimDebug, {InferImplDebug, nullptr, true}}, | |||||
| // Dynamic shape testing | // Dynamic shape testing | ||||
| {prim::kPrimGpuConvertToDynamicShape, {InferImplGpuConvertToDynamicShape, true}}, | |||||
| {prim::kPrimGpuConvertToDynamicShape, {InferImplGpuConvertToDynamicShape, nullptr, true}}, | |||||
| // SparseTensor | // SparseTensor | ||||
| {prim::kPrimMakeSparseTensor, {InferImplMakeSparseTensor, true}}, | |||||
| {prim::kPrimSparseTensorGetValues, {InferImplSparseTensorGetValues, true}}, | |||||
| {prim::kPrimSparseTensorGetIndices, {InferImplSparseTensorGetIndices, true}}, | |||||
| {prim::kPrimSparseTensorGetDenseShape, {InferImplSparseTensorGetDenseShape, true}}, | |||||
| {prim::kPrimMakeSparseTensor, {InferImplMakeSparseTensor, nullptr, true}}, | |||||
| {prim::kPrimSparseTensorGetValues, {InferImplSparseTensorGetValues, nullptr, true}}, | |||||
| {prim::kPrimSparseTensorGetIndices, {InferImplSparseTensorGetIndices, nullptr, true}}, | |||||
| {prim::kPrimSparseTensorGetDenseShape, {InferImplSparseTensorGetDenseShape, nullptr, true}}, | |||||
| // RowTensor | // RowTensor | ||||
| {prim::kPrimMakeRowTensor, {InferImplMakeRowTensor, true}}, | |||||
| {prim::kPrimRowTensorGetValues, {InferImplRowTensorGetValues, true}}, | |||||
| {prim::kPrimRowTensorGetIndices, {InferImplRowTensorGetIndices, true}}, | |||||
| {prim::kPrimRowTensorGetDenseShape, {InferImplRowTensorGetDenseShape, true}}, | |||||
| {prim::kPrimRowTensorAdd, {InferImplRowTensorAdd, false}}, | |||||
| {prim::kPrimMakeRowTensor, {InferImplMakeRowTensor, nullptr, true}}, | |||||
| {prim::kPrimRowTensorGetValues, {InferImplRowTensorGetValues, nullptr, true}}, | |||||
| {prim::kPrimRowTensorGetIndices, {InferImplRowTensorGetIndices, nullptr, true}}, | |||||
| {prim::kPrimRowTensorGetDenseShape, {InferImplRowTensorGetDenseShape, nullptr, true}}, | |||||
| {prim::kPrimRowTensorAdd, {InferImplRowTensorAdd, nullptr, false}}, | |||||
| // Comm Ops | // Comm Ops | ||||
| {prim::kPrimAllSwap, {InferImplAllSwap, true}}, | |||||
| {prim::kPrimMemCpyAsync, {InferImplMemCpyAsync, true}}, | |||||
| {prim::kPrimAllSwap, {InferImplAllSwap, nullptr, true}}, | |||||
| {prim::kPrimMemCpyAsync, {InferImplMemCpyAsync, nullptr, true}}, | |||||
| }; | }; | ||||
| return prim_eval_implement_map; | return prim_eval_implement_map; | ||||
| } | } | ||||
| PrimitiveEvalImplMap &GetPrimitiveToBackendEvalImplMap() { | PrimitiveEvalImplMap &GetPrimitiveToBackendEvalImplMap() { | ||||
| static PrimitiveEvalImplMap prim_backend_eval_implement_map = { | static PrimitiveEvalImplMap prim_backend_eval_implement_map = { | ||||
| {prim::kPrimMul, {InferImplMul, true}}, | |||||
| {prim::kPrimAdd, {InferImplAdd, true}}, | |||||
| {prim::kPrimSqrtGrad, {InferImplSqrtGrad, true}}, | |||||
| {prim::kPrimSub, {InferImplSub, true}}, | |||||
| {prim::kPrimEqual, {InferImplEqual, true}}, | |||||
| {prim::kPrimReduceSum, {InferImplReduceFunc, true}}, | |||||
| {prim::kPrimReduceMean, {InferImplReduceFunc, true}}, | |||||
| {prim::kPrimReduceAll, {InferImplReduceFunc, true}}, | |||||
| {prim::kPrimReduceAny, {InferImplReduceFunc, true}}, | |||||
| {prim::kPrimReduceMax, {InferImplReduceFunc, true}}, | |||||
| {prim::kPrimReduceMin, {InferImplReduceFunc, true}}, | |||||
| {prim::kPrimBiasAddGrad, {InferImplBiasAddGrad, true}}, | |||||
| {prim::kPrimReduceScatter, {InferImplReduceScatter, true}}, | |||||
| {prim::kPrimCast, {InferImplCast, true}}, | |||||
| {prim::kPrimExpandDims, {InferImplExpandDims, true}}, | |||||
| {prim::kPrimAllReduce, {InferImplAllReduce, true}}, | |||||
| {prim::kPrimBroadcast, {InferImplBroadcast, true}}, | |||||
| {prim::kPrimAllGather, {InferImplAllGather, true}}, | |||||
| {prim::kPrimMinimum, {InferImplMinimum, true}}, | |||||
| {prim::kPrimDivNoNan, {InferImplDivNoNan, true}}, | |||||
| {prim::kPrimLinSpace, {InferImplLinSpace, true}}, | |||||
| {prim::kPrimAddN, {InferImplAddN, true}}, | |||||
| {prim::kPrimMul, {InferImplMul, nullptr, true}}, | |||||
| {prim::kPrimAdd, {InferImplAdd, nullptr, true}}, | |||||
| {prim::kPrimSqrtGrad, {InferImplSqrtGrad, nullptr, true}}, | |||||
| {prim::kPrimSub, {InferImplSub, nullptr, true}}, | |||||
| {prim::kPrimEqual, {InferImplEqual, nullptr, true}}, | |||||
| {prim::kPrimReduceSum, {InferImplReduceFunc, nullptr, true}}, | |||||
| {prim::kPrimReduceMean, {InferImplReduceFunc, nullptr, true}}, | |||||
| {prim::kPrimReduceAll, {InferImplReduceFunc, nullptr, true}}, | |||||
| {prim::kPrimReduceAny, {InferImplReduceFunc, nullptr, true}}, | |||||
| {prim::kPrimReduceMax, {InferImplReduceFunc, nullptr, true}}, | |||||
| {prim::kPrimReduceMin, {InferImplReduceFunc, nullptr, true}}, | |||||
| {prim::kPrimBiasAddGrad, {InferImplBiasAddGrad, nullptr, true}}, | |||||
| {prim::kPrimReduceScatter, {InferImplReduceScatter, nullptr, true}}, | |||||
| {prim::kPrimCast, {InferImplCast, nullptr, true}}, | |||||
| {prim::kPrimExpandDims, {InferImplExpandDims, nullptr, true}}, | |||||
| {prim::kPrimAllReduce, {InferImplAllReduce, nullptr, true}}, | |||||
| {prim::kPrimBroadcast, {InferImplBroadcast, nullptr, true}}, | |||||
| {prim::kPrimAllGather, {InferImplAllGather, nullptr, true}}, | |||||
| {prim::kPrimMinimum, {InferImplMinimum, nullptr, true}}, | |||||
| {prim::kPrimDivNoNan, {InferImplDivNoNan, nullptr, true}}, | |||||
| {prim::kPrimLinSpace, {InferImplLinSpace, nullptr, true}}, | |||||
| {prim::kPrimAddN, {InferImplAddN, nullptr, true}}, | |||||
| {prim::kPrimLess, {InferImplLess, true}}, | |||||
| {prim::kPrimStack, {InferImplStack, true}}, | |||||
| {prim::kPrimPad, {InferImplPad, true}}, | |||||
| {prim::kPrimUnsortedSegmentSum, {InferImplUnsortedSegmentSum, true}}, | |||||
| {prim::kPrimDiv, {InferImplDiv, true}}, | |||||
| {prim::kPrimRealDiv, {InferImplRealDiv, true}}, | |||||
| {prim::kPrimShape, {InferImplShape, false}}, | |||||
| {prim::kPrimTranspose, {InferImplTranspose, true}}, | |||||
| {prim::kPrimReshape, {InferImplReshape, true}}, | |||||
| {prim::kPrimConcat, {InferImplConcat, true}}, | |||||
| {prim::kPrimArgMaxWithValue, {InferImplArgMaxWithValue, true}}, | |||||
| {prim::kPrimFusedSparseAdam, {InferImplFusedSparseAdam, true}}, | |||||
| {prim::kPrimLess, {InferImplLess, nullptr, true}}, | |||||
| {prim::kPrimStack, {InferImplStack, nullptr, true}}, | |||||
| {prim::kPrimPad, {InferImplPad, nullptr, true}}, | |||||
| {prim::kPrimUnsortedSegmentSum, {InferImplUnsortedSegmentSum, nullptr, true}}, | |||||
| {prim::kPrimDiv, {InferImplDiv, nullptr, true}}, | |||||
| {prim::kPrimRealDiv, {InferImplRealDiv, nullptr, true}}, | |||||
| {prim::kPrimShape, {InferImplShape, nullptr, false}}, | |||||
| {prim::kPrimTranspose, {InferImplTranspose, nullptr, true}}, | |||||
| {prim::kPrimReshape, {InferImplReshape, nullptr, true}}, | |||||
| {prim::kPrimConcat, {InferImplConcat, nullptr, true}}, | |||||
| {prim::kPrimArgMaxWithValue, {InferImplArgMaxWithValue, nullptr, true}}, | |||||
| {prim::kPrimFusedSparseAdam, {InferImplFusedSparseAdam, nullptr, true}}, | |||||
| }; | }; | ||||
| return prim_backend_eval_implement_map; | return prim_backend_eval_implement_map; | ||||
| } | } | ||||
| @@ -28,9 +28,13 @@ namespace mindspore { | |||||
| namespace abstract { | namespace abstract { | ||||
| using StandardPrimitiveEvalImpl = AbstractBasePtr (*)(const abstract::AnalysisEnginePtr &, const PrimitivePtr &, | using StandardPrimitiveEvalImpl = AbstractBasePtr (*)(const abstract::AnalysisEnginePtr &, const PrimitivePtr &, | ||||
| const AbstractBasePtrList &); | const AbstractBasePtrList &); | ||||
| using InferValueEvalImpl = ValuePtr (*)(const PrimitivePtr &, const AbstractBasePtrList &, const AbstractBasePtr &); | |||||
| struct StandardPrimitiveImplReg { | struct StandardPrimitiveImplReg { | ||||
| StandardPrimitiveEvalImpl impl_; // Implement function of Primitive. | |||||
| bool in_white_list_; // true if this Primitive in white list, else false. | |||||
| StandardPrimitiveEvalImpl impl_; // Implement function of Primitive | |||||
| InferValueEvalImpl infer_value_func_; // infer value of primitive | |||||
| // true means this primitive can be executed by vm backend else will be constant folded by frontend | |||||
| bool in_white_list_; | |||||
| }; | }; | ||||
| using PrimitiveEvalImplMap = | using PrimitiveEvalImplMap = | ||||
| @@ -48,15 +52,17 @@ void RegisterStandardPrimitiveImpl(const PrimitivePtr &primitive, const Standard | |||||
| class RegisterStandardPrimitiveEvalHelper { | class RegisterStandardPrimitiveEvalHelper { | ||||
| public: | public: | ||||
| RegisterStandardPrimitiveEvalHelper(const PrimitivePtr &primitive, const StandardPrimitiveEvalImpl &impl) { | |||||
| const StandardPrimitiveImplReg impl_reg{impl, true}; | |||||
| RegisterStandardPrimitiveEvalHelper(const PrimitivePtr &primitive, const StandardPrimitiveEvalImpl &impl, | |||||
| const InferValueEvalImpl &infer_value_impl, const bool is_wight_list = true) { | |||||
| const StandardPrimitiveImplReg impl_reg{impl, infer_value_impl, is_wight_list}; | |||||
| RegisterStandardPrimitiveImpl(primitive, impl_reg); | RegisterStandardPrimitiveImpl(primitive, impl_reg); | ||||
| } | } | ||||
| ~RegisterStandardPrimitiveEvalHelper() = default; | ~RegisterStandardPrimitiveEvalHelper() = default; | ||||
| }; | }; | ||||
| #define REGISTER_PRIMITIVE_EVAL_IMPL(name, primitive, impl) \ | |||||
| static auto helper_##name = abstract::RegisterStandardPrimitiveEvalHelper(primitive, impl) | |||||
| #define REGISTER_PRIMITIVE_EVAL_IMPL(name, primitive, impl, infer_value_impl, is_wight_list) \ | |||||
| static auto helper_##name = \ | |||||
| abstract::RegisterStandardPrimitiveEvalHelper(primitive, impl, infer_value_impl, is_wight_list) | |||||
| } // namespace abstract | } // namespace abstract | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CORE_ABSTRACT_PRIMITIVE_INFER_MAP_H_ | #endif // MINDSPORE_CORE_ABSTRACT_PRIMITIVE_INFER_MAP_H_ | ||||
| @@ -541,6 +541,40 @@ inline const PrimitivePtr kPrimDType = std::make_shared<Primitive>("DType"); | |||||
| inline const PrimitivePtr kPrimDivFusion = std::make_shared<Primitive>("DivFusion"); | inline const PrimitivePtr kPrimDivFusion = std::make_shared<Primitive>("DivFusion"); | ||||
| inline const PrimitivePtr kPrimErf = std::make_shared<Primitive>("Erf"); | inline const PrimitivePtr kPrimErf = std::make_shared<Primitive>("Erf"); | ||||
| // Type introspection | |||||
| inline const PrimitivePtr kPrimTypeOf = std::make_shared<Primitive>("typeof"); | |||||
| inline const PrimitivePtr kPrimHasType = std::make_shared<Primitive>("hastype"); | |||||
| inline const PrimitivePtr kPrimResolve = std::make_shared<Primitive>("resolve"); | |||||
| inline const PrimitivePtr kPrimEmbed = std::make_shared<Primitive>("embed"); | |||||
| inline const PrimitivePtr kPrimRefToEmbed = std::make_shared<Primitive>("RefToEmbed"); | |||||
| inline const PrimitivePtr kPrimCreateInstance = std::make_shared<Primitive>("create_instance"); | |||||
| // Other miscellaneous | |||||
| inline const PrimitivePtr kPrimGetRefOrigin = std::make_shared<Primitive>("get_ref_origin"); | |||||
| inline const PrimitivePtr kPrimInsertGradientOf = std::make_shared<Primitive>("InsertGradientOf"); | |||||
| inline const PrimitivePtr kPrimCheckBprop = std::make_shared<Primitive>("CheckBprop"); | |||||
| inline const PrimitivePtr kPrimMixedPrecisionCast = std::make_shared<Primitive>("mixed_precision_cast"); | |||||
| inline const PrimitivePtr kPrimMakeRecord = std::make_shared<Primitive>("make_record"); | |||||
| // Structures | |||||
| inline const PrimitivePtr kPrimListMap = std::make_shared<Primitive>("list_map"); | |||||
| inline const PrimitivePtr kPrimListReduce = std::make_shared<Primitive>("list_reduce"); | |||||
| inline const PrimitivePtr kPrimTupleReversed = std::make_shared<Primitive>("tuple_reversed"); | |||||
| inline const PrimitivePtr kPrimReducedShape = std::make_shared<Primitive>("reduced_shape"); | |||||
| inline const PrimitivePtr kPrimTupleDiv = std::make_shared<Primitive>("tuple_div"); | |||||
| inline const PrimitivePtr kPrimTupleToArray = std::make_shared<Primitive>("tuple_to_array"); | |||||
| inline const PrimitivePtr kPrimShapeMul = std::make_shared<Primitive>("shape_mul"); | |||||
| inline const PrimitivePtr kPrimTupleEqual = std::make_shared<Primitive>("tuple_equal"); | |||||
| inline const PrimitivePtr kPrimListEqual = std::make_shared<Primitive>("list_equal"); | |||||
| inline const PrimitivePtr kPrimMakeRange = std::make_shared<Primitive>("make_range"); | |||||
| inline const PrimitivePtr kPrimStopGradient = std::make_shared<Primitive>("stop_gradient"); | |||||
| inline const PrimitivePtr kPrimStringEqual = std::make_shared<Primitive>("string_equal"); | |||||
| inline const PrimitivePtr kPrimStringConcat = std::make_shared<Primitive>("string_concat"); | |||||
| inline const PrimitivePtr kPrimDictLen = std::make_shared<Primitive>("dict_len"); | |||||
| inline const PrimitivePtr kPrimFakeBprop = std::make_shared<Primitive>("fake_bprop"); | |||||
| inline const PrimitivePtr kPrimBroadcastGradientArgs = std::make_shared<Primitive>("BroadcastGradientArgs"); | |||||
| class DoSignaturePrimitive : public Primitive { | class DoSignaturePrimitive : public Primitive { | ||||
| public: | public: | ||||
| explicit DoSignaturePrimitive(const std::string &name, const ValuePtr &function) | explicit DoSignaturePrimitive(const std::string &name, const ValuePtr &function) | ||||
| @@ -49,6 +49,5 @@ AbstractBasePtr AddInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr | |||||
| return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args), | return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args), | ||||
| InferShape(primitive, input_args)->shape()); | InferShape(primitive, input_args)->shape()); | ||||
| } | } | ||||
| REGISTER_PRIMITIVE_EVAL_IMPL(Add, prim::kPrimAdd, AddInfer); | |||||
| REGISTER_PRIMITIVE_C(kNameAdd, Add); | REGISTER_PRIMITIVE_C(kNameAdd, Add); | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -42,7 +42,7 @@ AbstractBasePtr ScalarSummaryInfer(const abstract::AnalysisEnginePtr &, const Pr | |||||
| CheckAndConvertUtils::CheckInteger("v rank", v_shape.size(), kLessEqual, 1, prim_name); | CheckAndConvertUtils::CheckInteger("v rank", v_shape.size(), kLessEqual, 1, prim_name); | ||||
| return std::make_shared<abstract::AbstractTensor>(kInt32, std::make_shared<abstract::Shape>(ShapeVector(1))); | return std::make_shared<abstract::AbstractTensor>(kInt32, std::make_shared<abstract::Shape>(ShapeVector(1))); | ||||
| } | } | ||||
| REGISTER_PRIMITIVE_EVAL_IMPL(ScalarSummary, prim::kPrimScalarSummary, ScalarSummaryInfer); | |||||
| REGISTER_PRIMITIVE_EVAL_IMPL(ScalarSummary, prim::kPrimScalarSummary, ScalarSummaryInfer, nullptr, true); | |||||
| REGISTER_PRIMITIVE_C(kNameScalarSummary, ScalarSummary); | REGISTER_PRIMITIVE_C(kNameScalarSummary, ScalarSummary); | ||||
| } // namespace ops | } // namespace ops | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -42,7 +42,7 @@ AbstractBasePtr TensorSummaryInfer(const abstract::AnalysisEnginePtr &, const Pr | |||||
| CheckAndConvertUtils::CheckInteger("v rank", v_shape.size(), kGreaterEqual, 1, prim_name); | CheckAndConvertUtils::CheckInteger("v rank", v_shape.size(), kGreaterEqual, 1, prim_name); | ||||
| return std::make_shared<abstract::AbstractTensor>(kInt32, std::make_shared<abstract::Shape>(ShapeVector(1))); | return std::make_shared<abstract::AbstractTensor>(kInt32, std::make_shared<abstract::Shape>(ShapeVector(1))); | ||||
| } | } | ||||
| REGISTER_PRIMITIVE_EVAL_IMPL(TensorSummary, prim::kPrimTensorSummary, TensorSummaryInfer); | |||||
| REGISTER_PRIMITIVE_EVAL_IMPL(TensorSummary, prim::kPrimTensorSummary, TensorSummaryInfer, nullptr, true); | |||||
| REGISTER_PRIMITIVE_C(kNameTensorSummary, TensorSummary); | REGISTER_PRIMITIVE_C(kNameTensorSummary, TensorSummary); | ||||
| } // namespace ops | } // namespace ops | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -36,7 +36,7 @@ AbstractBasePtr InferImplAttrTest(const abstract::AnalysisEnginePtr &, const Pri | |||||
| EXPECT_EQ(args_spec_list[1]->isa<abstract::AbstractTuple>(), true); | EXPECT_EQ(args_spec_list[1]->isa<abstract::AbstractTuple>(), true); | ||||
| return args_spec_list[0]; | return args_spec_list[0]; | ||||
| } | } | ||||
| REGISTER_PRIMITIVE_EVAL_IMPL(TestAttr,kPrimAttrConvertTest,InferImplAttrTest); | |||||
| REGISTER_PRIMITIVE_EVAL_IMPL(TestAttr, kPrimAttrConvertTest, InferImplAttrTest, nullptr, true); | |||||
| AbstractBasePtr InferImplDynamicInputTest(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr InferImplDynamicInputTest(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const AbstractBasePtrList &args_spec_list) { | const AbstractBasePtrList &args_spec_list) { | ||||
| EXPECT_EQ(args_spec_list.size(), 3); | EXPECT_EQ(args_spec_list.size(), 3); | ||||
| @@ -45,7 +45,7 @@ AbstractBasePtr InferImplDynamicInputTest(const abstract::AnalysisEnginePtr &, c | |||||
| auto item = args_spec_list[1]->cast<abstract::AbstractTuplePtr>(); | auto item = args_spec_list[1]->cast<abstract::AbstractTuplePtr>(); | ||||
| return args_spec_list[0]; | return args_spec_list[0]; | ||||
| } | } | ||||
| REGISTER_PRIMITIVE_EVAL_IMPL(TestDynamicInput,kPrimDynamicInputTest,InferImplDynamicInputTest); | |||||
| REGISTER_PRIMITIVE_EVAL_IMPL(TestDynamicInput, kPrimDynamicInputTest, InferImplDynamicInputTest, nullptr, true); | |||||
| class TestAttrAndDynamicBackendInfer : public UT::Common { | class TestAttrAndDynamicBackendInfer : public UT::Common { | ||||
| public: | public: | ||||
| TestAttrAndDynamicBackendInfer() {} | TestAttrAndDynamicBackendInfer() {} | ||||