Browse Source

!4160 [refine]remove ref origin

Merge pull request !4160 from vlne-v1/remove-ref-origin
tags/v0.7.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
fe2c2e8330
15 changed files with 182 additions and 98 deletions
  1. +34
    -34
      mindspore/ccsrc/frontend/operator/composite/composite.cc
  2. +20
    -6
      mindspore/ccsrc/frontend/operator/composite/do_signature.cc
  3. +4
    -3
      mindspore/ccsrc/frontend/operator/composite/unpack_call.cc
  4. +9
    -17
      mindspore/ccsrc/frontend/operator/prim_others.cc
  5. +3
    -3
      mindspore/ccsrc/frontend/optimizer/irpass.cc
  6. +0
    -4
      mindspore/ccsrc/frontend/optimizer/irpass/ref_eliminate.h
  7. +11
    -0
      mindspore/ccsrc/pipeline/jit/parse/parse.cc
  8. +1
    -0
      mindspore/ccsrc/pipeline/jit/parse/parse.h
  9. +4
    -2
      mindspore/ccsrc/pipeline/jit/parse/resolve.cc
  10. +0
    -1
      mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc
  11. +3
    -2
      mindspore/ccsrc/pipeline/pynative/pynative_execute.cc
  12. +59
    -9
      mindspore/core/abstract/abstract_value.cc
  13. +33
    -13
      mindspore/core/abstract/abstract_value.h
  14. +1
    -3
      mindspore/core/abstract/analysis_context.cc
  15. +0
    -1
      mindspore/core/base/core_ops.h

+ 34
- 34
mindspore/ccsrc/frontend/operator/composite/composite.cc View File

@@ -333,28 +333,28 @@ ArgsPairList HyperMap::Harmonize(const FuncGraphPtr &func_graph, const ArgsPairL
} }


FuncGraphPtr HyperMap::GenerateFromTypes(const TypePtrList &args_spec_list) { FuncGraphPtr HyperMap::GenerateFromTypes(const TypePtrList &args_spec_list) {
FuncGraphPtr ptrGraph = std::make_shared<FuncGraph>();
ptrGraph->set_flag(FUNC_GRAPH_FLAG_CORE, true);
ptrGraph->set_flag(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER, true);
ptrGraph->debug_info()->set_name("hyper_map");
FuncGraphPtr ptr_graph = std::make_shared<FuncGraph>();
ptr_graph->set_flag(FUNC_GRAPH_FLAG_CORE, true);
ptr_graph->set_flag(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER, true);
ptr_graph->debug_info()->set_name("hyper_map");


AnfNodePtr ptrFnArg = nullptr; AnfNodePtr ptrFnArg = nullptr;
std::size_t i = 0; std::size_t i = 0;
ArgsPairList argmap; ArgsPairList argmap;
ArgsPairList argmap2; ArgsPairList argmap2;
if (fn_leaf_ == nullptr) { if (fn_leaf_ == nullptr) {
ptrFnArg = ptrGraph->add_parameter();
ptrFnArg = ptr_graph->add_parameter();
i = 1; i = 1;
} }


std::size_t size = args_spec_list.size(); std::size_t size = args_spec_list.size();
for (; i < size; ++i) { for (; i < size; ++i) {
argmap.push_back(std::make_pair(ptrGraph->add_parameter(), args_spec_list[i]));
argmap.push_back(std::make_pair(ptr_graph->add_parameter(), args_spec_list[i]));
} }


argmap2 = Harmonize(ptrGraph, argmap);
ptrGraph->set_output(Make(ptrGraph, ptrFnArg, argmap2));
return ptrGraph;
argmap2 = Harmonize(ptr_graph, argmap);
ptr_graph->set_output(Make(ptr_graph, ptrFnArg, argmap2));
return ptr_graph;
} }


abstract::AbstractBasePtrList HyperMap::NormalizeArgs(const AbstractBasePtrList &args_spec_list) const { abstract::AbstractBasePtrList HyperMap::NormalizeArgs(const AbstractBasePtrList &args_spec_list) const {
@@ -582,30 +582,30 @@ FuncGraphPtr GradOperation::GetGrad(AnfNodePtr node, const AnfNodePtr &weights,
inputs.push_back(opsTupleItem); inputs.push_back(opsTupleItem);
inputs.push_back(cnode); inputs.push_back(cnode);
inputs.push_back(NewValueNode(1)); inputs.push_back(NewValueNode(1));
AnfNodePtr ptrBprop = ret->NewCNode(inputs);
AnfNodePtr ptr_bprop = ret->NewCNode(inputs);


doGetGrad(ret, out, ptrBprop, weights_node, opsTupleItem);
doGetGrad(ret, out, ptr_bprop, weights_node, opsTupleItem);
return ret; return ret;
} }


void GradOperation::doGetGrad(const FuncGraphPtr &func_graph, AnfNodePtr out, AnfNodePtr ptrBprop, AnfNodePtr weights,
void GradOperation::doGetGrad(const FuncGraphPtr &func_graph, AnfNodePtr out, AnfNodePtr ptr_bprop, AnfNodePtr weights,
ValueNodePtr opsTupleItem) { ValueNodePtr opsTupleItem) {
MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(func_graph);


AnfNodePtr ptrBPropArg = nullptr;
AnfNodePtr ptr_bprop_arg = nullptr;
if (sens_param_) { if (sens_param_) {
ptrBPropArg = func_graph->add_parameter();
ptr_bprop_arg = func_graph->add_parameter();
} else { } else {
auto ones_like = prim::GetPythonOps("ones_like"); auto ones_like = prim::GetPythonOps("ones_like");
ptrBPropArg = func_graph->NewCNode({NewValueNode(ones_like), out});
ptr_bprop_arg = func_graph->NewCNode({NewValueNode(ones_like), out});
} }


AnfNodePtr ptrBApp = func_graph->NewCNode({ptrBprop, ptrBPropArg});
AnfNodePtr ptr_bapp = func_graph->NewCNode({ptr_bprop, ptr_bprop_arg});


CNodePtr fv_bprop = nullptr; CNodePtr fv_bprop = nullptr;
if (get_by_list_) { if (get_by_list_) {
// python code: grads = hyper_map(F.partial(env_get, env), weights) // python code: grads = hyper_map(F.partial(env_get, env), weights)
AnfNodePtr env = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), ptrBApp, NewValueNode(0)});
AnfNodePtr env = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), ptr_bapp, NewValueNode(0)});
AnfNodePtr partial_env_get = AnfNodePtr partial_env_get =
func_graph->NewCNode({NewValueNode(prim::kPrimPartial), NewValueNode(prim::GetPythonOps("env_get")), env}); func_graph->NewCNode({NewValueNode(prim::kPrimPartial), NewValueNode(prim::GetPythonOps("env_get")), env});
MetaFuncGraphPtr hyper_map = std::make_shared<HyperMap>(); MetaFuncGraphPtr hyper_map = std::make_shared<HyperMap>();
@@ -614,7 +614,7 @@ void GradOperation::doGetGrad(const FuncGraphPtr &func_graph, AnfNodePtr out, An


CNodePtr inputs_bprop = nullptr; CNodePtr inputs_bprop = nullptr;
if (get_all_) { if (get_all_) {
inputs_bprop = func_graph->NewCNode({NewValueNode(kTail), ptrBApp});
inputs_bprop = func_graph->NewCNode({NewValueNode(kTail), ptr_bapp});
} }


// Gradients wrt inputs and parameters // Gradients wrt inputs and parameters
@@ -636,8 +636,8 @@ void GradOperation::doGetGrad(const FuncGraphPtr &func_graph, AnfNodePtr out, An
} }


