Browse Source

!18049 code self check

Merge pull request !18049 from Margaret_wangrui/code_self_check
tags/v1.3.0
i-robot Gitee 4 years ago
parent
commit
7bf5f2b756
15 changed files with 73 additions and 21 deletions
  1. +2
    -3
      mindspore/ccsrc/frontend/operator/composite/do_signature.cc
  2. +15
    -3
      mindspore/ccsrc/pipeline/jit/parse/function_block.cc
  3. +1
    -2
      mindspore/ccsrc/pipeline/jit/parse/parse.cc
  4. +2
    -0
      mindspore/ccsrc/pipeline/jit/parse/resolve.cc
  5. +4
    -0
      mindspore/ccsrc/pipeline/jit/pass.cc
  6. +9
    -4
      mindspore/ccsrc/pipeline/jit/pipeline.cc
  7. +6
    -2
      mindspore/ccsrc/pipeline/jit/prim_bprop_optimizer.cc
  8. +2
    -0
      mindspore/ccsrc/pipeline/jit/static_analysis/auto_monad.cc
  9. +10
    -4
      mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.cc
  10. +5
    -0
      mindspore/ccsrc/pipeline/jit/static_analysis/order_enforce.cc
  11. +5
    -0
      mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc
  12. +3
    -1
      mindspore/ccsrc/pipeline/jit/static_analysis/program_specialize.cc
  13. +1
    -0
      mindspore/ccsrc/pipeline/jit/static_analysis/stack_frame.cc
  14. +6
    -0
      mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.cc
  15. +2
    -2
      mindspore/ccsrc/pybind_api/ir/signature_py.cc

+ 2
- 3
mindspore/ccsrc/frontend/operator/composite/do_signature.cc View File

