/** * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). * * Copyright 2019-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "pipeline/jit/static_analysis/prim.h" #include #include #include #include #include #include "utils/hash_set.h" #include "frontend/operator/cc_implementations.h" #include "frontend/operator/ops.h" #include "frontend/operator/composite/do_signature.h" #include "frontend/operator/prim_to_function.h" #include "abstract/utils.h" #include "utils/symbolic.h" #include "pipeline/jit/resource.h" #include "pipeline/jit/parse/resolve.h" #include "pipeline/jit/pipeline.h" #include "utils/convert_utils.h" #include "utils/convert_utils_py.h" #include "utils/ms_context.h" #include "pipeline/jit/parse/data_converter.h" #include "abstract/primitive_infer_map.h" #include "abstract/param_validator.h" #include "utils/ms_utils.h" #include "utils/shape_utils.h" #include "utils/parallel_node_check.h" #include "frontend/operator/ops_front_infer_function.h" namespace mindspore { namespace abstract { using mindspore::parse::PyObjectWrapper; mindspore::HashSet prims_to_skip_undetermined_infer{ "MakeTuple", "make_list", "Switch", "env_setitem", "env_getitem", "Load", "UpdateState"}; EvalResultPtr DoSignatureEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, const AnfNodeConfigPtr &out_conf) { MS_EXCEPTION_IF_NULL(engine); MS_EXCEPTION_IF_NULL(out_conf); AbstractBasePtrList args_spec_list; (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list), [](const ConfigPtr &ref) -> AbstractBasePtr { MS_EXCEPTION_IF_NULL(ref); MS_EXCEPTION_IF_NULL(ref->ObtainEvalResult()); return ref->ObtainEvalResult()->abstract(); }); auto do_signature = prim_->cast(); MS_EXCEPTION_IF_NULL(do_signature); auto &func = do_signature->function(); if (func->isa()) { auto sig_prim = func->cast(); if (prims_to_skip_undetermined_infer.find(sig_prim->name()) == prims_to_skip_undetermined_infer.end()) { auto ret_abstract = AbstractEval(args_spec_list); if (ret_abstract != nullptr) { MS_LOG(DEBUG) << "DoSignatureEvaluator eval Undetermined"; return ret_abstract; } } } if (out_conf->node() == nullptr || !out_conf->node()->isa()) { MS_LOG(EXCEPTION) << "Node of out_conf should be CNode"; } auto out_node = dyn_cast(out_conf->node()); MS_EXCEPTION_IF_NULL(out_node); const auto &out_node_inputs = out_node->inputs(); if (out_node->inputs().size() == 0 || (out_node_inputs.size() - 1) != args_conf_list.size()) { MS_LOG(EXCEPTION) << "Op: " << func->ToString() << " args size should equal to inputs size minus 1, but args size " << args_conf_list.size() << ", inputs size " << out_node_inputs.size(); } AnfNodePtrList args_inputs{out_node_inputs.begin() + 1, out_node_inputs.end()}; ScopePtr scope = kDefaultScope; if (out_conf != nullptr) { scope = out_conf->node()->scope(); } ScopeGuard scope_guard(scope); AnfNodePtr new_node = nullptr; if (bound_node() != nullptr) { TraceGuard trace_guard(std::make_shared(bound_node()->debug_info())); new_node = prim::GenerateCNode(out_node->func_graph(), prim_->ToString(), func, args_spec_list, args_inputs); } else { new_node = prim::GenerateCNode(out_node->func_graph(), prim_->ToString(), func, args_spec_list, args_inputs); } AnfNodeConfigPtr fn_conf = engine->MakeConfig(new_node, out_conf->context(), out_conf->func_graph()); if (out_node->isa()) { auto out_cnode = out_node->cast(); auto new_cnode = new_node->cast(); new_cnode->CloneCNodeInfo(out_cnode); } return engine->ForwardConfig(out_conf, fn_conf); } static AbstractBasePtrList GetUnpackGraphSpecArgsList(AbstractBasePtrList args_spec_list, bool need_unpack) { // arg[0] is the func graph to unpack, ignore it AbstractBasePtrList specialize_args_before_unpack(args_spec_list.begin() + 1, args_spec_list.end()); AbstractBasePtrList graph_specialize_args; if (need_unpack) { for (size_t index = 0; index < specialize_args_before_unpack.size(); index++) { MS_EXCEPTION_IF_NULL(specialize_args_before_unpack[index]); if (specialize_args_before_unpack[index]->isa()) { auto arg_tuple = specialize_args_before_unpack[index]->cast(); std::transform(arg_tuple->elements().begin(), arg_tuple->elements().end(), std::back_inserter(graph_specialize_args), [](AbstractBasePtr abs) { return abs; }); } else if (specialize_args_before_unpack[index]->isa()) { auto arg_dict = specialize_args_before_unpack[index]->cast(); auto dict_elems = arg_dict->elements(); (void)std::transform( dict_elems.begin(), dict_elems.end(), std::back_inserter(graph_specialize_args), [](const AbstractAttribute &item) { return std::make_shared(item.first, item.second); }); } else { MS_LOG(EXCEPTION) << "UnpackGraph require args should be tuple or dict, but got " << specialize_args_before_unpack[index]->ToString(); } } } else { graph_specialize_args = specialize_args_before_unpack; } return graph_specialize_args; } EvalResultPtr UnpackGraphEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, const AnfNodeConfigPtr &out_conf) { MS_EXCEPTION_IF_NULL(engine); MS_EXCEPTION_IF_NULL(out_conf); MS_EXCEPTION_IF_NULL(out_conf->node()); if (out_conf->node() == nullptr || !out_conf->node()->isa()) { MS_LOG(EXCEPTION) << "Node of out_conf should be CNode"; } auto unpack_graph = prim_->cast(); MS_EXCEPTION_IF_NULL(unpack_graph); auto out_node = out_conf->node()->cast(); MS_EXCEPTION_IF_NULL(out_node); const auto &out_node_inputs = out_node->inputs(); if (out_node->inputs().empty() || (out_node_inputs.size() - 1) != args_conf_list.size()) { MS_LOG(EXCEPTION) << "UnpackGraphPrimitive" << " args size should equal to inputs size minus 1, but args size " << args_conf_list.size() << ", inputs size " << out_node_inputs.size(); } AbstractBasePtrList args_spec_list; (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list), [](const ConfigPtr &ref) -> AbstractBasePtr { MS_EXCEPTION_IF_NULL(ref); MS_EXCEPTION_IF_NULL(ref->ObtainEvalResult()); return ref->ObtainEvalResult()->abstract(); }); // get the forward graph if (args_spec_list.empty()) { MS_LOG(EXCEPTION) << "args_spec_list can't be empty."; } MS_EXCEPTION_IF_NULL(args_spec_list[0]); auto fn = args_spec_list[0]->cast(); if (fn == nullptr) { MS_LOG(EXCEPTION) << "UnpackGraphPrimitive arg0 must be AbstractFunction, but " << args_spec_list[0]->ToString(); } auto real_fn = fn->cast(); MS_EXCEPTION_IF_NULL(real_fn); FuncGraphPtr forward_graph = real_fn->func_graph(); MS_EXCEPTION_IF_NULL(forward_graph); AbstractBasePtrList graph_specialize_args = GetUnpackGraphSpecArgsList(args_spec_list, unpack_graph->need_unpack_args()); AbstractBasePtrList graph_specialize_args_without_sens; if (unpack_graph->with_sens_in_args() && graph_specialize_args.empty()) { MS_EXCEPTION(ValueError) << "Grad with sens, but the sens is not provided."; } (void)std::transform(graph_specialize_args.begin(), graph_specialize_args.end() - (unpack_graph->with_sens_in_args() ? 1 : 0), std::back_inserter(graph_specialize_args_without_sens), [](AbstractBasePtr abs) { return abs; }); auto new_graph = forward_graph->GenerateGraph(graph_specialize_args_without_sens); engine->func_graph_manager()->AddFuncGraph(new_graph); ScopePtr scope = kDefaultScope; if (out_conf != nullptr) { scope = out_conf->node()->scope(); } ScopeGuard scope_guard(scope); AnfNodePtr new_vnode = NewValueNode(new_graph); AnfNodeConfigPtr fn_conf = engine->MakeConfig(new_vnode, out_conf->context(), out_conf->func_graph()); return engine->ForwardConfig(out_conf, fn_conf); } AnfNodePtr MixedPrecisionCastHelper(const AnfNodePtr &source_node, const AbstractBasePtr &node_type, const AnfNodePtr &target_type, const FuncGraphPtr &func_graph) { MS_EXCEPTION_IF_NULL(node_type); MS_EXCEPTION_IF_NULL(func_graph); AnfNodePtr target_node = source_node; if (node_type->isa()) { auto x = node_type->cast(); if (x->element()->BuildType()->isa()) { auto cast = prim::GetPythonOps("cast", "mindspore.ops.functional"); MS_EXCEPTION_IF_NULL(cast); target_node = func_graph->NewCNodeAfter(source_node, {NewValueNode(cast), source_node, target_type}); } } else if (node_type->isa()) { auto x = node_type->cast(); auto &items = x->elements(); std::vector nodes; nodes.emplace_back(NewValueNode(prim::kPrimMakeTuple)); int64_t idx = 0; for (const auto &item : items) { AnfNodePtr tuple_node = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), source_node, NewValueNode(idx)}); AnfNodePtr node = MixedPrecisionCastHelper(tuple_node, item, target_type, func_graph); nodes.emplace_back(node); ++idx; } target_node = func_graph->NewCNode(nodes); } else if (node_type->isa()) { auto x = node_type->cast(); auto &items = x->elements(); std::vector dict_key_nodes; std::vector dict_value_nodes; dict_key_nodes.emplace_back(NewValueNode(prim::kPrimMakeTuple)); dict_value_nodes.emplace_back(NewValueNode(prim::kPrimMakeTuple)); for (const auto &item : items) { AnfNodePtr dict_value_node = func_graph->NewCNode({NewValueNode(prim::kPrimDictGetItem), source_node, NewValueNode(item.first)}); AnfNodePtr node = MixedPrecisionCastHelper(dict_value_node, item.second, target_type, func_graph); dict_key_nodes.emplace_back(NewValueNode(item.first)); dict_value_nodes.emplace_back(node); } target_node = func_graph->NewCNode({NewValueNode(prim::kPrimMakeDict), func_graph->NewCNode(std::move(dict_key_nodes)), func_graph->NewCNode(std::move(dict_value_nodes))}); } else if (node_type->isa()) { auto x = node_type->cast(); std::string kwarg_key = x->get_key(); AnfNodePtr kwarg_value_node = func_graph->NewCNode({NewValueNode(prim::kPrimExtractKeywordArg), NewValueNode(kwarg_key), source_node}); AnfNodePtr node = MixedPrecisionCastHelper(kwarg_value_node, x->get_arg(), target_type, func_graph); target_node = func_graph->NewCNode({NewValueNode(prim::kPrimMakeKeywordArg), NewValueNode(kwarg_key), node}); } return target_node; } EvalResultPtr MixedPrecisionCastEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, const AnfNodeConfigPtr &out_conf) { MS_EXCEPTION_IF_NULL(engine); AbstractBasePtrList args_spec_list; MS_EXCEPTION_IF_NULL(out_conf); if (out_conf->node() == nullptr || !out_conf->node()->isa()) { MS_LOG(EXCEPTION) << "Node of out_conf should be CNode"; } auto out_node = out_conf->node()->cast(); MS_EXCEPTION_IF_NULL(out_node); const auto &out_node_inputs = out_node->inputs(); if (out_node->inputs().empty() || (out_node_inputs.size() - 1) != args_conf_list.size()) { MS_LOG(EXCEPTION) << "MixedPrecisionCast" << " args size should equal to inputs size minus 1, but args size " << args_conf_list.size() << ", inputs size " << out_node_inputs.size(); } (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list), [](const ConfigPtr &ref) -> AbstractBasePtr { return ref->ObtainEvalResult()->abstract(); }); ScopeGuard scope_guard(out_conf->node()->scope()); TraceGuard trace_guard(std::make_shared(out_conf->node()->debug_info())); FuncGraphPtr func_graph = out_node->func_graph(); constexpr size_t source_node_index = 2; if (out_node_inputs.size() <= source_node_index) { MS_LOG(EXCEPTION) << "Input size:" << out_node_inputs.size() << " should bigger than 2."; } AnfNodePtr new_node = MixedPrecisionCastHelper(out_node_inputs[source_node_index], args_spec_list[1], out_node_inputs[1], func_graph); AnfNodeConfigPtr fn_conf = engine->MakeConfig(new_node, out_conf->context(), out_conf->func_graph()); if (new_node->isa()) { auto new_cnode = new_node->cast(); new_cnode->CloneCNodeInfo(out_node); } return engine->ForwardConfig(out_conf, fn_conf); } namespace { py::object BuildValue(const ValuePtr &value_ptr) { if (value_ptr == nullptr) { return py::none(); } else { return ValueToPyData(value_ptr); } } py::dict AbstractTupleToPython(const AbstractBasePtr &abs_base) { auto arg_tuple = dyn_cast(abs_base); MS_EXCEPTION_IF_NULL(arg_tuple); size_t len = arg_tuple->size(); py::tuple shape_tuple(len); py::tuple dtype_tuple(len); py::tuple value_tuple(len); py::tuple min_value_tuple(len); py::tuple max_value_tuple(len); py::tuple min_shape_tuple(len); py::tuple max_shape_tuple(len); bool dyn_shape = false; bool dyn_value = false; for (size_t i = 0; i < len; i++) { auto arg = arg_tuple->elements()[i]; py::dict out = ConvertAbstractToPython(arg); shape_tuple[i] = out[ATTR_SHAPE]; dtype_tuple[i] = out[ATTR_DTYPE]; value_tuple[i] = out[ATTR_VALUE]; // Elements in tuple is tensor shape value. if (out.contains(py::str(ATTR_MIN_VALUE)) && out.contains(py::str(ATTR_MAX_VALUE))) { min_value_tuple[i] = out[ATTR_MIN_VALUE]; max_value_tuple[i] = out[ATTR_MAX_VALUE]; dyn_value = true; } // Elements in tuple is tensor, which shape is dynamic. if (out.contains(py::str(ATTR_MIN_SHAPE)) && out.contains(py::str(ATTR_MAX_SHAPE))) { min_shape_tuple[i] = out[ATTR_MIN_SHAPE]; max_shape_tuple[i] = out[ATTR_MAX_SHAPE]; dyn_shape = true; } } auto dic = py::dict(); dic[ATTR_SHAPE] = shape_tuple; dic[ATTR_DTYPE] = dtype_tuple; MS_EXCEPTION_IF_NULL(arg_tuple->BuildValue()); if (arg_tuple->BuildValue()->isa()) { dic[ATTR_VALUE] = py::none(); } else { dic[ATTR_VALUE] = value_tuple; } if (dyn_value) { dic[ATTR_MIN_VALUE] = min_value_tuple; dic[ATTR_MAX_VALUE] = max_value_tuple; } if (dyn_shape) { dic[ATTR_MIN_SHAPE] = min_shape_tuple; dic[ATTR_MAX_SHAPE] = max_shape_tuple; } return dic; } py::dict AbstractListToPython(const AbstractBasePtr &abs_base) { auto arg_list = dyn_cast(abs_base); MS_EXCEPTION_IF_NULL(arg_list); size_t len = arg_list->size(); py::list shape_list(len); py::list dtype_list(len); py::list value_list(len); py::list min_shape_list(len); py::list max_shape_list(len); bool dyn_shape = false; for (size_t i = 0; i < len; i++) { py::dict out = ConvertAbstractToPython(arg_list->elements()[i]); shape_list[i] = out[ATTR_SHAPE]; dtype_list[i] = out[ATTR_DTYPE]; value_list[i] = out[ATTR_VALUE]; // Elements in list is tensor, which shape is dynamic. if (out.contains(py::str(ATTR_MIN_SHAPE)) && out.contains(py::str(ATTR_MAX_SHAPE))) { min_shape_list[i] = out[ATTR_MIN_SHAPE]; max_shape_list[i] = out[ATTR_MAX_SHAPE]; dyn_shape = true; } } auto dic = py::dict(); dic[ATTR_SHAPE] = shape_list; dic[ATTR_DTYPE] = dtype_list; MS_EXCEPTION_IF_NULL(arg_list->BuildValue()); if (arg_list->BuildValue()->isa()) { dic[ATTR_VALUE] = py::none(); } else { dic[ATTR_VALUE] = value_list; } if (dyn_shape) { dic[ATTR_MIN_SHAPE] = min_shape_list; dic[ATTR_MAX_SHAPE] = max_shape_list; } return dic; } void ConvertAbstractTensorToPython(const AbstractBasePtr &abs_base, py::dict *dic) { auto arg_tensor = dyn_cast(abs_base); MS_EXCEPTION_IF_NULL(dic); MS_EXCEPTION_IF_NULL(arg_tensor); MS_EXCEPTION_IF_NULL(arg_tensor->shape()); (*dic)[ATTR_SHAPE] = arg_tensor->shape()->shape(); const auto &min_shape = arg_tensor->shape()->min_shape(); const auto &max_shape = arg_tensor->shape()->max_shape(); if (!min_shape.empty() && !max_shape.empty()) { (*dic)[ATTR_MIN_SHAPE] = min_shape; (*dic)[ATTR_MAX_SHAPE] = max_shape; } auto min_value = arg_tensor->get_min_value(); auto max_value = arg_tensor->get_max_value(); if (min_value != nullptr && max_value != nullptr) { (*dic)[ATTR_MIN_VALUE] = BuildValue(min_value); (*dic)[ATTR_MAX_VALUE] = BuildValue(max_value); } (*dic)[ATTR_DTYPE] = arg_tensor->BuildType(); (*dic)[ATTR_VALUE] = BuildValue(arg_tensor->BuildValue()); } void ConvertAbstractFunctionToPython(const AbstractBasePtr &abs_base, py::dict *dic) { MS_EXCEPTION_IF_NULL(dic); MS_EXCEPTION_IF_NULL(abs_base); (*dic)[ATTR_SHAPE] = py::none(); (*dic)[ATTR_DTYPE] = abs_base->BuildType(); (*dic)[ATTR_VALUE] = py::none(); if (abs_base->isa()) { AbstractBasePtrList args = abs_base->cast()->args(); if (!args.empty()) { MS_EXCEPTION_IF_NULL(args[0]->BuildValue()); auto value = args[0]->BuildValue()->cast(); if (value != nullptr) { (*dic)[ATTR_DTYPE] = std::make_shared(); (*dic)[ATTR_VALUE] = value->obj(); } } } } bool CheckType(const TypePtr &expected_type, const TypePtr &x) { // As x and predicate both are mindspore type statically, here we only to judge whether // x is predicate or is a subclass of predicate. return IsIdentidityOrSubclass(x, expected_type); } // Join all types in args_type_list; TypePtr TypeJoin(const TypePtrList &args_type_list) { if (args_type_list.empty()) { MS_LOG(EXCEPTION) << "args_type_list is empty"; } TypePtr type_tmp = args_type_list[0]; for (std::size_t i = 1; i < args_type_list.size(); i++) { type_tmp = abstract::TypeJoin(type_tmp, args_type_list[i]); } return type_tmp; } TypePtr CheckTypeList(const TypePtr &predicate, const TypePtrList &args_type_list) { MS_EXCEPTION_IF_NULL(predicate); for (const auto &arg_type : args_type_list) { MS_EXCEPTION_IF_NULL(arg_type); if (!CheckType(predicate, arg_type)) { MS_LOG(EXCEPTION) << "The expected is " << predicate->ToString() << ", not " << arg_type->ToString(); } } return TypeJoin(args_type_list); } } // end anonymous namespace py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base) { MS_EXCEPTION_IF_NULL(abs_base); auto dic = py::dict(); if (abs_base->isa()) { ConvertAbstractTensorToPython(abs_base, &dic); } else if (abs_base->isa()) { auto arg = dyn_cast(abs_base); dic[ATTR_SHAPE] = arg->shape()->shape(); dic[ATTR_DTYPE] = arg->BuildType(); dic[ATTR_VALUE] = BuildValue(arg->BuildValue()); } else if (abs_base->isa()) { auto arg = dyn_cast(abs_base); dic[ATTR_SHAPE] = arg->shape()->shape(); dic[ATTR_DTYPE] = arg->BuildType(); dic[ATTR_VALUE] = BuildValue(arg->BuildValue()); } else if (abs_base->isa()) { auto arg = dyn_cast(abs_base); dic[ATTR_SHAPE] = arg->shape()->shape(); dic[ATTR_DTYPE] = arg->BuildType(); dic[ATTR_VALUE] = BuildValue(arg->BuildValue()); } else if (abs_base->isa() || abs_base->isa() || abs_base->isa()) { ShapeVector shape; dic[ATTR_SHAPE] = shape; dic[ATTR_DTYPE] = abs_base->BuildType(); dic[ATTR_VALUE] = BuildValue(abs_base->BuildValue()); } else if (abs_base->isa()) { auto arg_slice = dyn_cast(abs_base); ShapeVector shape; dic[ATTR_SHAPE] = shape; dic[ATTR_DTYPE] = arg_slice->BuildType(); dic[ATTR_VALUE] = BuildValue(arg_slice->BuildValue()); } else if (abs_base->isa()) { dic[ATTR_SHAPE] = py::none(); dic[ATTR_DTYPE] = py::ellipsis(); dic[ATTR_VALUE] = py::ellipsis(); } else if (abs_base->isa()) { return AbstractTupleToPython(abs_base); } else if (abs_base->isa()) { return AbstractListToPython(abs_base); } else if (abs_base->isa()) { dic[ATTR_SHAPE] = py::none(); dic[ATTR_DTYPE] = py::none(); dic[ATTR_VALUE] = py::none(); } else if (abs_base->isa()) { ConvertAbstractFunctionToPython(abs_base, &dic); } else if (abs_base->isa()) { auto arg = dyn_cast(abs_base); dic[ATTR_SHAPE] = py::none(); dic[ATTR_DTYPE] = arg->BuildType(); dic[ATTR_VALUE] = py::none(); } else if (abs_base->isa()) { dic[ATTR_SHAPE] = py::none(); dic[ATTR_DTYPE] = abs_base->BuildType(); dic[ATTR_VALUE] = py::none(); } else { auto value = abs_base->BuildValue(); MS_EXCEPTION_IF_NULL(value); if ((*value == *kAnyValue)) { auto value_desc = abs_base->value_desc(); MS_EXCEPTION(TypeError) << "Unsupported parameter " << (value_desc.empty() ? "type" : value_desc) << " for python primitive." << abs_base->ToString(); } MS_EXCEPTION(TypeError) << "Unsupported parameter type for python primitive, the parameter value is " << value->ToString(); } return dic; } namespace { py::tuple PreparePyInputs(const PrimitivePyPtr &, const AbstractBasePtrList &args) { // The monad parameter is defined at the end of the parameter and needs to be ignored std::size_t size_args = args.size() - GetAbstractMonadNum(args); py::tuple py_args(size_args); for (size_t i = 0; i < size_args; i++) { auto arg_i = (args)[i]; py_args[i] = ConvertAbstractToPython(arg_i); } return py_args; } void CheckCustomPrimOutputInferResult(const PrimitivePtr &prim, const AbstractBasePtr &res_spec) { MS_EXCEPTION_IF_NULL(prim); MS_EXCEPTION_IF_NULL(res_spec); const string kOutputNum = "output_num"; if (prim->IsCustomPrim()) { // Raise error if output_num is not match the infer result. auto output_num_value = prim->GetAttr(kOutputNum); if (output_num_value == nullptr) { MS_LOG(DEBUG) << "The output num may no need to check"; return; } int64_t output_num = GetValue(output_num_value); if (res_spec->isa() && output_num != 1) { MS_LOG(EXCEPTION) << "Custom operator primitive[" << prim->ToString() << "]'s attribute[output_num]:" << output_num << " not matches the infer result " << res_spec->ToString(); } else if (res_spec->isa() && (res_spec->cast()->size() != LongToSize(output_num))) { MS_LOG(EXCEPTION) << "Custom operator primitive[" << prim->ToString() << "]'s attribute[output_num]:" << output_num << " not matches the infer result " << res_spec->ToString(); } } } AbstractBasePtr PyInferRes2Abstract(const PrimitivePyPtr &prim_py, const py::dict &output) { // Convert to AbstractValue based on type and shape auto out_dtype = output[ATTR_DTYPE]; if (output[ATTR_VALUE].is_none()) { auto out_shape = output[ATTR_SHAPE]; return MakePyInferRes2Abstract(out_shape, out_dtype, output); } // Convert pyobject to Value, then to AbstractValue ValuePtr converted_ret = nullptr; TypePtr dtype = py::isinstance(out_dtype) ? out_dtype.cast() : nullptr; bool converted = parse::ConvertData(output[ATTR_VALUE], &converted_ret, false, dtype); if (!converted) { MS_LOG(EXCEPTION) << "Convert data failed"; } auto res_spec = FromValue(converted_ret); MS_EXCEPTION_IF_NULL(res_spec); if (res_spec->isa()) { // Replace to tensor constant node in specialize auto res_tensor = res_spec->cast(); res_tensor->set_value(converted_ret); SetValueRange(res_tensor, output); } CheckCustomPrimOutputInferResult(prim_py, res_spec); return res_spec; } } // end anonymous namespace EvalResultPtr StandardPrimEvaluator::RunPyInferValue(const AnalysisEnginePtr &engine, const AbstractBasePtr &abs_base, const AbstractBasePtrList &args) { auto prim_py = dyn_cast(prim_); if (prim_py == nullptr) { MS_LOG(EXCEPTION) << "The primitive with type 'kPrimTypePyCheck' should be a python primitive."; } // Call checking method 'infer_value' for python primitive MS_LOG(DEBUG) << "Begin input args checking for: " << prim_py->ToString(); auto py_args = PreparePyInputs(prim_py, args); py::tuple py_vals(py_args.size()); auto added_attrs = prim_->evaluate_added_attrs(); for (size_t i = 0; i < py_args.size(); ++i) { py_vals[i] = py_args[i][ATTR_VALUE]; } py::object py_ret = prim_py->RunInferValue(py_vals); if (py::isinstance(py_ret)) { return std::make_shared(abs_base, std::make_shared(added_attrs)); } // Convert pyobject to Value, then to AbstractValue ValuePtr converted_ret = nullptr; TypePtr dtype = abs_base->BuildType(); bool converted = parse::ConvertData(py_ret, &converted_ret, false, dtype); if (!converted) { MS_LOG(EXCEPTION) << "Convert data failed"; } auto res_spec = FromValue(converted_ret); MS_EXCEPTION_IF_NULL(res_spec); if (res_spec->isa()) { // Replace to tensor constant node in specialize auto res_tensor = res_spec->cast(); res_tensor->set_value(converted_ret); } return std::make_shared(res_spec, std::make_shared(added_attrs)); } EvalResultPtr StandardPrimEvaluator::EvalPyCheckPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) { auto prim_py = dyn_cast(prim_); if (prim_py == nullptr) { MS_LOG(EXCEPTION) << "The primitive with type 'kPrimTypePyCheck' should be a python primitive."; } // Call checking method '__check__' for subclass of 'PrimitiveWithCheck' MS_LOG(DEBUG) << "Begin input args checking for: " << prim_py->ToString(); auto py_args = PreparePyInputs(prim_py, args); prim_py->RunCheck(py_args); prim_->BeginRecordAddAttr(); AbstractBasePtr abs_base = eval_impl_.infer_shape_impl_(engine, prim_, args); prim_->EndRecordAddAttr(); auto added_attrs = prim_->evaluate_added_attrs(); if (!py::hasattr(prim_py->GetPyObj(), PY_PRIM_METHOD_INFER_VALUE)) { return std::make_shared(abs_base, std::make_shared(added_attrs)); } // Call method 'infer_value' for primitive with this method for constant propagation return RunPyInferValue(engine, abs_base, args); } EvalResultPtr StandardPrimEvaluator::EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) { if (prims_to_skip_undetermined_infer.find(prim_->name()) == prims_to_skip_undetermined_infer.end()) { auto ret_abstract = AbstractEval(args); if (ret_abstract != nullptr) { MS_LOG(DEBUG) << "StandardPrimEvaluator eval Undetermined"; return ret_abstract; } } if (prim_->prim_type() == PrimType::kPrimTypePyCheck) { return EvalPyCheckPrim(engine, args); } auto context = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context); bool need_infer_value = !eval_impl_.in_white_list_; if (need_infer_value == false) { need_infer_value = ((context->get_param(MS_CTX_EXECUTION_MODE) == kGraphMode)) && std::all_of(args.begin(), args.end(), [](const AbstractBasePtr &abs) -> bool { MS_EXCEPTION_IF_NULL(abs); auto value = abs->BuildValue(); return (value != nullptr && !value->isa() && !value->isa() && !value->isa() && !value->isa()); }); } AbstractBasePtr abs_base = nullptr; ValuePtr value = nullptr; prim_->BeginRecordAddAttr(); if (need_infer_value && eval_impl_.infer_value_impl_ != nullptr) { value = eval_impl_.infer_value_impl_(prim_, args); if (value != nullptr) { abs_base = value->ToAbstract(); prim_->EndRecordAddAttr(); auto added_attrs = prim_->evaluate_added_attrs(); return std::make_shared(abs_base, std::make_shared(added_attrs)); } } abs_base = eval_impl_.infer_shape_impl_(engine, prim_, args); prim_->EndRecordAddAttr(); auto added_attrs = prim_->evaluate_added_attrs(); auto eval_result = std::make_shared(abs_base, std::make_shared(added_attrs)); return eval_result; } EvalResultPtr PythonPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args) { auto ret_abstract = AbstractEval(args); if (ret_abstract != nullptr) { MS_LOG(DEBUG) << "PythonPrimEvaluator eval Undetermined"; return ret_abstract; } MS_LOG(DEBUG) << "Eval for:" << prim_py_->ToString(); const auto eval_result = evaluator_cache_mgr_->GetValue(args); if (eval_result != nullptr) { auto abs = eval_result->abstract()->Clone(); auto attr = eval_result->attribute(); return std::make_shared(abs, attr); } auto py_args = PreparePyInputs(prim_py_, args); prim_py_->BeginRecordAddAttr(); py::dict output = prim_py_->RunInfer(py_args); prim_py_->EndRecordAddAttr(); auto added_attrs = prim_py_->evaluate_added_attrs(); MS_LOG(DEBUG) << "Output type is " << (std::string)py::str(output); auto res_spec = PyInferRes2Abstract(prim_py_, output); MS_LOG(DEBUG) << "Python InferTensor result spec: " << res_spec->ToString() << "."; auto infer_result = std::make_shared(res_spec, std::make_shared(added_attrs)); evaluator_cache_mgr_->SetValue(args, infer_result); return infer_result; } EvalResultPtr UniformPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args) { auto ret_abstract = AbstractEval(args); if (ret_abstract != nullptr) { MS_LOG(DEBUG) << "UniformPrimEvaluator eval Undetermined"; return ret_abstract; } // if func_desc_.retval type is super class of parameter type, then make the retval type as parameter type. if (nargs_ != args.size()) { MS_LOG(EXCEPTION) << "UniformPrimEvaluator expect " << nargs_ << " args, but got " << args.size() << " inputs"; } TypePtr ret_value_type = return_value_type_; ValuePtrList value_list; for (const auto &arg : args) { // Check if all arguments are scalar type. MS_EXCEPTION_IF_NULL(arg); if (arg->isa()) { auto arg_scalar = dyn_cast(arg); auto arg_value = arg_scalar->GetValueTrack(); value_list.push_back(arg_value); } else { // Raise TypeError Expected Scalar. MS_LOG(EXCEPTION) << "Expect scalar arguments for uniform primitives."; } } for (const auto &item : type_map_) { TypePtrList selections; MS_EXCEPTION_IF_NULL(item.second); (void)std::transform(item.second->begin(), item.second->end(), std::back_inserter(selections), [&args](size_t arg_idx) -> TypePtr { if (arg_idx >= args.size()) { MS_LOG(EXCEPTION) << "Index:" << arg_idx << " out of range:" << args.size(); } MS_EXCEPTION_IF_NULL(args[arg_idx]); return args[arg_idx]->GetTypeTrack(); }); TypePtr res = CheckTypeList(item.first, selections); MS_EXCEPTION_IF_NULL(return_value_type_); MS_EXCEPTION_IF_NULL(item.first); if (*return_value_type_ == *(item.first)) { ret_value_type = res; } } ValuePtr evaluated_value = RunImpl(value_list); if (!(*evaluated_value == *kAnyValue)) { ret_value_type = evaluated_value->type(); } // for comparison primitives , return type shall have be specified to be bool. if (specify_out_type_ != nullptr) { ret_value_type = specify_out_type_; } AbstractScalarPtr abs_base = std::make_shared(evaluated_value, ret_value_type); return std::make_shared(abs_base, std::make_shared()); } ValuePtr UniformPrimEvaluator::RunImpl(const ValuePtrList &args) const { if (!eval_value_) { return kAnyValue; } else { if (std::any_of(args.begin(), args.end(), [](const ValuePtr &arg) { MS_EXCEPTION_IF_NULL(arg); return arg->isa(); })) { return kAnyValue; } return impl_(args); } } // Primitive implementation // static function start namespace { EvaluatorPtr InitStandardPrimEvaluator(PrimitivePtr primitive, const StandardPrimitiveImplReg eval_impl) { EvaluatorPtr prim_evaluator = std::make_shared(primitive, eval_impl); return prim_evaluator; } EvaluatorPtr InitUniformPrimEvaluator(const PrimitivePtr &primitive, PrimitiveImpl prim_impl, bool eval_value, const TypePtr &specify_out_type) { FunctionPtr func = nullptr; (void)prim::PrimToFunction::GetInstance().GetFunction(primitive, &func); MS_EXCEPTION_IF_NULL(func); EvaluatorPtr uniform_primitive_evaluator = std::make_shared(func, prim_impl, eval_value, specify_out_type); return uniform_primitive_evaluator; } FuncGraphPtr PyObjToGraph(const AnalysisEnginePtr &engine, const ValuePtr &method) { MS_EXCEPTION_IF_NULL(engine); MS_EXCEPTION_IF_NULL(method); if (!method->isa()) { MS_LOG(EXCEPTION) << "Method type error: " << method->ToString(); } std::shared_ptr obj = method->cast>(); FuncGraphPtr func_graph = mindspore::parse::ConvertToFuncGraph(obj->obj()); if (func_graph == nullptr) { MS_LOG(EXCEPTION) << "Parse python object: " << method->ToString() << " failed"; } FuncGraphManagerPtr manager = engine->func_graph_manager(); manager->AddFuncGraph(func_graph); return func_graph; } inline void AddToManager(const AnalysisEnginePtr &engine, const FuncGraphPtr func_graph) { MS_EXCEPTION_IF_NULL(engine); FuncGraphManagerPtr manager = engine->func_graph_manager(); manager->AddFuncGraph(func_graph); } enum class REQUIRE_TYPE { ATTR, METHOD }; EvalResultPtr StaticGetterInferred(const ValuePtr &value, const ConfigPtr &data_conf, const AnfNodeConfigPtr &old_conf, REQUIRE_TYPE require_type = REQUIRE_TYPE::METHOD) { MS_EXCEPTION_IF_NULL(old_conf); AbstractBasePtr abstract = ToAbstract(value, AnalysisContext::DummyContext(), old_conf); AbstractFunctionPtr abs_func = dyn_cast(abstract); MS_EXCEPTION_IF_NULL(abs_func); // Create new cnode std::vector input = {NewValueNode(prim::kPrimPartial)}; auto func_graph_func = dyn_cast(abs_func); if (func_graph_func != nullptr) { FuncGraphPtr fg = func_graph_func->func_graph(); input.push_back(NewValueNode(fg)); } else { auto prim_func = dyn_cast(abs_func); MS_EXCEPTION_IF_NULL(prim_func); PrimitivePtr prim = prim_func->prim(); input.push_back(NewValueNode(prim)); } AnfNodeConfigPtr conf = dyn_cast(data_conf); MS_EXCEPTION_IF_NULL(conf); input.push_back(conf->node()); MS_EXCEPTION_IF_NULL(old_conf); FuncGraphPtr func_graph = old_conf->node()->func_graph(); MS_EXCEPTION_IF_NULL(func_graph); CNodePtr new_cnode = func_graph->NewCNode(input); if (require_type == REQUIRE_TYPE::ATTR) { new_cnode = func_graph->NewCNode({new_cnode}); } AnalysisEnginePtr eng = old_conf->engine(); AnfNodeConfigPtr fn_conf = eng->MakeConfig(new_cnode, old_conf->context(), old_conf->func_graph()); return eng->ForwardConfig(old_conf, fn_conf); } EvalResultPtr GetEvaluatedValueForNameSpaceString(const AnalysisEnginePtr &, const AbstractBasePtrList &args_spec_list, const AnfNodeConfigPtr &out_conf) { // args_spec_list: same as StaticGetter if (args_spec_list.size() < 2) { MS_LOG(EXCEPTION) << "Size of args_spec_list is less than 2"; } MS_EXCEPTION_IF_NULL(out_conf); // An external type. MS_EXCEPTION_IF_NULL(args_spec_list[0]); MS_EXCEPTION_IF_NULL(args_spec_list[1]); auto data_value = args_spec_list[0]->BuildValue(); MS_EXCEPTION_IF_NULL(data_value); if (!data_value->isa()) { MS_EXCEPTION(TypeError) << "Not supported to get attribute for " << data_value->ToString() << "\nThe first argument should be a NameSpace, but got " << args_spec_list[0]->ToString(); } auto item_value = args_spec_list[1]->BuildValue(); MS_EXCEPTION_IF_NULL(item_value); if (item_value->isa()) { item_value = std::make_shared(item_value->cast()->value()); } if (!item_value->isa()) { MS_LOG(EXCEPTION) << "The value of the attribute could not be inferred: " << item_value->ToString(); } // item_name to func addr from obj_map parse::SymbolPtr symbol = item_value->cast(); parse::NameSpacePtr name_space = data_value->cast(); MS_EXCEPTION_IF_NULL(out_conf); auto out_node = out_conf->node(); FuncGraphPtr func_graph = out_node->func_graph(); MS_EXCEPTION_IF_NULL(func_graph); auto new_node = parse::ResolveSymbol(func_graph->manager(), name_space, symbol, out_node); if (new_node == nullptr) { MS_LOG(EXCEPTION) << "Resolve node failed"; } if (pipeline::GetJitLevel() == "o0" && IsValueNode(new_node)) { UpdateDebugInfo(GetValueNode(new_node), out_node->scope(), out_node->debug_info()); } // Replace old node with the resolved new node in order list. func_graph->ReplaceInOrder(out_node, new_node); AnalysisEnginePtr eng = out_conf->engine(); MS_EXCEPTION_IF_NULL(eng); AnfNodeConfigPtr fn_conf = eng->MakeConfig(new_node, out_conf->context(), out_conf->func_graph()); return eng->ForwardConfig(out_conf, fn_conf); } EvalResultPtr GetEvaluatedValueForClassAttrOrMethod(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list, const ValuePtr &item_value, const ConfigPtr &data_conf, const AnfNodeConfigPtr &out_conf) { if (args_spec_list.empty()) { MS_LOG(EXCEPTION) << "args_spec_list is empty"; } AbstractClassPtr cls = CheckArg("__FUNC__", args_spec_list, 0); // If item_value is an attribute, get abstract value from AbstractClass MS_EXCEPTION_IF_NULL(item_value); if (!item_value->isa()) { MS_LOG(EXCEPTION) << "Attribute type error"; } std::string item_name = item_value->cast()->value(); MS_LOG(DEBUG) << "Resolve name: " << cls->tag().name(); MS_LOG(DEBUG) << "Resolve item: " << item_name; MS_EXCEPTION_IF_NULL(cls); AbstractBasePtr attr = cls->GetAttribute(item_name); if (attr != nullptr) { return std::make_shared(attr, nullptr); } ValuePtr method = cls->GetMethod(item_name); if (method->isa()) { MS_EXCEPTION_IF_NULL(args_spec_list[0]); MS_EXCEPTION_IF_NULL(args_spec_list[0]->BuildType()); MS_EXCEPTION(AttributeError) << "Unknown field, data type: " << args_spec_list[0]->BuildType()->ToString() << ", item value: " << item_value->ToString(); } // Infer class method ValuePtr converted_value = PyObjToGraph(engine, method); return StaticGetterInferred(converted_value, data_conf, out_conf); } EvalResultPtr GetEvaluatedValueForBuiltinTypeAttrOrMethod(const AnalysisEnginePtr &engine, const ValuePtr &item_value, const TypePtr &data_type, const ConfigPtr &data_conf, const AnfNodeConfigPtr &out_conf) { MS_EXCEPTION_IF_NULL(item_value); MS_EXCEPTION_IF_NULL(data_type); // The method maybe a Primitive or Composite if (!item_value->isa()) { MS_LOG(EXCEPTION) << "Expect a string, but got: " << item_value->ToString(); } std::string item_name = item_value->cast()->value(); REQUIRE_TYPE require_type = REQUIRE_TYPE::METHOD; Any require = pipeline::Resource::GetMethodPtr(data_type->type_id(), item_name); if (require.empty()) { require = pipeline::Resource::GetAttrPtr(data_type->type_id(), item_name); if (require.empty()) { MS_LOG(EXCEPTION) << "Not supported to get attribute item name:\'" << item_name << "\' of a type[" << data_type->ToString() << "]"; } require_type = REQUIRE_TYPE::ATTR; } ValuePtr converted_value = nullptr; if (require.is()) { // composite registered in standard_method_map go to this branch converted_value = prim::GetPythonOps(require.cast()); MS_EXCEPTION_IF_NULL(converted_value); if (pipeline::GetJitLevel() == "o0" && converted_value->isa()) { UpdateDebugInfo(converted_value->cast(), out_conf->node()->scope(), out_conf->node()->debug_info()); } if (!converted_value->isa()) { AddToManager(engine, converted_value->cast()); } } else if (require.is()) { converted_value = require.cast(); } else { MS_LOG(EXCEPTION) << "Expect to get string or PrimitivePtr from attr or method map, but got " << require.ToString(); } return StaticGetterInferred(converted_value, data_conf, out_conf, require_type); } enum ResolveType : int64_t { kResolveTypeUserDefineClass = 1, kResolveTypeBuiltInType, kResolveTypeFunction, }; int64_t GetResolveType(const TypePtr &data_type) { MS_EXCEPTION_IF_NULL(data_type); if (data_type->type_id() == kObjectTypeClass) { return kResolveTypeUserDefineClass; } // Try to search method map, if not found, the data_type should be External type. if (pipeline::Resource::IsTypeInBuiltInMap(data_type->type_id())) { return kResolveTypeBuiltInType; } return kResolveTypeFunction; } EvalResultPtr StaticGetter(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list, const ConfigPtr &data_conf, const AnfNodeConfigPtr &out_conf) { // Inputs: namespace and its static function; or class and its member function CheckArgsSize("StaticGetter", args_spec_list, 2); MS_EXCEPTION_IF_NULL(args_spec_list[0]); MS_EXCEPTION_IF_NULL(args_spec_list[1]); MS_LOG(DEBUG) << "Args[0]: " << args_spec_list[0]->ToString(); MS_LOG(DEBUG) << "Args[1]: " << args_spec_list[1]->ToString(); TypePtr data_type = args_spec_list[0]->BuildType(); ValuePtr item_value = args_spec_list[1]->BuildValue(); ScopePtr scope = kDefaultScope; if (out_conf != nullptr) { scope = out_conf->node()->scope(); } ScopeGuard scope_guard(scope); MS_EXCEPTION_IF_NULL(item_value); if (item_value->isa()) { MS_LOG(EXCEPTION) << "The value of the attribute could not be inferred: " << item_value->ToString(); } int64_t resolve_type = GetResolveType(data_type); if (resolve_type == kResolveTypeUserDefineClass) { return GetEvaluatedValueForClassAttrOrMethod(engine, args_spec_list, item_value, data_conf, out_conf); } else if (resolve_type == kResolveTypeBuiltInType) { return GetEvaluatedValueForBuiltinTypeAttrOrMethod(engine, item_value, data_type, data_conf, out_conf); } else { return GetEvaluatedValueForNameSpaceString(engine, args_spec_list, out_conf); } } } // end anonymous namespace namespace { class EmbedEvaluator : public SymbolicPrimEvaluator { public: EmbedEvaluator() : SymbolicPrimEvaluator("EmbedEvaluator") {} ~EmbedEvaluator() override = default; MS_DECLARE_PARENT(EmbedEvaluator, SymbolicPrimEvaluator); EvalResultPtr EvalPrim(const ConfigPtrList &args_conf_list) override { // arg: free variable to be embedded if (args_conf_list.size() != 1) { MS_LOG(EXCEPTION) << "EmbedEvaluator requires 1 parameter, but got " << args_conf_list.size(); } AnfNodeConfigPtr node_conf = dyn_cast(args_conf_list[0]); MS_EXCEPTION_IF_NULL(node_conf); MS_EXCEPTION_IF_NULL(node_conf->ObtainEvalResult()); AbstractBasePtr x = node_conf->ObtainEvalResult()->abstract(); x = SensitivityTransform(x); SymbolicKeyInstancePtr key = std::make_shared(node_conf->node(), x); AbstractScalarPtr abs_scalar = std::make_shared(key, std::make_shared()); return std::make_shared(abs_scalar, std::make_shared()); } }; static AnfNodePtr FindParameterNodeByString(const FuncGraphManagerPtr &manager, const std::string &name) { MS_EXCEPTION_IF_NULL(manager); auto root_g_set = manager->roots(); if (root_g_set.size() != 1) { return nullptr; } const FuncGraphPtr &root_g = root_g_set.back(); for (auto ¶m_node : root_g->parameters()) { auto param = param_node->cast(); if (param && name == param->name()) { return param; } } return nullptr; } class RefToEmbedEvaluator : public SymbolicPrimEvaluator { public: RefToEmbedEvaluator() : SymbolicPrimEvaluator("RefToEmbedEvaluator") {} ~RefToEmbedEvaluator() override = default; MS_DECLARE_PARENT(RefToEmbedEvaluator, SymbolicPrimEvaluator); EvalResultPtr EvalPrim(const ConfigPtrList &args_conf_list) override { if (args_conf_list.size() != 1) { MS_LOG(ERROR) << "Requires 1 parameter, but has: " << args_conf_list.size(); return nullptr; } static TypePtr type = std::make_shared(); auto node_conf = dyn_cast(args_conf_list[0]); if (node_conf == nullptr) { MS_LOG(ERROR) << "Conf should be AnfNodeConfig"; return nullptr; } MS_EXCEPTION_IF_NULL(node_conf->ObtainEvalResult()); AbstractBasePtr abs = node_conf->ObtainEvalResult()->abstract(); MS_EXCEPTION_IF_NULL(abs); AbstractRefPtr ref_abs = abs->cast(); if (ref_abs == nullptr) { MS_LOG(ERROR) << "The first parameter of RefToEmbed should be Ref, but " << abs->ToString(); return nullptr; } auto key_abs = ref_abs->ref_key(); if (key_abs == nullptr) { MS_LOG(ERROR) << "RefToEmbed input Ref key is nullptr."; return nullptr; } auto key_value = key_abs->BuildValue(); if (key_value == nullptr) { MS_LOG(ERROR) << "RefToEmbed input Ref key value is nullptr."; return nullptr; } auto refkey = key_value->cast(); if (refkey == nullptr) { auto ret = std::make_shared(type); auto ref_value = ref_abs->ref(); MS_EXCEPTION_IF_NULL(ref_value); return std::make_shared(ret, std::make_shared()); } std::string name = refkey->tag(); MS_EXCEPTION_IF_NULL(node_conf->node()); if (node_conf->node()->func_graph() == nullptr) { MS_LOG(EXCEPTION) << "Should not evaluate a ValueNode, node: " << node_conf->node()->DebugString(); } const auto &manager = node_conf->node()->func_graph()->manager(); auto node = FindParameterNodeByString(manager, name); if (node == nullptr) { MS_LOG(ERROR) << "RefToEmbed input can't find parameter \"" << name << "\" in graph."; return nullptr; } AbstractBasePtr x = ref_abs->ref(); x = SensitivityTransform(x); std::shared_ptr key = std::make_shared(node, x); std::shared_ptr abs_scalar = std::make_shared(key, type); return std::make_shared(abs_scalar, std::make_shared()); } }; class GetAttrEvaluator : public TransitionPrimEvaluator { public: GetAttrEvaluator() : TransitionPrimEvaluator("GetAttrEvaluator") {} ~GetAttrEvaluator() override = default; MS_DECLARE_PARENT(GetAttrEvaluator, TransitionPrimEvaluator); EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list, const ConfigPtr &in_conf0, const AnfNodeConfigPtr &out_conf) override { constexpr auto kGetAttrArgSize = 2; auto ret_abstract = AbstractEval(args_spec_list); if (ret_abstract != nullptr) { MS_LOG(DEBUG) << "GetAttrEvaluator eval Undetermined"; return ret_abstract; } // Inputs: data, item if (args_spec_list.size() != kGetAttrArgSize) { MS_LOG(EXCEPTION) << "Expected args_spec_list size = 2, but has size:" << args_spec_list.size(); } EvalResultPtr ret = nullptr; if (bound_node() != nullptr) { TraceGuard trace_guard(std::make_shared(bound_node()->debug_info())); ret = StaticGetter(engine, args_spec_list, in_conf0, out_conf); } else { ret = StaticGetter(engine, args_spec_list, in_conf0, out_conf); } // don't lookup from cache, as different out_conf with same node but different context // may add different entry to anfnode_config_map, like getattr primitive; evaluator_cache_mgr_->SetValue(args_spec_list, ret); return ret; } }; class ResolveEvaluator : public TransitionPrimEvaluator { public: ResolveEvaluator() : TransitionPrimEvaluator("ResolveEvaluator") {} ~ResolveEvaluator() override = default; MS_DECLARE_PARENT(ResolveEvaluator, TransitionPrimEvaluator); EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list, const ConfigPtr &in_conf0, const AnfNodeConfigPtr &out_conf) override { constexpr auto kResolveArgSize = 2; // Inputs: namespace, symbol if (args_spec_list.size() != kResolveArgSize) { MS_LOG(EXCEPTION) << "Expected args_spec_list size = 2, but has size:" << args_spec_list.size(); } EvalResultPtr ret = nullptr; if (bound_node() != nullptr) { TraceGuard trace_guard(std::make_shared(bound_node()->debug_info())); ret = StaticGetter(engine, args_spec_list, in_conf0, out_conf); } else { ret = StaticGetter(engine, args_spec_list, in_conf0, out_conf); } return ret; } }; class CreateInstanceEvaluator : public TransitionPrimEvaluator { public: CreateInstanceEvaluator() : TransitionPrimEvaluator("CreateInstanceEvaluator") {} ~CreateInstanceEvaluator() override = default; MS_DECLARE_PARENT(CreateInstanceEvaluator, TransitionPrimEvaluator); EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list, const ConfigPtr &, const AnfNodeConfigPtr &out_conf) override { if (args_spec_list.empty()) { MS_LOG(EXCEPTION) << "'args_spec_list' should not be empty"; } // Get the type parameter. MS_EXCEPTION_IF_NULL(args_spec_list[0]); TypePtr type = args_spec_list[0]->GetTypeTrack(); MS_EXCEPTION_IF_NULL(type); if (type->type_id() != kMetaTypeTypeType) { MS_LOG(EXCEPTION) << "CreateInstanceEvaluator require first parameter should be an object of TypeType, but got " << type->ToString(); } ValuePtr value_track = args_spec_list[0]->GetValueTrack(); MS_EXCEPTION_IF_NULL(value_track); std::shared_ptr type_obj = dyn_cast(value_track); if (type_obj == nullptr) { MS_LOG(EXCEPTION) << "Cast value failed, not PyObjectWrapper:" << value_track->ToString() << "."; } if (!type_obj->isa()) { MS_LOG(EXCEPTION) << "CreateInstanceEvaluator the type_obj should be an object of ClassType, but got " << type_obj->ToString() << "."; } auto class_type = type_obj->obj(); MS_LOG(DEBUG) << "Get class type is " << type_obj->ToString() << "."; // Get the create instance obj's parameters, `params` may contain tuple(args, kwargs). py::tuple params = GetParameters(args_spec_list); // Create class instance. auto obj = parse::data_converter::CreatePythonObject(class_type, params); if (py::isinstance(obj)) { MS_LOG(EXCEPTION) << "Create python object `" << py::str(class_type) << "` failed, only support to create \'Cell\' or \'Primitive\' object."; } // Process the object. ValuePtr converted_ret = nullptr; bool converted = parse::ConvertData(obj, &converted_ret, true); if (!converted) { MS_LOG(EXCEPTION) << "Convert the python object failed"; } MS_EXCEPTION_IF_NULL(converted_ret); if (converted_ret->isa()) { AddToManager(engine, converted_ret->cast()); } AbstractBasePtr ret = ToAbstract(converted_ret, AnalysisContext::DummyContext(), out_conf); auto infer_result = std::make_shared(ret, std::make_shared()); evaluator_cache_mgr_->SetValue(args_spec_list, infer_result); return infer_result; } py::tuple GetParameters(const AbstractBasePtrList &args_spec_list) const { // Exclude class type by minus 1; std::size_t params_size = args_spec_list.size() - 1; auto params = py::tuple(params_size); if (params_size > params.size()) { MS_LOG(EXCEPTION) << "Unexpected params_size: " << params_size << ", params.size():" << params.size(); } if (params_size > 0) { for (size_t i = 0; i < params_size; i++) { // Only support the Scalar parameters type. Bypass class type by offset with 1. auto arg = args_spec_list[i + 1]; MS_EXCEPTION_IF_NULL(arg); // Because the Tensor's AbstractTensor can't get value from GetValueTrack. ValuePtr param_value = arg->BuildValue(); py::object param = ValueToPyData(param_value); params[i] = param; } } return params; } }; class PyInterpretEvaluator : public TransitionPrimEvaluator { public: PyInterpretEvaluator() : TransitionPrimEvaluator("PyInterpretEvaluator") {} ~PyInterpretEvaluator() override = default; MS_DECLARE_PARENT(PyInterpretEvaluator, TransitionPrimEvaluator); EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list, const ConfigPtr &, const AnfNodeConfigPtr &out_conf) override { if (args_spec_list.empty()) { MS_LOG(ERROR) << "'args_spec_list' should not be empty"; } // Get the type parameter. MS_EXCEPTION_IF_NULL(args_spec_list[0]); ValuePtr value_track = args_spec_list[0]->GetValueTrack(); MS_EXCEPTION_IF_NULL(value_track); std::shared_ptr script_obj = dyn_cast(value_track); if (script_obj == nullptr) { MS_LOG(EXCEPTION) << "Cast value failed, not PyObjectWrapper:" << value_track->ToString() << "."; } // Make global and local parameters. py::tuple params = MakeParameters(args_spec_list); // Call python script string. MS_LOG(DEBUG) << "Call script: " << script_obj->script() << ", params: " << py::str(params); auto obj = parse::data_converter::CallPythonScript(py::str(script_obj->script()), params); if (py::isinstance(obj)) { MS_LOG(EXCEPTION) << "Failed to call python script: `" << script_obj->script() << "`"; } ValuePtr converted_val = nullptr; bool converted = parse::ConvertData(obj, &converted_val, true); if (!converted) { MS_LOG(EXCEPTION) << "Convert the python object failed"; } MS_EXCEPTION_IF_NULL(converted_val); AbstractBasePtr res = ToAbstract(converted_val, AnalysisContext::DummyContext(), out_conf); auto infer_result = std::make_shared(res, std::make_shared()); evaluator_cache_mgr_->SetValue(args_spec_list, infer_result); return infer_result; } py::tuple MakeParameters(const AbstractBasePtrList &args_spec_list) const { constexpr int params_size = 3; if (params_size != args_spec_list.size()) { MS_LOG(EXCEPTION) << "Unexpected params_size: " << params_size << ", not equal to arguments.size:" << args_spec_list.size(); } // The first argument is script string, ignore it. auto params = py::tuple(params_size - 1); // Make the global parameters. auto global_dict = dyn_cast(args_spec_list[1]); // Global parameters dict. MS_EXCEPTION_IF_NULL(global_dict); auto filtered_global_dict = FilterParameters(global_dict); MS_LOG(DEBUG) << "arg_1, global_dict: " << global_dict->ToString() << ", filtered_global_dict: " << filtered_global_dict->ToString(); ValuePtr global_dict_value = filtered_global_dict->BuildValue(); py::object global_params_dict = ValueToPyData(global_dict_value); MS_LOG(DEBUG) << "arg_1, python global_params_dict: " << global_dict_value->ToString() << " -> " << py::str(global_params_dict); params[0] = global_params_dict; // Make the local parameters. auto local_dict = dyn_cast(args_spec_list[2]); // Local parameters dict. MS_EXCEPTION_IF_NULL(local_dict); auto filtered_local_dict = FilterParameters(local_dict); MS_LOG(DEBUG) << "arg_2, local_dict: " << local_dict->ToString() << ", filtered_local_dict:" << filtered_local_dict->ToString(); ValuePtr local_dict_value = filtered_local_dict->BuildValue(); py::object local_params_dict = ValueToPyData(local_dict_value); MS_LOG(DEBUG) << "arg_2, python local_params_dict: " << local_dict_value->ToString() << " -> " << py::str(local_params_dict); params[1] = local_params_dict; return params; } AbstractDictionaryPtr FilterParameters(const AbstractDictionaryPtr &abstract_dict) const { std::vector kv; const auto &keys_values = abstract_dict->elements(); // Filter out the element of Function type. (void)std::copy_if(keys_values.cbegin(), keys_values.cend(), std::back_inserter(kv), [](const AbstractAttribute &item) { MS_EXCEPTION_IF_NULL(item.second); return (!item.second->isa()); }); return std::make_shared(kv); } }; class PartialEvaluator : public Evaluator { public: PartialEvaluator() : Evaluator("PartialEvaluator") {} ~PartialEvaluator() override = default; EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, const AnfNodeConfigPtr &out_conf) override { if (args_conf_list.size() == 0) { MS_LOG(EXCEPTION) << "Args size should be greater than 0"; } MS_EXCEPTION_IF_NULL(out_conf); MS_EXCEPTION_IF_NULL(out_conf->node()); MS_EXCEPTION_IF_NULL(args_conf_list[0]); MS_EXCEPTION_IF_NULL(args_conf_list[0]->ObtainEvalResult()); auto arg0_value = args_conf_list[0]->ObtainEvalResult()->abstract(); MS_EXCEPTION_IF_NULL(arg0_value); AbstractBasePtrList args_spec_list{arg0_value}; // Func in hypermap(partial(Func, arg0), arg1, arg2) may become Poly Node. if (arg0_value->isa()) { MS_EXCEPTION_IF_NULL(arg0_value->GetValueTrack()); auto ret = std::make_shared(arg0_value->GetValueTrack()->cast(), out_conf->node()); MS_LOG(DEBUG) << "AbstractError for node: " << out_conf->node()->DebugString() << " as func is: " << arg0_value->ToString(); auto eval_result = std::make_shared(ret, std::make_shared()); evaluator_cache_mgr_->SetValue(args_spec_list, eval_result); return eval_result; } auto func = CheckArg("partial", args_spec_list, 0); // Sometimes, node[0] in out_conf becomes phi0; if (func->isa()) { auto prim_func = dyn_cast(func); MS_EXCEPTION_IF_NULL(prim_func->prim()); if (prim_func->prim()->isa()) { prim::DoSignaturePrimitivePtr do_signature_prim = dyn_cast(prim_func->prim()); return HandleDoSignature(engine, do_signature_prim->function(), out_conf); } } (void)std::transform( args_conf_list.begin() + 1, args_conf_list.end(), std::back_inserter(args_spec_list), [](const ConfigPtr &config) -> AbstractBasePtr { return config->ObtainEvalResult()->abstract(); }); AbstractBasePtrList args(args_spec_list.begin() + 1, args_spec_list.end()); auto cnode = out_conf->node()->cast(); MS_EXCEPTION_IF_NULL(cnode); if (cnode->size() != (args_conf_list.size() + 1)) { MS_LOG(EXCEPTION) << "Out_conf node: " << cnode->DebugString() << ", args_conf_list: " << mindspore::ToString(args_conf_list); } AbstractFuncAtomPtrList partial_funcs_list; auto build_partial = [args, cnode, &partial_funcs_list](const AbstractFuncAtomPtr &atom_func) { auto new_func = std::make_shared(atom_func, args, cnode); partial_funcs_list.push_back(new_func); }; func->Visit(build_partial); auto ret = AbstractFunction::MakeAbstractFunction(partial_funcs_list); auto eval_result = std::make_shared(ret, std::make_shared()); evaluator_cache_mgr_->SetValue(args_spec_list, eval_result); return eval_result; } EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &, const AnfNodeConfigPtr &) override { MS_LOG(EXCEPTION) << "Eval() should not be called, Run() method should be called"; } EvalResultPtr HandleDoSignature(const AnalysisEnginePtr &engine, const ValuePtr &signature_value, const AnfNodeConfigPtr &out_conf) const { MS_EXCEPTION_IF_NULL(engine); MS_EXCEPTION_IF_NULL(out_conf); MS_EXCEPTION_IF_NULL(out_conf->node()); auto cnode = out_conf->node()->cast(); if (cnode == nullptr) { MS_LOG(EXCEPTION) << "Cnode is nullptr"; } ScopeGuard scope_guard(out_conf->node()->scope()); TraceGuard trace_guard(std::make_shared(out_conf->node()->debug_info())); std::vector new_nodes_inputs = cnode->inputs(); auto new_signature_value = std::make_shared("signature", signature_value); new_nodes_inputs[1] = NewValueNode(new_signature_value); FuncGraphPtr func_graph = cnode->func_graph(); MS_EXCEPTION_IF_NULL(func_graph); CNodePtr new_cnode = func_graph->NewCNode(std::move(new_nodes_inputs)); AnfNodeConfigPtr fn_conf = engine->MakeConfig(new_cnode, out_conf->context(), out_conf->func_graph()); return engine->ForwardConfig(out_conf, fn_conf); } }; struct PrimitiveImplInferValue { PrimitiveImpl impl_; // implement function of primitive bool eval_value_; // whether evaluate value TypePtr specify_out_type_; // whether specify return type bool in_white_list_; // true if this Primitive in white list, else false. }; using PrimitiveToImplMap = mindspore::HashMap; PrimitiveToImplMap &GetUniformPrimitiveToImplMap() { using R = PrimitiveToImplMap::mapped_type; static PrimitiveToImplMap uniform_prim_implement_map{ {prim::kPrimScalarAdd, R{prim::ScalarAdd, true, nullptr, true}}, {prim::kPrimScalarSub, R{prim::ScalarSub, true, nullptr, true}}, {prim::kPrimScalarMul, R{prim::ScalarMul, true, nullptr, true}}, {prim::kPrimScalarDiv, R{prim::ScalarDiv, true, nullptr, true}}, {prim::kPrimScalarMod, R{prim::ScalarMod, true, nullptr, true}}, {prim::kPrimScalarPow, R{prim::ScalarPow, true, nullptr, true}}, {prim::kPrimScalarFloordiv, R{prim::ScalarFloordiv, true, nullptr, true}}, {prim::kPrimScalarUadd, R{prim::ScalarUAdd, true, nullptr, true}}, {prim::kPrimScalarUsub, R{prim::ScalarUSub, true, nullptr, true}}, {prim::kPrimScalarLog, R{prim::ScalarLog, true, nullptr, true}}, {prim::kPrimScalarEq, R{prim::ScalarEq, true, std::make_shared(), true}}, {prim::kPrimScalarLt, R{prim::ScalarLt, true, std::make_shared(), true}}, {prim::kPrimScalarGt, R{prim::ScalarGt, true, std::make_shared(), true}}, {prim::kPrimScalarNe, R{prim::ScalarNe, true, std::make_shared(), true}}, {prim::kPrimScalarLe, R{prim::ScalarLe, true, std::make_shared(), true}}, {prim::kPrimScalarGe, R{prim::ScalarGe, true, std::make_shared(), true}}, {prim::kPrimBoolNot, R{prim::BoolNot, true, std::make_shared(), true}}, {prim::kPrimBoolAnd, R{prim::BoolAnd, true, std::make_shared(), true}}, {prim::kPrimBoolEq, R{prim::BoolEq, true, std::make_shared(), true}}, {prim::kPrimBoolOr, R{prim::BoolOr, true, std::make_shared(), true}}, }; return uniform_prim_implement_map; } PrimEvaluatorMap PrimEvaluatorConstructors = PrimEvaluatorMap(); std::mutex PrimEvaluatorConstructorMutex; void InitPrimEvaluatorConstructors() { PrimEvaluatorMap &constructor = PrimEvaluatorConstructors; for (const auto &iter : GetPrimitiveToEvalImplMap()) { constructor[iter.first] = InitStandardPrimEvaluator(iter.first, iter.second); } for (const auto &iter : GetUniformPrimitiveToImplMap()) { constructor[iter.first] = InitUniformPrimEvaluator(iter.first, iter.second.impl_, iter.second.eval_value_, iter.second.specify_out_type_); } constructor[prim::kPrimEmbed] = std::make_shared(); constructor[prim::kPrimRefToEmbed] = std::make_shared(); constructor[prim::kPrimGetAttr] = std::make_shared(); constructor[prim::kPrimResolve] = std::make_shared(); constructor[prim::kPrimCreateInstance] = std::make_shared(); constructor[prim::kPrimPartial] = std::make_shared(); constructor[prim::kPrimPyInterpret] = std::make_shared(); } } // namespace void ClearPrimEvaluatorMap() { PrimEvaluatorConstructors.clear(); GetPrimitiveToEvalImplMap().clear(); GetUniformPrimitiveToImplMap().clear(); } bool IsInWhiteList(const PrimitivePtr &primitive) { MS_EXCEPTION_IF_NULL(primitive); auto iter = GetPrimitiveToEvalImplMap().find(primitive); if (iter != GetPrimitiveToEvalImplMap().end()) { return iter->second.in_white_list_; } auto uni_iter = GetUniformPrimitiveToImplMap().find(primitive); if (uni_iter != GetUniformPrimitiveToImplMap().end()) { return uni_iter->second.in_white_list_; } return false; } PrimEvaluatorMap &GetPrimEvaluatorConstructors() { PrimEvaluatorMap &constructor = PrimEvaluatorConstructors; if (!constructor.empty()) { return constructor; } std::lock_guard initLock(PrimEvaluatorConstructorMutex); if (constructor.empty()) { InitPrimEvaluatorConstructors(); } return constructor; } namespace { bool IsSubtypeTuple(const AbstractBasePtr x, const TypePtr model) { MS_EXCEPTION_IF_NULL(x); MS_EXCEPTION_IF_NULL(model); auto x_tuple = dyn_cast(x); auto model_tuple = dyn_cast(model); if (x_tuple == nullptr || model_tuple == nullptr) { return false; } if (model->IsGeneric()) { return true; } if (x_tuple->size() != model_tuple->size()) { return false; } for (size_t i = 0; i < x_tuple->size(); i++) { bool is_subtype = IsSubtype((*x_tuple)[i], (*model_tuple)[i]); if (!is_subtype) { return false; } } return true; } bool IsSubtypeArray(const AbstractBasePtr x, const TypePtr model) { MS_EXCEPTION_IF_NULL(x); MS_EXCEPTION_IF_NULL(model); auto x_tensor = dyn_cast(x); auto model_tensor = dyn_cast(model); if (x_tensor == nullptr || model_tensor == nullptr) { return false; } if (model->IsGeneric()) { return true; } return IsSubtype(x_tensor->element(), model_tensor->element()); } bool IsSubtypeList(const AbstractBasePtr x, const TypePtr model) { MS_EXCEPTION_IF_NULL(x); MS_EXCEPTION_IF_NULL(model); auto x_list = dyn_cast(x); auto model_list = dyn_cast(model); if (x_list == nullptr || model_list == nullptr) { return false; } if (model->IsGeneric()) { return true; } if (x_list->size() != model_list->size()) { return false; } bool is_subtype = true; for (size_t i = 0; i < x_list->size(); i++) { is_subtype = IsSubtype((*x_list)[i], (*model_list)[i]); if (!is_subtype) { return false; } } return is_subtype; } bool IsSubtypeClass(const AbstractBasePtr x, const TypePtr model) { MS_EXCEPTION_IF_NULL(x); MS_EXCEPTION_IF_NULL(model); auto x_class = dyn_cast(x); auto model_class = dyn_cast(model); if (x_class == nullptr) { return false; } if (model->IsGeneric()) { return true; } MS_EXCEPTION_IF_NULL(model_class); if (x_class->tag() == model_class->tag()) { auto m_attributes = model_class->GetAttributes(); auto x_attributes = x_class->attributes(); if (m_attributes.size() != x_attributes.size()) { return false; } for (size_t i = 0; i < m_attributes.size(); i++) { if (!IsSubtype(x_attributes[i].second, m_attributes[i].second)) { return false; } } return true; } return false; } inline bool IsSubtypeScalar(const AbstractBasePtr x, const TypePtr model) { MS_EXCEPTION_IF_NULL(x); MS_EXCEPTION_IF_NULL(model); if (dyn_cast(x) == nullptr) { return false; } TypePtr x_type = x->GetTypeTrack(); return IsSubType(x_type, model); } } // namespace bool IsSubtype(const AbstractBasePtr x, const TypePtr model) { MS_EXCEPTION_IF_NULL(x); MS_EXCEPTION_IF_NULL(model); TypeId model_typeid = model->type_id(); switch (model_typeid) { case kMetaTypeObject: return true; case kObjectTypeTuple: return IsSubtypeTuple(x, model); case kObjectTypeTensorType: return IsSubtypeArray(x, model); case kObjectTypeList: return IsSubtypeList(x, model); case kObjectTypeClass: return IsSubtypeClass(x, model); default: if (IsSubType(model, std::make_shared())) { return IsSubtypeScalar(x, model); } MS_LOG(EXCEPTION) << "Invalid model type: " << model->ToString() << "."; } } } // namespace abstract } // namespace mindspore