// Gradients wrt first input. // Gradients wrt first input.
// ptrBApp returns (EnvInstance(grads wrt params), grads wrt input0, grads wrt input1, ...), so 1 is for first input
func_graph->set_output(func_graph->NewCNode({opsTupleItem, ptrBApp, NewValueNode(1)}));
// ptr_bapp returns (EnvInstance(grads wrt params), grads wrt input0, grads wrt input1, ...), so 1 is for first input
func_graph->set_output(func_graph->NewCNode({opsTupleItem, ptr_bapp, NewValueNode(1)}));
} }


// Generate the graph. // Generate the graph.
@@ -657,35 +657,35 @@ FuncGraphPtr GradOperation::GenerateFuncGraph(const AbstractBasePtrList &args_sp
auto real_fn = dyn_cast<FuncGraphAbstractClosure>(fn); auto real_fn = dyn_cast<FuncGraphAbstractClosure>(fn);
MS_EXCEPTION_IF_NULL(real_fn); MS_EXCEPTION_IF_NULL(real_fn);


FuncGraphPtr ptrGraph = real_fn->func_graph();
MS_EXCEPTION_IF_NULL(ptrGraph);
TraceManager::DebugTrace(std::make_shared<TraceGradOperation>(ptrGraph->debug_info()));
FuncGraphPtr dfBuilder = std::make_shared<FuncGraph>();
FuncGraphPtr ptr_graph = real_fn->func_graph();
MS_EXCEPTION_IF_NULL(ptr_graph);
TraceManager::DebugTrace(std::make_shared<TraceGradOperation>(ptr_graph->debug_info()));
FuncGraphPtr df_builder = std::make_shared<FuncGraph>();
TraceManager::EndTrace(); TraceManager::EndTrace();
auto nparam = ptrGraph->parameters().size();
auto nparam = ptr_graph->parameters().size();


std::ostringstream ss; std::ostringstream ss;
ss << "grad{" << nparam << "}"; ss << "grad{" << nparam << "}";
dfBuilder->set_flag(FUNC_GRAPH_FLAG_CORE, true);
dfBuilder->debug_info()->set_name(ss.str());
ParameterPtr param_graph = dfBuilder->add_parameter();
df_builder->set_flag(FUNC_GRAPH_FLAG_CORE, true);
df_builder->debug_info()->set_name(ss.str());
ParameterPtr param_graph = df_builder->add_parameter();


AnfNodePtr weights = nullptr; AnfNodePtr weights = nullptr;
if (get_by_list_) { if (get_by_list_) {
weights = dfBuilder->add_parameter();
weights = df_builder->add_parameter();
} }


std::vector<AnfNodePtr> inputs; std::vector<AnfNodePtr> inputs;
inputs.push_back(NewValueNode(prim::kPrimJ)); inputs.push_back(NewValueNode(prim::kPrimJ));
inputs.push_back(param_graph); inputs.push_back(param_graph);
auto jf = dfBuilder->NewCNode(inputs);
auto jf = df_builder->NewCNode(inputs);
// df is checked in GetGrad // df is checked in GetGrad
TraceManager::DebugTrace(std::make_shared<TraceGradOperation>(ptrGraph->debug_info()));
auto df = GetGrad(jf, weights, ptrGraph->parameters());
TraceManager::DebugTrace(std::make_shared<TraceGradOperation>(ptr_graph->debug_info()));
auto df = GetGrad(jf, weights, ptr_graph->parameters());
TraceManager::EndTrace(); TraceManager::EndTrace();
dfBuilder->set_output(NewValueNode(df));
df_builder->set_output(NewValueNode(df));


return dfBuilder;
return df_builder;
} }


