Merge pull request !1822 from fary86/codex_big_functionstags/v0.5.0-beta
| @@ -124,6 +124,8 @@ class AnalyzedFuncGraphExporter : public AnfExporter { | |||||
| void ExportOneFuncGraph(std::ofstream &ofs, const FuncGraphPtr &func_graph); | void ExportOneFuncGraph(std::ofstream &ofs, const FuncGraphPtr &func_graph); | ||||
| void OutputCNodes(std::ofstream &ofs, const std::vector<AnfNodePtr> &nodes, const FuncGraphPtr &func_graph); | void OutputCNodes(std::ofstream &ofs, const std::vector<AnfNodePtr> &nodes, const FuncGraphPtr &func_graph); | ||||
| void OutputCNode(std::ofstream &ofs, const CNodePtr &cnode, const FuncGraphPtr &func_graph, int *idx, | |||||
| std::map<AnfNodePtr, int> *const apply_map); | |||||
| private: | private: | ||||
| std::string GetNodeType(const AnfNodePtr &nd) override; | std::string GetNodeType(const AnfNodePtr &nd) override; | ||||
| @@ -169,7 +171,7 @@ std::string AnalyzedFuncGraphExporter::GetNodeType(const AnfNodePtr &node) { | |||||
| } | } | ||||
| auto abs = ret->abstract(); | auto abs = ret->abstract(); | ||||
| if (abs == nullptr) { | if (abs == nullptr) { | ||||
| return nullptr; | |||||
| return "Undefined"; | |||||
| } | } | ||||
| auto dtype = abs->BuildType(); | auto dtype = abs->BuildType(); | ||||
| auto shape = abs->BuildShape(); | auto shape = abs->BuildShape(); | ||||
| @@ -247,6 +249,51 @@ AnalysisContextPtr AnalyzedFuncGraphExporter::ProcessFuncGraphCall(const CNodePt | |||||
| return ctx; | return ctx; | ||||
| } | } | ||||
| void AnalyzedFuncGraphExporter::OutputCNode(std::ofstream &ofs, const CNodePtr &cnode, const FuncGraphPtr &func_graph, | |||||
| int *idx, std::map<AnfNodePtr, int> *const apply_map) { | |||||
| auto &inputs = cnode->inputs(); | |||||
| std::string op_text = GetAnfNodeText(func_graph, inputs[0], *apply_map); | |||||
| // non-return node | |||||
| if (cnode != func_graph->get_return()) { | |||||
| int apply_idx = (*idx)++; | |||||
| (*apply_map)[cnode] = apply_idx; | |||||
| std::string type_info = GetNodeType(cnode); | |||||
| if (type_info == "Undefined") { | |||||
| ofs << " %" << apply_idx << " = " << op_text << "("; | |||||
| } else { | |||||
| ofs << " %" << apply_idx << " : " << type_info << " = " << op_text << "("; | |||||
| } | |||||
| } else { | |||||
| ofs << " " << op_text << "("; | |||||
| } | |||||
| for (size_t i = 1; i < inputs.size(); ++i) { | |||||
| if (i != 1) { | |||||
| ofs << ", "; | |||||
| } | |||||
| AnfNodePtr arg = inputs[i]; | |||||
| ofs << GetAnfNodeText(func_graph, arg, *apply_map); | |||||
| } | |||||
| ofs << ")"; | |||||
| // process function graph call | |||||
| auto ctx = ProcessFuncGraphCall(cnode); | |||||
| // output comment | |||||
| OutputStatementComment(ofs, cnode); | |||||
| if (ctx != nullptr) { | |||||
| ofs << " @ctx.addr=" << ctx.get(); | |||||
| } | |||||
| ofs << "\n"; | |||||
| if (label_manage::GetGlobalTraceLabelType() == label_manage::TraceLabelType::kWithUniqueId) { | |||||
| ofs << trace::GetDebugInfo(cnode->debug_info(), " # ", kSourceLineTipDiscard) << "#" | |||||
| << label_manage::Label(cnode->debug_info()) << "\n"; | |||||
| } else { | |||||
| ofs << trace::GetDebugInfo(cnode->debug_info(), " # ", kSourceLineTipDiscard) << "\n"; | |||||
| } | |||||
| } | |||||
| void AnalyzedFuncGraphExporter::OutputCNodes(std::ofstream &ofs, const std::vector<AnfNodePtr> &nodes, | void AnalyzedFuncGraphExporter::OutputCNodes(std::ofstream &ofs, const std::vector<AnfNodePtr> &nodes, | ||||
| const FuncGraphPtr &func_graph) { | const FuncGraphPtr &func_graph) { | ||||
| if (func_graph == nullptr) { | if (func_graph == nullptr) { | ||||
| @@ -267,47 +314,7 @@ void AnalyzedFuncGraphExporter::OutputCNodes(std::ofstream &ofs, const std::vect | |||||
| } | } | ||||
| auto cnode = node->cast<CNodePtr>(); | auto cnode = node->cast<CNodePtr>(); | ||||
| auto &inputs = cnode->inputs(); | |||||
| std::string op_text = GetAnfNodeText(func_graph, inputs[0], apply_map); | |||||
| // non-return node | |||||
| if (node != func_graph->get_return()) { | |||||
| int apply_idx = idx++; | |||||
| apply_map[node] = apply_idx; | |||||
| std::string type_info = GetNodeType(node); | |||||
| if (type_info == "Undefined") { | |||||
| ofs << " %" << apply_idx << " = " << op_text << "("; | |||||
| } else { | |||||
| ofs << " %" << apply_idx << " : " << type_info << " = " << op_text << "("; | |||||
| } | |||||
| } else { | |||||
| ofs << " " << op_text << "("; | |||||
| } | |||||
| for (size_t i = 1; i < inputs.size(); ++i) { | |||||
| if (i != 1) { | |||||
| ofs << ", "; | |||||
| } | |||||
| AnfNodePtr arg = inputs[i]; | |||||
| ofs << GetAnfNodeText(func_graph, arg, apply_map); | |||||
| } | |||||
| ofs << ")"; | |||||
| // process function graph call | |||||
| auto ctx = ProcessFuncGraphCall(cnode); | |||||
| // output comment | |||||
| OutputStatementComment(ofs, cnode); | |||||
| if (ctx != nullptr) { | |||||
| ofs << " @ctx.addr=" << ctx.get(); | |||||
| } | |||||
| ofs << "\n"; | |||||
| if (label_manage::GetGlobalTraceLabelType() == label_manage::TraceLabelType::kWithUniqueId) { | |||||
| ofs << trace::GetDebugInfo(cnode->debug_info(), " # ", kSourceLineTipDiscard) << "#" | |||||
| << label_manage::Label(cnode->debug_info()) << "\n"; | |||||
| } else { | |||||
| ofs << trace::GetDebugInfo(cnode->debug_info(), " # ", kSourceLineTipDiscard) << "\n"; | |||||
| } | |||||
| OutputCNode(ofs, cnode, func_graph, &idx, &apply_map); | |||||
| } | } | ||||
| } | } | ||||
| @@ -76,44 +76,56 @@ bool CompareTensorScalarType(const TypeId &tensor_type, const size_t &t_type_num | |||||
| return true; | return true; | ||||
| } | } | ||||
| void setMaxType(TypeId *max_type_id, TypeId *max_type, size_t *max_type_number, const TypeId type_id, const TypeId type, | |||||
| void SetMaxType(TypeId *max_type_id, TypeId *max_type, size_t *max_type_number, const TypeId type_id, const TypeId type, | |||||
| const size_t type_number) { | const size_t type_number) { | ||||
| *max_type_id = type_id; | *max_type_id = type_id; | ||||
| *max_type = type; | *max_type = type; | ||||
| *max_type_number = type_number; | *max_type_number = type_number; | ||||
| } | } | ||||
| TypeId GetMaxTypeId(const abstract::AbstractBasePtrList &args_spec_list, std::vector<size_t> indexs, | |||||
| const std::set<size_t> &write_indexs) { | |||||
| bool GetTensorOrScalarTypeInfo(AbstractBasePtr arg_value, bool is_write, TypeId *arg_type_id, | |||||
| TypeId *arg_type = nullptr) { | |||||
| if (arg_value->isa<abstract::AbstractRef>()) { | |||||
| if (is_write) { | |||||
| arg_value = arg_value->cast<abstract::AbstractRefPtr>()->ref_origin(); | |||||
| } else { | |||||
| arg_value = arg_value->cast<abstract::AbstractRefPtr>()->ref(); | |||||
| } | |||||
| } | |||||
| if (arg_value->isa<abstract::AbstractTensor>()) { | |||||
| auto tensor = arg_value->cast<abstract::AbstractTensorPtr>(); | |||||
| auto tensor_type = tensor->element()->BuildType(); | |||||
| MS_EXCEPTION_IF_NULL(tensor_type); | |||||
| *arg_type_id = tensor_type->type_id(); | |||||
| if (arg_type != nullptr) { | |||||
| *arg_type = kObjectTypeTensorType; | |||||
| } | |||||
| return true; | |||||
| } | |||||
| if (arg_value->isa<abstract::AbstractScalar>()) { | |||||
| auto scalar = arg_value->cast<abstract::AbstractScalarPtr>(); | |||||
| auto scalar_type = scalar->BuildType(); | |||||
| MS_EXCEPTION_IF_NULL(scalar_type); | |||||
| *arg_type_id = scalar_type->type_id(); | |||||
| if (arg_type != nullptr) { | |||||
| *arg_type = kObjectTypeNumber; | |||||
| } | |||||
| return true; | |||||
| } | |||||
| return false; | |||||
| } | |||||
| TypeId GetMaxTypeId(const abstract::AbstractBasePtrList &args_spec_list, std::vector<size_t> indices, | |||||
| const std::set<size_t> &write_indices) { | |||||
| TypeId max_type_id = kTypeUnknown; | TypeId max_type_id = kTypeUnknown; | ||||
| TypeId max_type = kTypeUnknown; | TypeId max_type = kTypeUnknown; | ||||
| size_t max_type_number = 0; | size_t max_type_number = 0; | ||||
| bool has_int8 = false; | bool has_int8 = false; | ||||
| for (const auto &index : indexs) { | |||||
| for (const auto &index : indices) { | |||||
| TypeId arg_type_id = kTypeUnknown; | TypeId arg_type_id = kTypeUnknown; | ||||
| TypeId arg_type = kTypeUnknown; | TypeId arg_type = kTypeUnknown; | ||||
| AbstractBasePtr arg_value = args_spec_list[index]; | |||||
| if (arg_value->isa<abstract::AbstractRef>()) { | |||||
| auto is_write = (write_indexs.find(index) != write_indexs.end()); | |||||
| if (is_write) { | |||||
| arg_value = arg_value->cast<abstract::AbstractRefPtr>()->ref_origin(); | |||||
| } else { | |||||
| arg_value = arg_value->cast<abstract::AbstractRefPtr>()->ref(); | |||||
| } | |||||
| } | |||||
| if (arg_value->isa<abstract::AbstractTensor>()) { | |||||
| auto tensor = arg_value->cast<abstract::AbstractTensorPtr>(); | |||||
| auto tensor_type = tensor->element()->BuildType(); | |||||
| MS_EXCEPTION_IF_NULL(tensor_type); | |||||
| arg_type_id = tensor_type->type_id(); | |||||
| arg_type = kObjectTypeTensorType; | |||||
| } else if (arg_value->isa<abstract::AbstractScalar>()) { | |||||
| auto scalar = arg_value->cast<abstract::AbstractScalarPtr>(); | |||||
| auto scalar_type = scalar->BuildType(); | |||||
| MS_EXCEPTION_IF_NULL(scalar_type); | |||||
| arg_type_id = scalar_type->type_id(); | |||||
| arg_type = kObjectTypeNumber; | |||||
| } else { | |||||
| auto is_write = (write_indices.find(index) != write_indices.end()); | |||||
| if (!GetTensorOrScalarTypeInfo(args_spec_list[index], is_write, &arg_type_id, &arg_type)) { | |||||
| continue; | continue; | ||||
| } | } | ||||
| auto it = type_map.find(arg_type_id); | auto it = type_map.find(arg_type_id); | ||||
| @@ -124,22 +136,22 @@ TypeId GetMaxTypeId(const abstract::AbstractBasePtrList &args_spec_list, std::ve | |||||
| has_int8 = true; | has_int8 = true; | ||||
| } | } | ||||
| if (max_type_id == kTypeUnknown) { | if (max_type_id == kTypeUnknown) { | ||||
| setMaxType(&max_type_id, &max_type, &max_type_number, arg_type_id, arg_type, it->second); | |||||
| SetMaxType(&max_type_id, &max_type, &max_type_number, arg_type_id, arg_type, it->second); | |||||
| continue; | continue; | ||||
| } | } | ||||
| if (max_type == arg_type) { | if (max_type == arg_type) { | ||||
| if (it->second > max_type_number) { | if (it->second > max_type_number) { | ||||
| setMaxType(&max_type_id, &max_type, &max_type_number, arg_type_id, arg_type, it->second); | |||||
| SetMaxType(&max_type_id, &max_type, &max_type_number, arg_type_id, arg_type, it->second); | |||||
| } | } | ||||
| } else { | } else { | ||||
| if (arg_type == kObjectTypeTensorType) { | if (arg_type == kObjectTypeTensorType) { | ||||
| if (CompareTensorScalarType(arg_type_id, it->second, max_type_id, max_type_number)) { | if (CompareTensorScalarType(arg_type_id, it->second, max_type_id, max_type_number)) { | ||||
| setMaxType(&max_type_id, &max_type, &max_type_number, arg_type_id, arg_type, it->second); | |||||
| SetMaxType(&max_type_id, &max_type, &max_type_number, arg_type_id, arg_type, it->second); | |||||
| } | } | ||||
| } else { | } else { | ||||
| if (!CompareTensorScalarType(max_type_id, max_type_number, arg_type_id, it->second)) { | if (!CompareTensorScalarType(max_type_id, max_type_number, arg_type_id, it->second)) { | ||||
| setMaxType(&max_type_id, &max_type, &max_type_number, arg_type_id, arg_type, it->second); | |||||
| SetMaxType(&max_type_id, &max_type, &max_type_number, arg_type_id, arg_type, it->second); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -154,28 +166,28 @@ TypeId GetMaxTypeId(const abstract::AbstractBasePtrList &args_spec_list, std::ve | |||||
| // Get the largest type of index in the same SignatureEnumDType of arguments. | // Get the largest type of index in the same SignatureEnumDType of arguments. | ||||
| std::map<SignatureEnumDType, TypeId> GetMaxDtype(const std::vector<SignatureEnumDType> &dtypes, | std::map<SignatureEnumDType, TypeId> GetMaxDtype(const std::vector<SignatureEnumDType> &dtypes, | ||||
| const abstract::AbstractBasePtrList &args_spec_list, | const abstract::AbstractBasePtrList &args_spec_list, | ||||
| const std::set<size_t> &write_indexs) { | |||||
| const std::set<size_t> &write_indices) { | |||||
| // record index for signature.dtypes of the same type | // record index for signature.dtypes of the same type | ||||
| // eg. [T, T1, T, T2, T, T1, T3] -> {{T:(0,2,4)}, {T1:(1,5)}, {T2:(3)}, {T3:(6)}} | // eg. [T, T1, T, T2, T, T1, T3] -> {{T:(0,2,4)}, {T1:(1,5)}, {T2:(3)}, {T3:(6)}} | ||||
| std::map<SignatureEnumDType, std::vector<size_t>> type_indexs; | |||||
| std::map<SignatureEnumDType, std::vector<size_t>> type_indices; | |||||
| for (size_t i = 0; i < dtypes.size(); ++i) { | for (size_t i = 0; i < dtypes.size(); ++i) { | ||||
| auto it = type_indexs.find(dtypes[i]); | |||||
| if (it == type_indexs.end()) { | |||||
| (void)type_indexs.insert(std::make_pair(dtypes[i], std::vector<size_t>{i})); | |||||
| auto it = type_indices.find(dtypes[i]); | |||||
| if (it == type_indices.end()) { | |||||
| (void)type_indices.insert(std::make_pair(dtypes[i], std::vector<size_t>{i})); | |||||
| } else { | } else { | ||||
| it->second.push_back(i); | it->second.push_back(i); | ||||
| } | } | ||||
| } | } | ||||
| std::map<SignatureEnumDType, TypeId> dst_type; | std::map<SignatureEnumDType, TypeId> dst_type; | ||||
| for (auto it = type_indexs.begin(); it != type_indexs.end(); (void)++it) { | |||||
| for (auto it = type_indices.begin(); it != type_indices.end(); (void)++it) { | |||||
| auto type = it->first; | auto type = it->first; | ||||
| auto indexs = it->second; | |||||
| auto indices = it->second; | |||||
| // If the number of arguments belonging to the same SignatureEnumDType is less than 2, skip it. | // If the number of arguments belonging to the same SignatureEnumDType is less than 2, skip it. | ||||
| if (indexs.size() < 2) { | |||||
| if (indices.size() < 2) { | |||||
| continue; | continue; | ||||
| } | } | ||||
| bool has_tensor = false; | bool has_tensor = false; | ||||
| for (const auto &index : indexs) { | |||||
| for (const auto &index : indices) { | |||||
| AbstractBasePtr arg_value = args_spec_list[index]; | AbstractBasePtr arg_value = args_spec_list[index]; | ||||
| if (arg_value->isa<abstract::AbstractRef>()) { | if (arg_value->isa<abstract::AbstractRef>()) { | ||||
| arg_value = arg_value->cast<abstract::AbstractRefPtr>()->ref(); | arg_value = arg_value->cast<abstract::AbstractRefPtr>()->ref(); | ||||
| @@ -189,7 +201,7 @@ std::map<SignatureEnumDType, TypeId> GetMaxDtype(const std::vector<SignatureEnum | |||||
| (void)dst_type.insert(std::make_pair(type, kTypeUnknown)); | (void)dst_type.insert(std::make_pair(type, kTypeUnknown)); | ||||
| continue; | continue; | ||||
| } | } | ||||
| (void)dst_type.insert(std::make_pair(type, GetMaxTypeId(args_spec_list, indexs, write_indexs))); | |||||
| (void)dst_type.insert(std::make_pair(type, GetMaxTypeId(args_spec_list, indices, write_indices))); | |||||
| } | } | ||||
| return dst_type; | return dst_type; | ||||
| } | } | ||||
| @@ -204,7 +216,7 @@ AnfNodePtr DoCast(const AnfNodePtr ¶m, const TypeId &type_id, const FuncGrap | |||||
| void DoAutoCast(const std::string &func_name, const std::vector<Signature> &signature, | void DoAutoCast(const std::string &func_name, const std::vector<Signature> &signature, | ||||
| const abstract::AbstractBasePtrList &args_spec_list, const FuncGraphPtr &graph, | const abstract::AbstractBasePtrList &args_spec_list, const FuncGraphPtr &graph, | ||||
| std::vector<AnfNodePtr> *const op_inputs, const std::set<size_t> &write_indexs) { | |||||
| std::vector<AnfNodePtr> *const op_inputs, const std::set<size_t> &write_indices) { | |||||
| std::vector<SignatureEnumDType> dtypes; | std::vector<SignatureEnumDType> dtypes; | ||||
| (void)std::transform(signature.begin(), signature.end(), std::back_inserter(dtypes), | (void)std::transform(signature.begin(), signature.end(), std::back_inserter(dtypes), | ||||
| [](const Signature &sig) { return sig.dtype; }); | [](const Signature &sig) { return sig.dtype; }); | ||||
| @@ -213,36 +225,19 @@ void DoAutoCast(const std::string &func_name, const std::vector<Signature> &sign | |||||
| return; | return; | ||||
| } | } | ||||
| // Stat the index of the arguments with the largest type in the same SignatureEnumDType. | // Stat the index of the arguments with the largest type in the same SignatureEnumDType. | ||||
| std::map<SignatureEnumDType, TypeId> dst_type = GetMaxDtype(dtypes, args_spec_list, write_indexs); | |||||
| std::map<SignatureEnumDType, TypeId> dst_type = GetMaxDtype(dtypes, args_spec_list, write_indices); | |||||
| // Identify which arg requires auto cast | // Identify which arg requires auto cast | ||||
| for (size_t i = 0; i < args_spec_list.size(); ++i) { | for (size_t i = 0; i < args_spec_list.size(); ++i) { | ||||
| auto it = dst_type.find(dtypes[i]); | auto it = dst_type.find(dtypes[i]); | ||||
| if (it == dst_type.end() || it->second == kTypeUnknown) { | if (it == dst_type.end() || it->second == kTypeUnknown) { | ||||
| continue; | continue; | ||||
| } | } | ||||
| auto rw_it = write_indexs.find(i); | |||||
| auto is_write = (rw_it != write_indexs.end()); | |||||
| auto rw_it = write_indices.find(i); | |||||
| auto is_write = (rw_it != write_indices.end()); | |||||
| AbstractBasePtr arg_value = args_spec_list[i]; | |||||
| if (arg_value->isa<abstract::AbstractRef>()) { | |||||
| if (is_write) { | |||||
| arg_value = arg_value->cast<abstract::AbstractRefPtr>()->ref_origin(); | |||||
| } else { | |||||
| arg_value = arg_value->cast<abstract::AbstractRefPtr>()->ref(); | |||||
| } | |||||
| } | |||||
| TypeId arg_type_id = kTypeUnknown; | TypeId arg_type_id = kTypeUnknown; | ||||
| if (arg_value->isa<abstract::AbstractTensor>()) { | |||||
| auto tensor = arg_value->cast<abstract::AbstractTensorPtr>(); | |||||
| auto tensor_type = tensor->element()->BuildType(); | |||||
| MS_EXCEPTION_IF_NULL(tensor_type); | |||||
| arg_type_id = tensor_type->type_id(); | |||||
| } else if (arg_value->isa<abstract::AbstractScalar>()) { | |||||
| auto scalar = arg_value->cast<abstract::AbstractScalarPtr>(); | |||||
| auto scalar_type = scalar->BuildType(); | |||||
| MS_EXCEPTION_IF_NULL(scalar_type); | |||||
| arg_type_id = scalar_type->type_id(); | |||||
| } | |||||
| AbstractBasePtr arg_value = args_spec_list[i]; | |||||
| (void)GetTensorOrScalarTypeInfo(arg_value, is_write, &arg_type_id); | |||||
| auto it_map = type_map.find(arg_type_id); | auto it_map = type_map.find(arg_type_id); | ||||
| if (it_map == type_map.end()) { | if (it_map == type_map.end()) { | ||||
| continue; | continue; | ||||
| @@ -279,7 +274,7 @@ AnfNodePtr BuildNewCNode(const FuncGraphPtr &func_graph, const std::string &func | |||||
| } | } | ||||
| } | } | ||||
| std::vector<AnfNodePtr> op_inputs; | std::vector<AnfNodePtr> op_inputs; | ||||
| std::set<size_t> write_indexs; | |||||
| std::set<size_t> write_indices; | |||||
| op_inputs.push_back(NewValueNode(function)); | op_inputs.push_back(NewValueNode(function)); | ||||
| // Assume, the write input of op is always the first input. We check if any write op, | // Assume, the write input of op is always the first input. We check if any write op, | ||||
| // and add cast op on other inputs to keep the same type with assigned parameter. | // and add cast op on other inputs to keep the same type with assigned parameter. | ||||
| @@ -303,7 +298,7 @@ AnfNodePtr BuildNewCNode(const FuncGraphPtr &func_graph, const std::string &func | |||||
| param = func_graph->NewCNode({NewValueNode(prim::kPrimGetRefValue), param}); | param = func_graph->NewCNode({NewValueNode(prim::kPrimGetRefValue), param}); | ||||
| } else if (sig == SignatureEnumRW::kRWWrite) { | } else if (sig == SignatureEnumRW::kRWWrite) { | ||||
| param = func_graph->NewCNode({NewValueNode(prim::kPrimGetRefOrigin), param}); | param = func_graph->NewCNode({NewValueNode(prim::kPrimGetRefOrigin), param}); | ||||
| write_indexs.insert(i); | |||||
| write_indices.insert(i); | |||||
| } | } | ||||
| // If sig is SignatureEnumRW::kRWRef, not do anything. | // If sig is SignatureEnumRW::kRWRef, not do anything. | ||||
| } else if (sig == SignatureEnumRW::kRWWrite && type->type_id() != kObjectTypeRefKey) { | } else if (sig == SignatureEnumRW::kRWWrite && type->type_id() != kObjectTypeRefKey) { | ||||
| @@ -313,7 +308,7 @@ AnfNodePtr BuildNewCNode(const FuncGraphPtr &func_graph, const std::string &func | |||||
| } | } | ||||
| // process default | // process default | ||||
| ProcessDefault(func_name, args_spec_list, signature, has_var, &op_inputs); | ProcessDefault(func_name, args_spec_list, signature, has_var, &op_inputs); | ||||
| DoAutoCast(func_name, signature, args_spec_list, func_graph, &op_inputs, write_indexs); | |||||
| DoAutoCast(func_name, signature, args_spec_list, func_graph, &op_inputs, write_indices); | |||||
| return func_graph->NewCNode(op_inputs); | return func_graph->NewCNode(op_inputs); | ||||
| } | } | ||||
| } // namespace | } // namespace | ||||
| @@ -238,6 +238,31 @@ FuncGraphPtr ConvertToBpropCut(py::object obj) { | |||||
| return bprop_graph; | return bprop_graph; | ||||
| } | } | ||||
| bool ConvertCellObjToFuncGraph(py::object obj, ValuePtr *const data) { | |||||
| FuncGraphPtr func_graph = ConvertToFuncGraph(obj); | |||||
| if (func_graph == nullptr) { | |||||
| MS_LOG(ERROR) << "Parse resolve function error."; | |||||
| return false; | |||||
| } | |||||
| // if the cell object has specified bprop, it has user-defined bprop function parse and record it | |||||
| if (py::hasattr(obj, "bprop")) { | |||||
| FuncGraphPtr bprop_graph = nullptr; | |||||
| bool enable_bprop_debug = py::cast<bool>(py::getattr(obj, "bprop_debug")); | |||||
| if (enable_bprop_debug) { | |||||
| bprop_graph = ConvertToBpropCut(obj); | |||||
| } else { | |||||
| bprop_graph = ConvertToFuncGraph(obj, PYTHON_MOD_GET_BPROP_METHOD); | |||||
| } | |||||
| if (bprop_graph != nullptr) { | |||||
| (void)func_graph->transforms().insert(std::make_pair("bprop", FuncGraphTransform(bprop_graph))); | |||||
| (void)bprop_graph->transforms().insert(std::make_pair("primal", FuncGraphTransform(func_graph))); | |||||
| func_graph->set_flags(FUNC_GRAPH_FLAG_DEFER_INLINE, true); | |||||
| } | |||||
| } | |||||
| *data = func_graph; | |||||
| return true; | |||||
| } | |||||
| bool ConvertOtherObj(py::object obj, ValuePtr *const data) { | bool ConvertOtherObj(py::object obj, ValuePtr *const data) { | ||||
| auto obj_type = data_converter::GetObjType(obj); | auto obj_type = data_converter::GetObjType(obj); | ||||
| MS_LOG(DEBUG) << "Converting the object(" << ((std::string)py::str(obj)) << ") detail type: " << obj_type << " "; | MS_LOG(DEBUG) << "Converting the object(" << ((std::string)py::str(obj)) << ") detail type: " << obj_type << " "; | ||||
| @@ -262,32 +287,12 @@ bool ConvertOtherObj(py::object obj, ValuePtr *const data) { | |||||
| // Create the namespace for common class instance | // Create the namespace for common class instance | ||||
| // When the obj is Cell, default parse the 'construct' | // When the obj is Cell, default parse the 'construct' | ||||
| if (data_converter::IsCellInstance(obj)) { | if (data_converter::IsCellInstance(obj)) { | ||||
| FuncGraphPtr func_graph = ConvertToFuncGraph(obj); | |||||
| if (func_graph == nullptr) { | |||||
| MS_LOG(ERROR) << "Parse resolve function error."; | |||||
| return false; | |||||
| } | |||||
| // if the cell object has specified bprop, it has user-defined bprop function parse and record it | |||||
| if (py::hasattr(obj, "bprop")) { | |||||
| FuncGraphPtr bprop_graph = nullptr; | |||||
| bool enable_bprop_debug = py::cast<bool>(py::getattr(obj, "bprop_debug")); | |||||
| if (enable_bprop_debug) { | |||||
| bprop_graph = ConvertToBpropCut(obj); | |||||
| } else { | |||||
| bprop_graph = ConvertToFuncGraph(obj, PYTHON_MOD_GET_BPROP_METHOD); | |||||
| } | |||||
| if (bprop_graph != nullptr) { | |||||
| (void)func_graph->transforms().insert(std::make_pair("bprop", FuncGraphTransform(bprop_graph))); | |||||
| (void)bprop_graph->transforms().insert(std::make_pair("primal", FuncGraphTransform(func_graph))); | |||||
| func_graph->set_flags(FUNC_GRAPH_FLAG_DEFER_INLINE, true); | |||||
| } | |||||
| } | |||||
| *data = func_graph; | |||||
| } else { | |||||
| py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE); | |||||
| py::object namespace_var = python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL, obj); | |||||
| *data = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, namespace_var); | |||||
| return ConvertCellObjToFuncGraph(obj, data); | |||||
| } | } | ||||
| py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE); | |||||
| py::object namespace_var = python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL, obj); | |||||
| *data = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, namespace_var); | |||||
| return true; | return true; | ||||
| } | } | ||||
| MS_LOG(ERROR) << "Resolve type is invalid " << ((std::string)py::str(obj)); | MS_LOG(ERROR) << "Resolve type is invalid " << ((std::string)py::str(obj)); | ||||
| @@ -608,7 +608,7 @@ void Pipeline::Run() { | |||||
| MS_LOG(INFO) << "End"; | MS_LOG(INFO) << "End"; | ||||
| } | } | ||||
| void ProcessVmArgInner(const py::tuple &args, const ResourcePtr &res, VectorRef *arg_list) { | |||||
| void ProcessVmArgInner(const py::tuple &args, const ResourcePtr &res, VectorRef *const arg_list) { | |||||
| std::size_t size = args.size(); | std::size_t size = args.size(); | ||||
| for (std::size_t i = 0; i < size; i++) { | for (std::size_t i = 0; i < size; i++) { | ||||
| @@ -139,7 +139,7 @@ bool InitExecDatasetVm(const std::string &queue_name, int64_t size, int64_t batc | |||||
| const std::vector<TypePtr> &types, const std::vector<std::vector<int64_t>> &shapes, | const std::vector<TypePtr> &types, const std::vector<std::vector<int64_t>> &shapes, | ||||
| const std::vector<int64_t> &input_indexes, bool need_run); | const std::vector<int64_t> &input_indexes, bool need_run); | ||||
| void ProcessVmArgInner(const py::tuple &args, const ResourcePtr &res, VectorRef *arg_list); | |||||
| void ProcessVmArgInner(const py::tuple &args, const ResourcePtr &res, VectorRef *const arg_list); | |||||
| } // namespace pipeline | } // namespace pipeline | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -464,6 +464,85 @@ EvalResultPtr AnalysisEngine::ExecuteEvaluators(const std::vector<EvaluatorPtr> | |||||
| return ExecuteMultipleEvaluators(evaluators, out_conf, args_conf_list); | return ExecuteMultipleEvaluators(evaluators, out_conf, args_conf_list); | ||||
| } | } | ||||
| void AnalysisEngine::SetUndeterminedFlag(const EvaluatorPtr &evaluator) { | |||||
| auto fg_eval = evaluator->cast<FuncGraphEvaluatorPtr>(); | |||||
| if (fg_eval == nullptr) { | |||||
| return; | |||||
| } | |||||
| auto fg = fg_eval->func_graph(); | |||||
| MS_EXCEPTION_IF_NULL(fg); | |||||
| auto undetermined_fgs = fg->recursive_graphs(); | |||||
| if (undetermined_fgs) { | |||||
| auto fg_parent = fg->parent(); | |||||
| MS_EXCEPTION_IF_NULL(fg_parent); | |||||
| fg_parent->set_flags(kFuncGraphFlagUndetermined, true); | |||||
| MS_LOG(DEBUG) << "Set graph undetermined: " << fg_parent->ToString(); | |||||
| } | |||||
| } | |||||
| EvaluatorPtr AnalysisEngine::HandleNestedRecursion(const std::vector<EvaluatorPtr> &evaluators, | |||||
| const EvaluatorPtr &eval, const AbstractBasePtrList &args_spec_list, | |||||
| const EvalTraceRevIter &it, bool *continue_flag) { | |||||
| *continue_flag = false; | |||||
| // Find latest entry function to handle nested recursion. | |||||
| EvaluatorPtr latest_entry = eval; | |||||
| auto latest_entry_iter = eval_trace_.rbegin(); | |||||
| for (auto r_it = eval_trace_.rbegin(); *r_it != *it;) { | |||||
| auto it_temp = std::find(evaluators.begin(), evaluators.end(), r_it->first); | |||||
| if (it_temp != evaluators.end()) { | |||||
| latest_entry = *it_temp; | |||||
| latest_entry_iter = r_it; | |||||
| break; | |||||
| } | |||||
| latest_entry_iter = ++r_it; | |||||
| } | |||||
| if (latest_entry != eval) { | |||||
| MS_LOG(DEBUG) << "Continue Evaluator " << eval->ToString(); | |||||
| *continue_flag = true; | |||||
| return latest_entry; | |||||
| } | |||||
| bool has_undetermined = false; | |||||
| // Check whether sub loop has untraced undetermined evaluator. | |||||
| std::set<std::pair<EvaluatorPtr, AbstractBasePtrList>> undetermined_evals; | |||||
| for (auto r_it = eval_trace_.rbegin(); r_it != latest_entry_iter; r_it++) { | |||||
| undetermined_evals.insert(*r_it); | |||||
| } | |||||
| MS_LOG(DEBUG) << "undetermined_evals size(): " << undetermined_evals.size(); | |||||
| for (auto u_eval : undetermined_evals) { | |||||
| MS_LOG(DEBUG) << u_eval.first->ToString() << " check undetermined."; | |||||
| if (!undetermined_evals.count(std::make_pair(multi_poss_[u_eval.first], args_spec_list))) { | |||||
| MS_LOG(DEBUG) << u_eval.first->ToString() << " has undetermined."; | |||||
| has_undetermined = true; | |||||
| break; | |||||
| } | |||||
| } | |||||
| if (has_undetermined == false) { | |||||
| MS_LOG(DEBUG) << eval->ToString() << " has no undetermined."; | |||||
| *continue_flag = true; | |||||
| return latest_entry; | |||||
| } | |||||
| return latest_entry; | |||||
| } | |||||
| EvalResultPtr AnalysisEngine::ProcessEvalResults(const AbstractBasePtrList &out_specs) { | |||||
| if (out_specs.size() == 0) { | |||||
| MS_LOG(EXCEPTION) << "There is an endless loop for evaluator."; | |||||
| } | |||||
| if (out_specs.size() == 1) { | |||||
| MS_EXCEPTION_IF_NULL(out_specs[0]); | |||||
| // If only one result derived, then broaden it to avoid wrong constant propagation. | |||||
| return std::make_shared<EvalResult>(out_specs[0]->Broaden(), std::make_shared<AttrValueMap>()); | |||||
| } | |||||
| auto joined_spec = AbstractJoin(out_specs); | |||||
| MS_EXCEPTION_IF_NULL(joined_spec); | |||||
| MS_LOG(DEBUG) << "Multiple evaluators joined: " << joined_spec->ToString(); | |||||
| return std::make_shared<EvalResult>(joined_spec, std::make_shared<AttrValueMap>()); | |||||
| } | |||||
| EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector<EvaluatorPtr> &evaluators, | EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector<EvaluatorPtr> &evaluators, | ||||
| const AnfNodeConfigPtr &out_conf, | const AnfNodeConfigPtr &out_conf, | ||||
| const ConfigPtrList &args_conf_list) { | const ConfigPtrList &args_conf_list) { | ||||
| @@ -479,18 +558,7 @@ EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector<Evalua | |||||
| return conf->GetEvaluatedValue()->abstract(); | return conf->GetEvaluatedValue()->abstract(); | ||||
| }); | }); | ||||
| for (auto eval : evaluators) { | for (auto eval : evaluators) { | ||||
| auto fg_eval = eval->cast<FuncGraphEvaluatorPtr>(); | |||||
| if (fg_eval) { | |||||
| auto fg = fg_eval->func_graph(); | |||||
| MS_EXCEPTION_IF_NULL(fg); | |||||
| auto undetermined_fgs = fg->recursive_graphs(); | |||||
| if (undetermined_fgs) { | |||||
| auto fg_parent = fg->parent(); | |||||
| MS_EXCEPTION_IF_NULL(fg_parent); | |||||
| fg_parent->set_flags(kFuncGraphFlagUndetermined, true); | |||||
| MS_LOG(DEBUG) << "Set graph undetermined: " << fg_parent->ToString(); | |||||
| } | |||||
| } | |||||
| SetUndeterminedFlag(eval); | |||||
| auto current_inf = std::make_pair(eval, args_spec_list); | auto current_inf = std::make_pair(eval, args_spec_list); | ||||
| MS_LOG(DEBUG) << "Check Evaluator " << eval->ToString(); | MS_LOG(DEBUG) << "Check Evaluator " << eval->ToString(); | ||||
| @@ -510,40 +578,9 @@ EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector<Evalua | |||||
| multi_poss_.clear(); | multi_poss_.clear(); | ||||
| } | } | ||||
| } else if (it != eval_trace_.rbegin()) { | } else if (it != eval_trace_.rbegin()) { | ||||
| // Find latest entry function to handle nested recursion. | |||||
| EvaluatorPtr latest_entry = eval; | |||||
| auto latest_entry_iter = eval_trace_.rbegin(); | |||||
| for (auto r_it = eval_trace_.rbegin(); *r_it != *it;) { | |||||
| auto it_temp = std::find(evaluators.begin(), evaluators.end(), r_it->first); | |||||
| if (it_temp != evaluators.end()) { | |||||
| latest_entry = *it_temp; | |||||
| latest_entry_iter = r_it; | |||||
| break; | |||||
| } | |||||
| latest_entry_iter = ++r_it; | |||||
| } | |||||
| if (latest_entry != eval) { | |||||
| MS_LOG(DEBUG) << "Continue Evaluator " << eval->ToString(); | |||||
| continue; | |||||
| } | |||||
| bool has_undetermined = false; | |||||
| // Check whether sub loop has untraced undetermined evaluator. | |||||
| std::set<std::pair<EvaluatorPtr, AbstractBasePtrList>> undetermined_evals; | |||||
| for (auto r_it = eval_trace_.rbegin(); r_it != latest_entry_iter; r_it++) { | |||||
| undetermined_evals.insert(*r_it); | |||||
| } | |||||
| MS_LOG(DEBUG) << "undetermined_evals size(): " << undetermined_evals.size(); | |||||
| for (auto u_eval : undetermined_evals) { | |||||
| MS_LOG(DEBUG) << u_eval.first->ToString() << " check undetermined."; | |||||
| if (!undetermined_evals.count(std::make_pair(multi_poss_[u_eval.first], args_spec_list))) { | |||||
| MS_LOG(DEBUG) << u_eval.first->ToString() << " has undetermined."; | |||||
| has_undetermined = true; | |||||
| break; | |||||
| } | |||||
| } | |||||
| if (has_undetermined == false) { | |||||
| MS_LOG(DEBUG) << eval->ToString() << " has no undetermined."; | |||||
| bool continue_flag = false; | |||||
| auto latest_entry = HandleNestedRecursion(evaluators, eval, args_spec_list, it, &continue_flag); | |||||
| if (continue_flag) { | |||||
| continue; | continue; | ||||
| } | } | ||||
| @@ -558,19 +595,8 @@ EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector<Evalua | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| if (out_specs.size() == 0) { | |||||
| MS_LOG(EXCEPTION) << "There is an endless loop for evaluator."; | |||||
| } | |||||
| if (out_specs.size() == 1) { | |||||
| MS_EXCEPTION_IF_NULL(out_specs[0]); | |||||
| // If only one result derived, then broaden it to avoid wrong constant propagation. | |||||
| return std::make_shared<EvalResult>(out_specs[0]->Broaden(), std::make_shared<AttrValueMap>()); | |||||
| } | |||||
| auto joined_spec = AbstractJoin(out_specs); | |||||
| MS_EXCEPTION_IF_NULL(joined_spec); | |||||
| MS_LOG(DEBUG) << "Multiple evaluators joined: " << joined_spec->ToString(); | |||||
| return std::make_shared<EvalResult>(joined_spec, std::make_shared<AttrValueMap>()); | |||||
| return ProcessEvalResults(out_specs); | |||||
| } | } | ||||
| EvalResultPtr AnfNodeConfig::GetEvaluatedValue() { | EvalResultPtr AnfNodeConfig::GetEvaluatedValue() { | ||||
| @@ -172,6 +172,8 @@ struct AnalysisResult { | |||||
| AnalysisContextPtr context; | AnalysisContextPtr context; | ||||
| }; | }; | ||||
| using EvalTraceRevIter = std::list<std::pair<EvaluatorPtr, AbstractBasePtrList>>::reverse_iterator; | |||||
| class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> { | class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> { | ||||
| public: | public: | ||||
| AnalysisEngine(const PrimEvaluatorMap &prim_evaluator_map, const FuncGraphManagerPtr &func_graph_manager) | AnalysisEngine(const PrimEvaluatorMap &prim_evaluator_map, const FuncGraphManagerPtr &func_graph_manager) | ||||
| @@ -222,6 +224,12 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> { | |||||
| std::unordered_map<PrimitivePyPtr, EvaluatorPtr> prim_py_evaluators_; | std::unordered_map<PrimitivePyPtr, EvaluatorPtr> prim_py_evaluators_; | ||||
| private: | private: | ||||
| void SetUndeterminedFlag(const EvaluatorPtr &evaluator); | |||||
| EvaluatorPtr HandleNestedRecursion(const std::vector<EvaluatorPtr> &evaluators, const EvaluatorPtr &eval, | |||||
| const AbstractBasePtrList &args_spec_list, const EvalTraceRevIter &it, | |||||
| bool *continue_flag); | |||||
| EvalResultPtr ProcessEvalResults(const AbstractBasePtrList &out_specs); | |||||
| const PrimEvaluatorMap &prim_constructors_; | const PrimEvaluatorMap &prim_constructors_; | ||||
| FuncGraphManagerPtr func_graph_manager_; | FuncGraphManagerPtr func_graph_manager_; | ||||
| std::unordered_map<AbstractFunctionPtr, EvaluatorPtr> constructors_; | std::unordered_map<AbstractFunctionPtr, EvaluatorPtr> constructors_; | ||||