Merge pull request !1822 from fary86/codex_big_functionstags/v0.5.0-beta
| @@ -124,6 +124,8 @@ class AnalyzedFuncGraphExporter : public AnfExporter { | |||
| void ExportOneFuncGraph(std::ofstream &ofs, const FuncGraphPtr &func_graph); | |||
| void OutputCNodes(std::ofstream &ofs, const std::vector<AnfNodePtr> &nodes, const FuncGraphPtr &func_graph); | |||
| void OutputCNode(std::ofstream &ofs, const CNodePtr &cnode, const FuncGraphPtr &func_graph, int *idx, | |||
| std::map<AnfNodePtr, int> *const apply_map); | |||
| private: | |||
| std::string GetNodeType(const AnfNodePtr &nd) override; | |||
| @@ -169,7 +171,7 @@ std::string AnalyzedFuncGraphExporter::GetNodeType(const AnfNodePtr &node) { | |||
| } | |||
| auto abs = ret->abstract(); | |||
| if (abs == nullptr) { | |||
| return nullptr; | |||
| return "Undefined"; | |||
| } | |||
| auto dtype = abs->BuildType(); | |||
| auto shape = abs->BuildShape(); | |||
| @@ -247,6 +249,51 @@ AnalysisContextPtr AnalyzedFuncGraphExporter::ProcessFuncGraphCall(const CNodePt | |||
| return ctx; | |||
| } | |||
| void AnalyzedFuncGraphExporter::OutputCNode(std::ofstream &ofs, const CNodePtr &cnode, const FuncGraphPtr &func_graph, | |||
| int *idx, std::map<AnfNodePtr, int> *const apply_map) { | |||
| auto &inputs = cnode->inputs(); | |||
| std::string op_text = GetAnfNodeText(func_graph, inputs[0], *apply_map); | |||
| // non-return node | |||
| if (cnode != func_graph->get_return()) { | |||
| int apply_idx = (*idx)++; | |||
| (*apply_map)[cnode] = apply_idx; | |||
| std::string type_info = GetNodeType(cnode); | |||
| if (type_info == "Undefined") { | |||
| ofs << " %" << apply_idx << " = " << op_text << "("; | |||
| } else { | |||
| ofs << " %" << apply_idx << " : " << type_info << " = " << op_text << "("; | |||
| } | |||
| } else { | |||
| ofs << " " << op_text << "("; | |||
| } | |||
| for (size_t i = 1; i < inputs.size(); ++i) { | |||
| if (i != 1) { | |||
| ofs << ", "; | |||
| } | |||
| AnfNodePtr arg = inputs[i]; | |||
| ofs << GetAnfNodeText(func_graph, arg, *apply_map); | |||
| } | |||
| ofs << ")"; | |||
| // process function graph call | |||
| auto ctx = ProcessFuncGraphCall(cnode); | |||
| // output comment | |||
| OutputStatementComment(ofs, cnode); | |||
| if (ctx != nullptr) { | |||
| ofs << " @ctx.addr=" << ctx.get(); | |||
| } | |||
| ofs << "\n"; | |||
| if (label_manage::GetGlobalTraceLabelType() == label_manage::TraceLabelType::kWithUniqueId) { | |||
| ofs << trace::GetDebugInfo(cnode->debug_info(), " # ", kSourceLineTipDiscard) << "#" | |||
| << label_manage::Label(cnode->debug_info()) << "\n"; | |||
| } else { | |||
| ofs << trace::GetDebugInfo(cnode->debug_info(), " # ", kSourceLineTipDiscard) << "\n"; | |||
| } | |||
| } | |||
| void AnalyzedFuncGraphExporter::OutputCNodes(std::ofstream &ofs, const std::vector<AnfNodePtr> &nodes, | |||
| const FuncGraphPtr &func_graph) { | |||
| if (func_graph == nullptr) { | |||
| @@ -267,47 +314,7 @@ void AnalyzedFuncGraphExporter::OutputCNodes(std::ofstream &ofs, const std::vect | |||
| } | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| auto &inputs = cnode->inputs(); | |||
| std::string op_text = GetAnfNodeText(func_graph, inputs[0], apply_map); | |||
| // non-return node | |||
| if (node != func_graph->get_return()) { | |||
| int apply_idx = idx++; | |||
| apply_map[node] = apply_idx; | |||
| std::string type_info = GetNodeType(node); | |||
| if (type_info == "Undefined") { | |||
| ofs << " %" << apply_idx << " = " << op_text << "("; | |||
| } else { | |||
| ofs << " %" << apply_idx << " : " << type_info << " = " << op_text << "("; | |||
| } | |||
| } else { | |||
| ofs << " " << op_text << "("; | |||
| } | |||
| for (size_t i = 1; i < inputs.size(); ++i) { | |||
| if (i != 1) { | |||
| ofs << ", "; | |||
| } | |||
| AnfNodePtr arg = inputs[i]; | |||
| ofs << GetAnfNodeText(func_graph, arg, apply_map); | |||
| } | |||
| ofs << ")"; | |||
| // process function graph call | |||
| auto ctx = ProcessFuncGraphCall(cnode); | |||
| // output comment | |||
| OutputStatementComment(ofs, cnode); | |||
| if (ctx != nullptr) { | |||
| ofs << " @ctx.addr=" << ctx.get(); | |||
| } | |||
| ofs << "\n"; | |||
| if (label_manage::GetGlobalTraceLabelType() == label_manage::TraceLabelType::kWithUniqueId) { | |||
| ofs << trace::GetDebugInfo(cnode->debug_info(), " # ", kSourceLineTipDiscard) << "#" | |||
| << label_manage::Label(cnode->debug_info()) << "\n"; | |||
| } else { | |||
| ofs << trace::GetDebugInfo(cnode->debug_info(), " # ", kSourceLineTipDiscard) << "\n"; | |||
| } | |||
| OutputCNode(ofs, cnode, func_graph, &idx, &apply_map); | |||
| } | |||
| } | |||
| @@ -76,44 +76,56 @@ bool CompareTensorScalarType(const TypeId &tensor_type, const size_t &t_type_num | |||
| return true; | |||
| } | |||
| void setMaxType(TypeId *max_type_id, TypeId *max_type, size_t *max_type_number, const TypeId type_id, const TypeId type, | |||
| void SetMaxType(TypeId *max_type_id, TypeId *max_type, size_t *max_type_number, const TypeId type_id, const TypeId type, | |||
| const size_t type_number) { | |||
| *max_type_id = type_id; | |||
| *max_type = type; | |||
| *max_type_number = type_number; | |||
| } | |||
| TypeId GetMaxTypeId(const abstract::AbstractBasePtrList &args_spec_list, std::vector<size_t> indexs, | |||
| const std::set<size_t> &write_indexs) { | |||
| bool GetTensorOrScalarTypeInfo(AbstractBasePtr arg_value, bool is_write, TypeId *arg_type_id, | |||
| TypeId *arg_type = nullptr) { | |||
| if (arg_value->isa<abstract::AbstractRef>()) { | |||
| if (is_write) { | |||
| arg_value = arg_value->cast<abstract::AbstractRefPtr>()->ref_origin(); | |||
| } else { | |||
| arg_value = arg_value->cast<abstract::AbstractRefPtr>()->ref(); | |||
| } | |||
| } | |||
| if (arg_value->isa<abstract::AbstractTensor>()) { | |||
| auto tensor = arg_value->cast<abstract::AbstractTensorPtr>(); | |||
| auto tensor_type = tensor->element()->BuildType(); | |||
| MS_EXCEPTION_IF_NULL(tensor_type); | |||
| *arg_type_id = tensor_type->type_id(); | |||
| if (arg_type != nullptr) { | |||
| *arg_type = kObjectTypeTensorType; | |||
| } | |||
| return true; | |||
| } | |||
| if (arg_value->isa<abstract::AbstractScalar>()) { | |||
| auto scalar = arg_value->cast<abstract::AbstractScalarPtr>(); | |||
| auto scalar_type = scalar->BuildType(); | |||
| MS_EXCEPTION_IF_NULL(scalar_type); | |||
| *arg_type_id = scalar_type->type_id(); | |||
| if (arg_type != nullptr) { | |||
| *arg_type = kObjectTypeNumber; | |||
| } | |||
| return true; | |||
| } | |||
| return false; | |||
| } | |||
| TypeId GetMaxTypeId(const abstract::AbstractBasePtrList &args_spec_list, std::vector<size_t> indices, | |||
| const std::set<size_t> &write_indices) { | |||
| TypeId max_type_id = kTypeUnknown; | |||
| TypeId max_type = kTypeUnknown; | |||
| size_t max_type_number = 0; | |||
| bool has_int8 = false; | |||
| for (const auto &index : indexs) { | |||
| for (const auto &index : indices) { | |||
| TypeId arg_type_id = kTypeUnknown; | |||
| TypeId arg_type = kTypeUnknown; | |||
| AbstractBasePtr arg_value = args_spec_list[index]; | |||
| if (arg_value->isa<abstract::AbstractRef>()) { | |||
| auto is_write = (write_indexs.find(index) != write_indexs.end()); | |||
| if (is_write) { | |||
| arg_value = arg_value->cast<abstract::AbstractRefPtr>()->ref_origin(); | |||
| } else { | |||
| arg_value = arg_value->cast<abstract::AbstractRefPtr>()->ref(); | |||
| } | |||
| } | |||
| if (arg_value->isa<abstract::AbstractTensor>()) { | |||
| auto tensor = arg_value->cast<abstract::AbstractTensorPtr>(); | |||
| auto tensor_type = tensor->element()->BuildType(); | |||
| MS_EXCEPTION_IF_NULL(tensor_type); | |||
| arg_type_id = tensor_type->type_id(); | |||
| arg_type = kObjectTypeTensorType; | |||
| } else if (arg_value->isa<abstract::AbstractScalar>()) { | |||
| auto scalar = arg_value->cast<abstract::AbstractScalarPtr>(); | |||
| auto scalar_type = scalar->BuildType(); | |||
| MS_EXCEPTION_IF_NULL(scalar_type); | |||
| arg_type_id = scalar_type->type_id(); | |||
| arg_type = kObjectTypeNumber; | |||
| } else { | |||
| auto is_write = (write_indices.find(index) != write_indices.end()); | |||
| if (!GetTensorOrScalarTypeInfo(args_spec_list[index], is_write, &arg_type_id, &arg_type)) { | |||
| continue; | |||
| } | |||
| auto it = type_map.find(arg_type_id); | |||
| @@ -124,22 +136,22 @@ TypeId GetMaxTypeId(const abstract::AbstractBasePtrList &args_spec_list, std::ve | |||
| has_int8 = true; | |||
| } | |||
| if (max_type_id == kTypeUnknown) { | |||
| setMaxType(&max_type_id, &max_type, &max_type_number, arg_type_id, arg_type, it->second); | |||
| SetMaxType(&max_type_id, &max_type, &max_type_number, arg_type_id, arg_type, it->second); | |||
| continue; | |||
| } | |||
| if (max_type == arg_type) { | |||
| if (it->second > max_type_number) { | |||
| setMaxType(&max_type_id, &max_type, &max_type_number, arg_type_id, arg_type, it->second); | |||
| SetMaxType(&max_type_id, &max_type, &max_type_number, arg_type_id, arg_type, it->second); | |||
| } | |||
| } else { | |||
| if (arg_type == kObjectTypeTensorType) { | |||
| if (CompareTensorScalarType(arg_type_id, it->second, max_type_id, max_type_number)) { | |||
| setMaxType(&max_type_id, &max_type, &max_type_number, arg_type_id, arg_type, it->second); | |||
| SetMaxType(&max_type_id, &max_type, &max_type_number, arg_type_id, arg_type, it->second); | |||
| } | |||
| } else { | |||
| if (!CompareTensorScalarType(max_type_id, max_type_number, arg_type_id, it->second)) { | |||
| setMaxType(&max_type_id, &max_type, &max_type_number, arg_type_id, arg_type, it->second); | |||
| SetMaxType(&max_type_id, &max_type, &max_type_number, arg_type_id, arg_type, it->second); | |||
| } | |||
| } | |||
| } | |||
| @@ -154,28 +166,28 @@ TypeId GetMaxTypeId(const abstract::AbstractBasePtrList &args_spec_list, std::ve | |||
| // Get the largest type of index in the same SignatureEnumDType of arguments. | |||
| std::map<SignatureEnumDType, TypeId> GetMaxDtype(const std::vector<SignatureEnumDType> &dtypes, | |||
| const abstract::AbstractBasePtrList &args_spec_list, | |||
| const std::set<size_t> &write_indexs) { | |||
| const std::set<size_t> &write_indices) { | |||
| // record index for signature.dtypes of the same type | |||
| // eg. [T, T1, T, T2, T, T1, T3] -> {{T:(0,2,4)}, {T1:(1,5)}, {T2:(3)}, {T3:(6)}} | |||
| std::map<SignatureEnumDType, std::vector<size_t>> type_indexs; | |||
| std::map<SignatureEnumDType, std::vector<size_t>> type_indices; | |||
| for (size_t i = 0; i < dtypes.size(); ++i) { | |||
| auto it = type_indexs.find(dtypes[i]); | |||
| if (it == type_indexs.end()) { | |||
| (void)type_indexs.insert(std::make_pair(dtypes[i], std::vector<size_t>{i})); | |||
| auto it = type_indices.find(dtypes[i]); | |||
| if (it == type_indices.end()) { | |||
| (void)type_indices.insert(std::make_pair(dtypes[i], std::vector<size_t>{i})); | |||
| } else { | |||
| it->second.push_back(i); | |||
| } | |||
| } | |||
| std::map<SignatureEnumDType, TypeId> dst_type; | |||
| for (auto it = type_indexs.begin(); it != type_indexs.end(); (void)++it) { | |||
| for (auto it = type_indices.begin(); it != type_indices.end(); (void)++it) { | |||
| auto type = it->first; | |||
| auto indexs = it->second; | |||
| auto indices = it->second; | |||
| // If the number of arguments belonging to the same SignatureEnumDType is less than 2, skip it. | |||
| if (indexs.size() < 2) { | |||
| if (indices.size() < 2) { | |||
| continue; | |||
| } | |||
| bool has_tensor = false; | |||
| for (const auto &index : indexs) { | |||
| for (const auto &index : indices) { | |||
| AbstractBasePtr arg_value = args_spec_list[index]; | |||
| if (arg_value->isa<abstract::AbstractRef>()) { | |||
| arg_value = arg_value->cast<abstract::AbstractRefPtr>()->ref(); | |||
| @@ -189,7 +201,7 @@ std::map<SignatureEnumDType, TypeId> GetMaxDtype(const std::vector<SignatureEnum | |||
| (void)dst_type.insert(std::make_pair(type, kTypeUnknown)); | |||
| continue; | |||
| } | |||
| (void)dst_type.insert(std::make_pair(type, GetMaxTypeId(args_spec_list, indexs, write_indexs))); | |||
| (void)dst_type.insert(std::make_pair(type, GetMaxTypeId(args_spec_list, indices, write_indices))); | |||
| } | |||
| return dst_type; | |||
| } | |||
| @@ -204,7 +216,7 @@ AnfNodePtr DoCast(const AnfNodePtr ¶m, const TypeId &type_id, const FuncGrap | |||
| void DoAutoCast(const std::string &func_name, const std::vector<Signature> &signature, | |||
| const abstract::AbstractBasePtrList &args_spec_list, const FuncGraphPtr &graph, | |||
| std::vector<AnfNodePtr> *const op_inputs, const std::set<size_t> &write_indexs) { | |||
| std::vector<AnfNodePtr> *const op_inputs, const std::set<size_t> &write_indices) { | |||
| std::vector<SignatureEnumDType> dtypes; | |||
| (void)std::transform(signature.begin(), signature.end(), std::back_inserter(dtypes), | |||
| [](const Signature &sig) { return sig.dtype; }); | |||
| @@ -213,36 +225,19 @@ void DoAutoCast(const std::string &func_name, const std::vector<Signature> &sign | |||
| return; | |||
| } | |||
| // Stat the index of the arguments with the largest type in the same SignatureEnumDType. | |||
| std::map<SignatureEnumDType, TypeId> dst_type = GetMaxDtype(dtypes, args_spec_list, write_indexs); | |||
| std::map<SignatureEnumDType, TypeId> dst_type = GetMaxDtype(dtypes, args_spec_list, write_indices); | |||
| // Identify which arg requires auto cast | |||
| for (size_t i = 0; i < args_spec_list.size(); ++i) { | |||
| auto it = dst_type.find(dtypes[i]); | |||
| if (it == dst_type.end() || it->second == kTypeUnknown) { | |||
| continue; | |||
| } | |||
| auto rw_it = write_indexs.find(i); | |||
| auto is_write = (rw_it != write_indexs.end()); | |||
| auto rw_it = write_indices.find(i); | |||
| auto is_write = (rw_it != write_indices.end()); | |||
| AbstractBasePtr arg_value = args_spec_list[i]; | |||
| if (arg_value->isa<abstract::AbstractRef>()) { | |||
| if (is_write) { | |||
| arg_value = arg_value->cast<abstract::AbstractRefPtr>()->ref_origin(); | |||
| } else { | |||
| arg_value = arg_value->cast<abstract::AbstractRefPtr>()->ref(); | |||
| } | |||
| } | |||
| TypeId arg_type_id = kTypeUnknown; | |||
| if (arg_value->isa<abstract::AbstractTensor>()) { | |||
| auto tensor = arg_value->cast<abstract::AbstractTensorPtr>(); | |||
| auto tensor_type = tensor->element()->BuildType(); | |||
| MS_EXCEPTION_IF_NULL(tensor_type); | |||
| arg_type_id = tensor_type->type_id(); | |||
| } else if (arg_value->isa<abstract::AbstractScalar>()) { | |||
| auto scalar = arg_value->cast<abstract::AbstractScalarPtr>(); | |||
| auto scalar_type = scalar->BuildType(); | |||
| MS_EXCEPTION_IF_NULL(scalar_type); | |||
| arg_type_id = scalar_type->type_id(); | |||
| } | |||
| AbstractBasePtr arg_value = args_spec_list[i]; | |||
| (void)GetTensorOrScalarTypeInfo(arg_value, is_write, &arg_type_id); | |||
| auto it_map = type_map.find(arg_type_id); | |||
| if (it_map == type_map.end()) { | |||
| continue; | |||
| @@ -279,7 +274,7 @@ AnfNodePtr BuildNewCNode(const FuncGraphPtr &func_graph, const std::string &func | |||
| } | |||
| } | |||
| std::vector<AnfNodePtr> op_inputs; | |||
| std::set<size_t> write_indexs; | |||
| std::set<size_t> write_indices; | |||
| op_inputs.push_back(NewValueNode(function)); | |||
| // Assume, the write input of op is always the first input. We check if any write op, | |||
| // and add cast op on other inputs to keep the same type with assigned parameter. | |||
| @@ -303,7 +298,7 @@ AnfNodePtr BuildNewCNode(const FuncGraphPtr &func_graph, const std::string &func | |||
| param = func_graph->NewCNode({NewValueNode(prim::kPrimGetRefValue), param}); | |||
| } else if (sig == SignatureEnumRW::kRWWrite) { | |||
| param = func_graph->NewCNode({NewValueNode(prim::kPrimGetRefOrigin), param}); | |||
| write_indexs.insert(i); | |||
| write_indices.insert(i); | |||
| } | |||
| // If sig is SignatureEnumRW::kRWRef, not do anything. | |||
| } else if (sig == SignatureEnumRW::kRWWrite && type->type_id() != kObjectTypeRefKey) { | |||
| @@ -313,7 +308,7 @@ AnfNodePtr BuildNewCNode(const FuncGraphPtr &func_graph, const std::string &func | |||
| } | |||
| // process default | |||
| ProcessDefault(func_name, args_spec_list, signature, has_var, &op_inputs); | |||
| DoAutoCast(func_name, signature, args_spec_list, func_graph, &op_inputs, write_indexs); | |||
| DoAutoCast(func_name, signature, args_spec_list, func_graph, &op_inputs, write_indices); | |||
| return func_graph->NewCNode(op_inputs); | |||
| } | |||
| } // namespace | |||
| @@ -238,6 +238,31 @@ FuncGraphPtr ConvertToBpropCut(py::object obj) { | |||
| return bprop_graph; | |||
| } | |||
| bool ConvertCellObjToFuncGraph(py::object obj, ValuePtr *const data) { | |||
| FuncGraphPtr func_graph = ConvertToFuncGraph(obj); | |||
| if (func_graph == nullptr) { | |||
| MS_LOG(ERROR) << "Parse resolve function error."; | |||
| return false; | |||
| } | |||
| // if the cell object has specified bprop, it has user-defined bprop function parse and record it | |||
| if (py::hasattr(obj, "bprop")) { | |||
| FuncGraphPtr bprop_graph = nullptr; | |||
| bool enable_bprop_debug = py::cast<bool>(py::getattr(obj, "bprop_debug")); | |||
| if (enable_bprop_debug) { | |||
| bprop_graph = ConvertToBpropCut(obj); | |||
| } else { | |||
| bprop_graph = ConvertToFuncGraph(obj, PYTHON_MOD_GET_BPROP_METHOD); | |||
| } | |||
| if (bprop_graph != nullptr) { | |||
| (void)func_graph->transforms().insert(std::make_pair("bprop", FuncGraphTransform(bprop_graph))); | |||
| (void)bprop_graph->transforms().insert(std::make_pair("primal", FuncGraphTransform(func_graph))); | |||
| func_graph->set_flags(FUNC_GRAPH_FLAG_DEFER_INLINE, true); | |||
| } | |||
| } | |||
| *data = func_graph; | |||
| return true; | |||
| } | |||
| bool ConvertOtherObj(py::object obj, ValuePtr *const data) { | |||
| auto obj_type = data_converter::GetObjType(obj); | |||
| MS_LOG(DEBUG) << "Converting the object(" << ((std::string)py::str(obj)) << ") detail type: " << obj_type << " "; | |||
| @@ -262,32 +287,12 @@ bool ConvertOtherObj(py::object obj, ValuePtr *const data) { | |||
| // Create the namespace for common class instance | |||
| // When the obj is Cell, default parse the 'construct' | |||
| if (data_converter::IsCellInstance(obj)) { | |||
| FuncGraphPtr func_graph = ConvertToFuncGraph(obj); | |||
| if (func_graph == nullptr) { | |||
| MS_LOG(ERROR) << "Parse resolve function error."; | |||
| return false; | |||
| } | |||
| // if the cell object has specified bprop, it has user-defined bprop function parse and record it | |||
| if (py::hasattr(obj, "bprop")) { | |||
| FuncGraphPtr bprop_graph = nullptr; | |||
| bool enable_bprop_debug = py::cast<bool>(py::getattr(obj, "bprop_debug")); | |||
| if (enable_bprop_debug) { | |||
| bprop_graph = ConvertToBpropCut(obj); | |||
| } else { | |||
| bprop_graph = ConvertToFuncGraph(obj, PYTHON_MOD_GET_BPROP_METHOD); | |||
| } | |||
| if (bprop_graph != nullptr) { | |||
| (void)func_graph->transforms().insert(std::make_pair("bprop", FuncGraphTransform(bprop_graph))); | |||
| (void)bprop_graph->transforms().insert(std::make_pair("primal", FuncGraphTransform(func_graph))); | |||
| func_graph->set_flags(FUNC_GRAPH_FLAG_DEFER_INLINE, true); | |||
| } | |||
| } | |||
| *data = func_graph; | |||
| } else { | |||
| py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE); | |||
| py::object namespace_var = python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL, obj); | |||
| *data = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, namespace_var); | |||
| return ConvertCellObjToFuncGraph(obj, data); | |||
| } | |||
| py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE); | |||
| py::object namespace_var = python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL, obj); | |||
| *data = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, namespace_var); | |||
| return true; | |||
| } | |||
| MS_LOG(ERROR) << "Resolve type is invalid " << ((std::string)py::str(obj)); | |||
| @@ -608,7 +608,7 @@ void Pipeline::Run() { | |||
| MS_LOG(INFO) << "End"; | |||
| } | |||
| void ProcessVmArgInner(const py::tuple &args, const ResourcePtr &res, VectorRef *arg_list) { | |||
| void ProcessVmArgInner(const py::tuple &args, const ResourcePtr &res, VectorRef *const arg_list) { | |||
| std::size_t size = args.size(); | |||
| for (std::size_t i = 0; i < size; i++) { | |||
| @@ -139,7 +139,7 @@ bool InitExecDatasetVm(const std::string &queue_name, int64_t size, int64_t batc | |||
| const std::vector<TypePtr> &types, const std::vector<std::vector<int64_t>> &shapes, | |||
| const std::vector<int64_t> &input_indexes, bool need_run); | |||
| void ProcessVmArgInner(const py::tuple &args, const ResourcePtr &res, VectorRef *arg_list); | |||
| void ProcessVmArgInner(const py::tuple &args, const ResourcePtr &res, VectorRef *const arg_list); | |||
| } // namespace pipeline | |||
| } // namespace mindspore | |||
| @@ -464,6 +464,85 @@ EvalResultPtr AnalysisEngine::ExecuteEvaluators(const std::vector<EvaluatorPtr> | |||
| return ExecuteMultipleEvaluators(evaluators, out_conf, args_conf_list); | |||
| } | |||
| void AnalysisEngine::SetUndeterminedFlag(const EvaluatorPtr &evaluator) { | |||
| auto fg_eval = evaluator->cast<FuncGraphEvaluatorPtr>(); | |||
| if (fg_eval == nullptr) { | |||
| return; | |||
| } | |||
| auto fg = fg_eval->func_graph(); | |||
| MS_EXCEPTION_IF_NULL(fg); | |||
| auto undetermined_fgs = fg->recursive_graphs(); | |||
| if (undetermined_fgs) { | |||
| auto fg_parent = fg->parent(); | |||
| MS_EXCEPTION_IF_NULL(fg_parent); | |||
| fg_parent->set_flags(kFuncGraphFlagUndetermined, true); | |||
| MS_LOG(DEBUG) << "Set graph undetermined: " << fg_parent->ToString(); | |||
| } | |||
| } | |||
| EvaluatorPtr AnalysisEngine::HandleNestedRecursion(const std::vector<EvaluatorPtr> &evaluators, | |||
| const EvaluatorPtr &eval, const AbstractBasePtrList &args_spec_list, | |||
| const EvalTraceRevIter &it, bool *continue_flag) { | |||
| *continue_flag = false; | |||
| // Find latest entry function to handle nested recursion. | |||
| EvaluatorPtr latest_entry = eval; | |||
| auto latest_entry_iter = eval_trace_.rbegin(); | |||
| for (auto r_it = eval_trace_.rbegin(); *r_it != *it;) { | |||
| auto it_temp = std::find(evaluators.begin(), evaluators.end(), r_it->first); | |||
| if (it_temp != evaluators.end()) { | |||
| latest_entry = *it_temp; | |||
| latest_entry_iter = r_it; | |||
| break; | |||
| } | |||
| latest_entry_iter = ++r_it; | |||
| } | |||
| if (latest_entry != eval) { | |||
| MS_LOG(DEBUG) << "Continue Evaluator " << eval->ToString(); | |||
| *continue_flag = true; | |||
| return latest_entry; | |||
| } | |||
| bool has_undetermined = false; | |||
| // Check whether sub loop has untraced undetermined evaluator. | |||
| std::set<std::pair<EvaluatorPtr, AbstractBasePtrList>> undetermined_evals; | |||
| for (auto r_it = eval_trace_.rbegin(); r_it != latest_entry_iter; r_it++) { | |||
| undetermined_evals.insert(*r_it); | |||
| } | |||
| MS_LOG(DEBUG) << "undetermined_evals size(): " << undetermined_evals.size(); | |||
| for (auto u_eval : undetermined_evals) { | |||
| MS_LOG(DEBUG) << u_eval.first->ToString() << " check undetermined."; | |||
| if (!undetermined_evals.count(std::make_pair(multi_poss_[u_eval.first], args_spec_list))) { | |||
| MS_LOG(DEBUG) << u_eval.first->ToString() << " has undetermined."; | |||
| has_undetermined = true; | |||
| break; | |||
| } | |||
| } | |||
| if (has_undetermined == false) { | |||
| MS_LOG(DEBUG) << eval->ToString() << " has no undetermined."; | |||
| *continue_flag = true; | |||
| return latest_entry; | |||
| } | |||
| return latest_entry; | |||
| } | |||
| EvalResultPtr AnalysisEngine::ProcessEvalResults(const AbstractBasePtrList &out_specs) { | |||
| if (out_specs.size() == 0) { | |||
| MS_LOG(EXCEPTION) << "There is an endless loop for evaluator."; | |||
| } | |||
| if (out_specs.size() == 1) { | |||
| MS_EXCEPTION_IF_NULL(out_specs[0]); | |||
| // If only one result derived, then broaden it to avoid wrong constant propagation. | |||
| return std::make_shared<EvalResult>(out_specs[0]->Broaden(), std::make_shared<AttrValueMap>()); | |||
| } | |||
| auto joined_spec = AbstractJoin(out_specs); | |||
| MS_EXCEPTION_IF_NULL(joined_spec); | |||
| MS_LOG(DEBUG) << "Multiple evaluators joined: " << joined_spec->ToString(); | |||
| return std::make_shared<EvalResult>(joined_spec, std::make_shared<AttrValueMap>()); | |||
| } | |||
| EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector<EvaluatorPtr> &evaluators, | |||
| const AnfNodeConfigPtr &out_conf, | |||
| const ConfigPtrList &args_conf_list) { | |||
| @@ -479,18 +558,7 @@ EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector<Evalua | |||
| return conf->GetEvaluatedValue()->abstract(); | |||
| }); | |||
| for (auto eval : evaluators) { | |||
| auto fg_eval = eval->cast<FuncGraphEvaluatorPtr>(); | |||
| if (fg_eval) { | |||
| auto fg = fg_eval->func_graph(); | |||
| MS_EXCEPTION_IF_NULL(fg); | |||
| auto undetermined_fgs = fg->recursive_graphs(); | |||
| if (undetermined_fgs) { | |||
| auto fg_parent = fg->parent(); | |||
| MS_EXCEPTION_IF_NULL(fg_parent); | |||
| fg_parent->set_flags(kFuncGraphFlagUndetermined, true); | |||
| MS_LOG(DEBUG) << "Set graph undetermined: " << fg_parent->ToString(); | |||
| } | |||
| } | |||
| SetUndeterminedFlag(eval); | |||
| auto current_inf = std::make_pair(eval, args_spec_list); | |||
| MS_LOG(DEBUG) << "Check Evaluator " << eval->ToString(); | |||
| @@ -510,40 +578,9 @@ EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector<Evalua | |||
| multi_poss_.clear(); | |||
| } | |||
| } else if (it != eval_trace_.rbegin()) { | |||
| // Find latest entry function to handle nested recursion. | |||
| EvaluatorPtr latest_entry = eval; | |||
| auto latest_entry_iter = eval_trace_.rbegin(); | |||
| for (auto r_it = eval_trace_.rbegin(); *r_it != *it;) { | |||
| auto it_temp = std::find(evaluators.begin(), evaluators.end(), r_it->first); | |||
| if (it_temp != evaluators.end()) { | |||
| latest_entry = *it_temp; | |||
| latest_entry_iter = r_it; | |||
| break; | |||
| } | |||
| latest_entry_iter = ++r_it; | |||
| } | |||
| if (latest_entry != eval) { | |||
| MS_LOG(DEBUG) << "Continue Evaluator " << eval->ToString(); | |||
| continue; | |||
| } | |||
| bool has_undetermined = false; | |||
| // Check whether sub loop has untraced undetermined evaluator. | |||
| std::set<std::pair<EvaluatorPtr, AbstractBasePtrList>> undetermined_evals; | |||
| for (auto r_it = eval_trace_.rbegin(); r_it != latest_entry_iter; r_it++) { | |||
| undetermined_evals.insert(*r_it); | |||
| } | |||
| MS_LOG(DEBUG) << "undetermined_evals size(): " << undetermined_evals.size(); | |||
| for (auto u_eval : undetermined_evals) { | |||
| MS_LOG(DEBUG) << u_eval.first->ToString() << " check undetermined."; | |||
| if (!undetermined_evals.count(std::make_pair(multi_poss_[u_eval.first], args_spec_list))) { | |||
| MS_LOG(DEBUG) << u_eval.first->ToString() << " has undetermined."; | |||
| has_undetermined = true; | |||
| break; | |||
| } | |||
| } | |||
| if (has_undetermined == false) { | |||
| MS_LOG(DEBUG) << eval->ToString() << " has no undetermined."; | |||
| bool continue_flag = false; | |||
| auto latest_entry = HandleNestedRecursion(evaluators, eval, args_spec_list, it, &continue_flag); | |||
| if (continue_flag) { | |||
| continue; | |||
| } | |||
| @@ -558,19 +595,8 @@ EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector<Evalua | |||
| } | |||
| } | |||
| } | |||
| if (out_specs.size() == 0) { | |||
| MS_LOG(EXCEPTION) << "There is an endless loop for evaluator."; | |||
| } | |||
| if (out_specs.size() == 1) { | |||
| MS_EXCEPTION_IF_NULL(out_specs[0]); | |||
| // If only one result derived, then broaden it to avoid wrong constant propagation. | |||
| return std::make_shared<EvalResult>(out_specs[0]->Broaden(), std::make_shared<AttrValueMap>()); | |||
| } | |||
| auto joined_spec = AbstractJoin(out_specs); | |||
| MS_EXCEPTION_IF_NULL(joined_spec); | |||
| MS_LOG(DEBUG) << "Multiple evaluators joined: " << joined_spec->ToString(); | |||
| return std::make_shared<EvalResult>(joined_spec, std::make_shared<AttrValueMap>()); | |||
| return ProcessEvalResults(out_specs); | |||
| } | |||
| EvalResultPtr AnfNodeConfig::GetEvaluatedValue() { | |||
| @@ -172,6 +172,8 @@ struct AnalysisResult { | |||
| AnalysisContextPtr context; | |||
| }; | |||
| using EvalTraceRevIter = std::list<std::pair<EvaluatorPtr, AbstractBasePtrList>>::reverse_iterator; | |||
| class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> { | |||
| public: | |||
| AnalysisEngine(const PrimEvaluatorMap &prim_evaluator_map, const FuncGraphManagerPtr &func_graph_manager) | |||
| @@ -222,6 +224,12 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> { | |||
| std::unordered_map<PrimitivePyPtr, EvaluatorPtr> prim_py_evaluators_; | |||
| private: | |||
| void SetUndeterminedFlag(const EvaluatorPtr &evaluator); | |||
| EvaluatorPtr HandleNestedRecursion(const std::vector<EvaluatorPtr> &evaluators, const EvaluatorPtr &eval, | |||
| const AbstractBasePtrList &args_spec_list, const EvalTraceRevIter &it, | |||
| bool *continue_flag); | |||
| EvalResultPtr ProcessEvalResults(const AbstractBasePtrList &out_specs); | |||
| const PrimEvaluatorMap &prim_constructors_; | |||
| FuncGraphManagerPtr func_graph_manager_; | |||
| std::unordered_map<AbstractFunctionPtr, EvaluatorPtr> constructors_; | |||