REGISTER_PYBIND_DEFINE(GradOperation_, ([](const py::module *m) { REGISTER_PYBIND_DEFINE(GradOperation_, ([](const py::module *m) {


+ 20
- 6
mindspore/ccsrc/frontend/operator/composite/do_signature.cc View File

@@ -72,10 +72,15 @@ void SetMaxType(TypeId *max_type_id, size_t *max_type_number, const TypeId type_
bool GetTensorOrScalarTypeInfo(AbstractBasePtr arg_value, bool is_write, TypeId *arg_type_id, bool GetTensorOrScalarTypeInfo(AbstractBasePtr arg_value, bool is_write, TypeId *arg_type_id,
TypeId *arg_type = nullptr) { TypeId *arg_type = nullptr) {
if (arg_value->isa<abstract::AbstractRef>()) { if (arg_value->isa<abstract::AbstractRef>()) {
if (is_write) {
arg_value = arg_value->cast<abstract::AbstractRefPtr>()->ref_origin();
} else {
arg_value = arg_value->cast<abstract::AbstractRefPtr>()->ref();
auto ref = arg_value->cast<abstract::AbstractRefPtr>();
arg_value = ref->ref();
if (!is_write && ref->need_cast()) {
auto tensor_type = ref->target_type();
*arg_type_id = tensor_type->type_id();
if (arg_type != nullptr) {
*arg_type = kObjectTypeTensorType;
}
return true;
} }
} }
if (arg_value->isa<abstract::AbstractTensor>()) { if (arg_value->isa<abstract::AbstractTensor>()) {
@@ -248,6 +253,8 @@ void DoAutoCast(const std::string &func_name, const std::vector<Signature> &sign
if (arg_value->isa<abstract::AbstractTensor>() && arg_type_id == it->second) { if (arg_value->isa<abstract::AbstractTensor>() && arg_type_id == it->second) {
continue; continue;
} }
MS_LOG(DEBUG) << "do cast for inputs " << i << " " << (*op_inputs)[i + 1]->ToString() << " " << arg_type_id
<< " to " << it->second;
(*op_inputs)[i + 1] = DoCast((*op_inputs)[i + 1], it->second, graph); (*op_inputs)[i + 1] = DoCast((*op_inputs)[i + 1], it->second, graph);
} }
} }
@@ -289,16 +296,23 @@ AnfNodePtr BuildNewCNode(const FuncGraphPtr &func_graph, const std::string &func


TypePtr type = args_spec_list[i]->GetTypeTrack(); TypePtr type = args_spec_list[i]->GetTypeTrack();
if (type && type->type_id() == kObjectTypeRef) { if (type && type->type_id() == kObjectTypeRef) {
auto ref_abs = args_spec_list[i]->cast<abstract::AbstractRefPtr>();
if (sig == SignatureEnumRW::kRWRead) { if (sig == SignatureEnumRW::kRWRead) {
param = func_graph->NewCNode({NewValueNode(prim::kPrimGetRefValue), param});
param = NewCNode({NewValueNode(prim::kPrimGetRefValue), param}, func_graph);
if (ref_abs && ref_abs->need_cast()) {
auto cast = prim::GetPythonOps("cast", "mindspore.ops.functional");
param = NewCNode({NewValueNode(cast), param, NewValueNode(ref_abs->target_type())}, func_graph);
}
} else if (sig == SignatureEnumRW::kRWWrite) { } else if (sig == SignatureEnumRW::kRWWrite) {
param = func_graph->NewCNode({NewValueNode(prim::kPrimGetRefOrigin), param});
param = NewCNode({NewValueNode(prim::kPrimGetRefValue), param}, func_graph);
write_indices.insert(i); write_indices.insert(i);
} }
// If sig is SignatureEnumRW::kRWRef, not do anything. // If sig is SignatureEnumRW::kRWRef, not do anything.
} else if (sig == SignatureEnumRW::kRWWrite && type->type_id() != kObjectTypeRefKey) { } else if (sig == SignatureEnumRW::kRWWrite && type->type_id() != kObjectTypeRefKey) {
MS_EXCEPTION(TypeError) << "Function " << func_name << "'s input " << i << " should be a Parameter."; MS_EXCEPTION(TypeError) << "Function " << func_name << "'s input " << i << " should be a Parameter.";
} }
MS_LOG(DEBUG) << "Function " << func_name << "'s input " << i << " " << param->DebugString(2) << " type "
<< args_spec_list[i]->ToString();
op_inputs.push_back(param); op_inputs.push_back(param);
} }
// process default // process default


+ 4
- 3
mindspore/ccsrc/frontend/operator/composite/unpack_call.cc View File

@@ -49,13 +49,14 @@ FuncGraphPtr UnpackCall::GenerateFuncGraph(const AbstractBasePtrList &args_spec_
MS_LOG(EXCEPTION) << op_name << " requires at least two args, but got " << arg_length << "."; MS_LOG(EXCEPTION) << op_name << " requires at least two args, but got " << arg_length << ".";
} }


(void)abstract::CheckArg<AbstractFunction>(op_name, args_spec_list, 0);
// No need to check, check will be done in infer.
auto ret_graph = std::make_shared<FuncGraph>(); auto ret_graph = std::make_shared<FuncGraph>();
ret_graph->set_flag(FUNC_GRAPH_FLAG_CORE, true); ret_graph->set_flag(FUNC_GRAPH_FLAG_CORE, true);
ret_graph->debug_info()->set_name("UnpackCall");


AnfNodePtr fnNode = ret_graph->add_parameter();
AnfNodePtr fn_node = ret_graph->add_parameter();
std::vector<AnfNodePtr> elems; std::vector<AnfNodePtr> elems;
elems.push_back(fnNode);
elems.push_back(fn_node);
for (size_t index = 1; index < arg_length; index++) { for (size_t index = 1; index < arg_length; index++) {
MS_EXCEPTION_IF_NULL(args_spec_list[index]); MS_EXCEPTION_IF_NULL(args_spec_list[index]);
if (args_spec_list[index]->isa<AbstractTuple>()) { if (args_spec_list[index]->isa<AbstractTuple>()) {


+ 9
- 17
mindspore/ccsrc/frontend/operator/prim_others.cc View File

@@ -129,16 +129,22 @@ AbstractBasePtr InferImplMakeRefKey(const AnalysisEnginePtr &, const PrimitivePt


AbstractBasePtr InferImplMakeRef(const AnalysisEnginePtr &, const PrimitivePtr &, AbstractBasePtr InferImplMakeRef(const AnalysisEnginePtr &, const PrimitivePtr &,
const AbstractBasePtrList &args_spec_list) { const AbstractBasePtrList &args_spec_list) {
// arguments: key, value, original value
// arguments: key, value, target type(None if no target type)
if (args_spec_list.size() != 3) { if (args_spec_list.size() != 3) {
MS_LOG(EXCEPTION) << "make_ref evaluator requires 3 parameters, while the input size is " << args_spec_list.size() MS_LOG(EXCEPTION) << "make_ref evaluator requires 3 parameters, while the input size is " << args_spec_list.size()
<< "."; << ".";
} }
TypePtr type = args_spec_list[0]->GetTypeTrack(); TypePtr type = args_spec_list[0]->GetTypeTrack();
ValuePtr tensor_target_v = args_spec_list[2]->BuildValue();
if (type->type_id() != kObjectTypeRefKey) { if (type->type_id() != kObjectTypeRefKey) {
MS_LOG(EXCEPTION) << "First input of make_ref should be a RefKey but a " << type->ToString(); MS_LOG(EXCEPTION) << "First input of make_ref should be a RefKey but a " << type->ToString();
} }
return std::make_shared<AbstractRef>(args_spec_list[0], args_spec_list[1], args_spec_list[2]);
auto need_cast = !tensor_target_v->isa<None>();
if (need_cast && !tensor_target_v->isa<Type>()) {
MS_LOG(EXCEPTION) << "Third input of make_ref should be a Type but a " << tensor_target_v->ToString();
}
TypePtr cast_target = tensor_target_v->cast<TypePtr>();
return std::make_shared<AbstractRef>(args_spec_list[0], args_spec_list[1], need_cast, cast_target);
} }


AbstractBasePtr InferImplGetRefKey(const AnalysisEnginePtr &, const PrimitivePtr &, AbstractBasePtr InferImplGetRefKey(const AnalysisEnginePtr &, const PrimitivePtr &,
@@ -163,25 +169,11 @@ AbstractBasePtr InferImplGetRefValue(const AnalysisEnginePtr &, const PrimitiveP
} }
TypePtr type = args_spec_list[0]->GetTypeTrack(); TypePtr type = args_spec_list[0]->GetTypeTrack();
if (type->type_id() != kObjectTypeRef) { if (type->type_id() != kObjectTypeRef) {
MS_LOG(EXCEPTION) << "First input of get_ref_value should be a Ref but a " << type->ToString();
return args_spec_list[0];
} }
return args_spec_list[0]->cast<AbstractRefPtr>()->ref(); return args_spec_list[0]->cast<AbstractRefPtr>()->ref();
} }


AbstractBasePtr InferImplGetRefOrigin(const AnalysisEnginePtr &, const PrimitivePtr &,
const AbstractBasePtrList &args_spec_list) {
// arguments: value
if (args_spec_list.size() != 1) {
MS_LOG(EXCEPTION) << "get_ref_origin requires 1 parameters, while the input size is " << args_spec_list.size()
<< ".";
}
TypePtr type = args_spec_list[0]->GetTypeTrack();
if (type->type_id() != kObjectTypeRef) {
MS_LOG(EXCEPTION) << "First input of get_ref_value should be a Ref but a " << type->ToString();
}
return args_spec_list[0]->cast<AbstractRefPtr>()->ref_origin();
}

AbstractBasePtr InferImplStateSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, AbstractBasePtr InferImplStateSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) { const AbstractBasePtrList &args_spec_list) {
// args: Two objects of a subclass of AbstractBase, key and value. // args: Two objects of a subclass of AbstractBase, key and value.


+ 3
- 3
mindspore/ccsrc/frontend/optimizer/irpass.cc View File

@@ -95,10 +95,10 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
// Ref eliminate // Ref eliminate
make_ref_eliminate_ = make_ref_eliminate_ =
MakeSubstitution(std::make_shared<MakeRefEliminater>(), "make_ref_eliminate", prim::kPrimMakeRef); MakeSubstitution(std::make_shared<MakeRefEliminater>(), "make_ref_eliminate", prim::kPrimMakeRef);
get_ref_param_eliminate_ = MakeSubstitution(std::make_shared<GetRefParamEliminater>(), "get_ref_param_eliminate",
{prim::kPrimGetRefValue, prim::kPrimGetRefOrigin});
get_ref_param_eliminate_ =
MakeSubstitution(std::make_shared<GetRefParamEliminater>(), "get_ref_param_eliminate", {prim::kPrimGetRefValue});
get_make_ref_eliminate_ = MakeSubstitution(std::make_shared<GetMakeRefEliminater>(), "get_make_ref_eliminate", get_make_ref_eliminate_ = MakeSubstitution(std::make_shared<GetMakeRefEliminater>(), "get_make_ref_eliminate",
{prim::kPrimGetRefKey, prim::kPrimGetRefValue, prim::kPrimGetRefOrigin});
{prim::kPrimGetRefKey, prim::kPrimGetRefValue});


replace_refkey_by_param_ = MakeSubstitution(std::make_shared<ReplaceRefkeyByParam>(), "replace_refkey_by_param", replace_refkey_by_param_ = MakeSubstitution(std::make_shared<ReplaceRefkeyByParam>(), "replace_refkey_by_param",
IsValueNode<RefKey>, opt::FORCE_RENORM); IsValueNode<RefKey>, opt::FORCE_RENORM);


+ 0
- 4
mindspore/ccsrc/frontend/optimizer/irpass/ref_eliminate.h View File

@@ -37,27 +37,23 @@ class MakeRefEliminater : public OptimizerCaller {
}; };


// {prim::kPrimGetRefValue, Parameter} -> Parameter // {prim::kPrimGetRefValue, Parameter} -> Parameter
// {prim::kPrimGetRefOrigin, Parameter} -> Parameter
class GetRefParamEliminater : public OptimizerCaller { class GetRefParamEliminater : public OptimizerCaller {
public: public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
PatternNode<AnfNodePtr> x; PatternNode<AnfNodePtr> x;
MATCH_REPLACE(node, PPrimitive(prim::kPrimGetRefValue, x), x); MATCH_REPLACE(node, PPrimitive(prim::kPrimGetRefValue, x), x);
MATCH_REPLACE(node, PPrimitive(prim::kPrimGetRefOrigin, x), x);
return nullptr; return nullptr;
} }
}; };


// {prim::kPrimGetRefKey, {prim::kPrimMakeRef, X, Y, Z}} -> X // {prim::kPrimGetRefKey, {prim::kPrimMakeRef, X, Y, Z}} -> X
// {prim::kPrimGetRefValue, {prim::kPrimMakeRef, X, Y, Z}} -> Y // {prim::kPrimGetRefValue, {prim::kPrimMakeRef, X, Y, Z}} -> Y
// {prim::kPrimGetRefOrigin, {prim::kPrimMakeRef, X, Y, Z}} -> Z
class GetMakeRefEliminater : public OptimizerCaller { class GetMakeRefEliminater : public OptimizerCaller {
public: public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
PatternNode<AnfNodePtr> x, y, z; PatternNode<AnfNodePtr> x, y, z;
MATCH_REPLACE(node, PPrimitive(prim::kPrimGetRefKey, PPrimitive(prim::kPrimMakeRef, x, y, z)), x); MATCH_REPLACE(node, PPrimitive(prim::kPrimGetRefKey, PPrimitive(prim::kPrimMakeRef, x, y, z)), x);
MATCH_REPLACE(node, PPrimitive(prim::kPrimGetRefValue, PPrimitive(prim::kPrimMakeRef, x, y, z)), y); MATCH_REPLACE(node, PPrimitive(prim::kPrimGetRefValue, PPrimitive(prim::kPrimMakeRef, x, y, z)), y);
MATCH_REPLACE(node, PPrimitive(prim::kPrimGetRefOrigin, PPrimitive(prim::kPrimMakeRef, x, y, z)), z);


return nullptr; return nullptr;
} }


+ 11
- 0
mindspore/ccsrc/pipeline/jit/parse/parse.cc View File

@@ -60,6 +60,17 @@ FuncGraphPtr ParsePythonCode(const py::object &obj, const std::string &python_mo
return func_graph; return func_graph;
} }


ValuePtr GetMixedPrecisionTargetType(const FuncGraphPtr &func_graph, const AnfNodePtr &param) {
TypePtr dst_type;
if (func_graph->has_flag(GRAPH_FLAG_MIX_PRECISION_FP32)) {
return kFloat32;
} else if (func_graph->has_flag(GRAPH_FLAG_MIX_PRECISION_FP16)) {
return kFloat16;
} else {
return kNone;
}
}

// if any mixed precision flag add a cast node after the parameter node. // if any mixed precision flag add a cast node after the parameter node.
AnfNodePtr GetMixedPrecisionCastHelp(const FuncGraphPtr &func_graph, const AnfNodePtr &param) { AnfNodePtr GetMixedPrecisionCastHelp(const FuncGraphPtr &func_graph, const AnfNodePtr &param) {
TypePtr dst_type; TypePtr dst_type;


+ 1
- 0
mindspore/ccsrc/pipeline/jit/parse/parse.h View File

@@ -359,6 +359,7 @@ class ParseAst {
bool UpdateFuncGraphFlags(py::object obj, const FuncGraphPtr &func_graph); bool UpdateFuncGraphFlags(py::object obj, const FuncGraphPtr &func_graph);


AnfNodePtr GetMixedPrecisionCastHelp(const FuncGraphPtr &func_graph, const AnfNodePtr &param); AnfNodePtr GetMixedPrecisionCastHelp(const FuncGraphPtr &func_graph, const AnfNodePtr &param);
ValuePtr GetMixedPrecisionTargetType(const FuncGraphPtr &func_graph, const AnfNodePtr &param);


} // namespace parse } // namespace parse
} // namespace mindspore } // namespace mindspore


+ 4
- 2
mindspore/ccsrc/pipeline/jit/parse/resolve.cc View File

@@ -70,6 +70,7 @@ bool SymbolResolver::Resolve() {
} }


namespace { namespace {
// if any mixed precision flag add a cast node after the parameter node.
// argument obj should be python Parameter object // argument obj should be python Parameter object
// it will be converted to Parameter node here // it will be converted to Parameter node here
AnfNodePtr ResolveParameterObj(const FuncGraphPtr &func_graph, const py::object &obj) { AnfNodePtr ResolveParameterObj(const FuncGraphPtr &func_graph, const py::object &obj) {
@@ -112,11 +113,12 @@ AnfNodePtr ResolveParameterObj(const FuncGraphPtr &func_graph, const py::object
} }
auto iter = func_graph->make_ref_params().find(para_node); auto iter = func_graph->make_ref_params().find(para_node);
if (iter == func_graph->make_ref_params().end()) { if (iter == func_graph->make_ref_params().end()) {
AnfNodePtr value = GetMixedPrecisionCastHelp(func_graph, para_node);
ValuePtr target_type = GetMixedPrecisionTargetType(func_graph, para_node);


AnfNodePtr make_ref = NewValueNode(prim::kPrimMakeRef); AnfNodePtr make_ref = NewValueNode(prim::kPrimMakeRef);
AnfNodePtr ref_key = NewValueNode(std::make_shared<RefKey>(param_name)); AnfNodePtr ref_key = NewValueNode(std::make_shared<RefKey>(param_name));
AnfNodePtr ref_node = func_graph->NewCNode({make_ref, ref_key, value, para_node});
AnfNodePtr target_type_node = NewValueNode(target_type);
AnfNodePtr ref_node = func_graph->NewCNode({make_ref, ref_key, para_node, target_type_node});
func_graph->make_ref_params()[para_node] = ref_node; func_graph->make_ref_params()[para_node] = ref_node;
func_graph->add_parameter_obj_node(ref_node); func_graph->add_parameter_obj_node(ref_node);
return ref_node; return ref_node;


+ 0
- 1
mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc View File

@@ -125,7 +125,6 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
{prim::kPrimMakeRef, {InferImplMakeRef, true}}, {prim::kPrimMakeRef, {InferImplMakeRef, true}},
{prim::kPrimGetRefKey, {InferImplGetRefKey, true}}, {prim::kPrimGetRefKey, {InferImplGetRefKey, true}},
{prim::kPrimGetRefValue, {InferImplGetRefValue, true}}, {prim::kPrimGetRefValue, {InferImplGetRefValue, true}},
{prim::kPrimGetRefOrigin, {InferImplGetRefOrigin, true}},
{prim::kPrimStateSetItem, {InferImplStateSetItem, true}}, {prim::kPrimStateSetItem, {InferImplStateSetItem, true}},
{prim::kPrimDepend, {InferImplDepend, true}}, {prim::kPrimDepend, {InferImplDepend, true}},
{prim::kPrimBroadcastGradientArgs, {InferImplBroadcastGradientArgs, false}}, {prim::kPrimBroadcastGradientArgs, {InferImplBroadcastGradientArgs, false}},


+ 3
- 2
mindspore/ccsrc/pipeline/pynative/pynative_execute.cc View File

@@ -1117,11 +1117,12 @@ std::vector<AnfNodePtr> PynativeExecutor::GetWeightsArgs(const py::object &weigh
free_param->debug_info()->set_name(param_name); free_param->debug_info()->set_name(param_name);
para_node = free_param; para_node = free_param;
} }
AnfNodePtr value = parse::GetMixedPrecisionCastHelp(df_builder_, para_node);
ValuePtr target_type = parse::GetMixedPrecisionTargetType(df_builder_, para_node);
AnfNodePtr make_ref = NewValueNode(prim::kPrimMakeRef); AnfNodePtr make_ref = NewValueNode(prim::kPrimMakeRef);
auto refkey = std::make_shared<RefKey>(para_node->cast<ParameterPtr>()->name()); auto refkey = std::make_shared<RefKey>(para_node->cast<ParameterPtr>()->name());
AnfNodePtr ref_key_node = NewValueNode(refkey); AnfNodePtr ref_key_node = NewValueNode(refkey);
AnfNodePtr ref_node = df_builder_->NewCNode({make_ref, ref_key_node, value, para_node});
AnfNodePtr target_type_node = NewValueNode(target_type);
AnfNodePtr ref_node = df_builder_->NewCNode({make_ref, ref_key_node, para_node, target_type_node});
w_args.push_back(ref_node); w_args.push_back(ref_node);
} }
} else { } else {


+ 59
- 9
mindspore/core/abstract/abstract_value.cc View File

@@ -808,14 +808,40 @@ std::string AbstractJTagged::ToString() const {
return buffer.str(); return buffer.str();
} }


AbstractRef::AbstractRef(const AbstractBasePtr &ref_key, const AbstractBasePtr &ref_value, bool need_cast,
TypePtr cast_target)
: ref_key_(ref_key), ref_(ref_value), need_cast_(false), target_type_(nullptr), ref_key_value_(nullptr) {
set_type(std::make_shared<RefType>());
auto origin_type = ref_value->BuildType();
if (need_cast && cast_target && origin_type && origin_type->isa<TensorType>()) {
auto tensor_dtype = origin_type->cast<TensorTypePtr>()->element();
if (tensor_dtype && IsSubType(tensor_dtype, kFloat)) {
if (cast_target != tensor_dtype) {
need_cast_ = true;
target_type_ = cast_target;
}
}
}
if (ref_key && ref_key->isa<AbstractRefKey>()) {
ref_key_value_ = ref_key->cast<AbstractRefKeyPtr>()->ref_key_value();
}
}

BaseShapePtr AbstractRef::BuildShape() const { return ref_->BuildShape(); }

TypePtr AbstractRef::BuildType() const { TypePtr AbstractRef::BuildType() const {
TypePtr subtype = ref_->BuildType(); TypePtr subtype = ref_->BuildType();
TypePtr subtype_origin = ref_origin_->BuildType();
TypePtr subtype_origin = subtype;
if (need_cast_) {
subtype_origin = std::make_shared<TensorType>(target_type_);
}
return std::make_shared<RefType>(subtype, subtype_origin); return std::make_shared<RefType>(subtype, subtype_origin);
} }


bool AbstractRef::operator==(const AbstractRef &other) const { bool AbstractRef::operator==(const AbstractRef &other) const {
return (*ref_ == *other.ref_) && (*ref_key_ == *other.ref_key_) && (*ref_origin_ == *other.ref_origin_);
return (*ref_ == *other.ref_) && (need_cast_ == other.need_cast_) &&
(!need_cast_ || (*target_type_ == *other.target_type_));
// not compare the key for reuse the graph (*ref_key_ == *other.ref_key_);
} }


bool AbstractRef::operator==(const AbstractBase &other) const { bool AbstractRef::operator==(const AbstractBase &other) const {
@@ -826,27 +852,45 @@ bool AbstractRef::operator==(const AbstractBase &other) const {
return false; return false;
} }


AbstractBasePtr AbstractRefKey::Join(const AbstractBasePtr &other) {
MS_EXCEPTION_IF_NULL(other);
if (*this == *other) {
auto ret = shared_from_base<AbstractBase>();
return ret;
}
auto value_self = GetValueTrack();
MS_EXCEPTION_IF_NULL(value_self);
ValuePtr res_value = ValueJoin(value_self, other->GetValueTrack());
if (res_value == value_self) {
auto ret = shared_from_base<AbstractBase>();
return ret;
}
auto ret = std::make_shared<AbstractRefKey>();
ret->set_value(res_value);
return ret;
}

AbstractBasePtr AbstractRef::Join(const AbstractBasePtr &other) { AbstractBasePtr AbstractRef::Join(const AbstractBasePtr &other) {
auto other_ref = other->cast<AbstractRefPtr>(); auto other_ref = other->cast<AbstractRefPtr>();
if (other_ref == nullptr) { if (other_ref == nullptr) {
auto new_ref = ref_->Join(other); auto new_ref = ref_->Join(other);
return std::make_shared<AbstractRef>(ref_key_, new_ref, ref_origin_);
return std::make_shared<AbstractRef>(ref_key_, new_ref);
} }
if (*this == *other) {
if ((*this == *other) && (*ref_key_ == *other_ref->ref_key_)) {
return shared_from_base<AbstractBase>(); return shared_from_base<AbstractBase>();
} }
auto ref_key = ref_key_->Join(other_ref->ref_key_); auto ref_key = ref_key_->Join(other_ref->ref_key_);
auto ref = ref_->Join(other_ref->ref()); auto ref = ref_->Join(other_ref->ref());
auto ref_origin = ref_origin_->Join(other_ref->ref_origin_);

return std::make_shared<AbstractRef>(ref_key, ref, ref_origin);
return std::make_shared<AbstractRef>(ref_key, ref);
} }


std::string AbstractRef::ToString() const { std::string AbstractRef::ToString() const {
std::ostringstream buffer; std::ostringstream buffer;
buffer << type_name() << "(" buffer << type_name() << "("
<< "key: " << ref_key_->ToString() << " ref_value: " << ref_->ToString()
<< " origin_value: " << ref_origin_->ToString();
<< "key: " << ref_key_->ToString() << " ref_value: " << ref_->ToString();
if (need_cast_) {
buffer << " cast to: " << target_type_->ToString();
}
auto value = GetValueTrack(); auto value = GetValueTrack();
if (value) { if (value) {
buffer << ", value: " << value->ToString(); buffer << ", value: " << value->ToString();
@@ -873,6 +917,12 @@ std::string AbstractNone::ToString() const {


ValuePtr AbstractNone::RealBuildValue() const { return kNone; } ValuePtr AbstractNone::RealBuildValue() const { return kNone; }


AbstractBasePtr AbstractRefKey::Broaden() const {
auto refkey = std::make_shared<AbstractRefKey>();
refkey->set_value(kAnyValue);
return refkey;
}

bool AbstractRefKey::operator==(const AbstractRefKey &other) const { bool AbstractRefKey::operator==(const AbstractRefKey &other) const {
ValuePtr value_self = GetValueTrack(); ValuePtr value_self = GetValueTrack();
ValuePtr value_other = other.GetValueTrack(); ValuePtr value_other = other.GetValueTrack();


+ 33
- 13
mindspore/core/abstract/abstract_value.h View File

@@ -535,50 +535,70 @@ using AbstractEllipsisPtr = std::shared_ptr<AbstractEllipsis>;


class AbstractRefKey : public AbstractBase { class AbstractRefKey : public AbstractBase {
public: public:
AbstractRefKey() : AbstractBase() { set_type(std::make_shared<RefKeyType>()); }
AbstractRefKey() : AbstractBase(), ref_key_value_(nullptr) { set_type(std::make_shared<RefKeyType>()); }
~AbstractRefKey() override = default; ~AbstractRefKey() override = default;
MS_DECLARE_PARENT(AbstractRefKey, AbstractBase) MS_DECLARE_PARENT(AbstractRefKey, AbstractBase)


TypePtr BuildType() const override { return std::make_shared<RefKeyType>(); } TypePtr BuildType() const override { return std::make_shared<RefKeyType>(); }
bool operator==(const AbstractRefKey &other) const; bool operator==(const AbstractRefKey &other) const;
bool operator==(const AbstractBase &other) const override; bool operator==(const AbstractBase &other) const override;
AbstractBasePtr Clone() const override { return std::make_shared<AbstractRefKey>(); }
AbstractBasePtr Clone() const override {
auto cloned = std::make_shared<AbstractRefKey>();
cloned->set_value(GetValueTrack());
return cloned;
}
inline void set_value(const ValuePtr &value) {
AbstractBase::set_value(value);
ref_key_value_ = value->cast<RefKeyPtr>();
}
RefKeyPtr ref_key_value() const { return ref_key_value_; }
AbstractBasePtr Join(const AbstractBasePtr &other) override;
AbstractBasePtr Broaden() const override;
std::string ToString() const override; std::string ToString() const override;

private:
// cache for ref_key after build value, when value is null, return nullptr.
RefKeyPtr ref_key_value_{nullptr};
}; };
using AbstractRefKeyPtr = std::shared_ptr<AbstractRefKey>; using AbstractRefKeyPtr = std::shared_ptr<AbstractRefKey>;


class AbstractRef : public AbstractBase { class AbstractRef : public AbstractBase {
public: public:
AbstractRef(const AbstractBasePtr &ref_key, const AbstractBasePtr &ref_value, const AbstractBasePtr &ref_origin)
: ref_key_(ref_key), ref_(ref_value), ref_origin_(ref_origin) {
set_type(std::make_shared<RefType>());
}
AbstractRef(const AbstractBasePtr &ref_key, const AbstractBasePtr &ref_value, bool need_cast = false,
TypePtr cast_target = nullptr);


~AbstractRef() override = default; ~AbstractRef() override = default;
MS_DECLARE_PARENT(AbstractRef, AbstractBase) MS_DECLARE_PARENT(AbstractRef, AbstractBase)


TypePtr BuildType() const override; TypePtr BuildType() const override;
BaseShapePtr BuildShape() const override;
bool operator==(const AbstractRef &other) const; bool operator==(const AbstractRef &other) const;
bool operator==(const AbstractBase &other) const override; bool operator==(const AbstractBase &other) const override;
AbstractBasePtr Clone() const override { AbstractBasePtr Clone() const override {
return std::make_shared<AbstractRef>(ref_key_->Clone(), ref_->Clone(), ref_origin_->Clone());
return std::make_shared<AbstractRef>(ref_key_->Clone(), ref_->Clone(), need_cast_, target_type_);
} }
std::string ToString() const override; std::string ToString() const override;
AbstractBasePtr ref() { return ref_; }
AbstractBasePtr ref_origin() { return ref_origin_; }
AbstractBasePtr ref_key() { return ref_key_; }
inline AbstractBasePtr ref() const { return ref_; }
inline AbstractBasePtr ref_key() const { return ref_key_; }
inline RefKeyPtr ref_key_value() const { return ref_key_value_; }
inline TypePtr target_type() const { return target_type_; }
inline bool need_cast() const { return need_cast_; }
AbstractBasePtr Broaden() const override { AbstractBasePtr Broaden() const override {
return std::make_shared<AbstractRef>(ref_key_->Broaden(), ref_->Broaden(), ref_origin_->Broaden());
return std::make_shared<AbstractRef>(ref_key_->Broaden(), ref_->Broaden(), need_cast_, target_type_);
} }
AbstractBasePtr Join(const AbstractBasePtr &other) override; AbstractBasePtr Join(const AbstractBasePtr &other) override;
std::size_t hash() const override { std::size_t hash() const override {
return ref_key_->hash() ^ ref_->hash() ^ ref_origin_->hash() ^ (std::hash<uint32_t>{}(this->tid()) << 1);
return ref_->hash() ^ (std::hash<uint32_t>{}(this->tid()) << 1); // ref_key_->hash() ^
} }


private: private:
AbstractBasePtr ref_key_; AbstractBasePtr ref_key_;
AbstractBasePtr ref_; AbstractBasePtr ref_;
AbstractBasePtr ref_origin_;
// For mix presicion, only float type need to cast to float16 of float32
bool need_cast_;
TypePtr target_type_;
// cache for ref_key after build value, when value is null, return nullptr.
RefKeyPtr ref_key_value_;
}; };
using AbstractRefPtr = std::shared_ptr<AbstractRef>; using AbstractRefPtr = std::shared_ptr<AbstractRef>;




+ 1
- 3
mindspore/core/abstract/analysis_context.cc View File

@@ -171,9 +171,7 @@ AnalysisContextPtr AnalysisContext::SpecializeKey() const {
} }
if (arg->isa<AbstractRef>()) { if (arg->isa<AbstractRef>()) {
MS_LOG(DEBUG) << "refkey broaden"; MS_LOG(DEBUG) << "refkey broaden";
auto arg_spec = dyn_cast<AbstractRef>(arg);
auto ret_spec = arg_spec->Broaden();
return ret_spec;
return arg->Broaden();
} }
return arg; return arg;
}); });


+ 0
- 1
mindspore/core/base/core_ops.h View File

@@ -121,7 +121,6 @@ inline const PrimitivePtr kPrimEnvAdd = std::make_shared<Primitive>("env_add");
inline const PrimitivePtr kPrimMakeRefKey = std::make_shared<Primitive>("MakeRefKey"); inline const PrimitivePtr kPrimMakeRefKey = std::make_shared<Primitive>("MakeRefKey");
inline const PrimitivePtr kPrimGetRefKey = std::make_shared<Primitive>("get_ref_key"); inline const PrimitivePtr kPrimGetRefKey = std::make_shared<Primitive>("get_ref_key");
inline const PrimitivePtr kPrimGetRefValue = std::make_shared<Primitive>("get_ref_value"); inline const PrimitivePtr kPrimGetRefValue = std::make_shared<Primitive>("get_ref_value");
inline const PrimitivePtr kPrimGetRefOrigin = std::make_shared<Primitive>("get_ref_origin");
inline const PrimitivePtr kPrimInsertGradientOf = std::make_shared<Primitive>("InsertGradientOf"); inline const PrimitivePtr kPrimInsertGradientOf = std::make_shared<Primitive>("InsertGradientOf");
inline const PrimitivePtr kPrimHookBackward = std::make_shared<Primitive>("HookBackward"); inline const PrimitivePtr kPrimHookBackward = std::make_shared<Primitive>("HookBackward");
inline const PrimitivePtr kPrimPrintShapeType = std::make_shared<Primitive>("PrintShapeType"); inline const PrimitivePtr kPrimPrintShapeType = std::make_shared<Primitive>("PrintShapeType");


Loading…
Cancel
Save