Merge pull request !4160 from vlne-v1/remove-ref-origintags/v0.7.0-beta
| @@ -333,28 +333,28 @@ ArgsPairList HyperMap::Harmonize(const FuncGraphPtr &func_graph, const ArgsPairL | |||||
| } | } | ||||
| FuncGraphPtr HyperMap::GenerateFromTypes(const TypePtrList &args_spec_list) { | FuncGraphPtr HyperMap::GenerateFromTypes(const TypePtrList &args_spec_list) { | ||||
| FuncGraphPtr ptrGraph = std::make_shared<FuncGraph>(); | |||||
| ptrGraph->set_flag(FUNC_GRAPH_FLAG_CORE, true); | |||||
| ptrGraph->set_flag(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER, true); | |||||
| ptrGraph->debug_info()->set_name("hyper_map"); | |||||
| FuncGraphPtr ptr_graph = std::make_shared<FuncGraph>(); | |||||
| ptr_graph->set_flag(FUNC_GRAPH_FLAG_CORE, true); | |||||
| ptr_graph->set_flag(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER, true); | |||||
| ptr_graph->debug_info()->set_name("hyper_map"); | |||||
| AnfNodePtr ptrFnArg = nullptr; | AnfNodePtr ptrFnArg = nullptr; | ||||
| std::size_t i = 0; | std::size_t i = 0; | ||||
| ArgsPairList argmap; | ArgsPairList argmap; | ||||
| ArgsPairList argmap2; | ArgsPairList argmap2; | ||||
| if (fn_leaf_ == nullptr) { | if (fn_leaf_ == nullptr) { | ||||
| ptrFnArg = ptrGraph->add_parameter(); | |||||
| ptrFnArg = ptr_graph->add_parameter(); | |||||
| i = 1; | i = 1; | ||||
| } | } | ||||
| std::size_t size = args_spec_list.size(); | std::size_t size = args_spec_list.size(); | ||||
| for (; i < size; ++i) { | for (; i < size; ++i) { | ||||
| argmap.push_back(std::make_pair(ptrGraph->add_parameter(), args_spec_list[i])); | |||||
| argmap.push_back(std::make_pair(ptr_graph->add_parameter(), args_spec_list[i])); | |||||
| } | } | ||||
| argmap2 = Harmonize(ptrGraph, argmap); | |||||
| ptrGraph->set_output(Make(ptrGraph, ptrFnArg, argmap2)); | |||||
| return ptrGraph; | |||||
| argmap2 = Harmonize(ptr_graph, argmap); | |||||
| ptr_graph->set_output(Make(ptr_graph, ptrFnArg, argmap2)); | |||||
| return ptr_graph; | |||||
| } | } | ||||
| abstract::AbstractBasePtrList HyperMap::NormalizeArgs(const AbstractBasePtrList &args_spec_list) const { | abstract::AbstractBasePtrList HyperMap::NormalizeArgs(const AbstractBasePtrList &args_spec_list) const { | ||||
| @@ -582,30 +582,30 @@ FuncGraphPtr GradOperation::GetGrad(AnfNodePtr node, const AnfNodePtr &weights, | |||||
| inputs.push_back(opsTupleItem); | inputs.push_back(opsTupleItem); | ||||
| inputs.push_back(cnode); | inputs.push_back(cnode); | ||||
| inputs.push_back(NewValueNode(1)); | inputs.push_back(NewValueNode(1)); | ||||
| AnfNodePtr ptrBprop = ret->NewCNode(inputs); | |||||
| AnfNodePtr ptr_bprop = ret->NewCNode(inputs); | |||||
| doGetGrad(ret, out, ptrBprop, weights_node, opsTupleItem); | |||||
| doGetGrad(ret, out, ptr_bprop, weights_node, opsTupleItem); | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| void GradOperation::doGetGrad(const FuncGraphPtr &func_graph, AnfNodePtr out, AnfNodePtr ptrBprop, AnfNodePtr weights, | |||||
| void GradOperation::doGetGrad(const FuncGraphPtr &func_graph, AnfNodePtr out, AnfNodePtr ptr_bprop, AnfNodePtr weights, | |||||
| ValueNodePtr opsTupleItem) { | ValueNodePtr opsTupleItem) { | ||||
| MS_EXCEPTION_IF_NULL(func_graph); | MS_EXCEPTION_IF_NULL(func_graph); | ||||
| AnfNodePtr ptrBPropArg = nullptr; | |||||
| AnfNodePtr ptr_bprop_arg = nullptr; | |||||
| if (sens_param_) { | if (sens_param_) { | ||||
| ptrBPropArg = func_graph->add_parameter(); | |||||
| ptr_bprop_arg = func_graph->add_parameter(); | |||||
| } else { | } else { | ||||
| auto ones_like = prim::GetPythonOps("ones_like"); | auto ones_like = prim::GetPythonOps("ones_like"); | ||||
| ptrBPropArg = func_graph->NewCNode({NewValueNode(ones_like), out}); | |||||
| ptr_bprop_arg = func_graph->NewCNode({NewValueNode(ones_like), out}); | |||||
| } | } | ||||
| AnfNodePtr ptrBApp = func_graph->NewCNode({ptrBprop, ptrBPropArg}); | |||||
| AnfNodePtr ptr_bapp = func_graph->NewCNode({ptr_bprop, ptr_bprop_arg}); | |||||
| CNodePtr fv_bprop = nullptr; | CNodePtr fv_bprop = nullptr; | ||||
| if (get_by_list_) { | if (get_by_list_) { | ||||
| // python code: grads = hyper_map(F.partial(env_get, env), weights) | // python code: grads = hyper_map(F.partial(env_get, env), weights) | ||||
| AnfNodePtr env = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), ptrBApp, NewValueNode(0)}); | |||||
| AnfNodePtr env = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), ptr_bapp, NewValueNode(0)}); | |||||
| AnfNodePtr partial_env_get = | AnfNodePtr partial_env_get = | ||||
| func_graph->NewCNode({NewValueNode(prim::kPrimPartial), NewValueNode(prim::GetPythonOps("env_get")), env}); | func_graph->NewCNode({NewValueNode(prim::kPrimPartial), NewValueNode(prim::GetPythonOps("env_get")), env}); | ||||
| MetaFuncGraphPtr hyper_map = std::make_shared<HyperMap>(); | MetaFuncGraphPtr hyper_map = std::make_shared<HyperMap>(); | ||||
| @@ -614,7 +614,7 @@ void GradOperation::doGetGrad(const FuncGraphPtr &func_graph, AnfNodePtr out, An | |||||
| CNodePtr inputs_bprop = nullptr; | CNodePtr inputs_bprop = nullptr; | ||||
| if (get_all_) { | if (get_all_) { | ||||
| inputs_bprop = func_graph->NewCNode({NewValueNode(kTail), ptrBApp}); | |||||
| inputs_bprop = func_graph->NewCNode({NewValueNode(kTail), ptr_bapp}); | |||||
| } | } | ||||
| // Gradients wrt inputs and parameters | // Gradients wrt inputs and parameters | ||||
| @@ -636,8 +636,8 @@ void GradOperation::doGetGrad(const FuncGraphPtr &func_graph, AnfNodePtr out, An | |||||
| } | } | ||||
| // Gradients wrt first input. | // Gradients wrt first input. | ||||
| // ptrBApp returns (EnvInstance(grads wrt params), grads wrt input0, grads wrt input1, ...), so 1 is for first input | |||||
| func_graph->set_output(func_graph->NewCNode({opsTupleItem, ptrBApp, NewValueNode(1)})); | |||||
| // ptr_bapp returns (EnvInstance(grads wrt params), grads wrt input0, grads wrt input1, ...), so 1 is for first input | |||||
| func_graph->set_output(func_graph->NewCNode({opsTupleItem, ptr_bapp, NewValueNode(1)})); | |||||
| } | } | ||||
| // Generate the graph. | // Generate the graph. | ||||
| @@ -657,35 +657,35 @@ FuncGraphPtr GradOperation::GenerateFuncGraph(const AbstractBasePtrList &args_sp | |||||
| auto real_fn = dyn_cast<FuncGraphAbstractClosure>(fn); | auto real_fn = dyn_cast<FuncGraphAbstractClosure>(fn); | ||||
| MS_EXCEPTION_IF_NULL(real_fn); | MS_EXCEPTION_IF_NULL(real_fn); | ||||
| FuncGraphPtr ptrGraph = real_fn->func_graph(); | |||||
| MS_EXCEPTION_IF_NULL(ptrGraph); | |||||
| TraceManager::DebugTrace(std::make_shared<TraceGradOperation>(ptrGraph->debug_info())); | |||||
| FuncGraphPtr dfBuilder = std::make_shared<FuncGraph>(); | |||||
| FuncGraphPtr ptr_graph = real_fn->func_graph(); | |||||
| MS_EXCEPTION_IF_NULL(ptr_graph); | |||||
| TraceManager::DebugTrace(std::make_shared<TraceGradOperation>(ptr_graph->debug_info())); | |||||
| FuncGraphPtr df_builder = std::make_shared<FuncGraph>(); | |||||
| TraceManager::EndTrace(); | TraceManager::EndTrace(); | ||||
| auto nparam = ptrGraph->parameters().size(); | |||||
| auto nparam = ptr_graph->parameters().size(); | |||||
| std::ostringstream ss; | std::ostringstream ss; | ||||
| ss << "grad{" << nparam << "}"; | ss << "grad{" << nparam << "}"; | ||||
| dfBuilder->set_flag(FUNC_GRAPH_FLAG_CORE, true); | |||||
| dfBuilder->debug_info()->set_name(ss.str()); | |||||
| ParameterPtr param_graph = dfBuilder->add_parameter(); | |||||
| df_builder->set_flag(FUNC_GRAPH_FLAG_CORE, true); | |||||
| df_builder->debug_info()->set_name(ss.str()); | |||||
| ParameterPtr param_graph = df_builder->add_parameter(); | |||||
| AnfNodePtr weights = nullptr; | AnfNodePtr weights = nullptr; | ||||
| if (get_by_list_) { | if (get_by_list_) { | ||||
| weights = dfBuilder->add_parameter(); | |||||
| weights = df_builder->add_parameter(); | |||||
| } | } | ||||
| std::vector<AnfNodePtr> inputs; | std::vector<AnfNodePtr> inputs; | ||||
| inputs.push_back(NewValueNode(prim::kPrimJ)); | inputs.push_back(NewValueNode(prim::kPrimJ)); | ||||
| inputs.push_back(param_graph); | inputs.push_back(param_graph); | ||||
| auto jf = dfBuilder->NewCNode(inputs); | |||||
| auto jf = df_builder->NewCNode(inputs); | |||||
| // df is checked in GetGrad | // df is checked in GetGrad | ||||
| TraceManager::DebugTrace(std::make_shared<TraceGradOperation>(ptrGraph->debug_info())); | |||||
| auto df = GetGrad(jf, weights, ptrGraph->parameters()); | |||||
| TraceManager::DebugTrace(std::make_shared<TraceGradOperation>(ptr_graph->debug_info())); | |||||
| auto df = GetGrad(jf, weights, ptr_graph->parameters()); | |||||
| TraceManager::EndTrace(); | TraceManager::EndTrace(); | ||||
| dfBuilder->set_output(NewValueNode(df)); | |||||
| df_builder->set_output(NewValueNode(df)); | |||||
| return dfBuilder; | |||||
| return df_builder; | |||||
| } | } | ||||
| REGISTER_PYBIND_DEFINE(GradOperation_, ([](const py::module *m) { | REGISTER_PYBIND_DEFINE(GradOperation_, ([](const py::module *m) { | ||||
| @@ -72,10 +72,15 @@ void SetMaxType(TypeId *max_type_id, size_t *max_type_number, const TypeId type_ | |||||
| bool GetTensorOrScalarTypeInfo(AbstractBasePtr arg_value, bool is_write, TypeId *arg_type_id, | bool GetTensorOrScalarTypeInfo(AbstractBasePtr arg_value, bool is_write, TypeId *arg_type_id, | ||||
| TypeId *arg_type = nullptr) { | TypeId *arg_type = nullptr) { | ||||
| if (arg_value->isa<abstract::AbstractRef>()) { | if (arg_value->isa<abstract::AbstractRef>()) { | ||||
| if (is_write) { | |||||
| arg_value = arg_value->cast<abstract::AbstractRefPtr>()->ref_origin(); | |||||
| } else { | |||||
| arg_value = arg_value->cast<abstract::AbstractRefPtr>()->ref(); | |||||
| auto ref = arg_value->cast<abstract::AbstractRefPtr>(); | |||||
| arg_value = ref->ref(); | |||||
| if (!is_write && ref->need_cast()) { | |||||
| auto tensor_type = ref->target_type(); | |||||
| *arg_type_id = tensor_type->type_id(); | |||||
| if (arg_type != nullptr) { | |||||
| *arg_type = kObjectTypeTensorType; | |||||
| } | |||||
| return true; | |||||
| } | } | ||||
| } | } | ||||
| if (arg_value->isa<abstract::AbstractTensor>()) { | if (arg_value->isa<abstract::AbstractTensor>()) { | ||||
| @@ -248,6 +253,8 @@ void DoAutoCast(const std::string &func_name, const std::vector<Signature> &sign | |||||
| if (arg_value->isa<abstract::AbstractTensor>() && arg_type_id == it->second) { | if (arg_value->isa<abstract::AbstractTensor>() && arg_type_id == it->second) { | ||||
| continue; | continue; | ||||
| } | } | ||||
| MS_LOG(DEBUG) << "do cast for inputs " << i << " " << (*op_inputs)[i + 1]->ToString() << " " << arg_type_id | |||||
| << " to " << it->second; | |||||
| (*op_inputs)[i + 1] = DoCast((*op_inputs)[i + 1], it->second, graph); | (*op_inputs)[i + 1] = DoCast((*op_inputs)[i + 1], it->second, graph); | ||||
| } | } | ||||
| } | } | ||||
| @@ -289,16 +296,23 @@ AnfNodePtr BuildNewCNode(const FuncGraphPtr &func_graph, const std::string &func | |||||
| TypePtr type = args_spec_list[i]->GetTypeTrack(); | TypePtr type = args_spec_list[i]->GetTypeTrack(); | ||||
| if (type && type->type_id() == kObjectTypeRef) { | if (type && type->type_id() == kObjectTypeRef) { | ||||
| auto ref_abs = args_spec_list[i]->cast<abstract::AbstractRefPtr>(); | |||||
| if (sig == SignatureEnumRW::kRWRead) { | if (sig == SignatureEnumRW::kRWRead) { | ||||
| param = func_graph->NewCNode({NewValueNode(prim::kPrimGetRefValue), param}); | |||||
| param = NewCNode({NewValueNode(prim::kPrimGetRefValue), param}, func_graph); | |||||
| if (ref_abs && ref_abs->need_cast()) { | |||||
| auto cast = prim::GetPythonOps("cast", "mindspore.ops.functional"); | |||||
| param = NewCNode({NewValueNode(cast), param, NewValueNode(ref_abs->target_type())}, func_graph); | |||||
| } | |||||
| } else if (sig == SignatureEnumRW::kRWWrite) { | } else if (sig == SignatureEnumRW::kRWWrite) { | ||||
| param = func_graph->NewCNode({NewValueNode(prim::kPrimGetRefOrigin), param}); | |||||
| param = NewCNode({NewValueNode(prim::kPrimGetRefValue), param}, func_graph); | |||||
| write_indices.insert(i); | write_indices.insert(i); | ||||
| } | } | ||||
| // If sig is SignatureEnumRW::kRWRef, not do anything. | // If sig is SignatureEnumRW::kRWRef, not do anything. | ||||
| } else if (sig == SignatureEnumRW::kRWWrite && type->type_id() != kObjectTypeRefKey) { | } else if (sig == SignatureEnumRW::kRWWrite && type->type_id() != kObjectTypeRefKey) { | ||||
| MS_EXCEPTION(TypeError) << "Function " << func_name << "'s input " << i << " should be a Parameter."; | MS_EXCEPTION(TypeError) << "Function " << func_name << "'s input " << i << " should be a Parameter."; | ||||
| } | } | ||||
| MS_LOG(DEBUG) << "Function " << func_name << "'s input " << i << " " << param->DebugString(2) << " type " | |||||
| << args_spec_list[i]->ToString(); | |||||
| op_inputs.push_back(param); | op_inputs.push_back(param); | ||||
| } | } | ||||
| // process default | // process default | ||||
| @@ -49,13 +49,14 @@ FuncGraphPtr UnpackCall::GenerateFuncGraph(const AbstractBasePtrList &args_spec_ | |||||
| MS_LOG(EXCEPTION) << op_name << " requires at least two args, but got " << arg_length << "."; | MS_LOG(EXCEPTION) << op_name << " requires at least two args, but got " << arg_length << "."; | ||||
| } | } | ||||
| (void)abstract::CheckArg<AbstractFunction>(op_name, args_spec_list, 0); | |||||
| // No need to check, check will be done in infer. | |||||
| auto ret_graph = std::make_shared<FuncGraph>(); | auto ret_graph = std::make_shared<FuncGraph>(); | ||||
| ret_graph->set_flag(FUNC_GRAPH_FLAG_CORE, true); | ret_graph->set_flag(FUNC_GRAPH_FLAG_CORE, true); | ||||
| ret_graph->debug_info()->set_name("UnpackCall"); | |||||
| AnfNodePtr fnNode = ret_graph->add_parameter(); | |||||
| AnfNodePtr fn_node = ret_graph->add_parameter(); | |||||
| std::vector<AnfNodePtr> elems; | std::vector<AnfNodePtr> elems; | ||||
| elems.push_back(fnNode); | |||||
| elems.push_back(fn_node); | |||||
| for (size_t index = 1; index < arg_length; index++) { | for (size_t index = 1; index < arg_length; index++) { | ||||
| MS_EXCEPTION_IF_NULL(args_spec_list[index]); | MS_EXCEPTION_IF_NULL(args_spec_list[index]); | ||||
| if (args_spec_list[index]->isa<AbstractTuple>()) { | if (args_spec_list[index]->isa<AbstractTuple>()) { | ||||
| @@ -129,16 +129,22 @@ AbstractBasePtr InferImplMakeRefKey(const AnalysisEnginePtr &, const PrimitivePt | |||||
| AbstractBasePtr InferImplMakeRef(const AnalysisEnginePtr &, const PrimitivePtr &, | AbstractBasePtr InferImplMakeRef(const AnalysisEnginePtr &, const PrimitivePtr &, | ||||
| const AbstractBasePtrList &args_spec_list) { | const AbstractBasePtrList &args_spec_list) { | ||||
| // arguments: key, value, original value | |||||
| // arguments: key, value, target type(None if no target type) | |||||
| if (args_spec_list.size() != 3) { | if (args_spec_list.size() != 3) { | ||||
| MS_LOG(EXCEPTION) << "make_ref evaluator requires 3 parameters, while the input size is " << args_spec_list.size() | MS_LOG(EXCEPTION) << "make_ref evaluator requires 3 parameters, while the input size is " << args_spec_list.size() | ||||
| << "."; | << "."; | ||||
| } | } | ||||
| TypePtr type = args_spec_list[0]->GetTypeTrack(); | TypePtr type = args_spec_list[0]->GetTypeTrack(); | ||||
| ValuePtr tensor_target_v = args_spec_list[2]->BuildValue(); | |||||
| if (type->type_id() != kObjectTypeRefKey) { | if (type->type_id() != kObjectTypeRefKey) { | ||||
| MS_LOG(EXCEPTION) << "First input of make_ref should be a RefKey but a " << type->ToString(); | MS_LOG(EXCEPTION) << "First input of make_ref should be a RefKey but a " << type->ToString(); | ||||
| } | } | ||||
| return std::make_shared<AbstractRef>(args_spec_list[0], args_spec_list[1], args_spec_list[2]); | |||||
| auto need_cast = !tensor_target_v->isa<None>(); | |||||
| if (need_cast && !tensor_target_v->isa<Type>()) { | |||||
| MS_LOG(EXCEPTION) << "Third input of make_ref should be a Type but a " << tensor_target_v->ToString(); | |||||
| } | |||||
| TypePtr cast_target = tensor_target_v->cast<TypePtr>(); | |||||
| return std::make_shared<AbstractRef>(args_spec_list[0], args_spec_list[1], need_cast, cast_target); | |||||
| } | } | ||||
| AbstractBasePtr InferImplGetRefKey(const AnalysisEnginePtr &, const PrimitivePtr &, | AbstractBasePtr InferImplGetRefKey(const AnalysisEnginePtr &, const PrimitivePtr &, | ||||
| @@ -163,25 +169,11 @@ AbstractBasePtr InferImplGetRefValue(const AnalysisEnginePtr &, const PrimitiveP | |||||
| } | } | ||||
| TypePtr type = args_spec_list[0]->GetTypeTrack(); | TypePtr type = args_spec_list[0]->GetTypeTrack(); | ||||
| if (type->type_id() != kObjectTypeRef) { | if (type->type_id() != kObjectTypeRef) { | ||||
| MS_LOG(EXCEPTION) << "First input of get_ref_value should be a Ref but a " << type->ToString(); | |||||
| return args_spec_list[0]; | |||||
| } | } | ||||
| return args_spec_list[0]->cast<AbstractRefPtr>()->ref(); | return args_spec_list[0]->cast<AbstractRefPtr>()->ref(); | ||||
| } | } | ||||
| AbstractBasePtr InferImplGetRefOrigin(const AnalysisEnginePtr &, const PrimitivePtr &, | |||||
| const AbstractBasePtrList &args_spec_list) { | |||||
| // arguments: value | |||||
| if (args_spec_list.size() != 1) { | |||||
| MS_LOG(EXCEPTION) << "get_ref_origin requires 1 parameters, while the input size is " << args_spec_list.size() | |||||
| << "."; | |||||
| } | |||||
| TypePtr type = args_spec_list[0]->GetTypeTrack(); | |||||
| if (type->type_id() != kObjectTypeRef) { | |||||
| MS_LOG(EXCEPTION) << "First input of get_ref_value should be a Ref but a " << type->ToString(); | |||||
| } | |||||
| return args_spec_list[0]->cast<AbstractRefPtr>()->ref_origin(); | |||||
| } | |||||
| AbstractBasePtr InferImplStateSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr InferImplStateSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const AbstractBasePtrList &args_spec_list) { | const AbstractBasePtrList &args_spec_list) { | ||||
| // args: Two objects of a subclass of AbstractBase, key and value. | // args: Two objects of a subclass of AbstractBase, key and value. | ||||
| @@ -95,10 +95,10 @@ OptimizeIRPassLib::OptimizeIRPassLib() { | |||||
| // Ref eliminate | // Ref eliminate | ||||
| make_ref_eliminate_ = | make_ref_eliminate_ = | ||||
| MakeSubstitution(std::make_shared<MakeRefEliminater>(), "make_ref_eliminate", prim::kPrimMakeRef); | MakeSubstitution(std::make_shared<MakeRefEliminater>(), "make_ref_eliminate", prim::kPrimMakeRef); | ||||
| get_ref_param_eliminate_ = MakeSubstitution(std::make_shared<GetRefParamEliminater>(), "get_ref_param_eliminate", | |||||
| {prim::kPrimGetRefValue, prim::kPrimGetRefOrigin}); | |||||
| get_ref_param_eliminate_ = | |||||
| MakeSubstitution(std::make_shared<GetRefParamEliminater>(), "get_ref_param_eliminate", {prim::kPrimGetRefValue}); | |||||
| get_make_ref_eliminate_ = MakeSubstitution(std::make_shared<GetMakeRefEliminater>(), "get_make_ref_eliminate", | get_make_ref_eliminate_ = MakeSubstitution(std::make_shared<GetMakeRefEliminater>(), "get_make_ref_eliminate", | ||||
| {prim::kPrimGetRefKey, prim::kPrimGetRefValue, prim::kPrimGetRefOrigin}); | |||||
| {prim::kPrimGetRefKey, prim::kPrimGetRefValue}); | |||||
| replace_refkey_by_param_ = MakeSubstitution(std::make_shared<ReplaceRefkeyByParam>(), "replace_refkey_by_param", | replace_refkey_by_param_ = MakeSubstitution(std::make_shared<ReplaceRefkeyByParam>(), "replace_refkey_by_param", | ||||
| IsValueNode<RefKey>, opt::FORCE_RENORM); | IsValueNode<RefKey>, opt::FORCE_RENORM); | ||||
| @@ -37,27 +37,23 @@ class MakeRefEliminater : public OptimizerCaller { | |||||
| }; | }; | ||||
| // {prim::kPrimGetRefValue, Parameter} -> Parameter | // {prim::kPrimGetRefValue, Parameter} -> Parameter | ||||
| // {prim::kPrimGetRefOrigin, Parameter} -> Parameter | |||||
| class GetRefParamEliminater : public OptimizerCaller { | class GetRefParamEliminater : public OptimizerCaller { | ||||
| public: | public: | ||||
| AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { | AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { | ||||
| PatternNode<AnfNodePtr> x; | PatternNode<AnfNodePtr> x; | ||||
| MATCH_REPLACE(node, PPrimitive(prim::kPrimGetRefValue, x), x); | MATCH_REPLACE(node, PPrimitive(prim::kPrimGetRefValue, x), x); | ||||
| MATCH_REPLACE(node, PPrimitive(prim::kPrimGetRefOrigin, x), x); | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| }; | }; | ||||
| // {prim::kPrimGetRefKey, {prim::kPrimMakeRef, X, Y, Z}} -> X | // {prim::kPrimGetRefKey, {prim::kPrimMakeRef, X, Y, Z}} -> X | ||||
| // {prim::kPrimGetRefValue, {prim::kPrimMakeRef, X, Y, Z}} -> Y | // {prim::kPrimGetRefValue, {prim::kPrimMakeRef, X, Y, Z}} -> Y | ||||
| // {prim::kPrimGetRefOrigin, {prim::kPrimMakeRef, X, Y, Z}} -> Z | |||||
| class GetMakeRefEliminater : public OptimizerCaller { | class GetMakeRefEliminater : public OptimizerCaller { | ||||
| public: | public: | ||||
| AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { | AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { | ||||
| PatternNode<AnfNodePtr> x, y, z; | PatternNode<AnfNodePtr> x, y, z; | ||||
| MATCH_REPLACE(node, PPrimitive(prim::kPrimGetRefKey, PPrimitive(prim::kPrimMakeRef, x, y, z)), x); | MATCH_REPLACE(node, PPrimitive(prim::kPrimGetRefKey, PPrimitive(prim::kPrimMakeRef, x, y, z)), x); | ||||
| MATCH_REPLACE(node, PPrimitive(prim::kPrimGetRefValue, PPrimitive(prim::kPrimMakeRef, x, y, z)), y); | MATCH_REPLACE(node, PPrimitive(prim::kPrimGetRefValue, PPrimitive(prim::kPrimMakeRef, x, y, z)), y); | ||||
| MATCH_REPLACE(node, PPrimitive(prim::kPrimGetRefOrigin, PPrimitive(prim::kPrimMakeRef, x, y, z)), z); | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| @@ -60,6 +60,17 @@ FuncGraphPtr ParsePythonCode(const py::object &obj, const std::string &python_mo | |||||
| return func_graph; | return func_graph; | ||||
| } | } | ||||
| ValuePtr GetMixedPrecisionTargetType(const FuncGraphPtr &func_graph, const AnfNodePtr ¶m) { | |||||
| TypePtr dst_type; | |||||
| if (func_graph->has_flag(GRAPH_FLAG_MIX_PRECISION_FP32)) { | |||||
| return kFloat32; | |||||
| } else if (func_graph->has_flag(GRAPH_FLAG_MIX_PRECISION_FP16)) { | |||||
| return kFloat16; | |||||
| } else { | |||||
| return kNone; | |||||
| } | |||||
| } | |||||
| // if any mixed precision flag add a cast node after the parameter node. | // if any mixed precision flag add a cast node after the parameter node. | ||||
| AnfNodePtr GetMixedPrecisionCastHelp(const FuncGraphPtr &func_graph, const AnfNodePtr ¶m) { | AnfNodePtr GetMixedPrecisionCastHelp(const FuncGraphPtr &func_graph, const AnfNodePtr ¶m) { | ||||
| TypePtr dst_type; | TypePtr dst_type; | ||||
| @@ -359,6 +359,7 @@ class ParseAst { | |||||
| bool UpdateFuncGraphFlags(py::object obj, const FuncGraphPtr &func_graph); | bool UpdateFuncGraphFlags(py::object obj, const FuncGraphPtr &func_graph); | ||||
| AnfNodePtr GetMixedPrecisionCastHelp(const FuncGraphPtr &func_graph, const AnfNodePtr ¶m); | AnfNodePtr GetMixedPrecisionCastHelp(const FuncGraphPtr &func_graph, const AnfNodePtr ¶m); | ||||
| ValuePtr GetMixedPrecisionTargetType(const FuncGraphPtr &func_graph, const AnfNodePtr ¶m); | |||||
| } // namespace parse | } // namespace parse | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -70,6 +70,7 @@ bool SymbolResolver::Resolve() { | |||||
| } | } | ||||
| namespace { | namespace { | ||||
| // if any mixed precision flag add a cast node after the parameter node. | |||||
| // argument obj should be python Parameter object | // argument obj should be python Parameter object | ||||
| // it will be converted to Parameter node here | // it will be converted to Parameter node here | ||||
| AnfNodePtr ResolveParameterObj(const FuncGraphPtr &func_graph, const py::object &obj) { | AnfNodePtr ResolveParameterObj(const FuncGraphPtr &func_graph, const py::object &obj) { | ||||
| @@ -112,11 +113,12 @@ AnfNodePtr ResolveParameterObj(const FuncGraphPtr &func_graph, const py::object | |||||
| } | } | ||||
| auto iter = func_graph->make_ref_params().find(para_node); | auto iter = func_graph->make_ref_params().find(para_node); | ||||
| if (iter == func_graph->make_ref_params().end()) { | if (iter == func_graph->make_ref_params().end()) { | ||||
| AnfNodePtr value = GetMixedPrecisionCastHelp(func_graph, para_node); | |||||
| ValuePtr target_type = GetMixedPrecisionTargetType(func_graph, para_node); | |||||
| AnfNodePtr make_ref = NewValueNode(prim::kPrimMakeRef); | AnfNodePtr make_ref = NewValueNode(prim::kPrimMakeRef); | ||||
| AnfNodePtr ref_key = NewValueNode(std::make_shared<RefKey>(param_name)); | AnfNodePtr ref_key = NewValueNode(std::make_shared<RefKey>(param_name)); | ||||
| AnfNodePtr ref_node = func_graph->NewCNode({make_ref, ref_key, value, para_node}); | |||||
| AnfNodePtr target_type_node = NewValueNode(target_type); | |||||
| AnfNodePtr ref_node = func_graph->NewCNode({make_ref, ref_key, para_node, target_type_node}); | |||||
| func_graph->make_ref_params()[para_node] = ref_node; | func_graph->make_ref_params()[para_node] = ref_node; | ||||
| func_graph->add_parameter_obj_node(ref_node); | func_graph->add_parameter_obj_node(ref_node); | ||||
| return ref_node; | return ref_node; | ||||
| @@ -125,7 +125,6 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { | |||||
| {prim::kPrimMakeRef, {InferImplMakeRef, true}}, | {prim::kPrimMakeRef, {InferImplMakeRef, true}}, | ||||
| {prim::kPrimGetRefKey, {InferImplGetRefKey, true}}, | {prim::kPrimGetRefKey, {InferImplGetRefKey, true}}, | ||||
| {prim::kPrimGetRefValue, {InferImplGetRefValue, true}}, | {prim::kPrimGetRefValue, {InferImplGetRefValue, true}}, | ||||
| {prim::kPrimGetRefOrigin, {InferImplGetRefOrigin, true}}, | |||||
| {prim::kPrimStateSetItem, {InferImplStateSetItem, true}}, | {prim::kPrimStateSetItem, {InferImplStateSetItem, true}}, | ||||
| {prim::kPrimDepend, {InferImplDepend, true}}, | {prim::kPrimDepend, {InferImplDepend, true}}, | ||||
| {prim::kPrimBroadcastGradientArgs, {InferImplBroadcastGradientArgs, false}}, | {prim::kPrimBroadcastGradientArgs, {InferImplBroadcastGradientArgs, false}}, | ||||
| @@ -1117,11 +1117,12 @@ std::vector<AnfNodePtr> PynativeExecutor::GetWeightsArgs(const py::object &weigh | |||||
| free_param->debug_info()->set_name(param_name); | free_param->debug_info()->set_name(param_name); | ||||
| para_node = free_param; | para_node = free_param; | ||||
| } | } | ||||
| AnfNodePtr value = parse::GetMixedPrecisionCastHelp(df_builder_, para_node); | |||||
| ValuePtr target_type = parse::GetMixedPrecisionTargetType(df_builder_, para_node); | |||||
| AnfNodePtr make_ref = NewValueNode(prim::kPrimMakeRef); | AnfNodePtr make_ref = NewValueNode(prim::kPrimMakeRef); | ||||
| auto refkey = std::make_shared<RefKey>(para_node->cast<ParameterPtr>()->name()); | auto refkey = std::make_shared<RefKey>(para_node->cast<ParameterPtr>()->name()); | ||||
| AnfNodePtr ref_key_node = NewValueNode(refkey); | AnfNodePtr ref_key_node = NewValueNode(refkey); | ||||
| AnfNodePtr ref_node = df_builder_->NewCNode({make_ref, ref_key_node, value, para_node}); | |||||
| AnfNodePtr target_type_node = NewValueNode(target_type); | |||||
| AnfNodePtr ref_node = df_builder_->NewCNode({make_ref, ref_key_node, para_node, target_type_node}); | |||||
| w_args.push_back(ref_node); | w_args.push_back(ref_node); | ||||
| } | } | ||||
| } else { | } else { | ||||
| @@ -808,14 +808,40 @@ std::string AbstractJTagged::ToString() const { | |||||
| return buffer.str(); | return buffer.str(); | ||||
| } | } | ||||
| AbstractRef::AbstractRef(const AbstractBasePtr &ref_key, const AbstractBasePtr &ref_value, bool need_cast, | |||||
| TypePtr cast_target) | |||||
| : ref_key_(ref_key), ref_(ref_value), need_cast_(false), target_type_(nullptr), ref_key_value_(nullptr) { | |||||
| set_type(std::make_shared<RefType>()); | |||||
| auto origin_type = ref_value->BuildType(); | |||||
| if (need_cast && cast_target && origin_type && origin_type->isa<TensorType>()) { | |||||
| auto tensor_dtype = origin_type->cast<TensorTypePtr>()->element(); | |||||
| if (tensor_dtype && IsSubType(tensor_dtype, kFloat)) { | |||||
| if (cast_target != tensor_dtype) { | |||||
| need_cast_ = true; | |||||
| target_type_ = cast_target; | |||||
| } | |||||
| } | |||||
| } | |||||
| if (ref_key && ref_key->isa<AbstractRefKey>()) { | |||||
| ref_key_value_ = ref_key->cast<AbstractRefKeyPtr>()->ref_key_value(); | |||||
| } | |||||
| } | |||||
| BaseShapePtr AbstractRef::BuildShape() const { return ref_->BuildShape(); } | |||||
| TypePtr AbstractRef::BuildType() const { | TypePtr AbstractRef::BuildType() const { | ||||
| TypePtr subtype = ref_->BuildType(); | TypePtr subtype = ref_->BuildType(); | ||||
| TypePtr subtype_origin = ref_origin_->BuildType(); | |||||
| TypePtr subtype_origin = subtype; | |||||
| if (need_cast_) { | |||||
| subtype_origin = std::make_shared<TensorType>(target_type_); | |||||
| } | |||||
| return std::make_shared<RefType>(subtype, subtype_origin); | return std::make_shared<RefType>(subtype, subtype_origin); | ||||
| } | } | ||||
| bool AbstractRef::operator==(const AbstractRef &other) const { | bool AbstractRef::operator==(const AbstractRef &other) const { | ||||
| return (*ref_ == *other.ref_) && (*ref_key_ == *other.ref_key_) && (*ref_origin_ == *other.ref_origin_); | |||||
| return (*ref_ == *other.ref_) && (need_cast_ == other.need_cast_) && | |||||
| (!need_cast_ || (*target_type_ == *other.target_type_)); | |||||
| // not compare the key for reuse the graph (*ref_key_ == *other.ref_key_); | |||||
| } | } | ||||
| bool AbstractRef::operator==(const AbstractBase &other) const { | bool AbstractRef::operator==(const AbstractBase &other) const { | ||||
| @@ -826,27 +852,45 @@ bool AbstractRef::operator==(const AbstractBase &other) const { | |||||
| return false; | return false; | ||||
| } | } | ||||
| AbstractBasePtr AbstractRefKey::Join(const AbstractBasePtr &other) { | |||||
| MS_EXCEPTION_IF_NULL(other); | |||||
| if (*this == *other) { | |||||
| auto ret = shared_from_base<AbstractBase>(); | |||||
| return ret; | |||||
| } | |||||
| auto value_self = GetValueTrack(); | |||||
| MS_EXCEPTION_IF_NULL(value_self); | |||||
| ValuePtr res_value = ValueJoin(value_self, other->GetValueTrack()); | |||||
| if (res_value == value_self) { | |||||
| auto ret = shared_from_base<AbstractBase>(); | |||||
| return ret; | |||||
| } | |||||
| auto ret = std::make_shared<AbstractRefKey>(); | |||||
| ret->set_value(res_value); | |||||
| return ret; | |||||
| } | |||||
| AbstractBasePtr AbstractRef::Join(const AbstractBasePtr &other) { | AbstractBasePtr AbstractRef::Join(const AbstractBasePtr &other) { | ||||
| auto other_ref = other->cast<AbstractRefPtr>(); | auto other_ref = other->cast<AbstractRefPtr>(); | ||||
| if (other_ref == nullptr) { | if (other_ref == nullptr) { | ||||
| auto new_ref = ref_->Join(other); | auto new_ref = ref_->Join(other); | ||||
| return std::make_shared<AbstractRef>(ref_key_, new_ref, ref_origin_); | |||||
| return std::make_shared<AbstractRef>(ref_key_, new_ref); | |||||
| } | } | ||||
| if (*this == *other) { | |||||
| if ((*this == *other) && (*ref_key_ == *other_ref->ref_key_)) { | |||||
| return shared_from_base<AbstractBase>(); | return shared_from_base<AbstractBase>(); | ||||
| } | } | ||||
| auto ref_key = ref_key_->Join(other_ref->ref_key_); | auto ref_key = ref_key_->Join(other_ref->ref_key_); | ||||
| auto ref = ref_->Join(other_ref->ref()); | auto ref = ref_->Join(other_ref->ref()); | ||||
| auto ref_origin = ref_origin_->Join(other_ref->ref_origin_); | |||||
| return std::make_shared<AbstractRef>(ref_key, ref, ref_origin); | |||||
| return std::make_shared<AbstractRef>(ref_key, ref); | |||||
| } | } | ||||
| std::string AbstractRef::ToString() const { | std::string AbstractRef::ToString() const { | ||||
| std::ostringstream buffer; | std::ostringstream buffer; | ||||
| buffer << type_name() << "(" | buffer << type_name() << "(" | ||||
| << "key: " << ref_key_->ToString() << " ref_value: " << ref_->ToString() | |||||
| << " origin_value: " << ref_origin_->ToString(); | |||||
| << "key: " << ref_key_->ToString() << " ref_value: " << ref_->ToString(); | |||||
| if (need_cast_) { | |||||
| buffer << " cast to: " << target_type_->ToString(); | |||||
| } | |||||
| auto value = GetValueTrack(); | auto value = GetValueTrack(); | ||||
| if (value) { | if (value) { | ||||
| buffer << ", value: " << value->ToString(); | buffer << ", value: " << value->ToString(); | ||||
| @@ -873,6 +917,12 @@ std::string AbstractNone::ToString() const { | |||||
| ValuePtr AbstractNone::RealBuildValue() const { return kNone; } | ValuePtr AbstractNone::RealBuildValue() const { return kNone; } | ||||
| AbstractBasePtr AbstractRefKey::Broaden() const { | |||||
| auto refkey = std::make_shared<AbstractRefKey>(); | |||||
| refkey->set_value(kAnyValue); | |||||
| return refkey; | |||||
| } | |||||
| bool AbstractRefKey::operator==(const AbstractRefKey &other) const { | bool AbstractRefKey::operator==(const AbstractRefKey &other) const { | ||||
| ValuePtr value_self = GetValueTrack(); | ValuePtr value_self = GetValueTrack(); | ||||
| ValuePtr value_other = other.GetValueTrack(); | ValuePtr value_other = other.GetValueTrack(); | ||||
| @@ -535,50 +535,70 @@ using AbstractEllipsisPtr = std::shared_ptr<AbstractEllipsis>; | |||||
| class AbstractRefKey : public AbstractBase { | class AbstractRefKey : public AbstractBase { | ||||
| public: | public: | ||||
| AbstractRefKey() : AbstractBase() { set_type(std::make_shared<RefKeyType>()); } | |||||
| AbstractRefKey() : AbstractBase(), ref_key_value_(nullptr) { set_type(std::make_shared<RefKeyType>()); } | |||||
| ~AbstractRefKey() override = default; | ~AbstractRefKey() override = default; | ||||
| MS_DECLARE_PARENT(AbstractRefKey, AbstractBase) | MS_DECLARE_PARENT(AbstractRefKey, AbstractBase) | ||||
| TypePtr BuildType() const override { return std::make_shared<RefKeyType>(); } | TypePtr BuildType() const override { return std::make_shared<RefKeyType>(); } | ||||
| bool operator==(const AbstractRefKey &other) const; | bool operator==(const AbstractRefKey &other) const; | ||||
| bool operator==(const AbstractBase &other) const override; | bool operator==(const AbstractBase &other) const override; | ||||
| AbstractBasePtr Clone() const override { return std::make_shared<AbstractRefKey>(); } | |||||
| AbstractBasePtr Clone() const override { | |||||
| auto cloned = std::make_shared<AbstractRefKey>(); | |||||
| cloned->set_value(GetValueTrack()); | |||||
| return cloned; | |||||
| } | |||||
| inline void set_value(const ValuePtr &value) { | |||||
| AbstractBase::set_value(value); | |||||
| ref_key_value_ = value->cast<RefKeyPtr>(); | |||||
| } | |||||
| RefKeyPtr ref_key_value() const { return ref_key_value_; } | |||||
| AbstractBasePtr Join(const AbstractBasePtr &other) override; | |||||
| AbstractBasePtr Broaden() const override; | |||||
| std::string ToString() const override; | std::string ToString() const override; | ||||
| private: | |||||
| // cache for ref_key after build value, when value is null, return nullptr. | |||||
| RefKeyPtr ref_key_value_{nullptr}; | |||||
| }; | }; | ||||
| using AbstractRefKeyPtr = std::shared_ptr<AbstractRefKey>; | using AbstractRefKeyPtr = std::shared_ptr<AbstractRefKey>; | ||||
| class AbstractRef : public AbstractBase { | class AbstractRef : public AbstractBase { | ||||
| public: | public: | ||||
| AbstractRef(const AbstractBasePtr &ref_key, const AbstractBasePtr &ref_value, const AbstractBasePtr &ref_origin) | |||||
| : ref_key_(ref_key), ref_(ref_value), ref_origin_(ref_origin) { | |||||
| set_type(std::make_shared<RefType>()); | |||||
| } | |||||
| AbstractRef(const AbstractBasePtr &ref_key, const AbstractBasePtr &ref_value, bool need_cast = false, | |||||
| TypePtr cast_target = nullptr); | |||||
| ~AbstractRef() override = default; | ~AbstractRef() override = default; | ||||
| MS_DECLARE_PARENT(AbstractRef, AbstractBase) | MS_DECLARE_PARENT(AbstractRef, AbstractBase) | ||||
| TypePtr BuildType() const override; | TypePtr BuildType() const override; | ||||
| BaseShapePtr BuildShape() const override; | |||||
| bool operator==(const AbstractRef &other) const; | bool operator==(const AbstractRef &other) const; | ||||
| bool operator==(const AbstractBase &other) const override; | bool operator==(const AbstractBase &other) const override; | ||||
| AbstractBasePtr Clone() const override { | AbstractBasePtr Clone() const override { | ||||
| return std::make_shared<AbstractRef>(ref_key_->Clone(), ref_->Clone(), ref_origin_->Clone()); | |||||
| return std::make_shared<AbstractRef>(ref_key_->Clone(), ref_->Clone(), need_cast_, target_type_); | |||||
| } | } | ||||
| std::string ToString() const override; | std::string ToString() const override; | ||||
| AbstractBasePtr ref() { return ref_; } | |||||
| AbstractBasePtr ref_origin() { return ref_origin_; } | |||||
| AbstractBasePtr ref_key() { return ref_key_; } | |||||
| inline AbstractBasePtr ref() const { return ref_; } | |||||
| inline AbstractBasePtr ref_key() const { return ref_key_; } | |||||
| inline RefKeyPtr ref_key_value() const { return ref_key_value_; } | |||||
| inline TypePtr target_type() const { return target_type_; } | |||||
| inline bool need_cast() const { return need_cast_; } | |||||
| AbstractBasePtr Broaden() const override { | AbstractBasePtr Broaden() const override { | ||||
| return std::make_shared<AbstractRef>(ref_key_->Broaden(), ref_->Broaden(), ref_origin_->Broaden()); | |||||
| return std::make_shared<AbstractRef>(ref_key_->Broaden(), ref_->Broaden(), need_cast_, target_type_); | |||||
| } | } | ||||
| AbstractBasePtr Join(const AbstractBasePtr &other) override; | AbstractBasePtr Join(const AbstractBasePtr &other) override; | ||||
| std::size_t hash() const override { | std::size_t hash() const override { | ||||
| return ref_key_->hash() ^ ref_->hash() ^ ref_origin_->hash() ^ (std::hash<uint32_t>{}(this->tid()) << 1); | |||||
| return ref_->hash() ^ (std::hash<uint32_t>{}(this->tid()) << 1); // ref_key_->hash() ^ | |||||
| } | } | ||||
| private: | private: | ||||
| AbstractBasePtr ref_key_; | AbstractBasePtr ref_key_; | ||||
| AbstractBasePtr ref_; | AbstractBasePtr ref_; | ||||
| AbstractBasePtr ref_origin_; | |||||
| // For mix presicion, only float type need to cast to float16 of float32 | |||||
| bool need_cast_; | |||||
| TypePtr target_type_; | |||||
| // cache for ref_key after build value, when value is null, return nullptr. | |||||
| RefKeyPtr ref_key_value_; | |||||
| }; | }; | ||||
| using AbstractRefPtr = std::shared_ptr<AbstractRef>; | using AbstractRefPtr = std::shared_ptr<AbstractRef>; | ||||
| @@ -171,9 +171,7 @@ AnalysisContextPtr AnalysisContext::SpecializeKey() const { | |||||
| } | } | ||||
| if (arg->isa<AbstractRef>()) { | if (arg->isa<AbstractRef>()) { | ||||
| MS_LOG(DEBUG) << "refkey broaden"; | MS_LOG(DEBUG) << "refkey broaden"; | ||||
| auto arg_spec = dyn_cast<AbstractRef>(arg); | |||||
| auto ret_spec = arg_spec->Broaden(); | |||||
| return ret_spec; | |||||
| return arg->Broaden(); | |||||
| } | } | ||||
| return arg; | return arg; | ||||
| }); | }); | ||||
| @@ -121,7 +121,6 @@ inline const PrimitivePtr kPrimEnvAdd = std::make_shared<Primitive>("env_add"); | |||||
| inline const PrimitivePtr kPrimMakeRefKey = std::make_shared<Primitive>("MakeRefKey"); | inline const PrimitivePtr kPrimMakeRefKey = std::make_shared<Primitive>("MakeRefKey"); | ||||
| inline const PrimitivePtr kPrimGetRefKey = std::make_shared<Primitive>("get_ref_key"); | inline const PrimitivePtr kPrimGetRefKey = std::make_shared<Primitive>("get_ref_key"); | ||||
| inline const PrimitivePtr kPrimGetRefValue = std::make_shared<Primitive>("get_ref_value"); | inline const PrimitivePtr kPrimGetRefValue = std::make_shared<Primitive>("get_ref_value"); | ||||
| inline const PrimitivePtr kPrimGetRefOrigin = std::make_shared<Primitive>("get_ref_origin"); | |||||
| inline const PrimitivePtr kPrimInsertGradientOf = std::make_shared<Primitive>("InsertGradientOf"); | inline const PrimitivePtr kPrimInsertGradientOf = std::make_shared<Primitive>("InsertGradientOf"); | ||||
| inline const PrimitivePtr kPrimHookBackward = std::make_shared<Primitive>("HookBackward"); | inline const PrimitivePtr kPrimHookBackward = std::make_shared<Primitive>("HookBackward"); | ||||
| inline const PrimitivePtr kPrimPrintShapeType = std::make_shared<Primitive>("PrintShapeType"); | inline const PrimitivePtr kPrimPrintShapeType = std::make_shared<Primitive>("PrintShapeType"); | ||||