@@ -150,8 +150,7 @@ TypeId GetMaxTypeId(const std::vector<TypePtr> &input_types, std::vector<size_t>

// Get the largest type of index in the same SignatureEnumDType of arguments.
using MaxTypeMap = std::map<SignatureEnumDType, TypeId>;
MaxTypeMap GetMaxDtype(const std::vector<SignatureEnumDType> &dtypes, const std::vector<TypePtr> &input_types,
const std::set<size_t> &write_indices) {
MaxTypeMap GetMaxDtype(const std::vector<SignatureEnumDType> &dtypes, const std::vector<TypePtr> &input_types) {
// record index for signature.dtypes of the same type
// eg. [T, T1, T, T2, T, T1, T3] -> {{T:(0,2,4)}, {T1:(1,5)}, {T2:(3)}, {T3:(6)}}
std::map<SignatureEnumDType, std::vector<size_t>> type_indices;
@@ -207,7 +206,7 @@ void DoAutoCast(const std::string &func_name, const std::vector<Signature> &sign
return;
}
// Stat the index of the arguments with the largest type in the same SignatureEnumDType.
std::map<SignatureEnumDType, TypeId> dst_type = GetMaxDtype(dtypes, input_types, write_indices);
std::map<SignatureEnumDType, TypeId> dst_type = GetMaxDtype(dtypes, input_types);
// Identify which arg requires auto cast
for (size_t i = 0; i < input_types.size(); ++i) {
auto it = dst_type.find(dtypes[i]);


+ 15
- 3
mindspore/ccsrc/pipeline/jit/parse/function_block.cc View File

@@ -169,15 +169,19 @@ AnfNodePtr FunctionBlock::MakeResolveSymbol(const std::string &value) {
return MakeResolveClassMember(bits_str);
}
py::tuple namespace_info = parser_.ast()->CallParserObjMethod(PYTHON_PARSE_GET_NAMESPACE_SYMBOL, value);
const size_t namespace_info_size = 2;
// If namespace is None, the symbol is an undefined name or an unsupported builtin function.
if (namespace_info[0].is_none()) {
// If the size of namespace_var is greater than or equal to 3, the error information is stored in namespace_var[2].
if (namespace_info.size() >= 3) {
MS_EXCEPTION(NameError) << namespace_info[2].cast<std::string>();
if (namespace_info.size() > namespace_info_size) {
MS_EXCEPTION(NameError) << namespace_info[namespace_info_size].cast<std::string>();
}
// If the size of namespace_var is less than 3, the default error information is used.
MS_EXCEPTION(NameError) << "The name \'" << value << "\' is not defined.";
}
if (namespace_info.size() < namespace_info_size) {
MS_EXCEPTION(NameError) << "namespace_info is less than 2";
}

NameSpacePtr name_space = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_SYMBOL_STR, namespace_info[0]);
SymbolPtr symbol = std::make_shared<Symbol>(namespace_info[1].cast<std::string>());
@@ -186,6 +190,10 @@ AnfNodePtr FunctionBlock::MakeResolveSymbol(const std::string &value) {

AnfNodePtr FunctionBlock::MakeResolveOperation(const std::string &value) {
py::tuple namespace_var = parser_.ast()->CallParseModFunction(PYTHON_PARSE_GET_OPERATION_NAMESPACE_SYMBOL, value);
const size_t namespace_var_size = 2;
if (namespace_var.size() < namespace_var_size) {
MS_EXCEPTION(NameError) << "namespace_var is less than 2";
}
NameSpacePtr name_space = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_COMMON_OPS, namespace_var[0]);
SymbolPtr symbol = std::make_shared<Symbol>(namespace_var[1].cast<std::string>());
return MakeResolve(name_space, symbol);
@@ -216,6 +224,7 @@ void FunctionBlock::SetPhiArgument(const ParameterPtr &phi) {
MS_LOG(DEBUG) << "graph " << func_graph_->ToString() << " pred_blocks_ " << pred->func_graph_->ToString();
AnfNodePtr arg_node = pred->ReadVariable(var);
CNodePtr jump = pred->jumps_[this];
MS_EXCEPTION_IF_NULL(jump);
jump->add_input(arg_node);
}
}
@@ -225,6 +234,7 @@ AnfNodePtr FunctionBlock::SearchReplaceNode(const std::string &var, const Parame
for (auto &prev : prev_blocks_) {
MS_EXCEPTION_IF_NULL(prev);
AnfNodePtr temp_node = prev->ReadVariable(var);
MS_EXCEPTION_IF_NULL(temp_node);
MS_LOG(DEBUG) << "graph " << prev->func_graph_->ToString() << " phi " << phi->ToString() << " for var " << var
<< " is " << temp_node->DebugString();
if (temp_node != phi) {
@@ -448,7 +458,8 @@ void FunctionBlock::AttachIsolatedNodesBeforeReturn() {
AnfNodePtr old_output = nullptr;
auto return_node = func_graph()->get_return();
if (return_node) {
if (return_node->inputs().empty()) {
const size_t return_input_size = 2;
if (return_node->inputs().size() < return_input_size) {
MS_LOG(EXCEPTION) << "Length of inputs of output node is less than 2";
}
old_output = return_node->input(1);
@@ -460,6 +471,7 @@ void FunctionBlock::AttachIsolatedNodesBeforeReturn() {
// We add this attribute for @constexpr use scene, since we must infer them before other nodes.
// That means isolated nodes will be evaluated first. It's not complete, but works in most scenes.
depend_node->AddAttr(kAttrTopoSortRhsFirst, MakeValue(true));
MS_EXCEPTION_IF_NULL(state);
MS_LOG(INFO) << "Attached for side-effect nodes, depend_node: " << depend_node->DebugString()
<< ", state: " << state->DebugString(2);
func_graph()->set_output(depend_node, true);


+ 1
- 2
mindspore/ccsrc/pipeline/jit/parse/parse.cc View File

@@ -1655,7 +1655,6 @@ void Parser::RemoveUnnecessaryPhis() {
if (removable_phis.empty()) {
return;
}
auto fg_name = func_graph_->ToString();
auto mng = Manage(func_graph_, false);
// Replace the nodes
// Remove from inside to outside
@@ -1680,7 +1679,7 @@ void Parser::RemoveUnnecessaryPhis() {
});

// Shrink container to new size
new_parameters.resize(std::distance(new_parameters.begin(), it));
new_parameters.resize(static_cast<size_t>(std::distance(new_parameters.begin(), it)));
func_graph->set_parameters(new_parameters);
}
}


+ 2
- 0
mindspore/ccsrc/pipeline/jit/parse/resolve.cc View File

@@ -203,6 +203,7 @@ bool IsAllFuncInValueSequence(const std::vector<ValuePtr> &value_vec) {
return false;
}
for (auto &elem : value_vec) {
MS_EXCEPTION_IF_NULL(elem);
if (elem->isa<ValueTuple>() || elem->isa<ValueList>()) {
const auto &vec = GetValue<ValuePtrList>(elem);
auto is_graph = IsAllFuncInValueSequence(vec);
@@ -221,6 +222,7 @@ AnfNodePtr TransformToMakeTupleNodes(const FuncGraphManagerPtr &manager, const F
std::vector<AnfNodePtr> nodes;
nodes.emplace_back(NewValueNode(prim::kPrimMakeTuple));
for (auto &elem : value_vec) {
MS_EXCEPTION_IF_NULL(elem);
AnfNodePtr node = nullptr;
if (elem->isa<ValueTuple>() || elem->isa<ValueList>()) {
const auto &vec = GetValue<std::vector<ValuePtr>>(elem);


+ 4
- 0
mindspore/ccsrc/pipeline/jit/pass.cc View File

@@ -109,6 +109,7 @@ bool CleanAfterOptAPass(const ResourcePtr &res) {
}

FuncGraphPtr PrimBpOptPassStep1(const opt::irpass::OptimizeIRPassLib &irpass, const ResourcePtr &res) {
MS_EXCEPTION_IF_NULL(res->func_graph());
opt::OptPassConfig pynative_eliminate = opt::OptPassConfig({
irpass.pynative_eliminate_,
});
@@ -146,6 +147,7 @@ FuncGraphPtr PrimBpOptPassStep1(const opt::irpass::OptimizeIRPassLib &irpass, co
}

FuncGraphPtr PrimBpOptPassStep2(const opt::irpass::OptimizeIRPassLib &irpass, const ResourcePtr &res) {
MS_EXCEPTION_IF_NULL(res->func_graph());
opt::OptPassConfig special_op_simplify = opt::OptPassConfig({
irpass.switch_simplify_,
irpass.reduce_eliminate_,
@@ -175,6 +177,7 @@ FuncGraphPtr PrimBpOptPassStep2(const opt::irpass::OptimizeIRPassLib &irpass, co

FuncGraphPtr BpropGraphFinalOptPass(const ResourcePtr &res) {
MS_EXCEPTION_IF_NULL(res);
MS_EXCEPTION_IF_NULL(res->func_graph());
if (!TransformTopGraphPass(res)) {
MS_LOG(EXCEPTION) << "Run TransformTopGraphPass failed";
}
@@ -601,6 +604,7 @@ void UpdateFuncGraphParameter(const FuncGraphPtr &func_graph) {
continue;
}
AbstractBasePtr par_abs = param_node->abstract();
MS_EXCEPTION_IF_NULL(par_abs);
if (par_abs->isa<abstract::AbstractUndetermined>() ||
(MsContext::GetInstance()->get_param<bool>(MS_CTX_GRAD_FOR_SCALAR) && par_abs->BuildType() != nullptr &&
par_abs->BuildType()->isa<Number>())) {


+ 9
- 4
mindspore/ccsrc/pipeline/jit/pipeline.cc View File

@@ -453,6 +453,7 @@ void ExecutorPy::GetWeightInfo(const CNodePtr &root_node, const AnfNodePtr &weig
std::map<std::string, std::pair<PrimitivePyAdapterPtr, std::string>> *fake_quant_table) {
std::string weight_name;
auto x = root_node->input(1);
MS_EXCEPTION_IF_NULL(x);
if (IsPrimitiveCNode(weight_node, prim::kPrimLoad)) {
weight_name = weight_node->cast<CNodePtr>()->input(1)->cast<ParameterPtr>()->name();
} else {
@@ -490,7 +491,8 @@ void ExecutorPy::GetWeightInfo(const CNodePtr &root_node, const AnfNodePtr &weig
if (cnode == nullptr || IsPrimitiveCNode(cnode, prim::kPrimLoad) || cnode->size() != 4) {
return;
}
auto fakequant_min_node = cnode->input(2);
const size_t fakequant_index = 2;
auto fakequant_min_node = cnode->input(fakequant_index);
if (!fakequant_min_node->isa<Parameter>() && !IsPrimitiveCNode(fakequant_min_node, prim::kPrimLoad)) {
return;
}
@@ -525,13 +527,14 @@ std::map<std::string, std::pair<PrimitivePyAdapterPtr, std::string>> ExecutorPy:
IsPrimitiveCNode(node, prim::kPrimFakeLearnedScaleQuantPerLayer) ||
IsPrimitiveCNode(node, prim::kPrimFakeLearnedScaleQuantPerChannel);
};
const size_t root_node_size = 3;
const size_t weight_index = 2;
for (const auto &node : nodes) {
auto root_node = node->cast<CNodePtr>();
const size_t root_node_size = 3;
if (root_node == nullptr || root_node->size() != root_node_size) {
continue;
}
auto weight = root_node->input(2);
auto weight = root_node->input(weight_index);
if (!is_quant_cnode(weight)) {
auto tuple_node = weight->cast<CNodePtr>();
if (tuple_node != nullptr) {
@@ -545,7 +548,8 @@ std::map<std::string, std::pair<PrimitivePyAdapterPtr, std::string>> ExecutorPy:
}
// get parameter weight's name
auto cnode = weight->cast<CNodePtr>();
auto weight_node = cnode->input(2);
MS_EXCEPTION_IF_NULL(cnode);
auto weight_node = cnode->input(weight_index);
if (!weight_node->isa<Parameter>() && !IsPrimitiveCNode(weight_node, prim::kPrimLoad)) {
continue;
}
@@ -880,6 +884,7 @@ void ProcessVmArgInner(const py::tuple &args, const ResourcePtr &res, VectorRef
for (std::size_t i = (*arg_list).size(); i < graph_params_size; i++) {
MS_EXCEPTION_IF_NULL(graph_params[i]);
auto param_ptr = (graph_params[i])->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(param_ptr);
if (!param_ptr->has_default()) {
MS_LOG(EXCEPTION) << "Parameter[" << i << "] has no default param";
}


+ 6
- 2
mindspore/ccsrc/pipeline/jit/prim_bprop_optimizer.cc View File

@@ -75,6 +75,7 @@ void PrimBpropOptGraphLevel2Info::AnalysisArgUsingInfo(const FuncGraphManagerPtr
void PrimBpropOptGraphLevel2Info::AnalysisNodeUsingInfo(const NodeUsersMap &node_users,
const std::shared_ptr<AnfNode> &param,
ParamUsingInfo *arg_info) const {
MS_EXCEPTION_IF_NULL(arg_info);
auto iter = node_users.find(param);

if (iter == node_users.end()) {
@@ -108,11 +109,13 @@ void PrimBpropOptGraphLevel2Info::AalysisForTupleGetItem(const NodeUsersMap &nod
ParamUsingInfo *arg_info, const AnfNodePtr &user_node) const {
MS_EXCEPTION_IF_NULL(arg_info);
auto cnode = user_node->cast<CNodePtr>();
if (cnode->size() != 3) {
const size_t tuple_get_item_size = 3;
const size_t index = 2;
if (cnode->size() != tuple_get_item_size) {
MS_LOG(EXCEPTION) << "TupleGetItem Node:" << user_node->ToString() << " of bp_graph:" << opt_func_graph_->ToString()
<< "input size is:" << cnode->size();
}
auto idx_node = cnode->input(2);
auto idx_node = cnode->input(index);
if (!idx_node->isa<ValueNode>()) {
MS_LOG(EXCEPTION) << "tuple :" << param->ToString() << " of bp_graph:" << opt_func_graph_->ToString()
<< " unexpected used by node:" << user_node->ToString()
@@ -140,6 +143,7 @@ void PrimBpropOptGraphLevel2Info::ArgInfoRefresh(const std::shared_ptr<AnfNode>
ParamUsingInfo *arg_info) const {
MS_EXCEPTION_IF_NULL(arg_info);
auto abs = param->abstract();
MS_EXCEPTION_IF_NULL(abs);
if (abs->isa<abstract::AbstractTensor>()) {
arg_info->tuple_flg_ = false;
MS_LOG(DEBUG) << "param abstract:" << param->ToString() << " is a AbstractTensor";


+ 2
- 0
mindspore/ccsrc/pipeline/jit/static_analysis/auto_monad.cc View File

@@ -346,6 +346,7 @@ class SideEffectFinder {

static void PushToOrderList(const FuncGraphPtr &fg, const CNodePtr &cnode, OrderedSet<CNodePtr> *new_order_list) {
MS_EXCEPTION_IF_NULL(cnode);
MS_EXCEPTION_IF_NULL(new_order_list);
if (new_order_list->contains(cnode)) {
return;
}
@@ -554,6 +555,7 @@ class SideEffectFinder {
}

EffectInfo TraceTupleCNodeEffectInfo(const CNodePtr &cnode, std::stack<int64_t> *tuple_indexes) {
MS_EXCEPTION_IF_NULL(tuple_indexes);
auto prim = GetPrimitive(cnode);
// Trace MakeTuple.
if (IsPrimitiveEquals(prim, prim::kPrimMakeTuple)) {


+ 10
- 4
mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.cc View File

@@ -297,8 +297,10 @@ AbstractBasePtrList FuncGraphEvaluator::BroadenUndeterminedArgs(const AbstractBa
func_graph_->set_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES, true);
func_graph_->joined_shapes_.clear();
std::transform(joined_args_spec_list_1.begin(), joined_args_spec_list_1.end(),
std::back_inserter(func_graph_->joined_shapes_),
[](const AbstractBasePtr &arg_spec) { return arg_spec->GetShapeTrack(); });
std::back_inserter(func_graph_->joined_shapes_), [](const AbstractBasePtr &arg_spec) {
MS_EXCEPTION_IF_NULL(arg_spec);
return arg_spec->GetShapeTrack();
});
joined_args_spec_list_1 = NormalizeArgs(joined_args_spec_list_1);
MS_LOG(DEBUG) << "Set " << func_graph_->ToString() << " with IGNORE_VALUES flag.";
}
@@ -316,8 +318,10 @@ AbstractBasePtrList FuncGraphEvaluator::BroadenUndeterminedArgs(const AbstractBa
func_graph_->set_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES, true);
func_graph_->joined_shapes_.clear();
std::transform(joined_args_spec_list_2.begin(), joined_args_spec_list_2.end(),
std::back_inserter(func_graph_->joined_shapes_),
[](const AbstractBasePtr &arg_spec) { return arg_spec->GetShapeTrack(); });
std::back_inserter(func_graph_->joined_shapes_), [](const AbstractBasePtr &arg_spec) {
MS_EXCEPTION_IF_NULL(arg_spec);
return arg_spec->GetShapeTrack();
});
joined_args_spec_list_2 = NormalizeArgs(joined_args_spec_list_2);
MS_LOG(DEBUG) << "Set " << func_graph_->ToString() << " with IGNORE_VALUES flag.";
}
@@ -420,6 +424,7 @@ EvalResultPtr TrivialPrimEvaluator::Run(AnalysisEnginePtr engine, const ConfigPt
[is_py_eval](const ConfigPtr &conf) -> AbstractBasePtr {
MS_EXCEPTION_IF_NULL(conf);
auto abstract = conf->ObtainEvalResult()->abstract();
MS_EXCEPTION_IF_NULL(abstract);
// broaden the ref_key, while infer python prim for cache
if (is_py_eval && abstract->isa<AbstractRef>()) {
auto abs_ref = abstract->cast<AbstractRefPtr>();
@@ -518,6 +523,7 @@ EvalResultPtr JEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &arg
bool enable_sparse = context->get_param<bool>(MS_CTX_ENABLE_SPARSE);
(void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(bparams),
[&enable_sparse](const AbstractBasePtr &arg_spec) -> AbstractBasePtr {
MS_EXCEPTION_IF_NULL(arg_spec);
if (enable_sparse && arg_spec->isa<AbstractTensor>()) {
return std::make_shared<AbstractUndetermined>();
}


+ 5
- 0
mindspore/ccsrc/pipeline/jit/static_analysis/order_enforce.cc View File

@@ -57,6 +57,11 @@ class OrderEnforcer {
return;
}
auto update_state = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(update_state);
const size_t update_state_inputs_size = 3;
if (update_state->inputs().size() < update_state_inputs_size) {
MS_LOG(ERROR) << "UpdateState inputs size is less than 3, node is:" << update_state->DebugString();
}
if (!HasAbstractUMonad(update_state->input(1))) {
// Skip UpdateStates for IO.
return;


+ 5
- 0
mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc View File

@@ -1055,6 +1055,8 @@ class RefToEmbedEvaluator : public SymbolicPrimEvaluator {
}

std::string name = refkey->tag();
MS_EXCEPTION_IF_NULL(node_conf->node());
MS_EXCEPTION_IF_NULL(node_conf->node()->func_graph());
const auto &manager = node_conf->node()->func_graph()->manager();
auto node = FindParameterNodeByString(manager, name);
if (node == nullptr) {
@@ -1217,6 +1219,8 @@ class PartialEvaluator : public Evaluator {

MS_EXCEPTION_IF_NULL(out_conf);
MS_EXCEPTION_IF_NULL(out_conf->node());
MS_EXCEPTION_IF_NULL(args_conf_list[0]);
MS_EXCEPTION_IF_NULL(args_conf_list[0]->ObtainEvalResult());
auto arg0_value = args_conf_list[0]->ObtainEvalResult()->abstract();
AbstractBasePtrList args_spec_list{arg0_value};
// Func in hypermap(partial(Func, arg0), arg1, arg2) may become Poly Node.
@@ -1232,6 +1236,7 @@ class PartialEvaluator : public Evaluator {
// Sometimes, node[0] in out_conf becomes phi0;
if (func->isa<PrimitiveAbstractClosure>()) {
auto prim_func = dyn_cast<PrimitiveAbstractClosure>(func);
MS_EXCEPTION_IF_NULL(prim_func->prim());
if (prim_func->prim()->isa<prim::DoSignaturePrimitive>()) {
prim::DoSignaturePrimitivePtr do_signature_prim = dyn_cast<prim::DoSignaturePrimitive>(prim_func->prim());
return HandleDoSignature(engine, do_signature_prim->function(), out_conf);


+ 3
- 1
mindspore/ccsrc/pipeline/jit/static_analysis/program_specialize.cc View File

@@ -553,10 +553,11 @@ void FuncGraphSpecializer::ProcessCNode(const CNodePtr &new_node) {
// First element is func so arg start from 1
std::vector<AnfNodePtr> args(new_inputs.begin() + 1, new_inputs.end());
// CNode(CNode(Partial, f, arg1), arg2, ...) --> CNode(f, arg1, arg2, ...)
const size_t arg_start_index = 2;
while (IsPrimitiveCNode(func, prim::kPrimPartial)) {
std::vector<AnfNodePtr> inputs = func->cast<CNodePtr>()->inputs();
// First element is partial, second is func so arg is start from 2
(void)args.insert(args.begin(), inputs.begin() + 2, inputs.end());
(void)args.insert(args.begin(), inputs.begin() + arg_start_index, inputs.end());
func = inputs[1];
}
new_inputs = args;
@@ -738,6 +739,7 @@ AnfNodePtr FuncGraphSpecializer::BuildPossibleValueNode(const AnfNodePtr &origin
} else {
return nullptr;
}
MS_EXCEPTION_IF_NULL(value);
if (!value->isa<FuncGraph>() || value->cast<FuncGraphPtr>()->parent() == nullptr ||
(IsValueNode<FuncGraph>(origin_node) && IsVisible(func_graph_, value->cast<FuncGraphPtr>()->parent()))) {
return BuildValueNode(value, ival);


+ 1
- 0
mindspore/ccsrc/pipeline/jit/static_analysis/stack_frame.cc View File

@@ -86,6 +86,7 @@ StackFramePtr StackFrame::DoJump(const AnalysisEnginePtr &engine, const CNodePtr

// Find parent context and create new context.
AnalysisContextPtr parent_context = GetParentContext(fg_evaluator, graph_func);
MS_EXCEPTION_IF_NULL(parent_context);
auto new_context = parent_context->NewFuncGraphContext(fg, args_abs_list);

// Evaluate the parameters with new context.


+ 6
- 0
mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.cc View File

@@ -210,6 +210,8 @@ void AnalysisEngine::CheckNoStackInSameFuncGraph(const AnfNodeConfigPtr &conf) {
return;
}
auto &previous_stack = list.back();
MS_EXCEPTION_IF_NULL(previous_stack->node());
MS_EXCEPTION_IF_NULL(conf->node());
auto previous_cnode_fg = previous_stack->node()->func_graph();
auto current_cnode_fg = conf->node()->func_graph();
if (previous_cnode_fg != current_cnode_fg) { // Normal.
@@ -664,6 +666,10 @@ EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector<Evalua
AbstractBasePtrList out_specs;
EvaluatorPtr last_eval = nullptr;
AbstractBasePtr last_abstract = nullptr;
const size_t evaluators_size = 2;
if (evaluators.size() < evaluators_size) {
MS_LOG(ERROR) << "evaluators size is less than 2";
}
multi_poss_[evaluators[0]] = evaluators[1];
multi_poss_[evaluators[1]] = evaluators[0];
AbstractBasePtrList args_spec_list;


+ 2
- 2
mindspore/ccsrc/pybind_api/ir/signature_py.cc View File

@@ -32,8 +32,8 @@ static ValuePtr PyArgToValue(const py::object &arg) {
// Bind SignatureEnumRW as a python class.
REGISTER_PYBIND_DEFINE(SignatureEnumRW, ([](const py::module *m) {
(void)py::class_<Signature>(*m, "Signature")
.def(py::init([](std::string name, SignatureEnumRW rw, SignatureEnumKind kind,
py::object arg_default, SignatureEnumDType dtype) {
.def(py::init([](const std::string name, SignatureEnumRW rw, SignatureEnumKind kind,
const py::object arg_default, SignatureEnumDType dtype) {
auto default_value = PyArgToValue(arg_default);
return Signature(name, rw, kind, default_value, dtype);
}));


Loading…
Cancel
Save