From: @lianliguang Reviewed-by: @ginfung,@chujinjin Signed-off-by: @chujinjinpull/14548/MERGE
| @@ -30,45 +30,6 @@ namespace prim { | |||
| ValuePtr GetPythonOps(const std::string &op_name, | |||
| const std::string &module_name = "mindspore._extends.parse.standard_method", | |||
| bool use_signature = false); | |||
| // Primitives only used by frontend; | |||
| // Type introspection | |||
| inline const PrimitivePtr kPrimTypeOf = std::make_shared<Primitive>("typeof"); | |||
| inline const PrimitivePtr kPrimHasType = std::make_shared<Primitive>("hastype"); | |||
| inline const PrimitivePtr kPrimResolve = std::make_shared<Primitive>("resolve"); | |||
| inline const PrimitivePtr kPrimEmbed = std::make_shared<Primitive>("embed"); | |||
| inline const PrimitivePtr kPrimRefToEmbed = std::make_shared<Primitive>("RefToEmbed"); | |||
| inline const PrimitivePtr kPrimCreateInstance = std::make_shared<Primitive>("create_instance"); | |||
| // Other miscellaneous | |||
| inline const PrimitivePtr kPrimGetRefOrigin = std::make_shared<Primitive>("get_ref_origin"); | |||
| inline const PrimitivePtr kPrimInsertGradientOf = std::make_shared<Primitive>("InsertGradientOf"); | |||
| inline const PrimitivePtr kPrimCheckBprop = std::make_shared<Primitive>("CheckBprop"); | |||
| inline const PrimitivePtr kPrimMixedPrecisionCast = std::make_shared<Primitive>("mixed_precision_cast"); | |||
| inline const PrimitivePtr kPrimMakeRecord = std::make_shared<Primitive>("make_record"); | |||
| // Structures | |||
| inline const PrimitivePtr kPrimListMap = std::make_shared<Primitive>("list_map"); | |||
| inline const PrimitivePtr kPrimListReduce = std::make_shared<Primitive>("list_reduce"); | |||
| inline const PrimitivePtr kPrimTupleReversed = std::make_shared<Primitive>("tuple_reversed"); | |||
| inline const PrimitivePtr kPrimReducedShape = std::make_shared<Primitive>("reduced_shape"); | |||
| inline const PrimitivePtr kPrimTupleDiv = std::make_shared<Primitive>("tuple_div"); | |||
| inline const PrimitivePtr kPrimTupleToArray = std::make_shared<Primitive>("tuple_to_array"); | |||
| inline const PrimitivePtr kPrimShapeMul = std::make_shared<Primitive>("shape_mul"); | |||
| inline const PrimitivePtr kPrimTupleEqual = std::make_shared<Primitive>("tuple_equal"); | |||
| inline const PrimitivePtr kPrimListEqual = std::make_shared<Primitive>("list_equal"); | |||
| inline const PrimitivePtr kPrimMakeRange = std::make_shared<Primitive>("make_range"); | |||
| inline const PrimitivePtr kPrimStopGradient = std::make_shared<Primitive>("stop_gradient"); | |||
| inline const PrimitivePtr kPrimStringEqual = std::make_shared<Primitive>("string_equal"); | |||
| inline const PrimitivePtr kPrimStringConcat = std::make_shared<Primitive>("string_concat"); | |||
| inline const PrimitivePtr kPrimDictLen = std::make_shared<Primitive>("dict_len"); | |||
| inline const PrimitivePtr kPrimFakeBprop = std::make_shared<Primitive>("fake_bprop"); | |||
| inline const PrimitivePtr kPrimBroadcastGradientArgs = std::make_shared<Primitive>("BroadcastGradientArgs"); | |||
| class UnpackGraphPrimitive : public Primitive { | |||
| public: | |||
| explicit UnpackGraphPrimitive(const std::string &name, const bool &with_sens, const bool &need_unpack_args) | |||
| @@ -639,55 +639,26 @@ AbstractBasePtr InferImplMakeRecord(const AnalysisEnginePtr &, const PrimitivePt | |||
| return std::make_shared<AbstractClass>(cls->tag(), abs_attributes, cls->methods()); | |||
| } | |||
| AbstractBasePtr InferImplAssign(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| // Inputs: Ref, value, [universal] | |||
| CheckRequiredArgsSize(primitive->name(), args_spec_list, 2); | |||
| MS_LOG(DEBUG) << "InferImplAssign " << args_spec_list[0]; | |||
| auto type = args_spec_list[0]->BuildType(); | |||
| if (type->type_id() == kObjectTypeRefKey) { | |||
| return args_spec_list[1]->Broaden(); | |||
| } else { | |||
| return args_spec_list[0]; | |||
| } | |||
| } | |||
| AbstractBasePtr InferImplLoad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| // Inputs: Ref/Tensor, universal | |||
| CheckArgsSize(primitive->name(), args_spec_list, 2); | |||
| auto ref_abs = dyn_cast<abstract::AbstractRef>(args_spec_list[0]); | |||
| if (ref_abs != nullptr) { | |||
| // Return tensor value if input is Ref. | |||
| return ref_abs->CloneAsTensor(); | |||
| } | |||
| return args_spec_list[0]->Broaden(); | |||
| } | |||
| REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(TypeOf, prim::kPrimTypeOf, InferImplTypeof); | |||
| REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(HasType, prim::kPrimHasType, InferImplHasType); | |||
| REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(MakeRecord, prim::kPrimMakeRecord, InferImplMakeRecord); | |||
| REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(ListMap, prim::kPrimListMap, InferImplListMap); | |||
| REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(ListReduce, prim::kPrimListReduce, InferImplListReduce); | |||
| REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(TupleReversed, prim::kPrimTupleReversed, InferImplTupleReversed); | |||
| REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(ReducedShape, prim::kPrimReducedShape, InferImplReduceShape); | |||
| REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(TupleDiv, prim::kPrimTupleDiv, InferImplTupleDiv); | |||
| REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(TupleToArray, prim::kPrimTupleToArray, InferImplTuple2Array); | |||
| REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(ShapeMul, prim::kPrimShapeMul, InferImplShapeMul); | |||
| REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(TupleEqual, prim::kPrimTupleEqual, InferImplTupleEqual); | |||
| REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(ListEqual, prim::kPrimListEqual, InferImplListEqual); | |||
| REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(MakeRange, prim::kPrimMakeRange, InferImplMakeRange); | |||
| REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(StopGradient, prim::kPrimStopGradient, InferImplStopGradient); | |||
| REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(StringEqual, prim::kPrimStringEqual, InferImplStringEqual); | |||
| REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(StringConcat, prim::kPrimStringConcat, InferImplStringConcat); | |||
| REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(DictLen, prim::kPrimDictLen, InferImplDictLen); | |||
| REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(FakeBprop, prim::kPrimFakeBprop, InferImplFakeBprop); | |||
| REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(J, prim::kPrimJ, InferImplJ); | |||
| REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(BroadcastGradientArgs, prim::kPrimBroadcastGradientArgs, | |||
| InferImplBroadcastGradientArgs); | |||
| REGISTER_PRIMITIVE_EVAL_IMPL(Assign, prim::kPrimAssign, InferImplAssign); | |||
| REGISTER_PRIMITIVE_EVAL_IMPL(Load, prim::kPrimLoad, InferImplLoad); | |||
| REGISTER_PRIMITIVE_EVAL_IMPL(TypeOf, prim::kPrimTypeOf, InferImplTypeof, nullptr, false); | |||
| REGISTER_PRIMITIVE_EVAL_IMPL(HasType, prim::kPrimHasType, InferImplHasType, nullptr, false); | |||
| REGISTER_PRIMITIVE_EVAL_IMPL(MakeRecord, prim::kPrimMakeRecord, InferImplMakeRecord, nullptr, false); | |||
| REGISTER_PRIMITIVE_EVAL_IMPL(ListMap, prim::kPrimListMap, InferImplListMap, nullptr, false); | |||
| REGISTER_PRIMITIVE_EVAL_IMPL(ListReduce, prim::kPrimListReduce, InferImplListReduce, nullptr, false); | |||
| REGISTER_PRIMITIVE_EVAL_IMPL(TupleReversed, prim::kPrimTupleReversed, InferImplTupleReversed, nullptr, false); | |||
| REGISTER_PRIMITIVE_EVAL_IMPL(ReducedShape, prim::kPrimReducedShape, InferImplReduceShape, nullptr, false); | |||
| REGISTER_PRIMITIVE_EVAL_IMPL(TupleDiv, prim::kPrimTupleDiv, InferImplTupleDiv, nullptr, false); | |||
| REGISTER_PRIMITIVE_EVAL_IMPL(TupleToArray, prim::kPrimTupleToArray, InferImplTuple2Array, nullptr, false); | |||
| REGISTER_PRIMITIVE_EVAL_IMPL(ShapeMul, prim::kPrimShapeMul, InferImplShapeMul, nullptr, false); | |||
| REGISTER_PRIMITIVE_EVAL_IMPL(TupleEqual, prim::kPrimTupleEqual, InferImplTupleEqual, nullptr, false); | |||
| REGISTER_PRIMITIVE_EVAL_IMPL(ListEqual, prim::kPrimListEqual, InferImplListEqual, nullptr, false); | |||
| REGISTER_PRIMITIVE_EVAL_IMPL(MakeRange, prim::kPrimMakeRange, InferImplMakeRange, nullptr, false); | |||
| REGISTER_PRIMITIVE_EVAL_IMPL(StopGradient, prim::kPrimStopGradient, InferImplStopGradient, nullptr, false); | |||
| REGISTER_PRIMITIVE_EVAL_IMPL(StringEqual, prim::kPrimStringEqual, InferImplStringEqual, nullptr, false); | |||
| REGISTER_PRIMITIVE_EVAL_IMPL(StringConcat, prim::kPrimStringConcat, InferImplStringConcat, nullptr, false); | |||
| REGISTER_PRIMITIVE_EVAL_IMPL(DictLen, prim::kPrimDictLen, InferImplDictLen, nullptr, false); | |||
| REGISTER_PRIMITIVE_EVAL_IMPL(FakeBprop, prim::kPrimFakeBprop, InferImplFakeBprop, nullptr, false); | |||
| REGISTER_PRIMITIVE_EVAL_IMPL(J, prim::kPrimJ, InferImplJ, nullptr, false); | |||
| REGISTER_PRIMITIVE_EVAL_IMPL(BroadcastGradientArgs, prim::kPrimBroadcastGradientArgs, InferImplBroadcastGradientArgs, | |||
| nullptr, false); | |||
| } // namespace abstract | |||
| } // namespace mindspore | |||
| @@ -59,18 +59,6 @@ AbstractBasePtr InferImplFakeBprop(const AnalysisEnginePtr &, const PrimitivePtr | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplMakeRecord(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| class RegisterFrontendPrimitiveEvalHelper { | |||
| public: | |||
| RegisterFrontendPrimitiveEvalHelper(const PrimitivePtr &primitive, const StandardPrimitiveEvalImpl &impl) { | |||
| const StandardPrimitiveImplReg impl_reg{impl, false}; | |||
| RegisterStandardPrimitiveImpl(primitive, impl_reg); | |||
| } | |||
| ~RegisterFrontendPrimitiveEvalHelper() = default; | |||
| }; | |||
| #define REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(name, primitive, impl) \ | |||
| static auto helper_##name = RegisterFrontendPrimitiveEvalHelper(primitive, impl) | |||
| } // namespace abstract | |||
| } // namespace mindspore | |||
| @@ -308,6 +308,10 @@ AbstractBasePtr InferImplSparseSoftmaxCrossEntropyWithLogits(const AnalysisEngin | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplDType(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplLoad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplAssign(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| template <typename T> | |||
| AbstractBasePtr InferTupleOrListOrDictLen(const std::string &op_name, const AbstractBasePtrList &args_spec_list) { | |||
| // Inputs: a tuple or list or dict. | |||
| @@ -577,5 +577,31 @@ AbstractBasePtr InferImplDType(const AnalysisEnginePtr &, const PrimitivePtr &pr | |||
| abstract->set_value(value); | |||
| return abstract; | |||
| } | |||
| AbstractBasePtr InferImplLoad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| // Inputs: Ref/Tensor, universal | |||
| CheckArgsSize(primitive->name(), args_spec_list, 2); | |||
| auto ref_abs = dyn_cast<abstract::AbstractRef>(args_spec_list[0]); | |||
| if (ref_abs != nullptr) { | |||
| // Return tensor value if input is Ref. | |||
| return ref_abs->CloneAsTensor(); | |||
| } | |||
| return args_spec_list[0]->Broaden(); | |||
| } | |||
| AbstractBasePtr InferImplAssign(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| // Inputs: Ref, value, [universal] | |||
| CheckRequiredArgsSize(primitive->name(), args_spec_list, 2); | |||
| MS_LOG(DEBUG) << "InferImplAssign " << args_spec_list[0]; | |||
| auto type = args_spec_list[0]->BuildType(); | |||
| if (type->type_id() == kObjectTypeRefKey) { | |||
| return args_spec_list[1]->Broaden(); | |||
| } else { | |||
| return args_spec_list[0]; | |||
| } | |||
| } | |||
| } // namespace abstract | |||
| } // namespace mindspore | |||
| @@ -55,158 +55,161 @@ std::vector<int64_t> GetDependsFormMap(const CNodePtr &cnode) { | |||
| PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { | |||
| static PrimitiveEvalImplMap prim_eval_implement_map = { | |||
| // Statements | |||
| {prim::kPrimReturn, {InferImplReturn, true}}, | |||
| {prim::kPrimSwitch, {InferImplSwitch, true}}, | |||
| {prim::kPrimSwitchLayer, {InferImplSwitchLayer, true}}, | |||
| {prim::kPrimIs_, {InferImplIs_, true}}, | |||
| {prim::kPrimIsNot, {InferImplIsNot, true}}, | |||
| {prim::kPrimInDict, {InferImplInDict, true}}, | |||
| {prim::kPrimNotInDict, {InferImplNotInDict, true}}, | |||
| {prim::kPrimIsConsant, {InferImplIsConstant, true}}, | |||
| {prim::kPrimReturn, {InferImplReturn, nullptr, true}}, | |||
| {prim::kPrimSwitch, {InferImplSwitch, nullptr, true}}, | |||
| {prim::kPrimSwitchLayer, {InferImplSwitchLayer, nullptr, true}}, | |||
| {prim::kPrimIs_, {InferImplIs_, nullptr, true}}, | |||
| {prim::kPrimIsNot, {InferImplIsNot, nullptr, true}}, | |||
| {prim::kPrimInDict, {InferImplInDict, nullptr, true}}, | |||
| {prim::kPrimNotInDict, {InferImplNotInDict, nullptr, true}}, | |||
| {prim::kPrimIsConsant, {InferImplIsConstant, nullptr, true}}, | |||
| // Maths | |||
| {prim::kPrimSquare, {InferImplSquare, true}}, | |||
| {prim::kPrimMatMul, {InferImplMatMul, true}}, | |||
| {prim::kPrimBatchMatMul, {InferImplBatchMatMul, true}}, | |||
| {prim::kPrimMaximumGrad, {InferImplMinOrMaxGrad, true}}, | |||
| {prim::kPrimMinimumGrad, {InferImplMinOrMaxGrad, true}}, | |||
| {prim::kPrimSqrt, {InferImplSqrt, true}}, | |||
| {prim::kPrimSquare, {InferImplSquare, nullptr, true}}, | |||
| {prim::kPrimMatMul, {InferImplMatMul, nullptr, true}}, | |||
| {prim::kPrimBatchMatMul, {InferImplBatchMatMul, nullptr, true}}, | |||
| {prim::kPrimMaximumGrad, {InferImplMinOrMaxGrad, nullptr, true}}, | |||
| {prim::kPrimMinimumGrad, {InferImplMinOrMaxGrad, nullptr, true}}, | |||
| {prim::kPrimSqrt, {InferImplSqrt, nullptr, true}}, | |||
| // Array | |||
| {prim::kPrimRange, {InferImplRange, true}}, | |||
| {prim::kPrimScalarToArray, {InferImplScalarToArray, true}}, | |||
| {prim::kPrimArrayToScalar, {InferImplArrayToScalar, true}}, | |||
| {prim::kPrimBroadcastShape, {InferImplBroadCastShape, true}}, | |||
| {prim::kPrimUnique, {InferImplUnique, true}}, | |||
| {prim::kPrimUniqueGrad, {InferImplUniqueGrad, true}}, | |||
| {prim::kPrimGather, {InferImplGatherV2, true}}, | |||
| {prim::kPrimEmbeddingLookup, {InferImplEmbeddingLookup, true}}, | |||
| {prim::kPrimSparseGatherV2, {InferImplGatherV2, true}}, | |||
| {prim::kPrimUnsortedSegmentMax, {InferImplUnsortedSegmentMax, true}}, | |||
| {prim::kPrimUnsortedSegmentMin, {InferImplUnsortedSegmentMin, true}}, | |||
| {prim::kPrimScatterAdd, {InferImplScatterAdd, true}}, | |||
| {prim::kPrimSubAndFilter, {InferImplSubAndFilter, true}}, | |||
| {prim::kPrimScatterUpdate, {InferImplScatterUpdate, true}}, | |||
| {prim::kPrimMapCacheIdx, {InferImplMapCacheIdx, true}}, | |||
| {prim::kPrimDynamicAssign, {InferImplDynamicAssign, true}}, | |||
| {prim::kPrimCacheSwapTable, {InferImplCacheSwapTable, true}}, | |||
| {prim::kPrimUpdateCache, {InferImplUpdateCache, true}}, | |||
| {prim::kPrimComputeAccidentalHits, {InferImplComputeAccidentalHits, true}}, | |||
| {prim::kPrimPadAndShift, {InferImplPadAndShift, true}}, | |||
| {prim::kPrimDynamicShape, {InferImplDynamicShape, true}}, | |||
| {prim::kPrimMapUniform, {InferImplMapUniform, true}}, | |||
| {prim::kPrimSplit, {InferImplSplit, true}}, | |||
| {prim::kPrimSequenceMask, {InferImplSequenceMask, true}}, | |||
| {prim::kPrimRange, {InferImplRange, nullptr, true}}, | |||
| {prim::kPrimScalarToArray, {InferImplScalarToArray, nullptr, true}}, | |||
| {prim::kPrimArrayToScalar, {InferImplArrayToScalar, nullptr, true}}, | |||
| {prim::kPrimBroadcastShape, {InferImplBroadCastShape, nullptr, true}}, | |||
| {prim::kPrimUnique, {InferImplUnique, nullptr, true}}, | |||
| {prim::kPrimUniqueGrad, {InferImplUniqueGrad, nullptr, true}}, | |||
| {prim::kPrimGather, {InferImplGatherV2, nullptr, true}}, | |||
| {prim::kPrimEmbeddingLookup, {InferImplEmbeddingLookup, nullptr, true}}, | |||
| {prim::kPrimSparseGatherV2, {InferImplGatherV2, nullptr, true}}, | |||
| {prim::kPrimUnsortedSegmentMax, {InferImplUnsortedSegmentMax, nullptr, true}}, | |||
| {prim::kPrimUnsortedSegmentMin, {InferImplUnsortedSegmentMin, nullptr, true}}, | |||
| {prim::kPrimScatterAdd, {InferImplScatterAdd, nullptr, true}}, | |||
| {prim::kPrimSubAndFilter, {InferImplSubAndFilter, nullptr, true}}, | |||
| {prim::kPrimScatterUpdate, {InferImplScatterUpdate, nullptr, true}}, | |||
| {prim::kPrimMapCacheIdx, {InferImplMapCacheIdx, nullptr, true}}, | |||
| {prim::kPrimDynamicAssign, {InferImplDynamicAssign, nullptr, true}}, | |||
| {prim::kPrimCacheSwapTable, {InferImplCacheSwapTable, nullptr, true}}, | |||
| {prim::kPrimUpdateCache, {InferImplUpdateCache, nullptr, true}}, | |||
| {prim::kPrimComputeAccidentalHits, {InferImplComputeAccidentalHits, nullptr, true}}, | |||
| {prim::kPrimPadAndShift, {InferImplPadAndShift, nullptr, true}}, | |||
| {prim::kPrimDynamicShape, {InferImplDynamicShape, nullptr, true}}, | |||
| {prim::kPrimMapUniform, {InferImplMapUniform, nullptr, true}}, | |||
| {prim::kPrimSplit, {InferImplSplit, nullptr, true}}, | |||
| {prim::kPrimSequenceMask, {InferImplSequenceMask, nullptr, true}}, | |||
| // Structure | |||
| {prim::kPrimMakeTuple, {InferImplMakeTuple, true}}, | |||
| {prim::kPrimMakeList, {InferImplMakeList, true}}, | |||
| {prim::kPrimMakeDict, {InferImplMakeDict, true}}, | |||
| {prim::kPrimMakeSlice, {InferImplMakeSlice, true}}, | |||
| {prim::kPrimMakeKeywordArg, {InferImplMakeKwarg, true}}, | |||
| {prim::kPrimExtractKeywordArg, {InferImplExtractKwarg, true}}, | |||
| {prim::kPrimTupleGetItem, {InferImplTupleGetItem, true}}, | |||
| {prim::kPrimListGetItem, {InferImplListGetItem, true}}, | |||
| {prim::kPrimTupleSetItem, {InferImplTupleSetItem, true}}, | |||
| {prim::kPrimListSetItem, {InferImplListSetItem, true}}, | |||
| {prim::kPrimDictGetItem, {InferImplDictGetItem, true}}, | |||
| {prim::kPrimDictSetItem, {InferImplDictSetItem, true}}, | |||
| {prim::kPrimDictGetKeys, {InferImplDictGetKeys, true}}, | |||
| {prim::kPrimDictGetValues, {InferImplDictGetValues, true}}, | |||
| {prim::kPrimListAppend, {InferImplListAppend, true}}, | |||
| {prim::kPrimTupleLen, {InferImplTupleLen, true}}, | |||
| {prim::kPrimListLen, {InferImplListLen, true}}, | |||
| {prim::kPrimArrayLen, {InferImplArrayLen, true}}, | |||
| {prim::kPrimMakeTuple, {InferImplMakeTuple, nullptr, true}}, | |||
| {prim::kPrimMakeList, {InferImplMakeList, nullptr, true}}, | |||
| {prim::kPrimMakeDict, {InferImplMakeDict, nullptr, true}}, | |||
| {prim::kPrimMakeSlice, {InferImplMakeSlice, nullptr, true}}, | |||
| {prim::kPrimMakeKeywordArg, {InferImplMakeKwarg, nullptr, true}}, | |||
| {prim::kPrimExtractKeywordArg, {InferImplExtractKwarg, nullptr, true}}, | |||
| {prim::kPrimTupleGetItem, {InferImplTupleGetItem, nullptr, true}}, | |||
| {prim::kPrimListGetItem, {InferImplListGetItem, nullptr, true}}, | |||
| {prim::kPrimTupleSetItem, {InferImplTupleSetItem, nullptr, true}}, | |||
| {prim::kPrimListSetItem, {InferImplListSetItem, nullptr, true}}, | |||
| {prim::kPrimDictGetItem, {InferImplDictGetItem, nullptr, true}}, | |||
| {prim::kPrimDictSetItem, {InferImplDictSetItem, nullptr, true}}, | |||
| {prim::kPrimDictGetKeys, {InferImplDictGetKeys, nullptr, true}}, | |||
| {prim::kPrimDictGetValues, {InferImplDictGetValues, nullptr, true}}, | |||
| {prim::kPrimListAppend, {InferImplListAppend, nullptr, true}}, | |||
| {prim::kPrimTupleLen, {InferImplTupleLen, nullptr, true}}, | |||
| {prim::kPrimListLen, {InferImplListLen, nullptr, true}}, | |||
| {prim::kPrimArrayLen, {InferImplArrayLen, nullptr, true}}, | |||
| // NN | |||
| {prim::kPrimPooling, {InferImplPooling, true}}, | |||
| {prim::kPrimPoolingGrad, {InferImplPoolingGrad, true}}, | |||
| {prim::kPrimBatchNorm, {InferImplBatchNorm, true}}, | |||
| {prim::kPrimReluGrad, {InferImplReluGrad, true}}, | |||
| {prim::kPrimConv2D, {InferImplConv2D, true}}, | |||
| {prim::kPrimBiasAdd, {InferImplBiasAdd, true}}, | |||
| {prim::kPrimRelu, {InferImplRelu, true}}, | |||
| {prim::kPrimRelu6, {InferImplRelu, true}}, | |||
| {prim::kPrimZerosLike, {InferImplZerosLike, true}}, | |||
| {prim::kPrimBpropCut, {InferImplBpropCut, true}}, | |||
| {prim::kPrimLayerNorm, {InferImplLayerNorm, true}}, | |||
| {prim::kPrimLayerNormGrad, {InferImplLayerNormGrad, true}}, | |||
| {prim::kPrimDropout, {InferImplDropout, true}}, | |||
| {prim::kPrimDropoutGenMask, {InferImplDropoutGenMask, true}}, | |||
| {prim::kPrimSparseApplyFtrl, {InferImplSparseApplyFtrl, true}}, | |||
| {prim::kPrimSparseApplyProximalAdagrad, {InferImplSparseApplyProximalAdagrad, true}}, | |||
| {prim::kPrimSGD, {InferImplSGD, true}}, | |||
| {prim::kPrimCTCGreedyDecoder, {InferImplCTCGreedyDecoder, true}}, | |||
| {prim::kPrimPooling, {InferImplPooling, nullptr, true}}, | |||
| {prim::kPrimPoolingGrad, {InferImplPoolingGrad, nullptr, true}}, | |||
| {prim::kPrimBatchNorm, {InferImplBatchNorm, nullptr, true}}, | |||
| {prim::kPrimReluGrad, {InferImplReluGrad, nullptr, true}}, | |||
| {prim::kPrimConv2D, {InferImplConv2D, nullptr, true}}, | |||
| {prim::kPrimBiasAdd, {InferImplBiasAdd, nullptr, true}}, | |||
| {prim::kPrimRelu, {InferImplRelu, nullptr, true}}, | |||
| {prim::kPrimRelu6, {InferImplRelu, nullptr, true}}, | |||
| {prim::kPrimZerosLike, {InferImplZerosLike, nullptr, true}}, | |||
| {prim::kPrimBpropCut, {InferImplBpropCut, nullptr, true}}, | |||
| {prim::kPrimLayerNorm, {InferImplLayerNorm, nullptr, true}}, | |||
| {prim::kPrimLayerNormGrad, {InferImplLayerNormGrad, nullptr, true}}, | |||
| {prim::kPrimDropout, {InferImplDropout, nullptr, true}}, | |||
| {prim::kPrimDropoutGenMask, {InferImplDropoutGenMask, nullptr, true}}, | |||
| {prim::kPrimSparseApplyFtrl, {InferImplSparseApplyFtrl, nullptr, true}}, | |||
| {prim::kPrimSparseApplyProximalAdagrad, {InferImplSparseApplyProximalAdagrad, nullptr, true}}, | |||
| {prim::kPrimSGD, {InferImplSGD, nullptr, true}}, | |||
| {prim::kPrimCTCGreedyDecoder, {InferImplCTCGreedyDecoder, nullptr, true}}, | |||
| // Others | |||
| {prim::kPrimIdentity, {InferImplIdentity, true}}, | |||
| {prim::kPrimIdentity, {InferImplIdentity, nullptr, true}}, | |||
| {prim::kPrimLoad, {InferImplLoad, nullptr, true}}, | |||
| {prim::kPrimAssign, {InferImplAssign, nullptr, true}}, | |||
| // Set impl to null as it will use PartialEvaluator; | |||
| {prim::kPrimPartial, {nullptr, true}}, | |||
| {prim::kPrimEnvGetItem, {InferImplEnvGetItem, true}}, | |||
| {prim::kPrimEnvSetItem, {InferImplEnvSetItem, true}}, | |||
| {prim::kPrimEnvAdd, {InferImplEnvAdd, true}}, | |||
| {prim::kPrimMakeRefKey, {InferImplMakeRefKey, true}}, | |||
| {prim::kPrimMakeRef, {InferImplMakeRef, true}}, | |||
| {prim::kPrimGetRefKey, {InferImplGetRefKey, true}}, | |||
| {prim::kPrimGetRefValue, {InferImplGetRefValue, true}}, | |||
| {prim::kPrimStateSetItem, {InferImplStateSetItem, true}}, | |||
| {prim::kPrimDepend, {InferImplDepend, true}}, | |||
| {prim::kPrimUpdateState, {InferImplUpdateState, true}}, | |||
| {prim::kPrimControlDepend, {InferImplControlDepend, true}}, | |||
| {prim::kPrimPartial, {nullptr, nullptr, true}}, | |||
| {prim::kPrimEnvGetItem, {InferImplEnvGetItem, nullptr, true}}, | |||
| {prim::kPrimEnvSetItem, {InferImplEnvSetItem, nullptr, true}}, | |||
| {prim::kPrimEnvAdd, {InferImplEnvAdd, nullptr, true}}, | |||
| {prim::kPrimMakeRefKey, {InferImplMakeRefKey, nullptr, true}}, | |||
| {prim::kPrimMakeRef, {InferImplMakeRef, nullptr, true}}, | |||
| {prim::kPrimGetRefKey, {InferImplGetRefKey, nullptr, true}}, | |||
| {prim::kPrimGetRefValue, {InferImplGetRefValue, nullptr, true}}, | |||
| {prim::kPrimStateSetItem, {InferImplStateSetItem, nullptr, true}}, | |||
| {prim::kPrimDepend, {InferImplDepend, nullptr, true}}, | |||
| {prim::kPrimUpdateState, {InferImplUpdateState, nullptr, true}}, | |||
| {prim::kPrimControlDepend, {InferImplControlDepend, nullptr, true}}, | |||
| // Debug | |||
| {prim::kPrimDebug, {InferImplDebug, true}}, | |||
| {prim::kPrimDebug, {InferImplDebug, nullptr, true}}, | |||
| // Dynamic shape testing | |||
| {prim::kPrimGpuConvertToDynamicShape, {InferImplGpuConvertToDynamicShape, true}}, | |||
| {prim::kPrimGpuConvertToDynamicShape, {InferImplGpuConvertToDynamicShape, nullptr, true}}, | |||
| // SparseTensor | |||
| {prim::kPrimMakeSparseTensor, {InferImplMakeSparseTensor, true}}, | |||
| {prim::kPrimSparseTensorGetValues, {InferImplSparseTensorGetValues, true}}, | |||
| {prim::kPrimSparseTensorGetIndices, {InferImplSparseTensorGetIndices, true}}, | |||
| {prim::kPrimSparseTensorGetDenseShape, {InferImplSparseTensorGetDenseShape, true}}, | |||
| {prim::kPrimMakeSparseTensor, {InferImplMakeSparseTensor, nullptr, true}}, | |||
| {prim::kPrimSparseTensorGetValues, {InferImplSparseTensorGetValues, nullptr, true}}, | |||
| {prim::kPrimSparseTensorGetIndices, {InferImplSparseTensorGetIndices, nullptr, true}}, | |||
| {prim::kPrimSparseTensorGetDenseShape, {InferImplSparseTensorGetDenseShape, nullptr, true}}, | |||
| // RowTensor | |||
| {prim::kPrimMakeRowTensor, {InferImplMakeRowTensor, true}}, | |||
| {prim::kPrimRowTensorGetValues, {InferImplRowTensorGetValues, true}}, | |||
| {prim::kPrimRowTensorGetIndices, {InferImplRowTensorGetIndices, true}}, | |||
| {prim::kPrimRowTensorGetDenseShape, {InferImplRowTensorGetDenseShape, true}}, | |||
| {prim::kPrimRowTensorAdd, {InferImplRowTensorAdd, false}}, | |||
| {prim::kPrimMakeRowTensor, {InferImplMakeRowTensor, nullptr, true}}, | |||
| {prim::kPrimRowTensorGetValues, {InferImplRowTensorGetValues, nullptr, true}}, | |||
| {prim::kPrimRowTensorGetIndices, {InferImplRowTensorGetIndices, nullptr, true}}, | |||
| {prim::kPrimRowTensorGetDenseShape, {InferImplRowTensorGetDenseShape, nullptr, true}}, | |||
| {prim::kPrimRowTensorAdd, {InferImplRowTensorAdd, nullptr, false}}, | |||
| // Comm Ops | |||
| {prim::kPrimAllSwap, {InferImplAllSwap, true}}, | |||
| {prim::kPrimMemCpyAsync, {InferImplMemCpyAsync, true}}, | |||
| {prim::kPrimAllSwap, {InferImplAllSwap, nullptr, true}}, | |||
| {prim::kPrimMemCpyAsync, {InferImplMemCpyAsync, nullptr, true}}, | |||
| }; | |||
| return prim_eval_implement_map; | |||
| } | |||
| PrimitiveEvalImplMap &GetPrimitiveToBackendEvalImplMap() { | |||
| static PrimitiveEvalImplMap prim_backend_eval_implement_map = { | |||
| {prim::kPrimMul, {InferImplMul, true}}, | |||
| {prim::kPrimAdd, {InferImplAdd, true}}, | |||
| {prim::kPrimSqrtGrad, {InferImplSqrtGrad, true}}, | |||
| {prim::kPrimSub, {InferImplSub, true}}, | |||
| {prim::kPrimEqual, {InferImplEqual, true}}, | |||
| {prim::kPrimReduceSum, {InferImplReduceFunc, true}}, | |||
| {prim::kPrimReduceMean, {InferImplReduceFunc, true}}, | |||
| {prim::kPrimReduceAll, {InferImplReduceFunc, true}}, | |||
| {prim::kPrimReduceAny, {InferImplReduceFunc, true}}, | |||
| {prim::kPrimReduceMax, {InferImplReduceFunc, true}}, | |||
| {prim::kPrimReduceMin, {InferImplReduceFunc, true}}, | |||
| {prim::kPrimBiasAddGrad, {InferImplBiasAddGrad, true}}, | |||
| {prim::kPrimReduceScatter, {InferImplReduceScatter, true}}, | |||
| {prim::kPrimCast, {InferImplCast, true}}, | |||
| {prim::kPrimExpandDims, {InferImplExpandDims, true}}, | |||
| {prim::kPrimAllReduce, {InferImplAllReduce, true}}, | |||
| {prim::kPrimBroadcast, {InferImplBroadcast, true}}, | |||
| {prim::kPrimAllGather, {InferImplAllGather, true}}, | |||
| {prim::kPrimMinimum, {InferImplMinimum, true}}, | |||
| {prim::kPrimDivNoNan, {InferImplDivNoNan, true}}, | |||
| {prim::kPrimLinSpace, {InferImplLinSpace, true}}, | |||
| {prim::kPrimAddN, {InferImplAddN, true}}, | |||
| {prim::kPrimMul, {InferImplMul, nullptr, true}}, | |||
| {prim::kPrimAdd, {InferImplAdd, nullptr, true}}, | |||
| {prim::kPrimSqrtGrad, {InferImplSqrtGrad, nullptr, true}}, | |||
| {prim::kPrimSub, {InferImplSub, nullptr, true}}, | |||
| {prim::kPrimEqual, {InferImplEqual, nullptr, true}}, | |||
| {prim::kPrimReduceSum, {InferImplReduceFunc, nullptr, true}}, | |||
| {prim::kPrimReduceMean, {InferImplReduceFunc, nullptr, true}}, | |||
| {prim::kPrimReduceAll, {InferImplReduceFunc, nullptr, true}}, | |||
| {prim::kPrimReduceAny, {InferImplReduceFunc, nullptr, true}}, | |||
| {prim::kPrimReduceMax, {InferImplReduceFunc, nullptr, true}}, | |||
| {prim::kPrimReduceMin, {InferImplReduceFunc, nullptr, true}}, | |||
| {prim::kPrimBiasAddGrad, {InferImplBiasAddGrad, nullptr, true}}, | |||
| {prim::kPrimReduceScatter, {InferImplReduceScatter, nullptr, true}}, | |||
| {prim::kPrimCast, {InferImplCast, nullptr, true}}, | |||
| {prim::kPrimExpandDims, {InferImplExpandDims, nullptr, true}}, | |||
| {prim::kPrimAllReduce, {InferImplAllReduce, nullptr, true}}, | |||
| {prim::kPrimBroadcast, {InferImplBroadcast, nullptr, true}}, | |||
| {prim::kPrimAllGather, {InferImplAllGather, nullptr, true}}, | |||
| {prim::kPrimMinimum, {InferImplMinimum, nullptr, true}}, | |||
| {prim::kPrimDivNoNan, {InferImplDivNoNan, nullptr, true}}, | |||
| {prim::kPrimLinSpace, {InferImplLinSpace, nullptr, true}}, | |||
| {prim::kPrimAddN, {InferImplAddN, nullptr, true}}, | |||
| {prim::kPrimLess, {InferImplLess, true}}, | |||
| {prim::kPrimStack, {InferImplStack, true}}, | |||
| {prim::kPrimPad, {InferImplPad, true}}, | |||
| {prim::kPrimUnsortedSegmentSum, {InferImplUnsortedSegmentSum, true}}, | |||
| {prim::kPrimDiv, {InferImplDiv, true}}, | |||
| {prim::kPrimRealDiv, {InferImplRealDiv, true}}, | |||
| {prim::kPrimShape, {InferImplShape, false}}, | |||
| {prim::kPrimTranspose, {InferImplTranspose, true}}, | |||
| {prim::kPrimReshape, {InferImplReshape, true}}, | |||
| {prim::kPrimConcat, {InferImplConcat, true}}, | |||
| {prim::kPrimArgMaxWithValue, {InferImplArgMaxWithValue, true}}, | |||
| {prim::kPrimFusedSparseAdam, {InferImplFusedSparseAdam, true}}, | |||
| {prim::kPrimLess, {InferImplLess, nullptr, true}}, | |||
| {prim::kPrimStack, {InferImplStack, nullptr, true}}, | |||
| {prim::kPrimPad, {InferImplPad, nullptr, true}}, | |||
| {prim::kPrimUnsortedSegmentSum, {InferImplUnsortedSegmentSum, nullptr, true}}, | |||
| {prim::kPrimDiv, {InferImplDiv, nullptr, true}}, | |||
| {prim::kPrimRealDiv, {InferImplRealDiv, nullptr, true}}, | |||
| {prim::kPrimShape, {InferImplShape, nullptr, false}}, | |||
| {prim::kPrimTranspose, {InferImplTranspose, nullptr, true}}, | |||
| {prim::kPrimReshape, {InferImplReshape, nullptr, true}}, | |||
| {prim::kPrimConcat, {InferImplConcat, nullptr, true}}, | |||
| {prim::kPrimArgMaxWithValue, {InferImplArgMaxWithValue, nullptr, true}}, | |||
| {prim::kPrimFusedSparseAdam, {InferImplFusedSparseAdam, nullptr, true}}, | |||
| }; | |||
| return prim_backend_eval_implement_map; | |||
| } | |||
| @@ -28,9 +28,13 @@ namespace mindspore { | |||
| namespace abstract { | |||
| using StandardPrimitiveEvalImpl = AbstractBasePtr (*)(const abstract::AnalysisEnginePtr &, const PrimitivePtr &, | |||
| const AbstractBasePtrList &); | |||
| using InferValueEvalImpl = ValuePtr (*)(const PrimitivePtr &, const AbstractBasePtrList &, const AbstractBasePtr &); | |||
| struct StandardPrimitiveImplReg { | |||
| StandardPrimitiveEvalImpl impl_; // Implement function of Primitive. | |||
| bool in_white_list_; // true if this Primitive in white list, else false. | |||
| StandardPrimitiveEvalImpl impl_; // Implement function of Primitive | |||
| InferValueEvalImpl infer_value_func_; // infer value of primitive | |||
| // true means this primitive can be executed by vm backend else will be constant folded by frontend | |||
| bool in_white_list_; | |||
| }; | |||
| using PrimitiveEvalImplMap = | |||
| @@ -48,15 +52,17 @@ void RegisterStandardPrimitiveImpl(const PrimitivePtr &primitive, const Standard | |||
| class RegisterStandardPrimitiveEvalHelper { | |||
| public: | |||
| RegisterStandardPrimitiveEvalHelper(const PrimitivePtr &primitive, const StandardPrimitiveEvalImpl &impl) { | |||
| const StandardPrimitiveImplReg impl_reg{impl, true}; | |||
| RegisterStandardPrimitiveEvalHelper(const PrimitivePtr &primitive, const StandardPrimitiveEvalImpl &impl, | |||
| const InferValueEvalImpl &infer_value_impl, const bool is_wight_list = true) { | |||
| const StandardPrimitiveImplReg impl_reg{impl, infer_value_impl, is_wight_list}; | |||
| RegisterStandardPrimitiveImpl(primitive, impl_reg); | |||
| } | |||
| ~RegisterStandardPrimitiveEvalHelper() = default; | |||
| }; | |||
| #define REGISTER_PRIMITIVE_EVAL_IMPL(name, primitive, impl) \ | |||
| static auto helper_##name = abstract::RegisterStandardPrimitiveEvalHelper(primitive, impl) | |||
| #define REGISTER_PRIMITIVE_EVAL_IMPL(name, primitive, impl, infer_value_impl, is_wight_list) \ | |||
| static auto helper_##name = \ | |||
| abstract::RegisterStandardPrimitiveEvalHelper(primitive, impl, infer_value_impl, is_wight_list) | |||
| } // namespace abstract | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CORE_ABSTRACT_PRIMITIVE_INFER_MAP_H_ | |||
| @@ -541,6 +541,40 @@ inline const PrimitivePtr kPrimDType = std::make_shared<Primitive>("DType"); | |||
| inline const PrimitivePtr kPrimDivFusion = std::make_shared<Primitive>("DivFusion"); | |||
| inline const PrimitivePtr kPrimErf = std::make_shared<Primitive>("Erf"); | |||
| // Type introspection | |||
| inline const PrimitivePtr kPrimTypeOf = std::make_shared<Primitive>("typeof"); | |||
| inline const PrimitivePtr kPrimHasType = std::make_shared<Primitive>("hastype"); | |||
| inline const PrimitivePtr kPrimResolve = std::make_shared<Primitive>("resolve"); | |||
| inline const PrimitivePtr kPrimEmbed = std::make_shared<Primitive>("embed"); | |||
| inline const PrimitivePtr kPrimRefToEmbed = std::make_shared<Primitive>("RefToEmbed"); | |||
| inline const PrimitivePtr kPrimCreateInstance = std::make_shared<Primitive>("create_instance"); | |||
| // Other miscellaneous | |||
| inline const PrimitivePtr kPrimGetRefOrigin = std::make_shared<Primitive>("get_ref_origin"); | |||
| inline const PrimitivePtr kPrimInsertGradientOf = std::make_shared<Primitive>("InsertGradientOf"); | |||
| inline const PrimitivePtr kPrimCheckBprop = std::make_shared<Primitive>("CheckBprop"); | |||
| inline const PrimitivePtr kPrimMixedPrecisionCast = std::make_shared<Primitive>("mixed_precision_cast"); | |||
| inline const PrimitivePtr kPrimMakeRecord = std::make_shared<Primitive>("make_record"); | |||
| // Structures | |||
| inline const PrimitivePtr kPrimListMap = std::make_shared<Primitive>("list_map"); | |||
| inline const PrimitivePtr kPrimListReduce = std::make_shared<Primitive>("list_reduce"); | |||
| inline const PrimitivePtr kPrimTupleReversed = std::make_shared<Primitive>("tuple_reversed"); | |||
| inline const PrimitivePtr kPrimReducedShape = std::make_shared<Primitive>("reduced_shape"); | |||
| inline const PrimitivePtr kPrimTupleDiv = std::make_shared<Primitive>("tuple_div"); | |||
| inline const PrimitivePtr kPrimTupleToArray = std::make_shared<Primitive>("tuple_to_array"); | |||
| inline const PrimitivePtr kPrimShapeMul = std::make_shared<Primitive>("shape_mul"); | |||
| inline const PrimitivePtr kPrimTupleEqual = std::make_shared<Primitive>("tuple_equal"); | |||
| inline const PrimitivePtr kPrimListEqual = std::make_shared<Primitive>("list_equal"); | |||
| inline const PrimitivePtr kPrimMakeRange = std::make_shared<Primitive>("make_range"); | |||
| inline const PrimitivePtr kPrimStopGradient = std::make_shared<Primitive>("stop_gradient"); | |||
| inline const PrimitivePtr kPrimStringEqual = std::make_shared<Primitive>("string_equal"); | |||
| inline const PrimitivePtr kPrimStringConcat = std::make_shared<Primitive>("string_concat"); | |||
| inline const PrimitivePtr kPrimDictLen = std::make_shared<Primitive>("dict_len"); | |||
| inline const PrimitivePtr kPrimFakeBprop = std::make_shared<Primitive>("fake_bprop"); | |||
| inline const PrimitivePtr kPrimBroadcastGradientArgs = std::make_shared<Primitive>("BroadcastGradientArgs"); | |||
| class DoSignaturePrimitive : public Primitive { | |||
| public: | |||
| explicit DoSignaturePrimitive(const std::string &name, const ValuePtr &function) | |||
| @@ -49,6 +49,5 @@ AbstractBasePtr AddInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr | |||
| return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args), | |||
| InferShape(primitive, input_args)->shape()); | |||
| } | |||
| REGISTER_PRIMITIVE_EVAL_IMPL(Add, prim::kPrimAdd, AddInfer); | |||
| REGISTER_PRIMITIVE_C(kNameAdd, Add); | |||
| } // namespace mindspore | |||
| @@ -42,7 +42,7 @@ AbstractBasePtr ScalarSummaryInfer(const abstract::AnalysisEnginePtr &, const Pr | |||
| CheckAndConvertUtils::CheckInteger("v rank", v_shape.size(), kLessEqual, 1, prim_name); | |||
| return std::make_shared<abstract::AbstractTensor>(kInt32, std::make_shared<abstract::Shape>(ShapeVector(1))); | |||
| } | |||
| REGISTER_PRIMITIVE_EVAL_IMPL(ScalarSummary, prim::kPrimScalarSummary, ScalarSummaryInfer); | |||
| REGISTER_PRIMITIVE_EVAL_IMPL(ScalarSummary, prim::kPrimScalarSummary, ScalarSummaryInfer, nullptr, true); | |||
| REGISTER_PRIMITIVE_C(kNameScalarSummary, ScalarSummary); | |||
| } // namespace ops | |||
| } // namespace mindspore | |||
| @@ -42,7 +42,7 @@ AbstractBasePtr TensorSummaryInfer(const abstract::AnalysisEnginePtr &, const Pr | |||
| CheckAndConvertUtils::CheckInteger("v rank", v_shape.size(), kGreaterEqual, 1, prim_name); | |||
| return std::make_shared<abstract::AbstractTensor>(kInt32, std::make_shared<abstract::Shape>(ShapeVector(1))); | |||
| } | |||
| REGISTER_PRIMITIVE_EVAL_IMPL(TensorSummary, prim::kPrimTensorSummary, TensorSummaryInfer); | |||
| REGISTER_PRIMITIVE_EVAL_IMPL(TensorSummary, prim::kPrimTensorSummary, TensorSummaryInfer, nullptr, true); | |||
| REGISTER_PRIMITIVE_C(kNameTensorSummary, TensorSummary); | |||
| } // namespace ops | |||
| } // namespace mindspore | |||
| @@ -36,7 +36,7 @@ AbstractBasePtr InferImplAttrTest(const abstract::AnalysisEnginePtr &, const Pri | |||
| EXPECT_EQ(args_spec_list[1]->isa<abstract::AbstractTuple>(), true); | |||
| return args_spec_list[0]; | |||
| } | |||
| REGISTER_PRIMITIVE_EVAL_IMPL(TestAttr,kPrimAttrConvertTest,InferImplAttrTest); | |||
| REGISTER_PRIMITIVE_EVAL_IMPL(TestAttr, kPrimAttrConvertTest, InferImplAttrTest, nullptr, true); | |||
| AbstractBasePtr InferImplDynamicInputTest(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| EXPECT_EQ(args_spec_list.size(), 3); | |||
| @@ -45,7 +45,7 @@ AbstractBasePtr InferImplDynamicInputTest(const abstract::AnalysisEnginePtr &, c | |||
| auto item = args_spec_list[1]->cast<abstract::AbstractTuplePtr>(); | |||
| return args_spec_list[0]; | |||
| } | |||
| REGISTER_PRIMITIVE_EVAL_IMPL(TestDynamicInput,kPrimDynamicInputTest,InferImplDynamicInputTest); | |||
| REGISTER_PRIMITIVE_EVAL_IMPL(TestDynamicInput, kPrimDynamicInputTest, InferImplDynamicInputTest, nullptr, true); | |||
| class TestAttrAndDynamicBackendInfer : public UT::Common { | |||
| public: | |||
| TestAttrAndDynamicBackendInfer() {} | |||