Merge pull request !4160 from vlne-v1/remove-ref-origintags/v0.7.0-beta
| @@ -333,28 +333,28 @@ ArgsPairList HyperMap::Harmonize(const FuncGraphPtr &func_graph, const ArgsPairL | |||
| } | |||
| FuncGraphPtr HyperMap::GenerateFromTypes(const TypePtrList &args_spec_list) { | |||
| FuncGraphPtr ptrGraph = std::make_shared<FuncGraph>(); | |||
| ptrGraph->set_flag(FUNC_GRAPH_FLAG_CORE, true); | |||
| ptrGraph->set_flag(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER, true); | |||
| ptrGraph->debug_info()->set_name("hyper_map"); | |||
| FuncGraphPtr ptr_graph = std::make_shared<FuncGraph>(); | |||
| ptr_graph->set_flag(FUNC_GRAPH_FLAG_CORE, true); | |||
| ptr_graph->set_flag(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER, true); | |||
| ptr_graph->debug_info()->set_name("hyper_map"); | |||
| AnfNodePtr ptrFnArg = nullptr; | |||
| std::size_t i = 0; | |||
| ArgsPairList argmap; | |||
| ArgsPairList argmap2; | |||
| if (fn_leaf_ == nullptr) { | |||
| ptrFnArg = ptrGraph->add_parameter(); | |||
| ptrFnArg = ptr_graph->add_parameter(); | |||
| i = 1; | |||
| } | |||
| std::size_t size = args_spec_list.size(); | |||
| for (; i < size; ++i) { | |||
| argmap.push_back(std::make_pair(ptrGraph->add_parameter(), args_spec_list[i])); | |||
| argmap.push_back(std::make_pair(ptr_graph->add_parameter(), args_spec_list[i])); | |||
| } | |||
| argmap2 = Harmonize(ptrGraph, argmap); | |||
| ptrGraph->set_output(Make(ptrGraph, ptrFnArg, argmap2)); | |||
| return ptrGraph; | |||
| argmap2 = Harmonize(ptr_graph, argmap); | |||
| ptr_graph->set_output(Make(ptr_graph, ptrFnArg, argmap2)); | |||
| return ptr_graph; | |||
| } | |||
| abstract::AbstractBasePtrList HyperMap::NormalizeArgs(const AbstractBasePtrList &args_spec_list) const { | |||
| @@ -582,30 +582,30 @@ FuncGraphPtr GradOperation::GetGrad(AnfNodePtr node, const AnfNodePtr &weights, | |||
| inputs.push_back(opsTupleItem); | |||
| inputs.push_back(cnode); | |||
| inputs.push_back(NewValueNode(1)); | |||
| AnfNodePtr ptrBprop = ret->NewCNode(inputs); | |||
| AnfNodePtr ptr_bprop = ret->NewCNode(inputs); | |||
| doGetGrad(ret, out, ptrBprop, weights_node, opsTupleItem); | |||
| doGetGrad(ret, out, ptr_bprop, weights_node, opsTupleItem); | |||
| return ret; | |||
| } | |||
| void GradOperation::doGetGrad(const FuncGraphPtr &func_graph, AnfNodePtr out, AnfNodePtr ptrBprop, AnfNodePtr weights, | |||
| void GradOperation::doGetGrad(const FuncGraphPtr &func_graph, AnfNodePtr out, AnfNodePtr ptr_bprop, AnfNodePtr weights, | |||
| ValueNodePtr opsTupleItem) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| AnfNodePtr ptrBPropArg = nullptr; | |||
| AnfNodePtr ptr_bprop_arg = nullptr; | |||
| if (sens_param_) { | |||
| ptrBPropArg = func_graph->add_parameter(); | |||
| ptr_bprop_arg = func_graph->add_parameter(); | |||
| } else { | |||
| auto ones_like = prim::GetPythonOps("ones_like"); | |||
| ptrBPropArg = func_graph->NewCNode({NewValueNode(ones_like), out}); | |||
| ptr_bprop_arg = func_graph->NewCNode({NewValueNode(ones_like), out}); | |||
| } | |||
| AnfNodePtr ptrBApp = func_graph->NewCNode({ptrBprop, ptrBPropArg}); | |||
| AnfNodePtr ptr_bapp = func_graph->NewCNode({ptr_bprop, ptr_bprop_arg}); | |||
| CNodePtr fv_bprop = nullptr; | |||
| if (get_by_list_) { | |||
| // python code: grads = hyper_map(F.partial(env_get, env), weights) | |||
| AnfNodePtr env = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), ptrBApp, NewValueNode(0)}); | |||
| AnfNodePtr env = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), ptr_bapp, NewValueNode(0)}); | |||
| AnfNodePtr partial_env_get = | |||
| func_graph->NewCNode({NewValueNode(prim::kPrimPartial), NewValueNode(prim::GetPythonOps("env_get")), env}); | |||
| MetaFuncGraphPtr hyper_map = std::make_shared<HyperMap>(); | |||
| @@ -614,7 +614,7 @@ void GradOperation::doGetGrad(const FuncGraphPtr &func_graph, AnfNodePtr out, An | |||
| CNodePtr inputs_bprop = nullptr; | |||
| if (get_all_) { | |||
| inputs_bprop = func_graph->NewCNode({NewValueNode(kTail), ptrBApp}); | |||
| inputs_bprop = func_graph->NewCNode({NewValueNode(kTail), ptr_bapp}); | |||
| } | |||
| // Gradients wrt inputs and parameters | |||
| @@ -636,8 +636,8 @@ void GradOperation::doGetGrad(const FuncGraphPtr &func_graph, AnfNodePtr out, An | |||
| } | |||
| // Gradients wrt first input. | |||
| // ptrBApp returns (EnvInstance(grads wrt params), grads wrt input0, grads wrt input1, ...), so 1 is for first input | |||
| func_graph->set_output(func_graph->NewCNode({opsTupleItem, ptrBApp, NewValueNode(1)})); | |||
| // ptr_bapp returns (EnvInstance(grads wrt params), grads wrt input0, grads wrt input1, ...), so 1 is for first input | |||
| func_graph->set_output(func_graph->NewCNode({opsTupleItem, ptr_bapp, NewValueNode(1)})); | |||
| } | |||
| // Generate the graph. | |||
| @@ -657,35 +657,35 @@ FuncGraphPtr GradOperation::GenerateFuncGraph(const AbstractBasePtrList &args_sp | |||
| auto real_fn = dyn_cast<FuncGraphAbstractClosure>(fn); | |||
| MS_EXCEPTION_IF_NULL(real_fn); | |||
| FuncGraphPtr ptrGraph = real_fn->func_graph(); | |||
| MS_EXCEPTION_IF_NULL(ptrGraph); | |||
| TraceManager::DebugTrace(std::make_shared<TraceGradOperation>(ptrGraph->debug_info())); | |||
| FuncGraphPtr dfBuilder = std::make_shared<FuncGraph>(); | |||
| FuncGraphPtr ptr_graph = real_fn->func_graph(); | |||
| MS_EXCEPTION_IF_NULL(ptr_graph); | |||
| TraceManager::DebugTrace(std::make_shared<TraceGradOperation>(ptr_graph->debug_info())); | |||
| FuncGraphPtr df_builder = std::make_shared<FuncGraph>(); | |||
| TraceManager::EndTrace(); | |||
| auto nparam = ptrGraph->parameters().size(); | |||
| auto nparam = ptr_graph->parameters().size(); | |||
| std::ostringstream ss; | |||
| ss << "grad{" << nparam << "}"; | |||
| dfBuilder->set_flag(FUNC_GRAPH_FLAG_CORE, true); | |||
| dfBuilder->debug_info()->set_name(ss.str()); | |||
| ParameterPtr param_graph = dfBuilder->add_parameter(); | |||
| df_builder->set_flag(FUNC_GRAPH_FLAG_CORE, true); | |||
| df_builder->debug_info()->set_name(ss.str()); | |||
| ParameterPtr param_graph = df_builder->add_parameter(); | |||
| AnfNodePtr weights = nullptr; | |||
| if (get_by_list_) { | |||
| weights = dfBuilder->add_parameter(); | |||
| weights = df_builder->add_parameter(); | |||
| } | |||
| std::vector<AnfNodePtr> inputs; | |||
| inputs.push_back(NewValueNode(prim::kPrimJ)); | |||
| inputs.push_back(param_graph); | |||
| auto jf = dfBuilder->NewCNode(inputs); | |||
| auto jf = df_builder->NewCNode(inputs); | |||
| // df is checked in GetGrad | |||
| TraceManager::DebugTrace(std::make_shared<TraceGradOperation>(ptrGraph->debug_info())); | |||
| auto df = GetGrad(jf, weights, ptrGraph->parameters()); | |||
| TraceManager::DebugTrace(std::make_shared<TraceGradOperation>(ptr_graph->debug_info())); | |||
| auto df = GetGrad(jf, weights, ptr_graph->parameters()); | |||
| TraceManager::EndTrace(); | |||
| dfBuilder->set_output(NewValueNode(df)); | |||
| df_builder->set_output(NewValueNode(df)); | |||
| return dfBuilder; | |||
| return df_builder; | |||
| } | |||
| REGISTER_PYBIND_DEFINE(GradOperation_, ([](const py::module *m) { | |||
| @@ -72,10 +72,15 @@ void SetMaxType(TypeId *max_type_id, size_t *max_type_number, const TypeId type_ | |||
| bool GetTensorOrScalarTypeInfo(AbstractBasePtr arg_value, bool is_write, TypeId *arg_type_id, | |||
| TypeId *arg_type = nullptr) { | |||
| if (arg_value->isa<abstract::AbstractRef>()) { | |||
| if (is_write) { | |||
| arg_value = arg_value->cast<abstract::AbstractRefPtr>()->ref_origin(); | |||
| } else { | |||
| arg_value = arg_value->cast<abstract::AbstractRefPtr>()->ref(); | |||
| auto ref = arg_value->cast<abstract::AbstractRefPtr>(); | |||
| arg_value = ref->ref(); | |||
| if (!is_write && ref->need_cast()) { | |||
| auto tensor_type = ref->target_type(); | |||
| *arg_type_id = tensor_type->type_id(); | |||
| if (arg_type != nullptr) { | |||
| *arg_type = kObjectTypeTensorType; | |||
| } | |||
| return true; | |||
| } | |||
| } | |||
| if (arg_value->isa<abstract::AbstractTensor>()) { | |||
| @@ -248,6 +253,8 @@ void DoAutoCast(const std::string &func_name, const std::vector<Signature> &sign | |||
| if (arg_value->isa<abstract::AbstractTensor>() && arg_type_id == it->second) { | |||
| continue; | |||
| } | |||
| MS_LOG(DEBUG) << "do cast for inputs " << i << " " << (*op_inputs)[i + 1]->ToString() << " " << arg_type_id | |||
| << " to " << it->second; | |||
| (*op_inputs)[i + 1] = DoCast((*op_inputs)[i + 1], it->second, graph); | |||
| } | |||
| } | |||
| @@ -289,16 +296,23 @@ AnfNodePtr BuildNewCNode(const FuncGraphPtr &func_graph, const std::string &func | |||
| TypePtr type = args_spec_list[i]->GetTypeTrack(); | |||
| if (type && type->type_id() == kObjectTypeRef) { | |||
| auto ref_abs = args_spec_list[i]->cast<abstract::AbstractRefPtr>(); | |||
| if (sig == SignatureEnumRW::kRWRead) { | |||
| param = func_graph->NewCNode({NewValueNode(prim::kPrimGetRefValue), param}); | |||
| param = NewCNode({NewValueNode(prim::kPrimGetRefValue), param}, func_graph); | |||
| if (ref_abs && ref_abs->need_cast()) { | |||
| auto cast = prim::GetPythonOps("cast", "mindspore.ops.functional"); | |||
| param = NewCNode({NewValueNode(cast), param, NewValueNode(ref_abs->target_type())}, func_graph); | |||
| } | |||
| } else if (sig == SignatureEnumRW::kRWWrite) { | |||
| param = func_graph->NewCNode({NewValueNode(prim::kPrimGetRefOrigin), param}); | |||
| param = NewCNode({NewValueNode(prim::kPrimGetRefValue), param}, func_graph); | |||
| write_indices.insert(i); | |||
| } | |||
| // If sig is SignatureEnumRW::kRWRef, not do anything. | |||
| } else if (sig == SignatureEnumRW::kRWWrite && type->type_id() != kObjectTypeRefKey) { | |||
| MS_EXCEPTION(TypeError) << "Function " << func_name << "'s input " << i << " should be a Parameter."; | |||
| } | |||
| MS_LOG(DEBUG) << "Function " << func_name << "'s input " << i << " " << param->DebugString(2) << " type " | |||
| << args_spec_list[i]->ToString(); | |||
| op_inputs.push_back(param); | |||
| } | |||
| // process default | |||
| @@ -49,13 +49,14 @@ FuncGraphPtr UnpackCall::GenerateFuncGraph(const AbstractBasePtrList &args_spec_ | |||
| MS_LOG(EXCEPTION) << op_name << " requires at least two args, but got " << arg_length << "."; | |||
| } | |||
| (void)abstract::CheckArg<AbstractFunction>(op_name, args_spec_list, 0); | |||
| // No need to check, check will be done in infer. | |||
| auto ret_graph = std::make_shared<FuncGraph>(); | |||
| ret_graph->set_flag(FUNC_GRAPH_FLAG_CORE, true); | |||
| ret_graph->debug_info()->set_name("UnpackCall"); | |||
| AnfNodePtr fnNode = ret_graph->add_parameter(); | |||
| AnfNodePtr fn_node = ret_graph->add_parameter(); | |||
| std::vector<AnfNodePtr> elems; | |||
| elems.push_back(fnNode); | |||
| elems.push_back(fn_node); | |||
| for (size_t index = 1; index < arg_length; index++) { | |||
| MS_EXCEPTION_IF_NULL(args_spec_list[index]); | |||
| if (args_spec_list[index]->isa<AbstractTuple>()) { | |||
| @@ -129,16 +129,22 @@ AbstractBasePtr InferImplMakeRefKey(const AnalysisEnginePtr &, const PrimitivePt | |||
| AbstractBasePtr InferImplMakeRef(const AnalysisEnginePtr &, const PrimitivePtr &, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| // arguments: key, value, original value | |||
| // arguments: key, value, target type(None if no target type) | |||
| if (args_spec_list.size() != 3) { | |||
| MS_LOG(EXCEPTION) << "make_ref evaluator requires 3 parameters, while the input size is " << args_spec_list.size() | |||
| << "."; | |||
| } | |||
| TypePtr type = args_spec_list[0]->GetTypeTrack(); | |||
| ValuePtr tensor_target_v = args_spec_list[2]->BuildValue(); | |||
| if (type->type_id() != kObjectTypeRefKey) { | |||
| MS_LOG(EXCEPTION) << "First input of make_ref should be a RefKey but a " << type->ToString(); | |||
| } | |||
| return std::make_shared<AbstractRef>(args_spec_list[0], args_spec_list[1], args_spec_list[2]); | |||
| auto need_cast = !tensor_target_v->isa<None>(); | |||
| if (need_cast && !tensor_target_v->isa<Type>()) { | |||
| MS_LOG(EXCEPTION) << "Third input of make_ref should be a Type but a " << tensor_target_v->ToString(); | |||
| } | |||
| TypePtr cast_target = tensor_target_v->cast<TypePtr>(); | |||
| return std::make_shared<AbstractRef>(args_spec_list[0], args_spec_list[1], need_cast, cast_target); | |||
| } | |||
| AbstractBasePtr InferImplGetRefKey(const AnalysisEnginePtr &, const PrimitivePtr &, | |||
| @@ -163,25 +169,11 @@ AbstractBasePtr InferImplGetRefValue(const AnalysisEnginePtr &, const PrimitiveP | |||
| } | |||
| TypePtr type = args_spec_list[0]->GetTypeTrack(); | |||
| if (type->type_id() != kObjectTypeRef) { | |||
| MS_LOG(EXCEPTION) << "First input of get_ref_value should be a Ref but a " << type->ToString(); | |||
| return args_spec_list[0]; | |||
| } | |||
| return args_spec_list[0]->cast<AbstractRefPtr>()->ref(); | |||
| } | |||
| AbstractBasePtr InferImplGetRefOrigin(const AnalysisEnginePtr &, const PrimitivePtr &, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| // arguments: value | |||
| if (args_spec_list.size() != 1) { | |||
| MS_LOG(EXCEPTION) << "get_ref_origin requires 1 parameters, while the input size is " << args_spec_list.size() | |||
| << "."; | |||
| } | |||
| TypePtr type = args_spec_list[0]->GetTypeTrack(); | |||
| if (type->type_id() != kObjectTypeRef) { | |||
| MS_LOG(EXCEPTION) << "First input of get_ref_value should be a Ref but a " << type->ToString(); | |||
| } | |||
| return args_spec_list[0]->cast<AbstractRefPtr>()->ref_origin(); | |||
| } | |||
| AbstractBasePtr InferImplStateSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| // args: Two objects of a subclass of AbstractBase, key and value. | |||
| @@ -95,10 +95,10 @@ OptimizeIRPassLib::OptimizeIRPassLib() { | |||
| // Ref eliminate | |||
| make_ref_eliminate_ = | |||
| MakeSubstitution(std::make_shared<MakeRefEliminater>(), "make_ref_eliminate", prim::kPrimMakeRef); | |||
| get_ref_param_eliminate_ = MakeSubstitution(std::make_shared<GetRefParamEliminater>(), "get_ref_param_eliminate", | |||
| {prim::kPrimGetRefValue, prim::kPrimGetRefOrigin}); | |||
| get_ref_param_eliminate_ = | |||
| MakeSubstitution(std::make_shared<GetRefParamEliminater>(), "get_ref_param_eliminate", {prim::kPrimGetRefValue}); | |||
| get_make_ref_eliminate_ = MakeSubstitution(std::make_shared<GetMakeRefEliminater>(), "get_make_ref_eliminate", | |||
| {prim::kPrimGetRefKey, prim::kPrimGetRefValue, prim::kPrimGetRefOrigin}); | |||
| {prim::kPrimGetRefKey, prim::kPrimGetRefValue}); | |||
| replace_refkey_by_param_ = MakeSubstitution(std::make_shared<ReplaceRefkeyByParam>(), "replace_refkey_by_param", | |||
| IsValueNode<RefKey>, opt::FORCE_RENORM); | |||
| @@ -37,27 +37,23 @@ class MakeRefEliminater : public OptimizerCaller { | |||
| }; | |||
| // {prim::kPrimGetRefValue, Parameter} -> Parameter | |||
| // {prim::kPrimGetRefOrigin, Parameter} -> Parameter | |||
| class GetRefParamEliminater : public OptimizerCaller { | |||
| public: | |||
| AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { | |||
| PatternNode<AnfNodePtr> x; | |||
| MATCH_REPLACE(node, PPrimitive(prim::kPrimGetRefValue, x), x); | |||
| MATCH_REPLACE(node, PPrimitive(prim::kPrimGetRefOrigin, x), x); | |||
| return nullptr; | |||
| } | |||
| }; | |||
| // {prim::kPrimGetRefKey, {prim::kPrimMakeRef, X, Y, Z}} -> X | |||
| // {prim::kPrimGetRefValue, {prim::kPrimMakeRef, X, Y, Z}} -> Y | |||
| // {prim::kPrimGetRefOrigin, {prim::kPrimMakeRef, X, Y, Z}} -> Z | |||
| class GetMakeRefEliminater : public OptimizerCaller { | |||
| public: | |||
| AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { | |||
| PatternNode<AnfNodePtr> x, y, z; | |||
| MATCH_REPLACE(node, PPrimitive(prim::kPrimGetRefKey, PPrimitive(prim::kPrimMakeRef, x, y, z)), x); | |||
| MATCH_REPLACE(node, PPrimitive(prim::kPrimGetRefValue, PPrimitive(prim::kPrimMakeRef, x, y, z)), y); | |||
| MATCH_REPLACE(node, PPrimitive(prim::kPrimGetRefOrigin, PPrimitive(prim::kPrimMakeRef, x, y, z)), z); | |||
| return nullptr; | |||
| } | |||
| @@ -60,6 +60,17 @@ FuncGraphPtr ParsePythonCode(const py::object &obj, const std::string &python_mo | |||
| return func_graph; | |||
| } | |||
| ValuePtr GetMixedPrecisionTargetType(const FuncGraphPtr &func_graph, const AnfNodePtr ¶m) { | |||
| TypePtr dst_type; | |||
| if (func_graph->has_flag(GRAPH_FLAG_MIX_PRECISION_FP32)) { | |||
| return kFloat32; | |||
| } else if (func_graph->has_flag(GRAPH_FLAG_MIX_PRECISION_FP16)) { | |||
| return kFloat16; | |||
| } else { | |||
| return kNone; | |||
| } | |||
| } | |||
| // if any mixed precision flag add a cast node after the parameter node. | |||
| AnfNodePtr GetMixedPrecisionCastHelp(const FuncGraphPtr &func_graph, const AnfNodePtr ¶m) { | |||
| TypePtr dst_type; | |||
| @@ -359,6 +359,7 @@ class ParseAst { | |||
| bool UpdateFuncGraphFlags(py::object obj, const FuncGraphPtr &func_graph); | |||
| AnfNodePtr GetMixedPrecisionCastHelp(const FuncGraphPtr &func_graph, const AnfNodePtr ¶m); | |||
| ValuePtr GetMixedPrecisionTargetType(const FuncGraphPtr &func_graph, const AnfNodePtr ¶m); | |||
| } // namespace parse | |||
| } // namespace mindspore | |||
| @@ -70,6 +70,7 @@ bool SymbolResolver::Resolve() { | |||
| } | |||
| namespace { | |||
| // if any mixed precision flag add a cast node after the parameter node. | |||
| // argument obj should be python Parameter object | |||
| // it will be converted to Parameter node here | |||
| AnfNodePtr ResolveParameterObj(const FuncGraphPtr &func_graph, const py::object &obj) { | |||
| @@ -112,11 +113,12 @@ AnfNodePtr ResolveParameterObj(const FuncGraphPtr &func_graph, const py::object | |||
| } | |||
| auto iter = func_graph->make_ref_params().find(para_node); | |||
| if (iter == func_graph->make_ref_params().end()) { | |||
| AnfNodePtr value = GetMixedPrecisionCastHelp(func_graph, para_node); | |||
| ValuePtr target_type = GetMixedPrecisionTargetType(func_graph, para_node); | |||
| AnfNodePtr make_ref = NewValueNode(prim::kPrimMakeRef); | |||
| AnfNodePtr ref_key = NewValueNode(std::make_shared<RefKey>(param_name)); | |||
| AnfNodePtr ref_node = func_graph->NewCNode({make_ref, ref_key, value, para_node}); | |||
| AnfNodePtr target_type_node = NewValueNode(target_type); | |||
| AnfNodePtr ref_node = func_graph->NewCNode({make_ref, ref_key, para_node, target_type_node}); | |||
| func_graph->make_ref_params()[para_node] = ref_node; | |||
| func_graph->add_parameter_obj_node(ref_node); | |||
| return ref_node; | |||
| @@ -125,7 +125,6 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { | |||
| {prim::kPrimMakeRef, {InferImplMakeRef, true}}, | |||
| {prim::kPrimGetRefKey, {InferImplGetRefKey, true}}, | |||
| {prim::kPrimGetRefValue, {InferImplGetRefValue, true}}, | |||
| {prim::kPrimGetRefOrigin, {InferImplGetRefOrigin, true}}, | |||
| {prim::kPrimStateSetItem, {InferImplStateSetItem, true}}, | |||
| {prim::kPrimDepend, {InferImplDepend, true}}, | |||
| {prim::kPrimBroadcastGradientArgs, {InferImplBroadcastGradientArgs, false}}, | |||
| @@ -1117,11 +1117,12 @@ std::vector<AnfNodePtr> PynativeExecutor::GetWeightsArgs(const py::object &weigh | |||
| free_param->debug_info()->set_name(param_name); | |||
| para_node = free_param; | |||
| } | |||
| AnfNodePtr value = parse::GetMixedPrecisionCastHelp(df_builder_, para_node); | |||
| ValuePtr target_type = parse::GetMixedPrecisionTargetType(df_builder_, para_node); | |||
| AnfNodePtr make_ref = NewValueNode(prim::kPrimMakeRef); | |||
| auto refkey = std::make_shared<RefKey>(para_node->cast<ParameterPtr>()->name()); | |||
| AnfNodePtr ref_key_node = NewValueNode(refkey); | |||
| AnfNodePtr ref_node = df_builder_->NewCNode({make_ref, ref_key_node, value, para_node}); | |||
| AnfNodePtr target_type_node = NewValueNode(target_type); | |||
| AnfNodePtr ref_node = df_builder_->NewCNode({make_ref, ref_key_node, para_node, target_type_node}); | |||
| w_args.push_back(ref_node); | |||
| } | |||
| } else { | |||
| @@ -808,14 +808,40 @@ std::string AbstractJTagged::ToString() const { | |||
| return buffer.str(); | |||
| } | |||
| AbstractRef::AbstractRef(const AbstractBasePtr &ref_key, const AbstractBasePtr &ref_value, bool need_cast, | |||
| TypePtr cast_target) | |||
| : ref_key_(ref_key), ref_(ref_value), need_cast_(false), target_type_(nullptr), ref_key_value_(nullptr) { | |||
| set_type(std::make_shared<RefType>()); | |||
| auto origin_type = ref_value->BuildType(); | |||
| if (need_cast && cast_target && origin_type && origin_type->isa<TensorType>()) { | |||
| auto tensor_dtype = origin_type->cast<TensorTypePtr>()->element(); | |||
| if (tensor_dtype && IsSubType(tensor_dtype, kFloat)) { | |||
| if (cast_target != tensor_dtype) { | |||
| need_cast_ = true; | |||
| target_type_ = cast_target; | |||
| } | |||
| } | |||
| } | |||
| if (ref_key && ref_key->isa<AbstractRefKey>()) { | |||
| ref_key_value_ = ref_key->cast<AbstractRefKeyPtr>()->ref_key_value(); | |||
| } | |||
| } | |||
| BaseShapePtr AbstractRef::BuildShape() const { return ref_->BuildShape(); } | |||
| TypePtr AbstractRef::BuildType() const { | |||
| TypePtr subtype = ref_->BuildType(); | |||
| TypePtr subtype_origin = ref_origin_->BuildType(); | |||
| TypePtr subtype_origin = subtype; | |||
| if (need_cast_) { | |||
| subtype_origin = std::make_shared<TensorType>(target_type_); | |||
| } | |||
| return std::make_shared<RefType>(subtype, subtype_origin); | |||
| } | |||
| bool AbstractRef::operator==(const AbstractRef &other) const { | |||
| return (*ref_ == *other.ref_) && (*ref_key_ == *other.ref_key_) && (*ref_origin_ == *other.ref_origin_); | |||
| return (*ref_ == *other.ref_) && (need_cast_ == other.need_cast_) && | |||
| (!need_cast_ || (*target_type_ == *other.target_type_)); | |||
| // not compare the key for reuse the graph (*ref_key_ == *other.ref_key_); | |||
| } | |||
| bool AbstractRef::operator==(const AbstractBase &other) const { | |||
| @@ -826,27 +852,45 @@ bool AbstractRef::operator==(const AbstractBase &other) const { | |||
| return false; | |||
| } | |||
| AbstractBasePtr AbstractRefKey::Join(const AbstractBasePtr &other) { | |||
| MS_EXCEPTION_IF_NULL(other); | |||
| if (*this == *other) { | |||
| auto ret = shared_from_base<AbstractBase>(); | |||
| return ret; | |||
| } | |||
| auto value_self = GetValueTrack(); | |||
| MS_EXCEPTION_IF_NULL(value_self); | |||
| ValuePtr res_value = ValueJoin(value_self, other->GetValueTrack()); | |||
| if (res_value == value_self) { | |||
| auto ret = shared_from_base<AbstractBase>(); | |||
| return ret; | |||
| } | |||
| auto ret = std::make_shared<AbstractRefKey>(); | |||
| ret->set_value(res_value); | |||
| return ret; | |||
| } | |||
| AbstractBasePtr AbstractRef::Join(const AbstractBasePtr &other) { | |||
| auto other_ref = other->cast<AbstractRefPtr>(); | |||
| if (other_ref == nullptr) { | |||
| auto new_ref = ref_->Join(other); | |||
| return std::make_shared<AbstractRef>(ref_key_, new_ref, ref_origin_); | |||
| return std::make_shared<AbstractRef>(ref_key_, new_ref); | |||
| } | |||
| if (*this == *other) { | |||
| if ((*this == *other) && (*ref_key_ == *other_ref->ref_key_)) { | |||
| return shared_from_base<AbstractBase>(); | |||
| } | |||
| auto ref_key = ref_key_->Join(other_ref->ref_key_); | |||
| auto ref = ref_->Join(other_ref->ref()); | |||
| auto ref_origin = ref_origin_->Join(other_ref->ref_origin_); | |||
| return std::make_shared<AbstractRef>(ref_key, ref, ref_origin); | |||
| return std::make_shared<AbstractRef>(ref_key, ref); | |||
| } | |||
| std::string AbstractRef::ToString() const { | |||
| std::ostringstream buffer; | |||
| buffer << type_name() << "(" | |||
| << "key: " << ref_key_->ToString() << " ref_value: " << ref_->ToString() | |||
| << " origin_value: " << ref_origin_->ToString(); | |||
| << "key: " << ref_key_->ToString() << " ref_value: " << ref_->ToString(); | |||
| if (need_cast_) { | |||
| buffer << " cast to: " << target_type_->ToString(); | |||
| } | |||
| auto value = GetValueTrack(); | |||
| if (value) { | |||
| buffer << ", value: " << value->ToString(); | |||
| @@ -873,6 +917,12 @@ std::string AbstractNone::ToString() const { | |||
| ValuePtr AbstractNone::RealBuildValue() const { return kNone; } | |||
| AbstractBasePtr AbstractRefKey::Broaden() const { | |||
| auto refkey = std::make_shared<AbstractRefKey>(); | |||
| refkey->set_value(kAnyValue); | |||
| return refkey; | |||
| } | |||
| bool AbstractRefKey::operator==(const AbstractRefKey &other) const { | |||
| ValuePtr value_self = GetValueTrack(); | |||
| ValuePtr value_other = other.GetValueTrack(); | |||
| @@ -535,50 +535,70 @@ using AbstractEllipsisPtr = std::shared_ptr<AbstractEllipsis>; | |||
| class AbstractRefKey : public AbstractBase { | |||
| public: | |||
| AbstractRefKey() : AbstractBase() { set_type(std::make_shared<RefKeyType>()); } | |||
| AbstractRefKey() : AbstractBase(), ref_key_value_(nullptr) { set_type(std::make_shared<RefKeyType>()); } | |||
| ~AbstractRefKey() override = default; | |||
| MS_DECLARE_PARENT(AbstractRefKey, AbstractBase) | |||
| TypePtr BuildType() const override { return std::make_shared<RefKeyType>(); } | |||
| bool operator==(const AbstractRefKey &other) const; | |||
| bool operator==(const AbstractBase &other) const override; | |||
| AbstractBasePtr Clone() const override { return std::make_shared<AbstractRefKey>(); } | |||
| AbstractBasePtr Clone() const override { | |||
| auto cloned = std::make_shared<AbstractRefKey>(); | |||
| cloned->set_value(GetValueTrack()); | |||
| return cloned; | |||
| } | |||
| inline void set_value(const ValuePtr &value) { | |||
| AbstractBase::set_value(value); | |||
| ref_key_value_ = value->cast<RefKeyPtr>(); | |||
| } | |||
| RefKeyPtr ref_key_value() const { return ref_key_value_; } | |||
| AbstractBasePtr Join(const AbstractBasePtr &other) override; | |||
| AbstractBasePtr Broaden() const override; | |||
| std::string ToString() const override; | |||
| private: | |||
| // cache for ref_key after build value, when value is null, return nullptr. | |||
| RefKeyPtr ref_key_value_{nullptr}; | |||
| }; | |||
| using AbstractRefKeyPtr = std::shared_ptr<AbstractRefKey>; | |||
| class AbstractRef : public AbstractBase { | |||
| public: | |||
| AbstractRef(const AbstractBasePtr &ref_key, const AbstractBasePtr &ref_value, const AbstractBasePtr &ref_origin) | |||
| : ref_key_(ref_key), ref_(ref_value), ref_origin_(ref_origin) { | |||
| set_type(std::make_shared<RefType>()); | |||
| } | |||
| AbstractRef(const AbstractBasePtr &ref_key, const AbstractBasePtr &ref_value, bool need_cast = false, | |||
| TypePtr cast_target = nullptr); | |||
| ~AbstractRef() override = default; | |||
| MS_DECLARE_PARENT(AbstractRef, AbstractBase) | |||
| TypePtr BuildType() const override; | |||
| BaseShapePtr BuildShape() const override; | |||
| bool operator==(const AbstractRef &other) const; | |||
| bool operator==(const AbstractBase &other) const override; | |||
| AbstractBasePtr Clone() const override { | |||
| return std::make_shared<AbstractRef>(ref_key_->Clone(), ref_->Clone(), ref_origin_->Clone()); | |||
| return std::make_shared<AbstractRef>(ref_key_->Clone(), ref_->Clone(), need_cast_, target_type_); | |||
| } | |||
| std::string ToString() const override; | |||
| AbstractBasePtr ref() { return ref_; } | |||
| AbstractBasePtr ref_origin() { return ref_origin_; } | |||
| AbstractBasePtr ref_key() { return ref_key_; } | |||
| inline AbstractBasePtr ref() const { return ref_; } | |||
| inline AbstractBasePtr ref_key() const { return ref_key_; } | |||
| inline RefKeyPtr ref_key_value() const { return ref_key_value_; } | |||
| inline TypePtr target_type() const { return target_type_; } | |||
| inline bool need_cast() const { return need_cast_; } | |||
| AbstractBasePtr Broaden() const override { | |||
| return std::make_shared<AbstractRef>(ref_key_->Broaden(), ref_->Broaden(), ref_origin_->Broaden()); | |||
| return std::make_shared<AbstractRef>(ref_key_->Broaden(), ref_->Broaden(), need_cast_, target_type_); | |||
| } | |||
| AbstractBasePtr Join(const AbstractBasePtr &other) override; | |||
| std::size_t hash() const override { | |||
| return ref_key_->hash() ^ ref_->hash() ^ ref_origin_->hash() ^ (std::hash<uint32_t>{}(this->tid()) << 1); | |||
| return ref_->hash() ^ (std::hash<uint32_t>{}(this->tid()) << 1); // ref_key_->hash() ^ | |||
| } | |||
| private: | |||
| AbstractBasePtr ref_key_; | |||
| AbstractBasePtr ref_; | |||
| AbstractBasePtr ref_origin_; | |||
| // For mix presicion, only float type need to cast to float16 of float32 | |||
| bool need_cast_; | |||
| TypePtr target_type_; | |||
| // cache for ref_key after build value, when value is null, return nullptr. | |||
| RefKeyPtr ref_key_value_; | |||
| }; | |||
| using AbstractRefPtr = std::shared_ptr<AbstractRef>; | |||
| @@ -171,9 +171,7 @@ AnalysisContextPtr AnalysisContext::SpecializeKey() const { | |||
| } | |||
| if (arg->isa<AbstractRef>()) { | |||
| MS_LOG(DEBUG) << "refkey broaden"; | |||
| auto arg_spec = dyn_cast<AbstractRef>(arg); | |||
| auto ret_spec = arg_spec->Broaden(); | |||
| return ret_spec; | |||
| return arg->Broaden(); | |||
| } | |||
| return arg; | |||
| }); | |||
| @@ -121,7 +121,6 @@ inline const PrimitivePtr kPrimEnvAdd = std::make_shared<Primitive>("env_add"); | |||
| inline const PrimitivePtr kPrimMakeRefKey = std::make_shared<Primitive>("MakeRefKey"); | |||
| inline const PrimitivePtr kPrimGetRefKey = std::make_shared<Primitive>("get_ref_key"); | |||
| inline const PrimitivePtr kPrimGetRefValue = std::make_shared<Primitive>("get_ref_value"); | |||
| inline const PrimitivePtr kPrimGetRefOrigin = std::make_shared<Primitive>("get_ref_origin"); | |||
| inline const PrimitivePtr kPrimInsertGradientOf = std::make_shared<Primitive>("InsertGradientOf"); | |||
| inline const PrimitivePtr kPrimHookBackward = std::make_shared<Primitive>("HookBackward"); | |||
| inline const PrimitivePtr kPrimPrintShapeType = std::make_shared<Primitive>("PrintShapeType"); | |||