| @@ -94,7 +94,7 @@ PenaltyBreakString: 1000 | |||
| PenaltyBreakTemplateDeclaration: 10 | |||
| PenaltyExcessCharacter: 1000000 | |||
| PenaltyReturnTypeOnItsOwnLine: 200 | |||
| PointerAlignment: Left | |||
| PointerAlignment: Right | |||
| RawStringFormats: | |||
| - Language: Cpp | |||
| Delimiters: | |||
| @@ -23,7 +23,7 @@ namespace common { | |||
| const int CACHED_STR_NUM = 1 << 8; | |||
| const int CACHED_STR_MASK = CACHED_STR_NUM - 1; | |||
| std::vector<std::string> STR_HOLDER(CACHED_STR_NUM); | |||
| const char* SafeCStr(const std::string&& str) { | |||
| const char *SafeCStr(const std::string &&str) { | |||
| static std::atomic<uint32_t> index{0}; | |||
| uint32_t cur_index = index++; | |||
| cur_index = cur_index & CACHED_STR_MASK; | |||
| @@ -21,16 +21,16 @@ | |||
| #include <string> | |||
| #define DISABLE_COPY_AND_ASSIGN(ClassType) \ | |||
| ClassType(const ClassType&) = delete; \ | |||
| ClassType& operator=(const ClassType&) = delete; | |||
| ClassType(const ClassType &) = delete; \ | |||
| ClassType &operator=(const ClassType &) = delete; | |||
| namespace mindspore { | |||
| namespace common { | |||
| inline const char* SafeCStr(const std::string& str) { return str.c_str(); } | |||
| const char* SafeCStr(const std::string&& str); | |||
| inline const char *SafeCStr(const std::string &str) { return str.c_str(); } | |||
| const char *SafeCStr(const std::string &&str); | |||
| static inline std::string GetEnv(const std::string& envvar) { | |||
| const char* value = ::getenv(envvar.c_str()); | |||
| static inline std::string GetEnv(const std::string &envvar) { | |||
| const char *value = ::getenv(envvar.c_str()); | |||
| if (value == nullptr) { | |||
| return std::string(); | |||
| @@ -34,11 +34,11 @@ class DecodeOp : public TensorOp { | |||
| ~DecodeOp() = default; | |||
| Status Compute(const std::shared_ptr<Tensor>& input, std::shared_ptr<Tensor>* output) override; | |||
| Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override; | |||
| void Print(std::ostream& out) const override { out << "DecodeOp"; } | |||
| Status OutputShape(const std::vector<TensorShape>& inputs, std::vector<TensorShape>& outputs) override; | |||
| Status OutputType(const std::vector<DataType>& inputs, std::vector<DataType>& outputs) override; | |||
| void Print(std::ostream &out) const override { out << "DecodeOp"; } | |||
| Status OutputShape(const std::vector<TensorShape> &inputs, std::vector<TensorShape> &outputs) override; | |||
| Status OutputType(const std::vector<DataType> &inputs, std::vector<DataType> &outputs) override; | |||
| private: | |||
| bool is_rgb_format_ = true; | |||
| @@ -37,8 +37,8 @@ DistortBoundingBoxCropOp::DistortBoundingBoxCropOp(float aspect_ratio, float int | |||
| rnd_.seed(seed_); | |||
| } | |||
| Status DistortBoundingBoxCropOp::Compute(const std::vector<std::shared_ptr<Tensor>>& input, | |||
| std::vector<std::shared_ptr<Tensor>>* output) { | |||
| Status DistortBoundingBoxCropOp::Compute(const std::vector<std::shared_ptr<Tensor>> &input, | |||
| std::vector<std::shared_ptr<Tensor>> *output) { | |||
| IO_CHECK_VECTOR(input, output); | |||
| if (input.size() != NumInput()) | |||
| return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Number of inputs is not 5"); | |||
| @@ -98,8 +98,8 @@ Status DistortBoundingBoxCropOp::Compute(const std::vector<std::shared_ptr<Tenso | |||
| return Status::OK(); | |||
| } | |||
| Status DistortBoundingBoxCropOp::OutputShape(const std::vector<TensorShape>& inputs, | |||
| std::vector<TensorShape>& outputs) { | |||
| Status DistortBoundingBoxCropOp::OutputShape(const std::vector<TensorShape> &inputs, | |||
| std::vector<TensorShape> &outputs) { | |||
| RETURN_IF_NOT_OK(TensorOp::OutputShape(inputs, outputs)); | |||
| outputs.clear(); | |||
| TensorShape out = TensorShape{-1, -1}; | |||
| @@ -108,7 +108,7 @@ Status DistortBoundingBoxCropOp::OutputShape(const std::vector<TensorShape>& inp | |||
| if (!outputs.empty()) return Status::OK(); | |||
| return Status(StatusCode::kUnexpectedError, "Input has a wrong shape"); | |||
| } | |||
| Status DistortBoundingBoxCropOp::OutputType(const std::vector<DataType>& inputs, std::vector<DataType>& outputs) { | |||
| Status DistortBoundingBoxCropOp::OutputType(const std::vector<DataType> &inputs, std::vector<DataType> &outputs) { | |||
| RETURN_IF_NOT_OK(TensorOp::OutputType(inputs, outputs)); | |||
| outputs[0] = inputs[0]; | |||
| return Status::OK(); | |||
| @@ -45,16 +45,16 @@ class DistortBoundingBoxCropOp : public TensorOp { | |||
| ~DistortBoundingBoxCropOp() override = default; | |||
| void Print(std::ostream& out) const override { | |||
| void Print(std::ostream &out) const override { | |||
| out << "DistortBoundingBoxCropOp: " << max_attempts_ << " " << intersect_ratio_; | |||
| } | |||
| Status Compute(const std::vector<std::shared_ptr<Tensor>>& input, | |||
| std::vector<std::shared_ptr<Tensor>>* output) override; | |||
| Status Compute(const std::vector<std::shared_ptr<Tensor>> &input, | |||
| std::vector<std::shared_ptr<Tensor>> *output) override; | |||
| uint32_t NumInput() override { return 5; } | |||
| Status OutputShape(const std::vector<TensorShape>& inputs, std::vector<TensorShape>& outputs) override; | |||
| Status OutputType(const std::vector<DataType>& inputs, std::vector<DataType>& outputs) override; | |||
| Status OutputShape(const std::vector<TensorShape> &inputs, std::vector<TensorShape> &outputs) override; | |||
| Status OutputType(const std::vector<DataType> &inputs, std::vector<DataType> &outputs) override; | |||
| private: | |||
| int32_t max_attempts_; | |||
| @@ -41,7 +41,7 @@ RandomCropAndResizeOp::RandomCropAndResizeOp(int32_t target_height, int32_t targ | |||
| rnd_.seed(GetSeed()); | |||
| } | |||
| Status RandomCropAndResizeOp::Compute(const std::shared_ptr<Tensor>& input, std::shared_ptr<Tensor>* output) { | |||
| Status RandomCropAndResizeOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) { | |||
| IO_CHECK(input, output); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(input->shape().Size() >= 2, "The shape of input is abnormal"); | |||
| @@ -54,7 +54,7 @@ Status RandomCropAndResizeOp::Compute(const std::shared_ptr<Tensor>& input, std: | |||
| (void)GetCropBox(h_in, w_in, &x, &y, &crop_height, &crop_width); | |||
| return CropAndResize(input, output, x, y, crop_height, crop_width, target_height_, target_width_, interpolation_); | |||
| } | |||
| Status RandomCropAndResizeOp::OutputShape(const std::vector<TensorShape>& inputs, std::vector<TensorShape>& outputs) { | |||
| Status RandomCropAndResizeOp::OutputShape(const std::vector<TensorShape> &inputs, std::vector<TensorShape> &outputs) { | |||
| RETURN_IF_NOT_OK(TensorOp::OutputShape(inputs, outputs)); | |||
| outputs.clear(); | |||
| TensorShape out = TensorShape{target_height_, target_width_}; | |||
| @@ -63,7 +63,7 @@ Status RandomCropAndResizeOp::OutputShape(const std::vector<TensorShape>& inputs | |||
| if (!outputs.empty()) return Status::OK(); | |||
| return Status(StatusCode::kUnexpectedError, "Input has a wrong shape"); | |||
| } | |||
| Status RandomCropAndResizeOp::GetCropBox(int h_in, int w_in, int* x, int* y, int* crop_height, int* crop_width) { | |||
| Status RandomCropAndResizeOp::GetCropBox(int h_in, int w_in, int *x, int *y, int *crop_height, int *crop_width) { | |||
| double scale, aspect; | |||
| *crop_width = w_in; | |||
| *crop_height = h_in; | |||
| @@ -22,7 +22,7 @@ | |||
| namespace mindspore { | |||
| constexpr char PARALLEL_STRATEGY[] = "strategy"; | |||
| void DumpIR(const std::string& filename, const FuncGraphPtr& func_graph, bool dump_full_name = false); | |||
| void DumpIR(const std::string &filename, const FuncGraphPtr &func_graph, bool dump_full_name = false); | |||
| } // namespace mindspore | |||
| @@ -44,7 +44,7 @@ const int NUM_MAX_SEQUENCE_ELEMS = 0x00FFFFFF; | |||
| // get MindSpore Intermediate Representation Path | |||
| std::string GetMsIrPath(void) { | |||
| std::string path; | |||
| const char* path_ptr = getenv("MS_IR_PATH"); | |||
| const char *path_ptr = getenv("MS_IR_PATH"); | |||
| if (path_ptr != nullptr) { | |||
| path = path_ptr; | |||
| char real_path[PATH_MAX] = {0}; | |||
| @@ -62,13 +62,13 @@ std::string GetMsIrPath(void) { | |||
| return path; | |||
| } | |||
| std::string dump_obj(const py::object& obj, const std::string& path) { | |||
| std::string dump_obj(const py::object &obj, const std::string &path) { | |||
| py::module mod = parse::python_adapter::GetPyModule(parse::PYTHON_MOD_PARSE_MODULE); | |||
| py::object name = parse::python_adapter::CallPyModFn(mod, "dump_obj", obj, py::str(path)); | |||
| return py::str(name); | |||
| } | |||
| py::object load_obj(const std::string& path) { | |||
| py::object load_obj(const std::string &path) { | |||
| py::module mod = parse::python_adapter::GetPyModule(parse::PYTHON_MOD_PARSE_MODULE); | |||
| py::object obj = parse::python_adapter::CallPyModFn(mod, "load_obj", py::str(path)); | |||
| return obj; | |||
| @@ -76,7 +76,7 @@ py::object load_obj(const std::string& path) { | |||
| // ============================================= MindSpore IR Exporter ============================================= | |||
| std::string AnfExporter::GetNodeType(const AnfNodePtr& nd) { | |||
| std::string AnfExporter::GetNodeType(const AnfNodePtr &nd) { | |||
| abstract::ShapePtr shape = nd->Shape() == nullptr ? nullptr : dyn_cast<abstract::Shape>(nd->Shape()); | |||
| TypePtr type = dyn_cast<Type>(nd->Type()); | |||
| std::ostringstream oss; | |||
| @@ -90,7 +90,7 @@ std::string AnfExporter::GetNodeType(const AnfNodePtr& nd) { | |||
| return oss.str(); | |||
| } | |||
| std::string AnfExporter::DumpObject(const py::object& obj, const std::string& category) const { | |||
| std::string AnfExporter::DumpObject(const py::object &obj, const std::string &category) const { | |||
| std::string pkl_path = GetMsIrPath(); | |||
| // if not specified env 'MS_IR_PATH', do not create any files | |||
| if (pkl_path.empty() || (getenv("MS_IR_FILE") != nullptr)) { | |||
| @@ -101,7 +101,7 @@ std::string AnfExporter::DumpObject(const py::object& obj, const std::string& ca | |||
| return file_prefix + file_name; | |||
| } | |||
| int AnfExporter::GetParamIndex(const FuncGraphPtr& func_graph, const AnfNodePtr& param, bool throw_excp) { | |||
| int AnfExporter::GetParamIndex(const FuncGraphPtr &func_graph, const AnfNodePtr ¶m, bool throw_excp) { | |||
| if (func_graph == nullptr || param == nullptr) { | |||
| return -1; | |||
| } | |||
| @@ -129,13 +129,13 @@ int AnfExporter::GetParamIndex(const FuncGraphPtr& func_graph, const AnfNodePtr& | |||
| // try to find index of parameter for SymbolicKeyInstance from all exported graphs | |||
| // NOTICE: Suppose name of all parameters in SymbolicKeyInstance are different | |||
| int AnfExporter::GetParamIndexFromExported(const AnfNodePtr& param) { | |||
| int AnfExporter::GetParamIndexFromExported(const AnfNodePtr ¶m) { | |||
| if (param == nullptr) { | |||
| return -1; | |||
| } | |||
| int ret = -1; | |||
| for (const auto& item : exported) { | |||
| for (const auto &item : exported) { | |||
| auto pram_iter = item.second.find(param); | |||
| if (pram_iter != item.second.end()) { | |||
| return pram_iter->second; | |||
| @@ -144,12 +144,12 @@ int AnfExporter::GetParamIndexFromExported(const AnfNodePtr& param) { | |||
| return ret; | |||
| } | |||
| std::string AnfExporter::GetValueNodeText(const FuncGraphPtr& fg, const ValueNodePtr& node) { | |||
| std::string AnfExporter::GetValueNodeText(const FuncGraphPtr &fg, const ValueNodePtr &node) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| return GetValueText(fg, node->value()); | |||
| } | |||
| std::string AnfExporter::GetMultitypeFuncGraphText(const prim::MultitypeFuncGraphPtr& mt_func_graph) { | |||
| std::string AnfExporter::GetMultitypeFuncGraphText(const prim::MultitypeFuncGraphPtr &mt_func_graph) { | |||
| auto py_funcs = mt_func_graph->GetPyFunctions(); | |||
| if (py_funcs.empty()) { | |||
| return ""; | |||
| @@ -159,7 +159,7 @@ std::string AnfExporter::GetMultitypeFuncGraphText(const prim::MultitypeFuncGrap | |||
| oss << "{"; | |||
| bool is_first = true; | |||
| for (const auto& py_func : py_funcs) { | |||
| for (const auto &py_func : py_funcs) { | |||
| if (is_first) { | |||
| is_first = false; | |||
| } else { | |||
| @@ -193,7 +193,7 @@ std::string AnfExporter::GetMultitypeFuncGraphText(const prim::MultitypeFuncGrap | |||
| * ├── GradOperation | |||
| * └── TupleAdd | |||
| */ | |||
| std::string AnfExporter::GetMetaFuncGraphText(const MetaFuncGraphPtr& meta_func_graph) { | |||
| std::string AnfExporter::GetMetaFuncGraphText(const MetaFuncGraphPtr &meta_func_graph) { | |||
| if (meta_func_graph == nullptr) { | |||
| return ""; | |||
| } | |||
| @@ -244,7 +244,7 @@ std::string AnfExporter::GetMetaFuncGraphText(const MetaFuncGraphPtr& meta_func_ | |||
| return oss.str(); | |||
| } | |||
| std::string AnfExporter::GetPrimitiveText(const PrimitivePtr& prim) { | |||
| std::string AnfExporter::GetPrimitiveText(const PrimitivePtr &prim) { | |||
| std::ostringstream oss; | |||
| if (prim == nullptr) { | |||
| return oss.str(); | |||
| @@ -266,7 +266,7 @@ std::string AnfExporter::GetPrimitiveText(const PrimitivePtr& prim) { | |||
| if (prim->isa<prim::DoSignaturePrimitive>()) { | |||
| auto do_signature = dyn_cast<prim::DoSignaturePrimitive>(prim); | |||
| auto& func = do_signature->function(); | |||
| auto &func = do_signature->function(); | |||
| if (func->isa<Primitive>()) { | |||
| auto sig_prim = dyn_cast<Primitive>(func); | |||
| oss << sig_prim->GetAttrsText(); | |||
| @@ -276,7 +276,7 @@ std::string AnfExporter::GetPrimitiveText(const PrimitivePtr& prim) { | |||
| return oss.str(); | |||
| } | |||
| std::string AnfExporter::GetNameSpaceText(const parse::NameSpacePtr& ns) { | |||
| std::string AnfExporter::GetNameSpaceText(const parse::NameSpacePtr &ns) { | |||
| std::ostringstream oss; | |||
| if (ns == nullptr) { | |||
| return oss.str(); | |||
| @@ -288,8 +288,8 @@ std::string AnfExporter::GetNameSpaceText(const parse::NameSpacePtr& ns) { | |||
| return oss.str(); | |||
| } | |||
| std::string AnfExporter::GetSymbolicKeyInstanceText(const FuncGraphPtr& func_graph, | |||
| const SymbolicKeyInstancePtr& sym_inst) { | |||
| std::string AnfExporter::GetSymbolicKeyInstanceText(const FuncGraphPtr &func_graph, | |||
| const SymbolicKeyInstancePtr &sym_inst) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| MS_EXCEPTION_IF_NULL(sym_inst); | |||
| AnfNodePtr sym_node = sym_inst->node(); | |||
| @@ -317,7 +317,7 @@ std::string AnfExporter::GetSymbolicKeyInstanceText(const FuncGraphPtr& func_gra | |||
| return oss.str(); | |||
| } | |||
| std::string AnfExporter::GetSequenceText(const FuncGraphPtr& func_graph, const ValuePtr& value) { | |||
| std::string AnfExporter::GetSequenceText(const FuncGraphPtr &func_graph, const ValuePtr &value) { | |||
| std::ostringstream oss; | |||
| // output ValueList, ValueTuple | |||
| ValueSequeuePtr seq = dyn_cast<ValueSequeue>(value); | |||
| @@ -338,12 +338,12 @@ std::string AnfExporter::GetSequenceText(const FuncGraphPtr& func_graph, const V | |||
| return oss.str(); | |||
| } | |||
| std::string AnfExporter::GetDictText(const FuncGraphPtr& func_graph, const ValuePtr& value) { | |||
| std::string AnfExporter::GetDictText(const FuncGraphPtr &func_graph, const ValuePtr &value) { | |||
| std::ostringstream oss; | |||
| ValueDictionaryPtr dict = value->cast<ValueDictionaryPtr>(); | |||
| oss << "{"; | |||
| bool first_flag = true; | |||
| for (const auto& elem : dict->value()) { | |||
| for (const auto &elem : dict->value()) { | |||
| if (first_flag) { | |||
| first_flag = false; | |||
| } else { | |||
| @@ -355,7 +355,7 @@ std::string AnfExporter::GetDictText(const FuncGraphPtr& func_graph, const Value | |||
| return oss.str(); | |||
| } | |||
| std::string AnfExporter::GetOtherValueText(const FuncGraphPtr&, const ValuePtr& value) { | |||
| std::string AnfExporter::GetOtherValueText(const FuncGraphPtr &, const ValuePtr &value) { | |||
| std::ostringstream oss; | |||
| if (check_integrity_) { | |||
| @@ -366,7 +366,7 @@ std::string AnfExporter::GetOtherValueText(const FuncGraphPtr&, const ValuePtr& | |||
| return oss.str(); | |||
| } | |||
| std::string AnfExporter::GetValueText(const FuncGraphPtr& func_graph, const ValuePtr& value) { | |||
| std::string AnfExporter::GetValueText(const FuncGraphPtr &func_graph, const ValuePtr &value) { | |||
| std::ostringstream oss; | |||
| bool is_null_ptr = (func_graph == nullptr || value == nullptr); | |||
| if (is_null_ptr) { | |||
| @@ -413,8 +413,8 @@ std::string AnfExporter::GetValueText(const FuncGraphPtr& func_graph, const Valu | |||
| } | |||
| // this function is used to output node in CNode's inputs | |||
| std::string AnfExporter::GetAnfNodeText(const FuncGraphPtr& func_graph, const AnfNodePtr& node, | |||
| const std::map<AnfNodePtr, int>& apply_map) { | |||
| std::string AnfExporter::GetAnfNodeText(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | |||
| const std::map<AnfNodePtr, int> &apply_map) { | |||
| std::ostringstream oss; | |||
| if (func_graph == nullptr || node == nullptr) { | |||
| return oss.str(); | |||
| @@ -444,10 +444,10 @@ std::string AnfExporter::GetAnfNodeText(const FuncGraphPtr& func_graph, const An | |||
| return oss.str(); | |||
| } | |||
| void AnfExporter::OutputParameters(std::ofstream& ofs, const std::vector<AnfNodePtr>& parameters, | |||
| OrderedMap<AnfNodePtr, int, ParamPtrHasher, ParamPtrEqual>* param_map) { | |||
| void AnfExporter::OutputParameters(std::ofstream &ofs, const std::vector<AnfNodePtr> ¶meters, | |||
| OrderedMap<AnfNodePtr, int, ParamPtrHasher, ParamPtrEqual> *param_map) { | |||
| bool first_flag = true; | |||
| for (const AnfNodePtr& param : parameters) { | |||
| for (const AnfNodePtr ¶m : parameters) { | |||
| if (first_flag) { | |||
| first_flag = false; | |||
| ofs << " "; | |||
| @@ -479,13 +479,13 @@ void AnfExporter::OutputParameters(std::ofstream& ofs, const std::vector<AnfNode | |||
| } | |||
| } | |||
| void AnfExporter::OutputStatementComment(std::ofstream& ofs, const CNodePtr& node) { | |||
| void AnfExporter::OutputStatementComment(std::ofstream &ofs, const CNodePtr &node) { | |||
| if (node == nullptr) { | |||
| return; | |||
| } | |||
| // output type of each input argument | |||
| auto& inputs = node->inputs(); | |||
| auto &inputs = node->inputs(); | |||
| if (inputs.size() > 1) { | |||
| ofs << " #("; | |||
| for (size_t i = 1; i < inputs.size(); ++i) { | |||
| @@ -521,15 +521,15 @@ void AnfExporter::OutputStatementComment(std::ofstream& ofs, const CNodePtr& nod | |||
| ofs << " #scope: " << node->scope()->name(); | |||
| } | |||
| void AnfExporter::OutputCNodes(std::ofstream& ofs, const std::vector<AnfNodePtr>& nodes, | |||
| const FuncGraphPtr& func_graph) { | |||
| void AnfExporter::OutputCNodes(std::ofstream &ofs, const std::vector<AnfNodePtr> &nodes, | |||
| const FuncGraphPtr &func_graph) { | |||
| if (func_graph == nullptr) { | |||
| return; | |||
| } | |||
| int idx = 1; | |||
| std::map<AnfNodePtr, int> apply_map; | |||
| for (const AnfNodePtr& node : nodes) { | |||
| for (const AnfNodePtr &node : nodes) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| if (!node->isa<CNode>()) { | |||
| continue; | |||
| @@ -541,7 +541,7 @@ void AnfExporter::OutputCNodes(std::ofstream& ofs, const std::vector<AnfNodePtr> | |||
| } | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| auto& inputs = cnode->inputs(); | |||
| auto &inputs = cnode->inputs(); | |||
| std::string op_text = GetAnfNodeText(func_graph, inputs[0], apply_map); | |||
| // non-return node | |||
| if (node != func_graph->get_return()) { | |||
| @@ -578,7 +578,7 @@ void AnfExporter::OutputCNodes(std::ofstream& ofs, const std::vector<AnfNodePtr> | |||
| } | |||
| } | |||
| void AnfExporter::ExportOneFuncGraph(std::ofstream& ofs, const FuncGraphPtr& func_graph) { | |||
| void AnfExporter::ExportOneFuncGraph(std::ofstream &ofs, const FuncGraphPtr &func_graph) { | |||
| if (func_graph == nullptr) { | |||
| return; | |||
| } | |||
| @@ -612,7 +612,7 @@ void AnfExporter::ExportOneFuncGraph(std::ofstream& ofs, const FuncGraphPtr& fun | |||
| ofs << "}\n"; | |||
| } | |||
| void AnfExporter::ExportFuncGraph(const std::string& filename, const FuncGraphPtr& func_graph) { | |||
| void AnfExporter::ExportFuncGraph(const std::string &filename, const FuncGraphPtr &func_graph) { | |||
| if (func_graph == nullptr) { | |||
| return; | |||
| } | |||
| @@ -637,7 +637,7 @@ void AnfExporter::ExportFuncGraph(const std::string& filename, const FuncGraphPt | |||
| ofs.close(); | |||
| } | |||
| void AnfExporter::ExportFuncGraph(const std::string& filename, const std::vector<TaggedGraph>& graphs) { | |||
| void AnfExporter::ExportFuncGraph(const std::string &filename, const std::vector<TaggedGraph> &graphs) { | |||
| if (graphs.empty()) { | |||
| return; | |||
| } | |||
| @@ -650,7 +650,7 @@ void AnfExporter::ExportFuncGraph(const std::string& filename, const std::vector | |||
| param_index = 1; | |||
| for (const auto& tagged_graph : graphs) { | |||
| for (const auto &tagged_graph : graphs) { | |||
| tagged_cnodes_ = tagged_graph.second; | |||
| ExportOneFuncGraph(ofs, tagged_graph.first); | |||
| tagged_cnodes_.clear(); | |||
| @@ -663,7 +663,7 @@ void AnfExporter::ExportFuncGraph(const std::string& filename, const std::vector | |||
| } | |||
| #ifdef ENABLE_DUMP_IR | |||
| void ExportIR(const std::string& filename, const std::string& id, const FuncGraphPtr& func_graph) { | |||
| void ExportIR(const std::string &filename, const std::string &id, const FuncGraphPtr &func_graph) { | |||
| if (func_graph == nullptr) { | |||
| return; | |||
| } | |||
| @@ -675,7 +675,7 @@ void ExportIR(const std::string& filename, const std::string& id, const FuncGrap | |||
| ChangeFileMode(filename, S_IRUSR); | |||
| } | |||
| void ExportIR(const std::string& filename, const std::vector<TaggedGraph>& graphs) { | |||
| void ExportIR(const std::string &filename, const std::vector<TaggedGraph> &graphs) { | |||
| AnfExporter exporter("", false); | |||
| ChangeFileMode(filename, S_IRWXU); | |||
| exporter.ExportFuncGraph(filename, graphs); | |||
| @@ -683,7 +683,7 @@ void ExportIR(const std::string& filename, const std::vector<TaggedGraph>& graph | |||
| ChangeFileMode(filename, S_IRUSR); | |||
| } | |||
| #else | |||
| void ExportIR(const std::string&, const std::string&, const FuncGraphPtr&) { | |||
| void ExportIR(const std::string &, const std::string &, const FuncGraphPtr &) { | |||
| static bool already_printed = false; | |||
| if (already_printed) { | |||
| return; | |||
| @@ -693,7 +693,7 @@ void ExportIR(const std::string&, const std::string&, const FuncGraphPtr&) { | |||
| << "please recompile source to enable it. See help of building script."; | |||
| } | |||
| void ExportIR(const std::string& filename, const std::vector<TaggedGraph>& graphs) { | |||
| void ExportIR(const std::string &filename, const std::vector<TaggedGraph> &graphs) { | |||
| static bool already_printed = false; | |||
| if (already_printed) { | |||
| return; | |||
| @@ -732,7 +732,7 @@ enum Token : int { | |||
| TOK_ERROR // file read error | |||
| }; | |||
| std::map<Token, const char*> token_text = { | |||
| std::map<Token, const char *> token_text = { | |||
| {TOK_INVALID, "invalid"}, // invalid token | |||
| {TOK_LPARENTHESIS, "("}, // ( left parenthesis | |||
| {TOK_RPARENTHESIS, ")"}, // ) right parenthesis | |||
| @@ -761,14 +761,14 @@ std::map<Token, const char*> token_text = { | |||
| class Lexer { | |||
| public: | |||
| // filename is checked in ImportIR; | |||
| explicit Lexer(const char* filename) : fin(filename) {} | |||
| explicit Lexer(const char *filename) : fin(filename) {} | |||
| ~Lexer() { | |||
| try { | |||
| if (fin.is_open()) { | |||
| fin.close(); | |||
| } | |||
| } catch (const std::exception& e) { | |||
| } catch (const std::exception &e) { | |||
| MS_LOG(ERROR) << "Exception when closing file"; | |||
| } catch (...) { | |||
| std::string exName(abi::__cxa_current_exception_type()->name()); | |||
| @@ -776,7 +776,7 @@ class Lexer { | |||
| } | |||
| } | |||
| bool IsSingleCharToken(char ch, Token* token_ptr) { | |||
| bool IsSingleCharToken(char ch, Token *token_ptr) { | |||
| // clang-format off | |||
| std::unordered_map<char, Token> char_to_token = { | |||
| {'(', TOK_LPARENTHESIS}, | |||
| @@ -806,7 +806,7 @@ class Lexer { | |||
| Token GetNextToken() { | |||
| #ifdef DEBUG | |||
| Token token = GetNextTokenInner(); | |||
| const char* str = token_text[token]; | |||
| const char *str = token_text[token]; | |||
| std::string text = (str == nullptr ? GetTokenText() : str); | |||
| MS_LOG(DEBUG) << "------Parse token] " << text; | |||
| return token; | |||
| @@ -1064,11 +1064,11 @@ const unsigned Lexer::BUF_SIZE; | |||
| class IrParser { | |||
| public: | |||
| explicit IrParser(const char* filename) : lexer_(filename) {} | |||
| explicit IrParser(const char *filename) : lexer_(filename) {} | |||
| ~IrParser() {} | |||
| py::object LoadObject(const std::string& file_name) const { | |||
| py::object LoadObject(const std::string &file_name) const { | |||
| std::string pkl_path = GetMsIrPath(); | |||
| py::object default_obj = load_obj(pkl_path + "/" + file_name); | |||
| return default_obj; | |||
| @@ -1087,7 +1087,7 @@ class IrParser { | |||
| MS_LOG(INFO) << "Total graphs: " << func_graphs_.size(); | |||
| } | |||
| Token ParseParent(FuncGraphPtr* const parent_ptr) { | |||
| Token ParseParent(FuncGraphPtr *const parent_ptr) { | |||
| if (lexer_.GetNextToken() != TOK_IDENTIFIER) { | |||
| return TOK_ERROR; | |||
| } | |||
| @@ -1168,7 +1168,7 @@ class IrParser { | |||
| return func_graph; | |||
| } | |||
| FuncGraphPtr ParseStatements(const FuncGraphPtr& func_graph) { | |||
| FuncGraphPtr ParseStatements(const FuncGraphPtr &func_graph) { | |||
| Token tok = lexer_.SkipWhiteToken(); | |||
| while (tok == TOK_VARIABLE) { | |||
| if (ParseStatement(func_graph) == nullptr) { | |||
| @@ -1264,56 +1264,56 @@ class IrParser { | |||
| return func_graph; | |||
| } | |||
| void SetBasicType(TypePtr* ptr, const TypePtr& dtype) const { | |||
| void SetBasicType(TypePtr *ptr, const TypePtr &dtype) const { | |||
| if (ptr == nullptr) { | |||
| return; | |||
| } | |||
| *ptr = dtype; | |||
| } | |||
| void SetTupleType(TypePtr* ptr) { | |||
| void SetTupleType(TypePtr *ptr) { | |||
| if (ptr == nullptr) { | |||
| return; | |||
| } | |||
| *ptr = std::make_shared<Tuple>(); | |||
| } | |||
| void SetTupleType(TypePtr* ptr, const TypePtrList& elems) { | |||
| void SetTupleType(TypePtr *ptr, const TypePtrList &elems) { | |||
| if (ptr == nullptr) { | |||
| return; | |||
| } | |||
| *ptr = std::make_shared<Tuple>(elems); | |||
| } | |||
| void SetArrayType(TypePtr* const ptr, const TypePtr& elem_type, const std::vector<int>&) { | |||
| void SetArrayType(TypePtr *const ptr, const TypePtr &elem_type, const std::vector<int> &) { | |||
| if (ptr == nullptr) { | |||
| return; | |||
| } | |||
| *ptr = std::make_shared<TensorType>(elem_type); | |||
| } | |||
| void SetListType(TypePtr* ptr) { | |||
| void SetListType(TypePtr *ptr) { | |||
| if (ptr == nullptr) { | |||
| return; | |||
| } | |||
| *ptr = std::make_shared<List>(); | |||
| } | |||
| void SetListType(TypePtr* ptr, const TypePtrList& elems) { | |||
| void SetListType(TypePtr *ptr, const TypePtrList &elems) { | |||
| if (ptr == nullptr) { | |||
| return; | |||
| } | |||
| *ptr = std::make_shared<List>(elems); | |||
| } | |||
| void SetJTaggedType(TypePtr* ptr, const TypePtr& elem) { | |||
| void SetJTaggedType(TypePtr *ptr, const TypePtr &elem) { | |||
| if (ptr == nullptr) { | |||
| return; | |||
| } | |||
| *ptr = std::make_shared<JTagged>(elem); | |||
| } | |||
| void SetBasicType(AbstractBasePtr* ptr, const TypePtr& dtype) const { | |||
| void SetBasicType(AbstractBasePtr *ptr, const TypePtr &dtype) const { | |||
| if (ptr == nullptr) { | |||
| return; | |||
| } | |||
| @@ -1321,45 +1321,45 @@ class IrParser { | |||
| } | |||
| // void SetBasicType(AbstractBasePtr *ptr, const SymbolicKeyTypePtr& dtype) {} | |||
| void SetBasicType(AbstractBasePtr* const ptr, const TypeNonePtr&) const { | |||
| void SetBasicType(AbstractBasePtr *const ptr, const TypeNonePtr &) const { | |||
| if (ptr == nullptr) { | |||
| return; | |||
| } | |||
| *ptr = std::make_shared<abstract::AbstractNone>(); | |||
| } | |||
| void SetBasicType(AbstractBasePtr*, const FunctionPtr&) const {} | |||
| void SetBasicType(AbstractBasePtr*, const TensorTypePtr&) const {} | |||
| void SetBasicType(AbstractBasePtr *, const FunctionPtr &) const {} | |||
| void SetBasicType(AbstractBasePtr *, const TensorTypePtr &) const {} | |||
| void SetTupleType(AbstractBasePtr* const ptr, const AbstractBasePtrList& elems) { | |||
| void SetTupleType(AbstractBasePtr *const ptr, const AbstractBasePtrList &elems) { | |||
| if (ptr == nullptr) { | |||
| return; | |||
| } | |||
| // if one of elems is nullptr, just return | |||
| if (std::any_of(std::begin(elems), std::end(elems), [](const AbstractBasePtr& elem) { return elem == nullptr; })) { | |||
| if (std::any_of(std::begin(elems), std::end(elems), [](const AbstractBasePtr &elem) { return elem == nullptr; })) { | |||
| return; | |||
| } | |||
| *ptr = std::make_shared<abstract::AbstractTuple>(elems); | |||
| } | |||
| void SetArrayType(AbstractBasePtr* const ptr, const TypePtr& elem_type, const std::vector<int>& shape) { | |||
| void SetArrayType(AbstractBasePtr *const ptr, const TypePtr &elem_type, const std::vector<int> &shape) { | |||
| if (ptr == nullptr) { | |||
| return; | |||
| } | |||
| *ptr = std::make_shared<abstract::AbstractTensor>(elem_type, shape); | |||
| } | |||
| void SetListType(AbstractBasePtr* const ptr, const AbstractBasePtrList& elems) { | |||
| void SetListType(AbstractBasePtr *const ptr, const AbstractBasePtrList &elems) { | |||
| if (ptr == nullptr) { | |||
| return; | |||
| } | |||
| if (std::any_of(std::begin(elems), std::end(elems), [](const AbstractBasePtr& elem) { return elem == nullptr; })) { | |||
| if (std::any_of(std::begin(elems), std::end(elems), [](const AbstractBasePtr &elem) { return elem == nullptr; })) { | |||
| return; | |||
| } | |||
| *ptr = std::make_shared<abstract::AbstractList>(elems); | |||
| } | |||
| void SetJTaggedType(AbstractBasePtr* const ptr, const AbstractBasePtr& elem) { | |||
| void SetJTaggedType(AbstractBasePtr *const ptr, const AbstractBasePtr &elem) { | |||
| if (ptr == nullptr) { | |||
| return; | |||
| } | |||
| @@ -1367,7 +1367,7 @@ class IrParser { | |||
| } | |||
| template <typename T> | |||
| Token ParseTypeVector(const FuncGraphPtr& func_graph, Token tok, const std::string& type, T* const ptr = nullptr) { | |||
| Token ParseTypeVector(const FuncGraphPtr &func_graph, Token tok, const std::string &type, T *const ptr = nullptr) { | |||
| if (tok != TOK_LBRACKET) { | |||
| MS_LOG(EXCEPTION) << "Illegal case, , wrong token start symbol."; | |||
| return tok; | |||
| @@ -1415,7 +1415,7 @@ class IrParser { | |||
| } | |||
| template <typename T> | |||
| Token ParseTypeArray(const FuncGraphPtr& func_graph, Token tok, T* const ptr = nullptr) { | |||
| Token ParseTypeArray(const FuncGraphPtr &func_graph, Token tok, T *const ptr = nullptr) { | |||
| if (tok != TOK_LPARENTHESIS) { | |||
| if (ptr != nullptr) { | |||
| SetBasicType(ptr, std::make_shared<TensorType>()); | |||
| @@ -1454,7 +1454,7 @@ class IrParser { | |||
| return lexer_.GetNextToken(); | |||
| } | |||
| bool IsNumberType(const std::string& type, TypeId* typeid_ptr) { | |||
| bool IsNumberType(const std::string &type, TypeId *typeid_ptr) { | |||
| // clang-format off | |||
| static std::unordered_map<std::string, TypeId> basic_types = { | |||
| {"Bool", kNumberTypeBool}, | |||
| @@ -1486,7 +1486,7 @@ class IrParser { | |||
| } | |||
| template <typename T> | |||
| void ParseNumberType(const std::string& type, TypeId typeId, T* const ptr = nullptr) { | |||
| void ParseNumberType(const std::string &type, TypeId typeId, T *const ptr = nullptr) { | |||
| TypePtr dtype = nullptr; | |||
| std::unordered_map<int, TypePtr> type_map = { | |||
| @@ -1519,7 +1519,7 @@ class IrParser { | |||
| } | |||
| template <typename T> | |||
| Token ParseTrivalType(const std::string& type, T* const ptr = nullptr) { | |||
| Token ParseTrivalType(const std::string &type, T *const ptr = nullptr) { | |||
| if (type == "NoneType") { | |||
| SetBasicType(ptr, std::make_shared<TypeNone>()); | |||
| return lexer_.GetNextToken(); | |||
| @@ -1541,7 +1541,7 @@ class IrParser { | |||
| } | |||
| template <typename T> | |||
| Token ParseOneType(const FuncGraphPtr& func_graph, Token tok, T* const ptr = nullptr) { | |||
| Token ParseOneType(const FuncGraphPtr &func_graph, Token tok, T *const ptr = nullptr) { | |||
| if (tok != TOK_IDENTIFIER) { | |||
| return TOK_ERROR; | |||
| } | |||
| @@ -1588,11 +1588,11 @@ class IrParser { | |||
| } | |||
| } | |||
| Token ParseType(const FuncGraphPtr& func_graph, AbstractBasePtr* const abstract = nullptr) { | |||
| Token ParseType(const FuncGraphPtr &func_graph, AbstractBasePtr *const abstract = nullptr) { | |||
| return ParseOneType(func_graph, lexer_.GetNextToken(), abstract); | |||
| } | |||
| Token ParseAttributes(const FuncGraphPtr& func_graph, const PrimitivePtr& prim) { | |||
| Token ParseAttributes(const FuncGraphPtr &func_graph, const PrimitivePtr &prim) { | |||
| Token tok = ParseAttribute(func_graph, prim); | |||
| while (tok == TOK_COMMA) { | |||
| tok = ParseAttribute(func_graph, prim); | |||
| @@ -1603,7 +1603,7 @@ class IrParser { | |||
| return lexer_.GetNextToken(); | |||
| } | |||
| Token ParseAttribute(const FuncGraphPtr& func_graph, const PrimitivePtr& prim) { | |||
| Token ParseAttribute(const FuncGraphPtr &func_graph, const PrimitivePtr &prim) { | |||
| Token tok = lexer_.GetNextToken(); | |||
| if (tok != TOK_IDENTIFIER) { | |||
| return TOK_ERROR; | |||
| @@ -1670,7 +1670,7 @@ class IrParser { | |||
| return tok == TOK_RPARENTHESIS ? func_graph : nullptr; | |||
| } | |||
| FuncGraphPtr ParseArguments(FuncGraphPtr func_graph, std::vector<AnfNodePtr>* const inputs_ptr) { | |||
| FuncGraphPtr ParseArguments(FuncGraphPtr func_graph, std::vector<AnfNodePtr> *const inputs_ptr) { | |||
| Token tok = ParseArgument(func_graph, inputs_ptr); | |||
| while (tok == TOK_COMMA) { | |||
| tok = ParseArgument(func_graph, inputs_ptr); | |||
| @@ -1681,9 +1681,9 @@ class IrParser { | |||
| return func_graph; | |||
| } | |||
| AnfNodePtr FindParameter(FuncGraphPtr func_graph, const std::string& param_name) { | |||
| AnfNodePtr FindParameter(FuncGraphPtr func_graph, const std::string ¶m_name) { | |||
| while (func_graph != nullptr) { | |||
| for (auto& ptr : func_graph->parameters()) { | |||
| for (auto &ptr : func_graph->parameters()) { | |||
| MS_EXCEPTION_IF_NULL(ptr); | |||
| ParameterPtr param = ptr->cast<ParameterPtr>(); | |||
| MS_EXCEPTION_IF_NULL(param); | |||
| @@ -1701,12 +1701,12 @@ class IrParser { | |||
| return nullptr; | |||
| } | |||
| bool Match(const std::string& str, const std::string& pattern) const { | |||
| bool Match(const std::string &str, const std::string &pattern) const { | |||
| return strncmp(str.c_str(), pattern.c_str(), pattern.length()) == 0; | |||
| } | |||
| template <typename T, typename V> | |||
| Token ParseScalar(ValuePtr* const val_ptr) { | |||
| Token ParseScalar(ValuePtr *const val_ptr) { | |||
| if (lexer_.GetNextToken() != TOK_NUMBER) { | |||
| return TOK_ERROR; | |||
| } | |||
| @@ -1725,7 +1725,7 @@ class IrParser { | |||
| } | |||
| template <typename VT, typename V, typename T> | |||
| Token ParseScalar(ValuePtr* const val_ptr, Token tok) { | |||
| Token ParseScalar(ValuePtr *const val_ptr, Token tok) { | |||
| if (tok != TOK_LPARENTHESIS) { | |||
| *val_ptr = std::make_shared<T>(); | |||
| return tok; | |||
| @@ -1735,7 +1735,7 @@ class IrParser { | |||
| } | |||
| template <typename VT, typename V, typename T, const unsigned nbits> | |||
| Token ParseScalar(ValuePtr* const val_ptr, Token tok) { | |||
| Token ParseScalar(ValuePtr *const val_ptr, Token tok) { | |||
| if (tok != TOK_LPARENTHESIS) { | |||
| *val_ptr = std::make_shared<T>(nbits); | |||
| return tok; | |||
| @@ -1745,7 +1745,7 @@ class IrParser { | |||
| } | |||
| template <typename T> | |||
| T StringToScalar(const std::string& text) { | |||
| T StringToScalar(const std::string &text) { | |||
| std::stringstream ss; | |||
| T value; | |||
| ss << text; | |||
| @@ -1753,7 +1753,7 @@ class IrParser { | |||
| return value; | |||
| } | |||
| Token ParseTensor(ValuePtr* const val_ptr) { | |||
| Token ParseTensor(ValuePtr *const val_ptr) { | |||
| // parse type | |||
| TypeId type; | |||
| if (lexer_.GetNextToken() != TOK_LPARENTHESIS) { | |||
| @@ -1803,7 +1803,7 @@ class IrParser { | |||
| return lexer_.GetNextToken(); | |||
| } | |||
| Token ParsePrimType(Token tok, PrimType* prim_type_ptr) { | |||
| Token ParsePrimType(Token tok, PrimType *prim_type_ptr) { | |||
| if (tok != TOK_LBRACE) { | |||
| return tok; | |||
| } | |||
| @@ -1830,7 +1830,7 @@ class IrParser { | |||
| return lexer_.GetNextToken(); | |||
| } | |||
| Token ParseMultitypeFuncGraphItem(const prim::MultitypeFuncGraphPtr& mt_func_graph, Token tok) { | |||
| Token ParseMultitypeFuncGraphItem(const prim::MultitypeFuncGraphPtr &mt_func_graph, Token tok) { | |||
| if (tok != TOK_LPARENTHESIS) { | |||
| return TOK_ERROR; | |||
| } | |||
| @@ -1855,7 +1855,7 @@ class IrParser { | |||
| return lexer_.GetNextToken(); | |||
| } | |||
| Token ParseMultitypeFuncGraph(const prim::MultitypeFuncGraphPtr& mt_func_graph, Token tok) { | |||
| Token ParseMultitypeFuncGraph(const prim::MultitypeFuncGraphPtr &mt_func_graph, Token tok) { | |||
| if (tok != TOK_LBRACE) { | |||
| return tok; | |||
| } | |||
| @@ -1868,7 +1868,7 @@ class IrParser { | |||
| return lexer_.GetNextToken(); | |||
| } | |||
| Token ParseBoolValue(const std::string& key, bool* val_ptr) { | |||
| Token ParseBoolValue(const std::string &key, bool *val_ptr) { | |||
| if (lexer_.GetNextToken() != TOK_IDENTIFIER || lexer_.GetTokenText() != key) { | |||
| return TOK_ERROR; | |||
| } | |||
| @@ -1892,7 +1892,7 @@ class IrParser { | |||
| return lexer_.GetNextToken(); | |||
| } | |||
| Token ParseValueGradOperation(const std::string& name, ValuePtr* const val_ptr) { | |||
| Token ParseValueGradOperation(const std::string &name, ValuePtr *const val_ptr) { | |||
| if (lexer_.GetNextToken() != TOK_LBRACE) { | |||
| return TOK_ERROR; | |||
| } | |||
| @@ -1920,7 +1920,7 @@ class IrParser { | |||
| return lexer_.GetNextToken(); | |||
| } | |||
| Token ParseSymbolicKeyInstance(const FuncGraphPtr& func_graph, AnfNodePtr* const node_ptr = nullptr) { | |||
| Token ParseSymbolicKeyInstance(const FuncGraphPtr &func_graph, AnfNodePtr *const node_ptr = nullptr) { | |||
| if (lexer_.GetNextToken() != TOK_LPARENTHESIS) { | |||
| return TOK_ERROR; | |||
| } | |||
| @@ -1951,7 +1951,7 @@ class IrParser { | |||
| return lexer_.GetNextToken(); | |||
| } | |||
| Token ParsePrimitivePy(const FuncGraphPtr& func_graph, const std::string& id, ValuePtr* const val_ptr) { | |||
| Token ParsePrimitivePy(const FuncGraphPtr &func_graph, const std::string &id, ValuePtr *const val_ptr) { | |||
| if (lexer_.GetNextToken() != TOK_AT_FILE) { | |||
| return TOK_ERROR; | |||
| } | |||
| @@ -1984,7 +1984,7 @@ class IrParser { | |||
| return next; | |||
| } | |||
| Token ParseValueGraphAndNamespace(const std::string& id, ValuePtr* val_ptr) { | |||
| Token ParseValueGraphAndNamespace(const std::string &id, ValuePtr *val_ptr) { | |||
| if (Match(id, "MultitypeFuncGraph::")) { | |||
| std::string name = id.substr(strlen("MultitypeFuncGraph::")); | |||
| auto mt_func_graph = std::make_shared<prim::MultitypeFuncGraph>(name); | |||
| @@ -2024,8 +2024,8 @@ class IrParser { | |||
| } | |||
| } | |||
| Token ParseValueBasic(const FuncGraphPtr& func_graph, const std::string& id, ValuePtr* val_ptr, | |||
| AnfNodePtr* const node_ptr = nullptr) { | |||
| Token ParseValueBasic(const FuncGraphPtr &func_graph, const std::string &id, ValuePtr *val_ptr, | |||
| AnfNodePtr *const node_ptr = nullptr) { | |||
| if (id == "None") { | |||
| *val_ptr = std::make_shared<None>(); | |||
| return lexer_.GetNextToken(); | |||
| @@ -2075,9 +2075,9 @@ class IrParser { | |||
| } | |||
| } | |||
| Token SetListOrTupleValue(const FuncGraphPtr& func_graph, Token left_tok, Token next, bool node_is_valid, | |||
| const std::vector<ValuePtr>& elems, const std::vector<AnfNodePtr>& nodes, | |||
| ValuePtr* const val_ptr, AnfNodePtr* node_ptr) { | |||
| Token SetListOrTupleValue(const FuncGraphPtr &func_graph, Token left_tok, Token next, bool node_is_valid, | |||
| const std::vector<ValuePtr> &elems, const std::vector<AnfNodePtr> &nodes, | |||
| ValuePtr *const val_ptr, AnfNodePtr *node_ptr) { | |||
| if (left_tok == TOK_LPARENTHESIS && next == TOK_RPARENTHESIS) { | |||
| if (node_is_valid && node_ptr != nullptr) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| @@ -2097,8 +2097,8 @@ class IrParser { | |||
| } | |||
| } | |||
| Token ParseListOrTupleValue(const FuncGraphPtr& func_graph, Token tok, ValuePtr* const val_ptr, | |||
| AnfNodePtr* node_ptr = nullptr) { | |||
| Token ParseListOrTupleValue(const FuncGraphPtr &func_graph, Token tok, ValuePtr *const val_ptr, | |||
| AnfNodePtr *node_ptr = nullptr) { | |||
| Token left_tok = tok; | |||
| std::vector<ValuePtr> elems; | |||
| @@ -2138,7 +2138,7 @@ class IrParser { | |||
| return SetListOrTupleValue(func_graph, left_tok, next, node_is_valid, elems, nodes, val_ptr, node_ptr); | |||
| } | |||
| Token ParseValue(const FuncGraphPtr& func_graph, Token tok, ValuePtr* const val_ptr, AnfNodePtr* node_ptr = nullptr) { | |||
| Token ParseValue(const FuncGraphPtr &func_graph, Token tok, ValuePtr *const val_ptr, AnfNodePtr *node_ptr = nullptr) { | |||
| // tuple or list | |||
| if (tok == TOK_LPARENTHESIS || tok == TOK_LBRACKET) { | |||
| return ParseListOrTupleValue(func_graph, tok, val_ptr, node_ptr); | |||
| @@ -2152,7 +2152,7 @@ class IrParser { | |||
| return TOK_ERROR; | |||
| } | |||
| Token ParseItem(const FuncGraphPtr& func_graph, AnfNodePtr* node_ptr, ValuePtr* const val_ptr, | |||
| Token ParseItem(const FuncGraphPtr &func_graph, AnfNodePtr *node_ptr, ValuePtr *const val_ptr, | |||
| Token tok = TOK_INVALID) { | |||
| if (tok == TOK_INVALID) { | |||
| tok = lexer_.GetNextToken(); | |||
| @@ -2193,7 +2193,7 @@ class IrParser { | |||
| return lexer_.GetNextToken(); | |||
| } | |||
| Token ParseArgument(const FuncGraphPtr& func_graph, std::vector<AnfNodePtr>* const inputs_ptr) { | |||
| Token ParseArgument(const FuncGraphPtr &func_graph, std::vector<AnfNodePtr> *const inputs_ptr) { | |||
| Token tok = lexer_.GetNextToken(); | |||
| if (tok == TOK_RPARENTHESIS) { | |||
| return tok; | |||
| @@ -2208,7 +2208,7 @@ class IrParser { | |||
| return tok; | |||
| } | |||
| const std::vector<FuncGraphPtr>& GetFuncGraphs() const { return func_graphs_; } | |||
| const std::vector<FuncGraphPtr> &GetFuncGraphs() const { return func_graphs_; } | |||
| private: | |||
| Lexer lexer_; | |||
| @@ -2226,14 +2226,14 @@ class IrParser { | |||
| std::map<std::string, ParameterPtr> param_nodes_; // map parameter name to parameter | |||
| }; | |||
| std::vector<FuncGraphPtr> ImportIR(const std::string& filename) { | |||
| std::vector<FuncGraphPtr> ImportIR(const std::string &filename) { | |||
| IrParser parser(filename.c_str()); | |||
| parser.ParseFile(); | |||
| return parser.GetFuncGraphs(); | |||
| } | |||
| #ifdef ENABLE_DUMP_IR | |||
| void DumpIRProto(const FuncGraphPtr& func_graph, const std::string& suffix) { | |||
| void DumpIRProto(const FuncGraphPtr &func_graph, const std::string &suffix) { | |||
| if (func_graph == nullptr) { | |||
| MS_LOG(ERROR) << "Func graph is nullptr"; | |||
| return; | |||
| @@ -2253,7 +2253,7 @@ void DumpIRProto(const FuncGraphPtr& func_graph, const std::string& suffix) { | |||
| return; | |||
| } | |||
| char real_path[PATH_MAX] = {0}; | |||
| char* real_path_ret = nullptr; | |||
| char *real_path_ret = nullptr; | |||
| #if defined(_WIN32) || defined(_WIN64) | |||
| real_path_ret = _fullpath(real_path, file_path.c_str(), PATH_MAX); | |||
| #else | |||
| @@ -2281,7 +2281,7 @@ void DumpIRProto(const FuncGraphPtr& func_graph, const std::string& suffix) { | |||
| ChangeFileMode(file_path, S_IRUSR); | |||
| } | |||
| #else | |||
| void DumpIRProto(const FuncGraphPtr&, const std::string&) { | |||
| void DumpIRProto(const FuncGraphPtr &, const std::string &) { | |||
| static bool already_printed = false; | |||
| if (already_printed) { | |||
| return; | |||
| @@ -39,7 +39,7 @@ | |||
| namespace mindspore { | |||
| struct ParamPtrEqual { | |||
| bool operator()(AnfNodePtr const& t1, AnfNodePtr const& t2) const { | |||
| bool operator()(AnfNodePtr const &t1, AnfNodePtr const &t2) const { | |||
| const ParameterPtr param1 = dyn_cast<Parameter>(t1); | |||
| const ParameterPtr param2 = dyn_cast<Parameter>(t2); | |||
| @@ -52,7 +52,7 @@ struct ParamPtrEqual { | |||
| }; | |||
| struct ParamPtrHasher { | |||
| std::size_t operator()(AnfNodePtr const& param) const { | |||
| std::size_t operator()(AnfNodePtr const ¶m) const { | |||
| const ParameterPtr parameter = dyn_cast<Parameter>(param); | |||
| if (parameter == nullptr) { | |||
| return 0; | |||
| @@ -64,39 +64,39 @@ struct ParamPtrHasher { | |||
| class AnfExporter { | |||
| public: | |||
| explicit AnfExporter(const std::string& id, bool export_used = true, bool check_integrity = false) | |||
| explicit AnfExporter(const std::string &id, bool export_used = true, bool check_integrity = false) | |||
| : param_index(-1), id_(id), export_used_(export_used), check_integrity_(check_integrity) { | |||
| func_graph_set.clear(); | |||
| exported.clear(); | |||
| } | |||
| virtual ~AnfExporter() {} | |||
| void ExportFuncGraph(const std::string& filename, const FuncGraphPtr& func_graph); | |||
| void ExportFuncGraph(const std::string& filename, const std::vector<TaggedGraph>& graphs); | |||
| void ExportFuncGraph(const std::string &filename, const FuncGraphPtr &func_graph); | |||
| void ExportFuncGraph(const std::string &filename, const std::vector<TaggedGraph> &graphs); | |||
| protected: | |||
| virtual std::string GetNodeType(const AnfNodePtr& nd); | |||
| int GetParamIndex(const FuncGraphPtr& func_graph, const AnfNodePtr& param, bool throw_excp = true); | |||
| int GetParamIndexFromExported(const AnfNodePtr& param); | |||
| std::string DumpObject(const py::object& obj, const std::string& category) const; | |||
| std::string GetValueNodeText(const FuncGraphPtr& func_graph, const ValueNodePtr& node); | |||
| std::string GetMultitypeFuncGraphText(const prim::MultitypeFuncGraphPtr& mt_func_graph); | |||
| std::string GetSymbolicKeyInstanceText(const FuncGraphPtr& func_graph, const SymbolicKeyInstancePtr& sym_inst); | |||
| std::string GetSequenceText(const FuncGraphPtr& func_graph, const ValuePtr& value); | |||
| std::string GetValueText(const FuncGraphPtr& func_graph, const ValuePtr& value); | |||
| std::string GetOtherValueText(const FuncGraphPtr& func_graph, const ValuePtr& value); | |||
| std::string GetPrimitiveText(const PrimitivePtr& prim); | |||
| std::string GetDictText(const FuncGraphPtr& func_graph, const ValuePtr& value); | |||
| std::string GetNameSpaceText(const parse::NameSpacePtr& ns); | |||
| std::string GetMetaFuncGraphText(const MetaFuncGraphPtr& meta_func_graph); | |||
| std::string GetAnfNodeText(const FuncGraphPtr& func_graph, const AnfNodePtr& node, | |||
| const std::map<AnfNodePtr, int>& apply_map); | |||
| void ExportOneFuncGraph(std::ofstream& ofs, const FuncGraphPtr& func_graph); | |||
| void OutputParameters(std::ofstream& ofs, const std::vector<AnfNodePtr>& parameters, | |||
| OrderedMap<AnfNodePtr, int, ParamPtrHasher, ParamPtrEqual>* param_map); | |||
| void OutputStatementComment(std::ofstream& ofs, const CNodePtr& node); | |||
| void OutputCNodes(std::ofstream& ofs, const std::vector<AnfNodePtr>& nodes, const FuncGraphPtr& func_graph); | |||
| virtual std::string GetNodeType(const AnfNodePtr &nd); | |||
| int GetParamIndex(const FuncGraphPtr &func_graph, const AnfNodePtr ¶m, bool throw_excp = true); | |||
| int GetParamIndexFromExported(const AnfNodePtr ¶m); | |||
| std::string DumpObject(const py::object &obj, const std::string &category) const; | |||
| std::string GetValueNodeText(const FuncGraphPtr &func_graph, const ValueNodePtr &node); | |||
| std::string GetMultitypeFuncGraphText(const prim::MultitypeFuncGraphPtr &mt_func_graph); | |||
| std::string GetSymbolicKeyInstanceText(const FuncGraphPtr &func_graph, const SymbolicKeyInstancePtr &sym_inst); | |||
| std::string GetSequenceText(const FuncGraphPtr &func_graph, const ValuePtr &value); | |||
| std::string GetValueText(const FuncGraphPtr &func_graph, const ValuePtr &value); | |||
| std::string GetOtherValueText(const FuncGraphPtr &func_graph, const ValuePtr &value); | |||
| std::string GetPrimitiveText(const PrimitivePtr &prim); | |||
| std::string GetDictText(const FuncGraphPtr &func_graph, const ValuePtr &value); | |||
| std::string GetNameSpaceText(const parse::NameSpacePtr &ns); | |||
| std::string GetMetaFuncGraphText(const MetaFuncGraphPtr &meta_func_graph); | |||
| std::string GetAnfNodeText(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | |||
| const std::map<AnfNodePtr, int> &apply_map); | |||
| void ExportOneFuncGraph(std::ofstream &ofs, const FuncGraphPtr &func_graph); | |||
| void OutputParameters(std::ofstream &ofs, const std::vector<AnfNodePtr> ¶meters, | |||
| OrderedMap<AnfNodePtr, int, ParamPtrHasher, ParamPtrEqual> *param_map); | |||
| void OutputStatementComment(std::ofstream &ofs, const CNodePtr &node); | |||
| void OutputCNodes(std::ofstream &ofs, const std::vector<AnfNodePtr> &nodes, const FuncGraphPtr &func_graph); | |||
| int param_index; | |||
| OrderedSet<FuncGraphPtr> func_graph_set{}; | |||
| @@ -108,16 +108,16 @@ class AnfExporter { | |||
| abstract::AnfNodeConfigPtr node_cfg_ = nullptr; | |||
| }; | |||
| void ExportIR(const std::string& filename, const std::string& id, const FuncGraphPtr& func_graph); | |||
| void ExportIR(const std::string& filename, const std::vector<TaggedGraph>& graphs); | |||
| void ExportIR(const std::string &filename, const std::string &id, const FuncGraphPtr &func_graph); | |||
| void ExportIR(const std::string &filename, const std::vector<TaggedGraph> &graphs); | |||
| std::vector<FuncGraphPtr> ImportIR(const std::string& filename); | |||
| std::vector<FuncGraphPtr> ImportIR(const std::string &filename); | |||
| std::string GetFuncGraphProtoString(const FuncGraphPtr& func_graph); | |||
| std::string GetFuncGraphProtoString(const FuncGraphPtr &func_graph); | |||
| void DumpIRProto(const FuncGraphPtr& func_graph, const std::string& suffix); | |||
| void DumpIRProto(const FuncGraphPtr &func_graph, const std::string &suffix); | |||
| std::string GetOnnxProtoString(const FuncGraphPtr& func_graph); | |||
| std::string GetOnnxProtoString(const FuncGraphPtr &func_graph); | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_DEBUG_ANF_IR_UTILS_H_ | |||
| @@ -34,7 +34,7 @@ namespace draw { | |||
| namespace { | |||
| // Only for ValueNode | |||
| std::string ValueType(const ValueNodePtr& node) { | |||
| std::string ValueType(const ValueNodePtr &node) { | |||
| if (node == nullptr) { | |||
| return ""; | |||
| } | |||
| @@ -43,7 +43,7 @@ std::string ValueType(const ValueNodePtr& node) { | |||
| return v->type_name(); | |||
| } | |||
| std::string ReplaceSpecialChar(const std::string& str) { | |||
| std::string ReplaceSpecialChar(const std::string &str) { | |||
| std::ostringstream oss; | |||
| for (size_t i = 0; i < str.size(); i++) { | |||
| if (str[i] == '<') { | |||
| @@ -59,12 +59,12 @@ std::string ReplaceSpecialChar(const std::string& str) { | |||
| } // namespace | |||
| // API of debug utils | |||
| void DrawNodes(const std::vector<AnfNodePtr>& nodes, OrderedMap<FuncGraphPtr, std::shared_ptr<BaseDigraph>>* sub_graphs, | |||
| void DrawNodes(const std::vector<AnfNodePtr> &nodes, OrderedMap<FuncGraphPtr, std::shared_ptr<BaseDigraph>> *sub_graphs, | |||
| bool is_user) { | |||
| if (sub_graphs == nullptr) { | |||
| return; | |||
| } | |||
| for (auto& nd : nodes) { | |||
| for (auto &nd : nodes) { | |||
| MS_EXCEPTION_IF_NULL(nd); | |||
| auto sub_graph = nd->func_graph(); | |||
| if (sub_graph != nullptr) { | |||
| @@ -84,16 +84,16 @@ void DrawNodes(const std::vector<AnfNodePtr>& nodes, OrderedMap<FuncGraphPtr, st | |||
| } | |||
| } | |||
| void DrawValueNodes(const std::vector<AnfNodePtr>& nodes, | |||
| OrderedMap<FuncGraphPtr, std::shared_ptr<BaseDigraph>>* sub_graphs) { | |||
| void DrawValueNodes(const std::vector<AnfNodePtr> &nodes, | |||
| OrderedMap<FuncGraphPtr, std::shared_ptr<BaseDigraph>> *sub_graphs) { | |||
| if (sub_graphs == nullptr) { | |||
| return; | |||
| } | |||
| int dup_idx = 0; | |||
| for (auto& nd : nodes) { | |||
| for (auto& t : SuccIncoming(nd)) { | |||
| for (auto &nd : nodes) { | |||
| for (auto &t : SuccIncoming(nd)) { | |||
| MS_EXCEPTION_IF_NULL(t); | |||
| MS_EXCEPTION_IF_NULL(nd); | |||
| if (t->isa<ValueNode>() && (*sub_graphs).find(nd->func_graph()) != (*sub_graphs).end()) { | |||
| @@ -107,7 +107,7 @@ void DrawValueNodes(const std::vector<AnfNodePtr>& nodes, | |||
| } | |||
| } | |||
| void DrawEdges(const std::vector<AnfNodePtr>& nodes, const std::shared_ptr<BaseDigraph>& digraph, bool is_user) { | |||
| void DrawEdges(const std::vector<AnfNodePtr> &nodes, const std::shared_ptr<BaseDigraph> &digraph, bool is_user) { | |||
| if (digraph == nullptr) { | |||
| return; | |||
| } | |||
| @@ -120,11 +120,11 @@ void DrawEdges(const std::vector<AnfNodePtr>& nodes, const std::shared_ptr<BaseD | |||
| } | |||
| // Draw edge | |||
| for (auto& nd : nodes) { | |||
| for (auto &nd : nodes) { | |||
| auto succs = SuccIncoming(nd); | |||
| auto num = succs.size(); | |||
| for (size_t i = 0; i < num; i++) { | |||
| auto& t = succs.at(i); | |||
| auto &t = succs.at(i); | |||
| MS_EXCEPTION_IF_NULL(t); | |||
| if (t->isa<ValueNode>() || t->isa<Parameter>()) { | |||
| if ((!is_user) || (i != 0)) { | |||
| @@ -143,7 +143,7 @@ void DrawEdges(const std::vector<AnfNodePtr>& nodes, const std::shared_ptr<BaseD | |||
| } | |||
| } | |||
| void DrawByOpt(std::string filename, const FuncGraphPtr& func_graph, bool is_user) { | |||
| void DrawByOpt(std::string filename, const FuncGraphPtr &func_graph, bool is_user) { | |||
| if (func_graph == nullptr) { | |||
| return; | |||
| } | |||
| @@ -169,7 +169,7 @@ void DrawByOpt(std::string filename, const FuncGraphPtr& func_graph, bool is_use | |||
| DrawValueNodes(nodes, &sub_graphs); | |||
| // Draw subgraph | |||
| for (const auto& gsub : sub_graphs) { | |||
| for (const auto &gsub : sub_graphs) { | |||
| digraph->SubGraph(gsub.first, gsub.second); | |||
| } | |||
| @@ -182,18 +182,18 @@ void DrawByOpt(std::string filename, const FuncGraphPtr& func_graph, bool is_use | |||
| } | |||
| #ifdef ENABLE_DUMP_IR | |||
| void Draw(const std::string& filename, const FuncGraphPtr& func_graph) { | |||
| void Draw(const std::string &filename, const FuncGraphPtr &func_graph) { | |||
| const std::string dot_suffix = ".dot"; | |||
| std::string filename_with_suffix = | |||
| (filename.rfind(dot_suffix) != (filename.size() - dot_suffix.size())) ? (filename + dot_suffix) : filename; | |||
| DrawByOpt(filename_with_suffix, func_graph, false); | |||
| } | |||
| void DrawUserFuncGraph(const std::string& filename, const FuncGraphPtr& func_graph) { | |||
| void DrawUserFuncGraph(const std::string &filename, const FuncGraphPtr &func_graph) { | |||
| DrawByOpt(filename, func_graph, true); | |||
| } | |||
| #else | |||
| void Draw(const std::string&, const FuncGraphPtr&) { | |||
| void Draw(const std::string &, const FuncGraphPtr &) { | |||
| static bool already_printed = false; | |||
| if (already_printed) { | |||
| return; | |||
| @@ -203,7 +203,7 @@ void Draw(const std::string&, const FuncGraphPtr&) { | |||
| << "please recompile source to enable it. See help of building script."; | |||
| } | |||
| void DrawUserFuncGraph(const std::string&, const FuncGraphPtr&) { | |||
| void DrawUserFuncGraph(const std::string &, const FuncGraphPtr &) { | |||
| static bool already_printed = false; | |||
| if (already_printed) { | |||
| return; | |||
| @@ -234,7 +234,7 @@ std::string Graphviz::Shape(AnfNodePtr node) { | |||
| return "plaintext"; | |||
| } | |||
| std::string Graphviz::Color(const AnfNodePtr& node) { | |||
| std::string Graphviz::Color(const AnfNodePtr &node) { | |||
| if (node == nullptr) { | |||
| return ""; | |||
| } | |||
| @@ -259,7 +259,7 @@ void BaseDigraph::Start() { | |||
| buffer_ << "compound=true" << std::endl; | |||
| } | |||
| void BaseDigraph::Head(const AnfNodePtr& node, int id) { | |||
| void BaseDigraph::Head(const AnfNodePtr &node, int id) { | |||
| if (node == nullptr) { | |||
| return; | |||
| } | |||
| @@ -270,7 +270,7 @@ void BaseDigraph::Head(const AnfNodePtr& node, int id) { | |||
| } | |||
| } | |||
| void BaseDigraph::Tail(const AnfNodePtr& node, int idx, int id) { | |||
| void BaseDigraph::Tail(const AnfNodePtr &node, int idx, int id) { | |||
| if (node == nullptr) { | |||
| return; | |||
| } | |||
| @@ -279,7 +279,7 @@ void BaseDigraph::Tail(const AnfNodePtr& node, int idx, int id) { | |||
| buffer_ << ":" << idx; | |||
| } | |||
| void BaseDigraph::Tail(const FuncGraphPtr& func_graph) { | |||
| void BaseDigraph::Tail(const FuncGraphPtr &func_graph) { | |||
| if (func_graph == nullptr) { | |||
| return; | |||
| } | |||
| @@ -304,12 +304,12 @@ void BaseDigraph::End() { | |||
| } | |||
| } | |||
| void BaseDigraph::FuncGraphParameters(const FuncGraphPtr& key) { | |||
| void BaseDigraph::FuncGraphParameters(const FuncGraphPtr &key) { | |||
| buffer_ << "parameters_" << key << "[shape=plaintext "; | |||
| buffer_ << "label=<<table bgcolor='paleturquoise' cellspacing='0' cellborder='1' border='0'>"; | |||
| buffer_ << "<tr><td>parameters</td></tr>"; | |||
| int count = 0; | |||
| for (auto& parameter : key->parameters()) { | |||
| for (auto ¶meter : key->parameters()) { | |||
| buffer_ << "<tr><td>"; | |||
| buffer_ << parameter->ToString(); | |||
| auto py_p = dyn_cast<Parameter>(parameter)->default_param(); | |||
| @@ -331,7 +331,7 @@ void BaseDigraph::FuncGraphParameters(const FuncGraphPtr& key) { | |||
| buffer_ << "</table>>,];"; | |||
| } | |||
| void BaseDigraph::SubGraph(const FuncGraphPtr& key, const std::shared_ptr<BaseDigraph>& gsub) { | |||
| void BaseDigraph::SubGraph(const FuncGraphPtr &key, const std::shared_ptr<BaseDigraph> &gsub) { | |||
| if (key == nullptr || gsub == nullptr) { | |||
| return; | |||
| } | |||
| @@ -361,12 +361,12 @@ Digraph::~Digraph() { | |||
| if (fout_.is_open()) { | |||
| fout_.close(); | |||
| } | |||
| } catch (const std::exception& e) { | |||
| } catch (const std::exception &e) { | |||
| MS_LOG(ERROR) << "Exception when closing file " << filename_; | |||
| } | |||
| } | |||
| static std::string ReplaceAll(std::string str, const std::string& from, const std::string& to) { | |||
| static std::string ReplaceAll(std::string str, const std::string &from, const std::string &to) { | |||
| size_t start_pos = 0; | |||
| while ((start_pos = str.find(from, start_pos)) != std::string::npos) { | |||
| (void)str.replace(start_pos, from.length(), to); | |||
| @@ -375,7 +375,7 @@ static std::string ReplaceAll(std::string str, const std::string& from, const st | |||
| return str; | |||
| } | |||
| static void DrawValueNode(Graphviz* const graph_obj, const ValueNodePtr& node) { | |||
| static void DrawValueNode(Graphviz *const graph_obj, const ValueNodePtr &node) { | |||
| MS_EXCEPTION_IF_NULL(graph_obj); | |||
| graph_obj->buffer() << "label=<<table port='core' cellborder='0' cellspacing='2' bgcolor='" << graph_obj->Color(node) | |||
| << "'>"; | |||
| @@ -410,7 +410,7 @@ static void DrawValueNode(Graphviz* const graph_obj, const ValueNodePtr& node) { | |||
| graph_obj->buffer() << "</td></tr>"; | |||
| graph_obj->buffer() << "<tr><td align='left'>"; | |||
| int i = 0; | |||
| for (const auto& attr : attrs) { | |||
| for (const auto &attr : attrs) { | |||
| if (i != 0) { | |||
| graph_obj->buffer() << "<br/>"; | |||
| } | |||
| @@ -425,7 +425,7 @@ static void DrawValueNode(Graphviz* const graph_obj, const ValueNodePtr& node) { | |||
| graph_obj->buffer() << "</table>>,"; | |||
| } | |||
| static void DrawParallelInfo(Graphviz* const graph_obj, const CNodePtr& node) { | |||
| static void DrawParallelInfo(Graphviz *const graph_obj, const CNodePtr &node) { | |||
| if (graph_obj == nullptr || node == nullptr) { | |||
| return; | |||
| } | |||
| @@ -444,7 +444,7 @@ static void DrawParallelInfo(Graphviz* const graph_obj, const CNodePtr& node) { | |||
| } | |||
| } | |||
| static void DrawCNode(Graphviz* const graph_obj, const CNodePtr& node) { | |||
| static void DrawCNode(Graphviz *const graph_obj, const CNodePtr &node) { | |||
| if (graph_obj == nullptr || node == nullptr || node->size() == 0) { | |||
| return; | |||
| } | |||
| @@ -484,7 +484,7 @@ static void DrawCNode(Graphviz* const graph_obj, const CNodePtr& node) { | |||
| } | |||
| graph_obj->buffer() << ">"; | |||
| int i = 0; | |||
| for (auto& attr : attrs) { | |||
| for (auto &attr : attrs) { | |||
| if (i != 0) { | |||
| graph_obj->buffer() << "<br/>"; | |||
| } | |||
| @@ -567,7 +567,7 @@ ModelDigraph::~ModelDigraph() { | |||
| if (fout_.is_open()) { | |||
| fout_.close(); | |||
| } | |||
| } catch (const std::exception& e) { | |||
| } catch (const std::exception &e) { | |||
| MS_LOG(ERROR) << "exception when closing file " << filename_; | |||
| } | |||
| } | |||
| @@ -31,9 +31,9 @@ namespace parse = mindspore::parse; | |||
| class Graphviz { | |||
| public: | |||
| Graphviz(const std::string& name, const std::string& filename) : name_(name), filename_(filename), fout_(filename_) {} | |||
| Graphviz(const std::string &name, const std::string &filename) : name_(name), filename_(filename), fout_(filename_) {} | |||
| explicit Graphviz(const std::string& name) : name_(name) {} | |||
| explicit Graphviz(const std::string &name) : name_(name) {} | |||
| virtual ~Graphviz() {} | |||
| @@ -41,8 +41,8 @@ class Graphviz { | |||
| virtual void End() {} | |||
| virtual std::string Shape(AnfNodePtr node); | |||
| std::string Color(const AnfNodePtr& node); | |||
| std::ostringstream& buffer() { return buffer_; } | |||
| std::string Color(const AnfNodePtr &node); | |||
| std::ostringstream &buffer() { return buffer_; } | |||
| std::ostringstream buffer_; | |||
| protected: | |||
| @@ -53,8 +53,8 @@ class Graphviz { | |||
| class BaseDigraph : public Graphviz { | |||
| public: | |||
| BaseDigraph(const std::string& name, const std::string& filename) : Graphviz(name, filename) {} | |||
| explicit BaseDigraph(const std::string& name) : Graphviz(name) {} | |||
| BaseDigraph(const std::string &name, const std::string &filename) : Graphviz(name, filename) {} | |||
| explicit BaseDigraph(const std::string &name) : Graphviz(name) {} | |||
| ~BaseDigraph() override = default; | |||
| virtual void Node(AnfNodePtr node, int id = 0) = 0; | |||
| @@ -63,21 +63,21 @@ class BaseDigraph : public Graphviz { | |||
| void Start() override; | |||
| void End() override; | |||
| virtual void Edge(AnfNodePtr start, FuncGraphPtr end, int id_start); | |||
| void FuncGraphParameters(const FuncGraphPtr& key); | |||
| void SubGraph(const FuncGraphPtr& key, const std::shared_ptr<BaseDigraph>& gsub); | |||
| void FuncGraphParameters(const FuncGraphPtr &key); | |||
| void SubGraph(const FuncGraphPtr &key, const std::shared_ptr<BaseDigraph> &gsub); | |||
| const std::string& name() const { return name_; } | |||
| const std::string &name() const { return name_; } | |||
| protected: | |||
| void Head(const AnfNodePtr& node, int id = 0); | |||
| void Tail(const AnfNodePtr& node, int idx, int id = 0); | |||
| void Tail(const FuncGraphPtr& func_graph); | |||
| void Head(const AnfNodePtr &node, int id = 0); | |||
| void Tail(const AnfNodePtr &node, int idx, int id = 0); | |||
| void Tail(const FuncGraphPtr &func_graph); | |||
| }; | |||
| class Digraph : public BaseDigraph { | |||
| public: | |||
| Digraph(const std::string& name, const std::string& filename) : BaseDigraph(name, filename) {} | |||
| explicit Digraph(const std::string& name) : BaseDigraph(name) {} | |||
| Digraph(const std::string &name, const std::string &filename) : BaseDigraph(name, filename) {} | |||
| explicit Digraph(const std::string &name) : BaseDigraph(name) {} | |||
| ~Digraph() override; | |||
| void Node(AnfNodePtr node, int id = 0) override; | |||
| @@ -86,8 +86,8 @@ class Digraph : public BaseDigraph { | |||
| class ModelDigraph : public BaseDigraph { | |||
| public: | |||
| ModelDigraph(const std::string& name, const std::string& filename) : BaseDigraph(name, filename) {} | |||
| explicit ModelDigraph(const std::string& name) : BaseDigraph(name) {} | |||
| ModelDigraph(const std::string &name, const std::string &filename) : BaseDigraph(name, filename) {} | |||
| explicit ModelDigraph(const std::string &name) : BaseDigraph(name) {} | |||
| ~ModelDigraph() override; | |||
| std::string Shape(AnfNodePtr node) override; | |||
| @@ -96,8 +96,8 @@ class ModelDigraph : public BaseDigraph { | |||
| }; | |||
| // API to draw | |||
| void Draw(const std::string& filename, const FuncGraphPtr& func_graph); | |||
| void DrawUserFuncGraph(const std::string& filename, const FuncGraphPtr& func_graph); | |||
| void Draw(const std::string &filename, const FuncGraphPtr &func_graph); | |||
| void DrawUserFuncGraph(const std::string &filename, const FuncGraphPtr &func_graph); | |||
| } // namespace draw | |||
| } // namespace mindspore | |||
| @@ -33,38 +33,38 @@ class ProtoExporter { | |||
| ProtoExporter() {} | |||
| ~ProtoExporter() {} | |||
| std::string GetFuncGraphProtoString(const FuncGraphPtr& func_graph); | |||
| std::string GetFuncGraphProtoString(const FuncGraphPtr &func_graph); | |||
| private: | |||
| void InitModelInfo(); | |||
| void GetOpNodeTypeAndAttrs(const FuncGraphPtr& func_graph, const AnfNodePtr& node, irpb::NodeProto* node_proto); | |||
| std::string GetOpNodeInputId(const FuncGraphPtr& func_graph, const AnfNodePtr& node, | |||
| const std::map<AnfNodePtr, size_t>& apply_map, | |||
| std::map<AnfNodePtr, size_t>* const_map_ptr); | |||
| void SetValueToProto(const ValuePtr& attr_value, irpb::ValueProto* value_proto); | |||
| void SetScalarToProto(const ScalarPtr& val, irpb::ValueProto* value_proto); | |||
| void SetSequenceToProto(const ValueSequeuePtr& val, irpb::ValueProto* value_proto); | |||
| void SetDictionaryToProto(const ValueDictionaryPtr& val, irpb::ValueProto* value_proto); | |||
| void SetNodeOutputType(const AnfNodePtr& node, irpb::TypeProto* type_proto); | |||
| void SetNodeOutputType(const TypePtr& node, const BaseShapePtr& shape, irpb::TypeProto* type_proto); | |||
| void ExportFuncGraph(const FuncGraphPtr& func_graph, irpb::GraphProto* graph_proto); | |||
| void ExportParameters(const FuncGraphPtr& func_graph, irpb::GraphProto* graph_proto); | |||
| void ExportCNodes(const FuncGraphPtr& func_graph, irpb::GraphProto* graph_proto, | |||
| std::map<AnfNodePtr, size_t>* const_map_ptr); | |||
| void ExportCNode(const FuncGraphPtr& func_graph, const CNodePtr& node, std::map<AnfNodePtr, size_t>* apply_map_ptr, | |||
| std::map<AnfNodePtr, size_t>* const_map_ptr, irpb::GraphProto* graph_proto); | |||
| void ExportFuncGraphOutput(const FuncGraphPtr& func_graph, const CNodePtr& ret_node, | |||
| const std::map<AnfNodePtr, size_t>& apply_map, std::map<AnfNodePtr, size_t>* const_map_ptr, | |||
| irpb::GraphProto* graph_proto); | |||
| void ExportValueNodes(const std::map<AnfNodePtr, size_t>& const_map, irpb::GraphProto* graph_proto); | |||
| void GetOpNodeTypeAndAttrs(const FuncGraphPtr &func_graph, const AnfNodePtr &node, irpb::NodeProto *node_proto); | |||
| std::string GetOpNodeInputId(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | |||
| const std::map<AnfNodePtr, size_t> &apply_map, | |||
| std::map<AnfNodePtr, size_t> *const_map_ptr); | |||
| void SetValueToProto(const ValuePtr &attr_value, irpb::ValueProto *value_proto); | |||
| void SetScalarToProto(const ScalarPtr &val, irpb::ValueProto *value_proto); | |||
| void SetSequenceToProto(const ValueSequeuePtr &val, irpb::ValueProto *value_proto); | |||
| void SetDictionaryToProto(const ValueDictionaryPtr &val, irpb::ValueProto *value_proto); | |||
| void SetNodeOutputType(const AnfNodePtr &node, irpb::TypeProto *type_proto); | |||
| void SetNodeOutputType(const TypePtr &node, const BaseShapePtr &shape, irpb::TypeProto *type_proto); | |||
| void ExportFuncGraph(const FuncGraphPtr &func_graph, irpb::GraphProto *graph_proto); | |||
| void ExportParameters(const FuncGraphPtr &func_graph, irpb::GraphProto *graph_proto); | |||
| void ExportCNodes(const FuncGraphPtr &func_graph, irpb::GraphProto *graph_proto, | |||
| std::map<AnfNodePtr, size_t> *const_map_ptr); | |||
| void ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map<AnfNodePtr, size_t> *apply_map_ptr, | |||
| std::map<AnfNodePtr, size_t> *const_map_ptr, irpb::GraphProto *graph_proto); | |||
| void ExportFuncGraphOutput(const FuncGraphPtr &func_graph, const CNodePtr &ret_node, | |||
| const std::map<AnfNodePtr, size_t> &apply_map, std::map<AnfNodePtr, size_t> *const_map_ptr, | |||
| irpb::GraphProto *graph_proto); | |||
| void ExportValueNodes(const std::map<AnfNodePtr, size_t> &const_map, irpb::GraphProto *graph_proto); | |||
| static std::string GetConstNodeId(size_t idx) { return std::string("cst") + std::to_string(idx); } | |||
| irpb::ModelProto model_; | |||
| }; | |||
| static irpb::DataType GetNumberDataType(const TypePtr& type) { | |||
| static irpb::DataType GetNumberDataType(const TypePtr &type) { | |||
| switch (type->type_id()) { | |||
| case kNumberTypeBool: | |||
| return irpb::DT_BOOL; | |||
| @@ -101,7 +101,7 @@ static irpb::DataType GetNumberDataType(const TypePtr& type) { | |||
| } | |||
| } | |||
| void ProtoExporter::SetNodeOutputType(const TypePtr& type, const BaseShapePtr& shape, irpb::TypeProto* type_proto) { | |||
| void ProtoExporter::SetNodeOutputType(const TypePtr &type, const BaseShapePtr &shape, irpb::TypeProto *type_proto) { | |||
| if (type_proto == nullptr) { | |||
| return; | |||
| } | |||
| @@ -116,14 +116,14 @@ void ProtoExporter::SetNodeOutputType(const TypePtr& type, const BaseShapePtr& s | |||
| type_proto->set_data_type(irpb::DT_TENSOR); | |||
| if (shape != nullptr && shape->isa<abstract::Shape>()) { | |||
| abstract::ShapePtr shape_info = dyn_cast<abstract::Shape>(shape); | |||
| for (const auto& elem : shape_info->shape()) { | |||
| for (const auto &elem : shape_info->shape()) { | |||
| type_proto->mutable_tensor_type()->mutable_shape()->add_dim()->set_size(elem); | |||
| } | |||
| } | |||
| } else if (type->isa<Tuple>()) { | |||
| TuplePtr tuple_type = dyn_cast<Tuple>(type); | |||
| type_proto->set_data_type(irpb::DT_TUPLE); | |||
| for (const auto& elem_type : tuple_type->elements()) { | |||
| for (const auto &elem_type : tuple_type->elements()) { | |||
| SetNodeOutputType(elem_type, nullptr, type_proto->mutable_sequence_type()->add_elem_types()); | |||
| } | |||
| } else if (type->isa<TypeType>()) { | |||
| @@ -131,7 +131,7 @@ void ProtoExporter::SetNodeOutputType(const TypePtr& type, const BaseShapePtr& s | |||
| } else if (type->isa<List>()) { | |||
| ListPtr list_type = dyn_cast<List>(type); | |||
| type_proto->set_data_type(irpb::DT_LIST); | |||
| for (const auto& elem_type : list_type->elements()) { | |||
| for (const auto &elem_type : list_type->elements()) { | |||
| SetNodeOutputType(elem_type, nullptr, type_proto->mutable_sequence_type()->add_elem_types()); | |||
| } | |||
| } else if (type->isa<TypeAnything>()) { | |||
| @@ -153,20 +153,20 @@ void ProtoExporter::SetNodeOutputType(const TypePtr& type, const BaseShapePtr& s | |||
| } | |||
| } | |||
| void ProtoExporter::SetNodeOutputType(const AnfNodePtr& node, irpb::TypeProto* type_proto) { | |||
| void ProtoExporter::SetNodeOutputType(const AnfNodePtr &node, irpb::TypeProto *type_proto) { | |||
| if (node == nullptr || type_proto == nullptr) { | |||
| return; | |||
| } | |||
| SetNodeOutputType(node->Type(), node->Shape(), type_proto); | |||
| } | |||
| void ProtoExporter::SetValueToProto(const ValuePtr& val, irpb::ValueProto* value_proto) { | |||
| void ProtoExporter::SetValueToProto(const ValuePtr &val, irpb::ValueProto *value_proto) { | |||
| if (val == nullptr || value_proto == nullptr) { | |||
| return; | |||
| } | |||
| if (val->isa<StringImm>()) { | |||
| const StringImmPtr& value = dyn_cast<StringImm>(val); | |||
| const StringImmPtr &value = dyn_cast<StringImm>(val); | |||
| value_proto->set_dtype(irpb::DT_STRING); | |||
| value_proto->set_str_val(value->value()); | |||
| } else if (val->isa<Scalar>()) { | |||
| @@ -195,15 +195,15 @@ void ProtoExporter::SetValueToProto(const ValuePtr& val, irpb::ValueProto* value | |||
| } else if (val->isa<tensor::Tensor>()) { | |||
| tensor::TensorPtr tensor_ptr = dyn_cast<tensor::Tensor>(val); | |||
| value_proto->set_dtype(irpb::DT_TENSOR); | |||
| irpb::TensorProto* tensor_proto = value_proto->mutable_tensor_val(); | |||
| irpb::TensorProto *tensor_proto = value_proto->mutable_tensor_val(); | |||
| tensor_proto->set_data_type(GetNumberDataType(tensor_ptr->Dtype())); | |||
| for (auto& elem : tensor_ptr->shape()) { | |||
| for (auto &elem : tensor_ptr->shape()) { | |||
| tensor_proto->add_dims(elem); | |||
| } | |||
| } else if (val->isa<TensorType>()) { | |||
| value_proto->set_dtype(irpb::DT_TYPE); | |||
| irpb::TypeProto* type_proto = value_proto->mutable_type_val(); | |||
| irpb::TypeProto *type_proto = value_proto->mutable_type_val(); | |||
| type_proto->set_data_type(irpb::DT_TENSOR); | |||
| TypePtr elem_type = dyn_cast<TensorType>(val)->element(); | |||
| type_proto->mutable_tensor_type()->set_elem_type(GetNumberDataType(elem_type)); | |||
| @@ -212,53 +212,53 @@ void ProtoExporter::SetValueToProto(const ValuePtr& val, irpb::ValueProto* value | |||
| } | |||
| } | |||
| void ProtoExporter::SetScalarToProto(const ScalarPtr& val, irpb::ValueProto* value_proto) { | |||
| void ProtoExporter::SetScalarToProto(const ScalarPtr &val, irpb::ValueProto *value_proto) { | |||
| if (val == nullptr || value_proto == nullptr) { | |||
| return; | |||
| } | |||
| if (val->isa<BoolImm>()) { | |||
| const BoolImmPtr& value = dyn_cast<BoolImm>(val); | |||
| const BoolImmPtr &value = dyn_cast<BoolImm>(val); | |||
| value_proto->set_dtype(irpb::DT_BOOL); | |||
| value_proto->set_bool_val(value->value()); | |||
| } else if (val->isa<Int8Imm>()) { | |||
| const Int8ImmPtr& value = dyn_cast<Int8Imm>(val); | |||
| const Int8ImmPtr &value = dyn_cast<Int8Imm>(val); | |||
| value_proto->set_dtype(irpb::DT_INT8); | |||
| value_proto->set_int_val(value->value()); | |||
| } else if (val->isa<Int16Imm>()) { | |||
| const Int16ImmPtr& value = dyn_cast<Int16Imm>(val); | |||
| const Int16ImmPtr &value = dyn_cast<Int16Imm>(val); | |||
| value_proto->set_dtype(irpb::DT_INT16); | |||
| value_proto->set_int_val(value->value()); | |||
| } else if (val->isa<Int32Imm>()) { | |||
| const Int32ImmPtr& value = dyn_cast<Int32Imm>(val); | |||
| const Int32ImmPtr &value = dyn_cast<Int32Imm>(val); | |||
| value_proto->set_dtype(irpb::DT_INT32); | |||
| value_proto->set_int_val(value->value()); | |||
| } else if (val->isa<Int64Imm>()) { | |||
| const Int64ImmPtr& value = dyn_cast<Int64Imm>(val); | |||
| const Int64ImmPtr &value = dyn_cast<Int64Imm>(val); | |||
| value_proto->set_dtype(irpb::DT_INT64); | |||
| value_proto->set_int_val(value->value()); | |||
| } else if (val->isa<UInt8Imm>()) { | |||
| const UInt8ImmPtr& value = dyn_cast<UInt8Imm>(val); | |||
| const UInt8ImmPtr &value = dyn_cast<UInt8Imm>(val); | |||
| value_proto->set_dtype(irpb::DT_UINT8); | |||
| value_proto->set_uint_val(value->value()); | |||
| } else if (val->isa<UInt16Imm>()) { | |||
| const UInt16ImmPtr& value = dyn_cast<UInt16Imm>(val); | |||
| const UInt16ImmPtr &value = dyn_cast<UInt16Imm>(val); | |||
| value_proto->set_dtype(irpb::DT_UINT16); | |||
| value_proto->set_uint_val(value->value()); | |||
| } else if (val->isa<UInt32Imm>()) { | |||
| const UInt32ImmPtr& value = dyn_cast<UInt32Imm>(val); | |||
| const UInt32ImmPtr &value = dyn_cast<UInt32Imm>(val); | |||
| value_proto->set_dtype(irpb::DT_UINT32); | |||
| value_proto->set_uint_val(value->value()); | |||
| } else if (val->isa<UInt64Imm>()) { | |||
| const UInt64ImmPtr& value = dyn_cast<UInt64Imm>(val); | |||
| const UInt64ImmPtr &value = dyn_cast<UInt64Imm>(val); | |||
| value_proto->set_dtype(irpb::DT_UINT64); | |||
| value_proto->set_uint_val(value->value()); | |||
| } else if (val->isa<FP32Imm>()) { | |||
| const FP32ImmPtr& value = dyn_cast<FP32Imm>(val); | |||
| const FP32ImmPtr &value = dyn_cast<FP32Imm>(val); | |||
| value_proto->set_dtype(irpb::DT_FLOAT32); | |||
| value_proto->set_float_val(value->value()); | |||
| } else if (val->isa<FP64Imm>()) { | |||
| const FP64ImmPtr& value = dyn_cast<FP64Imm>(val); | |||
| const FP64ImmPtr &value = dyn_cast<FP64Imm>(val); | |||
| value_proto->set_dtype(irpb::DT_FLOAT64); | |||
| value_proto->set_double_val(value->value()); | |||
| } else { | |||
| @@ -266,40 +266,40 @@ void ProtoExporter::SetScalarToProto(const ScalarPtr& val, irpb::ValueProto* val | |||
| } | |||
| } | |||
| void ProtoExporter::SetSequenceToProto(const ValueSequeuePtr& val, irpb::ValueProto* value_proto) { | |||
| void ProtoExporter::SetSequenceToProto(const ValueSequeuePtr &val, irpb::ValueProto *value_proto) { | |||
| if (val == nullptr || value_proto == nullptr) { | |||
| return; | |||
| } | |||
| if (val->isa<ValueTuple>()) { | |||
| const ValueTuplePtr& value = dyn_cast<ValueTuple>(val); | |||
| const ValueTuplePtr &value = dyn_cast<ValueTuple>(val); | |||
| value_proto->set_dtype(irpb::DT_TUPLE); | |||
| for (const auto& item : value->value()) { | |||
| for (const auto &item : value->value()) { | |||
| SetValueToProto(item, value_proto->add_values()); | |||
| } | |||
| } else if (val->isa<ValueList>()) { | |||
| const ValueListPtr& value = dyn_cast<ValueList>(val); | |||
| const ValueListPtr &value = dyn_cast<ValueList>(val); | |||
| value_proto->set_dtype(irpb::DT_LIST); | |||
| for (const auto& item : value->value()) { | |||
| for (const auto &item : value->value()) { | |||
| SetValueToProto(item, value_proto->add_values()); | |||
| } | |||
| } | |||
| } | |||
| void ProtoExporter::SetDictionaryToProto(const ValueDictionaryPtr& val, irpb::ValueProto* value_proto) { | |||
| void ProtoExporter::SetDictionaryToProto(const ValueDictionaryPtr &val, irpb::ValueProto *value_proto) { | |||
| if (val == nullptr || value_proto == nullptr) { | |||
| return; | |||
| } | |||
| value_proto->set_dtype(irpb::DT_DICT); | |||
| for (const auto& item : val->value()) { | |||
| irpb::NamedValueProto* named_val = value_proto->add_dict_val(); | |||
| for (const auto &item : val->value()) { | |||
| irpb::NamedValueProto *named_val = value_proto->add_dict_val(); | |||
| named_val->set_key(item.first); | |||
| SetValueToProto(item.second, named_val->mutable_value()); | |||
| } | |||
| } | |||
| void ProtoExporter::GetOpNodeTypeAndAttrs(const FuncGraphPtr&, const AnfNodePtr& node, irpb::NodeProto* node_proto) { | |||
| void ProtoExporter::GetOpNodeTypeAndAttrs(const FuncGraphPtr &, const AnfNodePtr &node, irpb::NodeProto *node_proto) { | |||
| if (node == nullptr || node_proto == nullptr) { | |||
| return; | |||
| } | |||
| @@ -312,19 +312,19 @@ void ProtoExporter::GetOpNodeTypeAndAttrs(const FuncGraphPtr&, const AnfNodePtr& | |||
| MS_LOG(EXCEPTION) << "Op node is not primitive: " << node->ToString(); | |||
| } | |||
| const PrimitivePtr& prim = GetValueNode<PrimitivePtr>(node); | |||
| const PrimitivePtr &prim = GetValueNode<PrimitivePtr>(node); | |||
| node_proto->set_op_type(prim->name()); | |||
| for (const auto& attr : prim->attrs()) { | |||
| irpb::AttributeProto* attr_proto = node_proto->add_attribute(); | |||
| for (const auto &attr : prim->attrs()) { | |||
| irpb::AttributeProto *attr_proto = node_proto->add_attribute(); | |||
| attr_proto->set_name(attr.first); | |||
| SetValueToProto(attr.second, attr_proto->mutable_value()); | |||
| } | |||
| node_proto->set_scope(node->scope()->name()); | |||
| } | |||
| std::string ProtoExporter::GetOpNodeInputId(const FuncGraphPtr&, const AnfNodePtr& node, | |||
| const std::map<AnfNodePtr, size_t>& apply_map, | |||
| std::map<AnfNodePtr, size_t>* const_map_ptr) { | |||
| std::string ProtoExporter::GetOpNodeInputId(const FuncGraphPtr &, const AnfNodePtr &node, | |||
| const std::map<AnfNodePtr, size_t> &apply_map, | |||
| std::map<AnfNodePtr, size_t> *const_map_ptr) { | |||
| if (node == nullptr || const_map_ptr == nullptr) { | |||
| return ""; | |||
| } | |||
| @@ -354,18 +354,18 @@ std::string ProtoExporter::GetOpNodeInputId(const FuncGraphPtr&, const AnfNodePt | |||
| MS_LOG(EXCEPTION) << "Unknown node type. node is '" << node->ToString() << "'"; | |||
| } | |||
| std::string ProtoExporter::GetFuncGraphProtoString(const FuncGraphPtr& func_graph) { | |||
| std::string ProtoExporter::GetFuncGraphProtoString(const FuncGraphPtr &func_graph) { | |||
| if (func_graph == nullptr) { | |||
| return ""; | |||
| } | |||
| InitModelInfo(); | |||
| irpb::GraphProto* graph_proto = model_.mutable_graph(); | |||
| irpb::GraphProto *graph_proto = model_.mutable_graph(); | |||
| ExportFuncGraph(func_graph, graph_proto); | |||
| return model_.SerializeAsString(); | |||
| } | |||
| void ProtoExporter::ExportFuncGraph(const FuncGraphPtr& func_graph, irpb::GraphProto* graph_proto) { | |||
| void ProtoExporter::ExportFuncGraph(const FuncGraphPtr &func_graph, irpb::GraphProto *graph_proto) { | |||
| if (func_graph == nullptr || graph_proto == nullptr) { | |||
| return; | |||
| } | |||
| @@ -383,14 +383,14 @@ void ProtoExporter::ExportFuncGraph(const FuncGraphPtr& func_graph, irpb::GraphP | |||
| ExportValueNodes(const_map, graph_proto); | |||
| } | |||
| void ProtoExporter::ExportParameters(const FuncGraphPtr& func_graph, irpb::GraphProto* graph_proto) { | |||
| void ProtoExporter::ExportParameters(const FuncGraphPtr &func_graph, irpb::GraphProto *graph_proto) { | |||
| if (func_graph == nullptr || graph_proto == nullptr) { | |||
| return; | |||
| } | |||
| std::vector<AnfNodePtr> parameters = func_graph->parameters(); | |||
| for (auto& param : parameters) { | |||
| irpb::ParameterProto* param_proto = graph_proto->add_parameters(); | |||
| for (auto ¶m : parameters) { | |||
| irpb::ParameterProto *param_proto = graph_proto->add_parameters(); | |||
| param_proto->set_name(param->ToString()); | |||
| SetNodeOutputType(param, param_proto->mutable_type()); | |||
| @@ -402,15 +402,15 @@ void ProtoExporter::ExportParameters(const FuncGraphPtr& func_graph, irpb::Graph | |||
| } | |||
| } | |||
| void ProtoExporter::ExportCNodes(const FuncGraphPtr& func_graph, irpb::GraphProto* graph_proto, | |||
| std::map<AnfNodePtr, size_t>* const_map_ptr) { | |||
| void ProtoExporter::ExportCNodes(const FuncGraphPtr &func_graph, irpb::GraphProto *graph_proto, | |||
| std::map<AnfNodePtr, size_t> *const_map_ptr) { | |||
| if (func_graph == nullptr || graph_proto == nullptr || const_map_ptr == nullptr) { | |||
| return; | |||
| } | |||
| // topo sort nodes | |||
| std::vector<AnfNodePtr> nodes = TopoSort(func_graph->get_return(), SuccIncoming, AlwaysInclude); | |||
| std::map<AnfNodePtr, size_t> apply_map; | |||
| for (const AnfNodePtr& node : nodes) { | |||
| for (const AnfNodePtr &node : nodes) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| if (!node->isa<CNode>()) { | |||
| continue; | |||
| @@ -424,9 +424,9 @@ void ProtoExporter::ExportCNodes(const FuncGraphPtr& func_graph, irpb::GraphProt | |||
| } | |||
| } | |||
| void ProtoExporter::ExportCNode(const FuncGraphPtr& func_graph, const CNodePtr& node, | |||
| std::map<AnfNodePtr, size_t>* apply_map_ptr, | |||
| std::map<AnfNodePtr, size_t>* const_map_ptr, irpb::GraphProto* graph_proto) { | |||
| void ProtoExporter::ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &node, | |||
| std::map<AnfNodePtr, size_t> *apply_map_ptr, | |||
| std::map<AnfNodePtr, size_t> *const_map_ptr, irpb::GraphProto *graph_proto) { | |||
| if (func_graph == nullptr || node == nullptr || apply_map_ptr == nullptr || const_map_ptr == nullptr || | |||
| graph_proto == nullptr) { | |||
| return; | |||
| @@ -435,12 +435,12 @@ void ProtoExporter::ExportCNode(const FuncGraphPtr& func_graph, const CNodePtr& | |||
| auto apply_idx = apply_map_ptr->size() + 1; | |||
| (*apply_map_ptr)[node] = apply_idx; | |||
| auto& inputs = node->inputs(); | |||
| auto &inputs = node->inputs(); | |||
| if (inputs.size() < 1) { | |||
| MS_LOG(EXCEPTION) << "Inputs of apply node is empty"; | |||
| } | |||
| AnfNodePtr op = inputs[0]; | |||
| irpb::NodeProto* node_proto = graph_proto->add_node(); | |||
| irpb::NodeProto *node_proto = graph_proto->add_node(); | |||
| // CNode/ConstGraph/Const/Parameter | |||
| if (op->isa<CNode>() || IsValueNode<FuncGraph>(op) || op->isa<Parameter>()) { | |||
| @@ -452,7 +452,7 @@ void ProtoExporter::ExportCNode(const FuncGraphPtr& func_graph, const CNodePtr& | |||
| // process OP inputs | |||
| for (size_t i = 1; i < inputs.size(); ++i) { | |||
| irpb::InputProto* input_proto = node_proto->add_input(); | |||
| irpb::InputProto *input_proto = node_proto->add_input(); | |||
| input_proto->set_type(irpb::InputProto_EdgeType_DATA_EDGE); | |||
| std::string id = GetOpNodeInputId(func_graph, inputs[i], *apply_map_ptr, const_map_ptr); | |||
| input_proto->set_name(id); | |||
| @@ -463,9 +463,9 @@ void ProtoExporter::ExportCNode(const FuncGraphPtr& func_graph, const CNodePtr& | |||
| } | |||
| } | |||
| void ProtoExporter::ExportFuncGraphOutput(const FuncGraphPtr& func_graph, const CNodePtr& ret_node, | |||
| const std::map<AnfNodePtr, size_t>& apply_map, | |||
| std::map<AnfNodePtr, size_t>* const_map_ptr, irpb::GraphProto* graph_proto) { | |||
| void ProtoExporter::ExportFuncGraphOutput(const FuncGraphPtr &func_graph, const CNodePtr &ret_node, | |||
| const std::map<AnfNodePtr, size_t> &apply_map, | |||
| std::map<AnfNodePtr, size_t> *const_map_ptr, irpb::GraphProto *graph_proto) { | |||
| if (ret_node == nullptr || !ret_node->isa<CNode>()) { | |||
| MS_LOG(EXCEPTION) << "Graph return node is illegal"; | |||
| } | |||
| @@ -473,7 +473,7 @@ void ProtoExporter::ExportFuncGraphOutput(const FuncGraphPtr& func_graph, const | |||
| if (graph_proto == nullptr) { | |||
| MS_LOG(EXCEPTION) << "graph_proto is nullptr"; | |||
| } | |||
| irpb::OutputProto* output_proto = graph_proto->add_outputs(); | |||
| irpb::OutputProto *output_proto = graph_proto->add_outputs(); | |||
| if (output_proto == nullptr) { | |||
| MS_LOG(EXCEPTION) << "output_proto is nullptr"; | |||
| } | |||
| @@ -482,22 +482,22 @@ void ProtoExporter::ExportFuncGraphOutput(const FuncGraphPtr& func_graph, const | |||
| SetNodeOutputType(arg, output_proto->mutable_type()); | |||
| } | |||
| static bool CompareValue(const std::pair<AnfNodePtr, size_t>& x, const std::pair<AnfNodePtr, size_t>& y) { | |||
| static bool CompareValue(const std::pair<AnfNodePtr, size_t> &x, const std::pair<AnfNodePtr, size_t> &y) { | |||
| return x.second < y.second; | |||
| } | |||
| void ProtoExporter::ExportValueNodes(const std::map<AnfNodePtr, size_t>& const_map, irpb::GraphProto* graph_proto) { | |||
| void ProtoExporter::ExportValueNodes(const std::map<AnfNodePtr, size_t> &const_map, irpb::GraphProto *graph_proto) { | |||
| std::vector<std::pair<AnfNodePtr, size_t>> nodes; | |||
| (void)std::transform(const_map.cbegin(), const_map.cend(), std::back_inserter(nodes), | |||
| [](const std::pair<AnfNodePtr, size_t>& item) { return item; }); | |||
| [](const std::pair<AnfNodePtr, size_t> &item) { return item; }); | |||
| sort(nodes.begin(), nodes.end(), CompareValue); | |||
| for (auto& item : nodes) { | |||
| for (auto &item : nodes) { | |||
| if (graph_proto == nullptr) { | |||
| MS_LOG(EXCEPTION) << "graph_proto is nullptr"; | |||
| } | |||
| irpb::NamedValueProto* named_value = graph_proto->add_const_vals(); | |||
| irpb::NamedValueProto *named_value = graph_proto->add_const_vals(); | |||
| MS_EXCEPTION_IF_NULL(named_value); | |||
| named_value->set_key(GetConstNodeId(item.second)); | |||
| SetValueToProto(GetValueNode(item.first), named_value->mutable_value()); | |||
| @@ -506,7 +506,7 @@ void ProtoExporter::ExportValueNodes(const std::map<AnfNodePtr, size_t>& const_m | |||
| void ProtoExporter::InitModelInfo() { model_.set_ir_version(irpb::IR_VERSION); } | |||
| std::string GetFuncGraphProtoString(const FuncGraphPtr& func_graph) { | |||
| std::string GetFuncGraphProtoString(const FuncGraphPtr &func_graph) { | |||
| ProtoExporter exporter; | |||
| return exporter.GetFuncGraphProtoString(func_graph); | |||
| } | |||
| @@ -36,7 +36,7 @@ Dump::Dump() | |||
| dump_iter_(0), | |||
| cur_iter_(0) {} | |||
| bool Dump::IsKernelNeedDump(const std::string& kernel_name) { | |||
| bool Dump::IsKernelNeedDump(const std::string &kernel_name) { | |||
| if (dump_mode_ == 0) { | |||
| // Dump All Kernels mode | |||
| return true; | |||
| @@ -49,7 +49,7 @@ bool Dump::IsKernelNeedDump(const std::string& kernel_name) { | |||
| return false; | |||
| } | |||
| bool Dump::ParseDumpConfig(const std::string& dump_config_file) { | |||
| bool Dump::ParseDumpConfig(const std::string &dump_config_file) { | |||
| std::ifstream jsonFile(dump_config_file); | |||
| if (!jsonFile.is_open()) { | |||
| MS_LOG(ERROR) << dump_config_file << " open failed."; | |||
| @@ -79,7 +79,7 @@ bool Dump::ParseDumpConfig(const std::string& dump_config_file) { | |||
| return true; | |||
| } | |||
| bool Dump::IsConfigExist(const nlohmann::json& dumpSettings) { | |||
| bool Dump::IsConfigExist(const nlohmann::json &dumpSettings) { | |||
| if (dumpSettings.find("trans_flag") == dumpSettings.end() || dumpSettings.find("enable") == dumpSettings.end() || | |||
| dumpSettings.find("mode") == dumpSettings.end() || dumpSettings.find("path") == dumpSettings.end() || | |||
| dumpSettings.find("net_name") == dumpSettings.end() || dumpSettings.find("iteration") == dumpSettings.end() || | |||
| @@ -91,7 +91,7 @@ bool Dump::IsConfigExist(const nlohmann::json& dumpSettings) { | |||
| return true; | |||
| } | |||
| bool Dump::IsConfigValid(const nlohmann::json& dumpSettings) { | |||
| bool Dump::IsConfigValid(const nlohmann::json &dumpSettings) { | |||
| auto trans_flag = dumpSettings.at("trans_flag"); | |||
| auto enable = dumpSettings.at("enable"); | |||
| auto mode = dumpSettings.at("mode"); | |||
| @@ -112,14 +112,14 @@ bool Dump::IsConfigValid(const nlohmann::json& dumpSettings) { | |||
| dump_path_ = path; | |||
| dump_net_name_ = net_name; | |||
| dump_iter_ = iteration; | |||
| for (const auto& kernel : kernels) { | |||
| for (const auto &kernel : kernels) { | |||
| dump_kernels_.push_back(kernel); | |||
| } | |||
| return true; | |||
| } | |||
| bool Dump::SetDumpConfFromJsonFile() { | |||
| const char* config_path_str = std::getenv("MINDSPORE_CONFIG_PATH"); | |||
| const char *config_path_str = std::getenv("MINDSPORE_CONFIG_PATH"); | |||
| if (config_path_str != nullptr) { | |||
| MS_LOG(INFO) << "Getenv MINDSPORE_CONFIG_PATH :" << config_path_str; | |||
| } else { | |||
| @@ -148,7 +148,7 @@ bool Dump::SetDumpConfFromJsonFile() { | |||
| return ParseDumpConfig(dump_config_file); | |||
| } | |||
| bool Dump::DumpToFile(const std::string& filename, const void* data, size_t len) { | |||
| bool Dump::DumpToFile(const std::string &filename, const void *data, size_t len) { | |||
| if (filename.empty() || data == nullptr || len == 0) { | |||
| MS_LOG(ERROR) << "Incorrect parameter."; | |||
| return false; | |||
| @@ -166,12 +166,12 @@ bool Dump::DumpToFile(const std::string& filename, const void* data, size_t len) | |||
| MS_LOG(ERROR) << "Open file " << realpath << " fail."; | |||
| return false; | |||
| } | |||
| (void)fd.write(reinterpret_cast<const char*>(data), SizeToLong(len)); | |||
| (void)fd.write(reinterpret_cast<const char *>(data), SizeToLong(len)); | |||
| fd.close(); | |||
| return true; | |||
| } | |||
| bool Dump::GetRealPath(const std::string& inpath, std::string* outpath) { | |||
| bool Dump::GetRealPath(const std::string &inpath, std::string *outpath) { | |||
| MS_EXCEPTION_IF_NULL(outpath); | |||
| auto path_split_pos = inpath.find_last_of('/'); | |||
| if (path_split_pos == std::string::npos) { | |||
| @@ -213,7 +213,7 @@ bool Dump::GetRealPath(const std::string& inpath, std::string* outpath) { | |||
| return true; | |||
| } | |||
| bool Dump::CreateNotExistDirs(const std::string& path) { | |||
| bool Dump::CreateNotExistDirs(const std::string &path) { | |||
| std::shared_ptr<system::FileSystem> fs = system::Env::GetFileSystem(); | |||
| MS_EXCEPTION_IF_NULL(fs); | |||
| char temp_path[PATH_MAX] = {0}; | |||
| @@ -43,11 +43,11 @@ class Dump { | |||
| uint32_t cur_iter() const { return cur_iter_; } | |||
| bool IsKernelNeedDump(const std::string& kernel_name); | |||
| bool IsKernelNeedDump(const std::string &kernel_name); | |||
| bool SetDumpConfFromJsonFile(); | |||
| static bool DumpToFile(const std::string& filename, const void* data, size_t len); | |||
| static bool DumpToFile(const std::string &filename, const void *data, size_t len); | |||
| protected: | |||
| bool dump_enable_; | |||
| @@ -59,14 +59,14 @@ class Dump { | |||
| uint32_t cur_iter_; | |||
| std::vector<std::string> dump_kernels_; | |||
| static bool GetRealPath(const std::string& inpath, std::string* outpath); | |||
| static bool GetRealPath(const std::string &inpath, std::string *outpath); | |||
| static bool CreateNotExistDirs(const std::string& path); | |||
| static bool CreateNotExistDirs(const std::string &path); | |||
| private: | |||
| bool ParseDumpConfig(const std::string& dump_config_file); | |||
| bool IsConfigExist(const nlohmann::json& dumpSettings); | |||
| bool IsConfigValid(const nlohmann::json& dumpSettings); | |||
| bool ParseDumpConfig(const std::string &dump_config_file); | |||
| bool IsConfigExist(const nlohmann::json &dumpSettings); | |||
| bool IsConfigValid(const nlohmann::json &dumpSettings); | |||
| }; | |||
| using DumpConfPtr = std::shared_ptr<Dump>; | |||
| @@ -23,7 +23,7 @@ | |||
| #include "pipeline/parse/python_adapter.h" | |||
| namespace mindspore { | |||
| std::string HighLightLine(const std::string& line, int col_begin, int col_end, SourceLineTip tip) { | |||
| std::string HighLightLine(const std::string &line, int col_begin, int col_end, SourceLineTip tip) { | |||
| std::string temp_line = line; | |||
| if (col_begin < col_end && col_begin != -1 && col_end <= SizeToInt(temp_line.length()) && | |||
| tip != kSourceLineTipDiscard) { | |||
| @@ -101,14 +101,14 @@ DebugInfo::DebugInfo() { | |||
| name_ = ""; | |||
| } | |||
| DebugInfo::DebugInfo(const std::string& name) { | |||
| DebugInfo::DebugInfo(const std::string &name) { | |||
| InitValueFromContext(); | |||
| unique_id_ = gen_unique_id(); | |||
| debug_id_ = -1; | |||
| name_ = name; | |||
| } | |||
| DebugInfo::DebugInfo(const LocationPtr& loc) { | |||
| DebugInfo::DebugInfo(const LocationPtr &loc) { | |||
| InitValueFromContext(); | |||
| unique_id_ = gen_unique_id(); | |||
| debug_id_ = -1; | |||
| @@ -126,7 +126,7 @@ int64_t DebugInfo::debug_id() { | |||
| } | |||
| int64_t DebugInfo::unique_id_through_copy() const { | |||
| TraceInfoPtr trace_info = const_cast<DebugInfo*>(this)->trace_info(); | |||
| TraceInfoPtr trace_info = const_cast<DebugInfo *>(this)->trace_info(); | |||
| if (trace_info != nullptr) { | |||
| if (trace_info->isa<TraceCopy>() && trace_info->debug_info() != nullptr) { | |||
| return trace_info->debug_info()->unique_id_through_copy(); | |||
| @@ -172,7 +172,7 @@ LocationPtr GraphDebugInfo::location() { | |||
| } | |||
| return DebugInfo::location(); | |||
| } | |||
| void GraphDebugInfo::set_deco_location(const LocationPtr& deco_list_loc) { deco_loc_ = deco_list_loc; } | |||
| void GraphDebugInfo::set_deco_location(const LocationPtr &deco_list_loc) { deco_loc_ = deco_list_loc; } | |||
| TraceContextPtr TraceManager::CurrentContextInfo() { | |||
| if (!TraceManager::trace_context_stack_.empty()) { | |||
| @@ -181,18 +181,18 @@ TraceContextPtr TraceManager::CurrentContextInfo() { | |||
| return nullptr; | |||
| } | |||
| void TraceManager::DebugTrace(const std::string& func_name, const LocationPtr& location) { | |||
| void TraceManager::DebugTrace(const std::string &func_name, const LocationPtr &location) { | |||
| TraceContextPtr context = std::make_shared<TraceContext>(location); | |||
| context->set_func_name(func_name); | |||
| TraceManager::trace_context_stack_.push(context); | |||
| } | |||
| void TraceManager::DebugTrace(const LocationPtr& location) { | |||
| void TraceManager::DebugTrace(const LocationPtr &location) { | |||
| TraceContextPtr context = std::make_shared<TraceContext>(location); | |||
| TraceManager::trace_context_stack_.push(context); | |||
| } | |||
| void TraceManager::DebugTrace(const TraceInfoPtr& trace_info) { | |||
| void TraceManager::DebugTrace(const TraceInfoPtr &trace_info) { | |||
| if (trace_info == nullptr) { | |||
| MS_LOG(EXCEPTION) << "DebugTrace wrong traced info is null"; | |||
| } | |||
| @@ -203,7 +203,7 @@ void TraceManager::DebugTrace(const TraceInfoPtr& trace_info) { | |||
| TraceManager::trace_context_stack_.push(context); | |||
| } | |||
| void TraceManager::DebugTrace(const DebugInfoPtr& debug_info, const TraceInfoPtr& trace_info) { | |||
| void TraceManager::DebugTrace(const DebugInfoPtr &debug_info, const TraceInfoPtr &trace_info) { | |||
| if (trace_info == nullptr) { | |||
| MS_LOG(EXCEPTION) << "DebugTrace wrong traced info is null"; | |||
| } | |||
| @@ -37,9 +37,9 @@ enum SourceLineTip { kSourceLineTipDiscard = 0, kSourceLineTipNextLine = 1, kSou | |||
| // Location class record the location in source code. | |||
| class Location { | |||
| public: | |||
| Location(const std::string& file_name, int line, int column, int line_end, int column_end) | |||
| Location(const std::string &file_name, int line, int column, int line_end, int column_end) | |||
| : file_name_(file_name), line_(line), column_(column), line_end_(line_end), column_end_(column_end) {} | |||
| Location(const Location& loc) | |||
| Location(const Location &loc) | |||
| : file_name_(loc.file_name_), | |||
| line_(loc.line_), | |||
| column_(loc.column_), | |||
| @@ -77,21 +77,21 @@ class TraceManager { | |||
| TraceManager() = default; | |||
| ~TraceManager() = default; | |||
| static TraceContextPtr CurrentContextInfo(); | |||
| static void DebugTrace(const std::string& func_name, const LocationPtr& location); | |||
| static void DebugTrace(const LocationPtr& location); | |||
| static void DebugTrace(const TraceInfoPtr& trace_info); | |||
| static void DebugTrace(const std::string &func_name, const LocationPtr &location); | |||
| static void DebugTrace(const LocationPtr &location); | |||
| static void DebugTrace(const TraceInfoPtr &trace_info); | |||
| // debug trace with a cloned trace info with debug_info | |||
| static void DebugTrace(const DebugInfoPtr& debug_info, const TraceInfoPtr& trace_info); | |||
| static void DebugTrace(const DebugInfoPtr &debug_info, const TraceInfoPtr &trace_info); | |||
| static void EndTrace(); | |||
| static std::stack<TraceContextPtr> trace_context_stack_; | |||
| }; | |||
| class TraceGuard { | |||
| public: | |||
| explicit TraceGuard(const std::string func_name, const LocationPtr& location) { | |||
| explicit TraceGuard(const std::string func_name, const LocationPtr &location) { | |||
| TraceManager::DebugTrace(func_name, location); | |||
| } | |||
| explicit TraceGuard(const LocationPtr& location) { TraceManager::DebugTrace(location); } | |||
| explicit TraceGuard(const LocationPtr &location) { TraceManager::DebugTrace(location); } | |||
| ~TraceGuard() { TraceManager::EndTrace(); } | |||
| }; | |||
| @@ -106,23 +106,23 @@ class TraceContext { | |||
| public: | |||
| ~TraceContext() = default; | |||
| explicit TraceContext(const LocationPtr& loc) { | |||
| explicit TraceContext(const LocationPtr &loc) { | |||
| ProcessAttributeFromContext(); | |||
| location_ = loc; | |||
| } | |||
| explicit TraceContext(const std::string& func_name) { | |||
| explicit TraceContext(const std::string &func_name) { | |||
| ProcessAttributeFromContext(); | |||
| func_name_ = func_name; | |||
| } | |||
| explicit TraceContext(const TraceInfoPtr& trace_info) { | |||
| explicit TraceContext(const TraceInfoPtr &trace_info) { | |||
| ProcessAttributeFromContext(); | |||
| trace_info_ = trace_info; | |||
| } | |||
| void set_location(const LocationPtr& loc) { location_ = loc; } | |||
| void set_location(const LocationPtr &loc) { location_ = loc; } | |||
| LocationPtr location() { return location_; } | |||
| void set_trace_info(const TraceInfoPtr& trace_info) { trace_info_ = trace_info; } | |||
| void set_trace_info(const TraceInfoPtr &trace_info) { trace_info_ = trace_info; } | |||
| TraceInfoPtr trace_info() { return trace_info_; } | |||
| void set_func_name(const std::string& func_name) { func_name_ = func_name; } | |||
| void set_func_name(const std::string &func_name) { func_name_ = func_name; } | |||
| std::string func_name() { return func_name_; } | |||
| }; | |||
| @@ -130,9 +130,9 @@ class DebugInfo : public Base { | |||
| public: | |||
| DebugInfo(); | |||
| explicit DebugInfo(const std::string& name); | |||
| explicit DebugInfo(const std::string &name); | |||
| explicit DebugInfo(const LocationPtr& loc); | |||
| explicit DebugInfo(const LocationPtr &loc); | |||
| virtual ~DebugInfo() = default; | |||
| MS_DECLARE_PARENT(DebugInfo, Base); | |||
| @@ -141,12 +141,12 @@ class DebugInfo : public Base { | |||
| int64_t unique_id_through_copy() const; | |||
| std::string get_id() { return std::to_string(debug_id()); } | |||
| void set_trace_info(const TraceInfoPtr& trace_info) { trace_info_ = trace_info; } | |||
| void set_trace_info(const TraceInfoPtr &trace_info) { trace_info_ = trace_info; } | |||
| TraceInfoPtr trace_info() { return trace_info_; } | |||
| void set_location(const LocationPtr& loc) { location_ = loc; } | |||
| void set_location(const LocationPtr &loc) { location_ = loc; } | |||
| virtual LocationPtr location() { return location_; } | |||
| std::string name() { return name_; } | |||
| void set_name(const std::string& name) { name_ = name; } | |||
| void set_name(const std::string &name) { name_ = name; } | |||
| virtual std::string debug_name(); | |||
| virtual std::string get_python_func_belonged() { return ""; } | |||
| @@ -186,7 +186,7 @@ class NodeDebugInfo : public DebugInfo { | |||
| py_func_belonged_ = context_info->func_name(); | |||
| } | |||
| } | |||
| explicit NodeDebugInfo(const std::string& name) : DebugInfo(name) { | |||
| explicit NodeDebugInfo(const std::string &name) : DebugInfo(name) { | |||
| if (TraceManager::CurrentContextInfo() != nullptr) { | |||
| auto context_info = TraceManager::CurrentContextInfo(); | |||
| py_func_belonged_ = context_info->func_name(); | |||
| @@ -195,9 +195,9 @@ class NodeDebugInfo : public DebugInfo { | |||
| ~NodeDebugInfo() override = default; | |||
| std::string debug_name() override; | |||
| void set_node(const std::shared_ptr<AnfNode>& node) { node_ = AnfNodeWeakPtr(node); } | |||
| void set_node(const std::shared_ptr<AnfNode> &node) { node_ = AnfNodeWeakPtr(node); } | |||
| std::shared_ptr<AnfNode> get_node() const { return node_.lock(); } | |||
| void set_py_func_belonged(const std::string& name) { py_func_belonged_ = name; } | |||
| void set_py_func_belonged(const std::string &name) { py_func_belonged_ = name; } | |||
| std::string get_python_func_belonged() override { return py_func_belonged_; } | |||
| AnfNodeWeakPtr node_; | |||
| std::string py_func_belonged_; | |||
| @@ -214,7 +214,7 @@ class GraphDebugInfo : public DebugInfo { | |||
| } | |||
| } | |||
| explicit GraphDebugInfo(const std::string& name) : DebugInfo(name) { | |||
| explicit GraphDebugInfo(const std::string &name) : DebugInfo(name) { | |||
| if (TraceManager::CurrentContextInfo() != nullptr) { | |||
| auto context_info = TraceManager::CurrentContextInfo(); | |||
| py_func_name_ = context_info->func_name(); | |||
| @@ -225,11 +225,11 @@ class GraphDebugInfo : public DebugInfo { | |||
| std::string debug_name() override; | |||
| LocationPtr location() override; | |||
| LocationPtr deco_location() { return deco_loc_; } | |||
| void set_graph(const FuncGraphPtr& func_graph) { func_graph_ = FuncGraphWeakPtr(func_graph); } | |||
| void set_graph(const FuncGraphPtr &func_graph) { func_graph_ = FuncGraphWeakPtr(func_graph); } | |||
| FuncGraphPtr get_graph() const { return func_graph_.lock(); } | |||
| void set_full_name(const std::string& name) { full_name_ = name; } | |||
| void set_full_name(const std::string &name) { full_name_ = name; } | |||
| std::string get_full_name() { return full_name_; } | |||
| void set_deco_location(const LocationPtr& deco_list_loc); | |||
| void set_deco_location(const LocationPtr &deco_list_loc); | |||
| std::string get_python_func_belonged() override { return py_func_name_; } | |||
| FuncGraphWeakPtr func_graph_; | |||
| LocationPtr deco_loc_; | |||
| @@ -31,7 +31,7 @@ struct NameWithTrace { | |||
| std::string name; | |||
| std::vector<std::string> trace_labels; | |||
| }; | |||
| static std::string GetTraceName(const TraceInfoPtr& trace_info, TraceLabelType trace_label) { | |||
| static std::string GetTraceName(const TraceInfoPtr &trace_info, TraceLabelType trace_label) { | |||
| switch (trace_label) { | |||
| case TraceLabelType::kShortSymbol: | |||
| return trace_info->symbol(); | |||
| @@ -42,7 +42,7 @@ static std::string GetTraceName(const TraceInfoPtr& trace_info, TraceLabelType t | |||
| } | |||
| } | |||
| NameWithTrace RootName(const DebugInfoPtr& debug_info, TraceLabelType trace_label) { | |||
| NameWithTrace RootName(const DebugInfoPtr &debug_info, TraceLabelType trace_label) { | |||
| NameWithTrace trace_name; | |||
| // find debug info after Resolve/ExpandJ/GenMetaFuncGraph, it is a new node | |||
| auto temp_info = debug_info; | |||
| @@ -66,9 +66,9 @@ NameWithTrace RootName(const DebugInfoPtr& debug_info, TraceLabelType trace_labe | |||
| return trace_name; | |||
| } | |||
| std::string CombineTraceTypes(const std::string& root_name, const std::vector<std::string>& trace_labels) { | |||
| std::string CombineTraceTypes(const std::string &root_name, const std::vector<std::string> &trace_labels) { | |||
| std::string tags = ""; | |||
| for (auto& itr : trace_labels) { | |||
| for (auto &itr : trace_labels) { | |||
| std::string symbol = itr; | |||
| tags = tags + symbol; | |||
| } | |||
| @@ -76,12 +76,12 @@ std::string CombineTraceTypes(const std::string& root_name, const std::vector<st | |||
| } | |||
| // get the label name of the node debug info | |||
| std::string LabelString(const DebugInfoPtr& debug_info, TraceLabelType trace_label) { | |||
| std::string LabelString(const DebugInfoPtr &debug_info, TraceLabelType trace_label) { | |||
| NameWithTrace trace_name = RootName(debug_info, trace_label); | |||
| return CombineTraceTypes(trace_name.name, trace_name.trace_labels); | |||
| } | |||
| std::string CombineUniqueID(const DebugInfoPtr& debug_info) { | |||
| std::string CombineUniqueID(const DebugInfoPtr &debug_info) { | |||
| auto temp_info = debug_info; | |||
| std::string label = ""; | |||
| while (temp_info != nullptr) { | |||
| @@ -103,9 +103,9 @@ std::string CombineUniqueID(const DebugInfoPtr& debug_info) { | |||
| } | |||
| // get trace with unique id chain | |||
| std::string LabelStringUnique(const DebugInfoPtr& debug_info) { return CombineUniqueID(debug_info); } | |||
| std::string LabelStringUnique(const DebugInfoPtr &debug_info) { return CombineUniqueID(debug_info); } | |||
| std::string Label(const DebugInfoPtr& debug_info, TraceLabelType trace_label) { | |||
| std::string Label(const DebugInfoPtr &debug_info, TraceLabelType trace_label) { | |||
| if (GetGlobalTraceLabelType() == TraceLabelType::kWithUniqueId) { | |||
| return LabelStringUnique(debug_info); | |||
| } | |||
| @@ -29,7 +29,7 @@ namespace label_manage { | |||
| enum class TraceLabelType { kShortSymbol, kFullName, kWithUniqueId }; | |||
| TraceLabelType GetGlobalTraceLabelType(); | |||
| void SetGlobalTraceLabelType(TraceLabelType label_type); | |||
| std::string Label(const DebugInfoPtr& debug_info, TraceLabelType trace_type = TraceLabelType::kShortSymbol); | |||
| std::string Label(const DebugInfoPtr &debug_info, TraceLabelType trace_type = TraceLabelType::kShortSymbol); | |||
| } // namespace label_manage | |||
| } // namespace mindspore | |||
| @@ -37,7 +37,7 @@ | |||
| namespace mindspore { | |||
| // namespace to support debug trace infomation | |||
| namespace trace { | |||
| std::string GetAbstractStr(const abstract::AbstractBasePtr& abs) { | |||
| std::string GetAbstractStr(const abstract::AbstractBasePtr &abs) { | |||
| if (abs == nullptr) { | |||
| return "Null Abstract"; | |||
| } | |||
| @@ -69,7 +69,7 @@ std::vector<DebugInfoPtr> GetSourceCodeDebugInfoVec(DebugInfoPtr debug_info) { | |||
| return debug_with_loc_vec; | |||
| } | |||
| DebugInfoPtr GetSourceCodeDebugInfo(const DebugInfoPtr& info) { | |||
| DebugInfoPtr GetSourceCodeDebugInfo(const DebugInfoPtr &info) { | |||
| auto debug_with_loc_vec = GetSourceCodeDebugInfoVec(info); | |||
| if (debug_with_loc_vec.size() > 0) { | |||
| return debug_with_loc_vec[0]; | |||
| @@ -78,7 +78,7 @@ DebugInfoPtr GetSourceCodeDebugInfo(const DebugInfoPtr& info) { | |||
| } | |||
| } | |||
| std::string GetDebugInfo(const DebugInfoPtr& info, SourceLineTip tip) { | |||
| std::string GetDebugInfo(const DebugInfoPtr &info, SourceLineTip tip) { | |||
| if (info == nullptr) { | |||
| return ""; | |||
| } | |||
| @@ -91,7 +91,7 @@ std::string GetDebugInfo(const DebugInfoPtr& info, SourceLineTip tip) { | |||
| // a trace info identifies a node transform, so we can trace the node transform through | |||
| // a link of trace info and debug info | |||
| std::string GetInfoWithAction(const std::vector<DebugInfoPtr>& info_vec, SourceLineTip tip) { | |||
| std::string GetInfoWithAction(const std::vector<DebugInfoPtr> &info_vec, SourceLineTip tip) { | |||
| if (info_vec.size() < 1) { | |||
| return ""; | |||
| } | |||
| @@ -109,7 +109,7 @@ std::string GetInfoWithAction(const std::vector<DebugInfoPtr>& info_vec, SourceL | |||
| return traced_info; | |||
| } | |||
| std::string GetTracedDebugInfo(const DebugInfoPtr& info, SourceLineTip tip) { | |||
| std::string GetTracedDebugInfo(const DebugInfoPtr &info, SourceLineTip tip) { | |||
| if (info == nullptr) { | |||
| return ""; | |||
| } | |||
| @@ -124,7 +124,7 @@ std::string GetTracedDebugInfo(const DebugInfoPtr& info, SourceLineTip tip) { | |||
| return ""; | |||
| } | |||
| std::string GetDebugInfo(const DebugInfoPtr& info, const std::string& prefix, SourceLineTip tip) { | |||
| std::string GetDebugInfo(const DebugInfoPtr &info, const std::string &prefix, SourceLineTip tip) { | |||
| std::ostringstream oss; | |||
| if (info == nullptr) { | |||
| return ""; | |||
| @@ -139,7 +139,7 @@ std::string GetDebugInfo(const DebugInfoPtr& info, const std::string& prefix, So | |||
| return oss.str(); | |||
| } | |||
| std::string GetGraphParamString(const FuncGraphPtr& graph, abstract::AbstractBasePtrList args_spec_list) { | |||
| std::string GetGraphParamString(const FuncGraphPtr &graph, abstract::AbstractBasePtrList args_spec_list) { | |||
| std::ostringstream oss; | |||
| oss << "graph:" << graph->ToString() << " with args["; | |||
| auto params = graph->parameters(); | |||
| @@ -151,8 +151,8 @@ std::string GetGraphParamString(const FuncGraphPtr& graph, abstract::AbstractBas | |||
| return oss.str(); | |||
| } | |||
| void DumpInferStack(std::ostringstream& oss) { | |||
| auto& infer_stack = GetCurrenGraphInferStack(); | |||
| void DumpInferStack(std::ostringstream &oss) { | |||
| auto &infer_stack = GetCurrenGraphInferStack(); | |||
| if (infer_stack.empty()) { | |||
| return; | |||
| } | |||
| @@ -164,7 +164,7 @@ void DumpInferStack(std::ostringstream& oss) { | |||
| } | |||
| std::reverse(infer_vec.begin(), infer_vec.end()); | |||
| int index = 0; | |||
| for (auto& item : infer_vec) { | |||
| for (auto &item : infer_vec) { | |||
| auto graph_infer = std::dynamic_pointer_cast<abstract::BaseFuncGraphEvaluator>(item.first); | |||
| if (graph_infer == nullptr) { | |||
| MS_LOG(WARNING) << "DumpInferStack failed, got null graph evaluator"; | |||
| @@ -183,7 +183,7 @@ void DumpInferStack(std::ostringstream& oss) { | |||
| } | |||
| void TraceGraphInfer() { | |||
| auto& infer_stack = GetCurrenGraphInferStack(); | |||
| auto &infer_stack = GetCurrenGraphInferStack(); | |||
| std::ostringstream oss; | |||
| if (infer_stack.empty()) { | |||
| return; | |||
| @@ -200,15 +200,15 @@ class AnalyzedFuncGraphExporter : public AnfExporter { | |||
| AnalyzedFuncGraphExporter() : AnfExporter("", true, false) {} | |||
| ~AnalyzedFuncGraphExporter() override = default; | |||
| void ExportFuncGraph(const std::string& filename, const std::vector<abstract::AnfNodeConfigPtr>& node_cfgs); | |||
| void ExportFuncGraph(const std::string &filename, const std::vector<abstract::AnfNodeConfigPtr> &node_cfgs); | |||
| private: | |||
| std::string GetNodeType(const AnfNodePtr& nd) override; | |||
| std::string GetNodeType(const AnfNodePtr &nd) override; | |||
| }; | |||
| std::unordered_map<FuncGraphPtr, TaggedNodeMap> CalcTaggedFuncGraphs() { | |||
| std::unordered_map<FuncGraphPtr, TaggedNodeMap> tagged_func_graphs; | |||
| auto& list = GetCNodeDebugStack(); | |||
| auto &list = GetCNodeDebugStack(); | |||
| for (size_t i = 0; i < list.size(); ++i) { | |||
| auto node_cfg = list[i]; | |||
| auto fg = node_cfg->context()->func_graph(); | |||
| @@ -223,7 +223,7 @@ void OutputAnalyzedGraphWithType() { | |||
| exporter.ExportFuncGraph("analyze_fail.dat", GetCNodeDebugStack()); | |||
| } | |||
| std::string AnalyzedFuncGraphExporter::GetNodeType(const AnfNodePtr& node) { | |||
| std::string AnalyzedFuncGraphExporter::GetNodeType(const AnfNodePtr &node) { | |||
| if (node_cfg_ == nullptr) { | |||
| return AnfExporter::GetNodeType(node); | |||
| } | |||
| @@ -248,8 +248,8 @@ std::string AnalyzedFuncGraphExporter::GetNodeType(const AnfNodePtr& node) { | |||
| return oss.str(); | |||
| } | |||
| void AnalyzedFuncGraphExporter::ExportFuncGraph(const std::string& filename, | |||
| const std::vector<abstract::AnfNodeConfigPtr>& node_cfgs) { | |||
| void AnalyzedFuncGraphExporter::ExportFuncGraph(const std::string &filename, | |||
| const std::vector<abstract::AnfNodeConfigPtr> &node_cfgs) { | |||
| if (node_cfgs.empty()) { | |||
| MS_LOG(DEBUG) << "Node configs is empty"; | |||
| return; | |||
| @@ -265,7 +265,7 @@ void AnalyzedFuncGraphExporter::ExportFuncGraph(const std::string& filename, | |||
| auto tagged_func_graphs = CalcTaggedFuncGraphs(); | |||
| // first output graph on the analysis stack | |||
| for (const auto& node_cfg : node_cfgs) { | |||
| for (const auto &node_cfg : node_cfgs) { | |||
| auto fg = node_cfg->context()->func_graph(); | |||
| // the graph is already output, skip it | |||
| if (exported.find(fg) != exported.end()) { | |||
| @@ -296,7 +296,7 @@ void AnalyzedFuncGraphExporter::ExportFuncGraph(const std::string& filename, | |||
| ofs.close(); | |||
| } | |||
| void GetInferStackInfo(std::ostringstream& oss) { | |||
| void GetInferStackInfo(std::ostringstream &oss) { | |||
| MS_LOG(INFO) << "Get graph analysis information begin"; | |||
| auto stack = GetCNodeDebugStack(); | |||
| if (stack.empty()) { | |||
| @@ -336,7 +336,7 @@ void GetInferStackInfo(std::ostringstream& oss) { | |||
| static std::stack<std::pair<abstract::EvaluatorPtr, abstract::AnfNodeConfigPtr>> graph_infer_stack; | |||
| // trace the cnode infer debug info | |||
| static std::vector<abstract::AnfNodeConfigPtr> cnode_debug_stack{}; | |||
| void TraceGraphInferEnter(const abstract::EvaluatorPtr& eval, const abstract::AnfNodeConfigPtr& node) { | |||
| void TraceGraphInferEnter(const abstract::EvaluatorPtr &eval, const abstract::AnfNodeConfigPtr &node) { | |||
| if (eval == nullptr) { | |||
| MS_LOG(EXCEPTION) << "GraphInferEnter got null eval"; | |||
| } | |||
| @@ -345,7 +345,7 @@ void TraceGraphInferEnter(const abstract::EvaluatorPtr& eval, const abstract::An | |||
| } | |||
| } | |||
| void TraceGraphInferLeave(const abstract::EvaluatorPtr& eval) { | |||
| void TraceGraphInferLeave(const abstract::EvaluatorPtr &eval) { | |||
| if (eval == nullptr) { | |||
| MS_LOG(EXCEPTION) << "GraphInferEnter got null eval"; | |||
| } | |||
| @@ -354,13 +354,13 @@ void TraceGraphInferLeave(const abstract::EvaluatorPtr& eval) { | |||
| } | |||
| } | |||
| void TraceInferCNodeEnter(const abstract::AnfNodeConfigPtr& node_cfg) { cnode_debug_stack.push_back(node_cfg); } | |||
| void TraceInferCNodeEnter(const abstract::AnfNodeConfigPtr &node_cfg) { cnode_debug_stack.push_back(node_cfg); } | |||
| void TraceInferCNodeLeave() { cnode_debug_stack.pop_back(); } | |||
| std::vector<abstract::AnfNodeConfigPtr>& GetCNodeDebugStack() { return cnode_debug_stack; } | |||
| std::vector<abstract::AnfNodeConfigPtr> &GetCNodeDebugStack() { return cnode_debug_stack; } | |||
| std::stack<std::pair<abstract::EvaluatorPtr, abstract::AnfNodeConfigPtr>>& GetCurrenGraphInferStack() { | |||
| std::stack<std::pair<abstract::EvaluatorPtr, abstract::AnfNodeConfigPtr>> &GetCurrenGraphInferStack() { | |||
| return graph_infer_stack; | |||
| } | |||
| void ClearTraceStack() { | |||
| @@ -31,19 +31,19 @@ | |||
| namespace mindspore { | |||
| namespace trace { | |||
| std::string GetDebugInfo(const DebugInfoPtr& info, SourceLineTip tip = kSourceLineTipNextLine); | |||
| std::string GetDebugInfo(const DebugInfoPtr& info, const std::string& prefix, | |||
| std::string GetDebugInfo(const DebugInfoPtr &info, SourceLineTip tip = kSourceLineTipNextLine); | |||
| std::string GetDebugInfo(const DebugInfoPtr &info, const std::string &prefix, | |||
| SourceLineTip tip = kSourceLineTipNextLine); | |||
| DebugInfoPtr GetSourceCodeDebugInfo(const DebugInfoPtr& info); | |||
| DebugInfoPtr GetSourceCodeDebugInfo(const DebugInfoPtr &info); | |||
| void TraceGraphInfer(); | |||
| void GetInferStackInfo(std::ostringstream& oss); | |||
| void TraceGraphInferEnter(const abstract::EvaluatorPtr& eval, const abstract::AnfNodeConfigPtr& node); | |||
| void TraceGraphInferLeave(const abstract::EvaluatorPtr& eval); | |||
| void TraceInferCNodeEnter(const abstract::AnfNodeConfigPtr& node_cfg); | |||
| void GetInferStackInfo(std::ostringstream &oss); | |||
| void TraceGraphInferEnter(const abstract::EvaluatorPtr &eval, const abstract::AnfNodeConfigPtr &node); | |||
| void TraceGraphInferLeave(const abstract::EvaluatorPtr &eval); | |||
| void TraceInferCNodeEnter(const abstract::AnfNodeConfigPtr &node_cfg); | |||
| void TraceInferCNodeLeave(); | |||
| std::vector<abstract::AnfNodeConfigPtr>& GetCNodeDebugStack(); | |||
| std::stack<std::pair<abstract::EvaluatorPtr, abstract::AnfNodeConfigPtr>>& GetCurrenGraphInferStack(); | |||
| std::string GetAbstractStr(const abstract::AbstractBasePtr& abs); | |||
| std::vector<abstract::AnfNodeConfigPtr> &GetCNodeDebugStack(); | |||
| std::stack<std::pair<abstract::EvaluatorPtr, abstract::AnfNodeConfigPtr>> &GetCurrenGraphInferStack(); | |||
| std::string GetAbstractStr(const abstract::AbstractBasePtr &abs); | |||
| void ClearTraceStack(); | |||
| } // namespace trace | |||
| } // namespace mindspore | |||
| @@ -23,7 +23,7 @@ | |||
| #include "pipeline/parse/python_adapter.h" | |||
| namespace mindspore { | |||
| std::string TraceInfo::GetActionBetweenNode(const DebugInfoPtr& info) { | |||
| std::string TraceInfo::GetActionBetweenNode(const DebugInfoPtr &info) { | |||
| if (info == nullptr) { | |||
| return ""; | |||
| } | |||
| @@ -40,13 +40,13 @@ using DebugInfoPtr = std::shared_ptr<DebugInfo>; | |||
| // namespace to support intermediate representation definition | |||
| class TraceInfo : public Base { | |||
| public: | |||
| TraceInfo(const DebugInfoPtr& info, const std::string& full_name, const std::string& symbol) { | |||
| TraceInfo(const DebugInfoPtr &info, const std::string &full_name, const std::string &symbol) { | |||
| symbol_ = symbol; | |||
| full_name_ = full_name; | |||
| name_ = full_name_; | |||
| debug_info_ = info; | |||
| } | |||
| TraceInfo(const TraceInfo& info) | |||
| TraceInfo(const TraceInfo &info) | |||
| : Base(), debug_info_(info.debug_info_), symbol_(info.symbol_), full_name_(info.full_name_), name_(info.name_) {} | |||
| virtual ~TraceInfo() = default; | |||
| MS_DECLARE_PARENT(TraceInfo, Base); | |||
| @@ -55,8 +55,8 @@ class TraceInfo : public Base { | |||
| virtual std::string full_name() { return full_name_; } | |||
| virtual TraceInfoPtr clone() { return shared_from_base<TraceInfo>(); } | |||
| virtual std::string action_name() { return ""; } | |||
| virtual std::string GetActionBetweenNode(const DebugInfoPtr& info); | |||
| void set_debug_info(const DebugInfoPtr& info) { debug_info_ = info; } | |||
| virtual std::string GetActionBetweenNode(const DebugInfoPtr &info); | |||
| void set_debug_info(const DebugInfoPtr &info) { debug_info_ = info; } | |||
| DebugInfoPtr debug_info() { return debug_info_; } | |||
| DebugInfoPtr DebugInfoHasLoc(); | |||
| std::vector<std::pair<DebugInfoPtr, TraceInfoPtr>> GetSourceCodeDebugInfo(); | |||
| @@ -70,7 +70,7 @@ class TraceInfo : public Base { | |||
| class TracePhi : public TraceInfo { | |||
| public: | |||
| explicit TracePhi(const DebugInfoPtr& info) : TraceInfo(info, "phi", "Φ") {} | |||
| explicit TracePhi(const DebugInfoPtr &info) : TraceInfo(info, "phi", "Φ") {} | |||
| MS_DECLARE_PARENT(TracePhi, TraceInfo); | |||
| ~TracePhi() override = default; | |||
| TraceInfoPtr clone() override { return std::make_shared<TracePhi>(*shared_from_base<TracePhi>()); } | |||
| @@ -78,8 +78,8 @@ class TracePhi : public TraceInfo { | |||
| class TraceIfStmtTrueBranch : public TraceInfo { | |||
| public: | |||
| TraceIfStmtTrueBranch(const TraceIfStmtTrueBranch&) = default; | |||
| explicit TraceIfStmtTrueBranch(const DebugInfoPtr& info) : TraceInfo(info, "if_true", "✓") {} | |||
| TraceIfStmtTrueBranch(const TraceIfStmtTrueBranch &) = default; | |||
| explicit TraceIfStmtTrueBranch(const DebugInfoPtr &info) : TraceInfo(info, "if_true", "✓") {} | |||
| MS_DECLARE_PARENT(TraceIfStmtTrueBranch, TraceInfo); | |||
| ~TraceIfStmtTrueBranch() override = default; | |||
| TraceInfoPtr clone() override { | |||
| @@ -89,8 +89,8 @@ class TraceIfStmtTrueBranch : public TraceInfo { | |||
| class TraceIfStmtFalseBranch : public TraceInfo { | |||
| public: | |||
| TraceIfStmtFalseBranch(const TraceIfStmtFalseBranch&) = default; | |||
| explicit TraceIfStmtFalseBranch(const DebugInfoPtr& info) : TraceInfo(info, "if_false", "✗") {} | |||
| TraceIfStmtFalseBranch(const TraceIfStmtFalseBranch &) = default; | |||
| explicit TraceIfStmtFalseBranch(const DebugInfoPtr &info) : TraceInfo(info, "if_false", "✗") {} | |||
| MS_DECLARE_PARENT(TraceIfStmtFalseBranch, TraceInfo); | |||
| ~TraceIfStmtFalseBranch() override = default; | |||
| TraceInfoPtr clone() override { | |||
| @@ -100,7 +100,7 @@ class TraceIfStmtFalseBranch : public TraceInfo { | |||
| class TraceIfStmtAfterBranch : public TraceInfo { | |||
| public: | |||
| explicit TraceIfStmtAfterBranch(const DebugInfoPtr& info) : TraceInfo(info, "if_after", "↓") {} | |||
| explicit TraceIfStmtAfterBranch(const DebugInfoPtr &info) : TraceInfo(info, "if_after", "↓") {} | |||
| MS_DECLARE_PARENT(TraceIfStmtAfterBranch, TraceInfo); | |||
| ~TraceIfStmtAfterBranch() override = default; | |||
| TraceInfoPtr clone() override { | |||
| @@ -110,7 +110,7 @@ class TraceIfStmtAfterBranch : public TraceInfo { | |||
| class TraceIfExpTrueBranch : public TraceInfo { | |||
| public: | |||
| explicit TraceIfExpTrueBranch(const DebugInfoPtr& info) : TraceInfo(info, "ifexp_true", "↰") {} | |||
| explicit TraceIfExpTrueBranch(const DebugInfoPtr &info) : TraceInfo(info, "ifexp_true", "↰") {} | |||
| MS_DECLARE_PARENT(TraceIfExpTrueBranch, TraceInfo); | |||
| ~TraceIfExpTrueBranch() override = default; | |||
| TraceInfoPtr clone() override { | |||
| @@ -120,7 +120,7 @@ class TraceIfExpTrueBranch : public TraceInfo { | |||
| class TraceIfExpFalseBranch : public TraceInfo { | |||
| public: | |||
| explicit TraceIfExpFalseBranch(const DebugInfoPtr& info) : TraceInfo(info, "ifexp_false", "↱") {} | |||
| explicit TraceIfExpFalseBranch(const DebugInfoPtr &info) : TraceInfo(info, "ifexp_false", "↱") {} | |||
| MS_DECLARE_PARENT(TraceIfExpFalseBranch, TraceInfo); | |||
| ~TraceIfExpFalseBranch() override = default; | |||
| TraceInfoPtr clone() override { | |||
| @@ -131,7 +131,7 @@ class TraceIfExpFalseBranch : public TraceInfo { | |||
| class TraceCopy : public TraceInfo { | |||
| public: | |||
| TraceCopy() : TraceInfo(nullptr, "copy", "") {} | |||
| explicit TraceCopy(const DebugInfoPtr& info) : TraceInfo(info, "copy", "") {} | |||
| explicit TraceCopy(const DebugInfoPtr &info) : TraceInfo(info, "copy", "") {} | |||
| MS_DECLARE_PARENT(TraceCopy, TraceInfo); | |||
| ~TraceCopy() override = default; | |||
| TraceInfoPtr clone() override { return std::make_shared<TraceCopy>(*shared_from_base<TraceCopy>()); } | |||
| @@ -139,7 +139,7 @@ class TraceCopy : public TraceInfo { | |||
| class TraceIterator : public TraceInfo { | |||
| public: | |||
| explicit TraceIterator(const DebugInfoPtr& info) : TraceInfo(info, "iterator", "@") {} | |||
| explicit TraceIterator(const DebugInfoPtr &info) : TraceInfo(info, "iterator", "@") {} | |||
| MS_DECLARE_PARENT(TraceIterator, TraceInfo); | |||
| ~TraceIterator() override = default; | |||
| TraceInfoPtr clone() override { return std::make_shared<TraceIterator>(*shared_from_base<TraceIterator>()); } | |||
| @@ -147,7 +147,7 @@ class TraceIterator : public TraceInfo { | |||
| class TraceWhileHeader : public TraceInfo { | |||
| public: | |||
| explicit TraceWhileHeader(const DebugInfoPtr& info) : TraceInfo(info, "while_header", "⤾") {} | |||
| explicit TraceWhileHeader(const DebugInfoPtr &info) : TraceInfo(info, "while_header", "⤾") {} | |||
| MS_DECLARE_PARENT(TraceWhileHeader, TraceInfo); | |||
| ~TraceWhileHeader() override = default; | |||
| TraceInfoPtr clone() override { return std::make_shared<TraceWhileHeader>(*shared_from_base<TraceWhileHeader>()); } | |||
| @@ -155,7 +155,7 @@ class TraceWhileHeader : public TraceInfo { | |||
| class TraceWhileBody : public TraceInfo { | |||
| public: | |||
| explicit TraceWhileBody(const DebugInfoPtr& info) : TraceInfo(info, "while_body", "⥁") {} | |||
| explicit TraceWhileBody(const DebugInfoPtr &info) : TraceInfo(info, "while_body", "⥁") {} | |||
| MS_DECLARE_PARENT(TraceWhileBody, TraceInfo); | |||
| ~TraceWhileBody() override = default; | |||
| TraceInfoPtr clone() override { return std::make_shared<TraceWhileBody>(*shared_from_base<TraceWhileBody>()); } | |||
| @@ -163,7 +163,7 @@ class TraceWhileBody : public TraceInfo { | |||
| class TraceWhileAfter : public TraceInfo { | |||
| public: | |||
| explicit TraceWhileAfter(const DebugInfoPtr& info) : TraceInfo(info, "while_after", "↓") {} | |||
| explicit TraceWhileAfter(const DebugInfoPtr &info) : TraceInfo(info, "while_after", "↓") {} | |||
| MS_DECLARE_PARENT(TraceWhileAfter, TraceInfo); | |||
| ~TraceWhileAfter() override = default; | |||
| TraceInfoPtr clone() override { return std::make_shared<TraceWhileAfter>(*shared_from_base<TraceWhileAfter>()); } | |||
| @@ -171,7 +171,7 @@ class TraceWhileAfter : public TraceInfo { | |||
| class TraceForHeader : public TraceInfo { | |||
| public: | |||
| explicit TraceForHeader(const DebugInfoPtr& info) : TraceInfo(info, "for_header", "⤾") {} | |||
| explicit TraceForHeader(const DebugInfoPtr &info) : TraceInfo(info, "for_header", "⤾") {} | |||
| MS_DECLARE_PARENT(TraceForHeader, TraceInfo); | |||
| ~TraceForHeader() override = default; | |||
| TraceInfoPtr clone() override { return std::make_shared<TraceForHeader>(*shared_from_base<TraceForHeader>()); } | |||
| @@ -179,7 +179,7 @@ class TraceForHeader : public TraceInfo { | |||
| class TraceForBody : public TraceInfo { | |||
| public: | |||
| explicit TraceForBody(const DebugInfoPtr& info) : TraceInfo(info, "for_body", "⥁") {} | |||
| explicit TraceForBody(const DebugInfoPtr &info) : TraceInfo(info, "for_body", "⥁") {} | |||
| MS_DECLARE_PARENT(TraceForBody, TraceInfo); | |||
| ~TraceForBody() override = default; | |||
| TraceInfoPtr clone() override { return std::make_shared<TraceForBody>(*shared_from_base<TraceForBody>()); } | |||
| @@ -187,7 +187,7 @@ class TraceForBody : public TraceInfo { | |||
| class TraceForAfter : public TraceInfo { | |||
| public: | |||
| explicit TraceForAfter(const DebugInfoPtr& info) : TraceInfo(info, "for_after", "↓") {} | |||
| explicit TraceForAfter(const DebugInfoPtr &info) : TraceInfo(info, "for_after", "↓") {} | |||
| MS_DECLARE_PARENT(TraceForAfter, TraceInfo); | |||
| ~TraceForAfter() override = default; | |||
| TraceInfoPtr clone() override { return std::make_shared<TraceForAfter>(*shared_from_base<TraceForAfter>()); } | |||
| @@ -195,7 +195,7 @@ class TraceForAfter : public TraceInfo { | |||
| class TraceEquiv : public TraceInfo { | |||
| public: | |||
| explicit TraceEquiv(const DebugInfoPtr& info) : TraceInfo(info, "equiv", "equiv") {} | |||
| explicit TraceEquiv(const DebugInfoPtr &info) : TraceInfo(info, "equiv", "equiv") {} | |||
| MS_DECLARE_PARENT(TraceEquiv, TraceInfo); | |||
| ~TraceEquiv() override = default; | |||
| TraceInfoPtr clone() override { return std::make_shared<TraceEquiv>(*shared_from_base<TraceEquiv>()); } | |||
| @@ -204,7 +204,7 @@ class TraceEquiv : public TraceInfo { | |||
| class TraceGradFpropApp : public TraceInfo { | |||
| public: | |||
| TraceGradFpropApp() : TraceInfo(nullptr, "grad_fprop_app", "▲") {} | |||
| explicit TraceGradFpropApp(const DebugInfoPtr& info) : TraceInfo(info, "grad_fprop_app", "▲") {} | |||
| explicit TraceGradFpropApp(const DebugInfoPtr &info) : TraceInfo(info, "grad_fprop_app", "▲") {} | |||
| MS_DECLARE_PARENT(TraceGradFpropApp, TraceInfo); | |||
| ~TraceGradFpropApp() override = default; | |||
| TraceInfoPtr clone() override { return std::make_shared<TraceGradFpropApp>(*shared_from_base<TraceGradFpropApp>()); } | |||
| @@ -213,7 +213,7 @@ class TraceGradFpropApp : public TraceInfo { | |||
| class TraceGradBpropApp : public TraceInfo { | |||
| public: | |||
| TraceGradBpropApp() : TraceInfo(nullptr, "grad_bprop_app", "▼") {} | |||
| explicit TraceGradBpropApp(const DebugInfoPtr& info) : TraceInfo(info, "grad_bprop_app", "▼") {} | |||
| explicit TraceGradBpropApp(const DebugInfoPtr &info) : TraceInfo(info, "grad_bprop_app", "▼") {} | |||
| MS_DECLARE_PARENT(TraceGradBpropApp, TraceInfo); | |||
| ~TraceGradBpropApp() override = default; | |||
| TraceInfoPtr clone() override { return std::make_shared<TraceGradBpropApp>(*shared_from_base<TraceGradBpropApp>()); } | |||
| @@ -222,7 +222,7 @@ class TraceGradBpropApp : public TraceInfo { | |||
| class TraceGradFprop : public TraceInfo { | |||
| public: | |||
| TraceGradFprop() : TraceInfo(nullptr, "grad_fprop", "▶") {} | |||
| explicit TraceGradFprop(const DebugInfoPtr& info) : TraceInfo(info, "grad_fprop", "▶") {} | |||
| explicit TraceGradFprop(const DebugInfoPtr &info) : TraceInfo(info, "grad_fprop", "▶") {} | |||
| MS_DECLARE_PARENT(TraceGradFprop, TraceInfo); | |||
| ~TraceGradFprop() override = default; | |||
| TraceInfoPtr clone() override { return std::make_shared<TraceGradFprop>(*shared_from_base<TraceGradFprop>()); } | |||
| @@ -231,7 +231,7 @@ class TraceGradFprop : public TraceInfo { | |||
| class TraceGradBprop : public TraceInfo { | |||
| public: | |||
| TraceGradBprop() : TraceInfo(nullptr, "grad_bprop", "◀") {} | |||
| explicit TraceGradBprop(const DebugInfoPtr& info) : TraceInfo(info, "grad_bprop", "◀") {} | |||
| explicit TraceGradBprop(const DebugInfoPtr &info) : TraceInfo(info, "grad_bprop", "◀") {} | |||
| MS_DECLARE_PARENT(TraceGradBprop, TraceInfo); | |||
| ~TraceGradBprop() override = default; | |||
| TraceInfoPtr clone() override { return std::make_shared<TraceGradBprop>(*shared_from_base<TraceGradBprop>()); } | |||
| @@ -240,7 +240,7 @@ class TraceGradBprop : public TraceInfo { | |||
| class TraceGradSens : public TraceInfo { | |||
| public: | |||
| TraceGradSens() : TraceInfo(nullptr, "grad_sens", "∇") {} | |||
| explicit TraceGradSens(const DebugInfoPtr& info) : TraceInfo(info, "grad_sens", "∇") {} | |||
| explicit TraceGradSens(const DebugInfoPtr &info) : TraceInfo(info, "grad_sens", "∇") {} | |||
| MS_DECLARE_PARENT(TraceGradSens, TraceInfo); | |||
| ~TraceGradSens() override = default; | |||
| TraceInfoPtr clone() override { return std::make_shared<TraceGradSens>(*shared_from_base<TraceGradSens>()); } | |||
| @@ -248,7 +248,7 @@ class TraceGradSens : public TraceInfo { | |||
| class TraceSpecialize : public TraceInfo { | |||
| public: | |||
| explicit TraceSpecialize(const std::string& counter) : TraceInfo(nullptr, "specialize", "") { counter_ = counter; } | |||
| explicit TraceSpecialize(const std::string &counter) : TraceInfo(nullptr, "specialize", "") { counter_ = counter; } | |||
| MS_DECLARE_PARENT(TraceSpecialize, TraceInfo); | |||
| std::string name() override { return full_name_ + counter_; } | |||
| std::string symbol() override { return counter_ + "_"; } | |||
| @@ -260,7 +260,7 @@ class TraceSpecialize : public TraceInfo { | |||
| class TraceGradOperation : public TraceInfo { | |||
| public: | |||
| explicit TraceGradOperation(const DebugInfoPtr& info) : TraceInfo(info, "grad_ops", "") {} | |||
| explicit TraceGradOperation(const DebugInfoPtr &info) : TraceInfo(info, "grad_ops", "") {} | |||
| MS_DECLARE_PARENT(TraceGradOperation, TraceInfo); | |||
| ~TraceGradOperation() override = default; | |||
| TraceInfoPtr clone() override { | |||
| @@ -270,7 +270,7 @@ class TraceGradOperation : public TraceInfo { | |||
| class TraceForceBool : public TraceInfo { | |||
| public: | |||
| explicit TraceForceBool(const DebugInfoPtr& info) : TraceInfo(info, "force_bool", "") {} | |||
| explicit TraceForceBool(const DebugInfoPtr &info) : TraceInfo(info, "force_bool", "") {} | |||
| MS_DECLARE_PARENT(TraceForceBool, TraceInfo); | |||
| ~TraceForceBool() override = default; | |||
| TraceInfoPtr clone() override { return std::make_shared<TraceForceBool>(*shared_from_base<TraceForceBool>()); } | |||
| @@ -278,7 +278,7 @@ class TraceForceBool : public TraceInfo { | |||
| class TraceExpandJ : public TraceInfo { | |||
| public: | |||
| explicit TraceExpandJ(const DebugInfoPtr& info) : TraceInfo(info, "expand_j", "") {} | |||
| explicit TraceExpandJ(const DebugInfoPtr &info) : TraceInfo(info, "expand_j", "") {} | |||
| MS_DECLARE_PARENT(TraceExpandJ, TraceInfo); | |||
| ~TraceExpandJ() override = default; | |||
| TraceInfoPtr clone() override { return std::make_shared<TraceExpandJ>(*shared_from_base<TraceExpandJ>()); } | |||
| @@ -286,7 +286,7 @@ class TraceExpandJ : public TraceInfo { | |||
| class TraceGenMetaFuncGraph : public TraceInfo { | |||
| public: | |||
| explicit TraceGenMetaFuncGraph(const DebugInfoPtr& info) : TraceInfo(info, "GenMetaFuncGraph", "") {} | |||
| explicit TraceGenMetaFuncGraph(const DebugInfoPtr &info) : TraceInfo(info, "GenMetaFuncGraph", "") {} | |||
| MS_DECLARE_PARENT(TraceGenMetaFuncGraph, TraceInfo); | |||
| ~TraceGenMetaFuncGraph() override = default; | |||
| TraceInfoPtr clone() override { | |||
| @@ -296,7 +296,7 @@ class TraceGenMetaFuncGraph : public TraceInfo { | |||
| class TraceEvaluatorGenGraph : public TraceInfo { | |||
| public: | |||
| explicit TraceEvaluatorGenGraph(const DebugInfoPtr& info) : TraceInfo(info, "GenEvaluatorGraph", "") {} | |||
| explicit TraceEvaluatorGenGraph(const DebugInfoPtr &info) : TraceInfo(info, "GenEvaluatorGraph", "") {} | |||
| MS_DECLARE_PARENT(TraceEvaluatorGenGraph, TraceInfo); | |||
| ~TraceEvaluatorGenGraph() override = default; | |||
| TraceInfoPtr clone() override { | |||
| @@ -306,7 +306,7 @@ class TraceEvaluatorGenGraph : public TraceInfo { | |||
| class TraceResolve : public TraceInfo { | |||
| public: | |||
| explicit TraceResolve(const DebugInfoPtr& info) : TraceInfo(info, "resolve", "") {} | |||
| explicit TraceResolve(const DebugInfoPtr &info) : TraceInfo(info, "resolve", "") {} | |||
| MS_DECLARE_PARENT(TraceResolve, TraceInfo); | |||
| ~TraceResolve() override = default; | |||
| TraceInfoPtr clone() override { return std::make_shared<TraceResolve>(*shared_from_base<TraceResolve>()); } | |||
| @@ -315,7 +315,7 @@ class TraceResolve : public TraceInfo { | |||
| class TraceTransform : public TraceInfo { | |||
| public: | |||
| TraceTransform() : TraceInfo(nullptr, "transform", "") { transform_name_ = ""; } | |||
| explicit TraceTransform(const std::string& transform_name) : TraceInfo(nullptr, "transform", "") { | |||
| explicit TraceTransform(const std::string &transform_name) : TraceInfo(nullptr, "transform", "") { | |||
| transform_name_ = transform_name; | |||
| } | |||
| @@ -335,7 +335,7 @@ class TraceTransform : public TraceInfo { | |||
| class TraceGenerateVarArg : public TraceInfo { | |||
| public: | |||
| explicit TraceGenerateVarArg(const DebugInfoPtr& info) : TraceInfo(info, "GenerateVarArg", "") {} | |||
| explicit TraceGenerateVarArg(const DebugInfoPtr &info) : TraceInfo(info, "GenerateVarArg", "") {} | |||
| MS_DECLARE_PARENT(TraceGenerateVarArg, TraceInfo); | |||
| ~TraceGenerateVarArg() override = default; | |||
| TraceInfoPtr clone() override { | |||
| @@ -345,7 +345,7 @@ class TraceGenerateVarArg : public TraceInfo { | |||
| class TraceGenerateKwArg : public TraceInfo { | |||
| public: | |||
| explicit TraceGenerateKwArg(const DebugInfoPtr& info) : TraceInfo(info, "GenerateKwArg", "") {} | |||
| explicit TraceGenerateKwArg(const DebugInfoPtr &info) : TraceInfo(info, "GenerateKwArg", "") {} | |||
| MS_DECLARE_PARENT(TraceGenerateKwArg, TraceInfo); | |||
| ~TraceGenerateKwArg() override = default; | |||
| TraceInfoPtr clone() override { | |||
| @@ -355,7 +355,7 @@ class TraceGenerateKwArg : public TraceInfo { | |||
| class TraceTrasformK : public TraceInfo { | |||
| public: | |||
| explicit TraceTrasformK(const DebugInfoPtr& info) : TraceInfo(info, "TraceTrasformK", "") {} | |||
| explicit TraceTrasformK(const DebugInfoPtr &info) : TraceInfo(info, "TraceTrasformK", "") {} | |||
| MS_DECLARE_PARENT(TraceTrasformK, TraceInfo); | |||
| ~TraceTrasformK() override = default; | |||
| TraceInfoPtr clone() override { return std::make_shared<TraceTrasformK>(*shared_from_base<TraceTrasformK>()); } | |||
| @@ -363,7 +363,7 @@ class TraceTrasformK : public TraceInfo { | |||
| class TracePartialTransform : public TraceInfo { | |||
| public: | |||
| explicit TracePartialTransform(const DebugInfoPtr& info) : TraceInfo(info, "PartialTransform", "") {} | |||
| explicit TracePartialTransform(const DebugInfoPtr &info) : TraceInfo(info, "PartialTransform", "") {} | |||
| MS_DECLARE_PARENT(TracePartialTransform, TraceInfo); | |||
| ~TracePartialTransform() override = default; | |||
| TraceInfoPtr clone() override { | |||
| @@ -373,7 +373,7 @@ class TracePartialTransform : public TraceInfo { | |||
| class TraceGetEnv : public TraceInfo { | |||
| public: | |||
| explicit TraceGetEnv(const DebugInfoPtr& info) : TraceInfo(info, "get_env", "") {} | |||
| explicit TraceGetEnv(const DebugInfoPtr &info) : TraceInfo(info, "get_env", "") {} | |||
| MS_DECLARE_PARENT(TraceGetEnv, TraceInfo); | |||
| ~TraceGetEnv() override = default; | |||
| TraceInfoPtr clone() override { return std::make_shared<TraceGetEnv>(*shared_from_base<TraceGetEnv>()); } | |||
| @@ -381,7 +381,7 @@ class TraceGetEnv : public TraceInfo { | |||
| class TraceDoSignature : public TraceInfo { | |||
| public: | |||
| explicit TraceDoSignature(const DebugInfoPtr& info) : TraceInfo(info, "DoSignature", "") {} | |||
| explicit TraceDoSignature(const DebugInfoPtr &info) : TraceInfo(info, "DoSignature", "") {} | |||
| MS_DECLARE_PARENT(TraceDoSignature, TraceInfo); | |||
| ~TraceDoSignature() override = default; | |||
| TraceInfoPtr clone() override { return std::make_shared<TraceDoSignature>(*shared_from_base<TraceDoSignature>()); } | |||
| @@ -390,7 +390,7 @@ class TraceDoSignature : public TraceInfo { | |||
| class TraceCombileLikeGraphs : public TraceInfo { | |||
| public: | |||
| TraceCombileLikeGraphs() : TraceInfo(nullptr, "CombileLike", "L-") {} | |||
| explicit TraceCombileLikeGraphs(const DebugInfoPtr& info) : TraceInfo(info, "CombileLike", "L-") {} | |||
| explicit TraceCombileLikeGraphs(const DebugInfoPtr &info) : TraceInfo(info, "CombileLike", "L-") {} | |||
| MS_DECLARE_PARENT(TraceCombileLikeGraphs, TraceInfo); | |||
| ~TraceCombileLikeGraphs() override = default; | |||
| TraceInfoPtr clone() override { | |||
| @@ -21,7 +21,7 @@ | |||
| namespace mindspore { | |||
| namespace device { | |||
| namespace ascend { | |||
| size_t AscendMemoryPool::AllocDeviceMem(size_t size, DeviceMemPtr* addr) { | |||
| size_t AscendMemoryPool::AllocDeviceMem(size_t size, DeviceMemPtr *addr) { | |||
| if (has_malloc_) { | |||
| MS_LOG(EXCEPTION) << "Has alloc memory pool memory !"; | |||
| } | |||
| @@ -37,7 +37,7 @@ size_t AscendMemoryPool::AllocDeviceMem(size_t size, DeviceMemPtr* addr) { | |||
| return size; | |||
| } | |||
| bool AscendMemoryPool::FreeDeviceMem(const DeviceMemPtr& addr) { | |||
| bool AscendMemoryPool::FreeDeviceMem(const DeviceMemPtr &addr) { | |||
| MS_EXCEPTION_IF_NULL(addr); | |||
| has_malloc_ = false; | |||
| free_mem_size_ = total_mem_size_; | |||
| @@ -53,7 +53,7 @@ size_t AscendMemoryPool::AlignMemorySize(size_t size) const { | |||
| size_t AscendMemoryPool::mem_alloc_unit_size() const { return free_mem_size_ - 512; } | |||
| void AscendMemoryPool::set_device_mem_pool_base(uint8_t* device_mem_pool_base) { | |||
| void AscendMemoryPool::set_device_mem_pool_base(uint8_t *device_mem_pool_base) { | |||
| MS_EXCEPTION_IF_NULL(device_mem_pool_base); | |||
| device_mem_pool_base_ = device_mem_pool_base; | |||
| } | |||
| @@ -26,12 +26,12 @@ namespace ascend { | |||
| class AscendMemoryPool : public DynamicMemPoolBestFit { | |||
| public: | |||
| ~AscendMemoryPool() override = default; | |||
| AscendMemoryPool(const AscendMemoryPool&) = delete; | |||
| AscendMemoryPool& operator=(const AscendMemoryPool&) = delete; | |||
| AscendMemoryPool(const AscendMemoryPool &) = delete; | |||
| AscendMemoryPool &operator=(const AscendMemoryPool &) = delete; | |||
| size_t AllocDeviceMem(size_t size, DeviceMemPtr* addr) override; | |||
| bool FreeDeviceMem(const DeviceMemPtr& addr) override; | |||
| void set_device_mem_pool_base(uint8_t* device_mem_pool_base); | |||
| size_t AllocDeviceMem(size_t size, DeviceMemPtr *addr) override; | |||
| bool FreeDeviceMem(const DeviceMemPtr &addr) override; | |||
| void set_device_mem_pool_base(uint8_t *device_mem_pool_base); | |||
| void set_device_mem_pool_size(uint64_t device_mem_pool_size) { | |||
| device_mem_pool_size_ = device_mem_pool_size; | |||
| free_mem_size_ = device_mem_pool_size_; | |||
| @@ -40,7 +40,7 @@ class AscendMemoryPool : public DynamicMemPoolBestFit { | |||
| size_t free_mem_size() override; | |||
| size_t total_mem_size() override; | |||
| static AscendMemoryPool& GetInstance() { | |||
| static AscendMemoryPool &GetInstance() { | |||
| static AscendMemoryPool instance; | |||
| return instance; | |||
| } | |||
| @@ -54,7 +54,7 @@ class AscendMemoryPool : public DynamicMemPoolBestFit { | |||
| private: | |||
| AscendMemoryPool() = default; | |||
| bool has_malloc_{false}; | |||
| uint8_t* device_mem_pool_base_{nullptr}; | |||
| uint8_t *device_mem_pool_base_{nullptr}; | |||
| uint64_t device_mem_pool_size_{0}; | |||
| size_t free_mem_size_{0}; | |||
| size_t total_mem_size_{0}; | |||
| @@ -39,13 +39,13 @@ using std::vector; | |||
| class AscendStreamAssign { | |||
| public: | |||
| static AscendStreamAssign& GetInstance() { | |||
| static AscendStreamAssign &GetInstance() { | |||
| static AscendStreamAssign instance; // Guaranteed to be destroyed. | |||
| return instance; | |||
| } | |||
| AscendStreamAssign(const AscendStreamAssign&) = delete; | |||
| AscendStreamAssign& operator=(const AscendStreamAssign&) = delete; | |||
| AscendStreamAssign(const AscendStreamAssign &) = delete; | |||
| AscendStreamAssign &operator=(const AscendStreamAssign &) = delete; | |||
| uint32_t GetTotalStreamNum() const; | |||
| // new stream policy | |||
| @@ -53,19 +53,19 @@ class AscendStreamAssign { | |||
| uint32_t total_independ_stream_num() const { return total_independ_stream_num_; } | |||
| uint32_t total_event_num() const { return total_event_num_; } | |||
| void InsertActiveNew(const std::shared_ptr<session::KernelGraph>& graph_ptr); | |||
| void AssignAllNodesStream(const std::shared_ptr<session::KernelGraph>& graph_ptr); | |||
| void InsertActiveNew(const std::shared_ptr<session::KernelGraph> &graph_ptr); | |||
| void AssignAllNodesStream(const std::shared_ptr<session::KernelGraph> &graph_ptr); | |||
| void ResetNew(); | |||
| void AssignStreamNew(const std::shared_ptr<session::KernelGraph>& graph_ptr); | |||
| bool IsIndependentNode(const CNodePtr& node_ptr); | |||
| const std::unordered_map<uint32_t, uint32_t>& logic_to_independent_map() { return logic_to_independent_map_; } | |||
| const std::unordered_map<uint32_t, uint32_t>& logic_to_physic_map() { return logic_to_physic_map_; } | |||
| const std::vector<std::vector<uint32_t>>& inner_parallel_streams() { return inner_parallel_streams_; } | |||
| void GetWaitStreams(vector<uint32_t>* wait_active_stream_list); | |||
| const std::vector<uint32_t>& hcom_streams() { return hcom_stream_list_; } | |||
| CNodePtr CreateSendApplyKernel(const std::shared_ptr<session::KernelGraph>& graph_ptr, uint32_t event_id, | |||
| void AssignStreamNew(const std::shared_ptr<session::KernelGraph> &graph_ptr); | |||
| bool IsIndependentNode(const CNodePtr &node_ptr); | |||
| const std::unordered_map<uint32_t, uint32_t> &logic_to_independent_map() { return logic_to_independent_map_; } | |||
| const std::unordered_map<uint32_t, uint32_t> &logic_to_physic_map() { return logic_to_physic_map_; } | |||
| const std::vector<std::vector<uint32_t>> &inner_parallel_streams() { return inner_parallel_streams_; } | |||
| void GetWaitStreams(vector<uint32_t> *wait_active_stream_list); | |||
| const std::vector<uint32_t> &hcom_streams() { return hcom_stream_list_; } | |||
| CNodePtr CreateSendApplyKernel(const std::shared_ptr<session::KernelGraph> &graph_ptr, uint32_t event_id, | |||
| uint32_t stream_id); | |||
| CNodePtr CreateRecvApplyKernel(const std::shared_ptr<session::KernelGraph>& graph_ptr, uint32_t event_id, | |||
| CNodePtr CreateRecvApplyKernel(const std::shared_ptr<session::KernelGraph> &graph_ptr, uint32_t event_id, | |||
| uint32_t stream_id); | |||
| private: | |||
| @@ -73,30 +73,30 @@ class AscendStreamAssign { | |||
| ~AscendStreamAssign() = default; | |||
| vector<CNodePtr>::iterator FindTargetOp(vector<CNodePtr>::iterator begin, vector<CNodePtr>::iterator end, | |||
| const CNodePtr& node); | |||
| const CNodePtr &node); | |||
| bool IsHcom(const CNodePtr& apply_kernel); | |||
| bool IsHcom(const CNodePtr &apply_kernel); | |||
| bool IsProcessed(uint32_t logic_id); | |||
| void TransLogicToPhysic(const vector<uint32_t>& logic_ids, vector<uint32_t>* physic_ids); | |||
| void AssignCommonStreamId(const CNodePtr& cur_cnode_ptr, CNodePtr* pre_cnode_ptr, uint32_t* cur_index, | |||
| uint32_t* cur_stream_id); | |||
| void TransLogicToPhysic(const vector<uint32_t> &logic_ids, vector<uint32_t> *physic_ids); | |||
| void AssignCommonStreamId(const CNodePtr &cur_cnode_ptr, CNodePtr *pre_cnode_ptr, uint32_t *cur_index, | |||
| uint32_t *cur_stream_id); | |||
| void RecordIdMap(uint32_t logic_id, uint32_t physic_id); | |||
| void UpdateStreamActive(const CNodePtr& active_ptr); | |||
| void UpdateStreamSwitch(const CNodePtr& switch_ptr, const CNodePtr& active_ptr); | |||
| void UpdateStreamActive(const CNodePtr &active_ptr); | |||
| void UpdateStreamSwitch(const CNodePtr &switch_ptr, const CNodePtr &active_ptr); | |||
| bool IsTaskSink(); | |||
| void AssignIndependentStreamId(const CNodePtr& cur_cnode_ptr, uint32_t deal_logic_id); | |||
| void UpdateStreamId(const std::shared_ptr<session::KernelGraph>& graph_ptr); | |||
| void UpdateEventId(const std::shared_ptr<session::KernelGraph>& graph_ptr); | |||
| void PrintGraphExeOrders(const std::shared_ptr<session::KernelGraph>& graph_ptr); | |||
| void RecordFirstCommonOp(const CNodePtr& cur_cnode_ptr, uint32_t cur_node_logic_id, uint32_t cur_stream_id); | |||
| uint32_t GetLogicId(const CNodePtr& cur_cnode_ptr); | |||
| void AssignIndependentStreamId(const CNodePtr &cur_cnode_ptr, uint32_t deal_logic_id); | |||
| void UpdateStreamId(const std::shared_ptr<session::KernelGraph> &graph_ptr); | |||
| void UpdateEventId(const std::shared_ptr<session::KernelGraph> &graph_ptr); | |||
| void PrintGraphExeOrders(const std::shared_ptr<session::KernelGraph> &graph_ptr); | |||
| void RecordFirstCommonOp(const CNodePtr &cur_cnode_ptr, uint32_t cur_node_logic_id, uint32_t cur_stream_id); | |||
| uint32_t GetLogicId(const CNodePtr &cur_cnode_ptr); | |||
| void SetCommonStreamNum(uint32_t cur_stream_id); | |||
| void FindAllReduceParallel(const std::shared_ptr<session::KernelGraph>& graph_ptr); | |||
| void FindAllReduceParallel(const std::shared_ptr<session::KernelGraph> &graph_ptr); | |||
| bool IsProcessedParallelStream(uint32_t stream_id); | |||
| void GetParallelStream(uint32_t cur_stream_id, uint32_t stream_acitve_id, std::vector<uint32_t>* parallel_streams); | |||
| void InsertSendRecvForIndependent(const std::shared_ptr<session::KernelGraph>& graph_ptr); | |||
| void InsertSendRecvForHcomParallel(const std::shared_ptr<session::KernelGraph>& graph_ptr); | |||
| void GetNeedActiveStreams(const std::shared_ptr<session::KernelGraph>& graph_ptr); | |||
| void GetParallelStream(uint32_t cur_stream_id, uint32_t stream_acitve_id, std::vector<uint32_t> *parallel_streams); | |||
| void InsertSendRecvForIndependent(const std::shared_ptr<session::KernelGraph> &graph_ptr); | |||
| void InsertSendRecvForHcomParallel(const std::shared_ptr<session::KernelGraph> &graph_ptr); | |||
| void GetNeedActiveStreams(const std::shared_ptr<session::KernelGraph> &graph_ptr); | |||
| uint32_t total_common_stream_num_{0}; | |||
| uint32_t total_independ_stream_num_{0}; | |||
| @@ -28,14 +28,14 @@ namespace device { | |||
| namespace ascend { | |||
| class PluginImpl : public PluginIntf { | |||
| public: | |||
| explicit PluginImpl(const std::string& module); | |||
| explicit PluginImpl(const std::string &module); | |||
| ~PluginImpl() override = default; | |||
| int Init(const Reporter* reporter) override; | |||
| int Init(const Reporter *reporter) override; | |||
| int UnInit() override; | |||
| static Reporter* GetPluginReporter() { return reporter_; } | |||
| static Reporter *GetPluginReporter() { return reporter_; } | |||
| private: | |||
| static Reporter* reporter_; | |||
| static Reporter *reporter_; | |||
| std::string module_; | |||
| }; | |||
| } // namespace ascend | |||
| @@ -20,12 +20,12 @@ | |||
| namespace mindspore { | |||
| namespace device { | |||
| namespace ascend { | |||
| PluginIntf* ProfilingEngineImpl::CreatePlugin() { | |||
| PluginIntf *ProfilingEngineImpl::CreatePlugin() { | |||
| MS_LOG(INFO) << "Create Plugin."; | |||
| return new (std::nothrow) PluginImpl("Framework"); | |||
| } | |||
| int ProfilingEngineImpl::ReleasePlugin(PluginIntf* plugin) { | |||
| int ProfilingEngineImpl::ReleasePlugin(PluginIntf *plugin) { | |||
| if (plugin != nullptr) { | |||
| delete plugin; | |||
| } | |||
| @@ -29,8 +29,8 @@ class ProfilingEngineImpl : public EngineIntf { | |||
| ProfilingEngineImpl() = default; | |||
| ~ProfilingEngineImpl() override = default; | |||
| PluginIntf* CreatePlugin() override; | |||
| int ReleasePlugin(PluginIntf* plugin) override; | |||
| PluginIntf *CreatePlugin() override; | |||
| int ReleasePlugin(PluginIntf *plugin) override; | |||
| }; | |||
| } // namespace ascend | |||
| } // namespace device | |||
| @@ -35,7 +35,7 @@ using Json = nlohmann::json; | |||
| namespace mindspore { | |||
| namespace device { | |||
| namespace ascend { | |||
| ProfilingManager& ProfilingManager::GetInstance() { | |||
| ProfilingManager &ProfilingManager::GetInstance() { | |||
| static ProfilingManager inst; | |||
| return inst; | |||
| } | |||
| @@ -45,11 +45,11 @@ ProfilingManager::ProfilingManager() : device_id_(0), prof_handle_(nullptr) { | |||
| } | |||
| uint64_t ProfilingManager::GetJobId() const { | |||
| const char* job_id = std::getenv("JOB_ID"); | |||
| const char *job_id = std::getenv("JOB_ID"); | |||
| return ((job_id != nullptr) ? std::strtoul(job_id, nullptr, 10) : 0); | |||
| } | |||
| bool ProfilingManager::ReportProfilingData(const map<uint32_t, string>& op_taskId_map) const { | |||
| bool ProfilingManager::ReportProfilingData(const map<uint32_t, string> &op_taskId_map) const { | |||
| if (!IsProfiling()) { | |||
| MS_LOG(INFO) << "No need profiling. please export PROFILING_MODE and in train mode."; | |||
| return false; | |||
| @@ -66,10 +66,10 @@ bool ProfilingManager::ReportProfilingData(const map<uint32_t, string>& op_taskI | |||
| MS_LOG(INFO) << "DistributeTask: op tasId map size = " << op_taskId_map.size(); | |||
| Msprof::Engine::ReporterData reporter_data = {}; | |||
| for (const auto& iter : op_taskId_map) { | |||
| for (const auto &iter : op_taskId_map) { | |||
| auto data = iter.second + ' ' + std::to_string(iter.first) + ';'; | |||
| reporter_data.deviceId = UintToInt(device_id_); | |||
| reporter_data.data = (unsigned char*)(const_cast<char*>(data.c_str())); | |||
| reporter_data.data = (unsigned char *)(const_cast<char *>(data.c_str())); | |||
| reporter_data.dataLen = data.size(); | |||
| auto ret = memcpy_s(reporter_data.tag, MSPROF_ENGINE_MAX_TAG_LEN + 1, "framework", sizeof("framework")); | |||
| if (ret != 0) { | |||
| @@ -85,7 +85,7 @@ bool ProfilingManager::ReportProfilingData(const map<uint32_t, string>& op_taskI | |||
| return true; | |||
| } | |||
| static std::vector<std::string> Split(const std::string& str, const char delim) { | |||
| static std::vector<std::string> Split(const std::string &str, const char delim) { | |||
| std::vector<std::string> elems; | |||
| if (str.empty()) { | |||
| @@ -116,7 +116,7 @@ bool ProfilingManager::StartupProfiling(uint32_t device_id) { | |||
| device_id_ = device_id; | |||
| // exp: export PROFILING_MODE=true | |||
| // export PROFILING_OPTIONS=training_trace | |||
| const char* prof_options_str = std::getenv("PROFILING_OPTIONS"); | |||
| const char *prof_options_str = std::getenv("PROFILING_OPTIONS"); | |||
| // register Framework to profiling | |||
| int result = Msprof::Engine::RegisterEngine("Framework", engine_0_.get()); | |||
| if (result != 0) { | |||
| @@ -176,7 +176,7 @@ bool ProfilingManager::StopProfiling() const { | |||
| MS_LOG(INFO) << "No need profiling. please export PROFILING_MODE and in train mode."; | |||
| return true; | |||
| } | |||
| Msprof::Engine::Reporter* reporter = PluginImpl::GetPluginReporter(); | |||
| Msprof::Engine::Reporter *reporter = PluginImpl::GetPluginReporter(); | |||
| if (reporter != nullptr) { | |||
| MS_LOG(INFO) << "report data end, ret = " << reporter->Flush(); | |||
| } | |||
| @@ -33,27 +33,27 @@ enum BlockQueueStatus_T : int { SUCCESS = 0, QUEUE_NOT_EXIST, HANDLE_NOT_EXIST, | |||
| class GpuQueue { | |||
| public: | |||
| GpuQueue(void* addr, size_t feature_size, size_t label_size, size_t capacity); | |||
| GpuQueue(void *addr, size_t feature_size, size_t label_size, size_t capacity); | |||
| virtual ~GpuQueue(); | |||
| void RegisterRelease(const std::function<void(void*)>& func) { host_release_ = func; } | |||
| void RegisterRelease(const std::function<void(void *)> &func) { host_release_ = func; } | |||
| inline bool IsEmpty() const { return head_ == tail_; } | |||
| inline bool IsFull() const { return head_ == ((tail_ + 1) % (capacity_)); } | |||
| BlockQueueStatus_T Push(void* feature_addr, size_t feature_size, void* label_addr, size_t label_size); | |||
| BlockQueueStatus_T Front(void** feature_addr, size_t* feature_size, void** label_addr, size_t* label_size) const; | |||
| BlockQueueStatus_T Push(void *feature_addr, size_t feature_size, void *label_addr, size_t label_size); | |||
| BlockQueueStatus_T Front(void **feature_addr, size_t *feature_size, void **label_addr, size_t *label_size) const; | |||
| BlockQueueStatus_T Pop(); | |||
| bool Destroy(); | |||
| private: | |||
| struct NodeInfo { | |||
| std::unique_ptr<cudaEvent_t> event_; | |||
| void* host_feature_addr_; | |||
| void* host_label_addr_; | |||
| void *host_feature_addr_; | |||
| void *host_label_addr_; | |||
| }; | |||
| void* buffer_; | |||
| void *buffer_; | |||
| size_t head_; | |||
| size_t tail_; | |||
| size_t feature_size_; | |||
| @@ -61,10 +61,10 @@ class GpuQueue { | |||
| size_t capacity_; | |||
| cudaStream_t stream_; | |||
| std::unique_ptr<NodeInfo[]> node_info_; | |||
| std::function<void(void*)> host_release_; | |||
| std::function<void(void *)> host_release_; | |||
| GpuQueue(const GpuQueue&) = delete; | |||
| GpuQueue& operator=(const GpuQueue&) = delete; | |||
| GpuQueue(const GpuQueue &) = delete; | |||
| GpuQueue &operator=(const GpuQueue &) = delete; | |||
| }; | |||
| class BlockingQueue { | |||
| @@ -72,11 +72,11 @@ class BlockingQueue { | |||
| BlockingQueue() : queue_(nullptr) {} | |||
| ~BlockingQueue() = default; | |||
| BlockQueueStatus_T Create(void* addr, size_t feature_size, size_t label_size, size_t capacity); | |||
| void RegisterRelease(const std::function<void(void*)>& func); | |||
| BlockQueueStatus_T Push(void* feature_addr, size_t feature_size, void* label_addr, size_t label_size, | |||
| BlockQueueStatus_T Create(void *addr, size_t feature_size, size_t label_size, size_t capacity); | |||
| void RegisterRelease(const std::function<void(void *)> &func); | |||
| BlockQueueStatus_T Push(void *feature_addr, size_t feature_size, void *label_addr, size_t label_size, | |||
| unsigned int timeout_in_sec); | |||
| BlockQueueStatus_T Front(void** feature_addr, size_t* feature_size, void** label_addr, size_t* label_size); | |||
| BlockQueueStatus_T Front(void **feature_addr, size_t *feature_size, void **label_addr, size_t *label_size); | |||
| BlockQueueStatus_T Pop(); | |||
| bool Destroy(); | |||
| @@ -20,17 +20,17 @@ | |||
| namespace mindspore { | |||
| namespace device { | |||
| namespace gpu { | |||
| CollectiveInitializer& CollectiveInitializer::instance() { | |||
| CollectiveInitializer &CollectiveInitializer::instance() { | |||
| static CollectiveInitializer instance = {}; | |||
| return instance; | |||
| } | |||
| bool CollectiveInitializer::collective_inited() const { return collective_inited_; } | |||
| const void* CollectiveInitializer::collective_handle() const { return collective_handle_; } | |||
| const void *CollectiveInitializer::collective_handle() const { return collective_handle_; } | |||
| void CollectiveInitializer::InitCollective() { | |||
| void* handle = dlopen("libgpu_collective.so", RTLD_LAZY); | |||
| void *handle = dlopen("libgpu_collective.so", RTLD_LAZY); | |||
| if (handle == nullptr) { | |||
| MS_LOG(EXCEPTION) | |||
| << "Loading libgpu_collective.so failed. Many reasons could cause this:\n1.libgpu_collective.so is not " | |||
| @@ -50,13 +50,13 @@ void GPUDeviceManager::ReleaseDevice() { | |||
| CHECK_OP_RET_WITH_ERROR(GPUMemoryAllocator::GetInstance().Finalize(), "Failed to destroy gpu memory allocator"); | |||
| } | |||
| bool GPUDeviceManager::CreateStream(DeviceStream* stream) { | |||
| bool GPUDeviceManager::CreateStream(DeviceStream *stream) { | |||
| CHECK_OP_RET_WITH_EXCEPT(CudaDriver::CreateStream(stream), "Failed to create CUDA stream"); | |||
| gpu_streams_.emplace_back(*stream); | |||
| return true; | |||
| } | |||
| const DeviceStream& GPUDeviceManager::default_stream() const { return default_stream_; } | |||
| const DeviceStream &GPUDeviceManager::default_stream() const { return default_stream_; } | |||
| int GPUDeviceManager::device_count() const { return CudaDriver::device_count(); } | |||
| @@ -76,17 +76,17 @@ uint32_t GPUDeviceManager::cur_device_id() const { return cur_dev_id_; } | |||
| bool GPUDeviceManager::is_device_id_init() const { return dev_id_init_; } | |||
| const cudnnHandle_t& GPUDeviceManager::GetCudnnHandle() const { return cudnn_handle_; } | |||
| const cudnnHandle_t &GPUDeviceManager::GetCudnnHandle() const { return cudnn_handle_; } | |||
| const cublasHandle_t& GPUDeviceManager::GetCublasHandle() const { return cublas_handle_; } | |||
| const cublasHandle_t &GPUDeviceManager::GetCublasHandle() const { return cublas_handle_; } | |||
| bool GPUDeviceManager::SyncStream(const DeviceStream& stream) const { return CudaDriver::SyncStream(stream); } | |||
| bool GPUDeviceManager::SyncStream(const DeviceStream &stream) const { return CudaDriver::SyncStream(stream); } | |||
| bool GPUDeviceManager::CopyDeviceMemToHost(const HostMemPtr& dst, const DeviceMemPtr& src, size_t size) const { | |||
| bool GPUDeviceManager::CopyDeviceMemToHost(const HostMemPtr &dst, const DeviceMemPtr &src, size_t size) const { | |||
| return CudaDriver::CopyDeviceMemToHost(dst, src, size); | |||
| } | |||
| bool GPUDeviceManager::CopyHostMemToDevice(const DeviceMemPtr& dst, const void* src, size_t size) const { | |||
| bool GPUDeviceManager::CopyHostMemToDevice(const DeviceMemPtr &dst, const void *src, size_t size) const { | |||
| return CudaDriver::CopyHostMemToDevice(dst, src, size); | |||
| } | |||
| } // namespace gpu | |||
| @@ -37,17 +37,17 @@ class GPUDeviceManager { | |||
| uint32_t cur_device_id() const; | |||
| bool is_device_id_init() const; | |||
| bool CreateStream(DeviceStream* stream); | |||
| bool SyncStream(const DeviceStream& stream) const; | |||
| const DeviceStream& default_stream() const; | |||
| bool CreateStream(DeviceStream *stream); | |||
| bool SyncStream(const DeviceStream &stream) const; | |||
| const DeviceStream &default_stream() const; | |||
| const cudnnHandle_t& GetCudnnHandle() const; | |||
| const cublasHandle_t& GetCublasHandle() const; | |||
| const cudnnHandle_t &GetCudnnHandle() const; | |||
| const cublasHandle_t &GetCublasHandle() const; | |||
| bool CopyDeviceMemToHost(const HostMemPtr& dst, const DeviceMemPtr& src, size_t size) const; | |||
| bool CopyHostMemToDevice(const DeviceMemPtr& dst, const void* src, size_t size) const; | |||
| bool CopyDeviceMemToHost(const HostMemPtr &dst, const DeviceMemPtr &src, size_t size) const; | |||
| bool CopyHostMemToDevice(const DeviceMemPtr &dst, const void *src, size_t size) const; | |||
| static GPUDeviceManager& GetInstance() { | |||
| static GPUDeviceManager &GetInstance() { | |||
| static GPUDeviceManager instance; | |||
| return instance; | |||
| } | |||
| @@ -55,8 +55,8 @@ class GPUDeviceManager { | |||
| private: | |||
| GPUDeviceManager() : dev_id_init_(false), cur_dev_id_(0) {} | |||
| ~GPUDeviceManager() = default; | |||
| GPUDeviceManager(const GPUDeviceManager&) = delete; | |||
| GPUDeviceManager& operator=(const GPUDeviceManager&) = delete; | |||
| GPUDeviceManager(const GPUDeviceManager &) = delete; | |||
| GPUDeviceManager &operator=(const GPUDeviceManager &) = delete; | |||
| // default CUDA stream used for all the kernels. | |||
| DeviceStream default_stream_{nullptr}; | |||
| @@ -43,14 +43,14 @@ bool GPUMemoryAllocator::Finalize() { | |||
| return true; | |||
| } | |||
| bool GPUMemoryAllocator::AllocBufferQueueMem(size_t size, DeviceMemPtr* addr) { | |||
| bool GPUMemoryAllocator::AllocBufferQueueMem(size_t size, DeviceMemPtr *addr) { | |||
| auto alloc_size = AllocDeviceMem(size, addr); | |||
| buffer_q_addr_ = *addr; | |||
| // Buffer queue needs to ensure that the alloc_size and size is equal. | |||
| return (alloc_size == size) ? true : false; | |||
| } | |||
| size_t GPUMemoryAllocator::AllocDeviceMem(size_t size, DeviceMemPtr* addr) { | |||
| size_t GPUMemoryAllocator::AllocDeviceMem(size_t size, DeviceMemPtr *addr) { | |||
| if (size == 0) { | |||
| MS_LOG(EXCEPTION) << "The memory alloc size is 0."; | |||
| } | |||
| @@ -68,7 +68,7 @@ size_t GPUMemoryAllocator::AllocDeviceMem(size_t size, DeviceMemPtr* addr) { | |||
| return alloc_size; | |||
| } | |||
| bool GPUMemoryAllocator::FreeDeviceMem(const DeviceMemPtr& addr) { return CudaDriver::FreeDeviceMem(addr); } | |||
| bool GPUMemoryAllocator::FreeDeviceMem(const DeviceMemPtr &addr) { return CudaDriver::FreeDeviceMem(addr); } | |||
| size_t GPUMemoryAllocator::free_mem_size() { return CudaDriver::free_mem_size(); } | |||
| @@ -29,22 +29,22 @@ class GPUMemoryAllocator : public DynamicMemPoolBestFit { | |||
| ~GPUMemoryAllocator() override = default; | |||
| bool Init(); | |||
| bool Finalize(); | |||
| bool AllocBufferQueueMem(size_t size, DeviceMemPtr* addr); | |||
| bool AllocBufferQueueMem(size_t size, DeviceMemPtr *addr); | |||
| size_t AllocDeviceMem(size_t size, DeviceMemPtr* addr) override; | |||
| bool FreeDeviceMem(const DeviceMemPtr& addr) override; | |||
| size_t AllocDeviceMem(size_t size, DeviceMemPtr *addr) override; | |||
| bool FreeDeviceMem(const DeviceMemPtr &addr) override; | |||
| size_t free_mem_size() override; | |||
| size_t total_mem_size() override; | |||
| static GPUMemoryAllocator& GetInstance() { | |||
| static GPUMemoryAllocator &GetInstance() { | |||
| static GPUMemoryAllocator instance; | |||
| return instance; | |||
| } | |||
| private: | |||
| GPUMemoryAllocator() = default; | |||
| GPUMemoryAllocator(const GPUMemoryAllocator&) = delete; | |||
| GPUMemoryAllocator& operator=(const GPUMemoryAllocator&) = delete; | |||
| GPUMemoryAllocator(const GPUMemoryAllocator &) = delete; | |||
| GPUMemoryAllocator &operator=(const GPUMemoryAllocator &) = delete; | |||
| // Used to track address of data buffer queue. | |||
| DeviceMemPtr buffer_q_addr_{nullptr}; | |||
| @@ -33,8 +33,8 @@ namespace gpu { | |||
| using AnfAlgo = mindspore::session::AnfRuntimeAlgorithm; | |||
| using mindspore::kernel::KernelBuildInfo; | |||
| namespace { | |||
| bool CheckKernelInfo(const std::shared_ptr<KernelBuildInfo>& alternative_kernel_info, | |||
| const std::shared_ptr<KernelBuildInfo>& selected_kernel_info) { | |||
| bool CheckKernelInfo(const std::shared_ptr<KernelBuildInfo> &alternative_kernel_info, | |||
| const std::shared_ptr<KernelBuildInfo> &selected_kernel_info) { | |||
| MS_EXCEPTION_IF_NULL(selected_kernel_info); | |||
| MS_EXCEPTION_IF_NULL(alternative_kernel_info); | |||
| size_t selected_input_num = selected_kernel_info->GetInputNum(); | |||
| @@ -67,7 +67,7 @@ bool CheckKernelInfo(const std::shared_ptr<KernelBuildInfo>& alternative_kernel_ | |||
| return true; | |||
| } | |||
| std::string SupportedTypeList(const CNodePtr& kernel_node) { | |||
| std::string SupportedTypeList(const CNodePtr &kernel_node) { | |||
| std::string supported_type_lists = | |||
| kernel::GpuKernelFactory::GetInstance().SupportedTypeList(AnfAlgo::GetCNodeName(kernel_node)); | |||
| if (!supported_type_lists.empty()) { | |||
| @@ -91,7 +91,7 @@ std::string SupportedTypeList(const CNodePtr& kernel_node) { | |||
| return supported_type_lists; | |||
| } | |||
| bool SelectAkgKernel(const CNodePtr& kernel_node, const std::shared_ptr<KernelBuildInfo>& selected_kernel_info) { | |||
| bool SelectAkgKernel(const CNodePtr &kernel_node, const std::shared_ptr<KernelBuildInfo> &selected_kernel_info) { | |||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||
| MS_EXCEPTION_IF_NULL(selected_kernel_info); | |||
| std::vector<std::shared_ptr<KernelBuildInfo>> kernel_info_list; | |||
| @@ -110,7 +110,7 @@ bool SelectAkgKernel(const CNodePtr& kernel_node, const std::shared_ptr<KernelBu | |||
| } | |||
| bool match = std::any_of(kernel_info_list.begin(), kernel_info_list.end(), | |||
| [&](const std::shared_ptr<KernelBuildInfo>& alternative_kernel_info) { | |||
| [&](const std::shared_ptr<KernelBuildInfo> &alternative_kernel_info) { | |||
| return CheckKernelInfo(alternative_kernel_info, selected_kernel_info); | |||
| }); | |||
| if (!match) { | |||
| @@ -120,7 +120,7 @@ bool SelectAkgKernel(const CNodePtr& kernel_node, const std::shared_ptr<KernelBu | |||
| return true; | |||
| } | |||
| void SetTensorDeviceInfo(const kernel::KernelBuildInfo& selected_kernel_info, const CNodePtr& kernel_node) { | |||
| void SetTensorDeviceInfo(const kernel::KernelBuildInfo &selected_kernel_info, const CNodePtr &kernel_node) { | |||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||
| for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) { | |||
| auto input_kernel_node = kernel_node->input(input_index + 1); | |||
| @@ -153,7 +153,7 @@ void SetTensorDeviceInfo(const kernel::KernelBuildInfo& selected_kernel_info, co | |||
| } | |||
| } // namespace | |||
| void SetKernelInfo(const CNodePtr& kernel_node) { | |||
| void SetKernelInfo(const CNodePtr &kernel_node) { | |||
| std::vector<std::string> inputs_format; | |||
| std::vector<TypeId> inputs_type; | |||
| std::shared_ptr<KernelBuildInfo::KernelBuildInfoBuilder> builder = | |||
| @@ -27,7 +27,7 @@ | |||
| namespace mindspore { | |||
| namespace device { | |||
| namespace gpu { | |||
| void SetKernelInfo(const CNodePtr& apply_kernel_ptr); | |||
| void SetKernelInfo(const CNodePtr &apply_kernel_ptr); | |||
| class KernelAttr { | |||
| public: | |||
| @@ -35,24 +35,24 @@ class KernelAttr { | |||
| KernelAttr() : all_same_(false) {} | |||
| ~KernelAttr() = default; | |||
| KernelAttr& AddInputAttr(const TypeId& ms_type, const std::string& format = kOpFormat_DEFAULT) { | |||
| KernelAttr &AddInputAttr(const TypeId &ms_type, const std::string &format = kOpFormat_DEFAULT) { | |||
| input_type_.emplace_back(ms_type, format); | |||
| return *this; | |||
| } | |||
| KernelAttr& AddOutputAttr(const TypeId& ms_type, const std::string& format = kOpFormat_DEFAULT) { | |||
| KernelAttr &AddOutputAttr(const TypeId &ms_type, const std::string &format = kOpFormat_DEFAULT) { | |||
| output_type_.emplace_back(ms_type, format); | |||
| return *this; | |||
| } | |||
| KernelAttr& AddAllSameAttr(const bool& all_same) { | |||
| KernelAttr &AddAllSameAttr(const bool &all_same) { | |||
| all_same_ = all_same; | |||
| return *this; | |||
| } | |||
| const DataType& GetInputAttr(const size_t index) const { return input_type_[index]; } | |||
| const DataType& GetOutputAttr(const size_t index) const { return output_type_[index]; } | |||
| const bool& GetAllSame() const { return all_same_; } | |||
| const DataType &GetInputAttr(const size_t index) const { return input_type_[index]; } | |||
| const DataType &GetOutputAttr(const size_t index) const { return output_type_[index]; } | |||
| const bool &GetAllSame() const { return all_same_; } | |||
| size_t GetInputSize() const { return input_type_.size(); } | |||
| size_t GetOutputSize() const { return output_type_.size(); } | |||
| @@ -24,7 +24,7 @@ | |||
| namespace mindspore { | |||
| struct TypeIdManager* TypeIdManager::Get() { | |||
| struct TypeIdManager *TypeIdManager::Get() { | |||
| static TypeIdManager manager; | |||
| return &manager; | |||
| } | |||
| @@ -35,14 +35,14 @@ TypePtr AnfNode::Type() const { return (abstract_ == nullptr) ? nullptr : abstra | |||
| BaseShapePtr AnfNode::Shape() const { return (abstract_ == nullptr) ? nullptr : abstract_->BuildShape(); } | |||
| std::string AnfNode::ToString() const { | |||
| return mindspore::label_manage::Label(const_cast<AnfNode*>(this)->shared_from_base<AnfNode>()->debug_info()); | |||
| return mindspore::label_manage::Label(const_cast<AnfNode *>(this)->shared_from_base<AnfNode>()->debug_info()); | |||
| } | |||
| CNode::CNode(const std::vector<AnfNodePtr>& inputs, const FuncGraphPtr& func_graph) | |||
| CNode::CNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphPtr &func_graph) | |||
| : AnfNode(func_graph), inputs_(inputs), stop_gradient_(false) {} | |||
| // Check if CNode is an apply with the specific Primitive. | |||
| bool CNode::IsApply(const PrimitivePtr& value) const { | |||
| bool CNode::IsApply(const PrimitivePtr &value) const { | |||
| if (value == nullptr) { | |||
| return false; | |||
| } | |||
| @@ -57,7 +57,7 @@ bool CNode::IsApply(const PrimitivePtr& value) const { | |||
| return false; | |||
| } | |||
| void CNode::set_input(size_t i, const AnfNodePtr& new_input) { inputs_[i] = new_input; } | |||
| void CNode::set_input(size_t i, const AnfNodePtr &new_input) { inputs_[i] = new_input; } | |||
| std::string CNode::DebugString(int recursive_level) const { | |||
| std::ostringstream buffer; | |||
| @@ -68,7 +68,7 @@ std::string CNode::DebugString(int recursive_level) const { | |||
| buffer << ToString() << "{"; | |||
| bool is_first_node = true; | |||
| int idx = 0; | |||
| for (auto& node : inputs_) { | |||
| for (auto &node : inputs_) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| if (is_first_node) { | |||
| is_first_node = false; | |||
| @@ -85,7 +85,7 @@ std::string CNode::DebugString(int recursive_level) const { | |||
| return buffer.str(); | |||
| } | |||
| OperatorInfoPtr CNode::set_operator_info(const OperatorInfoPtr& operator_info) { | |||
| OperatorInfoPtr CNode::set_operator_info(const OperatorInfoPtr &operator_info) { | |||
| if (operator_info_ != nullptr) { | |||
| MS_LOG(WARNING) << "The CNode: " << ToString() << " has already been set OperatorInfo: " << operator_info_->name() | |||
| << ", using the new one: " << operator_info->name(); | |||
| @@ -173,11 +173,11 @@ std::string ValueNode::fullname_with_scope() { | |||
| return fullname_with_scope_; | |||
| } | |||
| void CNode::accept(AnfVisitor* v) { v->Visit(shared_from_base<CNode>()); } | |||
| void ValueNode::accept(AnfVisitor* v) { v->Visit(shared_from_base<ValueNode>()); } | |||
| void Parameter::accept(AnfVisitor* v) { v->Visit(shared_from_base<Parameter>()); } | |||
| void CNode::accept(AnfVisitor *v) { v->Visit(shared_from_base<CNode>()); } | |||
| void ValueNode::accept(AnfVisitor *v) { v->Visit(shared_from_base<ValueNode>()); } | |||
| void Parameter::accept(AnfVisitor *v) { v->Visit(shared_from_base<Parameter>()); } | |||
| bool IsPrimitiveCNode(const AnfNodePtr& node, const PrimitivePtr& value) { | |||
| bool IsPrimitiveCNode(const AnfNodePtr &node, const PrimitivePtr &value) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| if (cnode != nullptr) { | |||
| @@ -186,7 +186,7 @@ bool IsPrimitiveCNode(const AnfNodePtr& node, const PrimitivePtr& value) { | |||
| return false; | |||
| } | |||
| PrimitivePtr GetCNodePrimitive(const AnfNodePtr& node) { | |||
| PrimitivePtr GetCNodePrimitive(const AnfNodePtr &node) { | |||
| if (node == nullptr) { | |||
| return nullptr; | |||
| } | |||
| @@ -217,7 +217,7 @@ std::string GetCNodeFuncName(const CNodePtr cnode) { | |||
| return ""; | |||
| } | |||
| bool IsPrimitive(const AnfNodePtr& node, const PrimitivePtr& value) { | |||
| bool IsPrimitive(const AnfNodePtr &node, const PrimitivePtr &value) { | |||
| if (IsValueNode<Primitive>(node)) { | |||
| PrimitivePtr fn_value = GetValueNode<PrimitivePtr>(node); | |||
| MS_EXCEPTION_IF_NULL(value); | |||
| @@ -229,7 +229,7 @@ bool IsPrimitive(const AnfNodePtr& node, const PrimitivePtr& value) { | |||
| } | |||
| namespace id_generator { | |||
| static std::unordered_map<std::string, int> node_ids; | |||
| std::string get_id(const AnfNodePtr& node) { | |||
| std::string get_id(const AnfNodePtr &node) { | |||
| auto type_name = node->type_name(); | |||
| if (node_ids.find(type_name) == node_ids.end()) { | |||
| node_ids[type_name] = 0; | |||
| @@ -39,15 +39,15 @@ struct is_shared_ptr<std::shared_ptr<T>> : public std::true_type {}; | |||
| class Base : public std::enable_shared_from_this<Base> { | |||
| public: | |||
| constexpr Base() = default; | |||
| Base(const Base& other) : std::enable_shared_from_this<Base>(other) {} | |||
| virtual bool operator==(const Base& rhs) { | |||
| Base(const Base &other) : std::enable_shared_from_this<Base>(other) {} | |||
| virtual bool operator==(const Base &rhs) { | |||
| if (this == &rhs) { | |||
| return true; | |||
| } | |||
| return false; | |||
| } | |||
| virtual Base& operator=(const Base&) { return *this; } | |||
| virtual Base &operator=(const Base &) { return *this; } | |||
| virtual ~Base() = default; | |||
| virtual std::size_t hash() const { return tid(); } | |||
| virtual std::string ToString() const { return type_name(); } | |||
| @@ -57,14 +57,14 @@ class Base : public std::enable_shared_from_this<Base> { | |||
| virtual const bool IsFromTypeId(uint32_t tid) const; | |||
| virtual std::string type_name() const { return "Base"; } | |||
| static uint32_t GetTypeId(const char* const type_key); | |||
| static uint32_t GetTypeId(const char *const type_key); | |||
| virtual uint32_t tid() const { | |||
| static const uint32_t tid = GetTypeId(typeid(Base).name()); | |||
| return tid; | |||
| } | |||
| template <typename T, | |||
| typename std::enable_if<!is_shared_ptr<T>::value && std::is_base_of<Base, T>::value, T>::type* = nullptr> | |||
| typename std::enable_if<!is_shared_ptr<T>::value && std::is_base_of<Base, T>::value, T>::type * = nullptr> | |||
| inline bool isa() const { | |||
| static const uint32_t tid = GetTypeId(typeid(T).name()); | |||
| return this->IsFromTypeId(tid); | |||
| @@ -90,9 +90,9 @@ using BasePtr = std::shared_ptr<Base>; | |||
| using BaseWeakPtr = std::weak_ptr<Base>; | |||
| template <typename T, typename U> | |||
| inline T* cast(U* source) { | |||
| inline T *cast(U *source) { | |||
| if (source != nullptr && source->template isa<T>()) { | |||
| return static_cast<T*>(source); | |||
| return static_cast<T *>(source); | |||
| } else { | |||
| return nullptr; | |||
| } | |||
| @@ -100,7 +100,7 @@ inline T* cast(U* source) { | |||
| template < | |||
| typename T, typename U, | |||
| typename std::enable_if<std::is_base_of<Base, T>::value && std::is_base_of<Base, U>::value, T>::type* = nullptr> | |||
| typename std::enable_if<std::is_base_of<Base, T>::value && std::is_base_of<Base, U>::value, T>::type * = nullptr> | |||
| inline std::shared_ptr<T> dyn_cast(const std::shared_ptr<U> r) { | |||
| if (r != nullptr && r->template isa<T>()) { | |||
| return std::static_pointer_cast<T>(r); | |||
| @@ -143,7 +143,7 @@ struct MS_EXPORT TypeIdManager { | |||
| std::mutex mutex; | |||
| std::atomic<uint32_t> type_counter{0}; | |||
| std::unordered_map<std::string, uint32_t> map; | |||
| static TypeIdManager* Get(); | |||
| static TypeIdManager *Get(); | |||
| TypeIdManager() : mutex(), type_counter(0), map() {} | |||
| }; | |||
| } // namespace mindspore | |||
| @@ -48,11 +48,11 @@ std::string Keyword::ToString() const { | |||
| return buffer.str(); | |||
| } | |||
| bool Keyword::operator==(const Type& other) const { | |||
| bool Keyword::operator==(const Type &other) const { | |||
| if (!IsSameObjectType(*this, other)) { | |||
| return false; | |||
| } | |||
| const auto& other_keyword = static_cast<const Keyword&>(other); | |||
| const auto &other_keyword = static_cast<const Keyword &>(other); | |||
| return (other_keyword.key_ == key_ && *other_keyword.value_ == *value_); | |||
| } | |||
| @@ -87,11 +87,11 @@ std::string Slice::ToString() const { | |||
| return buffer.str(); | |||
| } | |||
| bool Slice::operator==(const Type& other) const { | |||
| bool Slice::operator==(const Type &other) const { | |||
| if (!IsSameObjectType(*this, other)) { | |||
| return false; | |||
| } | |||
| auto other_slice = static_cast<const Slice&>(other); | |||
| auto other_slice = static_cast<const Slice &>(other); | |||
| return (*start_ == *other_slice.start_ && *stop_ == *other_slice.stop_ && *step_ == *other_slice.step_); | |||
| } | |||
| @@ -122,11 +122,11 @@ std::string TensorType::DumpText() const { | |||
| } | |||
| } | |||
| bool TensorType::operator==(const Type& other) const { | |||
| bool TensorType::operator==(const Type &other) const { | |||
| if (!IsSameObjectType(*this, other)) { | |||
| return false; | |||
| } | |||
| auto other_elem_type = static_cast<const TensorType&>(other).element_type_; | |||
| auto other_elem_type = static_cast<const TensorType &>(other).element_type_; | |||
| // When element_type_ = nullptr, which means any type of Array. | |||
| if (element_type_ == nullptr && other_elem_type == nullptr) { | |||
| return true; | |||
| @@ -141,7 +141,7 @@ Function::Function() : Object(kObjectTypeFunction) { | |||
| retval_ = nullptr; | |||
| } | |||
| Function::Function(const std::vector<TypePtr>& args, const TypePtr retval) | |||
| Function::Function(const std::vector<TypePtr> &args, const TypePtr retval) | |||
| : Object(kObjectTypeFunction, false), args_(args), retval_(retval) {} | |||
| TypePtr Function::DeepCopy() const { | |||
| @@ -151,7 +151,7 @@ TypePtr Function::DeepCopy() const { | |||
| TypePtrList args; | |||
| TypePtr retval = nullptr; | |||
| (void)std::transform(args_.begin(), args_.end(), std::back_inserter(args), | |||
| [](const TypePtr& arg) { return arg->DeepCopy(); }); | |||
| [](const TypePtr &arg) { return arg->DeepCopy(); }); | |||
| if (retval_ != nullptr) { | |||
| retval = retval_->DeepCopy(); | |||
| } | |||
| @@ -159,12 +159,12 @@ TypePtr Function::DeepCopy() const { | |||
| } | |||
| } | |||
| bool Function::operator==(const Type& other) const { | |||
| bool Function::operator==(const Type &other) const { | |||
| if (!IsSameObjectType(*this, other)) { | |||
| return false; | |||
| } | |||
| const auto& other_function = static_cast<const Function&>(other); | |||
| const auto &other_function = static_cast<const Function &>(other); | |||
| if ((retval_ != nullptr) && (other_function.retval_ != nullptr)) { | |||
| if (*retval_ != *other_function.retval_) { | |||
| return false; | |||
| @@ -188,7 +188,7 @@ std::string Function::ToString() const { | |||
| } else { | |||
| buffer << "Func[("; | |||
| bool begin = true; | |||
| for (auto& attr : args_) { | |||
| for (auto &attr : args_) { | |||
| if (!begin) { | |||
| buffer << ", "; | |||
| } else { | |||
| @@ -242,34 +242,34 @@ std::string JTagged::DumpText() const { | |||
| return buffer.str(); | |||
| } | |||
| std::ostream& operator<<(std::ostream& os, const std::shared_ptr<Problem> problem) { | |||
| std::ostream &operator<<(std::ostream &os, const std::shared_ptr<Problem> problem) { | |||
| MS_EXCEPTION_IF_NULL(problem); | |||
| os << problem->ToString(); | |||
| return os; | |||
| } | |||
| std::size_t TypeHasher::operator()(TypePtr const& type) const { | |||
| std::size_t TypeHasher::operator()(TypePtr const &type) const { | |||
| MS_EXCEPTION_IF_NULL(type); | |||
| std::size_t hash = std::hash<size_t>()(type->type_id()); | |||
| return hash; | |||
| } | |||
| std::size_t TypeListHasher::operator()(const TypePtrList& type_list) const { | |||
| std::size_t TypeListHasher::operator()(const TypePtrList &type_list) const { | |||
| std::size_t hash_sum = 0; | |||
| for (auto& type : type_list) { | |||
| for (auto &type : type_list) { | |||
| auto type_id = static_cast<std::size_t>(type->type_id()); | |||
| hash_sum = hash_combine(hash_sum, type_id); | |||
| } | |||
| return hash_sum; | |||
| } | |||
| bool TypeEqual::operator()(TypePtr const& t1, TypePtr const& t2) const { | |||
| bool TypeEqual::operator()(TypePtr const &t1, TypePtr const &t2) const { | |||
| MS_EXCEPTION_IF_NULL(t1); | |||
| MS_EXCEPTION_IF_NULL(t2); | |||
| return t1->type_id() == t2->type_id(); | |||
| } | |||
| bool TypeListEqual::operator()(TypePtrList const& lhs, TypePtrList const& rhs) const { | |||
| bool TypeListEqual::operator()(TypePtrList const &lhs, TypePtrList const &rhs) const { | |||
| if (lhs.size() != rhs.size()) { | |||
| return false; | |||
| } | |||
| @@ -332,7 +332,7 @@ TypePtr TypeIdToType(TypeId id) { | |||
| namespace { | |||
| template <typename T> | |||
| TypePtr StringToNumberType(const std::string& type_name, const std::string& num_type_name) { | |||
| TypePtr StringToNumberType(const std::string &type_name, const std::string &num_type_name) { | |||
| TypePtr type = nullptr; | |||
| if (type_name == num_type_name) { | |||
| type = std::make_shared<T>(); | |||
| @@ -344,14 +344,14 @@ TypePtr StringToNumberType(const std::string& type_name, const std::string& num_ | |||
| } | |||
| auto bits = std::stoi(type_name.substr(num_type_name.size())); | |||
| type = std::make_shared<T>(bits); | |||
| } catch (const std::exception& e) { | |||
| } catch (const std::exception &e) { | |||
| MS_LOG(EXCEPTION) << "" << num_type_name << " convert from string error " << e.what(); | |||
| } | |||
| } | |||
| return type; | |||
| } | |||
| std::vector<TypePtr> StringToVectorOfType(const std::string& type_names) { | |||
| std::vector<TypePtr> StringToVectorOfType(const std::string &type_names) { | |||
| std::vector<TypePtr> types; | |||
| if (type_names.length() == 0) { | |||
| return types; | |||
| @@ -371,7 +371,7 @@ std::vector<TypePtr> StringToVectorOfType(const std::string& type_names) { | |||
| return types; | |||
| } | |||
| TypePtr TensorStrToType(const std::string& type_name) { | |||
| TypePtr TensorStrToType(const std::string &type_name) { | |||
| TypePtr type = nullptr; | |||
| if (type_name == "Tensor") { | |||
| type = std::make_shared<TensorType>(); | |||
| @@ -388,7 +388,7 @@ TypePtr TensorStrToType(const std::string& type_name) { | |||
| return nullptr; | |||
| } | |||
| type = std::make_shared<TensorType>(element_type); | |||
| } catch (const std::exception& e) { | |||
| } catch (const std::exception &e) { | |||
| MS_LOG(EXCEPTION) << "" << type_name << " convert from string error " << e.what(); | |||
| } | |||
| } | |||
| @@ -396,7 +396,7 @@ TypePtr TensorStrToType(const std::string& type_name) { | |||
| return type; | |||
| } | |||
| TypePtr ListStrToType(const std::string& type_name) { | |||
| TypePtr ListStrToType(const std::string &type_name) { | |||
| TypePtr type = nullptr; | |||
| if (type_name == "List") { | |||
| type = std::make_shared<List>(); | |||
| @@ -410,12 +410,12 @@ TypePtr ListStrToType(const std::string& type_name) { | |||
| std::string element_strs = type_name.substr(start, end - start); | |||
| std::vector<TypePtr> element_types = StringToVectorOfType(element_strs); | |||
| bool wrong = | |||
| std::any_of(element_types.begin(), element_types.end(), [](const TypePtr& x) { return x == nullptr; }); | |||
| std::any_of(element_types.begin(), element_types.end(), [](const TypePtr &x) { return x == nullptr; }); | |||
| if (wrong) { | |||
| return nullptr; | |||
| } | |||
| type = std::make_shared<List>(element_types); | |||
| } catch (const std::exception& e) { | |||
| } catch (const std::exception &e) { | |||
| MS_LOG(EXCEPTION) << "" << type_name << " convert from string error " << e.what(); | |||
| } | |||
| } | |||
| @@ -423,7 +423,7 @@ TypePtr ListStrToType(const std::string& type_name) { | |||
| return type; | |||
| } | |||
| TypePtr TupleStrToType(const std::string& type_name) { | |||
| TypePtr TupleStrToType(const std::string &type_name) { | |||
| TypePtr type = nullptr; | |||
| if (type_name == "Tuple") { | |||
| type = std::make_shared<Tuple>(); | |||
| @@ -437,19 +437,19 @@ TypePtr TupleStrToType(const std::string& type_name) { | |||
| std::string element_strs = type_name.substr(start, end - start); | |||
| std::vector<TypePtr> element_types = StringToVectorOfType(element_strs); | |||
| bool wrong = | |||
| std::any_of(element_types.begin(), element_types.end(), [](const TypePtr& x) { return x == nullptr; }); | |||
| std::any_of(element_types.begin(), element_types.end(), [](const TypePtr &x) { return x == nullptr; }); | |||
| if (wrong) { | |||
| return nullptr; | |||
| } | |||
| type = std::make_shared<Tuple>(element_types); | |||
| } catch (const std::exception& e) { | |||
| } catch (const std::exception &e) { | |||
| MS_LOG(EXCEPTION) << "" << type_name << " convert from string error " << e.what(); | |||
| } | |||
| } | |||
| return type; | |||
| } | |||
| TypePtr FunctionStrToType(const std::string& type_name) { | |||
| TypePtr FunctionStrToType(const std::string &type_name) { | |||
| TypePtr type = nullptr; | |||
| if (type_name == "Function") { | |||
| @@ -478,12 +478,12 @@ TypePtr FunctionStrToType(const std::string& type_name) { | |||
| std::vector<TypePtr> args_type = StringToVectorOfType(str_args); | |||
| TypePtr retval = StringToType(str_retval); | |||
| bool wrong = std::any_of(args_type.begin(), args_type.end(), [](const TypePtr& x) { return x == nullptr; }); | |||
| bool wrong = std::any_of(args_type.begin(), args_type.end(), [](const TypePtr &x) { return x == nullptr; }); | |||
| if (retval == nullptr || wrong) { | |||
| return nullptr; | |||
| } | |||
| type = std::make_shared<Function>(args_type, retval); | |||
| } catch (const std::exception& e) { | |||
| } catch (const std::exception &e) { | |||
| MS_LOG(EXCEPTION) << "" << type_name << " convert from string error " << e.what(); | |||
| } | |||
| } | |||
| @@ -491,7 +491,7 @@ TypePtr FunctionStrToType(const std::string& type_name) { | |||
| } | |||
| } // namespace | |||
| TypePtr StringToType(const std::string& type_name) { | |||
| TypePtr StringToType(const std::string &type_name) { | |||
| TypePtr type = nullptr; | |||
| if (type_name.compare("None") == 0) { | |||
| type = std::make_shared<TypeNone>(); | |||
| @@ -542,7 +542,7 @@ TypePtr StringToType(const std::string& type_name) { | |||
| return type; | |||
| } | |||
| bool IsIdentidityOrSubclass(TypePtr const& x, TypePtr const& base_type) { | |||
| bool IsIdentidityOrSubclass(TypePtr const &x, TypePtr const &base_type) { | |||
| if (x == nullptr || base_type == nullptr) { | |||
| MS_LOG(ERROR) << "Type is nullptr."; | |||
| return false; | |||
| @@ -564,7 +564,7 @@ bool IsIdentidityOrSubclass(TypePtr const& x, TypePtr const& base_type) { | |||
| } | |||
| } | |||
| bool IsSubType(TypePtr const& t1, TypePtr const& t2) { | |||
| bool IsSubType(TypePtr const &t1, TypePtr const &t2) { | |||
| MS_EXCEPTION_IF_NULL(t1); | |||
| if (t1->type_id() == kTypeUnknown) { | |||
| return false; | |||
| @@ -576,17 +576,17 @@ bool IsSubType(TypePtr const& t1, TypePtr const& t2) { | |||
| } | |||
| REGISTER_PYBIND_DEFINE( | |||
| typing, ([](py::module* const m) { | |||
| typing, ([](py::module *const m) { | |||
| auto m_sub = m->def_submodule("typing", "submodule for dtype"); | |||
| py::enum_<TypeId>(m_sub, "TypeId"); | |||
| (void)m_sub.def("is_subclass", &IsIdentidityOrSubclass, "is equal or subclass"); | |||
| (void)m_sub.def("load_type", &TypeIdToType, "load type"); | |||
| (void)m_sub.def( | |||
| "dump_type", [](const TypePtr& t) { return t->type_id(); }, "dump type"); | |||
| "dump_type", [](const TypePtr &t) { return t->type_id(); }, "dump type"); | |||
| (void)py::class_<Type, std::shared_ptr<Type>>(m_sub, "Type") | |||
| .def_readonly(PYTHON_DTYPE_FLAG, &mindspore::Type::parse_info_) | |||
| .def("__eq__", | |||
| [](const TypePtr& t1, const TypePtr& t2) { | |||
| [](const TypePtr &t1, const TypePtr &t2) { | |||
| if (t1 != nullptr && t2 != nullptr) { | |||
| return *t1 == *t2; | |||
| } | |||
| @@ -595,7 +595,7 @@ REGISTER_PYBIND_DEFINE( | |||
| .def("__hash__", &Type::hash) | |||
| .def("__str__", &Type::ToString) | |||
| .def("__repr__", &Type::ReprString) | |||
| .def("__deepcopy__", [](const TypePtr& t, py::dict) { | |||
| .def("__deepcopy__", [](const TypePtr &t, py::dict) { | |||
| if (t == nullptr) { | |||
| return static_cast<TypePtr>(nullptr); | |||
| } | |||
| @@ -605,21 +605,21 @@ REGISTER_PYBIND_DEFINE( | |||
| (void)py::class_<Bool, Type, std::shared_ptr<Bool>>(m_sub, "Bool") | |||
| .def(py::init()) | |||
| .def(py::pickle( | |||
| [](const Bool&) { // __getstate__ | |||
| [](const Bool &) { // __getstate__ | |||
| return py::make_tuple(); | |||
| }, | |||
| [](const py::tuple&) { // __setstate__ | |||
| [](const py::tuple &) { // __setstate__ | |||
| return std::make_shared<Bool>(); | |||
| })); | |||
| (void)py::class_<Int, Type, std::shared_ptr<Int>>(m_sub, "Int") | |||
| .def(py::init()) | |||
| .def(py::init<int>(), py::arg("nbits")) | |||
| .def(py::pickle( | |||
| [](const Int& t) { // __getstate__ | |||
| [](const Int &t) { // __getstate__ | |||
| /* Return a tuple that fully encodes the state of the object */ | |||
| return py::make_tuple(py::int_(t.nbits())); | |||
| }, | |||
| [](const py::tuple& t) { // __setstate__ | |||
| [](const py::tuple &t) { // __setstate__ | |||
| if (t.size() != 1) { | |||
| throw std::runtime_error("Invalid state!"); | |||
| } | |||
| @@ -631,11 +631,11 @@ REGISTER_PYBIND_DEFINE( | |||
| .def(py::init()) | |||
| .def(py::init<int>(), py::arg("nbits")) | |||
| .def(py::pickle( | |||
| [](const UInt& t) { // __getstate__ | |||
| [](const UInt &t) { // __getstate__ | |||
| /* Return a tuple that fully encodes the state of the object */ | |||
| return py::make_tuple(py::int_(t.nbits())); | |||
| }, | |||
| [](const py::tuple& t) { // __setstate__ | |||
| [](const py::tuple &t) { // __setstate__ | |||
| if (t.size() != 1) { | |||
| throw std::runtime_error("Invalid state!"); | |||
| } | |||
| @@ -647,11 +647,11 @@ REGISTER_PYBIND_DEFINE( | |||
| .def(py::init()) | |||
| .def(py::init<int>(), py::arg("nbits")) | |||
| .def(py::pickle( | |||
| [](const Float& t) { // __getstate__ | |||
| [](const Float &t) { // __getstate__ | |||
| /* Return a tuple that fully encodes the state of the object */ | |||
| return py::make_tuple(py::int_(t.nbits())); | |||
| }, | |||
| [](const py::tuple& t) { // __setstate__ | |||
| [](const py::tuple &t) { // __setstate__ | |||
| if (t.size() != 1) { | |||
| throw std::runtime_error("Invalid state!"); | |||
| } | |||
| @@ -670,11 +670,11 @@ REGISTER_PYBIND_DEFINE( | |||
| .def(py::init<TypePtr>(), py::arg("element")) | |||
| .def("element_type", &TensorType::element) | |||
| .def(py::pickle( | |||
| [](const TensorType& t) { // __getstate__ | |||
| [](const TensorType &t) { // __getstate__ | |||
| /* Return a tuple that fully encodes the state of the object */ | |||
| return py::make_tuple(py::int_(static_cast<int>(t.element()->type_id()))); | |||
| }, | |||
| [](const py::tuple& t) { // __setstate__ | |||
| [](const py::tuple &t) { // __setstate__ | |||
| if (t.size() != 1) { | |||
| throw std::runtime_error("Invalid state!"); | |||
| } | |||
| @@ -60,7 +60,7 @@ using StringPtr = std::shared_ptr<String>; | |||
| class Keyword : public Object { | |||
| public: | |||
| Keyword() : Object(kObjectTypeKeyword, false), key_(""), value_(nullptr) {} | |||
| Keyword(const std::string& key, const TypePtr& value) : Object(kObjectTypeKeyword, false), key_(key), value_(value) {} | |||
| Keyword(const std::string &key, const TypePtr &value) : Object(kObjectTypeKeyword, false), key_(key), value_(value) {} | |||
| ~Keyword() override = default; | |||
| MS_DECLARE_PARENT(Keyword, Object) | |||
| @@ -70,7 +70,7 @@ class Keyword : public Object { | |||
| std::string ToString() const override; | |||
| std::string DumpText() const override; | |||
| bool operator==(const Type& other) const override; | |||
| bool operator==(const Type &other) const override; | |||
| std::string GetKey() const { return key_; } | |||
| TypePtr GetValue() const { return value_; } | |||
| @@ -84,7 +84,7 @@ using KeywordPtr = std::shared_ptr<Keyword>; | |||
| class Slice : public Object { | |||
| public: | |||
| Slice() : Object(kObjectTypeSlice), start_(nullptr), stop_(nullptr), step_(nullptr) {} | |||
| Slice(const TypePtr& start, const TypePtr& stop, const TypePtr& step) | |||
| Slice(const TypePtr &start, const TypePtr &stop, const TypePtr &step) | |||
| : Object(kObjectTypeSlice, false), start_(start), stop_(stop), step_(step) {} | |||
| ~Slice() override = default; | |||
| @@ -95,7 +95,7 @@ class Slice : public Object { | |||
| std::string ToString() const override; | |||
| std::string DumpText() const override; | |||
| bool operator==(const Type& other) const override; | |||
| bool operator==(const Type &other) const override; | |||
| TypePtr get_start() const { return start_; } | |||
| TypePtr get_stop() const { return stop_; } | |||
| @@ -111,19 +111,19 @@ using SlicePtr = std::shared_ptr<Slice>; | |||
| class TensorType : public Object { | |||
| public: | |||
| TensorType() : Object(kObjectTypeTensorType) {} | |||
| explicit TensorType(const TypePtr& ele) : Object(kObjectTypeTensorType, false), element_type_(ele) {} | |||
| explicit TensorType(const TypePtr &ele) : Object(kObjectTypeTensorType, false), element_type_(ele) {} | |||
| ~TensorType() override = default; | |||
| MS_DECLARE_PARENT(TensorType, Object) | |||
| TypeId generic_type_id() const override { return kObjectTypeTensorType; } | |||
| const TypePtr element() const { return element_type_; } | |||
| void set_element(const TypePtr& element_type) { element_type_ = element_type; } | |||
| void set_element(const TypePtr &element_type) { element_type_ = element_type; } | |||
| TypePtr DeepCopy() const override; | |||
| std::string ToString() const override; | |||
| std::string ToReprString() const override { return "tensor"; } | |||
| std::string DumpText() const override; | |||
| bool operator==(const Type& other) const override; | |||
| bool operator==(const Type &other) const override; | |||
| private: | |||
| TypePtr element_type_; | |||
| @@ -133,7 +133,7 @@ using TensorTypePtr = std::shared_ptr<TensorType>; | |||
| class Function : public Object { | |||
| public: | |||
| Function(); | |||
| Function(const std::vector<TypePtr>& args, const TypePtr retval); | |||
| Function(const std::vector<TypePtr> &args, const TypePtr retval); | |||
| ~Function() override = default; | |||
| MS_DECLARE_PARENT(Function, Object) | |||
| @@ -141,11 +141,11 @@ class Function : public Object { | |||
| // Add temporarily for return abstraction to avoid type checking. | |||
| bool IsTransparent() const { return (args_.empty()) && (retval_ == nullptr); } | |||
| const std::vector<TypePtr>& args() const { return args_; } | |||
| const TypePtr& retval() const { return retval_; } | |||
| const std::vector<TypePtr> &args() const { return args_; } | |||
| const TypePtr &retval() const { return retval_; } | |||
| TypePtr DeepCopy() const override; | |||
| bool operator==(const Type& other) const override; | |||
| bool operator==(const Type &other) const override; | |||
| std::string ToString() const override; | |||
| std::string ToReprString() const override { return "function"; } | |||
| @@ -158,7 +158,7 @@ using FunctionPtr = std::shared_ptr<Function>; | |||
| class JTagged : public Object { | |||
| public: | |||
| JTagged() : Object(kObjectTypeJTagged) {} | |||
| explicit JTagged(const TypePtr& subtype) : Object(kObjectTypeJTagged, false), subtype_(subtype) {} | |||
| explicit JTagged(const TypePtr &subtype) : Object(kObjectTypeJTagged, false), subtype_(subtype) {} | |||
| ~JTagged() override = default; | |||
| MS_DECLARE_PARENT(JTagged, Object) | |||
| @@ -213,7 +213,7 @@ using TypeTypePtr = std::shared_ptr<TypeType>; | |||
| class Problem : public Type { | |||
| public: | |||
| Problem() : Type(kMetaTypeProblem), kind_(Named("unknown")) {} | |||
| explicit Problem(const Named& kind) : Type(kMetaTypeProblem), kind_(kind) {} | |||
| explicit Problem(const Named &kind) : Type(kMetaTypeProblem), kind_(kind) {} | |||
| ~Problem() override = default; | |||
| MS_DECLARE_PARENT(Problem, Type) | |||
| @@ -222,7 +222,7 @@ class Problem : public Type { | |||
| std::string ToString() const override { return kind_.name(); } | |||
| std::string DumpText() const override { return "ProblemType"; } | |||
| friend std::ostream& operator<<(std::ostream& os, const std::shared_ptr<Problem> problem); | |||
| friend std::ostream &operator<<(std::ostream &os, const std::shared_ptr<Problem> problem); | |||
| private: | |||
| Named kind_; | |||
| @@ -246,29 +246,29 @@ using ExternalPtr = std::shared_ptr<External>; | |||
| // helper template | |||
| template <class T> | |||
| TypePtr Clone(const T& t) { | |||
| TypePtr Clone(const T &t) { | |||
| return t.Clone(); | |||
| } | |||
| TypePtr StringToType(const std::string& type_name); | |||
| TypePtr StringToType(const std::string &type_name); | |||
| // Judge whether x is predicate or is a subclass of predicate. | |||
| bool IsIdentidityOrSubclass(TypePtr const& x, TypePtr const& base_type); | |||
| bool IsIdentidityOrSubclass(TypePtr const &x, TypePtr const &base_type); | |||
| // Whether t1 is identity or a subclass of t2. | |||
| bool IsSubType(TypePtr const& t1, TypePtr const& t2 = nullptr); | |||
| bool IsSubType(TypePtr const &t1, TypePtr const &t2 = nullptr); | |||
| struct TypeHasher { | |||
| std::size_t operator()(TypePtr const& type) const; | |||
| std::size_t operator()(TypePtr const &type) const; | |||
| }; | |||
| struct TypeListHasher { | |||
| std::size_t operator()(const TypePtrList& type_list) const; | |||
| std::size_t operator()(const TypePtrList &type_list) const; | |||
| }; | |||
| struct TypeEqual { | |||
| bool operator()(TypePtr const& t1, TypePtr const& t2) const; | |||
| bool operator()(TypePtr const &t1, TypePtr const &t2) const; | |||
| }; | |||
| struct TypeListEqual { | |||
| bool operator()(TypePtrList const& lhs, TypePtrList const& rhs) const; | |||
| bool operator()(TypePtrList const &lhs, TypePtrList const &rhs) const; | |||
| }; | |||
| extern const TypePtr kTypeExternal; | |||
| @@ -24,7 +24,7 @@ | |||
| #include "pybind_api/export_flags.h" | |||
| namespace mindspore { | |||
| static std::string DumpTypeVector(const std::vector<TypePtr>& elements, bool is_dumptext) { | |||
| static std::string DumpTypeVector(const std::vector<TypePtr> &elements, bool is_dumptext) { | |||
| std::ostringstream oss; | |||
| bool begin = true; | |||
| int cnt = 0; | |||
| @@ -65,7 +65,7 @@ TypePtr List::DeepCopy() const { | |||
| } else { | |||
| TypePtrList elements; | |||
| (void)std::transform(elements_.begin(), elements_.end(), std::back_inserter(elements), | |||
| [](const TypePtr& ele) { return ele->DeepCopy(); }); | |||
| [](const TypePtr &ele) { return ele->DeepCopy(); }); | |||
| auto copy = std::make_shared<List>(elements); | |||
| return copy; | |||
| } | |||
| @@ -78,11 +78,11 @@ const TypePtr List::operator[](std::size_t dim) const { | |||
| return elements_[dim]; | |||
| } | |||
| bool List::operator==(const Type& other) const { | |||
| bool List::operator==(const Type &other) const { | |||
| if (!IsSameObjectType(*this, other)) { | |||
| return false; | |||
| } | |||
| const List& other_list = static_cast<const List&>(other); | |||
| const List &other_list = static_cast<const List &>(other); | |||
| if (elements_.size() != other_list.elements_.size()) { | |||
| return false; | |||
| } | |||
| @@ -94,8 +94,8 @@ bool List::operator==(const Type& other) const { | |||
| return true; | |||
| } | |||
| Class::Class(const Named& tag, const ClassAttrVector& attributes, | |||
| const std::unordered_map<std::string, ValuePtr>& methods) | |||
| Class::Class(const Named &tag, const ClassAttrVector &attributes, | |||
| const std::unordered_map<std::string, ValuePtr> &methods) | |||
| : Object(kObjectTypeClass, false), attributes_(attributes), tag_(tag), methods_(methods) {} | |||
| std::string List::ToString() const { | |||
| @@ -122,7 +122,7 @@ std::string List::DumpText() const { | |||
| return buffer.str(); | |||
| } | |||
| bool Class::operator==(const Type& other) const { | |||
| bool Class::operator==(const Type &other) const { | |||
| // Class is cached for each pyobj in ParseDataClass, so ClassPtr is one by one map to pyobj. | |||
| return &other == this; | |||
| } | |||
| @@ -143,7 +143,7 @@ std::string Class::ToString() const { | |||
| } else { | |||
| bool begin = true; | |||
| buffer << "cls." << tag_ << "["; | |||
| for (auto& attr : attributes_) { | |||
| for (auto &attr : attributes_) { | |||
| if (!begin) { | |||
| buffer << ", "; | |||
| } else { | |||
| @@ -163,7 +163,7 @@ std::string Class::DumpText() const { | |||
| } else { | |||
| bool begin = true; | |||
| buffer << "Cls." << tag_ << "["; | |||
| for (auto& attr : attributes_) { | |||
| for (auto &attr : attributes_) { | |||
| if (!begin) { | |||
| buffer << ", "; | |||
| } else { | |||
| @@ -182,17 +182,17 @@ TypePtr Tuple::DeepCopy() const { | |||
| } else { | |||
| TypePtrList elements; | |||
| (void)std::transform(elements_.begin(), elements_.end(), std::back_inserter(elements), | |||
| [](const TypePtr& ele) { return ele->DeepCopy(); }); | |||
| [](const TypePtr &ele) { return ele->DeepCopy(); }); | |||
| auto copy = std::make_shared<Tuple>(elements); | |||
| return copy; | |||
| } | |||
| } | |||
| bool Tuple::operator==(const Type& other) const { | |||
| bool Tuple::operator==(const Type &other) const { | |||
| if (!IsSameObjectType(*this, other)) { | |||
| return false; | |||
| } | |||
| auto other_tuple = static_cast<const Tuple&>(other); | |||
| auto other_tuple = static_cast<const Tuple &>(other); | |||
| if (elements_.size() != other_tuple.elements_.size()) { | |||
| return false; | |||
| } | |||
| @@ -242,7 +242,7 @@ TypePtr Dictionary::DeepCopy() const { | |||
| std::vector<std::pair<std::string, TypePtr>> kv; | |||
| (void)std::transform( | |||
| key_values_.begin(), key_values_.end(), std::back_inserter(kv), | |||
| [](const std::pair<std::string, TypePtr>& item) { return std::make_pair(item.first, item.second->DeepCopy()); }); | |||
| [](const std::pair<std::string, TypePtr> &item) { return std::make_pair(item.first, item.second->DeepCopy()); }); | |||
| return std::make_shared<Dictionary>(kv); | |||
| } | |||
| } | |||
| @@ -259,7 +259,7 @@ std::string Dictionary::ToString() const { | |||
| std::ostringstream buffer; | |||
| std::vector<std::string> keys; | |||
| std::vector<TypePtr> values; | |||
| for (const auto& kv : key_values_) { | |||
| for (const auto &kv : key_values_) { | |||
| keys.push_back(kv.first); | |||
| values.push_back(kv.second); | |||
| } | |||
| @@ -276,12 +276,12 @@ std::string Dictionary::ToString() const { | |||
| std::string Dictionary::DumpText() const { return ToString(); } | |||
| bool Dictionary::operator==(const mindspore::Type& other) const { | |||
| bool Dictionary::operator==(const mindspore::Type &other) const { | |||
| if (!IsSameObjectType(*this, other)) { | |||
| return false; | |||
| } | |||
| const auto& other_dict = static_cast<const Dictionary&>(other); | |||
| const auto &other_dict = static_cast<const Dictionary &>(other); | |||
| if (key_values_.size() != other_dict.key_values_.size()) { | |||
| return false; | |||
| } | |||
| @@ -40,10 +40,10 @@ namespace mindspore { | |||
| class List : public Object { | |||
| public: | |||
| List() : Object(kObjectTypeList) {} | |||
| List(const std::initializer_list<TypePtr>& objs) | |||
| List(const std::initializer_list<TypePtr> &objs) | |||
| : Object(kObjectTypeList, false), elements_(objs.begin(), objs.end()) {} | |||
| // Shadow copy; | |||
| explicit List(const TypePtrList& obj) : Object(kObjectTypeList, false), elements_(obj) {} | |||
| explicit List(const TypePtrList &obj) : Object(kObjectTypeList, false), elements_(obj) {} | |||
| ~List() override {} | |||
| MS_DECLARE_PARENT(List, Object) | |||
| @@ -51,7 +51,7 @@ class List : public Object { | |||
| TypeId generic_type_id() const override { return kObjectTypeList; } | |||
| TypePtr DeepCopy() const override; | |||
| bool operator==(const Type& other) const override; | |||
| bool operator==(const Type &other) const override; | |||
| std::size_t size() const { return elements_.size(); } | |||
| TypePtrList elements() const { return elements_; } | |||
| std::string ToString() const override; | |||
| @@ -68,22 +68,22 @@ using ClassAttrVector = std::vector<std::pair<std::string, TypePtr>>; | |||
| class Class : public Object { | |||
| public: | |||
| Class() : Object(kObjectTypeClass), tag_(Named("Class")) {} | |||
| Class(const Named& tag, const ClassAttrVector& attributes, const std::unordered_map<std::string, ValuePtr>& methods); | |||
| Class(const Named &tag, const ClassAttrVector &attributes, const std::unordered_map<std::string, ValuePtr> &methods); | |||
| ~Class() override {} | |||
| MS_DECLARE_PARENT(Class, Object) | |||
| TypeId generic_type_id() const override { return kObjectTypeClass; } | |||
| bool operator==(const Type& other) const override; | |||
| bool operator==(const Type &other) const override; | |||
| TypePtr DeepCopy() const override; | |||
| std::string ToString() const override; | |||
| std::string DumpText() const override; | |||
| void set_value(const std::unordered_map<std::string, ValuePtr>& v) { attributes_value_ = v; } | |||
| void set_value(const std::unordered_map<std::string, ValuePtr> &v) { attributes_value_ = v; } | |||
| Named tag() { return tag_; } | |||
| std::unordered_map<std::string, ValuePtr> GetValue() { return attributes_value_; } | |||
| std::unordered_map<std::string, ValuePtr> methods() { return methods_; } | |||
| ClassAttrVector& GetAttributes() { return attributes_; } | |||
| ClassAttrVector &GetAttributes() { return attributes_; } | |||
| ClassAttrVector attributes_; | |||
| @@ -99,11 +99,11 @@ class Tuple : public Object { | |||
| public: | |||
| Tuple() : Object(kObjectTypeTuple) {} | |||
| // usage : Tuple t = {std::make_shared<Bool>(), std::make_shared<Int>(32)}; | |||
| Tuple(const std::initializer_list<TypePtr>& objs) | |||
| Tuple(const std::initializer_list<TypePtr> &objs) | |||
| : Object(kObjectTypeTuple, false), elements_(objs.begin(), objs.end()) {} | |||
| // Shadow copy | |||
| explicit Tuple(const TypePtrList& objs) : Object(kObjectTypeTuple, false), elements_(objs.begin(), objs.end()) {} | |||
| explicit Tuple(const TypePtrList &objs) : Object(kObjectTypeTuple, false), elements_(objs.begin(), objs.end()) {} | |||
| ~Tuple() override {} | |||
| MS_DECLARE_PARENT(Tuple, Object) | |||
| @@ -115,7 +115,7 @@ class Tuple : public Object { | |||
| std::string ToReprString() const override { return "tuple_"; } | |||
| std::string DumpText() const override; | |||
| const TypePtr operator[](size_t dim) const; | |||
| bool operator==(const Type& other) const override; | |||
| bool operator==(const Type &other) const override; | |||
| TypePtrList elements() const { return elements_; } | |||
| std::size_t size() const { return elements_.size(); } | |||
| @@ -128,7 +128,7 @@ using TuplePtr = std::shared_ptr<Tuple>; | |||
| class Dictionary : public Object { | |||
| public: | |||
| Dictionary() : Object(kObjectTypeDictionary) {} | |||
| explicit Dictionary(const std::vector<std::pair<std::string, TypePtr>>& key_values) | |||
| explicit Dictionary(const std::vector<std::pair<std::string, TypePtr>> &key_values) | |||
| : Object(kObjectTypeDictionary, false), key_values_(key_values) {} | |||
| ~Dictionary() override {} | |||
| @@ -136,7 +136,7 @@ class Dictionary : public Object { | |||
| TypeId generic_type_id() const override { return kObjectTypeDictionary; } | |||
| bool operator==(const Type& other) const override; | |||
| bool operator==(const Type &other) const override; | |||
| TypePtr DeepCopy() const override; | |||
| std::string ToString() const override; | |||
| std::string DumpText() const override; | |||
| @@ -24,11 +24,11 @@ | |||
| #include "pybind_api/export_flags.h" | |||
| namespace mindspore { | |||
| bool Number::operator==(const Type& other) const { | |||
| bool Number::operator==(const Type &other) const { | |||
| if (!IsSameObjectType(*this, other)) { | |||
| return false; | |||
| } | |||
| auto other_number = static_cast<const Number&>(other); | |||
| auto other_number = static_cast<const Number &>(other); | |||
| return ((number_type_ == other_number.number_type_) && (nbits_ == other_number.nbits_)); | |||
| } | |||
| @@ -49,12 +49,12 @@ class Number : public Object { | |||
| TypeId type_id() const override { return number_type_; } | |||
| TypeId generic_type_id() const override { return kObjectTypeNumber; } | |||
| bool operator==(const Type& other) const override; | |||
| bool operator==(const Type &other) const override; | |||
| TypePtr DeepCopy() const override { return std::make_shared<Number>(); } | |||
| std::string ToString() const override { return "Number"; } | |||
| std::string ToReprString() const override { return "number"; } | |||
| std::string DumpText() const override { return "Number"; } | |||
| std::string GetTypeName(const std::string& type_name) const { | |||
| std::string GetTypeName(const std::string &type_name) const { | |||
| std::ostringstream oss; | |||
| oss << type_name; | |||
| if (nbits() != 0) { | |||
| @@ -51,7 +51,7 @@ class RefKeyType : public Object { | |||
| class RefType : public Object { | |||
| public: | |||
| RefType() : Object(kObjectTypeRef) {} | |||
| RefType(const TypePtr& subtype, const TypePtr& subtype_origin) | |||
| RefType(const TypePtr &subtype, const TypePtr &subtype_origin) | |||
| : Object(kObjectTypeRef, false), subtype_(subtype), subtype_origin_(subtype_origin) {} | |||
| ~RefType() override {} | |||
| MS_DECLARE_PARENT(RefType, Object) | |||
| @@ -69,7 +69,7 @@ TypeId FloatBitsToTypeId(const int nbits) { | |||
| } | |||
| } | |||
| const char* MetaIdLabel(const TypeId& v) { | |||
| const char *MetaIdLabel(const TypeId &v) { | |||
| switch (v) { | |||
| case kTypeUnknown: | |||
| return "kTypeUnknown"; | |||
| @@ -92,7 +92,7 @@ const char* MetaIdLabel(const TypeId& v) { | |||
| } | |||
| } | |||
| const char* ObjectIdLabel(const TypeId& v) { | |||
| const char *ObjectIdLabel(const TypeId &v) { | |||
| switch (v) { | |||
| case kObjectTypeNumber: | |||
| return "kObjectTypeNumber"; | |||
| @@ -129,7 +129,7 @@ const char* ObjectIdLabel(const TypeId& v) { | |||
| } | |||
| } | |||
| const char* NumberIdLabel(const TypeId& v) { | |||
| const char *NumberIdLabel(const TypeId &v) { | |||
| switch (v) { | |||
| case kNumberTypeBool: | |||
| return "kNumberTypeBool"; | |||
| @@ -166,7 +166,7 @@ const char* NumberIdLabel(const TypeId& v) { | |||
| } | |||
| } | |||
| const char* TypeIdLabel(const TypeId& v) { | |||
| const char *TypeIdLabel(const TypeId &v) { | |||
| if (v < kMetaTypeEnd) { | |||
| return MetaIdLabel(v); | |||
| } else { | |||
| @@ -190,14 +190,14 @@ TypeId NormalizeTypeId(const TypeId type_id) { | |||
| } | |||
| } | |||
| bool IsSameObjectType(const Type& lhs, const Type& rhs) { | |||
| bool IsSameObjectType(const Type &lhs, const Type &rhs) { | |||
| if ((lhs.meta_type() != kMetaTypeObject) || (rhs.meta_type() != kMetaTypeObject)) { | |||
| return false; | |||
| } | |||
| return lhs.object_type() == rhs.object_type(); | |||
| } | |||
| size_t GetTypeByte(const TypePtr& type_ptr) { | |||
| size_t GetTypeByte(const TypePtr &type_ptr) { | |||
| if (type_ptr && type_ptr->isa<Number>()) { | |||
| auto number = dyn_cast<Number>(type_ptr); | |||
| if (!number) { | |||
| @@ -212,9 +212,9 @@ size_t GetTypeByte(const TypePtr& type_ptr) { | |||
| } | |||
| } | |||
| bool Type::operator==(const Value& other) const { | |||
| bool Type::operator==(const Value &other) const { | |||
| if (other.isa<Type>()) { | |||
| auto other_type = static_cast<const Type*>(&other); | |||
| auto other_type = static_cast<const Type *>(&other); | |||
| return *this == *other_type; | |||
| } else { | |||
| return false; | |||
| @@ -226,12 +226,12 @@ abstract::AbstractBasePtr Type::ToAbstract() { | |||
| return ptr; | |||
| } | |||
| std::ostream& operator<<(std::ostream& os, const Type& type) { | |||
| std::ostream &operator<<(std::ostream &os, const Type &type) { | |||
| os << type.ToString(); | |||
| return os; | |||
| } | |||
| std::ostream& operator<<(std::ostream& os, const TypePtr type) { | |||
| std::ostream &operator<<(std::ostream &os, const TypePtr type) { | |||
| os << type->ToString(); | |||
| return os; | |||
| } | |||
| @@ -244,17 +244,17 @@ bool Object::equal(const TypePtr other) const { | |||
| return false; | |||
| } | |||
| std::ostream& operator<<(std::ostream& os, const Object& obj) { | |||
| std::ostream &operator<<(std::ostream &os, const Object &obj) { | |||
| os << obj.ToString(); | |||
| return os; | |||
| } | |||
| std::ostream& operator<<(std::ostream& os, const std::shared_ptr<Object> obj) { | |||
| std::ostream &operator<<(std::ostream &os, const std::shared_ptr<Object> obj) { | |||
| os << obj->ToString(); | |||
| return os; | |||
| } | |||
| std::ostream& operator<<(std::ostream& os, const TypePtrList& types) { | |||
| std::ostream &operator<<(std::ostream &os, const TypePtrList &types) { | |||
| os << "["; | |||
| for (size_t i = 0; i < types.size(); ++i) { | |||
| if (i > 0) { | |||
| @@ -95,10 +95,10 @@ enum TypeId : int { | |||
| TypeId IntBitsToTypeId(const int nbits); | |||
| TypeId UIntBitsToTypeId(const int nbits); | |||
| TypeId FloatBitsToTypeId(const int nbits); | |||
| const char* TypeIdLabel(const TypeId& v); | |||
| const char *TypeIdLabel(const TypeId &v); | |||
| TypeId NormalizeTypeId(const TypeId type_id); | |||
| bool IsSameObjectType(const Type& lhs, const Type& rhs); | |||
| size_t GetTypeByte(const TypePtr& type_ptr); | |||
| bool IsSameObjectType(const Type &lhs, const Type &rhs); | |||
| size_t GetTypeByte(const TypePtr &type_ptr); | |||
| // Base class for all types | |||
| // forward declaration. | |||
| @@ -110,14 +110,14 @@ class Type : public Value { | |||
| ~Type() override = default; | |||
| MS_DECLARE_PARENT(Type, Value) | |||
| bool operator==(const Value& other) const override; | |||
| bool operator==(const Value &other) const override; | |||
| TypeId meta_type() const { return meta_type_; } | |||
| virtual TypeId type_id() const { return meta_type_; } | |||
| virtual TypeId generic_type_id() const { return kMetaTypeType; } | |||
| virtual bool operator!=(const Type& other) const { return !(*this == other); } | |||
| virtual bool operator==(const Type& other) const { return this->type_id() == other.type_id(); } | |||
| virtual bool operator!=(const Type &other) const { return !(*this == other); } | |||
| virtual bool operator==(const Type &other) const { return this->type_id() == other.type_id(); } | |||
| virtual bool equal(const TypePtr other) const { return *this == *other; } | |||
| virtual TypeId object_type() const { return kTypeUnknown; } | |||
| @@ -134,8 +134,8 @@ class Type : public Value { | |||
| bool IsUnknown() const { return (meta_type_ == kMetaTypeType); } | |||
| bool IsGeneric() const { return is_generic_; } | |||
| abstract::AbstractBasePtr ToAbstract() override; | |||
| friend std::ostream& operator<<(std::ostream& os, const Type& type); | |||
| friend std::ostream& operator<<(std::ostream& os, const TypePtr type); | |||
| friend std::ostream &operator<<(std::ostream &os, const Type &type); | |||
| friend std::ostream &operator<<(std::ostream &os, const TypePtr type); | |||
| const bool parse_info_ = true; | |||
| @@ -163,14 +163,14 @@ class Object : public Type { | |||
| bool equal(const TypePtr other) const override; | |||
| std::string ToString() const override { return std::string("Object:") + TypeIdLabel(object_type_); } | |||
| friend std::ostream& operator<<(std::ostream& os, const Object& obj); | |||
| friend std::ostream& operator<<(std::ostream& os, const std::shared_ptr<Object> obj); | |||
| friend std::ostream &operator<<(std::ostream &os, const Object &obj); | |||
| friend std::ostream &operator<<(std::ostream &os, const std::shared_ptr<Object> obj); | |||
| private: | |||
| const TypeId object_type_; | |||
| }; | |||
| std::ostream& operator<<(std::ostream& os, const TypePtrList& types); | |||
| std::ostream &operator<<(std::ostream &os, const TypePtrList &types); | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_IR_DTYPE_TYPE_H_ | |||
| @@ -61,7 +61,7 @@ FuncGraph::FuncGraph() | |||
| AbstractFunctionPtr FuncGraph::abstract() { | |||
| AbstractBasePtrList args_spec_list; | |||
| for (auto& p : parameters_) { | |||
| for (auto &p : parameters_) { | |||
| MS_EXCEPTION_IF_NULL(p); | |||
| if (p->abstract() == nullptr) { | |||
| MS_LOG(ERROR) << "Error!!"; | |||
| @@ -78,7 +78,7 @@ AbstractFunctionPtr FuncGraph::abstract() { | |||
| return std::make_shared<VirtualAbstractClosure>(args_spec_list, output()->abstract()); | |||
| } | |||
| abstract::AbstractBasePtr FuncGraph::MakeAbstractClosure(const abstract::AnalysisContextPtr& context) { | |||
| abstract::AbstractBasePtr FuncGraph::MakeAbstractClosure(const abstract::AnalysisContextPtr &context) { | |||
| AnalysisContextPtr temp_context = context; | |||
| if (temp_context == nullptr) { | |||
| temp_context = abstract::AnalysisContext::DummyContext(); | |||
| @@ -96,7 +96,7 @@ AnfNodePtr FuncGraph::output() const { | |||
| } | |||
| } | |||
| void FuncGraph::set_output(const AnfNodePtr& value, bool force_new_ret) { | |||
| void FuncGraph::set_output(const AnfNodePtr &value, bool force_new_ret) { | |||
| if (force_new_ret || return_ == nullptr) { | |||
| std::vector<AnfNodePtr> params({NewValueNode(prim::kPrimReturn), value}); | |||
| FuncGraphPtr this_graph = shared_from_base<FuncGraph>(); | |||
| @@ -125,7 +125,7 @@ ParameterPtr FuncGraph::add_parameter() { | |||
| return p; | |||
| } | |||
| void FuncGraph::add_parameter(const ParameterPtr& p) { | |||
| void FuncGraph::add_parameter(const ParameterPtr &p) { | |||
| if (manager_.lock()) { | |||
| std::vector<AnfNodePtr> new_params = parameters_; | |||
| new_params.push_back(p); | |||
| @@ -135,7 +135,7 @@ void FuncGraph::add_parameter(const ParameterPtr& p) { | |||
| } | |||
| } | |||
| ParameterPtr FuncGraph::AddWeightParameter(const std::string& name) { | |||
| ParameterPtr FuncGraph::AddWeightParameter(const std::string &name) { | |||
| FuncGraphPtr this_graph = shared_from_base<FuncGraph>(); | |||
| ParameterPtr p = std::make_shared<Parameter>(this_graph); | |||
| p->set_name(name); | |||
| @@ -154,14 +154,14 @@ ParameterPtr FuncGraph::AddWeightParameter(const std::string& name) { | |||
| return p; | |||
| } | |||
| bool FuncGraph::has_flag(const std::string& flag) { | |||
| bool FuncGraph::has_flag(const std::string &flag) { | |||
| if (flags_.count(flag)) { | |||
| return flags_[flag]; | |||
| } | |||
| return false; | |||
| } | |||
| CNodePtr FuncGraph::NewCNode(const std::vector<AnfNodePtr>& inputs) { | |||
| CNodePtr FuncGraph::NewCNode(const std::vector<AnfNodePtr> &inputs) { | |||
| CNodePtr cnode = std::make_shared<CNode>(inputs, shared_from_base<FuncGraph>()); | |||
| if (has_flag(GRAPH_FLAG_HAS_EFFECT)) { | |||
| order_.push_back(cnode); | |||
| @@ -170,7 +170,7 @@ CNodePtr FuncGraph::NewCNode(const std::vector<AnfNodePtr>& inputs) { | |||
| return cnode; | |||
| } | |||
| CNodePtr FuncGraph::NewCNodeWithScope(const std::vector<AnfNodePtr>& inputs, const ScopePtr& scope) { | |||
| CNodePtr FuncGraph::NewCNodeWithScope(const std::vector<AnfNodePtr> &inputs, const ScopePtr &scope) { | |||
| CNodePtr app = NewCNode(inputs); | |||
| app->set_scope(scope); | |||
| return app; | |||
| @@ -178,13 +178,13 @@ CNodePtr FuncGraph::NewCNodeWithScope(const std::vector<AnfNodePtr>& inputs, con | |||
| void FuncGraph::DumpCNodeList() { | |||
| MS_LOG(INFO) << "FuncGraph " << ToString() << " has following CNode in code order:"; | |||
| for (const auto& cnode : order_) { | |||
| for (const auto &cnode : order_) { | |||
| MS_LOG(INFO) << cnode->DebugString(); | |||
| } | |||
| } | |||
| std::string FuncGraph::ToString() const { | |||
| return mindspore::label_manage::Label(const_cast<FuncGraph*>(this)->shared_from_base<FuncGraph>()->debug_info()); | |||
| return mindspore::label_manage::Label(const_cast<FuncGraph *>(this)->shared_from_base<FuncGraph>()->debug_info()); | |||
| } | |||
| GraphDebugInfoPtr FuncGraph::debug_info() { | |||
| @@ -195,38 +195,38 @@ GraphDebugInfoPtr FuncGraph::debug_info() { | |||
| return this->debug_info_; | |||
| } | |||
| const AnfNodeSet& FuncGraph::nodes() { | |||
| const AnfNodeSet &FuncGraph::nodes() { | |||
| auto mng = manager_.lock(); | |||
| MS_EXCEPTION_IF_NULL(mng); | |||
| auto& nodes = mng->nodes(); | |||
| auto &nodes = mng->nodes(); | |||
| return nodes[shared_from_base<FuncGraph>()]; | |||
| } | |||
| const AnfNodeCounterMap& FuncGraph::value_nodes() { | |||
| const AnfNodeCounterMap &FuncGraph::value_nodes() { | |||
| auto mng = manager_.lock(); | |||
| MS_EXCEPTION_IF_NULL(mng); | |||
| auto& cts = mng->valuenodes(); | |||
| auto &cts = mng->valuenodes(); | |||
| return cts[shared_from_base<FuncGraph>()]; | |||
| } | |||
| const AnfNodeCounterMap& FuncGraph::free_variables_direct() { | |||
| const AnfNodeCounterMap &FuncGraph::free_variables_direct() { | |||
| auto mng = manager_.lock(); | |||
| MS_EXCEPTION_IF_NULL(mng); | |||
| auto& fv_direct = mng->free_variables_direct(); | |||
| auto &fv_direct = mng->free_variables_direct(); | |||
| return fv_direct[shared_from_base<FuncGraph>()]; | |||
| } | |||
| const BaseRefCounterMap& FuncGraph::free_variables_total() { | |||
| const BaseRefCounterMap &FuncGraph::free_variables_total() { | |||
| auto mng = manager_.lock(); | |||
| MS_EXCEPTION_IF_NULL(mng); | |||
| auto& fv_total = mng->free_variables_total(); | |||
| auto &fv_total = mng->free_variables_total(); | |||
| return fv_total[shared_from_base<FuncGraph>()]; | |||
| } | |||
| std::vector<AnfNodePtr> FuncGraph::free_variables_nodes() { | |||
| std::vector<AnfNodePtr> nodes; | |||
| const auto& fv_total = this->free_variables_total(); | |||
| for (auto& p : fv_total) { | |||
| const auto &fv_total = this->free_variables_total(); | |||
| for (auto &p : fv_total) { | |||
| auto key = p.first; | |||
| if (utils::isa<AnfNodePtr>(key)) { | |||
| nodes.push_back(utils::cast<AnfNodePtr>(key)); | |||
| @@ -238,8 +238,8 @@ std::vector<AnfNodePtr> FuncGraph::free_variables_nodes() { | |||
| std::vector<FuncGraphPtr> FuncGraph::free_variables_func_graphs() { | |||
| std::vector<FuncGraphPtr> func_graphs; | |||
| const auto& fv_total = this->free_variables_total(); | |||
| for (auto& p : fv_total) { | |||
| const auto &fv_total = this->free_variables_total(); | |||
| for (auto &p : fv_total) { | |||
| auto key = p.first; | |||
| if (utils::isa<FuncGraphPtr>(key)) { | |||
| func_graphs.push_back(utils::cast<FuncGraphPtr>(key)); | |||
| @@ -249,31 +249,31 @@ std::vector<FuncGraphPtr> FuncGraph::free_variables_func_graphs() { | |||
| return func_graphs; | |||
| } | |||
| const FuncGraphCounterMap& FuncGraph::func_graphs_used() { | |||
| const FuncGraphCounterMap &FuncGraph::func_graphs_used() { | |||
| auto mng = manager_.lock(); | |||
| MS_EXCEPTION_IF_NULL(mng); | |||
| auto& used = mng->func_graphs_used(); | |||
| auto &used = mng->func_graphs_used(); | |||
| return used[shared_from_base<FuncGraph>()]; | |||
| } | |||
| const FuncGraphSet& FuncGraph::func_graphs_used_total() { | |||
| const FuncGraphSet &FuncGraph::func_graphs_used_total() { | |||
| auto mng = manager_.lock(); | |||
| MS_EXCEPTION_IF_NULL(mng); | |||
| auto& used = mng->func_graphs_used_total(shared_from_base<FuncGraph>()); | |||
| auto &used = mng->func_graphs_used_total(shared_from_base<FuncGraph>()); | |||
| return used; | |||
| } | |||
| const FuncGraphCounterMap& FuncGraph::func_graph_users() { | |||
| const FuncGraphCounterMap &FuncGraph::func_graph_users() { | |||
| auto mng = manager_.lock(); | |||
| MS_EXCEPTION_IF_NULL(mng); | |||
| auto& users = mng->func_graph_users(); | |||
| auto &users = mng->func_graph_users(); | |||
| return users[shared_from_base<FuncGraph>()]; | |||
| } | |||
| const AnfNodeCounterMap& FuncGraph::func_graph_user_cnodes() { | |||
| const AnfNodeCounterMap &FuncGraph::func_graph_user_cnodes() { | |||
| auto mng = manager_.lock(); | |||
| MS_EXCEPTION_IF_NULL(mng); | |||
| auto& users = mng->func_graph_user_cnodes(); | |||
| auto &users = mng->func_graph_user_cnodes(); | |||
| return users[shared_from_base<FuncGraph>()]; | |||
| } | |||
| @@ -288,13 +288,13 @@ FuncGraphPtr FuncGraph::parent() { | |||
| return mng->parent(shared_from_base<FuncGraph>()); | |||
| } | |||
| const FuncGraphSet& FuncGraph::children() { | |||
| const FuncGraphSet &FuncGraph::children() { | |||
| auto mng = manager_.lock(); | |||
| MS_EXCEPTION_IF_NULL(mng); | |||
| return mng->children(shared_from_base<FuncGraph>()); | |||
| } | |||
| const FuncGraphSet& FuncGraph::scope() { | |||
| const FuncGraphSet &FuncGraph::scope() { | |||
| auto mng = manager_.lock(); | |||
| MS_EXCEPTION_IF_NULL(mng); | |||
| return mng->scopes(shared_from_base<FuncGraph>()); | |||
| @@ -312,9 +312,9 @@ std::shared_ptr<std::list<FuncGraphPtr>> FuncGraph::recursive_graphs() { | |||
| return mng->recursive_graphs(shared_from_base<FuncGraph>()); | |||
| } | |||
| void FuncGraph::DumpFuncGraph(const std::string& path) { draw::Draw(path + ".dot", shared_from_base<FuncGraph>()); } | |||
| void FuncGraph::DumpFuncGraph(const std::string &path) { draw::Draw(path + ".dot", shared_from_base<FuncGraph>()); } | |||
| AnfNodePtr FuncGraph::GetDefaultValueByName(const std::string& name) { | |||
| AnfNodePtr FuncGraph::GetDefaultValueByName(const std::string &name) { | |||
| auto itr = this->parameter_default_value_.find(name); | |||
| if (itr == parameter_default_value_.end()) { | |||
| return nullptr; | |||
| @@ -330,9 +330,9 @@ AnfNodePtr FuncGraph::GetDefaultValueByName(const std::string& name) { | |||
| } | |||
| // set the default values | |||
| void FuncGraph::SetDefaultValues(const std::vector<std::string>& name_list, const std::vector<AnfNodePtr>& value_list) { | |||
| void FuncGraph::SetDefaultValues(const std::vector<std::string> &name_list, const std::vector<AnfNodePtr> &value_list) { | |||
| auto all_is_null = std::all_of(value_list.begin(), value_list.end(), | |||
| [](const AnfNodePtr& node) { return IsValueNode<NullObj>(node); }); | |||
| [](const AnfNodePtr &node) { return IsValueNode<NullObj>(node); }); | |||
| if (value_list.empty()) { | |||
| all_is_null = true; | |||
| } | |||
| @@ -348,7 +348,7 @@ void FuncGraph::ClearDefaultValues() { parameter_default_value_.clear(); } | |||
| size_t FuncGraph::GetDefaultValueCount() { | |||
| int null_count = | |||
| std::count_if(parameter_default_value_.begin(), parameter_default_value_.end(), | |||
| [](const std::pair<std::string, AnfNodePtr>& pair) { return IsValueNode<NullObj>(pair.second); }); | |||
| [](const std::pair<std::string, AnfNodePtr> &pair) { return IsValueNode<NullObj>(pair.second); }); | |||
| return parameter_default_value_.size() - IntToSize(null_count); | |||
| } | |||
| @@ -425,7 +425,7 @@ int FuncGraph::GetPositionalArgsCount() const { | |||
| return count - kwonlyargs_count_ - SizeToInt(hyper_param_count_); | |||
| } | |||
| AnfNodePtr FuncGraph::GetParameterByName(const std::string& name) { | |||
| AnfNodePtr FuncGraph::GetParameterByName(const std::string &name) { | |||
| for (size_t i = 0; i < parameters_.size(); ++i) { | |||
| MS_EXCEPTION_IF_NULL(parameters_[i]); | |||
| auto param_cast = parameters_[i]->cast<ParameterPtr>(); | |||
| @@ -437,9 +437,9 @@ AnfNodePtr FuncGraph::GetParameterByName(const std::string& name) { | |||
| return nullptr; | |||
| } | |||
| void FuncGraph::GenerateVarParams(const FuncGraphPtr& specialized_graph, | |||
| std::vector<AnfNodePtr>* specialized_parameter_list, | |||
| std::unordered_map<AnfNodePtr, AnfNodePtr>* repl_nodes, int variable_args_count, | |||
| void FuncGraph::GenerateVarParams(const FuncGraphPtr &specialized_graph, | |||
| std::vector<AnfNodePtr> *specialized_parameter_list, | |||
| std::unordered_map<AnfNodePtr, AnfNodePtr> *repl_nodes, int variable_args_count, | |||
| int pos_args_input_count) { | |||
| // if there is variable argument, pass the input arguments that does not match positional args to it as a tuple | |||
| if (specialized_graph->has_vararg()) { | |||
| @@ -472,14 +472,14 @@ void FuncGraph::GenerateVarParams(const FuncGraphPtr& specialized_graph, | |||
| } | |||
| } | |||
| void FuncGraph::GenerateKwParams(const FuncGraphPtr& specialized_graph, | |||
| std::vector<AnfNodePtr>* specialized_parameter_list, | |||
| const std::vector<abstract::AbstractKeywordArgPtr>& kwarg_list, | |||
| std::unordered_map<AnfNodePtr, AnfNodePtr>* repl_nodes) { | |||
| void FuncGraph::GenerateKwParams(const FuncGraphPtr &specialized_graph, | |||
| std::vector<AnfNodePtr> *specialized_parameter_list, | |||
| const std::vector<abstract::AbstractKeywordArgPtr> &kwarg_list, | |||
| std::unordered_map<AnfNodePtr, AnfNodePtr> *repl_nodes) { | |||
| std::vector<AnfNodePtr> kwarg_keys_tuple_nodes = {NewValueNode(prim::kPrimMakeTuple)}; | |||
| std::vector<AnfNodePtr> kwarg_values_tuple_nodes = {NewValueNode(prim::kPrimMakeTuple)}; | |||
| for (const auto& kwarg : kwarg_list) { | |||
| for (const auto &kwarg : kwarg_list) { | |||
| MS_EXCEPTION_IF_NULL(kwarg); | |||
| std::string kw_param_name = kwarg->get_key(); | |||
| MS_EXCEPTION_IF_NULL(specialized_graph); | |||
| @@ -493,7 +493,7 @@ void FuncGraph::GenerateKwParams(const FuncGraphPtr& specialized_graph, | |||
| std::string param_name = specialized_graph->GetVariableKwargName() + "[" + kw_param_name + "]"; | |||
| MS_EXCEPTION_IF_NULL(specialized_parameter_list); | |||
| auto find_kw_arg_in_list = std::any_of(specialized_parameter_list->begin(), specialized_parameter_list->end(), | |||
| [param_name](const AnfNodePtr& node) { | |||
| [param_name](const AnfNodePtr &node) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| auto param = node->cast<ParameterPtr>(); | |||
| return param != nullptr && param->name() == param_name; | |||
| @@ -526,10 +526,10 @@ void FuncGraph::GenerateKwParams(const FuncGraphPtr& specialized_graph, | |||
| GenerateKwargReplNode(specialized_graph, repl_nodes, kwarg_keys_tuple_nodes, kwarg_values_tuple_nodes); | |||
| } | |||
| void FuncGraph::GenerateKwargReplNode(const FuncGraphPtr& specialized_graph, | |||
| std::unordered_map<AnfNodePtr, AnfNodePtr>* repl_nodes, | |||
| const std::vector<AnfNodePtr>& kwarg_keys_tuple_nodes, | |||
| const std::vector<AnfNodePtr>& kwarg_values_tuple_nodes) { | |||
| void FuncGraph::GenerateKwargReplNode(const FuncGraphPtr &specialized_graph, | |||
| std::unordered_map<AnfNodePtr, AnfNodePtr> *repl_nodes, | |||
| const std::vector<AnfNodePtr> &kwarg_keys_tuple_nodes, | |||
| const std::vector<AnfNodePtr> &kwarg_values_tuple_nodes) { | |||
| if (has_kwarg()) { | |||
| MS_EXCEPTION_IF_NULL(specialized_graph); | |||
| TraceManager::DebugTrace( | |||
| @@ -544,7 +544,7 @@ void FuncGraph::GenerateKwargReplNode(const FuncGraphPtr& specialized_graph, | |||
| } | |||
| } | |||
| bool FuncGraph::NeedGenerate(const std::vector<abstract::AbstractKeywordArgPtr>& kwarg_list) { | |||
| bool FuncGraph::NeedGenerate(const std::vector<abstract::AbstractKeywordArgPtr> &kwarg_list) { | |||
| // if the function does not have any vararg/kwarg/kwonly/default value/kw args input | |||
| // return the original graph | |||
| if (!has_vararg() && kwonlyargs_count() == 0 && !has_kwarg() && GetDefaultValueCount() == 0 && kwarg_list.empty()) { | |||
| @@ -558,9 +558,9 @@ bool FuncGraph::NeedGenerate(const std::vector<abstract::AbstractKeywordArgPtr>& | |||
| return true; | |||
| } | |||
| void FuncGraph::GenerateDefaultValue(const FuncGraphPtr& specialized_graph, | |||
| const std::vector<AnfNodePtr>& specialized_parameter_list, | |||
| std::unordered_map<AnfNodePtr, AnfNodePtr>* repl_nodes) { | |||
| void FuncGraph::GenerateDefaultValue(const FuncGraphPtr &specialized_graph, | |||
| const std::vector<AnfNodePtr> &specialized_parameter_list, | |||
| std::unordered_map<AnfNodePtr, AnfNodePtr> *repl_nodes) { | |||
| MS_EXCEPTION_IF_NULL(specialized_graph); | |||
| for (size_t i = 0; i < specialized_graph->parameters().size() - hyper_param_count(); ++i) { | |||
| auto param_node = specialized_graph->parameters()[i]; | |||
| @@ -583,10 +583,10 @@ void FuncGraph::GenerateDefaultValue(const FuncGraphPtr& specialized_graph, | |||
| } | |||
| } | |||
| FuncGraphPtr FuncGraph::GenerateGraph(const AbstractBasePtrList& args_spec_list) { | |||
| FuncGraphPtr FuncGraph::GenerateGraph(const AbstractBasePtrList &args_spec_list) { | |||
| std::vector<abstract::AbstractKeywordArgPtr> kwarg_list; | |||
| size_t arguments_count = args_spec_list.size(); | |||
| for (const auto& arg : args_spec_list) { | |||
| for (const auto &arg : args_spec_list) { | |||
| // if it is a keyword argument | |||
| MS_EXCEPTION_IF_NULL(arg); | |||
| if (arg->isa<abstract::AbstractKeywordArg>()) { | |||
| @@ -619,11 +619,11 @@ FuncGraphPtr FuncGraph::GenerateGraph(const AbstractBasePtrList& args_spec_list) | |||
| MS_EXCEPTION_IF_NULL(specialized_graph); | |||
| auto params = specialized_graph->parameters(); | |||
| (void)std::transform(params.end() - SizeToInt(hyper_param_count()), params.end(), | |||
| std::back_inserter(specialized_parameter_list), [](const AnfNodePtr& node) { return node; }); | |||
| std::back_inserter(specialized_parameter_list), [](const AnfNodePtr &node) { return node; }); | |||
| std::shared_ptr<mindspore::FuncGraphManager> manager = mindspore::Manage(specialized_graph, false); | |||
| auto tr = manager->Transact(); | |||
| for (auto& node_pair : repl_nodes) { | |||
| for (auto &node_pair : repl_nodes) { | |||
| MS_LOG(DEBUG) << "GenerateGraph replace:" << node_pair.first->DebugString() << "-" | |||
| << node_pair.second->DebugString(); | |||
| (void)tr.Replace(node_pair.first, node_pair.second); | |||
| @@ -638,7 +638,7 @@ FuncGraphPtr FuncGraph::GenerateGraph(const AbstractBasePtrList& args_spec_list) | |||
| return specialized_graph; | |||
| } | |||
| void FuncGraph::add_parameter_obj_node(const AnfNodePtr& p) { paramter_obj_nodes_.push_back(p); } | |||
| void FuncGraph::add_parameter_obj_node(const AnfNodePtr &p) { paramter_obj_nodes_.push_back(p); } | |||
| std::list<CNodePtr> FuncGraph::GetOrderedCnodes() { | |||
| if (has_flag(GRAPH_FLAG_HAS_EFFECT)) { | |||
| @@ -651,7 +651,7 @@ std::list<CNodePtr> FuncGraph::GetOrderedCnodes() { | |||
| std::list<CNodePtr> cnodes; | |||
| auto nodes = TopoSort(get_return(), SuccDepends, BelongSameGraph); | |||
| for (const auto& node : nodes) { | |||
| for (const auto &node : nodes) { | |||
| auto cnode = dyn_cast<CNode>(node); | |||
| if (cnode) { | |||
| cnodes.push_back(cnode); | |||
| @@ -679,7 +679,7 @@ void FuncGraph::EraseUnusedNodeInOrder() { | |||
| } | |||
| } | |||
| void FuncGraph::EraseUnusedNodeInOrder(const AnfNodePtr& n) { | |||
| void FuncGraph::EraseUnusedNodeInOrder(const AnfNodePtr &n) { | |||
| if (has_flag(GRAPH_FLAG_HAS_EFFECT) && n && n->isa<CNode>()) { | |||
| order_.remove(n->cast<CNodePtr>()); | |||
| MS_LOG(DEBUG) << "Remove the node" << n->DebugString() << " from order list."; | |||
| @@ -690,7 +690,7 @@ void FuncGraph::CheckOrder() { | |||
| if (has_flag(GRAPH_FLAG_HAS_EFFECT)) { | |||
| MS_LOG(DEBUG) << "Check graph " << ToString(); | |||
| for (auto it = order_.begin(); it != order_.end(); (void)it++) { | |||
| for (const auto& input_node : (*it)->inputs()) { | |||
| for (const auto &input_node : (*it)->inputs()) { | |||
| if (input_node && input_node->isa<CNode>() && input_node->func_graph() == shared_from_base<FuncGraph>()) { | |||
| // Need to reorder the wrong order node. | |||
| auto found = std::find(order_.begin(), it, input_node); | |||
| @@ -705,7 +705,7 @@ void FuncGraph::CheckOrder() { | |||
| } | |||
| auto mng = manager_.lock(); | |||
| if (mng != nullptr) { | |||
| const auto& nodes = mng->nodes()[shared_from_base<FuncGraph>()]; | |||
| const auto &nodes = mng->nodes()[shared_from_base<FuncGraph>()]; | |||
| if (nodes.size() != (order_.size() + parameters_.size())) { | |||
| DumpCNodeList(); | |||
| MS_LOG(EXCEPTION) << "CNode order size " << order_.size() << " is not equal to managed node size " | |||
| @@ -718,7 +718,7 @@ void FuncGraph::CheckOrder() { | |||
| const char kPrimHasEffect[] = "_side_effect_flag"; | |||
| bool FuncGraph::HasEffect(const CNodePtr& cnode) { | |||
| bool FuncGraph::HasEffect(const CNodePtr &cnode) { | |||
| auto prim = GetCNodePrimitive(cnode); | |||
| if (prim != nullptr && prim->isa<prim::DoSignaturePrimitive>()) { | |||
| auto do_sig = prim->cast<prim::DoSignaturePrimitivePtr>(); | |||
| @@ -739,9 +739,9 @@ bool FuncGraph::HasEffect(const CNodePtr& cnode) { | |||
| return false; | |||
| } | |||
| std::shared_ptr<OrderedSet<CNodePtr>> FindRoots(const std::vector<CNodePtr>& segment) { | |||
| std::shared_ptr<OrderedSet<CNodePtr>> FindRoots(const std::vector<CNodePtr> &segment) { | |||
| std::shared_ptr<OrderedSet<CNodePtr>> roots = std::make_shared<OrderedSet<CNodePtr>>(segment); | |||
| for (const auto& node : segment) { | |||
| for (const auto &node : segment) { | |||
| if (roots->size() == 1) { | |||
| return roots; | |||
| } | |||
| @@ -757,9 +757,9 @@ std::shared_ptr<OrderedSet<CNodePtr>> FindRoots(const std::vector<CNodePtr>& seg | |||
| return roots; | |||
| } | |||
| std::shared_ptr<OrderedSet<CNodePtr>> FindLeaves(const std::vector<CNodePtr>& segment) { | |||
| std::shared_ptr<OrderedSet<CNodePtr>> FindLeaves(const std::vector<CNodePtr> &segment) { | |||
| std::shared_ptr<OrderedSet<CNodePtr>> nodes = std::make_shared<OrderedSet<CNodePtr>>(segment); | |||
| for (const auto& node : segment) { | |||
| for (const auto &node : segment) { | |||
| if (nodes->size() == 1) { | |||
| return nodes; | |||
| } | |||
| @@ -790,7 +790,7 @@ void FuncGraph::ReleaseFullOrderToEffectOrder() { | |||
| if (has_flag(GRAPH_FLAG_HAS_EFFECT)) { | |||
| std::list<AnfNodePtr> depends_order; | |||
| std::vector<CNodePtr> segment; | |||
| for (const auto& cnode : order_) { | |||
| for (const auto &cnode : order_) { | |||
| if (IsPrimitiveCNode(cnode, prim::kPrimReturn)) { | |||
| continue; | |||
| } | |||
| @@ -830,7 +830,7 @@ void FuncGraph::ReleaseFullOrderToEffectOrder() { | |||
| } | |||
| } | |||
| void FuncGraph::SetEffectDepends(const std::vector<AnfNodePtr>& depend_inputs) { | |||
| void FuncGraph::SetEffectDepends(const std::vector<AnfNodePtr> &depend_inputs) { | |||
| auto old_ret = output(); | |||
| std::vector<AnfNodePtr> inputs{NewValueNode(prim::kPrimDepend), old_ret}; | |||
| (void)inputs.insert(inputs.end(), depend_inputs.begin(), depend_inputs.end()); | |||
| @@ -26,29 +26,29 @@ | |||
| // namespace to support intermediate representation definition | |||
| namespace mindspore { | |||
| Cloner::Cloner(const FuncGraphPtrList& func_graphs, bool clone_all_valuenodes, bool clone_all_child_graphs, | |||
| bool clone_all_used_graphs, const TraceInfoPtr& relation, const TraceInfoPtr& target_relation) | |||
| Cloner::Cloner(const FuncGraphPtrList &func_graphs, bool clone_all_valuenodes, bool clone_all_child_graphs, | |||
| bool clone_all_used_graphs, const TraceInfoPtr &relation, const TraceInfoPtr &target_relation) | |||
| : clone_all_valuenodes_(clone_all_valuenodes), | |||
| clone_all_child_graphs_(clone_all_child_graphs), | |||
| clone_all_used_graphs_(clone_all_used_graphs), | |||
| relation_(relation), | |||
| target_relation_(target_relation == nullptr ? relation : target_relation) { | |||
| for (auto& func_graph : func_graphs) { | |||
| for (auto &func_graph : func_graphs) { | |||
| AddClone(func_graph); | |||
| } | |||
| scope_ = kDefaultScope; | |||
| type_ = kBasic; | |||
| } | |||
| void Cloner::AddClone(const FuncGraphPtr& func_graph, const FuncGraphPtr& target_func_graph, | |||
| const AnfNodePtrList& params, CloneType type) { | |||
| void Cloner::AddClone(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph, | |||
| const AnfNodePtrList ¶ms, CloneType type) { | |||
| if (func_graph != nullptr) { | |||
| todo_.push_back({.origin = func_graph, .target = target_func_graph, .params = params}); | |||
| type_ = type; | |||
| } | |||
| } | |||
| void Cloner::CloneNode(const AnfNodePtr& node, const FuncGraphPtr& target) { | |||
| void Cloner::CloneNode(const AnfNodePtr &node, const FuncGraphPtr &target) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| if (repl_node_.find(node) != repl_node_.end() || node->isa<ValueNode>()) { | |||
| return; | |||
| @@ -60,7 +60,7 @@ void Cloner::CloneNode(const AnfNodePtr& node, const FuncGraphPtr& target) { | |||
| } | |||
| } | |||
| void Cloner::CloneParameter(const AnfNodePtr& node, const FuncGraphPtr& target, bool is_add) { | |||
| void Cloner::CloneParameter(const AnfNodePtr &node, const FuncGraphPtr &target, bool is_add) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| MS_EXCEPTION_IF_NULL(target); | |||
| TraceManager::DebugTrace(node->debug_info(), relation_); | |||
| @@ -77,7 +77,7 @@ void Cloner::CloneParameter(const AnfNodePtr& node, const FuncGraphPtr& target, | |||
| TraceManager::EndTrace(); | |||
| } | |||
| void Cloner::CloneCNode(const AnfNodePtr& node, const FuncGraphPtr& target) { | |||
| void Cloner::CloneCNode(const AnfNodePtr &node, const FuncGraphPtr &target) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| MS_EXCEPTION_IF_NULL(target); | |||
| TraceManager::DebugTrace(node->debug_info(), relation_); | |||
| @@ -91,7 +91,7 @@ void Cloner::CloneCNode(const AnfNodePtr& node, const FuncGraphPtr& target) { | |||
| TraceManager::EndTrace(); | |||
| } | |||
| void Cloner::CloneValueNode(const AnfNodePtr& node) { | |||
| void Cloner::CloneValueNode(const AnfNodePtr &node) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| TraceManager::DebugTrace(node->debug_info(), relation_); | |||
| ValueNodePtr new_const = NewValueNode(GetValueNode(node)); | |||
| @@ -102,7 +102,7 @@ void Cloner::CloneValueNode(const AnfNodePtr& node) { | |||
| TraceManager::EndTrace(); | |||
| } | |||
| void Cloner::CloneValueNode(const AnfNodePtr& node, const FuncGraphPtr& target) { | |||
| void Cloner::CloneValueNode(const AnfNodePtr &node, const FuncGraphPtr &target) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| MS_EXCEPTION_IF_NULL(target); | |||
| TraceManager::DebugTrace(node->debug_info(), relation_); | |||
| @@ -114,14 +114,14 @@ void Cloner::CloneValueNode(const AnfNodePtr& node, const FuncGraphPtr& target) | |||
| TraceManager::EndTrace(); | |||
| } | |||
| void Cloner::CloneValueNodes(const FuncGraphPtr& func_graph) { | |||
| void Cloner::CloneValueNodes(const FuncGraphPtr &func_graph) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| MS_EXCEPTION_IF_NULL(manager_); | |||
| if (!clone_all_valuenodes_) { | |||
| return; | |||
| } | |||
| auto& value_nodes = manager_->valuenodes()[func_graph]; | |||
| for (auto& value_node : value_nodes) { | |||
| auto &value_nodes = manager_->valuenodes()[func_graph]; | |||
| for (auto &value_node : value_nodes) { | |||
| auto old_node = value_node.first; | |||
| MS_EXCEPTION_IF_NULL(old_node); | |||
| if (repl_node_.count(old_node) == 0) { | |||
| @@ -130,38 +130,38 @@ void Cloner::CloneValueNodes(const FuncGraphPtr& func_graph) { | |||
| } | |||
| } | |||
| void Cloner::AddChildGraphs(const FuncGraphPtr& func_graph) { | |||
| void Cloner::AddChildGraphs(const FuncGraphPtr &func_graph) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| MS_EXCEPTION_IF_NULL(manager_); | |||
| if (!clone_all_child_graphs_) { | |||
| return; | |||
| } | |||
| auto& scopes = manager_->scopes(func_graph); | |||
| for (auto& graph : scopes) { | |||
| auto &scopes = manager_->scopes(func_graph); | |||
| for (auto &graph : scopes) { | |||
| if (graph != func_graph) { | |||
| todo_.push_back({graph, nullptr, {}}); | |||
| } | |||
| } | |||
| } | |||
| void Cloner::AddTotalGraphs(const FuncGraphPtr& func_graph) { | |||
| void Cloner::AddTotalGraphs(const FuncGraphPtr &func_graph) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| MS_EXCEPTION_IF_NULL(manager_); | |||
| if (!clone_all_used_graphs_) { | |||
| return; | |||
| } | |||
| auto& used_graphs = manager_->func_graphs_used()[func_graph]; | |||
| for (auto& used_graph : used_graphs) { | |||
| auto &used_graphs = manager_->func_graphs_used()[func_graph]; | |||
| for (auto &used_graph : used_graphs) { | |||
| todo_.push_back({used_graph.first, nullptr, {}}); | |||
| } | |||
| } | |||
| void Cloner::CloneFuncGraphDefaultValues(const FuncGraphPtr& func_graph, const FuncGraphPtr& target_func_graph) { | |||
| void Cloner::CloneFuncGraphDefaultValues(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| MS_EXCEPTION_IF_NULL(target_func_graph); | |||
| for (auto& item : func_graph->parameter_default_value()) { | |||
| for (auto &item : func_graph->parameter_default_value()) { | |||
| auto nodes = DeepLinkedGraphSearch(item.second); | |||
| for (auto& node : nodes) { | |||
| for (auto &node : nodes) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| if (node->isa<CNode>()) { | |||
| CloneNode(node, target_func_graph); | |||
| @@ -172,7 +172,7 @@ void Cloner::CloneFuncGraphDefaultValues(const FuncGraphPtr& func_graph, const F | |||
| } | |||
| } | |||
| void Cloner::CloneFuncGraphValueNodes(const FuncGraphPtr& func_graph, const FuncGraphPtr& target_func_graph) { | |||
| void Cloner::CloneFuncGraphValueNodes(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| MS_EXCEPTION_IF_NULL(target_func_graph); | |||
| MS_EXCEPTION_IF_NULL(manager_); | |||
| @@ -182,15 +182,15 @@ void Cloner::CloneFuncGraphValueNodes(const FuncGraphPtr& func_graph, const Func | |||
| } | |||
| target_func_graph->set_return(return_node); | |||
| auto& value_nodes = manager_->func_graph_valuenodes()[func_graph]; | |||
| for (auto& value_node : value_nodes) { | |||
| auto &value_nodes = manager_->func_graph_valuenodes()[func_graph]; | |||
| for (auto &value_node : value_nodes) { | |||
| CloneValueNode(value_node.first, target_func_graph); | |||
| } | |||
| } | |||
| void Cloner::InlineCloneParameters(const FuncGraphPtr& func_graph, const AnfNodePtrList& params) { | |||
| void Cloner::InlineCloneParameters(const FuncGraphPtr &func_graph, const AnfNodePtrList ¶ms) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| auto& old_params = func_graph->parameters(); | |||
| auto &old_params = func_graph->parameters(); | |||
| if (old_params.size() != params.size()) { | |||
| MS_LOG(EXCEPTION) << "Origin params size[" << old_params.size() << "], inline params size[" << params.size() << "]"; | |||
| return; | |||
| @@ -200,7 +200,7 @@ void Cloner::InlineCloneParameters(const FuncGraphPtr& func_graph, const AnfNode | |||
| } | |||
| } | |||
| void Cloner::SetFuncGraphInfo(const FuncGraphPtr& func_graph, FuncGraphPtr* const target_func_graph) { | |||
| void Cloner::SetFuncGraphInfo(const FuncGraphPtr &func_graph, FuncGraphPtr *const target_func_graph) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| MS_EXCEPTION_IF_NULL(target_func_graph); | |||
| TraceManager::DebugTrace(func_graph->debug_info(), target_relation_); | |||
| @@ -215,33 +215,33 @@ void Cloner::SetFuncGraphInfo(const FuncGraphPtr& func_graph, FuncGraphPtr* cons | |||
| TraceManager::EndTrace(); | |||
| } | |||
| void Cloner::CloneParameters(const FuncGraphPtr& func_graph, const FuncGraphPtr& target_func_graph) { | |||
| void Cloner::CloneParameters(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| MS_EXCEPTION_IF_NULL(target_func_graph); | |||
| auto& params = func_graph->parameters(); | |||
| for (auto& param : params) { | |||
| auto ¶ms = func_graph->parameters(); | |||
| for (auto ¶m : params) { | |||
| CloneParameter(param, target_func_graph, true); | |||
| } | |||
| repl_func_graph_[func_graph] = target_func_graph; | |||
| } | |||
| void Cloner::GenParameters(const FuncGraphPtr& func_graph) { | |||
| void Cloner::GenParameters(const FuncGraphPtr &func_graph) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| auto& free_vars = manager_->free_variables_total(); | |||
| auto &free_vars = manager_->free_variables_total(); | |||
| auto iter = free_vars.find(func_graph); | |||
| if (iter == free_vars.end()) { | |||
| return; | |||
| } | |||
| for (auto& fv_map : iter->second) { | |||
| auto& free_var = fv_map.first; | |||
| for (auto &fv_map : iter->second) { | |||
| auto &free_var = fv_map.first; | |||
| if (utils::isa<AnfNodePtr>(free_var)) { | |||
| repl_func_graph_params_[func_graph].push_back(AddParameter(func_graph, utils::cast<AnfNodePtr>(free_var))); | |||
| } | |||
| } | |||
| } | |||
| void Cloner::CloneParameter(const ParameterPtr& param, const AnfNodePtr& node) { | |||
| void Cloner::CloneParameter(const ParameterPtr ¶m, const AnfNodePtr &node) { | |||
| param->set_abstract(node->abstract()); | |||
| if (node->isa<Parameter>()) { | |||
| ParameterPtr old_param = dyn_cast<Parameter>(node); | |||
| @@ -252,7 +252,7 @@ void Cloner::CloneParameter(const ParameterPtr& param, const AnfNodePtr& node) { | |||
| } | |||
| } | |||
| ParameterPtr Cloner::AddParameter(const FuncGraphPtr& func_graph, const AnfNodePtr& node, bool is_add) { | |||
| ParameterPtr Cloner::AddParameter(const FuncGraphPtr &func_graph, const AnfNodePtr &node, bool is_add) { | |||
| TraceManager::DebugTrace(std::make_shared<TraceCopy>(node->debug_info())); | |||
| ParameterPtr param = std::make_shared<Parameter>(func_graph); | |||
| TraceManager::EndTrace(); | |||
| @@ -265,11 +265,11 @@ ParameterPtr Cloner::AddParameter(const FuncGraphPtr& func_graph, const AnfNodeP | |||
| return param; | |||
| } | |||
| void Cloner::AddParameters(const FuncGraphPtr& func_graph, const AnfNodePtrList& params, | |||
| AnfNodePtrList* const lift_params, AnfNodePtrList* const input_params) { | |||
| void Cloner::AddParameters(const FuncGraphPtr &func_graph, const AnfNodePtrList ¶ms, | |||
| AnfNodePtrList *const lift_params, AnfNodePtrList *const input_params) { | |||
| AnfNodePtrList parameters; | |||
| std::unordered_set<AnfNodePtr> old_params; | |||
| for (auto& param : func_graph->parameters()) { | |||
| for (auto ¶m : func_graph->parameters()) { | |||
| auto iter = repl_node_.find(param); | |||
| if (iter != repl_node_.end()) { | |||
| (void)old_params.insert(iter->second); | |||
| @@ -280,7 +280,7 @@ void Cloner::AddParameters(const FuncGraphPtr& func_graph, const AnfNodePtrList& | |||
| } | |||
| } | |||
| AnfNodePtr new_param = nullptr; | |||
| for (auto& param : params) { | |||
| for (auto ¶m : params) { | |||
| auto old_param = repl_node_[param]; | |||
| if (old_param->isa<CNode>() && old_param->func_graph() == func_graph) { | |||
| repl_node_[old_param] = old_param; | |||
| @@ -301,10 +301,10 @@ void Cloner::AddParameters(const FuncGraphPtr& func_graph, const AnfNodePtrList& | |||
| func_graph->set_parameters(parameters); | |||
| } | |||
| void Cloner::AddInputs(const FuncGraphPtr& func_graph_user, const FuncGraphPtr& func_graph, | |||
| const AnfNodePtrList& params) { | |||
| void Cloner::AddInputs(const FuncGraphPtr &func_graph_user, const FuncGraphPtr &func_graph, | |||
| const AnfNodePtrList ¶ms) { | |||
| AnfNodePtr node = nullptr; | |||
| auto& repl_func_graph = repl_map_func_graph_[func_graph_user]; | |||
| auto &repl_func_graph = repl_map_func_graph_[func_graph_user]; | |||
| auto iter = repl_func_graph.find(func_graph); | |||
| if (iter == repl_func_graph.end()) { | |||
| node = func_graph_user->NewCNode({NewValueNode(prim::kPrimPartial), NewValueNode(func_graph)}); | |||
| @@ -322,9 +322,9 @@ void Cloner::AddInputs(const FuncGraphPtr& func_graph_user, const FuncGraphPtr& | |||
| OrderParameters(func_graph, inputs); | |||
| } | |||
| void Cloner::OrderParameters(const FuncGraphPtr& func_graph, const AnfNodePtrList& inputs) { | |||
| void Cloner::OrderParameters(const FuncGraphPtr &func_graph, const AnfNodePtrList &inputs) { | |||
| std::unordered_set<AnfNodePtr> old_params; | |||
| for (auto& param : func_graph->parameters()) { | |||
| for (auto ¶m : func_graph->parameters()) { | |||
| (void)old_params.insert(repl_node_[param]); | |||
| } | |||
| std::unordered_set<AnfNodePtr> new_params; | |||
| @@ -339,7 +339,7 @@ void Cloner::OrderParameters(const FuncGraphPtr& func_graph, const AnfNodePtrLis | |||
| (void)new_params.insert(new_param); | |||
| } | |||
| } | |||
| for (auto& param : func_graph->parameters()) { | |||
| for (auto ¶m : func_graph->parameters()) { | |||
| if (new_params.find(param) == new_params.end()) { | |||
| parameters.push_back(param); | |||
| } | |||
| @@ -347,9 +347,9 @@ void Cloner::OrderParameters(const FuncGraphPtr& func_graph, const AnfNodePtrLis | |||
| func_graph->set_parameters(parameters); | |||
| } | |||
| void Cloner::SetEdges(const FuncGraphPtr& func_graph) { | |||
| void Cloner::SetEdges(const FuncGraphPtr &func_graph) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| for (auto& node : func_graph->nodes()) { | |||
| for (auto &node : func_graph->nodes()) { | |||
| if (node == nullptr) { | |||
| continue; | |||
| } | |||
| @@ -358,17 +358,17 @@ void Cloner::SetEdges(const FuncGraphPtr& func_graph) { | |||
| continue; | |||
| } | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| auto& inputs = cnode->inputs(); | |||
| auto &inputs = cnode->inputs(); | |||
| for (size_t i = 0; i < inputs.size(); i++) { | |||
| auto& input = inputs[i]; | |||
| auto &input = inputs[i]; | |||
| if (IsValueNode<FuncGraph>(input)) { | |||
| auto graph = GetValueNode<FuncGraphPtr>(input); | |||
| auto& repl_func_graph = repl_map_func_graph_[func_graph]; | |||
| auto &repl_func_graph = repl_map_func_graph_[func_graph]; | |||
| if (repl_func_graph.find(graph) != repl_func_graph.end()) { | |||
| transaction_.SetEdge(cnode, SizeToInt(i), repl_func_graph[graph]); | |||
| } | |||
| } else { | |||
| auto& repl_node = repl_map_node_[func_graph]; | |||
| auto &repl_node = repl_map_node_[func_graph]; | |||
| if (repl_node.find(input) != repl_node.end()) { | |||
| transaction_.SetEdge(cnode, SizeToInt(i), repl_node[input]); | |||
| } | |||
| @@ -377,8 +377,8 @@ void Cloner::SetEdges(const FuncGraphPtr& func_graph) { | |||
| } | |||
| } | |||
| void Cloner::LiftParameters(const FuncGraphPtr& func_graph_user, const FuncGraphPtr& func_graph, | |||
| const AnfNodePtrList& params) { | |||
| void Cloner::LiftParameters(const FuncGraphPtr &func_graph_user, const FuncGraphPtr &func_graph, | |||
| const AnfNodePtrList ¶ms) { | |||
| AnfNodePtrList lift_params; | |||
| AnfNodePtrList input_params; | |||
| AddParameters(func_graph_user, params, &lift_params, &input_params); | |||
| @@ -386,16 +386,16 @@ void Cloner::LiftParameters(const FuncGraphPtr& func_graph_user, const FuncGraph | |||
| if (lift_params.empty()) { | |||
| return; | |||
| } | |||
| for (auto& user : func_graph_user->func_graph_users()) { | |||
| for (auto &user : func_graph_user->func_graph_users()) { | |||
| LiftParameters(user.first, func_graph_user, lift_params); | |||
| } | |||
| } | |||
| void Cloner::Lift() { | |||
| for (auto& func_graph_params : repl_func_graph_params_) { | |||
| auto& func_graph = func_graph_params.first; | |||
| auto& params = func_graph_params.second; | |||
| for (auto& user : func_graph->func_graph_users()) { | |||
| for (auto &func_graph_params : repl_func_graph_params_) { | |||
| auto &func_graph = func_graph_params.first; | |||
| auto ¶ms = func_graph_params.second; | |||
| for (auto &user : func_graph->func_graph_users()) { | |||
| LiftParameters(user.first, func_graph, params); | |||
| } | |||
| } | |||
| @@ -404,18 +404,18 @@ void Cloner::Lift() { | |||
| void Cloner::LiftParameters() { | |||
| MS_EXCEPTION_IF_NULL(manager_); | |||
| transaction_ = manager_->Transact(); | |||
| const FuncGraphSet& func_graphs = manager_->func_graphs(); | |||
| for (auto& func_graph : func_graphs) { | |||
| const FuncGraphSet &func_graphs = manager_->func_graphs(); | |||
| for (auto &func_graph : func_graphs) { | |||
| GenParameters(func_graph); | |||
| } | |||
| Lift(); | |||
| for (auto& func_graph : func_graphs) { | |||
| for (auto &func_graph : func_graphs) { | |||
| SetEdges(func_graph); | |||
| } | |||
| transaction_.Commit(); | |||
| } | |||
| bool Cloner::CheckStatus(const FuncGraphPtr& func_graph, bool is_inline) { | |||
| bool Cloner::CheckStatus(const FuncGraphPtr &func_graph, bool is_inline) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| // Make sure only inline once | |||
| if (status_.count(func_graph) != 0) { | |||
| @@ -430,12 +430,12 @@ bool Cloner::CheckStatus(const FuncGraphPtr& func_graph, bool is_inline) { | |||
| return true; | |||
| } | |||
| void Cloner::CloneAllNodes(const FuncGraphPtr& func_graph, const FuncGraphPtr& target_func_graph) { | |||
| void Cloner::CloneAllNodes(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| MS_EXCEPTION_IF_NULL(target_func_graph); | |||
| MS_EXCEPTION_IF_NULL(manager_); | |||
| const AnfNodeSet& nodes = manager_->nodes()[func_graph]; | |||
| for (auto& node : nodes) { | |||
| const AnfNodeSet &nodes = manager_->nodes()[func_graph]; | |||
| for (auto &node : nodes) { | |||
| CloneNode(node, target_func_graph); | |||
| } | |||
| } | |||
| @@ -449,7 +449,7 @@ void Cloner::Run() { | |||
| // Basic and Inline Clone | |||
| FuncGraphPtrList func_graphs; | |||
| (void)std::transform(todo_.begin(), todo_.end(), std::back_inserter(func_graphs), | |||
| [](const CloneInfo& item) -> FuncGraphPtr { return item.origin; }); | |||
| [](const CloneInfo &item) -> FuncGraphPtr { return item.origin; }); | |||
| manager_ = Manage(func_graphs, false); | |||
| CloneNodes(); | |||
| LinkEdges(); | |||
| @@ -495,13 +495,13 @@ void Cloner::CloneNodes() { | |||
| } | |||
| void Cloner::LinkEdges() { | |||
| for (auto& node_pair : nodes_) { | |||
| for (auto &node_pair : nodes_) { | |||
| CNodePtr old_node = node_pair.first; | |||
| CNodePtr new_node = node_pair.second; | |||
| MS_EXCEPTION_IF_NULL(old_node); | |||
| MS_EXCEPTION_IF_NULL(new_node); | |||
| for (auto& input : old_node->inputs()) { | |||
| auto& new_input = (repl_node_.count(input) == 0) ? input : repl_node_[input]; | |||
| for (auto &input : old_node->inputs()) { | |||
| auto &new_input = (repl_node_.count(input) == 0) ? input : repl_node_[input]; | |||
| new_node->add_input(new_input); | |||
| } | |||
| } | |||
| @@ -509,10 +509,10 @@ void Cloner::LinkEdges() { | |||
| // For the graphs cloned, update its default value map to the cloned nodes | |||
| void Cloner::SetDefaults() { | |||
| for (auto& item : graph_set_) { | |||
| for (auto &item : graph_set_) { | |||
| MS_EXCEPTION_IF_NULL(item); | |||
| if (repl_func_graph_.count(item) != 0) { | |||
| for (auto& param_def : item->parameter_default_value()) { | |||
| for (auto ¶m_def : item->parameter_default_value()) { | |||
| MS_EXCEPTION_IF_NULL(repl_func_graph_[item]); | |||
| if (repl_node_.count(param_def.second) != 0) { | |||
| repl_func_graph_[item]->set_param_default_value(param_def.first, repl_node_[param_def.second]); | |||
| @@ -524,7 +524,7 @@ void Cloner::SetDefaults() { | |||
| } | |||
| } | |||
| AnfNodePtr Cloner::CloneDisconnected(const AnfNodePtr& root) { | |||
| AnfNodePtr Cloner::CloneDisconnected(const AnfNodePtr &root) { | |||
| MS_EXCEPTION_IF_NULL(root); | |||
| if (repl_func_graph_.find(root->func_graph()) == repl_func_graph_.end()) { | |||
| MS_LOG(EXCEPTION) << "Cannot find func graph " << root->func_graph()->ToString() << " in cloner."; | |||
| @@ -537,7 +537,7 @@ AnfNodePtr Cloner::CloneDisconnected(const AnfNodePtr& root) { | |||
| MS_LOG(EXCEPTION) << "Failed in clone for node " << root->DebugString() << "."; | |||
| } | |||
| AnfNodePtr Cloner::operator[](const AnfNodePtr& node) { | |||
| AnfNodePtr Cloner::operator[](const AnfNodePtr &node) { | |||
| #ifdef ENABLE_PROFILE | |||
| double time = GetTime(); | |||
| #endif | |||
| @@ -548,7 +548,7 @@ AnfNodePtr Cloner::operator[](const AnfNodePtr& node) { | |||
| return ((repl_node_.count(node) == 0) ? node : repl_node_[node]); | |||
| } | |||
| FuncGraphPtr Cloner::operator[](const FuncGraphPtr& func_graph) { | |||
| FuncGraphPtr Cloner::operator[](const FuncGraphPtr &func_graph) { | |||
| #ifdef ENABLE_PROFILE | |||
| double time = GetTime(); | |||
| #endif | |||
| @@ -559,14 +559,14 @@ FuncGraphPtr Cloner::operator[](const FuncGraphPtr& func_graph) { | |||
| return ((repl_func_graph_.count(func_graph) == 0) ? func_graph : repl_func_graph_[func_graph]); | |||
| } | |||
| FuncGraphPtr BasicClone(const FuncGraphPtr& func_graph) { | |||
| FuncGraphPtr BasicClone(const FuncGraphPtr &func_graph) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| Cloner cloner({func_graph}, false, true, true, std::make_shared<TraceCopy>(), nullptr); | |||
| return cloner[func_graph]; | |||
| } | |||
| AnfNodePtr InlineClone(const FuncGraphPtr& func_graph, const FuncGraphPtr& target_func_graph, | |||
| const AnfNodePtrList& func_graph_args, const ScopePtr& scope) { | |||
| AnfNodePtr InlineClone(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph, | |||
| const AnfNodePtrList &func_graph_args, const ScopePtr &scope) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| MS_EXCEPTION_IF_NULL(target_func_graph); | |||
| Cloner cloner({}, false); | |||
| @@ -577,14 +577,14 @@ AnfNodePtr InlineClone(const FuncGraphPtr& func_graph, const FuncGraphPtr& targe | |||
| return cloner[func_graph->output()]; | |||
| } | |||
| FuncGraphPtr LiftingClone(const FuncGraphPtr& func_graph) { | |||
| FuncGraphPtr LiftingClone(const FuncGraphPtr &func_graph) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| Cloner cloner({}, false); | |||
| cloner.AddClone(func_graph, nullptr, {}, kLifting); | |||
| return cloner[func_graph]; | |||
| } | |||
| ClonerPtr SpecializerClone(const FuncGraphPtr& func_graph, const TraceInfoPtr& relation) { | |||
| ClonerPtr SpecializerClone(const FuncGraphPtr &func_graph, const TraceInfoPtr &relation) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| FuncGraphPtrList func_graphs = {func_graph}; | |||
| ClonerPtr cloner = | |||
| @@ -599,14 +599,14 @@ ClonerPtr SpecializerClone(const FuncGraphPtr& func_graph, const TraceInfoPtr& r | |||
| return cloner; | |||
| } | |||
| FuncGraphPtr TransformableClone(const FuncGraphPtr& func_graph, const TraceInfoPtr& relation) { | |||
| FuncGraphPtr TransformableClone(const FuncGraphPtr &func_graph, const TraceInfoPtr &relation) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| TraceManager::DebugTrace(func_graph->debug_info(), relation); | |||
| auto new_func_graph = std::make_shared<FuncGraph>(); | |||
| TraceManager::EndTrace(); | |||
| auto& parameters = func_graph->parameters(); | |||
| (void)std::for_each(parameters.begin(), parameters.end(), [&new_func_graph](const AnfNodePtr& param) -> void { | |||
| auto ¶meters = func_graph->parameters(); | |||
| (void)std::for_each(parameters.begin(), parameters.end(), [&new_func_graph](const AnfNodePtr ¶m) -> void { | |||
| MS_EXCEPTION_IF_NULL(param); | |||
| TraceManager::DebugTrace(std::make_shared<TraceCopy>(param->debug_info())); | |||
| (void)new_func_graph->add_parameter(); | |||
| @@ -622,7 +622,7 @@ FuncGraphPtr TransformableClone(const FuncGraphPtr& func_graph, const TraceInfoP | |||
| new_func_graph->set_kwonlyargs_count(func_graph->kwonlyargs_count()); | |||
| new_func_graph->set_hyper_param_count(func_graph->hyper_param_count()); | |||
| new_func_graph->set_is_generate(func_graph->is_generated()); | |||
| for (auto& item : func_graph->parameter_default_value()) { | |||
| for (auto &item : func_graph->parameter_default_value()) { | |||
| new_func_graph->set_param_default_value(item.first, cloner[item.second]); | |||
| } | |||
| @@ -43,26 +43,26 @@ struct CloneInfo { | |||
| class Cloner { | |||
| public: | |||
| explicit Cloner(const FuncGraphPtrList& func_graphs = {}, bool clone_all_valuenodes = false, | |||
| explicit Cloner(const FuncGraphPtrList &func_graphs = {}, bool clone_all_valuenodes = false, | |||
| bool clone_all_child_graphs = true, bool clone_all_used_graphs = false, | |||
| const TraceInfoPtr& relation = std::make_shared<TraceCopy>(), | |||
| const TraceInfoPtr& target_relation = nullptr); | |||
| const TraceInfoPtr &relation = std::make_shared<TraceCopy>(), | |||
| const TraceInfoPtr &target_relation = nullptr); | |||
| ~Cloner() = default; | |||
| void AddClone(const FuncGraphPtr& func_graph, const FuncGraphPtr& target_func_graph = nullptr, | |||
| const AnfNodePtrList& params = {}, CloneType type = kBasic); | |||
| void AddClone(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph = nullptr, | |||
| const AnfNodePtrList ¶ms = {}, CloneType type = kBasic); | |||
| void Run(); | |||
| // Interfaces for specializer | |||
| AnfNodePtr CloneDisconnected(const AnfNodePtr& root); | |||
| AnfNodePtr operator[](const AnfNodePtr& node); | |||
| FuncGraphPtr operator[](const FuncGraphPtr& func_graph); | |||
| AnfNodePtr CloneDisconnected(const AnfNodePtr &root); | |||
| AnfNodePtr operator[](const AnfNodePtr &node); | |||
| FuncGraphPtr operator[](const FuncGraphPtr &func_graph); | |||
| // Map of replicate nodes and graphs | |||
| std::unordered_map<AnfNodePtr, AnfNodePtr>* cloned_node() { return &repl_node_; } | |||
| std::unordered_map<AnfNodePtr, AnfNodePtr> *cloned_node() { return &repl_node_; } | |||
| std::unordered_map<FuncGraphPtr, FuncGraphPtr> cloned_func_graph() { return repl_func_graph_; } | |||
| // Scope of cloned graphs | |||
| void set_scope(const ScopePtr& scope) { scope_ = scope; } | |||
| void set_scope(const ScopePtr &scope) { scope_ = scope; } | |||
| const ScopePtr scope() const { return scope_; } | |||
| std::unordered_map<AnfNodePtr, AnfNodePtr> repl_node_; | |||
| @@ -71,31 +71,31 @@ class Cloner { | |||
| void CloneNodes(); | |||
| void LinkEdges(); | |||
| void SetDefaults(); | |||
| void CloneNode(const AnfNodePtr& node, const FuncGraphPtr& target); | |||
| void CloneValueNode(const AnfNodePtr& node); | |||
| void CloneValueNode(const AnfNodePtr& node, const FuncGraphPtr& target); | |||
| void CloneCNode(const AnfNodePtr& node, const FuncGraphPtr& target); | |||
| void CloneParameter(const AnfNodePtr& node, const FuncGraphPtr& target, bool is_add = false); | |||
| void CloneValueNodes(const FuncGraphPtr& func_graph); | |||
| void AddChildGraphs(const FuncGraphPtr& func_graph); | |||
| void AddTotalGraphs(const FuncGraphPtr& func_graph); | |||
| bool CheckStatus(const FuncGraphPtr& func_graph, bool is_inline); | |||
| void CloneAllNodes(const FuncGraphPtr& func_graph, const FuncGraphPtr& target_func_graph); | |||
| void CloneFuncGraphValueNodes(const FuncGraphPtr& func_graph, const FuncGraphPtr& target_func_graph); | |||
| void CloneFuncGraphDefaultValues(const FuncGraphPtr& func_graph, const FuncGraphPtr& target_func_graph); | |||
| void InlineCloneParameters(const FuncGraphPtr& func_graph, const AnfNodePtrList& params); | |||
| void SetFuncGraphInfo(const FuncGraphPtr& func_graph, FuncGraphPtr* const target_func_graph); | |||
| void CloneParameters(const FuncGraphPtr& func_graph, const FuncGraphPtr& target_func_graph); | |||
| void GenParameters(const FuncGraphPtr& func_graph); | |||
| void CloneParameter(const ParameterPtr& param, const AnfNodePtr& node); | |||
| ParameterPtr AddParameter(const FuncGraphPtr& func_graph, const AnfNodePtr& node, bool is_add = true); | |||
| void AddParameters(const FuncGraphPtr& func_graph, const AnfNodePtrList& params, AnfNodePtrList* const lift_params, | |||
| AnfNodePtrList* const input_params); | |||
| void AddInputs(const FuncGraphPtr& func_graph_user, const FuncGraphPtr& func_graph, const AnfNodePtrList& params); | |||
| void OrderParameters(const FuncGraphPtr& func_graph, const AnfNodePtrList& inputs); | |||
| void SetEdges(const FuncGraphPtr& func_graph); | |||
| void LiftParameters(const FuncGraphPtr& func_graph_user, const FuncGraphPtr& func_graph, | |||
| const AnfNodePtrList& params); | |||
| void CloneNode(const AnfNodePtr &node, const FuncGraphPtr &target); | |||
| void CloneValueNode(const AnfNodePtr &node); | |||
| void CloneValueNode(const AnfNodePtr &node, const FuncGraphPtr &target); | |||
| void CloneCNode(const AnfNodePtr &node, const FuncGraphPtr &target); | |||
| void CloneParameter(const AnfNodePtr &node, const FuncGraphPtr &target, bool is_add = false); | |||
| void CloneValueNodes(const FuncGraphPtr &func_graph); | |||
| void AddChildGraphs(const FuncGraphPtr &func_graph); | |||
| void AddTotalGraphs(const FuncGraphPtr &func_graph); | |||
| bool CheckStatus(const FuncGraphPtr &func_graph, bool is_inline); | |||
| void CloneAllNodes(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph); | |||
| void CloneFuncGraphValueNodes(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph); | |||
| void CloneFuncGraphDefaultValues(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph); | |||
| void InlineCloneParameters(const FuncGraphPtr &func_graph, const AnfNodePtrList ¶ms); | |||
| void SetFuncGraphInfo(const FuncGraphPtr &func_graph, FuncGraphPtr *const target_func_graph); | |||
| void CloneParameters(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph); | |||
| void GenParameters(const FuncGraphPtr &func_graph); | |||
| void CloneParameter(const ParameterPtr ¶m, const AnfNodePtr &node); | |||
| ParameterPtr AddParameter(const FuncGraphPtr &func_graph, const AnfNodePtr &node, bool is_add = true); | |||
| void AddParameters(const FuncGraphPtr &func_graph, const AnfNodePtrList ¶ms, AnfNodePtrList *const lift_params, | |||
| AnfNodePtrList *const input_params); | |||
| void AddInputs(const FuncGraphPtr &func_graph_user, const FuncGraphPtr &func_graph, const AnfNodePtrList ¶ms); | |||
| void OrderParameters(const FuncGraphPtr &func_graph, const AnfNodePtrList &inputs); | |||
| void SetEdges(const FuncGraphPtr &func_graph); | |||
| void LiftParameters(const FuncGraphPtr &func_graph_user, const FuncGraphPtr &func_graph, | |||
| const AnfNodePtrList ¶ms); | |||
| void Lift(); | |||
| void LiftParameters(); | |||
| @@ -118,17 +118,17 @@ class Cloner { | |||
| std::unordered_map<FuncGraphPtr, AnfNodePtrList> repl_func_graph_params_; | |||
| }; | |||
| FuncGraphPtr BasicClone(const FuncGraphPtr& func_graph); | |||
| FuncGraphPtr BasicClone(const FuncGraphPtr &func_graph); | |||
| AnfNodePtr InlineClone(const FuncGraphPtr& func_graph, const FuncGraphPtr& target_func_graph, | |||
| const AnfNodePtrList& func_graph_args, const ScopePtr& scope = nullptr); | |||
| AnfNodePtr InlineClone(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph, | |||
| const AnfNodePtrList &func_graph_args, const ScopePtr &scope = nullptr); | |||
| FuncGraphPtr LiftingClone(const FuncGraphPtr& func_graph); | |||
| FuncGraphPtr LiftingClone(const FuncGraphPtr &func_graph); | |||
| ClonerPtr SpecializerClone(const FuncGraphPtr& func_graph, const TraceInfoPtr& relation); | |||
| ClonerPtr SpecializerClone(const FuncGraphPtr &func_graph, const TraceInfoPtr &relation); | |||
| FuncGraphPtr TransformableClone(const FuncGraphPtr& func_graph, | |||
| const TraceInfoPtr& relation = std::make_shared<TraceTransform>()); | |||
| FuncGraphPtr TransformableClone(const FuncGraphPtr &func_graph, | |||
| const TraceInfoPtr &relation = std::make_shared<TraceTransform>()); | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_IR_FUNC_GRAPH_CLONER_H_ | |||
| @@ -27,17 +27,17 @@ | |||
| namespace mindspore { | |||
| FuncGraphManagerPtr MakeManager(const std::vector<FuncGraphPtr>& func_graphs, bool manage) { | |||
| FuncGraphManagerPtr MakeManager(const std::vector<FuncGraphPtr> &func_graphs, bool manage) { | |||
| auto m = std::make_shared<FuncGraphManager>(func_graphs, manage); | |||
| m->Init(); | |||
| return m; | |||
| } | |||
| FuncGraphManagerPtr Manage(const std::vector<FuncGraphPtr>& func_graphs, bool manage) { | |||
| FuncGraphManagerPtr Manage(const std::vector<FuncGraphPtr> &func_graphs, bool manage) { | |||
| FuncGraphManagerPtr m = nullptr; | |||
| bool root = false; | |||
| for (auto& fg : func_graphs) { | |||
| for (auto &fg : func_graphs) { | |||
| if (fg == nullptr) { | |||
| continue; | |||
| } | |||
| @@ -53,7 +53,7 @@ FuncGraphManagerPtr Manage(const std::vector<FuncGraphPtr>& func_graphs, bool ma | |||
| root = true; | |||
| } | |||
| for (auto& fg : func_graphs) { | |||
| for (auto &fg : func_graphs) { | |||
| if (fg == nullptr) { | |||
| continue; | |||
| } | |||
| @@ -67,7 +67,7 @@ FuncGraphManagerPtr Manage(FuncGraphPtr func_graph, bool manage) { | |||
| return Manage(func_graphs, manage); | |||
| } | |||
| FuncGraphManager::FuncGraphManager(const std::vector<FuncGraphPtr>& roots, bool manage) | |||
| FuncGraphManager::FuncGraphManager(const std::vector<FuncGraphPtr> &roots, bool manage) | |||
| : roots_(roots), is_manage_(manage) { | |||
| Reset(); | |||
| } | |||
| @@ -103,12 +103,12 @@ void FuncGraphManager::Init() { | |||
| auto roots = roots_; | |||
| roots_ = FuncGraphSet(); | |||
| for (auto& fg : roots) { | |||
| for (auto &fg : roots) { | |||
| AddFuncGraph(fg, true); | |||
| } | |||
| } | |||
| FuncGraphSet& FuncGraphManager::func_graph_parents_total(const FuncGraphPtr& fg) const { | |||
| FuncGraphSet &FuncGraphManager::func_graph_parents_total(const FuncGraphPtr &fg) const { | |||
| MS_EXCEPTION_IF_NULL(fg); | |||
| MS_LOG(DEBUG) << "Start func_graph_parents_total func graph " << fg->ToString(); | |||
| func_graph_parents_total_->Recompute(fg); | |||
| @@ -116,7 +116,7 @@ FuncGraphSet& FuncGraphManager::func_graph_parents_total(const FuncGraphPtr& fg) | |||
| return func_graph_parents_total_->func_graph_parents_total_analysis()[fg]; | |||
| } | |||
| FuncGraphPtr FuncGraphManager::parent(const FuncGraphPtr& fg) const { | |||
| FuncGraphPtr FuncGraphManager::parent(const FuncGraphPtr &fg) const { | |||
| MS_EXCEPTION_IF_NULL(fg); | |||
| MS_EXCEPTION_IF_NULL(func_graph_parent_); | |||
| MS_LOG(DEBUG) << "Start parents func graph " << fg->ToString(); | |||
| @@ -129,7 +129,7 @@ FuncGraphPtr FuncGraphManager::parent(const FuncGraphPtr& fg) const { | |||
| return func_graph_parent_->parent_analysis()[fg]; | |||
| } | |||
| FuncGraphSet& FuncGraphManager::children(const FuncGraphPtr& fg) const { | |||
| FuncGraphSet &FuncGraphManager::children(const FuncGraphPtr &fg) const { | |||
| MS_EXCEPTION_IF_NULL(fg); | |||
| MS_EXCEPTION_IF_NULL(children_); | |||
| MS_LOG(DEBUG) << "Start child func graph " << fg->ToString(); | |||
| @@ -137,7 +137,7 @@ FuncGraphSet& FuncGraphManager::children(const FuncGraphPtr& fg) const { | |||
| return children_->children_analysis()[fg]; | |||
| } | |||
| FuncGraphSet& FuncGraphManager::scopes(const FuncGraphPtr& fg) const { | |||
| FuncGraphSet &FuncGraphManager::scopes(const FuncGraphPtr &fg) const { | |||
| MS_EXCEPTION_IF_NULL(fg); | |||
| MS_EXCEPTION_IF_NULL(scopes_); | |||
| MS_LOG(DEBUG) << "Start scopes func graph:" << fg->ToString(); | |||
| @@ -146,19 +146,19 @@ FuncGraphSet& FuncGraphManager::scopes(const FuncGraphPtr& fg) const { | |||
| return scopes_->scope_analysis()[fg]; | |||
| } | |||
| FVTotalMap& FuncGraphManager::free_variables_total() const { | |||
| FVTotalMap &FuncGraphManager::free_variables_total() const { | |||
| MS_EXCEPTION_IF_NULL(free_variables_total_); | |||
| free_variables_total_->Recompute(); | |||
| return free_variables_total_->fv_total_analysis(); | |||
| } | |||
| FuncGraphSet& FuncGraphManager::func_graphs_used_total(const FuncGraphPtr& fg) const { | |||
| FuncGraphSet &FuncGraphManager::func_graphs_used_total(const FuncGraphPtr &fg) const { | |||
| MS_EXCEPTION_IF_NULL(func_graphs_used_total_); | |||
| func_graphs_used_total_->Recompute(fg); | |||
| return func_graphs_used_total_->func_graph_used_total_analysis()[fg]; | |||
| } | |||
| bool FuncGraphManager::recursive(const FuncGraphPtr& fg) const { | |||
| bool FuncGraphManager::recursive(const FuncGraphPtr &fg) const { | |||
| MS_EXCEPTION_IF_NULL(fg); | |||
| recursive_->Recompute(fg); | |||
| if (recursive_->recursive_analysis().count(fg) == 0) { | |||
| @@ -168,7 +168,7 @@ bool FuncGraphManager::recursive(const FuncGraphPtr& fg) const { | |||
| return recursive_->recursive_analysis()[fg]; | |||
| } | |||
| std::shared_ptr<std::list<FuncGraphPtr>> FuncGraphManager::recursive_graphs(const FuncGraphPtr& fg) const { | |||
| std::shared_ptr<std::list<FuncGraphPtr>> FuncGraphManager::recursive_graphs(const FuncGraphPtr &fg) const { | |||
| MS_EXCEPTION_IF_NULL(fg); | |||
| if (recursive(fg)) { | |||
| if (!recursive_->recursive_map().count(fg)) { | |||
| @@ -185,7 +185,7 @@ std::shared_ptr<std::list<FuncGraphPtr>> FuncGraphManager::recursive_graphs(cons | |||
| } | |||
| } | |||
| bool FuncGraphManager::func_graph_j_total(const FuncGraphPtr& fg) const { | |||
| bool FuncGraphManager::func_graph_j_total(const FuncGraphPtr &fg) const { | |||
| MS_EXCEPTION_IF_NULL(j_total_); | |||
| MS_EXCEPTION_IF_NULL(fg); | |||
| j_total_->Recompute(fg); | |||
| @@ -225,10 +225,10 @@ void FuncGraphManager::Clear() { | |||
| signals_->InvalidateComputer(); | |||
| } | |||
| void FuncGraphManager::KeepRoots(const std::vector<FuncGraphPtr>& func_graphs) { | |||
| void FuncGraphManager::KeepRoots(const std::vector<FuncGraphPtr> &func_graphs) { | |||
| MS_LOG(DEBUG) << "Start keep roots"; | |||
| bool root_exist = false; | |||
| for (auto& item : func_graphs) { | |||
| for (auto &item : func_graphs) { | |||
| if (roots_.contains(item)) { | |||
| root_exist = true; | |||
| break; | |||
| @@ -245,17 +245,17 @@ void FuncGraphManager::KeepRoots(const std::vector<FuncGraphPtr>& func_graphs) { | |||
| roots = roots_; | |||
| } else { | |||
| roots_.clear(); | |||
| for (auto& item : roots) { | |||
| for (auto &item : roots) { | |||
| AddFuncGraph(item, true); | |||
| } | |||
| } | |||
| FuncGraphSet keep; | |||
| for (auto& item : roots) { | |||
| for (auto &item : roots) { | |||
| MS_LOG(DEBUG) << "roots: " << item->ToString(); | |||
| keep.update(func_graphs_used_total(item)); | |||
| #ifdef DEBUG | |||
| for (auto& k : keep) { | |||
| for (auto &k : keep) { | |||
| MS_LOG(DEBUG) << "keep: " << k->ToString(); | |||
| } | |||
| #endif | |||
| @@ -264,7 +264,7 @@ void FuncGraphManager::KeepRoots(const std::vector<FuncGraphPtr>& func_graphs) { | |||
| } else { | |||
| Clear(); | |||
| FuncGraphSet roots(func_graphs); | |||
| for (auto& item : roots) { | |||
| for (auto &item : roots) { | |||
| AddFuncGraph(item, true); | |||
| } | |||
| } | |||
| @@ -276,7 +276,7 @@ void FuncGraphManager::RemoveRoots() { | |||
| MaybeDropFuncGraphs(func_graphs_, true); | |||
| } | |||
| void FuncGraphManager::AddIntoManaged(const FuncGraphPtr& fg) { | |||
| void FuncGraphManager::AddIntoManaged(const FuncGraphPtr &fg) { | |||
| MS_EXCEPTION_IF_NULL(fg); | |||
| if (is_manage_) { | |||
| if (fg->manager() != nullptr && (&(*fg->manager()) != this)) { | |||
| @@ -288,7 +288,7 @@ void FuncGraphManager::AddIntoManaged(const FuncGraphPtr& fg) { | |||
| func_graphs_.add(fg); | |||
| } | |||
| void FuncGraphManager::MaybeDropFuncGraphs(const FuncGraphSet& func_graphs, bool ignore_users) { | |||
| void FuncGraphManager::MaybeDropFuncGraphs(const FuncGraphSet &func_graphs, bool ignore_users) { | |||
| FuncGraphSet todo(func_graphs); | |||
| std::set<FuncGraphPtr> dropped; | |||
| // int count = 0; | |||
| @@ -301,7 +301,7 @@ void FuncGraphManager::MaybeDropFuncGraphs(const FuncGraphSet& func_graphs, bool | |||
| continue; | |||
| } | |||
| MS_EXCEPTION_IF_NULL(func_graph_users_); | |||
| auto& users = func_graph_users_->count_func_graphs_map()[func_graph]; | |||
| auto &users = func_graph_users_->count_func_graphs_map()[func_graph]; | |||
| if (!users.empty() && !ignore_users) { | |||
| MS_LOG(DEBUG) << "Cannot drop as users not empty: " << func_graph->ToString(); | |||
| continue; | |||
| @@ -315,7 +315,7 @@ void FuncGraphManager::MaybeDropFuncGraphs(const FuncGraphSet& func_graphs, bool | |||
| todo.update(MaybeDropNodes(return_vec)); | |||
| } | |||
| MS_EXCEPTION_IF_NULL(signals_); | |||
| for (auto& fg : dropped) { | |||
| for (auto &fg : dropped) { | |||
| MS_EXCEPTION_IF_NULL(fg); | |||
| signals_->DropFuncGraph(fg); | |||
| all_nodes_.difference_update(fg->parameters()); | |||
| @@ -331,7 +331,7 @@ void FuncGraphManager::ProcessEdge(AnfNodePtr node, int index, AnfNodePtr inp, E | |||
| MS_EXCEPTION_IF_NULL(inp); | |||
| if (direction == kDecEdge) { | |||
| MS_LOG(DEBUG) << "Remove node " << node->ToString() << " input[" << index << "] " << inp->ToString(); | |||
| auto& users_node = node_users_[inp]; | |||
| auto &users_node = node_users_[inp]; | |||
| if (!users_node.contains(make_pair(node, index))) { | |||
| return; | |||
| } | |||
| @@ -346,26 +346,26 @@ void FuncGraphManager::ProcessEdge(AnfNodePtr node, int index, AnfNodePtr inp, E | |||
| MS_LOG(DEBUG) << "Input[" << index << "] is const graph " << inp->ToString(); | |||
| AddFuncGraph(GetValueNode<FuncGraphPtr>(inp)); | |||
| } | |||
| auto& users_node = node_users_[inp]; | |||
| auto &users_node = node_users_[inp]; | |||
| users_node.add(make_pair(node, index)); | |||
| MS_EXCEPTION_IF_NULL(signals_); | |||
| signals_->AddEdge(node, index, inp); | |||
| } | |||
| } | |||
| void FuncGraphManager::ProcessInputs(const AnfNodePtr& node, EdgeProcessDirection direction) { | |||
| void FuncGraphManager::ProcessInputs(const AnfNodePtr &node, EdgeProcessDirection direction) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| if (node->isa<CNode>()) { | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| int index = 0; | |||
| for (auto& inp : cnode->inputs()) { | |||
| for (auto &inp : cnode->inputs()) { | |||
| ProcessEdge(cnode, index, inp, direction); | |||
| ++index; | |||
| } | |||
| } | |||
| } | |||
| IncludeType FuncGraphManager::Limit(const AnfNodePtr& node) { | |||
| IncludeType FuncGraphManager::Limit(const AnfNodePtr &node) { | |||
| if (all_nodes_.contains(node)) { | |||
| return EXCLUDE; | |||
| } else { | |||
| @@ -373,9 +373,9 @@ IncludeType FuncGraphManager::Limit(const AnfNodePtr& node) { | |||
| } | |||
| } | |||
| void FuncGraphManager::AcquireNodes(const std::vector<AnfNodePtr>& nodes) { | |||
| void FuncGraphManager::AcquireNodes(const std::vector<AnfNodePtr> &nodes) { | |||
| AnfNodeSet acq; | |||
| for (auto& node : nodes) { | |||
| for (auto &node : nodes) { | |||
| std::function<IncludeType(AnfNodePtr)> limit = std::bind(&FuncGraphManager::Limit, this, std::placeholders::_1); | |||
| AnfNodeSet new_nodes = AnfNodeSet(DeepScopedGraphSearch(node, limit)); | |||
| @@ -384,7 +384,7 @@ void FuncGraphManager::AcquireNodes(const std::vector<AnfNodePtr>& nodes) { | |||
| acq.update(new_nodes); | |||
| } | |||
| for (auto& node : acq) { | |||
| for (auto &node : acq) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| FuncGraphPtr fg = node->func_graph(); | |||
| if (fg != nullptr) { | |||
| @@ -395,7 +395,7 @@ void FuncGraphManager::AcquireNodes(const std::vector<AnfNodePtr>& nodes) { | |||
| } | |||
| } | |||
| FuncGraphSetPtr FuncGraphManager::MaybeDropNodes(const std::vector<AnfNodePtr>& nodes) { | |||
| FuncGraphSetPtr FuncGraphManager::MaybeDropNodes(const std::vector<AnfNodePtr> &nodes) { | |||
| AnfNodeSet nodes_ordered(nodes); | |||
| FuncGraphSetPtr func_graphs_to_check = std::make_shared<FuncGraphSet>(); | |||
| MS_EXCEPTION_IF_NULL(signals_); | |||
| @@ -406,7 +406,7 @@ FuncGraphSetPtr FuncGraphManager::MaybeDropNodes(const std::vector<AnfNodePtr>& | |||
| if (!all_nodes_.contains(node)) { | |||
| continue; | |||
| } | |||
| AnfNodeIndexSet& users = node_users_[node]; | |||
| AnfNodeIndexSet &users = node_users_[node]; | |||
| std::vector<AnfNodePtr> parameters; | |||
| if (!users.empty() || | |||
| @@ -431,13 +431,13 @@ FuncGraphSetPtr FuncGraphManager::MaybeDropNodes(const std::vector<AnfNodePtr>& | |||
| return func_graphs_to_check; | |||
| } | |||
| void FuncGraphManager::SetParameters(const FuncGraphPtr& fg, const std::vector<AnfNodePtr>& parameters) { | |||
| void FuncGraphManager::SetParameters(const FuncGraphPtr &fg, const std::vector<AnfNodePtr> ¶meters) { | |||
| auto tr = Transact(); | |||
| tr.SetParameters(fg, parameters); | |||
| tr.Commit(); | |||
| } | |||
| bool FuncGraphManager::Replace(const AnfNodePtr& old_node, const AnfNodePtr& new_node) { | |||
| bool FuncGraphManager::Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node) { | |||
| auto tr = Transact(); | |||
| bool success = tr.Replace(old_node, new_node); | |||
| if (success) { | |||
| @@ -446,13 +446,13 @@ bool FuncGraphManager::Replace(const AnfNodePtr& old_node, const AnfNodePtr& new | |||
| return success; | |||
| } | |||
| void FuncGraphManager::SetEdge(const AnfNodePtr& node, int index, const AnfNodePtr& value) { | |||
| void FuncGraphManager::SetEdge(const AnfNodePtr &node, int index, const AnfNodePtr &value) { | |||
| auto tr = Transact(); | |||
| tr.SetEdge(node, index, value); | |||
| tr.Commit(); | |||
| } | |||
| void FuncGraphManager::MoveAllCNodeDropGraph(FuncGraphPtr source, FuncGraphPtr target, const ScopePtr& scope) { | |||
| void FuncGraphManager::MoveAllCNodeDropGraph(FuncGraphPtr source, FuncGraphPtr target, const ScopePtr &scope) { | |||
| AnfNodePtr source_return = source->get_return(); | |||
| AnfNodePtr source_output = source->output(); | |||
| AnfNodePtr source_prim = source_return->cast<CNodePtr>()->input(0); | |||
| @@ -466,23 +466,23 @@ void FuncGraphManager::MoveAllCNodeDropGraph(FuncGraphPtr source, FuncGraphPtr t | |||
| (void)all_nodes_.erase(source_return); | |||
| (void)node_users_.erase(source_return); | |||
| signals_->DropNode(source_return); | |||
| for (auto& node : source->nodes()) { | |||
| for (auto &node : source->nodes()) { | |||
| node->set_func_graph(target); | |||
| if (node->scope() == kDefaultScope) { | |||
| node->set_scope(scope); | |||
| } | |||
| } | |||
| for (auto& used : source->func_graphs_used()) { | |||
| for (auto &used : source->func_graphs_used()) { | |||
| (void)func_graph_users_->Inc(used.first, target, used.second); | |||
| (void)this->func_graph_users()[used.first].erase(source); | |||
| } | |||
| for (auto& child : this->func_graph_child_direct()[source]) { | |||
| for (auto &child : this->func_graph_child_direct()[source]) { | |||
| (void)func_graph_parents_direct_->Inc(child.first, target, child.second); | |||
| (void)this->func_graph_parents_direct()[child.first].erase(source); | |||
| } | |||
| for (auto& fv_count : this->free_variables_direct()[source]) { | |||
| for (auto &fv_count : this->free_variables_direct()[source]) { | |||
| auto fv_g = fv_count.first->func_graph(); | |||
| auto& count_on_g = this->func_graph_child_direct()[fv_g]; | |||
| auto &count_on_g = this->func_graph_child_direct()[fv_g]; | |||
| auto pair = count_on_g.find(source); | |||
| if (fv_g != target && pair != count_on_g.end()) { | |||
| (void)func_graph_child_direct_->Inc(fv_g, target, pair->second); | |||
| @@ -504,9 +504,9 @@ FuncGraphTransaction FuncGraphManager::Transact() { | |||
| return tr; | |||
| } | |||
| void FuncGraphManager::ParseChanges(const std::vector<Change>& changes, EdgeTupleCounter* add_edges, | |||
| EdgeTupleCounter* rm_edges, Counter<AnfNodePtr>* adds, Counter<AnfNodePtr>* rms) { | |||
| for (auto& iter : changes) { | |||
| void FuncGraphManager::ParseChanges(const std::vector<Change> &changes, EdgeTupleCounter *add_edges, | |||
| EdgeTupleCounter *rm_edges, Counter<AnfNodePtr> *adds, Counter<AnfNodePtr> *rms) { | |||
| for (auto &iter : changes) { | |||
| auto operation = iter.op; | |||
| auto args = iter.args; | |||
| if (operation == Change::kTxSetEdge) { | |||
| @@ -521,10 +521,10 @@ void FuncGraphManager::ParseChanges(const std::vector<Change>& changes, EdgeTupl | |||
| auto param = args.cast<ArgsOfSetParams>(); | |||
| MS_EXCEPTION_IF_NULL(param.func_graph); | |||
| auto old_parameters = param.func_graph->parameters(); | |||
| for (auto& p : param.params) { | |||
| for (auto &p : param.params) { | |||
| (*adds)[p] += 1; | |||
| } | |||
| for (auto& p : old_parameters) { | |||
| for (auto &p : old_parameters) { | |||
| (*rms)[p] += 1; | |||
| } | |||
| param.func_graph->set_parameters(param.params); | |||
| @@ -532,7 +532,7 @@ void FuncGraphManager::ParseChanges(const std::vector<Change>& changes, EdgeTupl | |||
| } | |||
| } | |||
| void FuncGraphManager::CommitChanges(const std::vector<Change>& changes) { | |||
| void FuncGraphManager::CommitChanges(const std::vector<Change> &changes) { | |||
| EdgeTupleCounter add_edges; | |||
| EdgeTupleCounter rm_edges; | |||
| Counter<AnfNodePtr> adds; | |||
| @@ -540,7 +540,7 @@ void FuncGraphManager::CommitChanges(const std::vector<Change>& changes) { | |||
| ParseChanges(changes, &add_edges, &rm_edges, &adds, &rms); | |||
| auto sub_edges = add_edges - rm_edges; | |||
| for (auto& iter : sub_edges) { | |||
| for (auto &iter : sub_edges) { | |||
| auto root_node = iter.first.first; | |||
| int index = iter.first.second.first; | |||
| auto new_node = iter.first.second.second; | |||
| @@ -550,12 +550,12 @@ void FuncGraphManager::CommitChanges(const std::vector<Change>& changes) { | |||
| auto sub_nodes = adds - rms; | |||
| std::vector<AnfNodePtr> nodes; | |||
| (void)std::transform(sub_nodes.begin(), sub_nodes.end(), std::back_inserter(nodes), | |||
| [](const std::pair<const AnfNodePtr, int>& iter) -> AnfNodePtr { return iter.first; }); | |||
| [](const std::pair<const AnfNodePtr, int> &iter) -> AnfNodePtr { return iter.first; }); | |||
| AcquireNodes(nodes); | |||
| auto sub_edges_reverse = rm_edges - add_edges; | |||
| for (auto& iter : sub_edges_reverse) { | |||
| for (auto &iter : sub_edges_reverse) { | |||
| auto root_node = iter.first.first; | |||
| int index = iter.first.second.first; | |||
| auto old_node = iter.first.second.second; | |||
| @@ -566,17 +566,17 @@ void FuncGraphManager::CommitChanges(const std::vector<Change>& changes) { | |||
| std::vector<AnfNodePtr> nodes_reverse; | |||
| (void)std::transform(sub_nodes_reverse.begin(), sub_nodes_reverse.end(), std::back_inserter(nodes_reverse), | |||
| [](const std::pair<const AnfNodePtr, int>& iter) -> AnfNodePtr { return iter.first; }); | |||
| [](const std::pair<const AnfNodePtr, int> &iter) -> AnfNodePtr { return iter.first; }); | |||
| auto drop_func_graphs = MaybeDropNodes(nodes_reverse); | |||
| MaybeDropFuncGraphs(*drop_func_graphs); | |||
| } | |||
| void FuncGraphTransaction::SetParameters(FuncGraphPtr fg, const std::vector<AnfNodePtr>& params) { | |||
| void FuncGraphTransaction::SetParameters(FuncGraphPtr fg, const std::vector<AnfNodePtr> ¶ms) { | |||
| changes_.emplace_back(Change::kTxSetParams, ArgsOfSetParams{fg, params}); | |||
| } | |||
| bool FuncGraphTransaction::Replace(const AnfNodePtr& old_node, const AnfNodePtr& new_node) { | |||
| bool FuncGraphTransaction::Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node) { | |||
| MS_EXCEPTION_IF_NULL(old_node); | |||
| MS_EXCEPTION_IF_NULL(new_node); | |||
| FuncGraphPtr old_func_graph = old_node->func_graph(); | |||
| @@ -585,14 +585,14 @@ bool FuncGraphTransaction::Replace(const AnfNodePtr& old_node, const AnfNodePtr& | |||
| return false; | |||
| } | |||
| auto users = manager_->node_users()[old_node]; | |||
| for (auto& node : users) { | |||
| for (auto &node : users) { | |||
| SetEdge(node.first, node.second, new_node); | |||
| } | |||
| return true; | |||
| } | |||
| void FuncGraphTransaction::SetEdge(const AnfNodePtr& src_node, int k, const AnfNodePtr& v) { | |||
| void FuncGraphTransaction::SetEdge(const AnfNodePtr &src_node, int k, const AnfNodePtr &v) { | |||
| if (k < 0) { | |||
| MS_LOG(EXCEPTION) << "Invalid value k = " << k; | |||
| } | |||
| @@ -610,7 +610,7 @@ void FuncGraphTransaction::Commit() { | |||
| manager_->CommitChanges(changes); | |||
| } | |||
| FuncGraphAnalysis::FuncGraphAnalysis(const FuncGraphManager* const manager) | |||
| FuncGraphAnalysis::FuncGraphAnalysis(const FuncGraphManager *const manager) | |||
| : manager_(manager), include_func_graph_none_(false) { | |||
| manager_->signals()->AddFuncGraph.connect(this, &FuncGraphAnalysis::OnAddFuncGraph); | |||
| manager_->signals()->DropFuncGraph.connect(this, &FuncGraphAnalysis::OnDropFuncGraph); | |||
| @@ -619,7 +619,7 @@ FuncGraphAnalysis::FuncGraphAnalysis(const FuncGraphManager* const manager) | |||
| manager_->signals()->MoveAllCNode.connect(this, &FuncGraphAnalysis::OnMoveAllCNode); | |||
| } | |||
| NodesCollector::NodesCollector(const FuncGraphManager* const m) : DepCollector(m), nodes_analysis_() { | |||
| NodesCollector::NodesCollector(const FuncGraphManager *const m) : DepCollector(m), nodes_analysis_() { | |||
| include_func_graph_none_ = true; | |||
| nodes_analysis_[nullptr] = AnfNodeSet(); | |||
| @@ -646,7 +646,7 @@ void NodesCollector::OnDropNode(AnfNodePtr n) { | |||
| void NodesCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) { | |||
| // change the owner of node except for the src's return node | |||
| for (auto& it : nodes_analysis_[src]) { | |||
| for (auto &it : nodes_analysis_[src]) { | |||
| nodes_analysis_[dst].add(it); | |||
| } | |||
| (void)nodes_analysis_.erase(src); | |||
| @@ -654,15 +654,15 @@ void NodesCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) { | |||
| void DepCollector::OnAddEdge(AnfNodePtr node, int index, AnfNodePtr inp) { OnModEdge(node, index, inp, kIncEdge); } | |||
| DepCollector::DepCollector(const FuncGraphManager* const manager) : FuncGraphAnalysis(manager) { | |||
| DepCollector::DepCollector(const FuncGraphManager *const manager) : FuncGraphAnalysis(manager) { | |||
| MS_EXCEPTION_IF_NULL(manager_); | |||
| manager_->signals()->InvalidateCollector.connect(this, &DepCollector::OnInvalidateCollector); | |||
| } | |||
| void DepCollector::OnDropEdge(AnfNodePtr node, int index, AnfNodePtr inp) { OnModEdge(node, index, inp, kDecEdge); } | |||
| bool CounterAnfNodeCollector::Inc(const FuncGraphPtr& func_graph, const AnfNodePtr& key, int count = 1) { | |||
| auto& d = count_nodes_map_[func_graph]; | |||
| bool CounterAnfNodeCollector::Inc(const FuncGraphPtr &func_graph, const AnfNodePtr &key, int count = 1) { | |||
| auto &d = count_nodes_map_[func_graph]; | |||
| if (d.count(key) == 0) { | |||
| d[key] = count; | |||
| return true; | |||
| @@ -672,9 +672,9 @@ bool CounterAnfNodeCollector::Inc(const FuncGraphPtr& func_graph, const AnfNodeP | |||
| return false; | |||
| } | |||
| bool CounterAnfNodeCollector::Dec(const FuncGraphPtr& func_graph, const AnfNodePtr& key, int count = 1) { | |||
| bool CounterAnfNodeCollector::Dec(const FuncGraphPtr &func_graph, const AnfNodePtr &key, int count = 1) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| auto& d = count_nodes_map_[func_graph]; | |||
| auto &d = count_nodes_map_[func_graph]; | |||
| if (d.count(key) != 0) { | |||
| if (d[key] == count) { | |||
| (void)d.erase(key); | |||
| @@ -690,7 +690,7 @@ bool CounterAnfNodeCollector::Dec(const FuncGraphPtr& func_graph, const AnfNodeP | |||
| return false; | |||
| } | |||
| bool CounterAnfNodeCollector::Mod(const FuncGraphPtr& func_graph, const AnfNodePtr& key, int count) { | |||
| bool CounterAnfNodeCollector::Mod(const FuncGraphPtr &func_graph, const AnfNodePtr &key, int count) { | |||
| if (count > 0) { | |||
| return Inc(func_graph, key, count); | |||
| } else if (count < 0) { | |||
| @@ -701,8 +701,8 @@ bool CounterAnfNodeCollector::Mod(const FuncGraphPtr& func_graph, const AnfNodeP | |||
| } | |||
| } | |||
| bool CounterFuncGraphCollector::Inc(const FuncGraphPtr& func_graph, const FuncGraphPtr& key, int count = 1) { | |||
| auto& d = count_func_graphs_map_[func_graph]; | |||
| bool CounterFuncGraphCollector::Inc(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count = 1) { | |||
| auto &d = count_func_graphs_map_[func_graph]; | |||
| if (d.count(key) == 0) { | |||
| d[key] = count; | |||
| return true; | |||
| @@ -712,8 +712,8 @@ bool CounterFuncGraphCollector::Inc(const FuncGraphPtr& func_graph, const FuncGr | |||
| return false; | |||
| } | |||
| bool CounterFuncGraphCollector::Dec(const FuncGraphPtr& func_graph, const FuncGraphPtr& key, int count = 1) { | |||
| auto& d = count_func_graphs_map_[func_graph]; | |||
| bool CounterFuncGraphCollector::Dec(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count = 1) { | |||
| auto &d = count_func_graphs_map_[func_graph]; | |||
| if (d.count(key) != 0) { | |||
| if (d[key] == count) { | |||
| (void)d.erase(key); | |||
| @@ -729,7 +729,7 @@ bool CounterFuncGraphCollector::Dec(const FuncGraphPtr& func_graph, const FuncGr | |||
| return false; | |||
| } | |||
| bool CounterFuncGraphCollector::Mod(const FuncGraphPtr& func_graph, const FuncGraphPtr& key, int count) { | |||
| bool CounterFuncGraphCollector::Mod(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count) { | |||
| if (count > 0) { | |||
| return Inc(func_graph, key, count); | |||
| } else if (count < 0) { | |||
| @@ -748,7 +748,7 @@ void ValueNodesCollector::OnModEdge(AnfNodePtr node, int, AnfNodePtr inp, EdgePr | |||
| } | |||
| void ValueNodesCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) { | |||
| for (auto& it : count_nodes_map_[src]) { | |||
| for (auto &it : count_nodes_map_[src]) { | |||
| (void)Inc(dst, it.first, it.second); | |||
| } | |||
| (void)count_nodes_map_.erase(src); | |||
| @@ -762,7 +762,7 @@ void FuncGraphValueNodesCollector::OnModEdge(AnfNodePtr, int, AnfNodePtr inp, Ed | |||
| } | |||
| void FuncGraphValueNodesCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) { | |||
| for (auto& it : count_nodes_map_[src]) { | |||
| for (auto &it : count_nodes_map_[src]) { | |||
| (void)Inc(dst, it.first, it.second); | |||
| } | |||
| (void)count_nodes_map_.erase(src); | |||
| @@ -779,7 +779,7 @@ void FVDirectCollector::OnModEdge(AnfNodePtr node, int, AnfNodePtr inp, EdgeProc | |||
| } | |||
| void FVDirectCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) { | |||
| for (auto& it : count_nodes_map_[src]) { | |||
| for (auto &it : count_nodes_map_[src]) { | |||
| FuncGraphPtr fg2 = it.first->func_graph(); | |||
| if (fg2 != dst) { | |||
| (void)Inc(dst, it.first, it.second); | |||
| @@ -788,7 +788,7 @@ void FVDirectCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) { | |||
| (void)count_nodes_map_.erase(src); | |||
| } | |||
| static FuncGraphPtr ParentProxy(const FuncGraphPtr& fg) { | |||
| static FuncGraphPtr ParentProxy(const FuncGraphPtr &fg) { | |||
| FuncGraphPtr gn = std::make_shared<FuncGraph>(); | |||
| (void)gn->transforms().insert(std::make_pair("proxy", FuncGraphTransform(fg))); | |||
| return gn; | |||
| @@ -805,7 +805,7 @@ void FuncGraphChildDirect::OnModEdge(AnfNodePtr node, int, AnfNodePtr inp, EdgeP | |||
| } | |||
| void FuncGraphChildDirect::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) { | |||
| for (auto& it : count_func_graphs_map_[src]) { | |||
| for (auto &it : count_func_graphs_map_[src]) { | |||
| FuncGraphPtr fg = it.first; | |||
| if (fg != dst) { | |||
| (void)Inc(dst, fg, it.second); | |||
| @@ -835,7 +835,7 @@ void FuncGraphParentsDirectCollector::OnModEdge(AnfNodePtr node, int, AnfNodePtr | |||
| } | |||
| void FuncGraphParentsDirectCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) { | |||
| for (auto& it : count_func_graphs_map_[src]) { | |||
| for (auto &it : count_func_graphs_map_[src]) { | |||
| if (it.first != dst) { | |||
| (void)Inc(dst, it.first, it.second); | |||
| } | |||
| @@ -852,7 +852,7 @@ void FuncGraphsUsedCollector::OnModEdge(AnfNodePtr node, int, AnfNodePtr inp, Ed | |||
| void FuncGraphsUsedCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) { | |||
| // all graph use in src need to change to dst, so meger the to dst use | |||
| for (auto& it : count_func_graphs_map_[src]) { | |||
| for (auto &it : count_func_graphs_map_[src]) { | |||
| (void)Inc(dst, it.first, it.second); | |||
| } | |||
| (void)count_func_graphs_map_[dst].erase(src); | |||
| @@ -879,7 +879,7 @@ void FuncGraphUserNodesCollector::OnModEdge(AnfNodePtr node, int, AnfNodePtr inp | |||
| } | |||
| void FuncGraphUserNodesCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) { | |||
| for (auto& it : count_nodes_map_[src]) { | |||
| for (auto &it : count_nodes_map_[src]) { | |||
| (void)Inc(dst, it.first, it.second); | |||
| } | |||
| (void)count_nodes_map_.erase(src); | |||
| @@ -895,13 +895,13 @@ void FuncGraphJDirectCollector::OnModEdge(AnfNodePtr node, int, AnfNodePtr inp, | |||
| void FuncGraphJDirectCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) { | |||
| // all graph use in src need to change to dst, so meger the to dst use | |||
| for (auto& it : count_func_graphs_map_[src]) { | |||
| for (auto &it : count_func_graphs_map_[src]) { | |||
| (void)Inc(dst, it.first, it.second); | |||
| } | |||
| (void)count_func_graphs_map_.erase(src); | |||
| } | |||
| DepComputer::DepComputer(const FuncGraphManager* const manager) : FuncGraphAnalysis(manager) { | |||
| DepComputer::DepComputer(const FuncGraphManager *const manager) : FuncGraphAnalysis(manager) { | |||
| MS_EXCEPTION_IF_NULL(manager_); | |||
| manager_->signals()->InvalidateComputer.connect(this, &DepComputer::OnInvalidateComputer); | |||
| validate_ = false; | |||
| @@ -914,20 +914,20 @@ void DepComputer::Recompute() { | |||
| } | |||
| } | |||
| void DepComputer::Recompute(const FuncGraphPtr& fg) { | |||
| void DepComputer::Recompute(const FuncGraphPtr &fg) { | |||
| if (func_graphs_validate_.count(fg) == 0 || !func_graphs_validate_[fg]) { | |||
| RealRecompute(fg); | |||
| func_graphs_validate_[fg] = true; | |||
| } | |||
| } | |||
| FuncGraphSetPtr FuncGraphParentsTotalComputer::SeekParents(const FuncGraphPtr& fg, const FuncGraphSetPtr& path) { | |||
| FuncGraphSetPtr FuncGraphParentsTotalComputer::SeekParents(const FuncGraphPtr &fg, const FuncGraphSetPtr &path) { | |||
| if (path == nullptr || path->contains(fg)) { | |||
| return std::make_shared<FuncGraphSet>(); | |||
| } | |||
| FuncGraphSetPtr parents = std::make_shared<FuncGraphSet>(); | |||
| FuncGraphToFuncGraphCounterMap& deps = *all_parents_direct_; | |||
| for (auto& dep : deps[fg]) { | |||
| FuncGraphToFuncGraphCounterMap &deps = *all_parents_direct_; | |||
| for (auto &dep : deps[fg]) { | |||
| MS_EXCEPTION_IF_NULL(dep.first); | |||
| auto proxy = dep.first->transforms().find("proxy"); | |||
| if (proxy != dep.first->transforms().end()) { | |||
| @@ -950,7 +950,7 @@ void FuncGraphParentsTotalComputer::RealRecompute(FuncGraphPtr fg) { | |||
| MS_LOG(DEBUG) << "FuncGraphParentsTotalComputer end: " << func_graph_parents_total_analysis_[fg].size(); | |||
| } | |||
| bool set_len_compare(const FuncGraphSetPair& lhs, const FuncGraphSetPair& rhs) { | |||
| bool set_len_compare(const FuncGraphSetPair &lhs, const FuncGraphSetPair &rhs) { | |||
| auto l1 = lhs.second.size(); | |||
| auto l2 = rhs.second.size(); | |||
| return l1 < l2; | |||
| @@ -970,9 +970,9 @@ void ParentComputer::RealRecompute(FuncGraphPtr fg) { | |||
| } else { | |||
| // return nearest parent as parent | |||
| FuncGraphSet deps_copy(deps); | |||
| for (auto& dep : deps) { | |||
| for (auto &dep : deps) { | |||
| auto parent_deps = this->manager_->func_graph_parents_total(dep); | |||
| for (auto& p_d : parent_deps) { | |||
| for (auto &p_d : parent_deps) { | |||
| if (deps_copy.count(p_d)) { | |||
| (void)deps_copy.erase(p_d); | |||
| } | |||
| @@ -988,7 +988,7 @@ void ParentComputer::RealRecompute(FuncGraphPtr fg) { | |||
| void ChildrenComputer::RealRecompute(FuncGraphPtr fg) { | |||
| MS_EXCEPTION_IF_NULL(manager_); | |||
| auto used_fg_total = manager_->func_graphs_used_total(fg); | |||
| for (auto& used_fg : used_fg_total) { | |||
| for (auto &used_fg : used_fg_total) { | |||
| if (manager_->parent(used_fg) == fg) { | |||
| children_analysis_[fg].add(used_fg); | |||
| } | |||
| @@ -997,11 +997,11 @@ void ChildrenComputer::RealRecompute(FuncGraphPtr fg) { | |||
| void ScopeComputer::RealRecompute(FuncGraphPtr fg) { | |||
| MS_EXCEPTION_IF_NULL(manager_); | |||
| auto& children = manager_->children(fg); | |||
| auto &children = manager_->children(fg); | |||
| scope_analysis_[fg] = FuncGraphSet(); | |||
| scope_analysis_[fg].add(fg); | |||
| for (auto& child : children) { | |||
| for (auto &child : children) { | |||
| scope_analysis_[fg].add(child); | |||
| } | |||
| } | |||
| @@ -1010,20 +1010,20 @@ void FVTotalComputer::RealRecompute() { | |||
| auto manager = DepComputer::manager_; | |||
| MS_EXCEPTION_IF_NULL(manager); | |||
| for (auto& fg : manager->func_graphs()) { | |||
| for (auto &fg : manager->func_graphs()) { | |||
| fv_total_analysis_[fg] = OrderedMap<BaseRef, int, BaseRefHash>(); | |||
| count_nodes_map_[fg] = OrderedMap<AnfNodePtr, int>(); | |||
| count_func_graphs_map_[fg] = OrderedMap<FuncGraphPtr, int>(); | |||
| } | |||
| for (auto& fg : manager->func_graphs()) { | |||
| for (auto &fg : manager->func_graphs()) { | |||
| AnfNodeCounterMap items = manager->free_variables_direct()[fg]; | |||
| for (auto& iter : items) { | |||
| for (auto &iter : items) { | |||
| auto curr = fg; | |||
| while (curr) { | |||
| (void)CounterAnfNodeCollector::Mod(curr, iter.first, iter.second); | |||
| curr = manager->parent(curr); | |||
| const AnfNodeSet& nodes = manager->nodes()[curr]; | |||
| const AnfNodeSet &nodes = manager->nodes()[curr]; | |||
| if (nodes.contains(iter.first)) { | |||
| break; | |||
| } | |||
| @@ -1031,7 +1031,7 @@ void FVTotalComputer::RealRecompute() { | |||
| } | |||
| auto items_fg = manager->func_graphs_used()[fg]; | |||
| for (auto& iter : items_fg) { | |||
| for (auto &iter : items_fg) { | |||
| auto p = manager->parent(iter.first); | |||
| if (p == nullptr) { | |||
| continue; | |||
| @@ -1043,13 +1043,13 @@ void FVTotalComputer::RealRecompute() { | |||
| } | |||
| } | |||
| } | |||
| for (auto& fg : manager->func_graphs()) { | |||
| auto& fvp = count_nodes_map_[fg]; | |||
| auto& fvg = count_func_graphs_map_[fg]; | |||
| for (auto& item : fvp) { | |||
| for (auto &fg : manager->func_graphs()) { | |||
| auto &fvp = count_nodes_map_[fg]; | |||
| auto &fvg = count_func_graphs_map_[fg]; | |||
| for (auto &item : fvp) { | |||
| fv_total_analysis_[fg][item.first] = item.second; | |||
| } | |||
| for (auto& item : fvg) { | |||
| for (auto &item : fvg) { | |||
| fv_total_analysis_[fg][item.first] = item.second; | |||
| } | |||
| } | |||
| @@ -1057,15 +1057,15 @@ void FVTotalComputer::RealRecompute() { | |||
| void FuncGraphsUsedTotalComputer::RealRecompute(FuncGraphPtr fg) { | |||
| MS_EXCEPTION_IF_NULL(manager_); | |||
| auto& used = this->manager_->func_graphs_used(); | |||
| auto &used = this->manager_->func_graphs_used(); | |||
| std::vector<FuncGraphPtr> todo; | |||
| std::vector<FuncGraphPtr> todo_new; | |||
| todo.push_back(fg); | |||
| while (!todo.empty()) { | |||
| todo_new.clear(); | |||
| for (auto& gt : todo) { | |||
| for (auto& item : used[gt]) { | |||
| for (auto > : todo) { | |||
| for (auto &item : used[gt]) { | |||
| auto used_fg = item.first; | |||
| if (used_fg == fg) { | |||
| func_graph_used_total_analysis_[fg].add(used_fg); | |||
| @@ -1082,17 +1082,17 @@ void FuncGraphsUsedTotalComputer::RealRecompute(FuncGraphPtr fg) { | |||
| } | |||
| } | |||
| bool CheckRecursive(const FuncGraphManager* const manager, const FuncGraphPtr& fg) { | |||
| bool CheckRecursive(const FuncGraphManager *const manager, const FuncGraphPtr &fg) { | |||
| MS_EXCEPTION_IF_NULL(manager); | |||
| auto& used = manager->func_graphs_used(); | |||
| auto &used = manager->func_graphs_used(); | |||
| std::vector<FuncGraphPtr> todo; | |||
| std::vector<FuncGraphPtr> todo_new; | |||
| todo.push_back(fg); | |||
| FuncGraphSet used_total; | |||
| while (!todo.empty()) { | |||
| todo_new.clear(); | |||
| for (auto& gt : todo) { | |||
| for (auto& item : used[gt]) { | |||
| for (auto > : todo) { | |||
| for (auto &item : used[gt]) { | |||
| auto used_g = item.first; | |||
| if (used_g == fg) { | |||
| return true; | |||
| @@ -1112,7 +1112,7 @@ void RecursiveComputer::RealRecompute(FuncGraphPtr fg) { | |||
| this->recursive_analysis_[fg] = CheckRecursive(this->manager_, fg); | |||
| } | |||
| void RecursiveComputer::CheckRecursiveGraphs(const FuncGraphPtr& fg, std::list<FuncGraphPtr>* trace) { | |||
| void RecursiveComputer::CheckRecursiveGraphs(const FuncGraphPtr &fg, std::list<FuncGraphPtr> *trace) { | |||
| MS_EXCEPTION_IF_NULL(trace); | |||
| auto res = std::find(trace->begin(), trace->end(), fg); | |||
| // find recursive | |||
| @@ -1124,7 +1124,7 @@ void RecursiveComputer::CheckRecursiveGraphs(const FuncGraphPtr& fg, std::list<F | |||
| } | |||
| } else { | |||
| trace->push_back(fg); | |||
| auto& used_fgs = manager_->func_graphs_used()[fg]; | |||
| auto &used_fgs = manager_->func_graphs_used()[fg]; | |||
| for (auto iter = used_fgs.begin(); iter != used_fgs.end(); (void)iter++) { | |||
| CheckRecursiveGraphs(iter->first, trace); | |||
| } | |||
| @@ -1135,14 +1135,14 @@ void RecursiveComputer::CheckRecursiveGraphs(const FuncGraphPtr& fg, std::list<F | |||
| } | |||
| } | |||
| bool FuncGraphJTotalComputer::SeekJ(const FuncGraphPtr& fg, const FuncGraphSetPtr& path) { | |||
| bool FuncGraphJTotalComputer::SeekJ(const FuncGraphPtr &fg, const FuncGraphSetPtr &path) { | |||
| MS_EXCEPTION_IF_NULL(path); | |||
| if (path->contains(fg)) { | |||
| MS_LOG(DEBUG) << "" << fg->ToString() << " had been checked"; | |||
| return false; | |||
| } | |||
| MS_EXCEPTION_IF_NULL(manager_); | |||
| auto& func_graph_counter_map = manager_->func_graph_j_direct(); | |||
| auto &func_graph_counter_map = manager_->func_graph_j_direct(); | |||
| if (!func_graph_counter_map[fg].empty()) { | |||
| // check g1->J(fg)->g2->g cycle; | |||
| auto contains_j = | |||
| @@ -1156,8 +1156,8 @@ bool FuncGraphJTotalComputer::SeekJ(const FuncGraphPtr& fg, const FuncGraphSetPt | |||
| path->add(fg); | |||
| // check if func graphs used contains J(func_graph); | |||
| auto& used = this->manager_->func_graphs_used(); | |||
| for (auto& item : used[fg]) { | |||
| auto &used = this->manager_->func_graphs_used(); | |||
| for (auto &item : used[fg]) { | |||
| auto used_g = item.first; | |||
| if (SeekJ(used_g, path)) { | |||
| MS_LOG(DEBUG) << "" << fg->ToString() << " users func graph " << used_g->ToString() | |||
| @@ -46,13 +46,13 @@ class FuncGraphManager; | |||
| using FuncGraphManagerPtr = std::shared_ptr<FuncGraphManager>; | |||
| struct AnfNodeIndexPairHasher { | |||
| std::size_t operator()(const std::pair<AnfNodePtr, int>& p1) const { | |||
| return std::hash<const AnfNode*>{}(p1.first.get()); | |||
| std::size_t operator()(const std::pair<AnfNodePtr, int> &p1) const { | |||
| return std::hash<const AnfNode *>{}(p1.first.get()); | |||
| } | |||
| }; | |||
| struct AnfNodeIndexPairEqual { | |||
| bool operator()(const std::pair<AnfNodePtr, int>& lhs, const std::pair<AnfNodePtr, int>& rhs) const { | |||
| bool operator()(const std::pair<AnfNodePtr, int> &lhs, const std::pair<AnfNodePtr, int> &rhs) const { | |||
| return lhs == rhs; | |||
| } | |||
| }; | |||
| @@ -63,14 +63,14 @@ using FuncGraphSetPair = std::pair<FuncGraphPtr, FuncGraphSet>; | |||
| using FuncGraphSetPtr = std::shared_ptr<FuncGraphSet>; | |||
| using EdgeTuple = std::pair<AnfNodePtr, std::pair<int, AnfNodePtr>>; | |||
| struct EdgeTupleHasher { | |||
| std::size_t operator()(const EdgeTuple& p1) const { | |||
| return hash_combine({std::hash<AnfNode*>{}(p1.first.get()), std::hash<int>{}(p1.second.first), | |||
| std::hash<AnfNode*>{}(p1.second.second.get())}); | |||
| std::size_t operator()(const EdgeTuple &p1) const { | |||
| return hash_combine({std::hash<AnfNode *>{}(p1.first.get()), std::hash<int>{}(p1.second.first), | |||
| std::hash<AnfNode *>{}(p1.second.second.get())}); | |||
| } | |||
| }; | |||
| struct EdgeTupleEqual { | |||
| bool operator()(const EdgeTuple& lhs, const EdgeTuple& rhs) const { | |||
| bool operator()(const EdgeTuple &lhs, const EdgeTuple &rhs) const { | |||
| return lhs.first == rhs.first && lhs.second.first == rhs.second.first && lhs.second.second == rhs.second.second; | |||
| } | |||
| }; | |||
| @@ -82,9 +82,9 @@ using EdgeTupleCounter = Counter<EdgeTuple, EdgeTupleHasher, EdgeTupleEqual>; | |||
| // FuncGraphManagerPtr: return created manager | |||
| FuncGraphManagerPtr Manage(FuncGraphPtr func_graph, bool manage = true); | |||
| FuncGraphManagerPtr Manage(const std::vector<FuncGraphPtr>& func_graphs, bool manage = true); | |||
| FuncGraphManagerPtr Manage(const std::vector<FuncGraphPtr> &func_graphs, bool manage = true); | |||
| FuncGraphManagerPtr MakeManager(const std::vector<FuncGraphPtr>& func_graphs = {}, bool manage = true); | |||
| FuncGraphManagerPtr MakeManager(const std::vector<FuncGraphPtr> &func_graphs = {}, bool manage = true); | |||
| struct Signals { | |||
| Signal<void(FuncGraphPtr)> AddFuncGraph; | |||
| @@ -106,7 +106,7 @@ using FuncGraphToAnfNodeCounterMap = OrderedMap<FuncGraphPtr, OrderedMap<AnfNode | |||
| // analysis base class | |||
| class FuncGraphAnalysis { | |||
| public: | |||
| explicit FuncGraphAnalysis(const FuncGraphManager* const manager); | |||
| explicit FuncGraphAnalysis(const FuncGraphManager *const manager); | |||
| virtual ~FuncGraphAnalysis() { manager_ = nullptr; } | |||
| @@ -130,7 +130,7 @@ class FuncGraphAnalysis { | |||
| virtual void OnDropEdge(AnfNodePtr, int, AnfNodePtr) {} | |||
| const FuncGraphManager* manager_; | |||
| const FuncGraphManager *manager_; | |||
| bool include_func_graph_none_; | |||
| }; | |||
| @@ -139,7 +139,7 @@ using FuncGraphToAnfNodeMap = OrderedMap<FuncGraphPtr, AnfNodeSet>; | |||
| // graphs analysis which compute in write, read needn't recompute | |||
| class DepCollector : public FuncGraphAnalysis { | |||
| public: | |||
| explicit DepCollector(const FuncGraphManager* manager); | |||
| explicit DepCollector(const FuncGraphManager *manager); | |||
| ~DepCollector() override = default; | |||
| void Reset() { ExtraReset(); } | |||
| @@ -155,10 +155,10 @@ class DepCollector : public FuncGraphAnalysis { | |||
| class NodesCollector final : public DepCollector { | |||
| public: | |||
| explicit NodesCollector(const FuncGraphManager* m); | |||
| explicit NodesCollector(const FuncGraphManager *m); | |||
| ~NodesCollector() override = default; | |||
| const FuncGraphToAnfNodeMap& nodes_analysis() const { return nodes_analysis_; } | |||
| const FuncGraphToAnfNodeMap &nodes_analysis() const { return nodes_analysis_; } | |||
| size_t size() const override { return nodes_analysis_.size(); } | |||
| void OnAddFuncGraph(FuncGraphPtr fg) override { nodes_analysis_[fg] = AnfNodeSet(); } | |||
| @@ -176,16 +176,16 @@ class NodesCollector final : public DepCollector { | |||
| class CounterFuncGraphCollector : public DepCollector { | |||
| public: | |||
| explicit CounterFuncGraphCollector(const FuncGraphManager* m) : DepCollector(m) {} | |||
| explicit CounterFuncGraphCollector(const FuncGraphManager *m) : DepCollector(m) {} | |||
| ~CounterFuncGraphCollector() override = default; | |||
| FuncGraphToFuncGraphCounterMap& count_func_graphs_map() { return count_func_graphs_map_; } | |||
| FuncGraphToFuncGraphCounterMap &count_func_graphs_map() { return count_func_graphs_map_; } | |||
| // inherit from FuncGraphAnalysis | |||
| size_t size() const override { return count_func_graphs_map_.size(); } | |||
| void OnAddFuncGraph(FuncGraphPtr fg) final { count_func_graphs_map_[fg] = OrderedMap<FuncGraphPtr, int>(); } | |||
| void OnDropFuncGraph(FuncGraphPtr fg) final { (void)count_func_graphs_map_.erase(fg); } | |||
| bool Inc(const FuncGraphPtr& func_graph, const FuncGraphPtr& key, int count); | |||
| bool Dec(const FuncGraphPtr& func_graph, const FuncGraphPtr& key, int count); | |||
| bool Mod(const FuncGraphPtr& func_graph, const FuncGraphPtr& key, int count); | |||
| bool Inc(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count); | |||
| bool Dec(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count); | |||
| bool Mod(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count); | |||
| FuncGraphToFuncGraphCounterMap count_func_graphs_map_; | |||
| @@ -195,17 +195,17 @@ class CounterFuncGraphCollector : public DepCollector { | |||
| class CounterAnfNodeCollector : public DepCollector { | |||
| public: | |||
| explicit CounterAnfNodeCollector(const FuncGraphManager* m) : DepCollector(m) {} | |||
| explicit CounterAnfNodeCollector(const FuncGraphManager *m) : DepCollector(m) {} | |||
| ~CounterAnfNodeCollector() override = default; | |||
| FuncGraphToAnfNodeCounterMap& count_nodes_map() { return count_nodes_map_; } | |||
| FuncGraphToAnfNodeCounterMap &count_nodes_map() { return count_nodes_map_; } | |||
| size_t size() const override { return count_nodes_map_.size(); } | |||
| void OnAddFuncGraph(FuncGraphPtr fg) final { count_nodes_map_[fg] = OrderedMap<AnfNodePtr, int>(); } | |||
| void OnDropFuncGraph(FuncGraphPtr fg) final { (void)count_nodes_map_.erase(fg); } | |||
| bool Inc(const FuncGraphPtr& func_graph, const AnfNodePtr& key, int count); | |||
| bool Dec(const FuncGraphPtr& func_graph, const AnfNodePtr& key, int count); | |||
| bool Mod(const FuncGraphPtr& func_graph, const AnfNodePtr& key, int count); | |||
| bool Inc(const FuncGraphPtr &func_graph, const AnfNodePtr &key, int count); | |||
| bool Dec(const FuncGraphPtr &func_graph, const AnfNodePtr &key, int count); | |||
| bool Mod(const FuncGraphPtr &func_graph, const AnfNodePtr &key, int count); | |||
| FuncGraphToAnfNodeCounterMap count_nodes_map_; | |||
| @@ -215,7 +215,7 @@ class CounterAnfNodeCollector : public DepCollector { | |||
| class ValueNodesCollector final : public CounterAnfNodeCollector { | |||
| public: | |||
| explicit ValueNodesCollector(const FuncGraphManager* m) : CounterAnfNodeCollector(m) {} | |||
| explicit ValueNodesCollector(const FuncGraphManager *m) : CounterAnfNodeCollector(m) {} | |||
| ~ValueNodesCollector() override = default; | |||
| void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override; | |||
| @@ -225,7 +225,7 @@ class ValueNodesCollector final : public CounterAnfNodeCollector { | |||
| class FuncGraphValueNodesCollector final : public CounterAnfNodeCollector { | |||
| public: | |||
| explicit FuncGraphValueNodesCollector(const FuncGraphManager* m) : CounterAnfNodeCollector(m) {} | |||
| explicit FuncGraphValueNodesCollector(const FuncGraphManager *m) : CounterAnfNodeCollector(m) {} | |||
| ~FuncGraphValueNodesCollector() override = default; | |||
| void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override; | |||
| @@ -235,7 +235,7 @@ class FuncGraphValueNodesCollector final : public CounterAnfNodeCollector { | |||
| class FVDirectCollector final : public CounterAnfNodeCollector { | |||
| public: | |||
| explicit FVDirectCollector(const FuncGraphManager* m) : CounterAnfNodeCollector(m) {} | |||
| explicit FVDirectCollector(const FuncGraphManager *m) : CounterAnfNodeCollector(m) {} | |||
| ~FVDirectCollector() override = default; | |||
| void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override; | |||
| @@ -245,7 +245,7 @@ class FVDirectCollector final : public CounterAnfNodeCollector { | |||
| class FuncGraphChildDirect final : public CounterFuncGraphCollector { | |||
| public: | |||
| explicit FuncGraphChildDirect(const FuncGraphManager* m) : CounterFuncGraphCollector(m) {} | |||
| explicit FuncGraphChildDirect(const FuncGraphManager *m) : CounterFuncGraphCollector(m) {} | |||
| void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override; | |||
| ~FuncGraphChildDirect() override = default; | |||
| @@ -260,7 +260,7 @@ class FuncGraphChildDirect final : public CounterFuncGraphCollector { | |||
| // 2.direct parent: if graph g's node a used free_variable node in graph f, g's direct parent is f key is g, value is f | |||
| class FuncGraphParentsDirectCollector final : public CounterFuncGraphCollector { | |||
| public: | |||
| explicit FuncGraphParentsDirectCollector(const FuncGraphManager* m) : CounterFuncGraphCollector(m) {} | |||
| explicit FuncGraphParentsDirectCollector(const FuncGraphManager *m) : CounterFuncGraphCollector(m) {} | |||
| ~FuncGraphParentsDirectCollector() override = default; | |||
| void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override; | |||
| @@ -271,7 +271,7 @@ class FuncGraphParentsDirectCollector final : public CounterFuncGraphCollector { | |||
| // graph's all used graphs: key is g, value is g used graph | |||
| class FuncGraphsUsedCollector final : public CounterFuncGraphCollector { | |||
| public: | |||
| explicit FuncGraphsUsedCollector(const FuncGraphManager* m) : CounterFuncGraphCollector(m) {} | |||
| explicit FuncGraphsUsedCollector(const FuncGraphManager *m) : CounterFuncGraphCollector(m) {} | |||
| void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override; | |||
| ~FuncGraphsUsedCollector() override = default; | |||
| @@ -282,7 +282,7 @@ class FuncGraphsUsedCollector final : public CounterFuncGraphCollector { | |||
| // graph's all user graphs: key is g, value is graphs who used g | |||
| class FuncGraphUsersCollector final : public CounterFuncGraphCollector { | |||
| public: | |||
| explicit FuncGraphUsersCollector(const FuncGraphManager* m) : CounterFuncGraphCollector(m) {} | |||
| explicit FuncGraphUsersCollector(const FuncGraphManager *m) : CounterFuncGraphCollector(m) {} | |||
| void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override; | |||
| ~FuncGraphUsersCollector() override = default; | |||
| @@ -293,7 +293,7 @@ class FuncGraphUsersCollector final : public CounterFuncGraphCollector { | |||
| // graph's all user cnodes: key is g, value is cnodes who used g | |||
| class FuncGraphUserNodesCollector final : public CounterAnfNodeCollector { | |||
| public: | |||
| explicit FuncGraphUserNodesCollector(const FuncGraphManager* m) : CounterAnfNodeCollector(m) {} | |||
| explicit FuncGraphUserNodesCollector(const FuncGraphManager *m) : CounterAnfNodeCollector(m) {} | |||
| void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override; | |||
| ~FuncGraphUserNodesCollector() override = default; | |||
| @@ -303,7 +303,7 @@ class FuncGraphUserNodesCollector final : public CounterAnfNodeCollector { | |||
| class FuncGraphJDirectCollector final : public CounterFuncGraphCollector { | |||
| public: | |||
| explicit FuncGraphJDirectCollector(const FuncGraphManager* m) : CounterFuncGraphCollector(m) {} | |||
| explicit FuncGraphJDirectCollector(const FuncGraphManager *m) : CounterFuncGraphCollector(m) {} | |||
| void OnMoveAllCNode(FuncGraphPtr src, const FuncGraphPtr dst) override; | |||
| ~FuncGraphJDirectCollector() override = default; | |||
| @@ -316,7 +316,7 @@ using FuncGraphToFuncGraphSetMap = OrderedMap<FuncGraphPtr, FuncGraphSet>; | |||
| // graphs analysis which need dynamic compute by DepCollector in each read | |||
| class DepComputer : public FuncGraphAnalysis { | |||
| public: | |||
| explicit DepComputer(const FuncGraphManager* manager); | |||
| explicit DepComputer(const FuncGraphManager *manager); | |||
| ~DepComputer() override = default; | |||
| void Reset() { | |||
| @@ -329,11 +329,11 @@ class DepComputer : public FuncGraphAnalysis { | |||
| void Recompute(); | |||
| void Recompute(const FuncGraphPtr& fg); | |||
| void Recompute(const FuncGraphPtr &fg); | |||
| bool IsValidate() const { return validate_; } | |||
| bool IsValidate(const FuncGraphPtr& fg) { return func_graphs_validate_[fg]; } | |||
| bool IsValidate(const FuncGraphPtr &fg) { return func_graphs_validate_[fg]; } | |||
| void OnAddFuncGraph(FuncGraphPtr) final { Reset(); } | |||
| @@ -354,10 +354,10 @@ class DepComputer : public FuncGraphAnalysis { | |||
| // graph g's all direct or proxy parents | |||
| class FuncGraphParentsTotalComputer final : public DepComputer { | |||
| public: | |||
| explicit FuncGraphParentsTotalComputer(const FuncGraphManager* m) : DepComputer(m), all_parents_direct_(nullptr) {} | |||
| explicit FuncGraphParentsTotalComputer(const FuncGraphManager *m) : DepComputer(m), all_parents_direct_(nullptr) {} | |||
| ~FuncGraphParentsTotalComputer() override { all_parents_direct_ = nullptr; } | |||
| FuncGraphToFuncGraphSetMap& func_graph_parents_total_analysis() { return func_graph_parents_total_analysis_; } | |||
| FuncGraphToFuncGraphSetMap &func_graph_parents_total_analysis() { return func_graph_parents_total_analysis_; } | |||
| size_t size() const override { return func_graph_parents_total_analysis_.size(); } | |||
| @@ -369,10 +369,10 @@ class FuncGraphParentsTotalComputer final : public DepComputer { | |||
| void RealRecompute(FuncGraphPtr fg) override; | |||
| private: | |||
| FuncGraphSetPtr SeekParents(const FuncGraphPtr& fg, const FuncGraphSetPtr& path = std::make_shared<FuncGraphSet>()); | |||
| FuncGraphSetPtr SeekParents(const FuncGraphPtr &fg, const FuncGraphSetPtr &path = std::make_shared<FuncGraphSet>()); | |||
| // when SeekParents calls itself recursively, it can access these variables by class member | |||
| // other than pass by formal parameters, it can save 1 parameter for SeekParents(). | |||
| FuncGraphToFuncGraphCounterMap* all_parents_direct_; | |||
| FuncGraphToFuncGraphCounterMap *all_parents_direct_; | |||
| }; | |||
| using FuncGraphToFuncGraphMap = OrderedMap<FuncGraphPtr, FuncGraphPtr>; | |||
| @@ -380,10 +380,10 @@ using FuncGraphToFuncGraphMap = OrderedMap<FuncGraphPtr, FuncGraphPtr>; | |||
| // graph's nearest parent in parents total | |||
| class ParentComputer final : public DepComputer { | |||
| public: | |||
| explicit ParentComputer(const FuncGraphManager* m) : DepComputer(m) {} | |||
| explicit ParentComputer(const FuncGraphManager *m) : DepComputer(m) {} | |||
| ~ParentComputer() override = default; | |||
| FuncGraphToFuncGraphMap& parent_analysis() { return parent_analysis_; } | |||
| FuncGraphToFuncGraphMap &parent_analysis() { return parent_analysis_; } | |||
| size_t size() const override { return parent_analysis_.size(); } | |||
| @@ -398,10 +398,10 @@ class ParentComputer final : public DepComputer { | |||
| // graph's children graph except self | |||
| class ChildrenComputer final : public DepComputer { | |||
| public: | |||
| explicit ChildrenComputer(const FuncGraphManager* m) : DepComputer(m) {} | |||
| explicit ChildrenComputer(const FuncGraphManager *m) : DepComputer(m) {} | |||
| ~ChildrenComputer() override = default; | |||
| FuncGraphToFuncGraphSetMap& children_analysis() { return children_analysis_; } | |||
| FuncGraphToFuncGraphSetMap &children_analysis() { return children_analysis_; } | |||
| size_t size() const override { return children_analysis_.size(); } | |||
| @@ -416,10 +416,10 @@ class ChildrenComputer final : public DepComputer { | |||
| // graph's children graph include self | |||
| class ScopeComputer final : public DepComputer { | |||
| public: | |||
| explicit ScopeComputer(const FuncGraphManager* m) : DepComputer(m) {} | |||
| explicit ScopeComputer(const FuncGraphManager *m) : DepComputer(m) {} | |||
| ~ScopeComputer() override = default; | |||
| FuncGraphToFuncGraphSetMap& scope_analysis() { return scope_analysis_; } | |||
| FuncGraphToFuncGraphSetMap &scope_analysis() { return scope_analysis_; } | |||
| size_t size() const override { return scope_analysis_.size(); } | |||
| @@ -435,11 +435,11 @@ using FVTotalMap = OrderedMap<FuncGraphPtr, OrderedMap<BaseRef, int, BaseRefHash | |||
| class FVTotalComputer final : public DepComputer, public CounterAnfNodeCollector, public CounterFuncGraphCollector { | |||
| public: | |||
| explicit FVTotalComputer(const FuncGraphManager* m) | |||
| explicit FVTotalComputer(const FuncGraphManager *m) | |||
| : DepComputer(m), CounterAnfNodeCollector(m), CounterFuncGraphCollector(m) {} | |||
| ~FVTotalComputer() override = default; | |||
| FVTotalMap& fv_total_analysis() { return fv_total_analysis_; } | |||
| FVTotalMap &fv_total_analysis() { return fv_total_analysis_; } | |||
| size_t size() const override { return fv_total_analysis_.size(); } | |||
| @@ -453,10 +453,10 @@ class FVTotalComputer final : public DepComputer, public CounterAnfNodeCollector | |||
| class FuncGraphsUsedTotalComputer final : public DepComputer { | |||
| public: | |||
| explicit FuncGraphsUsedTotalComputer(const FuncGraphManager* m) : DepComputer(m) {} | |||
| explicit FuncGraphsUsedTotalComputer(const FuncGraphManager *m) : DepComputer(m) {} | |||
| ~FuncGraphsUsedTotalComputer() override = default; | |||
| FuncGraphToFuncGraphSetMap& func_graph_used_total_analysis() { return func_graph_used_total_analysis_; } | |||
| FuncGraphToFuncGraphSetMap &func_graph_used_total_analysis() { return func_graph_used_total_analysis_; } | |||
| size_t size() const override { return func_graph_used_total_analysis_.size(); } | |||
| @@ -473,13 +473,13 @@ using RecursiveMap = OrderedMap<FuncGraphPtr, std::shared_ptr<std::list<FuncGrap | |||
| class RecursiveComputer final : public DepComputer { | |||
| public: | |||
| explicit RecursiveComputer(const FuncGraphManager* m) : DepComputer(m) {} | |||
| explicit RecursiveComputer(const FuncGraphManager *m) : DepComputer(m) {} | |||
| ~RecursiveComputer() override = default; | |||
| RecursiveMap& recursive_map() { return recursive_map_; } | |||
| FuncGraphToBoolMap& recursive_analysis() { return recursive_analysis_; } | |||
| RecursiveMap &recursive_map() { return recursive_map_; } | |||
| FuncGraphToBoolMap &recursive_analysis() { return recursive_analysis_; } | |||
| void CheckRecursiveGraphs(const FuncGraphPtr& fg, std::list<FuncGraphPtr>* trace); | |||
| void CheckRecursiveGraphs(const FuncGraphPtr &fg, std::list<FuncGraphPtr> *trace); | |||
| size_t size() const override { return recursive_analysis_.size(); } | |||
| @@ -497,10 +497,10 @@ class RecursiveComputer final : public DepComputer { | |||
| class FuncGraphJTotalComputer final : public DepComputer { | |||
| public: | |||
| explicit FuncGraphJTotalComputer(const FuncGraphManager* m) : DepComputer(m) {} | |||
| explicit FuncGraphJTotalComputer(const FuncGraphManager *m) : DepComputer(m) {} | |||
| ~FuncGraphJTotalComputer() override = default; | |||
| FuncGraphToBoolMap& j_total_analysis() { return j_total_analysis_; } | |||
| FuncGraphToBoolMap &j_total_analysis() { return j_total_analysis_; } | |||
| size_t size() const override { return j_total_analysis_.size(); } | |||
| @@ -510,12 +510,12 @@ class FuncGraphJTotalComputer final : public DepComputer { | |||
| void ExtraReset() override { j_total_analysis_.clear(); } | |||
| void RealRecompute(FuncGraphPtr fg) override; | |||
| bool SeekJ(const FuncGraphPtr& fg, const FuncGraphSetPtr& path); | |||
| bool SeekJ(const FuncGraphPtr &fg, const FuncGraphSetPtr &path); | |||
| }; | |||
| class FuncGraphManager : public std::enable_shared_from_this<FuncGraphManager> { | |||
| public: | |||
| explicit FuncGraphManager(const std::vector<FuncGraphPtr>& roots, bool manage = true); | |||
| explicit FuncGraphManager(const std::vector<FuncGraphPtr> &roots, bool manage = true); | |||
| ~FuncGraphManager() { | |||
| if (is_manage_) { | |||
| RemoveRoots(); | |||
| @@ -526,71 +526,71 @@ class FuncGraphManager : public std::enable_shared_from_this<FuncGraphManager> { | |||
| void Init(); | |||
| void Clear(); | |||
| void AddFuncGraph(FuncGraphPtr func_graph, bool is_root = false); | |||
| void KeepRoots(const std::vector<FuncGraphPtr>& roots = {}); | |||
| void KeepRoots(const std::vector<FuncGraphPtr> &roots = {}); | |||
| void RemoveRoots(); | |||
| void SetParameters(const FuncGraphPtr& fg, const std::vector<AnfNodePtr>& parameters); | |||
| void MaybeDropFuncGraphs(const FuncGraphSet& func_graphs, bool ignore_users = false); | |||
| bool Replace(const AnfNodePtr& old_node, const AnfNodePtr& new_node); | |||
| void SetEdge(const AnfNodePtr& node, int index, const AnfNodePtr& value); | |||
| void MoveAllCNodeDropGraph(FuncGraphPtr source, FuncGraphPtr target, const ScopePtr& scope); | |||
| void SetParameters(const FuncGraphPtr &fg, const std::vector<AnfNodePtr> ¶meters); | |||
| void MaybeDropFuncGraphs(const FuncGraphSet &func_graphs, bool ignore_users = false); | |||
| bool Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node); | |||
| void SetEdge(const AnfNodePtr &node, int index, const AnfNodePtr &value); | |||
| void MoveAllCNodeDropGraph(FuncGraphPtr source, FuncGraphPtr target, const ScopePtr &scope); | |||
| FuncGraphTransaction Transact(); | |||
| void CommitChanges(const std::vector<Change>& changes); | |||
| void CommitChanges(const std::vector<Change> &changes); | |||
| bool IsManaged() const { return is_manage_; } | |||
| const FuncGraphSet& roots() const { return roots_; } | |||
| const FuncGraphSet &roots() const { return roots_; } | |||
| const FuncGraphSet& func_graphs() const { return func_graphs_; } | |||
| const FuncGraphSet &func_graphs() const { return func_graphs_; } | |||
| AnfNodeSet& all_nodes() { return all_nodes_; } | |||
| AnfNodeSet &all_nodes() { return all_nodes_; } | |||
| NodeUsersMap& node_users() { return node_users_; } | |||
| NodeUsersMap &node_users() { return node_users_; } | |||
| FuncGraphToAnfNodeMap& nodes() const { return nodes_->nodes_analysis_; } | |||
| FuncGraphToAnfNodeMap &nodes() const { return nodes_->nodes_analysis_; } | |||
| FuncGraphToAnfNodeCounterMap& valuenodes() const { return valuenodes_->count_nodes_map_; } | |||
| FuncGraphToAnfNodeCounterMap &valuenodes() const { return valuenodes_->count_nodes_map_; } | |||
| FuncGraphToAnfNodeCounterMap& free_variables_direct() const { return free_variables_direct_->count_nodes_map_; } | |||
| FuncGraphToAnfNodeCounterMap &free_variables_direct() const { return free_variables_direct_->count_nodes_map_; } | |||
| FuncGraphToAnfNodeCounterMap& func_graph_valuenodes() const { return func_graph_valuenodes_->count_nodes_map_; } | |||
| FuncGraphToAnfNodeCounterMap &func_graph_valuenodes() const { return func_graph_valuenodes_->count_nodes_map_; } | |||
| FuncGraphToFuncGraphCounterMap& func_graphs_used() const { return func_graphs_used_->count_func_graphs_map_; } | |||
| FuncGraphToFuncGraphCounterMap &func_graphs_used() const { return func_graphs_used_->count_func_graphs_map_; } | |||
| FuncGraphToFuncGraphCounterMap& func_graph_users() const { return func_graph_users_->count_func_graphs_map_; } | |||
| FuncGraphToFuncGraphCounterMap &func_graph_users() const { return func_graph_users_->count_func_graphs_map_; } | |||
| FuncGraphToAnfNodeCounterMap& func_graph_user_cnodes() const { return func_graph_user_cnodes_->count_nodes_map_; } | |||
| FuncGraphToAnfNodeCounterMap &func_graph_user_cnodes() const { return func_graph_user_cnodes_->count_nodes_map_; } | |||
| FuncGraphToFuncGraphCounterMap& func_graph_child_direct() const { | |||
| FuncGraphToFuncGraphCounterMap &func_graph_child_direct() const { | |||
| return func_graph_child_direct_->count_func_graphs_map_; | |||
| } | |||
| FuncGraphToFuncGraphCounterMap& func_graph_parents_direct() const { | |||
| FuncGraphToFuncGraphCounterMap &func_graph_parents_direct() const { | |||
| return func_graph_parents_direct_->count_func_graphs_map_; | |||
| } | |||
| FuncGraphToFuncGraphCounterMap& func_graph_j_direct() const { return func_graph_j_direct_->count_func_graphs_map_; } | |||
| FuncGraphToFuncGraphCounterMap &func_graph_j_direct() const { return func_graph_j_direct_->count_func_graphs_map_; } | |||
| FVTotalMap& free_variables_total() const; | |||
| FVTotalMap &free_variables_total() const; | |||
| FuncGraphSet& func_graph_parents_total(const FuncGraphPtr& fg) const; | |||
| FuncGraphSet &func_graph_parents_total(const FuncGraphPtr &fg) const; | |||
| FuncGraphSet& scopes(const FuncGraphPtr& fg) const; | |||
| FuncGraphSet &scopes(const FuncGraphPtr &fg) const; | |||
| FuncGraphPtr parent(const FuncGraphPtr& fg) const; | |||
| FuncGraphPtr parent(const FuncGraphPtr &fg) const; | |||
| FuncGraphSet& children(const FuncGraphPtr& fg) const; | |||
| FuncGraphSet &children(const FuncGraphPtr &fg) const; | |||
| FuncGraphSet& func_graphs_used_total(const FuncGraphPtr& fg) const; | |||
| FuncGraphSet &func_graphs_used_total(const FuncGraphPtr &fg) const; | |||
| bool recursive(const FuncGraphPtr& fg) const; | |||
| std::shared_ptr<std::list<FuncGraphPtr>> recursive_graphs(const FuncGraphPtr& fg) const; | |||
| bool recursive(const FuncGraphPtr &fg) const; | |||
| std::shared_ptr<std::list<FuncGraphPtr>> recursive_graphs(const FuncGraphPtr &fg) const; | |||
| bool func_graph_j_total(const FuncGraphPtr& fg) const; | |||
| bool func_graph_j_total(const FuncGraphPtr &fg) const; | |||
| std::shared_ptr<Signals> signals() const { return signals_; } | |||
| IncludeType Limit(const AnfNodePtr& node); | |||
| IncludeType Limit(const AnfNodePtr &node); | |||
| // Static Analysis | |||
| NodeUsersMap node_users_; | |||
| @@ -610,13 +610,13 @@ class FuncGraphManager : public std::enable_shared_from_this<FuncGraphManager> { | |||
| std::shared_ptr<ParentComputer> func_graph_parent_; | |||
| private: | |||
| void AddIntoManaged(const FuncGraphPtr& fg); | |||
| void AddIntoManaged(const FuncGraphPtr &fg); | |||
| void ProcessEdge(AnfNodePtr node, int index, AnfNodePtr inp, EdgeProcessDirection direction); | |||
| void ProcessInputs(const AnfNodePtr& node, EdgeProcessDirection direction); | |||
| void AcquireNodes(const std::vector<AnfNodePtr>& nodes); | |||
| FuncGraphSetPtr MaybeDropNodes(const std::vector<AnfNodePtr>& nodes); | |||
| void ParseChanges(const std::vector<Change>& changes, EdgeTupleCounter* add_edges, EdgeTupleCounter* rm_edges, | |||
| Counter<AnfNodePtr>* adds, Counter<AnfNodePtr>* rms); | |||
| void ProcessInputs(const AnfNodePtr &node, EdgeProcessDirection direction); | |||
| void AcquireNodes(const std::vector<AnfNodePtr> &nodes); | |||
| FuncGraphSetPtr MaybeDropNodes(const std::vector<AnfNodePtr> &nodes); | |||
| void ParseChanges(const std::vector<Change> &changes, EdgeTupleCounter *add_edges, EdgeTupleCounter *rm_edges, | |||
| Counter<AnfNodePtr> *adds, Counter<AnfNodePtr> *rms); | |||
| FuncGraphSet roots_; // managed roots | |||
| FuncGraphSet func_graphs_; // managed func graphs | |||
| @@ -637,7 +637,7 @@ class FuncGraphManager : public std::enable_shared_from_this<FuncGraphManager> { | |||
| class FuncGraphTransaction { | |||
| public: | |||
| explicit FuncGraphTransaction(FuncGraphManager* manager) : manager_(manager), changes_() { | |||
| explicit FuncGraphTransaction(FuncGraphManager *manager) : manager_(manager), changes_() { | |||
| MS_EXCEPTION_IF_NULL(manager_); | |||
| if (!manager_->IsManaged()) { | |||
| MS_LOG(DEBUG) << "The manager is not managed yet"; | |||
| @@ -648,19 +648,19 @@ class FuncGraphTransaction { | |||
| ~FuncGraphTransaction() { manager_ = nullptr; } | |||
| // set parameters of a func graph | |||
| void SetParameters(FuncGraphPtr fg, const std::vector<AnfNodePtr>& params); | |||
| void SetParameters(FuncGraphPtr fg, const std::vector<AnfNodePtr> ¶ms); | |||
| // replace old_node with new_node | |||
| bool Replace(const AnfNodePtr& old_node, const AnfNodePtr& new_node); | |||
| bool Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node); | |||
| // set esge, i.e., declare setting node.inputs[key] to value. | |||
| void SetEdge(const AnfNodePtr& src_node, int k, const AnfNodePtr& v); | |||
| void SetEdge(const AnfNodePtr &src_node, int k, const AnfNodePtr &v); | |||
| // commit all changes | |||
| void Commit(); | |||
| private: | |||
| FuncGraphManager* manager_; | |||
| FuncGraphManager *manager_; | |||
| std::vector<Change> changes_; | |||
| }; | |||
| @@ -668,9 +668,9 @@ class FuncGraphTransaction { | |||
| struct ArgsOfSetParams { | |||
| FuncGraphPtr func_graph; | |||
| std::vector<AnfNodePtr> params; | |||
| bool operator==(const ArgsOfSetParams& other) const { return &other == this; } | |||
| bool operator==(const ArgsOfSetParams &other) const { return &other == this; } | |||
| friend std::ostream& operator<<(std::ostream& os, const ArgsOfSetParams&) { | |||
| friend std::ostream &operator<<(std::ostream &os, const ArgsOfSetParams &) { | |||
| os << "[ArgsOfSetParams]"; | |||
| return os; | |||
| } | |||
| @@ -681,9 +681,9 @@ struct ArgsOfSetEdge { | |||
| CNodePtr root_node; | |||
| AnfNodePtr new_node; | |||
| size_t index; | |||
| bool operator==(const ArgsOfSetEdge& other) const { return &other == this; } | |||
| bool operator==(const ArgsOfSetEdge &other) const { return &other == this; } | |||
| friend std::ostream& operator<<(std::ostream& os, const ArgsOfSetEdge& other) { | |||
| friend std::ostream &operator<<(std::ostream &os, const ArgsOfSetEdge &other) { | |||
| os << "[ArgsOfSetEdge]"; | |||
| return os; | |||
| } | |||
| @@ -693,7 +693,7 @@ struct Change { | |||
| enum OpName { kTxSetParams, kTxSetEdge }; | |||
| OpName op; | |||
| Any args; | |||
| Change(OpName name, const Any& para) : op(name), args(para) {} | |||
| Change(OpName name, const Any ¶) : op(name), args(para) {} | |||
| }; | |||
| } // namespace mindspore | |||
| @@ -42,25 +42,25 @@ namespace mindspore { | |||
| // generate a graph corresponding to these types. | |||
| class MetaFuncGraph : public FuncGraphBase { | |||
| public: | |||
| explicit MetaFuncGraph(const std::string& name) : name_(name) { cache_.clear(); } | |||
| explicit MetaFuncGraph(const std::string &name) : name_(name) { cache_.clear(); } | |||
| ~MetaFuncGraph() override = default; | |||
| MS_DECLARE_PARENT(MetaFuncGraph, FuncGraphBase); | |||
| abstract::AbstractBasePtr MakeAbstractClosure(const AnfNodePtr& anf_node); | |||
| abstract::AbstractBasePtr MakeAbstractClosure(const AnfNodePtr &anf_node); | |||
| // Return normalized versions of the arguments. | |||
| // By default, this returns args unchanged. | |||
| virtual abstract::AbstractBasePtrList NormalizeArgs(const abstract::AbstractBasePtrList& args_spec_list) const { | |||
| virtual abstract::AbstractBasePtrList NormalizeArgs(const abstract::AbstractBasePtrList &args_spec_list) const { | |||
| return args_spec_list; | |||
| } | |||
| const std::vector<Signature>& signatures() const { return signatures_; } | |||
| void set_signatures(const std::vector<Signature>& signatures) { signatures_ = signatures; } | |||
| const std::vector<Signature> &signatures() const { return signatures_; } | |||
| void set_signatures(const std::vector<Signature> &signatures) { signatures_ = signatures; } | |||
| // Generate a Graph for the given abstract arguments. | |||
| virtual FuncGraphPtr GenerateFuncGraph(const abstract::AbstractBasePtrList& args_spec_list) { | |||
| virtual FuncGraphPtr GenerateFuncGraph(const abstract::AbstractBasePtrList &args_spec_list) { | |||
| TypePtrList types; | |||
| (void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(types), | |||
| [](const AbstractBasePtr& arg) -> TypePtr { | |||
| [](const AbstractBasePtr &arg) -> TypePtr { | |||
| MS_EXCEPTION_IF_NULL(arg); | |||
| return arg->BuildType(); | |||
| }); | |||
| @@ -81,7 +81,7 @@ class MetaFuncGraph : public FuncGraphBase { | |||
| } | |||
| // Generate a Graph for this type signature. | |||
| virtual FuncGraphPtr GenerateFromTypes(const TypePtrList&) { | |||
| virtual FuncGraphPtr GenerateFromTypes(const TypePtrList &) { | |||
| MS_LOG(EXCEPTION) << "Undefine the method of generating graph from types."; | |||
| } | |||
| @@ -89,8 +89,8 @@ class MetaFuncGraph : public FuncGraphBase { | |||
| std::string ToString() const override { return name_; } | |||
| std::size_t hash() const override { return tid(); } | |||
| virtual bool operator==(const MetaFuncGraph& other) const { return &other == this; } | |||
| bool operator==(const Value& other) const override { | |||
| virtual bool operator==(const MetaFuncGraph &other) const { return &other == this; } | |||
| bool operator==(const Value &other) const override { | |||
| if (other.isa<MetaFuncGraph>()) { | |||
| return &other == this; | |||
| } else { | |||
| @@ -31,7 +31,7 @@ namespace mindspore { | |||
| namespace tensor { | |||
| void DataBuf2Contiguous(const py::array& src, py::array* const dest) { | |||
| void DataBuf2Contiguous(const py::array &src, py::array *const dest) { | |||
| if (dest == nullptr) { | |||
| MS_LOG(EXCEPTION) << "Failed to copy data to a contiguous buffer as dest is nullptr!"; | |||
| } | |||
| @@ -55,9 +55,9 @@ void DataBuf2Contiguous(const py::array& src, py::array* const dest) { | |||
| // MetaTensor has default type_id_ which is TypeId::kTypeUnknown. | |||
| MetaTensor::MetaTensor() : data_type_(TypeId::kTypeUnknown) {} | |||
| MetaTensor::MetaTensor(const TypeId data_type, const std::vector<int>& shape) : data_type_(data_type), shape_(shape) {} | |||
| MetaTensor::MetaTensor(const TypeId data_type, const std::vector<int> &shape) : data_type_(data_type), shape_(shape) {} | |||
| MetaTensor::MetaTensor(const TypePtr& type_ptr, const py::tuple& shape) { | |||
| MetaTensor::MetaTensor(const TypePtr &type_ptr, const py::tuple &shape) { | |||
| TypeId data_type = TypeId::kTypeUnknown; | |||
| if (type_ptr != nullptr) { | |||
| data_type = type_ptr->type_id(); | |||
| @@ -69,10 +69,10 @@ MetaTensor::MetaTensor(const TypePtr& type_ptr, const py::tuple& shape) { | |||
| } | |||
| } | |||
| MetaTensor::MetaTensor(const MetaTensor& meta_tensor) | |||
| MetaTensor::MetaTensor(const MetaTensor &meta_tensor) | |||
| : Value(meta_tensor), data_type_(meta_tensor.data_type()), shape_(meta_tensor.shape()) {} | |||
| MetaTensor& MetaTensor::operator=(const MetaTensor& meta_tensor) { | |||
| MetaTensor &MetaTensor::operator=(const MetaTensor &meta_tensor) { | |||
| if (&meta_tensor == this) { | |||
| return *this; | |||
| } | |||
| @@ -84,7 +84,7 @@ MetaTensor& MetaTensor::operator=(const MetaTensor& meta_tensor) { | |||
| return *this; | |||
| } | |||
| bool MetaTensor::operator==(const MetaTensor& meta_tensor) const { | |||
| bool MetaTensor::operator==(const MetaTensor &meta_tensor) const { | |||
| return data_type_ == meta_tensor.data_type() && shape_ == meta_tensor.shape(); | |||
| } | |||
| @@ -117,7 +117,7 @@ TypePtr MetaTensor::SetDtype(const TypePtr type_ptr) { | |||
| return type_ptr; | |||
| } | |||
| void MetaTensor::SetDeviceInfo(const std::string& format, const TypePtr& data_type) { | |||
| void MetaTensor::SetDeviceInfo(const std::string &format, const TypePtr &data_type) { | |||
| DeviceInfo info(format, data_type); | |||
| set_device_info(info); | |||
| } | |||
| @@ -138,7 +138,7 @@ std::string MetaTensor::DumpText() const { | |||
| return oss.str(); | |||
| } | |||
| Tensor::Tensor(const TypePtr& type_ptr, const py::tuple& shape) { | |||
| Tensor::Tensor(const TypePtr &type_ptr, const py::tuple &shape) { | |||
| TypeId data_type = TypeId::kTypeUnknown; | |||
| if (type_ptr != nullptr) { | |||
| data_type = type_ptr->type_id(); | |||
| @@ -151,24 +151,24 @@ Tensor::Tensor(const TypePtr& type_ptr, const py::tuple& shape) { | |||
| init(data_type_, shape_, &data_); | |||
| } | |||
| Tensor::Tensor(TypeId data_type, const std::vector<int>& shape) { init(data_type, shape, &data_); } | |||
| Tensor::Tensor(TypeId data_type, const std::vector<int> &shape) { init(data_type, shape, &data_); } | |||
| Tensor::Tensor(const py::array& input, const TypePtr& data_type) { init(input, data_type); } | |||
| Tensor::Tensor(const py::array &input, const TypePtr &data_type) { init(input, data_type); } | |||
| Tensor::Tensor(const py::list& input, const TypePtr& data_type) { init(py::array(input), data_type); } | |||
| Tensor::Tensor(const py::list &input, const TypePtr &data_type) { init(py::array(input), data_type); } | |||
| Tensor::Tensor(const py::tuple& input, const TypePtr& data_type) { init(py::array(input), data_type); } | |||
| Tensor::Tensor(const py::tuple &input, const TypePtr &data_type) { init(py::array(input), data_type); } | |||
| Tensor::Tensor(const py::float_& input, const TypePtr& data_type) { init(py::array(input), data_type); } | |||
| Tensor::Tensor(const py::float_ &input, const TypePtr &data_type) { init(py::array(input), data_type); } | |||
| Tensor::Tensor(const py::int_& input, const TypePtr& data_type) { init(py::array(input), data_type); } | |||
| Tensor::Tensor(const py::int_ &input, const TypePtr &data_type) { init(py::array(input), data_type); } | |||
| Tensor::Tensor(const Tensor& tensor, const TypePtr& data_type) | |||
| Tensor::Tensor(const Tensor &tensor, const TypePtr &data_type) | |||
| : MetaTensor(tensor), device_address_(tensor.device_address()) { | |||
| init(tensor.data_, data_type); | |||
| } | |||
| Tensor& Tensor::operator=(const Tensor& tensor) { | |||
| Tensor &Tensor::operator=(const Tensor &tensor) { | |||
| if (this != &tensor) { | |||
| MetaTensor::operator=(tensor); | |||
| dirty_ = tensor.is_dirty(); | |||
| @@ -178,11 +178,11 @@ Tensor& Tensor::operator=(const Tensor& tensor) { | |||
| return *this; | |||
| } | |||
| bool Tensor::operator==(const Tensor& tensor) const { | |||
| bool Tensor::operator==(const Tensor &tensor) const { | |||
| return (MetaTensor::operator==(tensor) && data_ == tensor.data_); | |||
| } | |||
| bool Tensor::ValueEqualPy(const py::object& other) const { | |||
| bool Tensor::ValueEqualPy(const py::object &other) const { | |||
| if (!py::isinstance<Tensor>(other)) { | |||
| MS_LOG(WARNING) << "compare other not a tensor"; | |||
| return false; | |||
| @@ -190,7 +190,7 @@ bool Tensor::ValueEqualPy(const py::object& other) const { | |||
| return ValueEqual(py::cast<Tensor>(other)); | |||
| } | |||
| bool Tensor::ValueEqual(const Tensor& other) const { | |||
| bool Tensor::ValueEqual(const Tensor &other) const { | |||
| auto equal = [&other, this]() -> bool { | |||
| auto np = py::module::import("numpy"); | |||
| auto equal = np.attr("equal")(data_, other.data_); | |||
| @@ -218,7 +218,7 @@ int Tensor::data_type_c() const { return static_cast<int>(data_type_); } | |||
| std::vector<int> Tensor::shape_c(void) const { return shape(); } | |||
| void* Tensor::data_c(bool writable) { | |||
| void *Tensor::data_c(bool writable) { | |||
| // operand of bit operation should be unsigned int. | |||
| unsigned int flags = ((unsigned int)data_.flags()) & pybind11::detail::npy_api::NPY_ARRAY_C_CONTIGUOUS_; | |||
| bool is_c_contiguous = (flags != 0) ? true : false; | |||
| @@ -231,7 +231,7 @@ void* Tensor::data_c(bool writable) { | |||
| return data_.request(writable).ptr; | |||
| } | |||
| TypeId Tensor::GetDataType(const py::buffer_info& buf) const { | |||
| TypeId Tensor::GetDataType(const py::buffer_info &buf) const { | |||
| TypeId data_type = TypeId::kTypeUnknown; | |||
| if (buf.format.compare("e") == 0) { | |||
| data_type = TypeId::kNumberTypeFloat16; | |||
| @@ -263,7 +263,7 @@ TypeId Tensor::GetDataType(const py::buffer_info& buf) const { | |||
| return data_type; | |||
| } | |||
| void Tensor::init(const py::array& input, const TypePtr& type_ptr) { | |||
| void Tensor::init(const py::array &input, const TypePtr &type_ptr) { | |||
| TypeId data_type = TypeId::kTypeUnknown; | |||
| if (type_ptr != nullptr) { | |||
| data_type = type_ptr->type_id(); | |||
| @@ -271,7 +271,7 @@ void Tensor::init(const py::array& input, const TypePtr& type_ptr) { | |||
| init(input, data_type); | |||
| } | |||
| void Tensor::init(const py::array& input, const TypeId& data_type) { | |||
| void Tensor::init(const py::array &input, const TypeId &data_type) { | |||
| py::buffer_info buf = input.request(); | |||
| data_type_ = GetDataType(buf); | |||
| @@ -301,7 +301,7 @@ void Tensor::init(const py::array& input, const TypeId& data_type) { | |||
| } | |||
| } | |||
| void Tensor::init(TypeId data_type, const std::vector<int>& shape, py::array* const data) { | |||
| void Tensor::init(TypeId data_type, const std::vector<int> &shape, py::array *const data) { | |||
| data_type_ = data_type; | |||
| shape_ = shape; | |||
| switch (data_type) { | |||
| @@ -368,7 +368,7 @@ TypeId Tensor::set_data_type(const TypeId data_type) { | |||
| return data_type_; | |||
| } | |||
| bool Tensor::convert_data(const py::array& in, const TypeId in_data_type, py::array* const out, | |||
| bool Tensor::convert_data(const py::array &in, const TypeId in_data_type, py::array *const out, | |||
| const TypeId out_data_type) { | |||
| if (out == nullptr) { | |||
| return false; | |||
| @@ -458,7 +458,7 @@ py::array Tensor::data_sync() { | |||
| return data_; | |||
| } | |||
| REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module* m) { | |||
| REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) { | |||
| // dtype should define before Tensor, because Tensor init depend dtype | |||
| (void)py::class_<Tensor, std::shared_ptr<Tensor>>(*m, "Tensor") | |||
| .def(py::init<TypePtr, py::tuple>(), py::arg("dtype"), py::arg("shape")) | |||
| @@ -541,11 +541,11 @@ REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module* m) { | |||
| .def("__repr__", &Tensor::ToStringRepr) | |||
| .def("__eq__", &Tensor::ValueEqualPy) | |||
| .def(py::pickle( | |||
| [](const Tensor& t) { // __getstate__ | |||
| [](const Tensor &t) { // __getstate__ | |||
| /* Return a tuple that fully encodes the state of the object */ | |||
| return py::make_tuple(t.data()); | |||
| }, | |||
| [](const py::tuple& t) { // __setstate__ | |||
| [](const py::tuple &t) { // __setstate__ | |||
| if (t.size() != 1) { | |||
| throw std::runtime_error("Invalid state!"); | |||
| } | |||
| @@ -131,16 +131,16 @@ class MetaTensor : public Value { | |||
| // information of a Tensor. The following codes will create a 2x3 float | |||
| // param data_type The data type of the tensor. | |||
| // param shape The shape of the tensor. | |||
| MetaTensor(const TypeId data_type, const std::vector<int>& shape); | |||
| MetaTensor(const TypeId data_type, const std::vector<int> &shape); | |||
| MetaTensor(const TypePtr& type_ptr, const py::tuple& shape); | |||
| MetaTensor(const TypePtr &type_ptr, const py::tuple &shape); | |||
| // brief Constructs a MetaTensor object from an existing MetaTensor instance. | |||
| // | |||
| // The constructed MetaTensor object will have the same data type and shape as the | |||
| // meta_tensor. | |||
| // | |||
| // param meta_tensor An existing MetaTensor object. | |||
| MetaTensor(const MetaTensor& meta_tensor); | |||
| MetaTensor(const MetaTensor &meta_tensor); | |||
| ~MetaTensor() override = default; | |||
| MS_DECLARE_PARENT(MetaTensor, Value) | |||
| @@ -149,7 +149,7 @@ class MetaTensor : public Value { | |||
| // The constructed MetaTensor object has the same type and shape with meta_tensor. | |||
| // | |||
| // param meta_tensor An existing MetaTensor object. | |||
| virtual MetaTensor& operator=(const MetaTensor& meta_tensor); | |||
| virtual MetaTensor &operator=(const MetaTensor &meta_tensor); | |||
| // brief Compares two MetaTensor objects. | |||
| // | |||
| @@ -157,7 +157,7 @@ class MetaTensor : public Value { | |||
| // | |||
| // param meta_tensor The MetaTensor object to be compared. | |||
| // return true: If having same type and shape, return true, or return false. | |||
| virtual bool operator==(const MetaTensor& meta_tensor) const; | |||
| virtual bool operator==(const MetaTensor &meta_tensor) const; | |||
| // brief Returns the data type of the tensor in its MetaTensor. | |||
| // | |||
| @@ -193,7 +193,7 @@ class MetaTensor : public Value { | |||
| // | |||
| // param shape The shape of the tensor. | |||
| // return The shape's size. | |||
| size_t set_shape(const std::vector<int>& shape) { | |||
| size_t set_shape(const std::vector<int> &shape) { | |||
| this->shape_ = shape; | |||
| return shape_.size(); | |||
| } | |||
| @@ -202,9 +202,9 @@ class MetaTensor : public Value { | |||
| DeviceInfo device_info() const { return device_info_; } | |||
| // Set tensor's device info. | |||
| void set_device_info(const DeviceInfo& device_info) { device_info_ = device_info; } | |||
| void set_device_info(const DeviceInfo &device_info) { device_info_ = device_info; } | |||
| void SetDeviceInfo(const std::string& format, const TypePtr& data_type); | |||
| void SetDeviceInfo(const std::string &format, const TypePtr &data_type); | |||
| // Get the size of a given dimension by its index number. | |||
| int DimensionSize(size_t index) const; | |||
| @@ -222,9 +222,9 @@ class MetaTensor : public Value { | |||
| } | |||
| return hash_value; | |||
| } | |||
| bool operator==(const Value& other) const override { | |||
| bool operator==(const Value &other) const override { | |||
| if (other.isa<MetaTensor>()) { | |||
| auto other_ = static_cast<const MetaTensor&>(other); | |||
| auto other_ = static_cast<const MetaTensor &>(other); | |||
| return *this == other_; | |||
| } else { | |||
| return false; | |||
| @@ -262,49 +262,49 @@ class Tensor : public MetaTensor { | |||
| // | |||
| // param type_ptr [TypePty] Data type of the tensor. | |||
| // param py_shape [py::tuple] The shape represented by py::tuple of the tensor. | |||
| Tensor(const TypePtr& type_ptr, const py::tuple& shape); | |||
| Tensor(const TypePtr &type_ptr, const py::tuple &shape); | |||
| // brief Constructor for C++. | |||
| // | |||
| // param data_type [TypeId] Data type of the tensor. | |||
| // param shape The shape represented by std::vector<int> of the tensor. | |||
| Tensor(TypeId data_type, const std::vector<int>& shape); | |||
| Tensor(TypeId data_type, const std::vector<int> &shape); | |||
| // brief Constructor for Python. | |||
| // | |||
| // param input [py::array] Data value of the tensor. | |||
| // param data_type [TypeId] Data type of the tensor. | |||
| explicit Tensor(const py::array& input, const TypePtr& data_type = nullptr); | |||
| explicit Tensor(const py::array &input, const TypePtr &data_type = nullptr); | |||
| // brief Constructor | |||
| // | |||
| // param input [py::list] the data for tensor | |||
| // param data_type [TypeId] data type | |||
| explicit Tensor(const py::list& input, const TypePtr& data_type = nullptr); | |||
| explicit Tensor(const py::list &input, const TypePtr &data_type = nullptr); | |||
| // brief Constructor | |||
| // | |||
| // param input [py::tuple] the data for tensor | |||
| // param data_type [TypeId] data type | |||
| explicit Tensor(const py::tuple& input, const TypePtr& data_type = nullptr); | |||
| explicit Tensor(const py::tuple &input, const TypePtr &data_type = nullptr); | |||
| // brief Constructor | |||
| // | |||
| // param input [py::float_] the data for tensor | |||
| // param data_type [TypeId] data type | |||
| explicit Tensor(const py::float_& input, const TypePtr& data_type = nullptr); | |||
| explicit Tensor(const py::float_ &input, const TypePtr &data_type = nullptr); | |||
| // brief Constructor | |||
| // | |||
| // param input [py::int_] the data for tensor | |||
| // param data_type [TypeId] data type | |||
| explicit Tensor(const py::int_& input, const TypePtr& data_type = nullptr); | |||
| explicit Tensor(const py::int_ &input, const TypePtr &data_type = nullptr); | |||
| // brief Constructor | |||
| // | |||
| // param input [Tensor] the data for tensor | |||
| // param data_type [TypeId] data type | |||
| Tensor(const Tensor& tensor, const TypePtr& data_type = nullptr); | |||
| Tensor(const Tensor &tensor, const TypePtr &data_type = nullptr); | |||
| ~Tensor() override = default; | |||
| @@ -315,7 +315,7 @@ class Tensor : public MetaTensor { | |||
| // The constructed Tensor object has the same type and shape with tensor. | |||
| // | |||
| // param tensor An existing Tensor object. | |||
| Tensor& operator=(const Tensor& tensor); | |||
| Tensor &operator=(const Tensor &tensor); | |||
| // brief Compares two Tensor objects. | |||
| // | |||
| @@ -324,17 +324,17 @@ class Tensor : public MetaTensor { | |||
| // | |||
| // param tensor The Tensor object to be compared. | |||
| // return true: If having same type, shape and data, return true, or return false. | |||
| bool operator==(const Tensor& tensor) const; | |||
| bool operator==(const Tensor &tensor) const; | |||
| // It is different from 'operator==' which just compare shape/type/address, it do real value comparison. | |||
| bool ValueEqual(const Tensor& other) const; | |||
| bool ValueEqual(const Tensor &other) const; | |||
| // It is different from 'operator==' which just compare shape/type/address, it do real value comparison. | |||
| bool ValueEqualPy(const py::object& other) const; | |||
| bool ValueEqualPy(const py::object &other) const; | |||
| bool operator==(const Value& other) const override { | |||
| bool operator==(const Value &other) const override { | |||
| if (other.isa<Tensor>()) { | |||
| auto other_ = static_cast<const Tensor&>(other); | |||
| auto other_ = static_cast<const Tensor &>(other); | |||
| return *this == other_; | |||
| } else { | |||
| return false; | |||
| @@ -375,13 +375,13 @@ class Tensor : public MetaTensor { | |||
| // | |||
| // param writable true if writable, false if read only | |||
| // return The pointer to the object | |||
| void* data_c(bool writable = false); | |||
| void *data_c(bool writable = false); | |||
| // brief Get data type from tensor data. | |||
| // | |||
| // param buf The buffer info of the py::array data. | |||
| // return The [TypeId] of the tensor data. | |||
| TypeId GetDataType(const py::buffer_info& buf) const; | |||
| TypeId GetDataType(const py::buffer_info &buf) const; | |||
| // brief Sets the data type of a tensor. | |||
| // | |||
| @@ -401,23 +401,23 @@ class Tensor : public MetaTensor { | |||
| // param input [py::array] the data for tensor | |||
| // param data_type [TypeId] data type | |||
| // return true if succeed, false if failed. | |||
| void init(const py::array& input, const TypeId& data_type); | |||
| void init(const py::array& input, const TypePtr& type_ptr); | |||
| void init(const py::array &input, const TypeId &data_type); | |||
| void init(const py::array &input, const TypePtr &type_ptr); | |||
| // brief init tensor attribute | |||
| // | |||
| // param data_type [TypeId] Data type of the tensor. | |||
| // param shape [py::array] The shape of the tensor. | |||
| // return true if succeed, false if failed. | |||
| void init(TypeId data_type, const std::vector<int>& shape, py::array* data); | |||
| void init(TypeId data_type, const std::vector<int> &shape, py::array *data); | |||
| bool convert_data(const py::array& in, const TypeId in_data_type, py::array* out, const TypeId out_data_type); | |||
| bool convert_data(const py::array &in, const TypeId in_data_type, py::array *out, const TypeId out_data_type); | |||
| public: | |||
| bool is_dirty() const { return dirty_; } | |||
| void set_dirty(const bool dirty) { dirty_ = dirty; } | |||
| DeviceAddressPtr device_address() const { return device_address_; } | |||
| void set_device_address(const DeviceAddressPtr& device_address) { device_address_ = device_address; } | |||
| void set_device_address(const DeviceAddressPtr &device_address) { device_address_ = device_address; } | |||
| py::array data_sync(); | |||
| private: | |||
| @@ -18,9 +18,9 @@ | |||
| #include "pipeline/static_analysis/abstract_value.h" | |||
| namespace mindspore { | |||
| bool Named::operator==(const Value& other) const { | |||
| bool Named::operator==(const Value &other) const { | |||
| if (other.isa<Named>()) { | |||
| auto other_named = static_cast<const Named&>(other); | |||
| auto other_named = static_cast<const Named &>(other); | |||
| return *this == other_named; | |||
| } else { | |||
| return false; | |||
| @@ -27,18 +27,18 @@ | |||
| namespace mindspore { | |||
| class Named : public Value { | |||
| public: | |||
| explicit Named(const std::string& name) : name_(name) { hash_id_ = std::hash<std::string>{}(name); } | |||
| Named(const Named& other) : Value(other) { | |||
| explicit Named(const std::string &name) : name_(name) { hash_id_ = std::hash<std::string>{}(name); } | |||
| Named(const Named &other) : Value(other) { | |||
| this->name_ = other.name_; | |||
| hash_id_ = std::hash<std::string>{}(other.name_); | |||
| } | |||
| ~Named() override = default; | |||
| MS_DECLARE_PARENT(Named, Value); | |||
| const std::string& name() const { return name_; } | |||
| virtual bool operator==(const Named& other) const { return name_ == other.name(); } | |||
| bool operator==(const Value& other) const override; | |||
| Named& operator=(const Named& other) { | |||
| const std::string &name() const { return name_; } | |||
| virtual bool operator==(const Named &other) const { return name_ == other.name(); } | |||
| bool operator==(const Value &other) const override; | |||
| Named &operator=(const Named &other) { | |||
| if (&other != this) { | |||
| this->type_ = other.type_; | |||
| this->name_ = other.name_; | |||
| @@ -50,7 +50,7 @@ class Named : public Value { | |||
| std::size_t Hash() const { return hash_id_; } | |||
| std::size_t hash() const override { return hash_id_; } | |||
| friend std::ostream& operator<<(std::ostream& os, const Named& nmd) { | |||
| friend std::ostream &operator<<(std::ostream &os, const Named &nmd) { | |||
| os << nmd.name(); | |||
| return os; | |||
| } | |||
| @@ -31,7 +31,7 @@ | |||
| namespace mindspore { | |||
| using mindspore::abstract::AbstractFunction; | |||
| abstract::AbstractBasePtr Primitive::ToPrimAbstract(const AnfNodePtr& anf_node) { | |||
| abstract::AbstractBasePtr Primitive::ToPrimAbstract(const AnfNodePtr &anf_node) { | |||
| auto prim_func = std::make_shared<abstract::PrimitiveAbstractClosure>(shared_from_base<Primitive>(), anf_node); | |||
| return prim_func; | |||
| } | |||
| @@ -63,23 +63,23 @@ py::function Primitive::GetComputeFunction() { | |||
| return fn; | |||
| } | |||
| bool Primitive::operator==(const Value& other) const { | |||
| bool Primitive::operator==(const Value &other) const { | |||
| if (other.isa<Primitive>()) { | |||
| auto other_prim = static_cast<const Primitive&>(other); | |||
| auto other_prim = static_cast<const Primitive &>(other); | |||
| return *this == other_prim; | |||
| } else { | |||
| return false; | |||
| } | |||
| } | |||
| bool Primitive::operator==(const Primitive& other) const { | |||
| bool Primitive::operator==(const Primitive &other) const { | |||
| if (name() != other.name()) { | |||
| return false; | |||
| } | |||
| if (attrs_.size() != other.attrs_.size()) { | |||
| return false; | |||
| } | |||
| auto all = std::all_of(attrs_.begin(), attrs_.end(), [&other](const std::pair<std::string, ValuePtr>& item) -> bool { | |||
| auto all = std::all_of(attrs_.begin(), attrs_.end(), [&other](const std::pair<std::string, ValuePtr> &item) -> bool { | |||
| if (item.second == nullptr) { | |||
| return false; | |||
| } | |||
| @@ -95,7 +95,7 @@ bool Primitive::operator==(const Primitive& other) const { | |||
| void Primitive::set_signatures( | |||
| std::vector<std::tuple<std::string, SignatureEnumRW, SignatureEnumKind, py::object, SignatureEnumDType>> signatures) { | |||
| signatures_.clear(); | |||
| for (auto& signature : signatures) { | |||
| for (auto &signature : signatures) { | |||
| std::string name; | |||
| SignatureEnumRW rw; | |||
| SignatureEnumKind kind; | |||
| @@ -114,7 +114,7 @@ std::string Primitive::GetAttrsText() const { | |||
| std::ostringstream oss; | |||
| oss << "["; | |||
| bool is_first = true; | |||
| for (auto& attr : attrs_) { | |||
| for (auto &attr : attrs_) { | |||
| if (is_first) { | |||
| is_first = false; | |||
| } else { | |||
| @@ -128,7 +128,7 @@ std::string Primitive::GetAttrsText() const { | |||
| } | |||
| py::function PrimitivePy::GetBpropFunction() { | |||
| static const char* const get_bprop_func_name = "get_bprop"; | |||
| static const char *const get_bprop_func_name = "get_bprop"; | |||
| if (py::hasattr(python_obj_, get_bprop_func_name)) { | |||
| py::function fn = python_obj_.attr(get_bprop_func_name)().cast<py::function>(); | |||
| return fn; | |||
| @@ -142,7 +142,7 @@ py::function PrimitivePy::GetBpropFunction() { | |||
| } | |||
| py::function PrimitivePy::GetComputeFunction() { | |||
| static const char* const compute_func_name = "vm_impl"; | |||
| static const char *const compute_func_name = "vm_impl"; | |||
| if (py::hasattr(python_obj_, compute_func_name)) { | |||
| MS_LOG(INFO) << "" << name() << " compute_func_name"; | |||
| @@ -163,7 +163,7 @@ py::function PrimitivePy::GetComputeFunction() { | |||
| return vm_fn; | |||
| } | |||
| void PrimitivePy::AddPyAttr(const py::str& name, const py::object& obj) { | |||
| void PrimitivePy::AddPyAttr(const py::str &name, const py::object &obj) { | |||
| std::string attr_name = name; | |||
| ValuePtr converted_ret = nullptr; | |||
| if (py::isinstance<py::module>(obj)) { | |||
| @@ -178,13 +178,13 @@ void PrimitivePy::AddPyAttr(const py::str& name, const py::object& obj) { | |||
| py::dict PrimitivePy::GetAttrDict() { | |||
| py::dict attr_dict; | |||
| for (auto& attr : attrs_) { | |||
| for (auto &attr : attrs_) { | |||
| attr_dict[py::str(attr.first)] = ValuePtrToPyData(attr.second); | |||
| } | |||
| return attr_dict; | |||
| } | |||
| REGISTER_PYBIND_DEFINE(Primitive_, ([](const py::module* m) { | |||
| REGISTER_PYBIND_DEFINE(Primitive_, ([](const py::module *m) { | |||
| (void)py::enum_<PrimType>(*m, "prim_type", py::arithmetic()) | |||
| .value("unknown", PrimType::kPrimTypeUnknown) | |||
| .value("builtin", PrimType::kPrimTypeBuiltIn) | |||
| @@ -192,7 +192,7 @@ REGISTER_PYBIND_DEFINE(Primitive_, ([](const py::module* m) { | |||
| .value("user_custom", PrimType::kPrimTypeUserCustom); | |||
| (void)py::class_<PrimitivePy, std::shared_ptr<PrimitivePy>>(*m, "Primitive_") | |||
| .def_readonly(PYTHON_PRIMITIVE_FLAG, &PrimitivePy::parse_info_) | |||
| .def(py::init<py::str&, py::object>()) | |||
| .def(py::init<py::str &, py::object>()) | |||
| .def("add_attr", &PrimitivePy::AddPyAttr, "add primitive attr") | |||
| .def("get_attr_dict", &PrimitivePy::GetAttrDict, "get primitive attr") | |||
| .def("set_prim_type", &PrimitivePy::set_prim_type, "Set primitive type.") | |||
| @@ -48,25 +48,25 @@ enum PrimType { | |||
| class Primitive : public Named { | |||
| public: | |||
| explicit Primitive(const std::string& name, const PrimType prim_type = kPrimTypeBuiltIn) | |||
| explicit Primitive(const std::string &name, const PrimType prim_type = kPrimTypeBuiltIn) | |||
| : Named(name), signatures_(), prim_type_(prim_type) {} | |||
| Primitive(const Primitive& prim) | |||
| Primitive(const Primitive &prim) | |||
| : Named(prim), attrs_(prim.attrs_), signatures_(prim.signatures_), prim_type_(prim.prim_type_) {} | |||
| MS_DECLARE_PARENT(Primitive, Named); | |||
| abstract::AbstractBasePtr ToPrimAbstract(const AnfNodePtr& anf_node); | |||
| abstract::AbstractBasePtr ToPrimAbstract(const AnfNodePtr &anf_node); | |||
| std::string ToString() const override { return name(); } | |||
| virtual py::function GetBpropFunction(); | |||
| virtual py::function GetComputeFunction(); | |||
| Primitive& AddAttr(const std::string& name, const ValuePtr& attr) { | |||
| Primitive &AddAttr(const std::string &name, const ValuePtr &attr) { | |||
| attrs_[name] = attr; | |||
| return *this; | |||
| } | |||
| Primitive& SetAttrs(const std::unordered_map<std::string, ValuePtr>& attrs) { | |||
| for (auto& attr : attrs) { | |||
| Primitive &SetAttrs(const std::unordered_map<std::string, ValuePtr> &attrs) { | |||
| for (auto &attr : attrs) { | |||
| attrs_[attr.first] = attr.second; | |||
| } | |||
| return *this; | |||
| @@ -76,21 +76,21 @@ class Primitive : public Named { | |||
| std::vector<std::tuple<std::string, SignatureEnumRW, SignatureEnumKind, py::object, SignatureEnumDType>> | |||
| signatures); | |||
| const std::vector<Signature>& signatures() const { return signatures_; } | |||
| const std::vector<Signature> &signatures() const { return signatures_; } | |||
| void set_attr(const std::string& attrName, const ValuePtr& attr) { attrs_[attrName] = attr; } | |||
| void EraseAttr(const std::string& attrName) { (void)attrs_.erase(attrName); } | |||
| void set_attr(const std::string &attrName, const ValuePtr &attr) { attrs_[attrName] = attr; } | |||
| void EraseAttr(const std::string &attrName) { (void)attrs_.erase(attrName); } | |||
| ValuePtr GetAttr(const std::string& attrName) const { | |||
| ValuePtr GetAttr(const std::string &attrName) const { | |||
| auto iter = attrs_.find(attrName); | |||
| return iter == attrs_.cend() ? nullptr : iter->second; | |||
| } | |||
| const std::unordered_map<std::string, ValuePtr>& attrs() const { return attrs_; } | |||
| const std::unordered_map<std::string, ValuePtr> &attrs() const { return attrs_; } | |||
| // if Primitive has any attribute, for Primitives like scalar_add, return, etc, don't have any attribute. | |||
| bool HasAttr() const { return !attrs_.empty(); } | |||
| bool HasAttr(const std::string& attrName) const { | |||
| bool HasAttr(const std::string &attrName) const { | |||
| auto iter = attrs_.find(attrName); | |||
| return !(iter == attrs_.cend()); | |||
| } | |||
| @@ -103,8 +103,8 @@ class Primitive : public Named { | |||
| PrimType prim_type() const { return prim_type_; } | |||
| std::string instance_name() const { return instance_name_; } | |||
| std::string GetAttrsText() const; | |||
| bool operator==(const Value& other) const override; | |||
| bool operator==(const Primitive& other) const; | |||
| bool operator==(const Value &other) const override; | |||
| bool operator==(const Primitive &other) const; | |||
| ~Primitive() override = default; | |||
| protected: | |||
| @@ -118,18 +118,18 @@ class Primitive : public Named { | |||
| class PrimitivePy : public Primitive { | |||
| public: | |||
| PrimitivePy(const py::str& name, const py::object& python_obj) : Primitive(name), python_obj_(python_obj) {} | |||
| PrimitivePy(const py::str &name, const py::object &python_obj) : Primitive(name), python_obj_(python_obj) {} | |||
| ~PrimitivePy() override = default; | |||
| MS_DECLARE_PARENT(PrimitivePy, Primitive); | |||
| py::function GetBpropFunction() override; | |||
| py::function GetComputeFunction() override; | |||
| void AddPyAttr(const py::str& name, const py::object& obj); | |||
| void AddPyAttr(const py::str &name, const py::object &obj); | |||
| py::dict GetAttrDict(); | |||
| const bool parse_info_ = true; | |||
| const py::object& GetPyObj() const { return python_obj_; } | |||
| const py::object &GetPyObj() const { return python_obj_; } | |||
| bool is_tuple_input_ = false; | |||
| private: | |||
| @@ -138,13 +138,13 @@ class PrimitivePy : public Primitive { | |||
| using PrimitivePyPtr = std::shared_ptr<PrimitivePy>; | |||
| inline std::ostream& operator<<(std::ostream& os, const PrimitivePtr& p) { | |||
| inline std::ostream &operator<<(std::ostream &os, const PrimitivePtr &p) { | |||
| os << *p; | |||
| return os; | |||
| } | |||
| struct PrimitiveEqual { | |||
| bool operator()(PrimitivePtr const& t1, PrimitivePtr const& t2) const { | |||
| bool operator()(PrimitivePtr const &t1, PrimitivePtr const &t2) const { | |||
| MS_EXCEPTION_IF_NULL(t1); | |||
| MS_EXCEPTION_IF_NULL(t2); | |||
| return t1->name() == t2->name(); | |||
| @@ -152,7 +152,7 @@ struct PrimitiveEqual { | |||
| }; | |||
| struct PrimitiveHasher { | |||
| std::size_t operator()(PrimitivePtr const& prim) const { | |||
| std::size_t operator()(PrimitivePtr const &prim) const { | |||
| std::size_t hash = std::hash<std::string>()(prim->name()); | |||
| return hash; | |||
| } | |||
| @@ -55,8 +55,8 @@ class BoolImm : public Scalar { | |||
| bool value() const { return v_; } | |||
| bool IsZero() override { return v_ == false; } | |||
| bool IsOne() override { return v_ == true; } | |||
| bool operator==(const Value& other) const override; | |||
| bool operator==(const BoolImm& other) const; | |||
| bool operator==(const Value &other) const override; | |||
| bool operator==(const BoolImm &other) const; | |||
| std::string ToString() const override { | |||
| if (v_) { | |||
| return "true"; | |||
| @@ -80,7 +80,7 @@ IMM_TRAITS(BoolImmPtr, bool) | |||
| class IntergerImm : public Scalar { | |||
| public: | |||
| IntergerImm() = default; | |||
| explicit IntergerImm(const TypePtr& t) : Scalar(t) {} | |||
| explicit IntergerImm(const TypePtr &t) : Scalar(t) {} | |||
| ~IntergerImm() override = default; | |||
| MS_DECLARE_PARENT(IntergerImm, Scalar) | |||
| }; | |||
| @@ -95,8 +95,8 @@ class Int8Imm : public IntergerImm { | |||
| bool IsZero() override { return v_ == 0; } | |||
| bool IsOne() override { return v_ == 1; } | |||
| int8_t value() const { return v_; } | |||
| bool operator==(const Value& other) const override; | |||
| bool operator==(const Int8Imm& other) const; | |||
| bool operator==(const Value &other) const override; | |||
| bool operator==(const Int8Imm &other) const; | |||
| std::string ToString() const override { return std::to_string(v_); } | |||
| std::string DumpText() const override { | |||
| @@ -121,8 +121,8 @@ class Int16Imm : public IntergerImm { | |||
| bool IsZero() override { return v_ == 0; } | |||
| bool IsOne() override { return v_ == 1; } | |||
| int16_t value() const { return v_; } | |||
| bool operator==(const Value& other) const override; | |||
| bool operator==(const Int16Imm& other) const; | |||
| bool operator==(const Value &other) const override; | |||
| bool operator==(const Int16Imm &other) const; | |||
| std::string ToString() const override { return std::to_string(v_); } | |||
| std::string DumpText() const override { | |||
| @@ -147,8 +147,8 @@ class Int32Imm : public IntergerImm { | |||
| bool IsZero() override { return v_ == 0; } | |||
| bool IsOne() override { return v_ == 1; } | |||
| int32_t value() const { return v_; } | |||
| bool operator==(const Value& other) const override; | |||
| bool operator==(const Int32Imm& other) const; | |||
| bool operator==(const Value &other) const override; | |||
| bool operator==(const Int32Imm &other) const; | |||
| std::string ToString() const override { return std::to_string(v_); } | |||
| std::string DumpText() const override { | |||
| @@ -173,8 +173,8 @@ class Int64Imm : public IntergerImm { | |||
| bool IsZero() override { return v_ == 0; } | |||
| bool IsOne() override { return v_ == 1; } | |||
| int64_t value() const { return v_; } | |||
| bool operator==(const Value& other) const override; | |||
| bool operator==(const Int64Imm& other) const; | |||
| bool operator==(const Value &other) const override; | |||
| bool operator==(const Int64Imm &other) const; | |||
| std::string ToString() const override { return std::to_string(v_); } | |||
| std::string DumpText() const override { | |||
| @@ -199,8 +199,8 @@ class UInt8Imm : public IntergerImm { | |||
| bool IsZero() override { return v_ == 0; } | |||
| bool IsOne() override { return v_ == 1; } | |||
| uint8_t value() const { return v_; } | |||
| bool operator==(const Value& other) const override; | |||
| bool operator==(const UInt8Imm& other) const; | |||
| bool operator==(const Value &other) const override; | |||
| bool operator==(const UInt8Imm &other) const; | |||
| std::string ToString() const override { return std::to_string(v_); } | |||
| std::string DumpText() const override { | |||
| @@ -225,8 +225,8 @@ class UInt16Imm : public IntergerImm { | |||
| bool IsZero() override { return v_ == 0; } | |||
| bool IsOne() override { return v_ == 1; } | |||
| uint16_t value() const { return v_; } | |||
| bool operator==(const Value& other) const override; | |||
| bool operator==(const UInt16Imm& other) const; | |||
| bool operator==(const Value &other) const override; | |||
| bool operator==(const UInt16Imm &other) const; | |||
| std::string ToString() const override { return std::to_string(v_); } | |||
| std::string DumpText() const override { | |||
| @@ -251,8 +251,8 @@ class UInt32Imm : public IntergerImm { | |||
| bool IsZero() override { return v_ == 0; } | |||
| bool IsOne() override { return v_ == 1; } | |||
| uint32_t value() const { return v_; } | |||
| bool operator==(const Value& other) const override; | |||
| bool operator==(const UInt32Imm& other) const; | |||
| bool operator==(const Value &other) const override; | |||
| bool operator==(const UInt32Imm &other) const; | |||
| std::string ToString() const override { return std::to_string(v_); } | |||
| std::string DumpText() const override { | |||
| @@ -277,8 +277,8 @@ class UInt64Imm : public IntergerImm { | |||
| bool IsZero() override { return v_ == 0; } | |||
| bool IsOne() override { return v_ == 1; } | |||
| uint64_t value() const { return v_; } | |||
| bool operator==(const Value& other) const override; | |||
| bool operator==(const UInt64Imm& other) const; | |||
| bool operator==(const Value &other) const override; | |||
| bool operator==(const UInt64Imm &other) const; | |||
| std::string ToString() const override { return std::to_string(v_); } | |||
| std::string DumpText() const override { | |||
| @@ -296,7 +296,7 @@ IMM_TRAITS(UInt64ImmPtr, uint64_t); | |||
| class FloatImm : public Scalar { | |||
| public: | |||
| FloatImm() = default; | |||
| explicit FloatImm(const TypePtr& t) : Scalar(t) {} | |||
| explicit FloatImm(const TypePtr &t) : Scalar(t) {} | |||
| ~FloatImm() override = default; | |||
| MS_DECLARE_PARENT(FloatImm, Scalar) | |||
| }; | |||
| @@ -312,8 +312,8 @@ class FP32Imm : public FloatImm { | |||
| bool IsZero() override { return fabs(v_) <= FLT_EPSILON; } | |||
| bool IsOne() override { return fabs(v_ - 1.0) <= FLT_EPSILON; } | |||
| float value() const { return v_; } | |||
| bool operator==(const Value& other) const override; | |||
| bool operator==(const FP32Imm& other) const; | |||
| bool operator==(const Value &other) const override; | |||
| bool operator==(const FP32Imm &other) const; | |||
| std::string ToString() const override { return std::to_string(v_); } | |||
| std::string DumpText() const override { | |||
| @@ -338,8 +338,8 @@ class FP64Imm : public FloatImm { | |||
| bool IsZero() override { return fabs(v_) <= DBL_EPSILON; } | |||
| bool IsOne() override { return fabs(v_ - 1.0) <= DBL_EPSILON; } | |||
| double value() const { return v_; } | |||
| bool operator==(const Value& other) const override; | |||
| bool operator==(const FP64Imm& other) const; | |||
| bool operator==(const Value &other) const override; | |||
| bool operator==(const FP64Imm &other) const; | |||
| std::string ToString() const override { return std::to_string(v_); } | |||
| std::string DumpText() const override { | |||
| @@ -21,8 +21,8 @@ | |||
| #include "pipeline/parse/data_converter.h" | |||
| namespace mindspore { | |||
| Signature::Signature(const std::string& arg_name, const SignatureEnumRW& rw_tag, const SignatureEnumKind& arg_kind, | |||
| const py::object& arg_default, const SignatureEnumDType& arg_dtype) | |||
| Signature::Signature(const std::string &arg_name, const SignatureEnumRW &rw_tag, const SignatureEnumKind &arg_kind, | |||
| const py::object &arg_default, const SignatureEnumDType &arg_dtype) | |||
| : name(arg_name), rw(rw_tag), kind(arg_kind), dtype(arg_dtype) { | |||
| if (py::isinstance<SignatureEnumKind>(arg_default) && | |||
| py::cast<SignatureEnumKind>(arg_default) == SignatureEnumKind::kKindEmptyDefaultValue) { | |||
| @@ -32,14 +32,14 @@ Signature::Signature(const std::string& arg_name, const SignatureEnumRW& rw_tag, | |||
| } | |||
| } | |||
| Signature::Signature(const std::string& arg_name, const SignatureEnumRW& rw_tag, const SignatureEnumKind& arg_kind) | |||
| Signature::Signature(const std::string &arg_name, const SignatureEnumRW &rw_tag, const SignatureEnumKind &arg_kind) | |||
| : name(arg_name), | |||
| rw(rw_tag), | |||
| kind(arg_kind), | |||
| default_value(nullptr), | |||
| dtype(SignatureEnumDType::kDTypeEmptyDefaultValue) {} | |||
| REGISTER_PYBIND_DEFINE(SignatureEnumRW, ([](const py::module* m) { | |||
| REGISTER_PYBIND_DEFINE(SignatureEnumRW, ([](const py::module *m) { | |||
| (void)py::enum_<SignatureEnumRW>(*m, "signature_rw", py::arithmetic()) | |||
| .value("RW_READ", SignatureEnumRW::kRWRead) | |||
| .value("RW_WRITE", SignatureEnumRW::kRWWrite) | |||
| @@ -61,9 +61,9 @@ struct Signature { | |||
| SignatureEnumKind kind; | |||
| ValuePtr default_value; // nullptr for no default value | |||
| SignatureEnumDType dtype; | |||
| Signature(const std::string& arg_name, const SignatureEnumRW& rw_tag, const SignatureEnumKind& arg_kind, | |||
| const py::object& arg_default, const SignatureEnumDType& arg_dtype); | |||
| Signature(const std::string& arg_name, const SignatureEnumRW& rw_tag, const SignatureEnumKind& arg_kind); | |||
| Signature(const std::string &arg_name, const SignatureEnumRW &rw_tag, const SignatureEnumKind &arg_kind, | |||
| const py::object &arg_default, const SignatureEnumDType &arg_dtype); | |||
| Signature(const std::string &arg_name, const SignatureEnumRW &rw_tag, const SignatureEnumKind &arg_kind); | |||
| }; | |||
| } // namespace mindspore | |||
| @@ -24,7 +24,7 @@ | |||
| #include "pipeline/static_analysis/abstract_value.h" | |||
| namespace mindspore { | |||
| const ValuePtr ValueSequeue::operator[](const std::size_t& dim) const { | |||
| const ValuePtr ValueSequeue::operator[](const std::size_t &dim) const { | |||
| if (dim >= size()) { | |||
| MS_LOG(EXCEPTION) << "List index [" << dim << "] is out of range [" << size() << "]."; | |||
| } | |||
| @@ -40,125 +40,125 @@ bool ValueSequeue::erase(size_t idx) { | |||
| } | |||
| } | |||
| bool BoolImm::operator==(const Value& other) const { | |||
| bool BoolImm::operator==(const Value &other) const { | |||
| if (other.isa<BoolImm>()) { | |||
| auto other_ = static_cast<const BoolImm&>(other); | |||
| auto other_ = static_cast<const BoolImm &>(other); | |||
| return *this == other_; | |||
| } else { | |||
| return false; | |||
| } | |||
| } | |||
| bool BoolImm::operator==(const BoolImm& other) const { return v_ == other.v_; } | |||
| bool BoolImm::operator==(const BoolImm &other) const { return v_ == other.v_; } | |||
| bool Int8Imm::operator==(const Value& other) const { | |||
| bool Int8Imm::operator==(const Value &other) const { | |||
| if (other.isa<Int8Imm>()) { | |||
| auto other_ = static_cast<const Int8Imm&>(other); | |||
| auto other_ = static_cast<const Int8Imm &>(other); | |||
| return *this == other_; | |||
| } else { | |||
| return false; | |||
| } | |||
| } | |||
| bool Int8Imm::operator==(const Int8Imm& other) const { return v_ == other.v_; } | |||
| bool Int16Imm::operator==(const Value& other) const { | |||
| bool Int8Imm::operator==(const Int8Imm &other) const { return v_ == other.v_; } | |||
| bool Int16Imm::operator==(const Value &other) const { | |||
| if (other.isa<Int16Imm>()) { | |||
| auto other_ = static_cast<const Int16Imm&>(other); | |||
| auto other_ = static_cast<const Int16Imm &>(other); | |||
| return *this == other_; | |||
| } else { | |||
| return false; | |||
| } | |||
| } | |||
| bool Int16Imm::operator==(const Int16Imm& other) const { return v_ == other.v_; } | |||
| bool Int32Imm::operator==(const Value& other) const { | |||
| bool Int16Imm::operator==(const Int16Imm &other) const { return v_ == other.v_; } | |||
| bool Int32Imm::operator==(const Value &other) const { | |||
| if (other.isa<Int32Imm>()) { | |||
| auto other_ = static_cast<const Int32Imm&>(other); | |||
| auto other_ = static_cast<const Int32Imm &>(other); | |||
| return *this == other_; | |||
| } else { | |||
| return false; | |||
| } | |||
| } | |||
| bool Int32Imm::operator==(const Int32Imm& other) const { return v_ == other.v_; } | |||
| bool Int64Imm::operator==(const Value& other) const { | |||
| bool Int32Imm::operator==(const Int32Imm &other) const { return v_ == other.v_; } | |||
| bool Int64Imm::operator==(const Value &other) const { | |||
| if (other.isa<Int64Imm>()) { | |||
| auto other_ = static_cast<const Int64Imm&>(other); | |||
| auto other_ = static_cast<const Int64Imm &>(other); | |||
| return *this == other_; | |||
| } else { | |||
| return false; | |||
| } | |||
| } | |||
| bool Int64Imm::operator==(const Int64Imm& other) const { return v_ == other.v_; } | |||
| bool UInt8Imm::operator==(const Value& other) const { | |||
| bool Int64Imm::operator==(const Int64Imm &other) const { return v_ == other.v_; } | |||
| bool UInt8Imm::operator==(const Value &other) const { | |||
| if (other.isa<UInt8Imm>()) { | |||
| auto other_ = static_cast<const UInt8Imm&>(other); | |||
| auto other_ = static_cast<const UInt8Imm &>(other); | |||
| return *this == other_; | |||
| } else { | |||
| return false; | |||
| } | |||
| } | |||
| bool UInt8Imm::operator==(const UInt8Imm& other) const { return v_ == other.v_; } | |||
| bool UInt16Imm::operator==(const Value& other) const { | |||
| bool UInt8Imm::operator==(const UInt8Imm &other) const { return v_ == other.v_; } | |||
| bool UInt16Imm::operator==(const Value &other) const { | |||
| if (other.isa<UInt16Imm>()) { | |||
| auto other_ = static_cast<const UInt16Imm&>(other); | |||
| auto other_ = static_cast<const UInt16Imm &>(other); | |||
| return *this == other_; | |||
| } else { | |||
| return false; | |||
| } | |||
| } | |||
| bool UInt16Imm::operator==(const UInt16Imm& other) const { return v_ == other.v_; } | |||
| bool UInt32Imm::operator==(const Value& other) const { | |||
| bool UInt16Imm::operator==(const UInt16Imm &other) const { return v_ == other.v_; } | |||
| bool UInt32Imm::operator==(const Value &other) const { | |||
| if (other.isa<UInt32Imm>()) { | |||
| auto other_ = static_cast<const UInt32Imm&>(other); | |||
| auto other_ = static_cast<const UInt32Imm &>(other); | |||
| return *this == other_; | |||
| } else { | |||
| return false; | |||
| } | |||
| } | |||
| bool UInt32Imm::operator==(const UInt32Imm& other) const { return v_ == other.v_; } | |||
| bool UInt64Imm::operator==(const Value& other) const { | |||
| bool UInt32Imm::operator==(const UInt32Imm &other) const { return v_ == other.v_; } | |||
| bool UInt64Imm::operator==(const Value &other) const { | |||
| if (other.isa<UInt64Imm>()) { | |||
| auto other_ = static_cast<const UInt64Imm&>(other); | |||
| auto other_ = static_cast<const UInt64Imm &>(other); | |||
| return *this == other_; | |||
| } else { | |||
| return false; | |||
| } | |||
| } | |||
| bool UInt64Imm::operator==(const UInt64Imm& other) const { return v_ == other.v_; } | |||
| bool FP32Imm::operator==(const Value& other) const { | |||
| bool UInt64Imm::operator==(const UInt64Imm &other) const { return v_ == other.v_; } | |||
| bool FP32Imm::operator==(const Value &other) const { | |||
| if (other.isa<FP32Imm>()) { | |||
| auto other_ = static_cast<const FP32Imm&>(other); | |||
| auto other_ = static_cast<const FP32Imm &>(other); | |||
| return *this == other_; | |||
| } else { | |||
| return false; | |||
| } | |||
| } | |||
| bool FP32Imm::operator==(const FP32Imm& other) const { return fabs(v_ - other.v_) < FLT_EPSILON; } | |||
| bool FP64Imm::operator==(const Value& other) const { | |||
| bool FP32Imm::operator==(const FP32Imm &other) const { return fabs(v_ - other.v_) < FLT_EPSILON; } | |||
| bool FP64Imm::operator==(const Value &other) const { | |||
| if (other.isa<FP64Imm>()) { | |||
| auto other_ = static_cast<const FP64Imm&>(other); | |||
| auto other_ = static_cast<const FP64Imm &>(other); | |||
| return *this == other_; | |||
| } else { | |||
| return false; | |||
| } | |||
| } | |||
| bool ValueSequeue::operator==(const Value& other) const { | |||
| bool ValueSequeue::operator==(const Value &other) const { | |||
| if (other.isa<ValueSequeue>()) { | |||
| auto other_ = static_cast<const ValueSequeue&>(other); | |||
| auto other_ = static_cast<const ValueSequeue &>(other); | |||
| return *this == other_; | |||
| } else { | |||
| return false; | |||
| } | |||
| } | |||
| bool ValueSequeue::operator==(const ValueSequeue& other) const { | |||
| bool ValueSequeue::operator==(const ValueSequeue &other) const { | |||
| if (other.elements_.size() != elements_.size()) { | |||
| return false; | |||
| } | |||
| return std::equal(elements_.begin(), elements_.end(), other.elements_.begin(), | |||
| [](const ValuePtr& lhs, const ValuePtr& rhs) { return *lhs == *rhs; }); | |||
| [](const ValuePtr &lhs, const ValuePtr &rhs) { return *lhs == *rhs; }); | |||
| } | |||
| std::string ValueSequeue::ToString() const { | |||
| std::ostringstream buffer; | |||
| bool begin = true; | |||
| for (auto& attr : elements_) { | |||
| for (auto &attr : elements_) { | |||
| if (!begin) { | |||
| buffer << ", "; | |||
| } else { | |||
| @@ -179,28 +179,28 @@ std::string ValueSequeue::DumpText() const { | |||
| return oss.str(); | |||
| } | |||
| bool FP64Imm::operator==(const FP64Imm& other) const { return fabs(v_ - other.v_) < DBL_EPSILON; } | |||
| bool StringImm::operator==(const Value& other) const { | |||
| bool FP64Imm::operator==(const FP64Imm &other) const { return fabs(v_ - other.v_) < DBL_EPSILON; } | |||
| bool StringImm::operator==(const Value &other) const { | |||
| if (other.isa<StringImm>()) { | |||
| auto other_ = static_cast<const StringImm&>(other); | |||
| auto other_ = static_cast<const StringImm &>(other); | |||
| return *this == other_; | |||
| } else { | |||
| return false; | |||
| } | |||
| } | |||
| bool StringImm::operator==(const StringImm& other) const { return str_ == other.str_; } | |||
| bool StringImm::operator==(const StringImm &other) const { return str_ == other.str_; } | |||
| bool RefKey::operator==(const Value& other) const { | |||
| bool RefKey::operator==(const Value &other) const { | |||
| if (other.isa<RefKey>()) { | |||
| auto other_ = static_cast<const RefKey&>(other); | |||
| auto other_ = static_cast<const RefKey &>(other); | |||
| return *this == other_; | |||
| } else { | |||
| return false; | |||
| } | |||
| } | |||
| bool RefKey::operator==(const RefKey& other) const { return tag_ == other.tag_; } | |||
| bool RefKey::operator==(const RefKey &other) const { return tag_ == other.tag_; } | |||
| bool AnyValue::operator==(const Value& other) const { | |||
| bool AnyValue::operator==(const Value &other) const { | |||
| if (other.isa<AnyValue>()) { | |||
| return true; | |||
| } else { | |||
| @@ -228,7 +228,7 @@ abstract::AbstractBasePtr AnyValue::ToAbstract() { return std::make_shared<abstr | |||
| abstract::AbstractBasePtr ValueTuple::ToAbstract() { | |||
| abstract::AbstractBasePtrList a_list; | |||
| (void)std::transform(elements_.begin(), elements_.end(), std::back_inserter(a_list), [](const ValuePtr& ele) { | |||
| (void)std::transform(elements_.begin(), elements_.end(), std::back_inserter(a_list), [](const ValuePtr &ele) { | |||
| MS_EXCEPTION_IF_NULL(ele); | |||
| return ele->ToAbstract(); | |||
| }); | |||
| @@ -237,7 +237,7 @@ abstract::AbstractBasePtr ValueTuple::ToAbstract() { | |||
| abstract::AbstractBasePtr ValueList::ToAbstract() { | |||
| abstract::AbstractBasePtrList a_list; | |||
| (void)std::transform(elements_.begin(), elements_.end(), std::back_inserter(a_list), [](const ValuePtr& ele) { | |||
| (void)std::transform(elements_.begin(), elements_.end(), std::back_inserter(a_list), [](const ValuePtr &ele) { | |||
| MS_EXCEPTION_IF_NULL(ele); | |||
| return ele->ToAbstract(); | |||
| }); | |||
| @@ -251,16 +251,16 @@ std::size_t ValueSlice::hash() const { | |||
| return hash_combine({tid(), start_->hash(), stop_->hash(), step_->hash()}); | |||
| } | |||
| bool ValueSlice::operator==(const Value& other) const { | |||
| bool ValueSlice::operator==(const Value &other) const { | |||
| if (other.isa<ValueSlice>()) { | |||
| auto other_ = static_cast<const ValueSlice&>(other); | |||
| auto other_ = static_cast<const ValueSlice &>(other); | |||
| return *this == other_; | |||
| } else { | |||
| return false; | |||
| } | |||
| } | |||
| bool ValueSlice::operator==(const ValueSlice& other) const { | |||
| bool ValueSlice::operator==(const ValueSlice &other) const { | |||
| MS_EXCEPTION_IF_NULL(start_); | |||
| MS_EXCEPTION_IF_NULL(stop_); | |||
| MS_EXCEPTION_IF_NULL(step_); | |||
| @@ -295,16 +295,16 @@ std::size_t KeywordArg::hash() const { | |||
| return hash_combine({tid(), std::hash<std::string>{}(key_), value_->hash()}); | |||
| } | |||
| bool KeywordArg::operator==(const Value& other) const { | |||
| bool KeywordArg::operator==(const Value &other) const { | |||
| if (other.isa<KeywordArg>()) { | |||
| auto other_ = static_cast<const KeywordArg&>(other); | |||
| auto other_ = static_cast<const KeywordArg &>(other); | |||
| return *this == other_; | |||
| } else { | |||
| return false; | |||
| } | |||
| } | |||
| bool KeywordArg::operator==(const KeywordArg& other) const { return (other.key_ == key_ && *other.value_ == *value_); } | |||
| bool KeywordArg::operator==(const KeywordArg &other) const { return (other.key_ == key_ && *other.value_ == *value_); } | |||
| std::string KeywordArg::ToString() const { | |||
| std::ostringstream buffer; | |||
| @@ -322,25 +322,25 @@ abstract::AbstractBasePtr KeywordArg::ToAbstract() { | |||
| return std::make_shared<abstract::AbstractKeywordArg>(key_, argument); | |||
| } | |||
| const ValuePtr ValueDictionary::operator[](const std::string& key) const { | |||
| const ValuePtr ValueDictionary::operator[](const std::string &key) const { | |||
| auto it = std::find_if(key_values_.begin(), key_values_.end(), | |||
| [key](const std::pair<std::string, ValuePtr>& item) { return item.first == key; }); | |||
| [key](const std::pair<std::string, ValuePtr> &item) { return item.first == key; }); | |||
| if (it == key_values_.end()) { | |||
| MS_LOG(EXCEPTION) << "The key " << key << " is not in the map"; | |||
| } | |||
| return it->second; | |||
| } | |||
| bool ValueDictionary::operator==(const Value& other) const { | |||
| bool ValueDictionary::operator==(const Value &other) const { | |||
| if (other.isa<ValueDictionary>()) { | |||
| auto other_ = static_cast<const ValueDictionary&>(other); | |||
| auto other_ = static_cast<const ValueDictionary &>(other); | |||
| return *this == other_; | |||
| } else { | |||
| return false; | |||
| } | |||
| } | |||
| bool ValueDictionary::operator==(const ValueDictionary& other) const { | |||
| bool ValueDictionary::operator==(const ValueDictionary &other) const { | |||
| if (key_values_.size() != other.key_values_.size()) { | |||
| return false; | |||
| } | |||
| @@ -359,12 +359,12 @@ abstract::AbstractBasePtr ValueDictionary::ToAbstract() { | |||
| std::vector<std::pair<std::string, abstract::AbstractBasePtr>> kv; | |||
| (void)std::transform( | |||
| key_values_.begin(), key_values_.end(), std::back_inserter(kv), | |||
| [](const std::pair<std::string, ValuePtr>& item) { return std::make_pair(item.first, item.second->ToAbstract()); }); | |||
| [](const std::pair<std::string, ValuePtr> &item) { return std::make_pair(item.first, item.second->ToAbstract()); }); | |||
| return std::make_shared<abstract::AbstractDictionary>(kv); | |||
| } | |||
| REGISTER_PYBIND_DEFINE( | |||
| RefKey, ([](const py::module* m) { | |||
| RefKey, ([](const py::module *m) { | |||
| (void)py::class_<RefKey, std::shared_ptr<RefKey>>(*m, "RefKey").def(py::init<std::string>(), py::arg("tag")); | |||
| })); | |||
| } // namespace mindspore | |||
| @@ -35,19 +35,19 @@ | |||
| namespace mindspore { | |||
| class ValueSequeue : public Value { | |||
| public: | |||
| explicit ValueSequeue(const ValuePtrList& elements) : elements_(elements) { | |||
| explicit ValueSequeue(const ValuePtrList &elements) : elements_(elements) { | |||
| TypePtrList t_list; | |||
| (void)std::transform(elements.begin(), elements.end(), std::back_inserter(t_list), [](const ValuePtr& ele) { | |||
| (void)std::transform(elements.begin(), elements.end(), std::back_inserter(t_list), [](const ValuePtr &ele) { | |||
| MS_EXCEPTION_IF_NULL(ele); | |||
| return ele->type(); | |||
| }); | |||
| TypePtr t = std::make_shared<Tuple>(t_list); | |||
| type_ = t; | |||
| } | |||
| ValueSequeue(const std::initializer_list<ValuePtr>& elements) : elements_(elements.begin(), elements.end()) { | |||
| ValueSequeue(const std::initializer_list<ValuePtr> &elements) : elements_(elements.begin(), elements.end()) { | |||
| TypePtrList t_list; | |||
| (void)std::transform(elements_.begin(), elements_.end(), std::back_inserter(t_list), | |||
| [](const ValuePtr& ele) { return ele->type(); }); | |||
| [](const ValuePtr &ele) { return ele->type(); }); | |||
| TypePtr t = std::make_shared<Tuple>(t_list); | |||
| type_ = t; | |||
| } | |||
| @@ -56,10 +56,10 @@ class ValueSequeue : public Value { | |||
| std::size_t hash() const override { return hash_combine(tid(), std::hash<std::size_t>{}(elements_.size())); } | |||
| std::size_t size() const { return elements_.size(); } | |||
| bool erase(size_t idx); | |||
| const ValuePtr operator[](const std::size_t& dim) const; | |||
| const ValuePtrList& value() const { return elements_; } | |||
| bool operator==(const Value& other) const override; | |||
| bool operator==(const ValueSequeue& other) const; | |||
| const ValuePtr operator[](const std::size_t &dim) const; | |||
| const ValuePtrList &value() const { return elements_; } | |||
| bool operator==(const Value &other) const override; | |||
| bool operator==(const ValueSequeue &other) const; | |||
| std::string ToString() const override; | |||
| std::string DumpText() const override; | |||
| @@ -70,8 +70,8 @@ using ValueSequeuePtr = std::shared_ptr<ValueSequeue>; | |||
| class ValueTuple : public ValueSequeue { | |||
| public: | |||
| explicit ValueTuple(const std::vector<ValuePtr>& elements) : ValueSequeue(elements) {} | |||
| ValueTuple(const std::initializer_list<ValuePtr>& elements) : ValueSequeue(elements) {} | |||
| explicit ValueTuple(const std::vector<ValuePtr> &elements) : ValueSequeue(elements) {} | |||
| ValueTuple(const std::initializer_list<ValuePtr> &elements) : ValueSequeue(elements) {} | |||
| ~ValueTuple() override = default; | |||
| MS_DECLARE_PARENT(ValueTuple, ValueSequeue) | |||
| abstract::AbstractBasePtr ToAbstract() override; | |||
| @@ -83,8 +83,8 @@ using ValueTuplePtr = std::shared_ptr<ValueTuple>; | |||
| class ValueList : public ValueSequeue { | |||
| public: | |||
| explicit ValueList(const std::vector<ValuePtr>& elements) : ValueSequeue(elements) {} | |||
| ValueList(const std::initializer_list<ValuePtr>& elements) : ValueSequeue(elements) {} | |||
| explicit ValueList(const std::vector<ValuePtr> &elements) : ValueSequeue(elements) {} | |||
| ValueList(const std::initializer_list<ValuePtr> &elements) : ValueSequeue(elements) {} | |||
| ~ValueList() override = default; | |||
| MS_DECLARE_PARENT(ValueList, ValueSequeue) | |||
| abstract::AbstractBasePtr ToAbstract() override; | |||
| @@ -94,7 +94,7 @@ class ValueList : public ValueSequeue { | |||
| }; | |||
| using ValueListPtr = std::shared_ptr<ValueList>; | |||
| inline ValuePtr MakeValue(const std::vector<ValuePtr>& v) { return std::make_shared<ValueTuple>(v); } | |||
| inline ValuePtr MakeValue(const std::vector<ValuePtr> &v) { return std::make_shared<ValueTuple>(v); } | |||
| inline ValuePtr MakeValue(std::initializer_list<ValuePtr> v) { return std::make_shared<ValueTuple>(v); } | |||
| template <typename T> | |||
| @@ -103,7 +103,7 @@ template <typename T, typename A> | |||
| struct is_vector<std::vector<T, A>> : public std::true_type {}; | |||
| template <typename T, typename U = typename std::enable_if<is_vector<T>::value, typename T::value_type>::type> | |||
| ValuePtr MakeValue(const T& vec) { | |||
| ValuePtr MakeValue(const T &vec) { | |||
| std::vector<ValuePtr> list; | |||
| (void)std::transform(vec.begin(), vec.end(), std::back_inserter(list), [](U ele) { return MakeValue(ele); }); | |||
| return std::make_shared<ValueTuple>(list); | |||
| @@ -111,13 +111,13 @@ ValuePtr MakeValue(const T& vec) { | |||
| class ValueSlice : public Value { | |||
| public: | |||
| ValueSlice(const ValuePtr& start, const ValuePtr& stop, const ValuePtr& step) | |||
| ValueSlice(const ValuePtr &start, const ValuePtr &stop, const ValuePtr &step) | |||
| : start_(start), stop_(stop), step_(step) {} | |||
| ~ValueSlice() override = default; | |||
| MS_DECLARE_PARENT(ValueSlice, Value) | |||
| std::size_t hash() const override; | |||
| bool operator==(const Value& other) const override; | |||
| bool operator==(const ValueSlice& other) const; | |||
| bool operator==(const Value &other) const override; | |||
| bool operator==(const ValueSlice &other) const; | |||
| std::string ToString() const override; | |||
| @@ -133,13 +133,13 @@ using ValueSlicePtr = std::shared_ptr<ValueSlice>; | |||
| class KeywordArg : public Value { | |||
| public: | |||
| KeywordArg(const std::string& key, const ValuePtr& value) : key_(key), value_(value) {} | |||
| KeywordArg(const std::string &key, const ValuePtr &value) : key_(key), value_(value) {} | |||
| ~KeywordArg() override = default; | |||
| MS_DECLARE_PARENT(KeywordArg, Value) | |||
| std::size_t hash() const override; | |||
| ValuePtr get_value() const { return value_; } | |||
| bool operator==(const Value& other) const override; | |||
| bool operator==(const KeywordArg& other) const; | |||
| bool operator==(const Value &other) const override; | |||
| bool operator==(const KeywordArg &other) const; | |||
| std::string ToString() const override; | |||
| @@ -154,31 +154,31 @@ using KeywordArgPtr = std::shared_ptr<KeywordArg>; | |||
| class ValueDictionary : public Value { | |||
| public: | |||
| explicit ValueDictionary(const std::vector<std::pair<std::string, ValuePtr>>& key_values) : key_values_(key_values) {} | |||
| explicit ValueDictionary(const std::vector<std::pair<std::string, ValuePtr>> &key_values) : key_values_(key_values) {} | |||
| ~ValueDictionary() override = default; | |||
| MS_DECLARE_PARENT(ValueDictionary, Value) | |||
| std::size_t hash() const override { return hash_combine(tid(), std::hash<std::size_t>{}(key_values_.size())); } | |||
| std::size_t size() const { return key_values_.size(); } | |||
| const ValuePtr operator[](const std::string& key) const; | |||
| const std::vector<std::pair<std::string, ValuePtr>>& value() const { return key_values_; } | |||
| bool operator==(const Value& other) const override; | |||
| bool operator==(const ValueDictionary& other) const; | |||
| const ValuePtr operator[](const std::string &key) const; | |||
| const std::vector<std::pair<std::string, ValuePtr>> &value() const { return key_values_; } | |||
| bool operator==(const Value &other) const override; | |||
| bool operator==(const ValueDictionary &other) const; | |||
| std::string ToString() const override { | |||
| std::ostringstream buffer; | |||
| std::vector<std::string> keys; | |||
| std::vector<ValuePtr> values; | |||
| for (const auto& kv : key_values_) { | |||
| for (const auto &kv : key_values_) { | |||
| keys.push_back(kv.first); | |||
| values.push_back(kv.second); | |||
| } | |||
| buffer << "(Dict: " | |||
| << " keys:("; | |||
| for (const auto& key : keys) { | |||
| for (const auto &key : keys) { | |||
| buffer << key << ", "; | |||
| } | |||
| buffer << ") values:("; | |||
| for (const auto& value : values) { | |||
| for (const auto &value : values) { | |||
| MS_EXCEPTION_IF_NULL(value); | |||
| buffer << value->DumpText() << ", "; | |||
| } | |||
| @@ -195,14 +195,14 @@ using ValueDictionaryPtr = std::shared_ptr<ValueDictionary>; | |||
| class StringImm : public Value { | |||
| public: | |||
| explicit StringImm(const std::string& str) : Value(kString), str_(str), hash_(std::hash<std::string>{}(str_)) {} | |||
| explicit StringImm(const std::string &str) : Value(kString), str_(str), hash_(std::hash<std::string>{}(str_)) {} | |||
| ~StringImm() override = default; | |||
| MS_DECLARE_PARENT(StringImm, Value) | |||
| std::size_t hash() const override { return hash_; } | |||
| const std::string& value() const { return str_; } | |||
| bool operator==(const Value& other) const override; | |||
| bool operator==(const StringImm& other) const; | |||
| const std::string &value() const { return str_; } | |||
| bool operator==(const Value &other) const override; | |||
| bool operator==(const StringImm &other) const; | |||
| abstract::AbstractBasePtr ToAbstract() override; | |||
| std::string ToString() const override { return str_; } | |||
| @@ -218,18 +218,18 @@ class StringImm : public Value { | |||
| }; | |||
| using StringImmPtr = std::shared_ptr<StringImm>; | |||
| IMM_TRAITS(StringImmPtr, std::string) | |||
| IMM_TRAITS(StringImmPtr, const char*) | |||
| IMM_TRAITS(StringImmPtr, const char *) | |||
| class RefKey : public Value { | |||
| public: | |||
| explicit RefKey(const std::string& tag) : Value(kRefKeyType), tag_(tag), hash_(std::hash<std::string>{}(tag)) {} | |||
| explicit RefKey(const std::string &tag) : Value(kRefKeyType), tag_(tag), hash_(std::hash<std::string>{}(tag)) {} | |||
| ~RefKey() override = default; | |||
| MS_DECLARE_PARENT(RefKey, Value) | |||
| std::size_t hash() const override { return hash_; } | |||
| const std::string& tag() const { return tag_; } | |||
| bool operator==(const Value& other) const override; | |||
| bool operator==(const RefKey& other) const; | |||
| const std::string &tag() const { return tag_; } | |||
| bool operator==(const Value &other) const override; | |||
| bool operator==(const RefKey &other) const; | |||
| abstract::AbstractBasePtr ToAbstract() override; | |||
| std::string ToString() const override { return "RefKey[" + tag_ + "]"; } | |||
| @@ -251,13 +251,13 @@ class AnyValue : public Value { | |||
| ~AnyValue() override = default; | |||
| MS_DECLARE_PARENT(AnyValue, Value) | |||
| std::size_t hash() const override { return tid(); } | |||
| bool operator==(const Value& other) const override; | |||
| bool operator==(const Value &other) const override; | |||
| abstract::AbstractBasePtr ToAbstract() override; | |||
| }; | |||
| extern const ValuePtr kAnyValue; | |||
| template <> | |||
| inline const char* GetValue(const ValuePtr& value) { | |||
| inline const char *GetValue(const ValuePtr &value) { | |||
| if (value == nullptr) { | |||
| MS_LOG(EXCEPTION) << "Value is nullptr"; | |||
| } | |||
| @@ -270,7 +270,7 @@ inline const char* GetValue(const ValuePtr& value) { | |||
| template <typename T, typename S = typename std::decay<T>::type, | |||
| typename U = typename std::enable_if<is_vector<S>::value, typename S::value_type>::type> | |||
| std::vector<U> GetValue(const ValuePtr& value) { | |||
| std::vector<U> GetValue(const ValuePtr &value) { | |||
| if (value == nullptr) { | |||
| MS_LOG(EXCEPTION) << "Value is nullptr"; | |||
| } | |||
| @@ -280,21 +280,21 @@ std::vector<U> GetValue(const ValuePtr& value) { | |||
| << ">"; | |||
| } | |||
| std::vector<U> rets; | |||
| const std::vector<ValuePtr>& vals = value->cast<ValueSequeuePtr>()->value(); | |||
| const std::vector<ValuePtr> &vals = value->cast<ValueSequeuePtr>()->value(); | |||
| (void)std::transform(vals.begin(), vals.end(), std::back_inserter(rets), | |||
| [](const ValuePtr& v) { return GetValue<U>(v); }); | |||
| [](const ValuePtr &v) { return GetValue<U>(v); }); | |||
| return rets; | |||
| } | |||
| inline ValueNodePtr NewValueNode(const ValuePtr& t) { return std::make_shared<ValueNode>(t); } | |||
| inline ValueNodePtr NewValueNode(const ValuePtr &t) { return std::make_shared<ValueNode>(t); } | |||
| template <typename T, typename _ = typename std::enable_if<!std::is_base_of<Value, T>::value>::type> | |||
| inline ValueNodePtr NewValueNode(const std::shared_ptr<T>& x) { | |||
| inline ValueNodePtr NewValueNode(const std::shared_ptr<T> &x) { | |||
| return NewValueNode(MakeValue(x)); | |||
| } | |||
| template <typename T, typename _ = typename std::enable_if<!is_shared_ptr<T>::value>::type> | |||
| inline ValueNodePtr NewValueNode(const T& x) { | |||
| inline ValueNodePtr NewValueNode(const T &x) { | |||
| return NewValueNode(MakeValue(x)); | |||
| } | |||
| } // namespace mindspore | |||
| @@ -22,15 +22,15 @@ | |||
| #include "optimizer/opt.h" | |||
| namespace mindspore { | |||
| using VisitFuncType = std::function<void(const AnfNodePtr&)>; | |||
| using VisitFuncType = std::function<void(const AnfNodePtr &)>; | |||
| class AnfVisitor { | |||
| public: | |||
| virtual AnfNodePtr operator()(const opt::OptimizerPtr&, const AnfNodePtr&); | |||
| virtual void Visit(const AnfNodePtr&); | |||
| virtual void Visit(const CNodePtr&); | |||
| virtual void Visit(const ValueNodePtr&); | |||
| virtual void Visit(const ParameterPtr&); | |||
| VisitFuncType Match(const PrimitivePtr&, const std::vector<opt::PredicateFuncType>& = {}); | |||
| virtual AnfNodePtr operator()(const opt::OptimizerPtr &, const AnfNodePtr &); | |||
| virtual void Visit(const AnfNodePtr &); | |||
| virtual void Visit(const CNodePtr &); | |||
| virtual void Visit(const ValueNodePtr &); | |||
| virtual void Visit(const ParameterPtr &); | |||
| VisitFuncType Match(const PrimitivePtr &, const std::vector<opt::PredicateFuncType> & = {}); | |||
| virtual ~AnfVisitor() = default; | |||
| }; | |||
| } // namespace mindspore | |||
| @@ -26,12 +26,12 @@ | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| namespace { | |||
| void FilterInvaildKernelInfo(const CNodePtr& kernel_node, | |||
| std::vector<std::shared_ptr<kernel::KernelBuildInfo>>* kernel_info_list) { | |||
| void FilterInvaildKernelInfo(const CNodePtr &kernel_node, | |||
| std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list) { | |||
| MS_EXCEPTION_IF_NULL(kernel_info_list); | |||
| std::vector<std::shared_ptr<kernel::KernelBuildInfo>> filtered_list; | |||
| (void)std::copy_if(kernel_info_list->begin(), kernel_info_list->end(), std::back_inserter(filtered_list), | |||
| [&](const std::shared_ptr<kernel::KernelBuildInfo>& kernel_build_info) { | |||
| [&](const std::shared_ptr<kernel::KernelBuildInfo> &kernel_build_info) { | |||
| return AnfAlgo::GetOutputTensorNum(kernel_node) == kernel_build_info->GetOutputNum() && | |||
| AnfAlgo::GetInputTensorNum(kernel_node) == kernel_build_info->GetInputNum(); | |||
| }); | |||
| @@ -46,7 +46,7 @@ void FilterInvaildKernelInfo(const CNodePtr& kernel_node, | |||
| } | |||
| } | |||
| } // namespace | |||
| void KernelQuery(const CNodePtr& kernel_node, std::vector<std::shared_ptr<kernel::KernelBuildInfo>>* kernel_info_list) { | |||
| void KernelQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list) { | |||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||
| MS_EXCEPTION_IF_NULL(kernel_info_list); | |||
| TbeMetadataInfo(kernel_node, kernel_info_list); | |||
| @@ -38,11 +38,11 @@ class OpAttr { | |||
| std::string value() const { return value_; } | |||
| std::string default_value() const { return default_value_; } | |||
| void set_name(const std::string& name) { name_ = name; } | |||
| void set_param_type(const std::string& param_type) { param_type_ = param_type; } | |||
| void set_type(const std::string& type) { type_ = type; } | |||
| void set_value(const std::string& value) { value_ = value; } | |||
| void set_default_value(const std::string& default_value) { default_value_ = default_value; } | |||
| void set_name(const std::string &name) { name_ = name; } | |||
| void set_param_type(const std::string ¶m_type) { param_type_ = param_type; } | |||
| void set_type(const std::string &type) { type_ = type; } | |||
| void set_value(const std::string &value) { value_ = value; } | |||
| void set_default_value(const std::string &default_value) { default_value_ = default_value; } | |||
| private: | |||
| std::string name_; | |||
| @@ -67,13 +67,13 @@ class OpIOInfo { | |||
| std::vector<std::string> formats() const { return formats_; } | |||
| void set_index(const int index) { index_ = index; } | |||
| void set_name(const std::string& name) { name_ = name; } | |||
| void set_name(const std::string &name) { name_ = name; } | |||
| void set_need_compile(const bool need_compile) { need_compile_ = need_compile; } | |||
| void set_param_type(const std::string& param_type) { param_type_ = param_type; } | |||
| void set_reshape_type(const std::string& reshape_type) { reshape_type_ = reshape_type; } | |||
| void set_shape(const std::string& shape) { shape_ = shape; } | |||
| void set_dtypes(const std::vector<std::string>& dtype) { dtypes_ = dtype; } | |||
| void set_formats(const std::vector<std::string>& formats) { formats_ = formats; } | |||
| void set_param_type(const std::string ¶m_type) { param_type_ = param_type; } | |||
| void set_reshape_type(const std::string &reshape_type) { reshape_type_ = reshape_type; } | |||
| void set_shape(const std::string &shape) { shape_ = shape; } | |||
| void set_dtypes(const std::vector<std::string> &dtype) { dtypes_ = dtype; } | |||
| void set_formats(const std::vector<std::string> &formats) { formats_ = formats; } | |||
| private: | |||
| int index_ = 0; | |||
| @@ -104,24 +104,24 @@ class OpInfo { | |||
| std::vector<std::shared_ptr<OpAttr>> attrs_ptr() const { return attrs_ptr_; } | |||
| std::vector<std::shared_ptr<OpIOInfo>> inputs_ptr() const { return inputs_ptr_; } | |||
| std::vector<std::shared_ptr<OpIOInfo>> outputs_ptr() const { return outputs_ptr_; } | |||
| const std::unordered_map<size_t, size_t>& ref_infos() const { return ref_infos_; } | |||
| const std::unordered_map<size_t, size_t> &ref_infos() const { return ref_infos_; } | |||
| void set_op_name(const std::string& op_name) { op_name_ = op_name; } | |||
| void set_op_name(const std::string &op_name) { op_name_ = op_name; } | |||
| void set_imply_type(const OpImplyType imply_type) { imply_type_ = imply_type; } | |||
| void set_impl_path(const std::string& impl_path) { impl_path_ = impl_path; } | |||
| void set_fusion_type(const std::string& fusion_type) { fusion_type_ = fusion_type; } | |||
| void set_impl_path(const std::string &impl_path) { impl_path_ = impl_path; } | |||
| void set_fusion_type(const std::string &fusion_type) { fusion_type_ = fusion_type; } | |||
| void set_async_flag(const bool async_flag) { async_flag_ = async_flag; } | |||
| void set_binfile_name(const std::string& binfile_name) { binfile_name_ = binfile_name; } | |||
| void set_binfile_name(const std::string &binfile_name) { binfile_name_ = binfile_name; } | |||
| void set_compute_cost(const int compute_cost) { compute_cost_ = compute_cost; } | |||
| void set_kernel_name(const std::string& kernel_name) { kernel_name_ = kernel_name; } | |||
| void set_kernel_name(const std::string &kernel_name) { kernel_name_ = kernel_name; } | |||
| void set_partial_flag(const bool partial_flag) { partial_flag_ = partial_flag; } | |||
| void set_dynamic_format(const bool dynamic_format) { dynamic_format_ = dynamic_format; } | |||
| void set_op_pattern(const std::string op_pattern) { op_pattern_ = op_pattern; } | |||
| void add_attrs_ptr(const std::shared_ptr<OpAttr>& attr) { attrs_ptr_.push_back(attr); } | |||
| void add_inputs_ptr(const std::shared_ptr<OpIOInfo>& input) { inputs_ptr_.push_back(input); } | |||
| void add_outputs_ptr(const std::shared_ptr<OpIOInfo>& output) { outputs_ptr_.push_back(output); } | |||
| void set_inputs_ptr(const std::vector<std::shared_ptr<OpIOInfo>>& inputs) { inputs_ptr_ = inputs; } | |||
| void set_outputs_ptr(const std::vector<std::shared_ptr<OpIOInfo>>& outputs) { outputs_ptr_ = outputs; } | |||
| void add_attrs_ptr(const std::shared_ptr<OpAttr> &attr) { attrs_ptr_.push_back(attr); } | |||
| void add_inputs_ptr(const std::shared_ptr<OpIOInfo> &input) { inputs_ptr_.push_back(input); } | |||
| void add_outputs_ptr(const std::shared_ptr<OpIOInfo> &output) { outputs_ptr_.push_back(output); } | |||
| void set_inputs_ptr(const std::vector<std::shared_ptr<OpIOInfo>> &inputs) { inputs_ptr_ = inputs; } | |||
| void set_outputs_ptr(const std::vector<std::shared_ptr<OpIOInfo>> &outputs) { outputs_ptr_ = outputs; } | |||
| bool is_ref() const { return !ref_infos_.empty(); } | |||
| bool has_ref_index(size_t out_index) const { return ref_infos_.find(out_index) != ref_infos_.end(); } | |||
| void add_ref_pair(size_t out_index, size_t in_index) { (void)ref_infos_.emplace(out_index, in_index); } | |||
| @@ -67,7 +67,7 @@ std::string ImplTypeToStr(OpImplyType impl_type) { | |||
| return "unknow"; | |||
| } | |||
| } | |||
| bool OpLib::RegOp(const std::string& json_string, const std::string& impl_path) { | |||
| bool OpLib::RegOp(const std::string &json_string, const std::string &impl_path) { | |||
| bool ret = false; | |||
| try { | |||
| auto op_json = nlohmann::json::parse(json_string); | |||
| @@ -88,13 +88,13 @@ bool OpLib::RegOp(const std::string& json_string, const std::string& impl_path) | |||
| if (!ret) { | |||
| MS_LOG(DEBUG) << "RegOp failed: opname:" << op_name << "imply_type" << imply_type_string; | |||
| } | |||
| } catch (const std::exception& e) { | |||
| } catch (const std::exception &e) { | |||
| MS_LOG(DEBUG) << "get op_json elements failed:" << e.what(); | |||
| } | |||
| return ret; | |||
| } | |||
| void OpLib::DecodeTBESpecificInfo(const nlohmann::json& obj, const std::shared_ptr<OpInfo>& op_info) { | |||
| void OpLib::DecodeTBESpecificInfo(const nlohmann::json &obj, const std::shared_ptr<OpInfo> &op_info) { | |||
| op_info->set_async_flag(obj.at(kAsyncFlag)); | |||
| op_info->set_binfile_name(obj.at(kBinfileName)); | |||
| op_info->set_compute_cost(obj.at(kComputeCost)); | |||
| @@ -108,8 +108,8 @@ void OpLib::DecodeTBESpecificInfo(const nlohmann::json& obj, const std::shared_p | |||
| } | |||
| } | |||
| bool OpLib::DecodeOpInfo(const nlohmann::json& obj, const mindspore::kernel::OpImplyType imply_type, | |||
| const std::string& impl_path) { | |||
| bool OpLib::DecodeOpInfo(const nlohmann::json &obj, const mindspore::kernel::OpImplyType imply_type, | |||
| const std::string &impl_path) { | |||
| std::shared_ptr<OpInfo> op_info = std::make_shared<OpInfo>(); | |||
| MS_EXCEPTION_IF_NULL(op_info); | |||
| op_info->set_op_name(obj.at(kOpName)); | |||
| @@ -120,7 +120,7 @@ bool OpLib::DecodeOpInfo(const nlohmann::json& obj, const mindspore::kernel::OpI | |||
| DecodeTBESpecificInfo(obj, op_info); | |||
| } | |||
| auto attrs = obj.at(kAttr); | |||
| for (const auto& attr : attrs) { | |||
| for (const auto &attr : attrs) { | |||
| if (!DecodeAttr(attr, imply_type, op_info)) { | |||
| MS_LOG(DEBUG) << "DecodeAttr Failed"; | |||
| return false; | |||
| @@ -131,14 +131,14 @@ bool OpLib::DecodeOpInfo(const nlohmann::json& obj, const mindspore::kernel::OpI | |||
| dtype_format = obj.at(kDtypeFormat); | |||
| } | |||
| auto inputs = obj.at(kIputs); | |||
| for (const auto& input : inputs) { | |||
| for (const auto &input : inputs) { | |||
| if (!DecodeInputOutput(input, imply_type, kInput, op_info, dtype_format)) { | |||
| MS_LOG(DEBUG) << "DecodeInputOutput Failed"; | |||
| return false; | |||
| } | |||
| } | |||
| auto outputs = obj.at(kOutputs); | |||
| for (const auto& output : outputs) { | |||
| for (const auto &output : outputs) { | |||
| if (!DecodeInputOutput(output, imply_type, kOutput, op_info, dtype_format)) { | |||
| MS_LOG(DEBUG) << "DecodeInputOutput Failed"; | |||
| return false; | |||
| @@ -156,8 +156,8 @@ bool OpLib::DecodeOpInfo(const nlohmann::json& obj, const mindspore::kernel::OpI | |||
| return true; | |||
| } | |||
| bool OpLib::DecodeAttr(const nlohmann::json& obj, const OpImplyType imply_type, | |||
| const std::shared_ptr<OpInfo>& op_info) { | |||
| bool OpLib::DecodeAttr(const nlohmann::json &obj, const OpImplyType imply_type, | |||
| const std::shared_ptr<OpInfo> &op_info) { | |||
| MS_EXCEPTION_IF_NULL(op_info); | |||
| bool ret = true; | |||
| try { | |||
| @@ -175,34 +175,34 @@ bool OpLib::DecodeAttr(const nlohmann::json& obj, const OpImplyType imply_type, | |||
| op_attr->set_default_value(obj.at(kDefaultValue)); | |||
| } | |||
| op_info->add_attrs_ptr(op_attr); | |||
| } catch (const std::exception& e) { | |||
| } catch (const std::exception &e) { | |||
| MS_LOG(DEBUG) << "DecodeAttr failed:" << e.what(); | |||
| ret = false; | |||
| } | |||
| return ret; | |||
| } | |||
| bool OpLib::DecodeDtypeFormat(const nlohmann::json& dtype_format, const std::shared_ptr<OpIOInfo>& op_io, | |||
| bool OpLib::DecodeDtypeFormat(const nlohmann::json &dtype_format, const std::shared_ptr<OpIOInfo> &op_io, | |||
| size_t index) { | |||
| bool ret = true; | |||
| try { | |||
| std::vector<std::string> dtype; | |||
| std::vector<std::string> format; | |||
| for (const auto& it : dtype_format) { | |||
| for (const auto &it : dtype_format) { | |||
| dtype.emplace_back(it[index][0]); | |||
| format.emplace_back(it[index][1]); | |||
| } | |||
| op_io->set_dtypes(dtype); | |||
| op_io->set_formats(format); | |||
| } catch (const std::exception& e) { | |||
| } catch (const std::exception &e) { | |||
| MS_LOG(ERROR) << "DecodeDtypeFormat falied" << e.what(); | |||
| ret = false; | |||
| } | |||
| return ret; | |||
| } | |||
| bool OpLib::DecodeInputOutput(const nlohmann::json& obj, const OpImplyType imply_type, const OpIOType io_type, | |||
| const std::shared_ptr<OpInfo>& op_info, const nlohmann::json& dtype_format) { | |||
| bool OpLib::DecodeInputOutput(const nlohmann::json &obj, const OpImplyType imply_type, const OpIOType io_type, | |||
| const std::shared_ptr<OpInfo> &op_info, const nlohmann::json &dtype_format) { | |||
| bool ret = true; | |||
| try { | |||
| std::shared_ptr<OpIOInfo> op_io = std::make_shared<OpIOInfo>(); | |||
| @@ -243,14 +243,14 @@ bool OpLib::DecodeInputOutput(const nlohmann::json& obj, const OpImplyType imply | |||
| } else if (io_type == kOutput) { | |||
| op_info->add_outputs_ptr(op_io); | |||
| } | |||
| } catch (const std::exception& e) { | |||
| } catch (const std::exception &e) { | |||
| MS_LOG(DEBUG) << "DecodeInputOutput failed" << e.what(); | |||
| ret = false; | |||
| } | |||
| return ret; | |||
| } | |||
| std::shared_ptr<OpInfo> OpLib::FindOp(const std::string& op_name, OpImplyType imply_type) { | |||
| std::shared_ptr<OpInfo> OpLib::FindOp(const std::string &op_name, OpImplyType imply_type) { | |||
| auto context = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context); | |||
| bool is_gpu = (context->device_target() == kGPUDevice); | |||
| @@ -260,7 +260,7 @@ std::shared_ptr<OpInfo> OpLib::FindOp(const std::string& op_name, OpImplyType im | |||
| << ", current op num:" << op_info_.size(); | |||
| return nullptr; | |||
| } | |||
| for (const auto& op_info : op_info_) { | |||
| for (const auto &op_info : op_info_) { | |||
| MS_EXCEPTION_IF_NULL(op_info); | |||
| if (op_info->op_name() == op_name && op_info->imply_type() == imply_type) { | |||
| return op_info; | |||
| @@ -271,14 +271,14 @@ std::shared_ptr<OpInfo> OpLib::FindOp(const std::string& op_name, OpImplyType im | |||
| return nullptr; | |||
| } | |||
| bool OpLib::GetRefInfo(const std::shared_ptr<OpInfo>& op_info) { | |||
| bool OpLib::GetRefInfo(const std::shared_ptr<OpInfo> &op_info) { | |||
| MS_EXCEPTION_IF_NULL(op_info); | |||
| const auto& output_infos = op_info->outputs_ptr(); | |||
| const auto& input_infos = op_info->inputs_ptr(); | |||
| const auto &output_infos = op_info->outputs_ptr(); | |||
| const auto &input_infos = op_info->inputs_ptr(); | |||
| for (size_t out_index = 0; out_index < output_infos.size(); out_index++) { | |||
| const auto& out_name = output_infos[out_index]->name(); | |||
| const auto &out_name = output_infos[out_index]->name(); | |||
| for (size_t in_index = 0; in_index < input_infos.size(); in_index++) { | |||
| const auto& in_name = input_infos[in_index]->name(); | |||
| const auto &in_name = input_infos[in_index]->name(); | |||
| if (out_name == in_name) { | |||
| if (op_info->has_ref_index(out_index)) { | |||
| MS_LOG(DEBUG) << "The out_index" << out_index << "is already in ref_info"; | |||
| @@ -293,9 +293,9 @@ bool OpLib::GetRefInfo(const std::shared_ptr<OpInfo>& op_info) { | |||
| return true; | |||
| } | |||
| bool OpLib::CheckRepetition(const std::shared_ptr<OpInfo>& op_info) { | |||
| bool OpLib::CheckRepetition(const std::shared_ptr<OpInfo> &op_info) { | |||
| MS_EXCEPTION_IF_NULL(op_info); | |||
| for (const auto& exist_op_info : op_info_) { | |||
| for (const auto &exist_op_info : op_info_) { | |||
| MS_EXCEPTION_IF_NULL(exist_op_info); | |||
| if (exist_op_info->op_name() == op_info->op_name() && exist_op_info->imply_type() == op_info->imply_type() && | |||
| exist_op_info->impl_path() != op_info->impl_path()) { | |||
| @@ -28,23 +28,23 @@ class OpLib { | |||
| public: | |||
| OpLib() = default; | |||
| virtual ~OpLib() = default; | |||
| bool RegOp(const std::string& json_string, const std::string& impl_path); | |||
| static std::shared_ptr<OpInfo> FindOp(const std::string& op_name, OpImplyType imply_type); | |||
| bool RegOp(const std::string &json_string, const std::string &impl_path); | |||
| static std::shared_ptr<OpInfo> FindOp(const std::string &op_name, OpImplyType imply_type); | |||
| protected: | |||
| static std::vector<std::shared_ptr<OpInfo>> op_info_; | |||
| private: | |||
| static bool DecodeOpInfo(const nlohmann::json& obj, const OpImplyType imply_type, const std::string& impl_path); | |||
| static bool DecodeAttr(const nlohmann::json& obj, const OpImplyType imply_type, | |||
| const std::shared_ptr<OpInfo>& op_info); | |||
| static bool DecodeDtypeFormat(const nlohmann::json& dtype_format, const std::shared_ptr<OpIOInfo>& op_io, | |||
| static bool DecodeOpInfo(const nlohmann::json &obj, const OpImplyType imply_type, const std::string &impl_path); | |||
| static bool DecodeAttr(const nlohmann::json &obj, const OpImplyType imply_type, | |||
| const std::shared_ptr<OpInfo> &op_info); | |||
| static bool DecodeDtypeFormat(const nlohmann::json &dtype_format, const std::shared_ptr<OpIOInfo> &op_io, | |||
| size_t index); | |||
| static void DecodeTBESpecificInfo(const nlohmann::json& obj, const std::shared_ptr<OpInfo>& op_info); | |||
| static bool DecodeInputOutput(const nlohmann::json& obj, const OpImplyType imply_type, const OpIOType io_type, | |||
| const std::shared_ptr<OpInfo>& op_info, const nlohmann::json& dtype_format); | |||
| static bool GetRefInfo(const std::shared_ptr<OpInfo>& op_info); | |||
| static bool CheckRepetition(const std::shared_ptr<OpInfo>& op_info); | |||
| static void DecodeTBESpecificInfo(const nlohmann::json &obj, const std::shared_ptr<OpInfo> &op_info); | |||
| static bool DecodeInputOutput(const nlohmann::json &obj, const OpImplyType imply_type, const OpIOType io_type, | |||
| const std::shared_ptr<OpInfo> &op_info, const nlohmann::json &dtype_format); | |||
| static bool GetRefInfo(const std::shared_ptr<OpInfo> &op_info); | |||
| static bool CheckRepetition(const std::shared_ptr<OpInfo> &op_info); | |||
| }; | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -19,6 +19,6 @@ | |||
| namespace mindspore { | |||
| // cppcheck-suppress unusedFunction | |||
| std::string set_version(const std::string& version) { return version; } | |||
| std::string set_version(const std::string &version) { return version; } | |||
| } // namespace mindspore | |||
| @@ -42,11 +42,11 @@ struct OpMergedInfo { | |||
| }; | |||
| using GenAttrFuncType = | |||
| std::function<void(ValuePtr, onnx::AttributeProto_AttributeType, onnx::AttributeProto*, const PrimitivePtr&)>; | |||
| std::function<void(ValuePtr, onnx::AttributeProto_AttributeType, onnx::AttributeProto *, const PrimitivePtr &)>; | |||
| template <typename T, size_t rep_cnt = 0> | |||
| void SetAttrValueToProto(const ValuePtr& value, onnx::AttributeProto_AttributeType attr_type, | |||
| onnx::AttributeProto* const attr_proto, const PrimitivePtr&) { | |||
| void SetAttrValueToProto(const ValuePtr &value, onnx::AttributeProto_AttributeType attr_type, | |||
| onnx::AttributeProto *const attr_proto, const PrimitivePtr &) { | |||
| auto casted_value = dyn_cast<T>(value); | |||
| if (casted_value == nullptr) { | |||
| MS_LOG(EXCEPTION) << "Cast value " << value->ToString() << " to type T failed."; | |||
| @@ -76,8 +76,8 @@ void SetAttrValueToProto(const ValuePtr& value, onnx::AttributeProto_AttributeTy | |||
| } | |||
| template <size_t beg_idx = 0> | |||
| void SetAttrTupleValueToProto(const ValuePtr& value, onnx::AttributeProto_AttributeType attr_type, | |||
| onnx::AttributeProto* const attr_proto, const PrimitivePtr&) { | |||
| void SetAttrTupleValueToProto(const ValuePtr &value, onnx::AttributeProto_AttributeType attr_type, | |||
| onnx::AttributeProto *const attr_proto, const PrimitivePtr &) { | |||
| auto tuple_ptr = dyn_cast<ValueTuple>(value); | |||
| if (tuple_ptr == nullptr) { | |||
| MS_LOG(EXCEPTION) << "Cast value from type " << value->type_name() << " to ValueTuple failed."; | |||
| @@ -99,8 +99,8 @@ void SetAttrTupleValueToProto(const ValuePtr& value, onnx::AttributeProto_Attrib | |||
| attr_proto->set_type(attr_type); | |||
| } | |||
| void SetPoolingPadMode(const ValuePtr& value, onnx::AttributeProto_AttributeType, | |||
| onnx::AttributeProto* const attr_proto, const PrimitivePtr&) { | |||
| void SetPoolingPadMode(const ValuePtr &value, onnx::AttributeProto_AttributeType, | |||
| onnx::AttributeProto *const attr_proto, const PrimitivePtr &) { | |||
| attr_proto->set_type(onnx::AttributeProto_AttributeType_STRING); | |||
| auto attr_value = GetValue<std::string>(value); | |||
| if (attr_value == "VALID") { | |||
| @@ -112,16 +112,16 @@ void SetPoolingPadMode(const ValuePtr& value, onnx::AttributeProto_AttributeType | |||
| class OpAttrInfo { | |||
| public: | |||
| OpAttrInfo(const std::string& attr_name, const string& onnx_attr_name, | |||
| onnx::AttributeProto_AttributeType onnx_attr_type, const GenAttrFuncType& fn_gen_attr) | |||
| OpAttrInfo(const std::string &attr_name, const string &onnx_attr_name, | |||
| onnx::AttributeProto_AttributeType onnx_attr_type, const GenAttrFuncType &fn_gen_attr) | |||
| : attr_name_(attr_name), | |||
| onnx_attr_name_(onnx_attr_name), | |||
| onnx_attr_type_(onnx_attr_type), | |||
| fn_gen_attr_(fn_gen_attr) {} | |||
| ~OpAttrInfo() {} | |||
| const std::string& attr_name() const { return attr_name_; } | |||
| const std::string& onnx_attr_name() const { return onnx_attr_name_; } | |||
| const std::string &attr_name() const { return attr_name_; } | |||
| const std::string &onnx_attr_name() const { return onnx_attr_name_; } | |||
| onnx::AttributeProto_AttributeType onnx_attr_type() const { return onnx_attr_type_; } | |||
| GenAttrFuncType fn_gen_attr() const { return fn_gen_attr_; } | |||
| @@ -134,27 +134,27 @@ class OpAttrInfo { | |||
| class OpNameInfo { | |||
| public: | |||
| OpNameInfo& set_op_type(const std::string& op_type) { | |||
| OpNameInfo &set_op_type(const std::string &op_type) { | |||
| op_type_ = op_type; | |||
| return *this; | |||
| } | |||
| const std::string& op_type() const { return op_type_; } | |||
| const std::string &op_type() const { return op_type_; } | |||
| OpNameInfo& set_onnx_type(const std::string& onnx_type) { | |||
| OpNameInfo &set_onnx_type(const std::string &onnx_type) { | |||
| onnx_type_ = onnx_type; | |||
| return *this; | |||
| } | |||
| const std::string& onnx_type() const { return onnx_type_; } | |||
| const std::string &onnx_type() const { return onnx_type_; } | |||
| OpNameInfo& Attr(const std::string& attr_name, const std::string& onnx_attr_name, | |||
| onnx::AttributeProto_AttributeType onnx_attr_type, const GenAttrFuncType& fn_gen_attr) { | |||
| OpNameInfo &Attr(const std::string &attr_name, const std::string &onnx_attr_name, | |||
| onnx::AttributeProto_AttributeType onnx_attr_type, const GenAttrFuncType &fn_gen_attr) { | |||
| op_attrs_.emplace_back(OpAttrInfo(attr_name, onnx_attr_name, onnx_attr_type, fn_gen_attr)); | |||
| return *this; | |||
| } | |||
| const std::vector<OpAttrInfo>& op_attrs() const { return op_attrs_; } | |||
| const std::vector<OpAttrInfo> &op_attrs() const { return op_attrs_; } | |||
| private: | |||
| std::string op_type_; // operator type of MindSpore | |||
| @@ -183,8 +183,8 @@ OPERATOR_ONNX_CONVERT_DEFINE( | |||
| .Attr("group", "group", onnx::AttributeProto_AttributeType_INT, SetAttrValueToProto<Int32Imm>) | |||
| .Attr("kernel_size", "kernel_shape", onnx::AttributeProto_AttributeType_INTS, SetAttrTupleValueToProto<0>) | |||
| .Attr("pad_mode", "auto_pad", onnx::AttributeProto_AttributeType_STRING, | |||
| [](ValuePtr value, onnx::AttributeProto_AttributeType, onnx::AttributeProto* const attr_proto, | |||
| const PrimitivePtr& prim) { | |||
| [](ValuePtr value, onnx::AttributeProto_AttributeType, onnx::AttributeProto *const attr_proto, | |||
| const PrimitivePtr &prim) { | |||
| attr_proto->set_type(onnx::AttributeProto_AttributeType_STRING); | |||
| auto attr_value = GetValue<std::string>(value); | |||
| if (attr_value == "valid") { | |||
| @@ -220,7 +220,7 @@ OPERATOR_ONNX_CONVERT_DEFINE(Argmax, ArgMax, | |||
| SetAttrValueToProto<Int32Imm>) | |||
| .Attr("", "keepdims", onnx::AttributeProto_AttributeType_INT, | |||
| [](ValuePtr, onnx::AttributeProto_AttributeType, | |||
| onnx::AttributeProto* const attr_proto, const PrimitivePtr&) { | |||
| onnx::AttributeProto *const attr_proto, const PrimitivePtr &) { | |||
| attr_proto->set_type(onnx::AttributeProto_AttributeType_INT); | |||
| attr_proto->set_i(0); | |||
| })) | |||
| @@ -242,7 +242,7 @@ OPERATOR_ONNX_CONVERT_DEFINE( | |||
| #define OP_CONVERT_FUNCTION_NAME(name) GetOpOnnxConvertInfo_##name | |||
| void RegisterOpConverters(const std::function<void(OpNameInfo&&)>& fn) { | |||
| void RegisterOpConverters(const std::function<void(OpNameInfo &&)> &fn) { | |||
| fn(OP_CONVERT_FUNCTION_NAME(TensorAdd)()); | |||
| fn(OP_CONVERT_FUNCTION_NAME(Mul)()); | |||
| @@ -265,16 +265,16 @@ class OpConvertRegistry { | |||
| public: | |||
| ~OpConvertRegistry() { Clear(); } | |||
| static void RegisterOneOpConverter(OpNameInfo&& op_info) { GetSingleton().op_map_[op_info.op_type()] = op_info; } | |||
| static void RegisterOneOpConverter(OpNameInfo &&op_info) { GetSingleton().op_map_[op_info.op_type()] = op_info; } | |||
| static void RegisterAllOpConverters() { RegisterOpConverters(RegisterOneOpConverter); } | |||
| static OpConvertRegistry& GetSingleton() { | |||
| static OpConvertRegistry &GetSingleton() { | |||
| static OpConvertRegistry registry = OpConvertRegistry(); | |||
| return registry; | |||
| } | |||
| static const std::unordered_map<std::string, OpNameInfo>& GetOpConvertMap() { return GetSingleton().op_map_; } | |||
| static const std::unordered_map<std::string, OpNameInfo> &GetOpConvertMap() { return GetSingleton().op_map_; } | |||
| void Clear() noexcept { op_map_.clear(); } | |||
| @@ -289,59 +289,59 @@ class OnnxExporter { | |||
| OnnxExporter() {} | |||
| ~OnnxExporter() {} | |||
| std::string GetOnnxProtoString(const FuncGraphPtr& func_graph); | |||
| std::string GetOnnxProtoString(const FuncGraphPtr &func_graph); | |||
| private: | |||
| void InitModelInfo(); | |||
| void ExportFuncGraph(const FuncGraphPtr& func_graph, onnx::GraphProto* graph_proto); | |||
| void ExportParameters(const FuncGraphPtr& func_graph, onnx::GraphProto* graph_proto); | |||
| void ExportFuncGraph(const FuncGraphPtr &func_graph, onnx::GraphProto *graph_proto); | |||
| void ExportParameters(const FuncGraphPtr &func_graph, onnx::GraphProto *graph_proto); | |||
| size_t ExportPrimitive(const FuncGraphPtr& func_graph, std::map<AnfNodePtr, size_t>* node_map_ptr, | |||
| const PrimitivePtr& prim, const std::vector<AnfNodePtr>& inputs, | |||
| onnx::GraphProto* graph_proto); | |||
| size_t ExportPrimitive(const FuncGraphPtr &func_graph, std::map<AnfNodePtr, size_t> *node_map_ptr, | |||
| const PrimitivePtr &prim, const std::vector<AnfNodePtr> &inputs, | |||
| onnx::GraphProto *graph_proto); | |||
| static onnx::TensorProto_DataType GetOnnxDataType(TypeId type_id); | |||
| void SetValueInfoType(const AnfNodePtr& node, onnx::ValueInfoProto* value_proto, bool is_output = false); | |||
| void SetTensorProtoInfo(const ParameterPtr& param, onnx::TensorProto* tensor_proto); | |||
| void MatchAndMark(const FuncGraphPtr& func_graph, const std::vector<AnfNodePtr>& nodes, | |||
| std::unordered_map<AnfNodePtr, OpMergedInfo>* op_merged_infos_ptr); | |||
| void ExportNodes(const FuncGraphPtr& func_graph, std::map<AnfNodePtr, size_t>* node_map_ptr, | |||
| onnx::GraphProto* graph_proto); | |||
| void ExportCNode(const FuncGraphPtr& func_graph, const CNodePtr& node, std::map<AnfNodePtr, size_t>* node_map_ptr, | |||
| onnx::GraphProto* graph_proto); | |||
| void ExportPrimReshape(const FuncGraphPtr& func_graph, const CNodePtr& node, | |||
| std::map<AnfNodePtr, size_t>* node_map_ptr, onnx::GraphProto* graph_proto); | |||
| void ExportPrimReduceMean(const FuncGraphPtr& func_graph, const CNodePtr& node, | |||
| std::map<AnfNodePtr, size_t>* node_map_ptr, onnx::GraphProto* graph_proto); | |||
| void ExportPrimCast(const FuncGraphPtr& func_graph, const CNodePtr& node, std::map<AnfNodePtr, size_t>* node_map_ptr, | |||
| onnx::GraphProto* graph_proto); | |||
| void ExportPrimPReLU(const FuncGraphPtr& func_graph, const CNodePtr& node, std::map<AnfNodePtr, size_t>* node_map_ptr, | |||
| onnx::GraphProto* graph_proto); | |||
| void ExportMergeConv(const FuncGraphPtr& func_graph, const CNodePtr& node, std::map<AnfNodePtr, size_t>* node_map_ptr, | |||
| onnx::GraphProto* graph_proto); | |||
| void ExportMergeGemm(const FuncGraphPtr& func_graph, const CNodePtr& node, std::map<AnfNodePtr, size_t>* node_map_ptr, | |||
| onnx::GraphProto* graph_proto); | |||
| void ExportMergeBatchNorm(const FuncGraphPtr& func_graph, const CNodePtr& node, | |||
| std::map<AnfNodePtr, size_t>* node_map_ptr, onnx::GraphProto* graph_proto); | |||
| void ExportOutput(const FuncGraphPtr& func_graph, const CNodePtr& node, std::map<AnfNodePtr, size_t>* node_map_ptr, | |||
| onnx::GraphProto* graph_proto); | |||
| std::string GetNodeInputName(const AnfNodePtr& node, std::map<AnfNodePtr, size_t>* node_map_ptr, | |||
| onnx::GraphProto* const graph_proto); | |||
| void ConvertTupleToTensor(const ValuePtr& value, onnx::TensorProto* tensor_proto); | |||
| void SetNodeAttribute(const ValuePtr& value, onnx::NodeProto* node_proto); | |||
| void SetValueInfoType(const AnfNodePtr &node, onnx::ValueInfoProto *value_proto, bool is_output = false); | |||
| void SetTensorProtoInfo(const ParameterPtr ¶m, onnx::TensorProto *tensor_proto); | |||
| void MatchAndMark(const FuncGraphPtr &func_graph, const std::vector<AnfNodePtr> &nodes, | |||
| std::unordered_map<AnfNodePtr, OpMergedInfo> *op_merged_infos_ptr); | |||
| void ExportNodes(const FuncGraphPtr &func_graph, std::map<AnfNodePtr, size_t> *node_map_ptr, | |||
| onnx::GraphProto *graph_proto); | |||
| void ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map<AnfNodePtr, size_t> *node_map_ptr, | |||
| onnx::GraphProto *graph_proto); | |||
| void ExportPrimReshape(const FuncGraphPtr &func_graph, const CNodePtr &node, | |||
| std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *graph_proto); | |||
| void ExportPrimReduceMean(const FuncGraphPtr &func_graph, const CNodePtr &node, | |||
| std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *graph_proto); | |||
| void ExportPrimCast(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map<AnfNodePtr, size_t> *node_map_ptr, | |||
| onnx::GraphProto *graph_proto); | |||
| void ExportPrimPReLU(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map<AnfNodePtr, size_t> *node_map_ptr, | |||
| onnx::GraphProto *graph_proto); | |||
| void ExportMergeConv(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map<AnfNodePtr, size_t> *node_map_ptr, | |||
| onnx::GraphProto *graph_proto); | |||
| void ExportMergeGemm(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map<AnfNodePtr, size_t> *node_map_ptr, | |||
| onnx::GraphProto *graph_proto); | |||
| void ExportMergeBatchNorm(const FuncGraphPtr &func_graph, const CNodePtr &node, | |||
| std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *graph_proto); | |||
| void ExportOutput(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map<AnfNodePtr, size_t> *node_map_ptr, | |||
| onnx::GraphProto *graph_proto); | |||
| std::string GetNodeInputName(const AnfNodePtr &node, std::map<AnfNodePtr, size_t> *node_map_ptr, | |||
| onnx::GraphProto *const graph_proto); | |||
| void ConvertTupleToTensor(const ValuePtr &value, onnx::TensorProto *tensor_proto); | |||
| void SetNodeAttribute(const ValuePtr &value, onnx::NodeProto *node_proto); | |||
| size_t AllocateNodeIndex() { return ++onnx_node_index_; } | |||
| void ResetNodeIndex() { onnx_node_index_ = 0; } | |||
| static int GetInt32Value(const AnfNodePtr& node) { | |||
| static int GetInt32Value(const AnfNodePtr &node) { | |||
| auto value_node_ptr = dyn_cast<ValueNode>(node); | |||
| MS_EXCEPTION_IF_NULL(value_node_ptr); | |||
| return GetValue<int>(value_node_ptr->value()); | |||
| @@ -352,7 +352,7 @@ class OnnxExporter { | |||
| size_t onnx_node_index_ = 0; | |||
| }; | |||
| std::string OnnxExporter::GetOnnxProtoString(const FuncGraphPtr& func_graph) { | |||
| std::string OnnxExporter::GetOnnxProtoString(const FuncGraphPtr &func_graph) { | |||
| if (func_graph == nullptr) { | |||
| return ""; | |||
| } | |||
| @@ -360,7 +360,7 @@ std::string OnnxExporter::GetOnnxProtoString(const FuncGraphPtr& func_graph) { | |||
| OpConvertRegistry::GetSingleton().Clear(); | |||
| OpConvertRegistry::RegisterAllOpConverters(); | |||
| InitModelInfo(); | |||
| onnx::GraphProto* graph_proto = model_.mutable_graph(); | |||
| onnx::GraphProto *graph_proto = model_.mutable_graph(); | |||
| ExportFuncGraph(func_graph, graph_proto); | |||
| return model_.SerializeAsString(); | |||
| } | |||
| @@ -369,11 +369,11 @@ void OnnxExporter::InitModelInfo() { | |||
| model_.set_ir_version(onnx::IR_VERSION_2019_1_22); | |||
| model_.set_producer_name("MindSpore"); | |||
| model_.set_producer_version("1.0"); | |||
| onnx::OperatorSetIdProto* opset_proto = model_.add_opset_import(); | |||
| onnx::OperatorSetIdProto *opset_proto = model_.add_opset_import(); | |||
| opset_proto->set_version(9); | |||
| } | |||
| void OnnxExporter::ExportFuncGraph(const FuncGraphPtr& func_graph, onnx::GraphProto* const graph_proto) { | |||
| void OnnxExporter::ExportFuncGraph(const FuncGraphPtr &func_graph, onnx::GraphProto *const graph_proto) { | |||
| std::map<AnfNodePtr, size_t> node_map; | |||
| onnx_node_index_ = func_graph->parameters().size(); | |||
| @@ -390,14 +390,14 @@ void OnnxExporter::ExportFuncGraph(const FuncGraphPtr& func_graph, onnx::GraphPr | |||
| ExportNodes(func_graph, &node_map, graph_proto); | |||
| } | |||
| void OnnxExporter::ExportParameters(const FuncGraphPtr& func_graph, onnx::GraphProto* const graph_proto) { | |||
| for (auto& param : func_graph->parameters()) { | |||
| void OnnxExporter::ExportParameters(const FuncGraphPtr &func_graph, onnx::GraphProto *const graph_proto) { | |||
| for (auto ¶m : func_graph->parameters()) { | |||
| const ParameterPtr param_ptr = dyn_cast<Parameter>(param); | |||
| if (param_ptr == nullptr) { | |||
| MS_LOG(EXCEPTION) << "Parameter '" << param->ToString() << "' could not cast to parameter."; | |||
| } | |||
| onnx::ValueInfoProto* input_proto = graph_proto->add_input(); | |||
| onnx::ValueInfoProto *input_proto = graph_proto->add_input(); | |||
| input_proto->set_name(param_ptr->ToString()); | |||
| SetValueInfoType(param_ptr, input_proto); | |||
| @@ -405,7 +405,7 @@ void OnnxExporter::ExportParameters(const FuncGraphPtr& func_graph, onnx::GraphP | |||
| continue; | |||
| } | |||
| // parameter with default value is an ONNX initializer | |||
| onnx::TensorProto* initializer_proto = graph_proto->add_initializer(); | |||
| onnx::TensorProto *initializer_proto = graph_proto->add_initializer(); | |||
| initializer_proto->set_name(param_ptr->ToString()); | |||
| SetTensorProtoInfo(param_ptr, initializer_proto); | |||
| // set value for initializer | |||
| @@ -445,25 +445,25 @@ onnx::TensorProto_DataType OnnxExporter::GetOnnxDataType(TypeId type_id) { | |||
| return iter->second; | |||
| } | |||
| void OnnxExporter::SetValueInfoType(const AnfNodePtr& node, onnx::ValueInfoProto* const value_proto, bool is_output) { | |||
| void OnnxExporter::SetValueInfoType(const AnfNodePtr &node, onnx::ValueInfoProto *const value_proto, bool is_output) { | |||
| auto dtype = node->Type(); | |||
| auto shape = node->Shape(); | |||
| onnx::TypeProto* type_proto = value_proto->mutable_type(); | |||
| onnx::TypeProto *type_proto = value_proto->mutable_type(); | |||
| if (dtype->isa<TensorType>() && shape->isa<abstract::Shape>()) { | |||
| auto tensor = dyn_cast<TensorType>(dtype); | |||
| auto elem_type = tensor->element(); | |||
| const auto& dims = dyn_cast<abstract::Shape>(shape)->shape(); | |||
| const auto &dims = dyn_cast<abstract::Shape>(shape)->shape(); | |||
| // output type of 'Argmax' of MindSpore is int32, output type of 'ArgMax' of ONNX is int64 | |||
| auto type = is_output ? onnx::TensorProto_DataType_INT64 : GetOnnxDataType(elem_type->type_id()); | |||
| type_proto->mutable_tensor_type()->set_elem_type(type); | |||
| for (const auto& dim : dims) { | |||
| for (const auto &dim : dims) { | |||
| type_proto->mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(dim); | |||
| } | |||
| } | |||
| } | |||
| void OnnxExporter::SetTensorProtoInfo(const ParameterPtr& param, onnx::TensorProto* const tensor_proto) { | |||
| void OnnxExporter::SetTensorProtoInfo(const ParameterPtr ¶m, onnx::TensorProto *const tensor_proto) { | |||
| auto dtype = param->Type(); | |||
| auto shape = param->Shape(); | |||
| if (!dtype->isa<TensorType>() || !shape->isa<abstract::Shape>()) { | |||
| @@ -472,18 +472,18 @@ void OnnxExporter::SetTensorProtoInfo(const ParameterPtr& param, onnx::TensorPro | |||
| auto tensor = dyn_cast<TensorType>(dtype); | |||
| auto elem_type = tensor->element(); | |||
| const auto& dims = dyn_cast<abstract::Shape>(shape)->shape(); | |||
| const auto &dims = dyn_cast<abstract::Shape>(shape)->shape(); | |||
| tensor_proto->set_data_type(GetOnnxDataType(elem_type->type_id())); | |||
| for (const auto& dim : dims) { | |||
| for (const auto &dim : dims) { | |||
| tensor_proto->add_dims(dim); | |||
| } | |||
| } | |||
| void OnnxExporter::MatchAndMark(const FuncGraphPtr& func_graph, const std::vector<AnfNodePtr>& nodes, | |||
| std::unordered_map<AnfNodePtr, OpMergedInfo>* op_merged_infos_ptr) { | |||
| std::unordered_map<AnfNodePtr, OpMergedInfo>& op_merged_infos = *op_merged_infos_ptr; | |||
| void OnnxExporter::MatchAndMark(const FuncGraphPtr &func_graph, const std::vector<AnfNodePtr> &nodes, | |||
| std::unordered_map<AnfNodePtr, OpMergedInfo> *op_merged_infos_ptr) { | |||
| std::unordered_map<AnfNodePtr, OpMergedInfo> &op_merged_infos = *op_merged_infos_ptr; | |||
| for (auto& node : nodes) { | |||
| for (auto &node : nodes) { | |||
| if (!node->isa<CNode>()) { | |||
| continue; | |||
| } | |||
| @@ -492,7 +492,7 @@ void OnnxExporter::MatchAndMark(const FuncGraphPtr& func_graph, const std::vecto | |||
| // if the key `input` does not exist, just create a new one | |||
| op_merged_infos[cnode].referred_count += 1; | |||
| } | |||
| for (auto& input : cnode->inputs()) { | |||
| for (auto &input : cnode->inputs()) { | |||
| if (!input->isa<CNode>()) { | |||
| continue; | |||
| } | |||
| @@ -527,14 +527,14 @@ void OnnxExporter::MatchAndMark(const FuncGraphPtr& func_graph, const std::vecto | |||
| * | +-- Parameter | |||
| * | `-- ValueNode | |||
| */ | |||
| void OnnxExporter::ExportNodes(const FuncGraphPtr& func_graph, std::map<AnfNodePtr, size_t>* node_map_ptr, | |||
| onnx::GraphProto* const graph_proto) { | |||
| void OnnxExporter::ExportNodes(const FuncGraphPtr &func_graph, std::map<AnfNodePtr, size_t> *node_map_ptr, | |||
| onnx::GraphProto *const graph_proto) { | |||
| std::vector<AnfNodePtr> nodes = TopoSort(func_graph->get_return(), SuccIncoming, AlwaysInclude); | |||
| std::unordered_map<AnfNodePtr, OpMergedInfo> op_merged_infos; | |||
| MatchAndMark(func_graph, nodes, &op_merged_infos); | |||
| for (const AnfNodePtr& node : nodes) { | |||
| for (const AnfNodePtr &node : nodes) { | |||
| if (!node->isa<CNode>()) { | |||
| continue; | |||
| } | |||
| @@ -570,20 +570,20 @@ void OnnxExporter::ExportNodes(const FuncGraphPtr& func_graph, std::map<AnfNodeP | |||
| } | |||
| } | |||
| void OnnxExporter::ExportPrimReshape(const FuncGraphPtr& /*func_graph*/, const CNodePtr& node, | |||
| std::map<AnfNodePtr, size_t>* node_map_ptr, onnx::GraphProto* const graph_proto) { | |||
| void OnnxExporter::ExportPrimReshape(const FuncGraphPtr & /*func_graph*/, const CNodePtr &node, | |||
| std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *const graph_proto) { | |||
| auto name_x = GetNodeInputName(node->input(1), node_map_ptr, graph_proto); | |||
| auto input_shape = node->input(2); | |||
| std::string name_shape; | |||
| if (input_shape->isa<ValueNode>()) { | |||
| auto const_node_idx = AllocateNodeIndex(); | |||
| (*node_map_ptr)[input_shape] = const_node_idx; | |||
| onnx::NodeProto* node_proto = graph_proto->add_node(); | |||
| onnx::NodeProto *node_proto = graph_proto->add_node(); | |||
| name_shape = std::to_string(const_node_idx); | |||
| node_proto->add_output(name_shape); | |||
| node_proto->set_op_type("Constant"); | |||
| onnx::AttributeProto* attr_proto = node_proto->add_attribute(); | |||
| onnx::AttributeProto *attr_proto = node_proto->add_attribute(); | |||
| attr_proto->set_name("value"); | |||
| attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR); | |||
| @@ -595,28 +595,28 @@ void OnnxExporter::ExportPrimReshape(const FuncGraphPtr& /*func_graph*/, const C | |||
| auto node_idx = AllocateNodeIndex(); | |||
| (*node_map_ptr)[node] = node_idx; | |||
| onnx::NodeProto* node_proto = graph_proto->add_node(); | |||
| onnx::NodeProto *node_proto = graph_proto->add_node(); | |||
| node_proto->set_op_type(prim::kPrimReshape->name()); | |||
| node_proto->add_output(std::to_string(node_idx)); | |||
| node_proto->add_input(name_x); | |||
| node_proto->add_input(name_shape); | |||
| } | |||
| void OnnxExporter::ExportPrimReduceMean(const FuncGraphPtr& /*func_graph*/, const CNodePtr& node, | |||
| std::map<AnfNodePtr, size_t>* node_map_ptr, | |||
| onnx::GraphProto* const graph_proto) { | |||
| void OnnxExporter::ExportPrimReduceMean(const FuncGraphPtr & /*func_graph*/, const CNodePtr &node, | |||
| std::map<AnfNodePtr, size_t> *node_map_ptr, | |||
| onnx::GraphProto *const graph_proto) { | |||
| auto input_data = GetNodeInputName(node->input(1), node_map_ptr, graph_proto); | |||
| auto input_axis = node->input(2); | |||
| auto node_idx = AllocateNodeIndex(); | |||
| (*node_map_ptr)[node] = node_idx; | |||
| onnx::NodeProto* node_proto = graph_proto->add_node(); | |||
| onnx::NodeProto *node_proto = graph_proto->add_node(); | |||
| node_proto->set_op_type(prim::kPrimReduceMean->name()); | |||
| node_proto->add_output(std::to_string(node_idx)); | |||
| node_proto->add_input(input_data); | |||
| if (input_axis->isa<ValueNode>()) { | |||
| onnx::AttributeProto* attr_proto = node_proto->add_attribute(); | |||
| onnx::AttributeProto *attr_proto = node_proto->add_attribute(); | |||
| attr_proto->set_name("axes"); | |||
| attr_proto->set_type(onnx::AttributeProto_AttributeType_INTS); | |||
| auto axis_value = dyn_cast<ValueNode>(input_axis)->value(); | |||
| @@ -630,20 +630,20 @@ void OnnxExporter::ExportPrimReduceMean(const FuncGraphPtr& /*func_graph*/, cons | |||
| } | |||
| } | |||
| void OnnxExporter::ExportPrimCast(const FuncGraphPtr& /*func_graph*/, const CNodePtr& node, | |||
| std::map<AnfNodePtr, size_t>* node_map_ptr, onnx::GraphProto* const graph_proto) { | |||
| void OnnxExporter::ExportPrimCast(const FuncGraphPtr & /*func_graph*/, const CNodePtr &node, | |||
| std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *const graph_proto) { | |||
| auto input_data = GetNodeInputName(node->input(1), node_map_ptr, graph_proto); | |||
| auto input_type = node->input(2); | |||
| auto node_idx = AllocateNodeIndex(); | |||
| (*node_map_ptr)[node] = node_idx; | |||
| onnx::NodeProto* node_proto = graph_proto->add_node(); | |||
| onnx::NodeProto *node_proto = graph_proto->add_node(); | |||
| node_proto->set_op_type(prim::kPrimCast->name()); | |||
| node_proto->add_output(std::to_string(node_idx)); | |||
| node_proto->add_input(input_data); | |||
| if (input_type->isa<ValueNode>()) { | |||
| onnx::AttributeProto* attr_proto = node_proto->add_attribute(); | |||
| onnx::AttributeProto *attr_proto = node_proto->add_attribute(); | |||
| attr_proto->set_name("to"); | |||
| attr_proto->set_type(onnx::AttributeProto_AttributeType_INT); | |||
| auto type_value = dyn_cast<ValueNode>(input_type)->value(); | |||
| @@ -655,8 +655,8 @@ void OnnxExporter::ExportPrimCast(const FuncGraphPtr& /*func_graph*/, const CNod | |||
| } | |||
| } | |||
| void OnnxExporter::ExportPrimPReLU(const FuncGraphPtr& /*func_graph*/, const CNodePtr& node, | |||
| std::map<AnfNodePtr, size_t>* node_map_ptr, onnx::GraphProto* const graph_proto) { | |||
| void OnnxExporter::ExportPrimPReLU(const FuncGraphPtr & /*func_graph*/, const CNodePtr &node, | |||
| std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *const graph_proto) { | |||
| auto input_x = GetNodeInputName(node->input(1), node_map_ptr, graph_proto); | |||
| auto input_slope = GetNodeInputName(node->input(2), node_map_ptr, graph_proto); | |||
| @@ -668,11 +668,11 @@ void OnnxExporter::ExportPrimPReLU(const FuncGraphPtr& /*func_graph*/, const CNo | |||
| // format of x is NCHW, input format is NCHW, if length of input_slope is 1, insert Unsqueeze [1,2] | |||
| if (x_shape->shape().size() == 4 && slope_shape->shape().size() == 1) { | |||
| auto node_idx = AllocateNodeIndex(); | |||
| onnx::NodeProto* node_proto = graph_proto->add_node(); | |||
| onnx::NodeProto *node_proto = graph_proto->add_node(); | |||
| node_proto->set_op_type("Unsqueeze"); | |||
| node_proto->add_output(std::to_string(node_idx)); | |||
| onnx::AttributeProto* attr_proto = node_proto->add_attribute(); | |||
| onnx::AttributeProto *attr_proto = node_proto->add_attribute(); | |||
| attr_proto->set_type(onnx::AttributeProto_AttributeType_INTS); | |||
| attr_proto->set_name("axes"); | |||
| attr_proto->add_ints(1); | |||
| @@ -684,15 +684,15 @@ void OnnxExporter::ExportPrimPReLU(const FuncGraphPtr& /*func_graph*/, const CNo | |||
| auto node_idx = AllocateNodeIndex(); | |||
| (*node_map_ptr)[node] = node_idx; | |||
| onnx::NodeProto* node_proto = graph_proto->add_node(); | |||
| onnx::NodeProto *node_proto = graph_proto->add_node(); | |||
| node_proto->set_op_type("PRelu"); | |||
| node_proto->add_output(std::to_string(node_idx)); | |||
| node_proto->add_input(input_x); | |||
| node_proto->add_input(input_slope); | |||
| } | |||
| void OnnxExporter::ExportCNode(const FuncGraphPtr& func_graph, const CNodePtr& node, | |||
| std::map<AnfNodePtr, size_t>* node_map_ptr, onnx::GraphProto* const graph_proto) { | |||
| void OnnxExporter::ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &node, | |||
| std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *const graph_proto) { | |||
| // Type of the 2nd input of 'Reshape' of MindSpore is tuple, but ONNX's is tensor, need to do some convert | |||
| if (node->IsApply(prim::kPrimReshape)) { | |||
| return ExportPrimReshape(func_graph, node, node_map_ptr, graph_proto); | |||
| @@ -735,31 +735,31 @@ void OnnxExporter::ExportCNode(const FuncGraphPtr& func_graph, const CNodePtr& n | |||
| (*node_map_ptr)[node] = ExportPrimitive(func_graph, node_map_ptr, prim, op_inputs, graph_proto); | |||
| } | |||
| size_t OnnxExporter::ExportPrimitive(const FuncGraphPtr& /*func_graph*/, std::map<AnfNodePtr, size_t>* node_map_ptr, | |||
| const PrimitivePtr& prim, const std::vector<AnfNodePtr>& inputs, | |||
| onnx::GraphProto* const graph_proto) { | |||
| size_t OnnxExporter::ExportPrimitive(const FuncGraphPtr & /*func_graph*/, std::map<AnfNodePtr, size_t> *node_map_ptr, | |||
| const PrimitivePtr &prim, const std::vector<AnfNodePtr> &inputs, | |||
| onnx::GraphProto *const graph_proto) { | |||
| auto op_map = OpConvertRegistry::GetOpConvertMap(); | |||
| auto op_iter = op_map.find(prim->name()); | |||
| if (op_iter == op_map.end()) { | |||
| MS_LOG(EXCEPTION) << "Can not find key " << prim->name() << " in convert map"; | |||
| } | |||
| const OpNameInfo& op_convert_info = op_iter->second; | |||
| const OpNameInfo &op_convert_info = op_iter->second; | |||
| auto node_idx = AllocateNodeIndex(); | |||
| onnx::NodeProto* node_proto = graph_proto->add_node(); | |||
| onnx::NodeProto *node_proto = graph_proto->add_node(); | |||
| node_proto->add_output(std::to_string(node_idx)); | |||
| node_proto->set_op_type(op_convert_info.onnx_type()); | |||
| // Set inputs | |||
| for (const auto& input : inputs) { | |||
| for (const auto &input : inputs) { | |||
| auto input_name = GetNodeInputName(input, node_map_ptr, graph_proto); | |||
| node_proto->add_input(input_name); | |||
| } | |||
| // Set node attribute | |||
| for (const OpAttrInfo& attr : op_convert_info.op_attrs()) { | |||
| const std::string& attr_name = attr.attr_name(); | |||
| for (const OpAttrInfo &attr : op_convert_info.op_attrs()) { | |||
| const std::string &attr_name = attr.attr_name(); | |||
| ValuePtr attr_value = nullptr; | |||
| if (!attr_name.empty()) { | |||
| attr_value = prim->GetAttr(attr_name); | |||
| @@ -767,15 +767,15 @@ size_t OnnxExporter::ExportPrimitive(const FuncGraphPtr& /*func_graph*/, std::ma | |||
| MS_LOG(EXCEPTION) << "Primitive " << prim->name() << " does not have attribute " << attr_name; | |||
| } | |||
| } | |||
| onnx::AttributeProto* onnx_attr_proto = node_proto->add_attribute(); | |||
| onnx::AttributeProto *onnx_attr_proto = node_proto->add_attribute(); | |||
| onnx_attr_proto->set_name(attr.onnx_attr_name()); | |||
| attr.fn_gen_attr()(attr_value, attr.onnx_attr_type(), onnx_attr_proto, prim); | |||
| } | |||
| return node_idx; | |||
| } | |||
| void OnnxExporter::ExportMergeConv(const FuncGraphPtr& func_graph, const CNodePtr& node, | |||
| std::map<AnfNodePtr, size_t>* node_map_ptr, onnx::GraphProto* const graph_proto) { | |||
| void OnnxExporter::ExportMergeConv(const FuncGraphPtr &func_graph, const CNodePtr &node, | |||
| std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *const graph_proto) { | |||
| auto conv_node = dyn_cast<CNode>(node->input(1)); | |||
| auto input_x = conv_node->input(1); // conv input x | |||
| auto input_w = conv_node->input(2); // conv weight(filter) | |||
| @@ -786,8 +786,8 @@ void OnnxExporter::ExportMergeConv(const FuncGraphPtr& func_graph, const CNodePt | |||
| (*node_map_ptr)[node] = ExportPrimitive(func_graph, node_map_ptr, prim_conv, inputs, graph_proto); | |||
| } | |||
| void OnnxExporter::ExportMergeGemm(const FuncGraphPtr& func_graph, const CNodePtr& node, | |||
| std::map<AnfNodePtr, size_t>* node_map_ptr, onnx::GraphProto* const graph_proto) { | |||
| void OnnxExporter::ExportMergeGemm(const FuncGraphPtr &func_graph, const CNodePtr &node, | |||
| std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *const graph_proto) { | |||
| auto matmul_node = dyn_cast<CNode>(node->input(1)); | |||
| auto input_x = matmul_node->input(1); // matmul input x | |||
| auto input_y = matmul_node->input(2); // matmul input y | |||
| @@ -798,9 +798,9 @@ void OnnxExporter::ExportMergeGemm(const FuncGraphPtr& func_graph, const CNodePt | |||
| (*node_map_ptr)[node] = ExportPrimitive(func_graph, node_map_ptr, prim_matmul, inputs, graph_proto); | |||
| } | |||
| void OnnxExporter::ExportMergeBatchNorm(const FuncGraphPtr& func_graph, const CNodePtr& node, | |||
| std::map<AnfNodePtr, size_t>* node_map_ptr, | |||
| onnx::GraphProto* const graph_proto) { | |||
| void OnnxExporter::ExportMergeBatchNorm(const FuncGraphPtr &func_graph, const CNodePtr &node, | |||
| std::map<AnfNodePtr, size_t> *node_map_ptr, | |||
| onnx::GraphProto *const graph_proto) { | |||
| auto batch_norm_node = dyn_cast<CNode>(node->input(1)); | |||
| PrimitivePtr prim_batch_norm = dyn_cast<Primitive>((dyn_cast<ValueNode>(batch_norm_node->input(0)))->value()); | |||
| @@ -811,20 +811,20 @@ void OnnxExporter::ExportMergeBatchNorm(const FuncGraphPtr& func_graph, const CN | |||
| (*node_map_ptr)[node] = ExportPrimitive(func_graph, node_map_ptr, prim_batch_norm, inputs, graph_proto); | |||
| } | |||
| void OnnxExporter::ExportOutput(const FuncGraphPtr& /*func_graph*/, const CNodePtr& node, | |||
| std::map<AnfNodePtr, size_t>* node_map_ptr, onnx::GraphProto* const graph_proto) { | |||
| void OnnxExporter::ExportOutput(const FuncGraphPtr & /*func_graph*/, const CNodePtr &node, | |||
| std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *const graph_proto) { | |||
| if (node->inputs().size() != 2) { | |||
| MS_LOG(EXCEPTION) << "Number of inputs of return node is not equal to 2."; | |||
| } | |||
| AnfNodePtr arg = node->input(1); | |||
| std::string name = GetNodeInputName(arg, node_map_ptr, graph_proto); | |||
| onnx::ValueInfoProto* output_proto = graph_proto->add_output(); | |||
| onnx::ValueInfoProto *output_proto = graph_proto->add_output(); | |||
| output_proto->set_name(name); | |||
| SetValueInfoType(arg, output_proto, false); | |||
| } | |||
| std::string OnnxExporter::GetNodeInputName(const AnfNodePtr& node, std::map<AnfNodePtr, size_t>* node_map_ptr, | |||
| onnx::GraphProto* const graph_proto) { | |||
| std::string OnnxExporter::GetNodeInputName(const AnfNodePtr &node, std::map<AnfNodePtr, size_t> *node_map_ptr, | |||
| onnx::GraphProto *const graph_proto) { | |||
| if (node->isa<CNode>()) { | |||
| auto iter = node_map_ptr->find(node); | |||
| if (iter == node_map_ptr->end()) { | |||
| @@ -848,7 +848,7 @@ std::string OnnxExporter::GetNodeInputName(const AnfNodePtr& node, std::map<AnfN | |||
| (*node_map_ptr)[node] = node_idx; | |||
| std::string node_name = std::to_string(node_idx); | |||
| onnx::NodeProto* node_proto = graph_proto->add_node(); | |||
| onnx::NodeProto *node_proto = graph_proto->add_node(); | |||
| node_proto->add_output(node_name); | |||
| SetNodeAttribute(node->cast<ValueNodePtr>()->value(), node_proto); | |||
| @@ -859,7 +859,7 @@ std::string OnnxExporter::GetNodeInputName(const AnfNodePtr& node, std::map<AnfN | |||
| MS_LOG(EXCEPTION) << "Unexpected node type " << node->type_name(); | |||
| } | |||
| void OnnxExporter::ConvertTupleToTensor(const ValuePtr& value, onnx::TensorProto* const tensor_proto) { | |||
| void OnnxExporter::ConvertTupleToTensor(const ValuePtr &value, onnx::TensorProto *const tensor_proto) { | |||
| auto tuple_ptr = dyn_cast<ValueTuple>(value); | |||
| MS_EXCEPTION_IF_NULL(tuple_ptr); | |||
| if (tuple_ptr->size() == 0) { | |||
| @@ -891,14 +891,14 @@ void OnnxExporter::ConvertTupleToTensor(const ValuePtr& value, onnx::TensorProto | |||
| } | |||
| } | |||
| void OnnxExporter::SetNodeAttribute(const ValuePtr& value, onnx::NodeProto* const node_proto) { | |||
| void OnnxExporter::SetNodeAttribute(const ValuePtr &value, onnx::NodeProto *const node_proto) { | |||
| node_proto->set_op_type("Constant"); | |||
| onnx::AttributeProto* attr_proto = node_proto->add_attribute(); | |||
| onnx::AttributeProto *attr_proto = node_proto->add_attribute(); | |||
| attr_proto->set_name("value"); | |||
| MS_LOG(EXCEPTION) << "Need to set value " << value->ToString() << " attribute for Constant node"; | |||
| } | |||
| std::string GetOnnxProtoString(const FuncGraphPtr& func_graph) { | |||
| std::string GetOnnxProtoString(const FuncGraphPtr &func_graph) { | |||
| OnnxExporter exporter; | |||
| return exporter.GetOnnxProtoString(func_graph); | |||
| } | |||
| @@ -32,12 +32,12 @@ enum class DataType { kInt, kFloat, kDouble, kUnknown }; | |||
| // Whether has a T type data in AnyPtrList. | |||
| template <class T> | |||
| bool HasType(const AnyPtrList& list) { | |||
| bool ret = std::any_of(list.begin(), list.end(), [](const AnyPtr& ptr) { return ptr->is<T>(); }); | |||
| bool HasType(const AnyPtrList &list) { | |||
| bool ret = std::any_of(list.begin(), list.end(), [](const AnyPtr &ptr) { return ptr->is<T>(); }); | |||
| return ret; | |||
| } | |||
| DataType InferType(const AnyPtrList& list) { | |||
| DataType InferType(const AnyPtrList &list) { | |||
| if (HasType<double>(list)) { | |||
| return DataType::kDouble; | |||
| } else if (HasType<float>(list)) { | |||
| @@ -180,7 +180,7 @@ bool InnerScalarGe(T x, U y) { | |||
| } | |||
| #define SCALAR_OP(op_t) \ | |||
| ValuePtr Scalar##op_t(const ValuePtrList& list) { \ | |||
| ValuePtr Scalar##op_t(const ValuePtrList &list) { \ | |||
| do { \ | |||
| if (list.size() < 2) { \ | |||
| MS_LOG(EXCEPTION) << "length of input list for Scalar" << #op_t << " is less than 2."; \ | |||
| @@ -223,7 +223,7 @@ SCALAR_OP(Pow) | |||
| SCALAR_OP(Floordiv) | |||
| #define LOGIC_OP(op_t) \ | |||
| ValuePtr Scalar##op_t(const ValuePtrList& list) { \ | |||
| ValuePtr Scalar##op_t(const ValuePtrList &list) { \ | |||
| if (list.size() < 2) { \ | |||
| MS_LOG(EXCEPTION) << "length of input list for Scalar" << #op_t << " is less than 2."; \ | |||
| } \ | |||
| @@ -274,7 +274,7 @@ LOGIC_OP(Ne) | |||
| LOGIC_OP(Le) | |||
| LOGIC_OP(Ge) | |||
| ValuePtr ScalarUAdd(const ValuePtrList& list) { | |||
| ValuePtr ScalarUAdd(const ValuePtrList &list) { | |||
| if (list.size() != 1) { | |||
| MS_LOG(EXCEPTION) << "Input number of ScalarUAdd should be 1, but got " << list.size(); | |||
| } | |||
| @@ -283,7 +283,7 @@ ValuePtr ScalarUAdd(const ValuePtrList& list) { | |||
| return x; | |||
| } | |||
| ValuePtr ScalarUSub(const ValuePtrList& list) { | |||
| ValuePtr ScalarUSub(const ValuePtrList &list) { | |||
| if (list.size() != 1) { | |||
| MS_LOG(EXCEPTION) << "Input number of ScalarUSub should be 1, but got " << list.size(); | |||
| } | |||
| @@ -302,7 +302,7 @@ ValuePtr ScalarUSub(const ValuePtrList& list) { | |||
| MS_LOG(EXCEPTION) << "Unsported Value for ScalarUSub, x: " << x->ToString() << "."; | |||
| } | |||
| ValuePtr ScalarLog(const ValuePtrList& list) { | |||
| ValuePtr ScalarLog(const ValuePtrList &list) { | |||
| if (list.empty()) { | |||
| MS_LOG(EXCEPTION) << "Input list of ScalarLog is empty."; | |||
| } | |||
| @@ -321,7 +321,7 @@ ValuePtr ScalarLog(const ValuePtrList& list) { | |||
| MS_LOG(EXCEPTION) << "Unsported Value for ScalarLog, x: " << x->ToString(); | |||
| } | |||
| ValuePtr BoolNot(const ValuePtrList& list) { | |||
| ValuePtr BoolNot(const ValuePtrList &list) { | |||
| if (list.empty()) { | |||
| MS_LOG(EXCEPTION) << "value list of BoolNot is empty"; | |||
| } | |||
| @@ -337,7 +337,7 @@ ValuePtr BoolNot(const ValuePtrList& list) { | |||
| MS_LOG(EXCEPTION) << "Unsported Value for BoolNot, x: " << x->ToString(); | |||
| } | |||
| ValuePtr BoolAnd(const ValuePtrList& list) { | |||
| ValuePtr BoolAnd(const ValuePtrList &list) { | |||
| if (list.size() < 2) { | |||
| MS_LOG(EXCEPTION) << "Input number " << list.size() << " of BoolAnd is less then 2."; | |||
| } | |||
| @@ -356,7 +356,7 @@ ValuePtr BoolAnd(const ValuePtrList& list) { | |||
| MS_LOG(EXCEPTION) << "Unsported Value for BoolAnd, x: " << x->ToString() << "."; | |||
| } | |||
| ValuePtr BoolOr(const ValuePtrList& list) { | |||
| ValuePtr BoolOr(const ValuePtrList &list) { | |||
| if (list.size() < 2) { | |||
| MS_LOG(EXCEPTION) << "Input number " << list.size() << " of BoolOr is less then 2."; | |||
| } | |||
| @@ -375,7 +375,7 @@ ValuePtr BoolOr(const ValuePtrList& list) { | |||
| MS_LOG(EXCEPTION) << "Unsported Value for BoolOr, x: " << x->ToString() << "."; | |||
| } | |||
| ValuePtr BoolEq(const ValuePtrList& list) { | |||
| ValuePtr BoolEq(const ValuePtrList &list) { | |||
| if (list.size() < 2) { | |||
| MS_LOG(EXCEPTION) << "Input number " << list.size() << " of BoolEq is less than 2."; | |||
| } | |||
| @@ -29,29 +29,29 @@ namespace prim { | |||
| using Any = mindspore::Any; | |||
| using AnyPtrList = std::vector<std::shared_ptr<Any>>; | |||
| using ValuePtrList = std::vector<ValuePtr>; | |||
| using OpsFunction = std::function<Any(const AnyPtrList&)>; | |||
| using AnfNodeOpsFunction = std::function<AnfNodePtr(const std::vector<AnfNodePtr>&)>; | |||
| using OpsFunction = std::function<Any(const AnyPtrList &)>; | |||
| using AnfNodeOpsFunction = std::function<AnfNodePtr(const std::vector<AnfNodePtr> &)>; | |||
| ValuePtr ScalarAdd(const ValuePtrList& list); | |||
| ValuePtr ScalarSub(const ValuePtrList& list); | |||
| ValuePtr ScalarMul(const ValuePtrList& list); | |||
| ValuePtr ScalarDiv(const ValuePtrList& list); | |||
| ValuePtr ScalarMod(const ValuePtrList& list); | |||
| ValuePtr ScalarPow(const ValuePtrList& list); | |||
| ValuePtr ScalarFloordiv(const ValuePtrList& list); | |||
| ValuePtr ScalarUAdd(const ValuePtrList& list); | |||
| ValuePtr ScalarUSub(const ValuePtrList& list); | |||
| ValuePtr ScalarLog(const ValuePtrList& list); | |||
| ValuePtr ScalarEq(const ValuePtrList& list); | |||
| ValuePtr ScalarLt(const ValuePtrList& list); | |||
| ValuePtr ScalarGt(const ValuePtrList& list); | |||
| ValuePtr ScalarNe(const ValuePtrList& list); | |||
| ValuePtr ScalarLe(const ValuePtrList& list); | |||
| ValuePtr ScalarGe(const ValuePtrList& list); | |||
| ValuePtr BoolNot(const ValuePtrList& list); | |||
| ValuePtr BoolAnd(const ValuePtrList& list); | |||
| ValuePtr BoolOr(const ValuePtrList& list); | |||
| ValuePtr BoolEq(const ValuePtrList& list); | |||
| ValuePtr ScalarAdd(const ValuePtrList &list); | |||
| ValuePtr ScalarSub(const ValuePtrList &list); | |||
| ValuePtr ScalarMul(const ValuePtrList &list); | |||
| ValuePtr ScalarDiv(const ValuePtrList &list); | |||
| ValuePtr ScalarMod(const ValuePtrList &list); | |||
| ValuePtr ScalarPow(const ValuePtrList &list); | |||
| ValuePtr ScalarFloordiv(const ValuePtrList &list); | |||
| ValuePtr ScalarUAdd(const ValuePtrList &list); | |||
| ValuePtr ScalarUSub(const ValuePtrList &list); | |||
| ValuePtr ScalarLog(const ValuePtrList &list); | |||
| ValuePtr ScalarEq(const ValuePtrList &list); | |||
| ValuePtr ScalarLt(const ValuePtrList &list); | |||
| ValuePtr ScalarGt(const ValuePtrList &list); | |||
| ValuePtr ScalarNe(const ValuePtrList &list); | |||
| ValuePtr ScalarLe(const ValuePtrList &list); | |||
| ValuePtr ScalarGe(const ValuePtrList &list); | |||
| ValuePtr BoolNot(const ValuePtrList &list); | |||
| ValuePtr BoolAnd(const ValuePtrList &list); | |||
| ValuePtr BoolOr(const ValuePtrList &list); | |||
| ValuePtr BoolEq(const ValuePtrList &list); | |||
| std::vector<int> BroadcastShape_(std::vector<int> s1, std::vector<int> s2); | |||
| } // namespace prim | |||
| } // namespace mindspore | |||
| @@ -66,7 +66,7 @@ const MetaFuncGraphPtr kTail = std::make_shared<Tail>("tail"); | |||
| // Apply a function of two arguments cumulatively to the items of a sequence, | |||
| // from left to right, so as to reduce the sequence to a single value.For example, | |||
| // reduce(lambda x, y: x + y, [ 1, 2, 3, 4, 5 ]) calculates ((((1 + 2) + 3) + 4) + 5). | |||
| AnyPtr Reduce(const OpsFunction& func, const AnyPtrList& list) { | |||
| AnyPtr Reduce(const OpsFunction &func, const AnyPtrList &list) { | |||
| std::shared_ptr<Any> ret; | |||
| size_t size = list.size(); | |||
| if (size < 2) { | |||
| @@ -88,7 +88,7 @@ AnyPtr Reduce(const OpsFunction& func, const AnyPtrList& list) { | |||
| return ret; | |||
| } | |||
| AnfNodePtr Reduce(const AnfNodeOpsFunction& func, const std::vector<AnfNodePtr>& list) { | |||
| AnfNodePtr Reduce(const AnfNodeOpsFunction &func, const std::vector<AnfNodePtr> &list) { | |||
| size_t size = list.size(); | |||
| if (size < 2) { | |||
| MS_LOG(EXCEPTION) << "length of inputs of Reduce is less than 2"; | |||
| @@ -121,7 +121,7 @@ void HyperMap::Init() { | |||
| {"args", SignatureEnumRW::kRWRef, SignatureEnumKind::kKindVarPositional}}); | |||
| } | |||
| HyperMap::HyperMap(const std::shared_ptr<MultitypeFuncGraph>& fn_leaf) | |||
| HyperMap::HyperMap(const std::shared_ptr<MultitypeFuncGraph> &fn_leaf) | |||
| : MetaFuncGraph("hyper_map"), | |||
| fn_leaf_(fn_leaf), | |||
| broadcast_(false), | |||
| @@ -129,13 +129,13 @@ HyperMap::HyperMap(const std::shared_ptr<MultitypeFuncGraph>& fn_leaf) | |||
| Init(); | |||
| } | |||
| HyperMap::HyperMap(const HyperMap& h) | |||
| HyperMap::HyperMap(const HyperMap &h) | |||
| : MetaFuncGraph("hyper_map"), fn_leaf_(h.fn_leaf_), broadcast_(h.broadcast_), nonleaf_(h.nonleaf_) { | |||
| Init(); | |||
| } | |||
| AnfNodePtr HyperMap::FullMake(TypePtr, const FuncGraphPtr& func_graph, const AnfNodePtr& fn_arg, | |||
| const ArgsPairList& arg_map) { | |||
| AnfNodePtr HyperMap::FullMake(TypePtr, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, | |||
| const ArgsPairList &arg_map) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| std::vector<AnfNodePtr> inputs; | |||
| if (fn_arg != nullptr) { | |||
| @@ -145,17 +145,17 @@ AnfNodePtr HyperMap::FullMake(TypePtr, const FuncGraphPtr& func_graph, const Anf | |||
| } | |||
| (void)std::transform(arg_map.begin(), arg_map.end(), std::back_inserter(inputs), | |||
| [](const std::pair<AnfNodePtr, Any>& item) { return item.first; }); | |||
| [](const std::pair<AnfNodePtr, Any> &item) { return item.first; }); | |||
| return func_graph->NewCNode(inputs); | |||
| } | |||
| AnfNodePtr HyperMap::FullMake(const std::shared_ptr<List>& type, const FuncGraphPtr& func_graph, | |||
| const AnfNodePtr& fn_arg, const ArgsPairList& arg_map) { | |||
| AnfNodePtr HyperMap::FullMake(const std::shared_ptr<List> &type, const FuncGraphPtr &func_graph, | |||
| const AnfNodePtr &fn_arg, const ArgsPairList &arg_map) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| MS_EXCEPTION_IF_NULL(type); | |||
| std::size_t size = type->elements().size(); | |||
| bool is_not_same = std::any_of(arg_map.begin(), arg_map.end(), [size](const std::pair<AnfNodePtr, TypePtr>& item) { | |||
| bool is_not_same = std::any_of(arg_map.begin(), arg_map.end(), [size](const std::pair<AnfNodePtr, TypePtr> &item) { | |||
| auto lhs = std::static_pointer_cast<List>(item.second); | |||
| MS_EXCEPTION_IF_NULL(lhs); | |||
| return lhs->elements().size() != size; | |||
| @@ -179,7 +179,7 @@ AnfNodePtr HyperMap::FullMake(const std::shared_ptr<List>& type, const FuncGraph | |||
| (void)std::transform( | |||
| arg_map.begin(), arg_map.end(), std::back_inserter(inputs2), | |||
| [&func_graph, i](const std::pair<AnfNodePtr, Any>& item) { | |||
| [&func_graph, i](const std::pair<AnfNodePtr, Any> &item) { | |||
| return func_graph->NewCNode({NewValueNode(prim::kPrimListGetItem), item.first, NewValueNode(i)}); | |||
| }); | |||
| @@ -188,13 +188,13 @@ AnfNodePtr HyperMap::FullMake(const std::shared_ptr<List>& type, const FuncGraph | |||
| return func_graph->NewCNode(inputs); | |||
| } | |||
| AnfNodePtr HyperMap::FullMake(const std::shared_ptr<Tuple>& type, const FuncGraphPtr& func_graph, | |||
| const AnfNodePtr& fn_arg, const ArgsPairList& arg_map) { | |||
| AnfNodePtr HyperMap::FullMake(const std::shared_ptr<Tuple> &type, const FuncGraphPtr &func_graph, | |||
| const AnfNodePtr &fn_arg, const ArgsPairList &arg_map) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| MS_EXCEPTION_IF_NULL(type); | |||
| std::size_t size = type->elements().size(); | |||
| bool is_not_same = std::any_of(arg_map.begin(), arg_map.end(), [size](const std::pair<AnfNodePtr, TypePtr>& item) { | |||
| bool is_not_same = std::any_of(arg_map.begin(), arg_map.end(), [size](const std::pair<AnfNodePtr, TypePtr> &item) { | |||
| auto lhs = std::static_pointer_cast<Tuple>(item.second); | |||
| MS_EXCEPTION_IF_NULL(lhs); | |||
| return lhs->elements().size() != size; | |||
| @@ -226,8 +226,8 @@ AnfNodePtr HyperMap::FullMake(const std::shared_ptr<Tuple>& type, const FuncGrap | |||
| return func_graph->NewCNode(inputs); | |||
| } | |||
| AnfNodePtr HyperMap::FullMake(const std::shared_ptr<Class>& type, const FuncGraphPtr& func_graph, | |||
| const AnfNodePtr& fn_arg, const ArgsPairList& arg_map) { | |||
| AnfNodePtr HyperMap::FullMake(const std::shared_ptr<Class> &type, const FuncGraphPtr &func_graph, | |||
| const AnfNodePtr &fn_arg, const ArgsPairList &arg_map) { | |||
| MS_EXCEPTION_IF_NULL(type); | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| @@ -257,11 +257,11 @@ AnfNodePtr HyperMap::FullMake(const std::shared_ptr<Class>& type, const FuncGrap | |||
| return func_graph->NewCNode(inputs); | |||
| } | |||
| AnfNodePtr HyperMap::Make(const FuncGraphPtr& func_graph, const AnfNodePtr& fn_arg, const ArgsPairList& arg_map) { | |||
| AnfNodePtr HyperMap::Make(const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, const ArgsPairList &arg_map) { | |||
| bool found = false; | |||
| TypeId id = kObjectTypeEnd; | |||
| std::pair<AnfNodePtr, TypePtr> pair; | |||
| for (auto& item : arg_map) { | |||
| for (auto &item : arg_map) { | |||
| pair = item; | |||
| id = item.second->type_id(); | |||
| if (nonleaf_.count(id)) { | |||
| @@ -272,7 +272,7 @@ AnfNodePtr HyperMap::Make(const FuncGraphPtr& func_graph, const AnfNodePtr& fn_a | |||
| if (found) { | |||
| // In a nonleaf situation, all arguments must have the same generic. | |||
| bool is_not_same = std::any_of(arg_map.begin(), arg_map.end(), [pair](const std::pair<AnfNodePtr, TypePtr>& item) { | |||
| bool is_not_same = std::any_of(arg_map.begin(), arg_map.end(), [pair](const std::pair<AnfNodePtr, TypePtr> &item) { | |||
| if (item.first != pair.first) { | |||
| return item.second->type_id() != pair.second->type_id(); | |||
| } | |||
| @@ -283,7 +283,7 @@ AnfNodePtr HyperMap::Make(const FuncGraphPtr& func_graph, const AnfNodePtr& fn_a | |||
| oss << "There are " << arg_map.size() << " inputs of `" << name_ << "`, corresponding type info:\n" | |||
| << trace::GetDebugInfo(func_graph->debug_info()) << "\n"; | |||
| int idx = 0; | |||
| for (auto& item : arg_map) { | |||
| for (auto &item : arg_map) { | |||
| oss << ++idx << ": " << item.second->ToString() << "\n"; | |||
| } | |||
| MS_LOG(EXCEPTION) << "HyperMap cannot match up all input types of arguments.\n" << oss.str(); | |||
| @@ -308,14 +308,14 @@ AnfNodePtr HyperMap::Make(const FuncGraphPtr& func_graph, const AnfNodePtr& fn_a | |||
| } | |||
| } | |||
| ArgsPairList HyperMap::Harmonize(const FuncGraphPtr& func_graph, const ArgsPairList& args_spec_list) { | |||
| ArgsPairList HyperMap::Harmonize(const FuncGraphPtr &func_graph, const ArgsPairList &args_spec_list) { | |||
| TypePtr type_tensor = std::make_shared<TensorType>(); | |||
| bool flag = std::any_of( | |||
| args_spec_list.begin(), args_spec_list.end(), | |||
| [type_tensor](const std::pair<AnfNodePtr, TypePtr>& item) { return IsSubType(item.second, type_tensor); }); | |||
| [type_tensor](const std::pair<AnfNodePtr, TypePtr> &item) { return IsSubType(item.second, type_tensor); }); | |||
| if (flag && broadcast_) { | |||
| ArgsPairList ret; | |||
| for (auto& item : args_spec_list) { | |||
| for (auto &item : args_spec_list) { | |||
| if (!IsSubType(item.second, type_tensor)) { | |||
| TypePtr type_tensor_ele = std::make_shared<TensorType>(item.second); | |||
| ret.push_back( | |||
| @@ -329,7 +329,7 @@ ArgsPairList HyperMap::Harmonize(const FuncGraphPtr& func_graph, const ArgsPairL | |||
| return args_spec_list; | |||
| } | |||
| FuncGraphPtr HyperMap::GenerateFromTypes(const TypePtrList& args_spec_list) { | |||
| FuncGraphPtr HyperMap::GenerateFromTypes(const TypePtrList &args_spec_list) { | |||
| FuncGraphPtr ptrGraph = std::make_shared<FuncGraph>(); | |||
| ptrGraph->set_flags(FUNC_GRAPH_FLAG_CORE, true); | |||
| ptrGraph->debug_info()->set_name("hyper_map"); | |||
| @@ -353,7 +353,7 @@ FuncGraphPtr HyperMap::GenerateFromTypes(const TypePtrList& args_spec_list) { | |||
| return ptrGraph; | |||
| } | |||
| abstract::AbstractBasePtrList HyperMap::NormalizeArgs(const AbstractBasePtrList& args_spec_list) const { | |||
| abstract::AbstractBasePtrList HyperMap::NormalizeArgs(const AbstractBasePtrList &args_spec_list) const { | |||
| if (fn_leaf_ == nullptr) { | |||
| MS_EXCEPTION_IF_NULL(args_spec_list[0]); | |||
| // Assert that hypermap's function param does not contain free variables | |||
| @@ -368,20 +368,20 @@ abstract::AbstractBasePtrList HyperMap::NormalizeArgs(const AbstractBasePtrList& | |||
| AbstractBasePtrList broadened; | |||
| (void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(broadened), | |||
| [](const AbstractBasePtr& arg) -> AbstractBasePtr { | |||
| [](const AbstractBasePtr &arg) -> AbstractBasePtr { | |||
| MS_EXCEPTION_IF_NULL(arg); | |||
| return arg->Broaden(); | |||
| }); | |||
| return broadened; | |||
| } | |||
| REGISTER_PYBIND_DEFINE(HyperMap_, ([](const py::module* m) { | |||
| REGISTER_PYBIND_DEFINE(HyperMap_, ([](const py::module *m) { | |||
| (void)py::class_<HyperMapPy, MetaFuncGraph, std::shared_ptr<HyperMapPy>>(*m, "HyperMap_") | |||
| .def(py::init<std::shared_ptr<MultitypeFuncGraph>>(), py::arg("leaf")) | |||
| .def(py::init<>()); | |||
| })); | |||
| FuncGraphPtr Tail::GenerateTupleFuncGraph(const abstract::AbstractTuplePtr& a_tuple) { | |||
| FuncGraphPtr Tail::GenerateTupleFuncGraph(const abstract::AbstractTuplePtr &a_tuple) { | |||
| MS_EXCEPTION_IF_NULL(a_tuple); | |||
| FuncGraphPtr ret = std::make_shared<FuncGraph>(); | |||
| @@ -401,7 +401,7 @@ FuncGraphPtr Tail::GenerateTupleFuncGraph(const abstract::AbstractTuplePtr& a_tu | |||
| return ret; | |||
| } | |||
| FuncGraphPtr Tail::GenerateListFuncGraph(const abstract::AbstractListPtr& a_list) { | |||
| FuncGraphPtr Tail::GenerateListFuncGraph(const abstract::AbstractListPtr &a_list) { | |||
| MS_EXCEPTION_IF_NULL(a_list); | |||
| FuncGraphPtr ret = std::make_shared<FuncGraph>(); | |||
| @@ -421,7 +421,7 @@ FuncGraphPtr Tail::GenerateListFuncGraph(const abstract::AbstractListPtr& a_list | |||
| return ret; | |||
| } | |||
| FuncGraphPtr Tail::GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) { | |||
| FuncGraphPtr Tail::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) { | |||
| if (args_spec_list.size() != 1) { | |||
| MS_LOG(EXCEPTION) << "tail requires a non-empty tuple."; | |||
| } | |||
| @@ -441,11 +441,11 @@ FuncGraphPtr Tail::GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) | |||
| } | |||
| REGISTER_PYBIND_DEFINE( | |||
| Tail_, ([](const py::module* m) { | |||
| (void)py::class_<Tail, MetaFuncGraph, std::shared_ptr<Tail>>(*m, "Tail_").def(py::init<std::string&>()); | |||
| Tail_, ([](const py::module *m) { | |||
| (void)py::class_<Tail, MetaFuncGraph, std::shared_ptr<Tail>>(*m, "Tail_").def(py::init<std::string &>()); | |||
| })); | |||
| FuncGraphPtr MakeTupleGradient::GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) { | |||
| FuncGraphPtr MakeTupleGradient::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) { | |||
| int tuple_size = SizeToInt(args_spec_list.size()); | |||
| std::ostringstream ss; | |||
| @@ -486,7 +486,7 @@ FuncGraphPtr MakeTupleGradient::GenerateFuncGraph(const AbstractBasePtrList& arg | |||
| return fg; | |||
| } | |||
| GradOperation::GradOperation(const std::string& name, bool get_all, bool get_by_list, bool sens_param) | |||
| GradOperation::GradOperation(const std::string &name, bool get_all, bool get_by_list, bool sens_param) | |||
| : MetaFuncGraph(name), get_all_(get_all), get_by_list_(get_by_list), sens_param_(sens_param) { | |||
| if (get_by_list) { | |||
| signatures_ = | |||
| @@ -496,8 +496,8 @@ GradOperation::GradOperation(const std::string& name, bool get_all, bool get_by_ | |||
| } | |||
| } | |||
| FuncGraphPtr GradOperation::GetGrad(AnfNodePtr node, const AnfNodePtr& weights, | |||
| const std::vector<AnfNodePtr>& params_list, bool applyJ) { | |||
| FuncGraphPtr GradOperation::GetGrad(AnfNodePtr node, const AnfNodePtr &weights, | |||
| const std::vector<AnfNodePtr> ¶ms_list, bool applyJ) { | |||
| FuncGraphPtr ret = std::make_shared<FuncGraph>(); | |||
| ret->set_flags(FUNC_GRAPH_FLAG_CORE, true); | |||
| @@ -537,7 +537,7 @@ FuncGraphPtr GradOperation::GetGrad(AnfNodePtr node, const AnfNodePtr& weights, | |||
| return ret; | |||
| } | |||
| void GradOperation::doGetGrad(const FuncGraphPtr& func_graph, AnfNodePtr out, AnfNodePtr ptrBprop, AnfNodePtr weights, | |||
| void GradOperation::doGetGrad(const FuncGraphPtr &func_graph, AnfNodePtr out, AnfNodePtr ptrBprop, AnfNodePtr weights, | |||
| ValueNodePtr opsTupleItem) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| @@ -590,7 +590,7 @@ void GradOperation::doGetGrad(const FuncGraphPtr& func_graph, AnfNodePtr out, An | |||
| } | |||
| // Generate the graph. | |||
| FuncGraphPtr GradOperation::GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) { | |||
| FuncGraphPtr GradOperation::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) { | |||
| if (args_spec_list.size() < 1) { | |||
| MS_LOG(EXCEPTION) << "GenerateGraph requires at least 1 parameters, while the input size is " | |||
| << args_spec_list.size() << "."; | |||
| @@ -637,21 +637,21 @@ FuncGraphPtr GradOperation::GenerateFuncGraph(const AbstractBasePtrList& args_sp | |||
| return dfBuilder; | |||
| } | |||
| REGISTER_PYBIND_DEFINE(GradOperation_, ([](const py::module* m) { | |||
| REGISTER_PYBIND_DEFINE(GradOperation_, ([](const py::module *m) { | |||
| (void)py::class_<GradOperation, MetaFuncGraph, std::shared_ptr<GradOperation>>( | |||
| *m, "GradOperation_") | |||
| .def(py::init<std::string&>(), py::arg("fn")) | |||
| .def(py::init<std::string&, bool, bool, bool>(), py::arg("fn"), py::arg("get_all"), | |||
| .def(py::init<std::string &>(), py::arg("fn")) | |||
| .def(py::init<std::string &, bool, bool, bool>(), py::arg("fn"), py::arg("get_all"), | |||
| py::arg("get_by_list"), py::arg("sens_param")); | |||
| })); | |||
| MultitypeFuncGraph::MultitypeFuncGraph(const std::string& name) : MetaFuncGraph(name) { | |||
| MultitypeFuncGraph::MultitypeFuncGraph(const std::string &name) : MetaFuncGraph(name) { | |||
| fn_cache_.clear(); | |||
| signatures_ = std::vector<Signature>({// def multitype(*args:ref): | |||
| {"args", SignatureEnumRW::kRWRef, SignatureEnumKind::kKindVarPositional}}); | |||
| } | |||
| void MultitypeFuncGraph::Register(const TypePtrList& types, specialize_fn s_fn) { | |||
| void MultitypeFuncGraph::Register(const TypePtrList &types, specialize_fn s_fn) { | |||
| MS_LOG(DEBUG) << "Register type (" << ::mindspore::ToString(types) << "."; | |||
| auto fn = fn_cache_.find(types); | |||
| if (fn != fn_cache_.end()) { | |||
| @@ -660,7 +660,7 @@ void MultitypeFuncGraph::Register(const TypePtrList& types, specialize_fn s_fn) | |||
| fn_cache_[types] = s_fn; | |||
| } | |||
| void MultitypeFuncGraph::Register(const TypePtrList& types, const py::function& py_fn) { | |||
| void MultitypeFuncGraph::Register(const TypePtrList &types, const py::function &py_fn) { | |||
| MS_LOG(DEBUG) << "Register type (" << ::mindspore::ToString(types) << ", " << std::string(py_fn.str()) << ")."; | |||
| auto fn = fn_cache_.find(types); | |||
| if (fn != fn_cache_.end()) { | |||
| @@ -669,9 +669,9 @@ void MultitypeFuncGraph::Register(const TypePtrList& types, const py::function& | |||
| fn_cache_py_[types] = py_fn; | |||
| } | |||
| void MultitypeFuncGraph::Register(const std::vector<std::string>& types_name, const py::function& py_fn) { | |||
| void MultitypeFuncGraph::Register(const std::vector<std::string> &types_name, const py::function &py_fn) { | |||
| TypePtrList types; | |||
| for (auto& type_name : types_name) { | |||
| for (auto &type_name : types_name) { | |||
| auto type_ptr = StringToType(type_name); | |||
| if (type_ptr == nullptr) { | |||
| MS_LOG(EXCEPTION) << "" << type_name << " convert from string error "; | |||
| @@ -681,7 +681,7 @@ void MultitypeFuncGraph::Register(const std::vector<std::string>& types_name, co | |||
| Register(types, py_fn); | |||
| } | |||
| void MultitypeFuncGraph::PyRegister(const py::tuple& tuple, const py::function& py_fn) { | |||
| void MultitypeFuncGraph::PyRegister(const py::tuple &tuple, const py::function &py_fn) { | |||
| std::vector<std::string> types_name; | |||
| for (size_t it = 0; it < tuple.size(); ++it) { | |||
| py::object name_py = tuple[it]; | |||
| @@ -693,16 +693,16 @@ void MultitypeFuncGraph::PyRegister(const py::tuple& tuple, const py::function& | |||
| } | |||
| Register(types_name, py_fn); | |||
| } | |||
| static TypePtr UnwrapRef(const TypePtr& type) { | |||
| static TypePtr UnwrapRef(const TypePtr &type) { | |||
| if (type->isa<RefType>()) { | |||
| return type->cast<RefTypePtr>()->subtype(); | |||
| } | |||
| return type; | |||
| } | |||
| FuncGraphPtr MultitypeFuncGraph::GenerateFromTypes(const TypePtrList& types) { | |||
| FuncGraphPtr MultitypeFuncGraph::GenerateFromTypes(const TypePtrList &types) { | |||
| bool find_fn = false; | |||
| py::function py_fn; | |||
| for (auto& item : fn_cache_py_) { | |||
| for (auto &item : fn_cache_py_) { | |||
| TypePtrList sign = item.first; | |||
| if (sign.size() != types.size()) { | |||
| continue; | |||
| @@ -735,7 +735,7 @@ FuncGraphPtr MultitypeFuncGraph::GenerateFromTypes(const TypePtrList& types) { | |||
| oss << "There are " << fn_cache_py_.size() << " prototypes for overload function `" << name_ | |||
| << "`, corresponding location info:\n"; | |||
| int idx = 0; | |||
| for (auto& item : fn_cache_py_) { | |||
| for (auto &item : fn_cache_py_) { | |||
| FuncGraphPtr func_graph = parse::ParsePythonCode(item.second); | |||
| if (func_graph == nullptr) { | |||
| MS_LOG(WARNING) << "Fail to parse Python code for function `" << name_ << "`."; | |||
| @@ -747,15 +747,15 @@ FuncGraphPtr MultitypeFuncGraph::GenerateFromTypes(const TypePtrList& types) { | |||
| << oss.str(); | |||
| } | |||
| REGISTER_PYBIND_DEFINE(MultitypeFuncGraph_, ([](const py::module* m) { | |||
| REGISTER_PYBIND_DEFINE(MultitypeFuncGraph_, ([](const py::module *m) { | |||
| (void)py::class_<MultitypeFuncGraph, MetaFuncGraph, std::shared_ptr<MultitypeFuncGraph>>( | |||
| *m, "MultitypeFuncGraph_") | |||
| .def(py::init<std::string&>()) | |||
| .def(py::init<std::string &>()) | |||
| .def("register_fn", &MultitypeFuncGraph::PyRegister); | |||
| })); | |||
| // Generate the ListMap func graph. | |||
| FuncGraphPtr ListMap::GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) { | |||
| FuncGraphPtr ListMap::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) { | |||
| size_t args_num = args_spec_list.size(); | |||
| // args: fn, list1, list2, ... | |||
| if (args_num < 2) { | |||
| @@ -821,8 +821,8 @@ FuncGraphPtr ListMap::GenerateFuncGraph(const AbstractBasePtrList& args_spec_lis | |||
| return fg_ptr; | |||
| } | |||
| void ListMap::MakeCond(const std::vector<AnfNodePtr>& lists, const FuncGraphPtr& fgnext_ptr, | |||
| const FuncGraphPtr& fg_ptr) { | |||
| void ListMap::MakeCond(const std::vector<AnfNodePtr> &lists, const FuncGraphPtr &fgnext_ptr, | |||
| const FuncGraphPtr &fg_ptr) { | |||
| MS_EXCEPTION_IF_NULL(fg_ptr); | |||
| AnfNodePtr fn = fg_ptr->add_parameter(); | |||
| @@ -858,8 +858,8 @@ void ListMap::MakeCond(const std::vector<AnfNodePtr>& lists, const FuncGraphPtr& | |||
| fgtrue_ptr->set_output(output_cnode); | |||
| } | |||
| void ListMap::MakeNext(const std::vector<AnfNodePtr>& lists, const FuncGraphPtr& fgcond_ptr, | |||
| const FuncGraphPtr& fg_ptr) { | |||
| void ListMap::MakeNext(const std::vector<AnfNodePtr> &lists, const FuncGraphPtr &fgcond_ptr, | |||
| const FuncGraphPtr &fg_ptr) { | |||
| MS_EXCEPTION_IF_NULL(fg_ptr); | |||
| AnfNodePtr fn = fg_ptr->add_parameter(); | |||
| @@ -893,7 +893,7 @@ void ListMap::MakeNext(const std::vector<AnfNodePtr>& lists, const FuncGraphPtr& | |||
| fg_ptr->set_output(output_cnode); | |||
| } | |||
| FuncGraphPtr TupleAdd::GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) { | |||
| FuncGraphPtr TupleAdd::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) { | |||
| // args: tuple1, tuple2 | |||
| abstract::CheckArgsSize("TupleAdd", args_spec_list, 2); | |||
| AbstractBasePtr abs_a = args_spec_list[0]; | |||
| @@ -928,7 +928,7 @@ FuncGraphPtr TupleAdd::GenerateFuncGraph(const AbstractBasePtrList& args_spec_li | |||
| return ret; | |||
| } | |||
| int GetArgScalarValue(const abstract::AbstractScalarPtr& scalar, const std::string&) { | |||
| int GetArgScalarValue(const abstract::AbstractScalarPtr &scalar, const std::string &) { | |||
| MS_EXCEPTION_IF_NULL(scalar); | |||
| return GetValue<int>(scalar->BuildValue()); | |||
| } | |||
| @@ -942,7 +942,7 @@ int GetPositiveIndex(int index, int length) { | |||
| return index; | |||
| } | |||
| int CheckSliceMember(const AbstractBasePtr& member, int default_value, const std::string& member_name) { | |||
| int CheckSliceMember(const AbstractBasePtr &member, int default_value, const std::string &member_name) { | |||
| MS_EXCEPTION_IF_NULL(member); | |||
| if (member->isa<AbstractScalar>()) { | |||
| @@ -957,8 +957,8 @@ int CheckSliceMember(const AbstractBasePtr& member, int default_value, const std | |||
| << member->ToString(); | |||
| } | |||
| void GenerateTupleSliceParameter(const AbstractTuplePtr& tuple, const AbstractSlicePtr& slice, int* start_index, | |||
| int* stop_index, int* step_value) { | |||
| void GenerateTupleSliceParameter(const AbstractTuplePtr &tuple, const AbstractSlicePtr &slice, int *start_index, | |||
| int *stop_index, int *step_value) { | |||
| MS_EXCEPTION_IF_NULL(tuple); | |||
| MS_EXCEPTION_IF_NULL(slice); | |||
| MS_EXCEPTION_IF_NULL(start_index); | |||
| @@ -998,7 +998,7 @@ void GenerateTupleSliceParameter(const AbstractTuplePtr& tuple, const AbstractSl | |||
| } | |||
| } | |||
| FuncGraphPtr TupleSlice::GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) { | |||
| FuncGraphPtr TupleSlice::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) { | |||
| // slice a tuple | |||
| // args: tuple, start index, end index, step | |||
| const std::string op_name("TupleSlice"); | |||
| @@ -1032,7 +1032,7 @@ FuncGraphPtr TupleSlice::GenerateFuncGraph(const AbstractBasePtrList& args_spec_ | |||
| return ret; | |||
| } | |||
| int ConvertBinaryToDecimal(const std::vector<unsigned int>& number_bin) { | |||
| int ConvertBinaryToDecimal(const std::vector<unsigned int> &number_bin) { | |||
| unsigned int number_dec = 0; | |||
| for (size_t index = 0; index < number_bin.size(); index++) { | |||
| number_dec |= number_bin[index] << index; | |||
| @@ -1040,8 +1040,8 @@ int ConvertBinaryToDecimal(const std::vector<unsigned int>& number_bin) { | |||
| return static_cast<int>(number_dec); | |||
| } | |||
| void ParseSlice(const AbstractSlicePtr& slice, std::vector<int>* begin, std::vector<int>* end, | |||
| std::vector<int>* strides, int length) { | |||
| void ParseSlice(const AbstractSlicePtr &slice, std::vector<int> *begin, std::vector<int> *end, | |||
| std::vector<int> *strides, int length) { | |||
| MS_EXCEPTION_IF_NULL(slice); | |||
| MS_EXCEPTION_IF_NULL(begin); | |||
| MS_EXCEPTION_IF_NULL(end); | |||
| @@ -1064,8 +1064,8 @@ void ParseSlice(const AbstractSlicePtr& slice, std::vector<int>* begin, std::vec | |||
| strides->push_back(step_value); | |||
| } | |||
| int GenerateStridedSliceParametersFromTuple(const AbstractTuplePtr& slice_tuple, const std::vector<int>& shape, | |||
| std::vector<int>* begin, std::vector<int>* end, std::vector<int>* strides) { | |||
| int GenerateStridedSliceParametersFromTuple(const AbstractTuplePtr &slice_tuple, const std::vector<int> &shape, | |||
| std::vector<int> *begin, std::vector<int> *end, std::vector<int> *strides) { | |||
| MS_EXCEPTION_IF_NULL(slice_tuple); | |||
| MS_EXCEPTION_IF_NULL(begin); | |||
| MS_EXCEPTION_IF_NULL(end); | |||
| @@ -1111,8 +1111,8 @@ int GenerateStridedSliceParametersFromTuple(const AbstractTuplePtr& slice_tuple, | |||
| return ConvertBinaryToDecimal(shrink); | |||
| } | |||
| int GenerateStridedSliceParametersFromSlice(const AbstractSlicePtr& slice, const std::vector<int>& shape, | |||
| std::vector<int>* begin, std::vector<int>* end, std::vector<int>* strides) { | |||
| int GenerateStridedSliceParametersFromSlice(const AbstractSlicePtr &slice, const std::vector<int> &shape, | |||
| std::vector<int> *begin, std::vector<int> *end, std::vector<int> *strides) { | |||
| MS_EXCEPTION_IF_NULL(begin); | |||
| MS_EXCEPTION_IF_NULL(end); | |||
| MS_EXCEPTION_IF_NULL(strides); | |||
| @@ -1132,9 +1132,9 @@ int GenerateStridedSliceParametersFromSlice(const AbstractSlicePtr& slice, const | |||
| return 0; | |||
| } | |||
| int GenerateStridedSliceParametersFromNumber(const AbstractScalarPtr& scalar, const std::vector<int>& shape, | |||
| std::vector<int>* begin, std::vector<int>* end, | |||
| std::vector<int>* strides) { | |||
| int GenerateStridedSliceParametersFromNumber(const AbstractScalarPtr &scalar, const std::vector<int> &shape, | |||
| std::vector<int> *begin, std::vector<int> *end, | |||
| std::vector<int> *strides) { | |||
| MS_EXCEPTION_IF_NULL(begin); | |||
| MS_EXCEPTION_IF_NULL(end); | |||
| MS_EXCEPTION_IF_NULL(strides); | |||
| @@ -1153,7 +1153,7 @@ int GenerateStridedSliceParametersFromNumber(const AbstractScalarPtr& scalar, co | |||
| return 1; | |||
| } | |||
| FuncGraphPtr TensorSlice::GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) { | |||
| FuncGraphPtr TensorSlice::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) { | |||
| // slice a tensor | |||
| // args: tensor, slice or slice tuple | |||
| const std::string op_name = std::string("TensorSlice"); | |||
| @@ -1177,7 +1177,7 @@ FuncGraphPtr TensorSlice::GenerateFuncGraph(const AbstractBasePtrList& args_spec | |||
| shrink_axis_mask = GenerateStridedSliceParametersFromNumber(scalar_ptr, shape, &begin, &end, &strides); | |||
| } else { | |||
| std::ostringstream args_info; | |||
| for (const auto& arg : args_spec_list) { | |||
| for (const auto &arg : args_spec_list) { | |||
| MS_EXCEPTION_IF_NULL(arg); | |||
| args_info << arg->ToString() << "\n"; | |||
| } | |||
| @@ -1199,19 +1199,19 @@ FuncGraphPtr TensorSlice::GenerateFuncGraph(const AbstractBasePtrList& args_spec | |||
| return ret_graph; | |||
| } | |||
| REGISTER_PYBIND_DEFINE( | |||
| TupleAdd_, ([](const py::module* m) { | |||
| (void)py::class_<TupleAdd, MetaFuncGraph, std::shared_ptr<TupleAdd>>(*m, "TupleAdd_").def(py::init<std::string&>()); | |||
| })); | |||
| REGISTER_PYBIND_DEFINE(TupleAdd_, ([](const py::module *m) { | |||
| (void)py::class_<TupleAdd, MetaFuncGraph, std::shared_ptr<TupleAdd>>(*m, "TupleAdd_") | |||
| .def(py::init<std::string &>()); | |||
| })); | |||
| REGISTER_PYBIND_DEFINE(TupleSlice_, ([](const py::module* m) { | |||
| REGISTER_PYBIND_DEFINE(TupleSlice_, ([](const py::module *m) { | |||
| (void)py::class_<TupleSlice, MetaFuncGraph, std::shared_ptr<TupleSlice>>(*m, "TupleSlice_") | |||
| .def(py::init<std::string&>()); | |||
| .def(py::init<std::string &>()); | |||
| })); | |||
| REGISTER_PYBIND_DEFINE(TensorSlice_, ([](const py::module* m) { | |||
| REGISTER_PYBIND_DEFINE(TensorSlice_, ([](const py::module *m) { | |||
| (void)py::class_<TensorSlice, MetaFuncGraph, std::shared_ptr<TensorSlice>>(*m, "TensorSlice_") | |||
| .def(py::init<std::string&>()); | |||
| .def(py::init<std::string &>()); | |||
| })); | |||
| } // namespace prim | |||
| } // namespace mindspore | |||
| @@ -47,20 +47,20 @@ using ArgsPairList = std::vector<std::pair<AnfNodePtr, TypePtr>>; | |||
| class MultitypeFuncGraph : public MetaFuncGraph { | |||
| public: | |||
| explicit MultitypeFuncGraph(const std::string& name); | |||
| explicit MultitypeFuncGraph(const std::string &name); | |||
| ~MultitypeFuncGraph() override = default; | |||
| MS_DECLARE_PARENT(MultitypeFuncGraph, MetaFuncGraph) | |||
| using specialize_fn = FuncGraph* (*)(TypePtrList); | |||
| using specialize_fn = FuncGraph *(*)(TypePtrList); | |||
| // Register a method which specialize based on types vectors; | |||
| virtual void Register(const TypePtrList& types, specialize_fn s_fn); | |||
| virtual void Register(const TypePtrList& types, const py::function& py_fn); | |||
| virtual void Register(const std::vector<std::string>& types_name, const py::function& py_fn); | |||
| virtual void PyRegister(const py::tuple& tuple, const py::function& py_fn); | |||
| virtual void Register(const TypePtrList &types, specialize_fn s_fn); | |||
| virtual void Register(const TypePtrList &types, const py::function &py_fn); | |||
| virtual void Register(const std::vector<std::string> &types_name, const py::function &py_fn); | |||
| virtual void PyRegister(const py::tuple &tuple, const py::function &py_fn); | |||
| FuncGraphPtr GenerateFromTypes(const TypePtrList& types) override; | |||
| FuncGraphPtr GenerateFromTypes(const TypePtrList &types) override; | |||
| size_t GetPyFnCacheSize() const { return fn_cache_py_.size(); } | |||
| const std::unordered_map<TypePtrList, py::function, TypeListHasher, TypeListEqual>& GetPyFunctions() const { | |||
| const std::unordered_map<TypePtrList, py::function, TypeListHasher, TypeListEqual> &GetPyFunctions() const { | |||
| return fn_cache_py_; | |||
| } | |||
| @@ -72,10 +72,10 @@ using MultitypeFuncGraphPtr = std::shared_ptr<MultitypeFuncGraph>; | |||
| class HyperMap : public MetaFuncGraph { | |||
| public: | |||
| explicit HyperMap(const std::shared_ptr<MultitypeFuncGraph>& fn_leaf = nullptr); | |||
| HyperMap(const HyperMap& h); | |||
| explicit HyperMap(const std::shared_ptr<MultitypeFuncGraph> &fn_leaf = nullptr); | |||
| HyperMap(const HyperMap &h); | |||
| void Init(); | |||
| HyperMap& operator=(const HyperMap& h) { | |||
| HyperMap &operator=(const HyperMap &h) { | |||
| if (this != &h) { | |||
| fn_leaf_ = h.fn_leaf_; | |||
| broadcast_ = h.broadcast_; | |||
| @@ -89,21 +89,21 @@ class HyperMap : public MetaFuncGraph { | |||
| ~HyperMap() override = default; | |||
| MS_DECLARE_PARENT(HyperMap, MetaFuncGraph) | |||
| abstract::AbstractBasePtrList NormalizeArgs(const abstract::AbstractBasePtrList& args_spec_list) const override; | |||
| FuncGraphPtr GenerateFromTypes(const TypePtrList& args_spec_list) override; | |||
| abstract::AbstractBasePtrList NormalizeArgs(const abstract::AbstractBasePtrList &args_spec_list) const override; | |||
| FuncGraphPtr GenerateFromTypes(const TypePtrList &args_spec_list) override; | |||
| MetaFuncGraphPtr GetFnLeaf() { return fn_leaf_; } | |||
| private: | |||
| AnfNodePtr FullMake(TypePtr type, const FuncGraphPtr& func_graph, const AnfNodePtr& fn_arg, | |||
| const ArgsPairList& arg_map); | |||
| AnfNodePtr FullMake(const std::shared_ptr<List>& type, const FuncGraphPtr& func_graph, const AnfNodePtr& fn_arg, | |||
| const ArgsPairList& arg_map); | |||
| AnfNodePtr FullMake(const std::shared_ptr<Tuple>& type, const FuncGraphPtr& func_graph, const AnfNodePtr& fn_arg, | |||
| const ArgsPairList& arg_map); | |||
| AnfNodePtr FullMake(const std::shared_ptr<Class>& type, const FuncGraphPtr& func_graph, const AnfNodePtr& fn_arg, | |||
| const ArgsPairList& arg_map); | |||
| AnfNodePtr Make(const FuncGraphPtr& graph, const AnfNodePtr& fn_arg, const ArgsPairList& arg_map); | |||
| ArgsPairList Harmonize(const FuncGraphPtr& graph, const ArgsPairList& args_spec_list); | |||
| AnfNodePtr FullMake(TypePtr type, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, | |||
| const ArgsPairList &arg_map); | |||
| AnfNodePtr FullMake(const std::shared_ptr<List> &type, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, | |||
| const ArgsPairList &arg_map); | |||
| AnfNodePtr FullMake(const std::shared_ptr<Tuple> &type, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, | |||
| const ArgsPairList &arg_map); | |||
| AnfNodePtr FullMake(const std::shared_ptr<Class> &type, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, | |||
| const ArgsPairList &arg_map); | |||
| AnfNodePtr Make(const FuncGraphPtr &graph, const AnfNodePtr &fn_arg, const ArgsPairList &arg_map); | |||
| ArgsPairList Harmonize(const FuncGraphPtr &graph, const ArgsPairList &args_spec_list); | |||
| MultitypeFuncGraphPtr fn_leaf_; | |||
| bool broadcast_; | |||
| @@ -113,7 +113,7 @@ using HyperMapPtr = std::shared_ptr<HyperMap>; | |||
| class HyperMapPy : public HyperMap { | |||
| public: | |||
| explicit HyperMapPy(const std::shared_ptr<MultitypeFuncGraph>& fn_leaf = nullptr) : HyperMap(fn_leaf) {} | |||
| explicit HyperMapPy(const std::shared_ptr<MultitypeFuncGraph> &fn_leaf = nullptr) : HyperMap(fn_leaf) {} | |||
| ~HyperMapPy() override = default; | |||
| MS_DECLARE_PARENT(HyperMapPy, HyperMap) | |||
| }; | |||
| @@ -123,56 +123,56 @@ extern ValuePtr kCompositeHyperMap; | |||
| class Tail : public MetaFuncGraph { | |||
| public: | |||
| explicit Tail(const std::string& name) : MetaFuncGraph(name) {} | |||
| explicit Tail(const std::string &name) : MetaFuncGraph(name) {} | |||
| ~Tail() override = default; | |||
| MS_DECLARE_PARENT(Tail, MetaFuncGraph) | |||
| FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) override; | |||
| FuncGraphPtr GenerateTupleFuncGraph(const abstract::AbstractTuplePtr& a_tuple); | |||
| FuncGraphPtr GenerateListFuncGraph(const abstract::AbstractListPtr& a_list); | |||
| FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override; | |||
| FuncGraphPtr GenerateTupleFuncGraph(const abstract::AbstractTuplePtr &a_tuple); | |||
| FuncGraphPtr GenerateListFuncGraph(const abstract::AbstractListPtr &a_list); | |||
| friend bool operator==(const Tail& lhs, const Tail& rhs) { return lhs.name_ == rhs.name_; } | |||
| friend bool operator==(const Tail &lhs, const Tail &rhs) { return lhs.name_ == rhs.name_; } | |||
| }; | |||
| using TailPtr = std::shared_ptr<Tail>; | |||
| class MakeTupleGradient : public MetaFuncGraph { | |||
| public: | |||
| explicit MakeTupleGradient(const std::string& name) : MetaFuncGraph(name) {} | |||
| explicit MakeTupleGradient(const std::string &name) : MetaFuncGraph(name) {} | |||
| ~MakeTupleGradient() override = default; | |||
| MS_DECLARE_PARENT(MakeTupleGradient, MetaFuncGraph) | |||
| FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) override; | |||
| friend bool operator==(const MakeTupleGradient& lhs, const MakeTupleGradient& rhs) { return lhs.name_ == rhs.name_; } | |||
| FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override; | |||
| friend bool operator==(const MakeTupleGradient &lhs, const MakeTupleGradient &rhs) { return lhs.name_ == rhs.name_; } | |||
| }; | |||
| using MakeTupleGradientPtr = std::shared_ptr<MakeTupleGradient>; | |||
| class GradOperation : public MetaFuncGraph { | |||
| public: | |||
| explicit GradOperation(const std::string& name, bool get_all = false, bool get_by_list = false, | |||
| explicit GradOperation(const std::string &name, bool get_all = false, bool get_by_list = false, | |||
| bool sens_param = false); | |||
| ~GradOperation() override = default; | |||
| MS_DECLARE_PARENT(GradOperation, MetaFuncGraph) | |||
| FuncGraphPtr GetGrad(AnfNodePtr ptrNode, const AnfNodePtr& weights, const std::vector<AnfNodePtr>& ptrParams, | |||
| FuncGraphPtr GetGrad(AnfNodePtr ptrNode, const AnfNodePtr &weights, const std::vector<AnfNodePtr> &ptrParams, | |||
| bool applyJ = false); | |||
| FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) override; | |||
| FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override; | |||
| bool sens_param() const { return sens_param_; } | |||
| bool get_all_; | |||
| bool get_by_list_; | |||
| bool sens_param_; | |||
| private: | |||
| void doGetGrad(const FuncGraphPtr& func_graph, AnfNodePtr ptrOut, AnfNodePtr ptrBprop, AnfNodePtr weights, | |||
| void doGetGrad(const FuncGraphPtr &func_graph, AnfNodePtr ptrOut, AnfNodePtr ptrBprop, AnfNodePtr weights, | |||
| ValueNodePtr opsTupleItem); | |||
| }; | |||
| using GradOperationPtr = std::shared_ptr<GradOperation>; | |||
| class ListMap { | |||
| public: | |||
| explicit ListMap(const std::string& name) : name_(name) { cache_.clear(); } | |||
| explicit ListMap(const std::string &name) : name_(name) { cache_.clear(); } | |||
| ~ListMap() = default; | |||
| void MakeCond(const std::vector<AnfNodePtr>& lists, const FuncGraphPtr& gnext_ptr, const FuncGraphPtr& graph_ptr); | |||
| void MakeNext(const std::vector<AnfNodePtr>& lists, const FuncGraphPtr& gcond_ptr, const FuncGraphPtr& graph_ptr); | |||
| FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList& args_spec_list); | |||
| void MakeCond(const std::vector<AnfNodePtr> &lists, const FuncGraphPtr &gnext_ptr, const FuncGraphPtr &graph_ptr); | |||
| void MakeNext(const std::vector<AnfNodePtr> &lists, const FuncGraphPtr &gcond_ptr, const FuncGraphPtr &graph_ptr); | |||
| FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list); | |||
| private: | |||
| std::string name_; | |||
| @@ -181,31 +181,31 @@ class ListMap { | |||
| class TupleAdd : public MetaFuncGraph { | |||
| public: | |||
| explicit TupleAdd(const std::string& name) : MetaFuncGraph(name) {} | |||
| explicit TupleAdd(const std::string &name) : MetaFuncGraph(name) {} | |||
| ~TupleAdd() override = default; | |||
| MS_DECLARE_PARENT(TupleAdd, MetaFuncGraph) | |||
| FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) override; | |||
| friend bool operator==(const TupleAdd& lhs, const TupleAdd& rhs) { return lhs.name_ == rhs.name_; } | |||
| FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override; | |||
| friend bool operator==(const TupleAdd &lhs, const TupleAdd &rhs) { return lhs.name_ == rhs.name_; } | |||
| }; | |||
| using TupleAddPtr = std::shared_ptr<TupleAdd>; | |||
| class TupleSlice : public MetaFuncGraph { | |||
| public: | |||
| explicit TupleSlice(const std::string& name) : MetaFuncGraph(name) {} | |||
| explicit TupleSlice(const std::string &name) : MetaFuncGraph(name) {} | |||
| ~TupleSlice() override = default; | |||
| MS_DECLARE_PARENT(TupleSlice, MetaFuncGraph) | |||
| FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) override; | |||
| friend bool operator==(const TupleSlice& lhs, const TupleSlice& rhs) { return lhs.name_ == rhs.name_; } | |||
| FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override; | |||
| friend bool operator==(const TupleSlice &lhs, const TupleSlice &rhs) { return lhs.name_ == rhs.name_; } | |||
| }; | |||
| using TupleSlicePtr = std::shared_ptr<TupleSlice>; | |||
| class TensorSlice : public MetaFuncGraph { | |||
| public: | |||
| explicit TensorSlice(const std::string& name) : MetaFuncGraph(name) {} | |||
| explicit TensorSlice(const std::string &name) : MetaFuncGraph(name) {} | |||
| ~TensorSlice() override = default; | |||
| MS_DECLARE_PARENT(TensorSlice, MetaFuncGraph) | |||
| FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) override; | |||
| friend bool operator==(const TensorSlice& lhs, const TensorSlice& rhs) { return lhs.name_ == rhs.name_; } | |||
| FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override; | |||
| friend bool operator==(const TensorSlice &lhs, const TensorSlice &rhs) { return lhs.name_ == rhs.name_; } | |||
| }; | |||
| using TensorSlicePtr = std::shared_ptr<TensorSlice>; | |||
| @@ -34,7 +34,7 @@ namespace prim { | |||
| namespace { | |||
| using PatternListType = std::initializer_list<BaseRef>; | |||
| const std::vector<Signature>& GetSignature(const ValuePtr& function) { | |||
| const std::vector<Signature> &GetSignature(const ValuePtr &function) { | |||
| static const auto empty = std::vector<Signature>(); | |||
| if (function->isa<Primitive>()) { | |||
| return function->cast<PrimitivePtr>()->signatures(); | |||
| @@ -44,8 +44,8 @@ const std::vector<Signature>& GetSignature(const ValuePtr& function) { | |||
| return empty; | |||
| } | |||
| void ProcessDefault(const std::string& func_name, const AbstractBasePtrList& args_spec_list, | |||
| const std::vector<Signature>& signature, bool has_var, std::vector<AnfNodePtr>* op_inputs) { | |||
| void ProcessDefault(const std::string &func_name, const AbstractBasePtrList &args_spec_list, | |||
| const std::vector<Signature> &signature, bool has_var, std::vector<AnfNodePtr> *op_inputs) { | |||
| std::size_t sig_size = signature.size(); | |||
| auto positional_size = sig_size; | |||
| if (has_var) { | |||
| @@ -64,8 +64,8 @@ void ProcessDefault(const std::string& func_name, const AbstractBasePtrList& arg | |||
| } | |||
| // Get the largest type of index in the same SignatureEnumDType of arguments. | |||
| std::map<SignatureEnumDType, size_t> GetMaxDtypeIndex(const std::vector<SignatureEnumDType>& dtypes, | |||
| const abstract::AbstractBasePtrList& args_spec_list) { | |||
| std::map<SignatureEnumDType, size_t> GetMaxDtypeIndex(const std::vector<SignatureEnumDType> &dtypes, | |||
| const abstract::AbstractBasePtrList &args_spec_list) { | |||
| // record index for signature.dtypes of the same type | |||
| // eg. [T, T1, T, T2, T, T1, T3] -> {{T:(0,2,4)}, {T1:(1,5)}, {T2:(3)}, {T3:(6)}} | |||
| std::map<SignatureEnumDType, std::vector<size_t>> type_indexs; | |||
| @@ -89,7 +89,7 @@ std::map<SignatureEnumDType, size_t> GetMaxDtypeIndex(const std::vector<Signatur | |||
| continue; | |||
| } | |||
| for (const auto& index : indexs) { | |||
| for (const auto &index : indexs) { | |||
| AbstractBasePtr arg_value = args_spec_list[index]; | |||
| if (arg_value->isa<abstract::AbstractRef>()) { | |||
| arg_value = arg_value->cast<abstract::AbstractRefPtr>()->ref(); | |||
| @@ -104,7 +104,7 @@ std::map<SignatureEnumDType, size_t> GetMaxDtypeIndex(const std::vector<Signatur | |||
| return dst_type; | |||
| } | |||
| AnfNodePtr DoCast(const AnfNodePtr& param, const AnfNodePtr& source_param, const FuncGraphPtr& graph) { | |||
| AnfNodePtr DoCast(const AnfNodePtr ¶m, const AnfNodePtr &source_param, const FuncGraphPtr &graph) { | |||
| // op and module import path | |||
| auto prim_dtype = prim::GetPythonOps("dtype", "mindspore.ops.functional"); | |||
| MS_EXCEPTION_IF_NULL(prim_dtype); | |||
| @@ -116,11 +116,11 @@ AnfNodePtr DoCast(const AnfNodePtr& param, const AnfNodePtr& source_param, const | |||
| return NewCNode({cast_node, param, dtype_node}, graph); | |||
| } | |||
| void DoAutoCast(const std::vector<Signature>& signature, const abstract::AbstractBasePtrList& args_spec_list, | |||
| const FuncGraphPtr& graph, std::vector<AnfNodePtr>* op_inputs) { | |||
| void DoAutoCast(const std::vector<Signature> &signature, const abstract::AbstractBasePtrList &args_spec_list, | |||
| const FuncGraphPtr &graph, std::vector<AnfNodePtr> *op_inputs) { | |||
| std::vector<SignatureEnumDType> dtypes; | |||
| (void)std::transform(signature.begin(), signature.end(), std::back_inserter(dtypes), | |||
| [](const Signature& sig) { return sig.dtype; }); | |||
| [](const Signature &sig) { return sig.dtype; }); | |||
| int empty_dtype_count = std::count(dtypes.begin(), dtypes.end(), SignatureEnumDType::kDTypeEmptyDefaultValue); | |||
| if (dtypes.empty() || static_cast<int>(dtypes.size()) == empty_dtype_count) { | |||
| return; | |||
| @@ -143,10 +143,10 @@ void DoAutoCast(const std::vector<Signature>& signature, const abstract::Abstrac | |||
| } | |||
| } | |||
| AnfNodePtr BuildNewCNode(const FuncGraphPtr& func_graph, const std::string& func_name, const ValuePtr& function, | |||
| const AbstractBasePtrList& args_spec_list, const std::vector<AnfNodePtr>& params_list) { | |||
| AnfNodePtr BuildNewCNode(const FuncGraphPtr &func_graph, const std::string &func_name, const ValuePtr &function, | |||
| const AbstractBasePtrList &args_spec_list, const std::vector<AnfNodePtr> ¶ms_list) { | |||
| // args: original inputs | |||
| auto& signature = GetSignature(function); | |||
| auto &signature = GetSignature(function); | |||
| std::size_t sig_size = signature.size(); | |||
| auto has_var = (sig_size > 0 && signature[sig_size - 1].kind == SignatureEnumKind::kKindVarPositional); | |||
| if (sig_size > 0) { | |||
| @@ -196,13 +196,13 @@ AnfNodePtr BuildNewCNode(const FuncGraphPtr& func_graph, const std::string& func | |||
| } | |||
| } // namespace | |||
| AnfNodePtr GenerateCNode(const FuncGraphPtr& func_graph, const std::string& func_name, const ValuePtr& function, | |||
| const AbstractBasePtrList& args_spec_list, const AnfNodePtrList& old_node_inputs) { | |||
| AnfNodePtr GenerateCNode(const FuncGraphPtr &func_graph, const std::string &func_name, const ValuePtr &function, | |||
| const AbstractBasePtrList &args_spec_list, const AnfNodePtrList &old_node_inputs) { | |||
| auto new_cnode = BuildNewCNode(func_graph, func_name, function, args_spec_list, old_node_inputs); | |||
| return new_cnode; | |||
| } | |||
| FuncGraphPtr DoSignatureMetaFuncGraph::GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) { | |||
| FuncGraphPtr DoSignatureMetaFuncGraph::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) { | |||
| FuncGraphPtr func_graph = std::make_shared<FuncGraph>(); | |||
| for (size_t i = 0; i < args_spec_list.size(); ++i) { | |||
| @@ -37,17 +37,17 @@ namespace mindspore { | |||
| namespace prim { | |||
| class DoSignatureMetaFuncGraph : public MetaFuncGraph { | |||
| public: | |||
| explicit DoSignatureMetaFuncGraph(const std::string& name, const ValuePtr& function) | |||
| explicit DoSignatureMetaFuncGraph(const std::string &name, const ValuePtr &function) | |||
| : MetaFuncGraph("S-" + name), function_(function) {} | |||
| ~DoSignatureMetaFuncGraph() override = default; | |||
| MS_DECLARE_PARENT(DoSignatureMetaFuncGraph, MetaFuncGraph) | |||
| FuncGraphPtr GenerateFuncGraph(const abstract::AbstractBasePtrList& args_spec_list) override; | |||
| FuncGraphPtr GenerateFuncGraph(const abstract::AbstractBasePtrList &args_spec_list) override; | |||
| const ValuePtr function() const { return function_; } | |||
| friend bool operator==(const DoSignatureMetaFuncGraph& lhs, const DoSignatureMetaFuncGraph& rhs) { | |||
| friend bool operator==(const DoSignatureMetaFuncGraph &lhs, const DoSignatureMetaFuncGraph &rhs) { | |||
| return &lhs == &rhs; | |||
| } | |||
| @@ -56,8 +56,8 @@ class DoSignatureMetaFuncGraph : public MetaFuncGraph { | |||
| }; | |||
| using RWSignaturePtr = std::shared_ptr<DoSignatureMetaFuncGraph>; | |||
| AnfNodePtr GenerateCNode(const FuncGraphPtr& func_graph, const std::string& func_name, const ValuePtr& function, | |||
| const AbstractBasePtrList& args_spec_list, const AnfNodePtrList& old_node_inputs); | |||
| AnfNodePtr GenerateCNode(const FuncGraphPtr &func_graph, const std::string &func_name, const ValuePtr &function, | |||
| const AbstractBasePtrList &args_spec_list, const AnfNodePtrList &old_node_inputs); | |||
| } // namespace prim | |||
| } // namespace mindspore | |||
| @@ -27,7 +27,7 @@ | |||
| namespace mindspore { | |||
| // namespace to support composite operators definition | |||
| namespace prim { | |||
| FuncGraphPtr ListAppend::GenerateFuncGraph(const abstract::AbstractBasePtrList& args_list) { | |||
| FuncGraphPtr ListAppend::GenerateFuncGraph(const abstract::AbstractBasePtrList &args_list) { | |||
| abstract::CheckArgsSize("ListAppend", args_list, 2); | |||
| AbstractBasePtr arg0 = args_list[0]; | |||
| @@ -52,9 +52,9 @@ FuncGraphPtr ListAppend::GenerateFuncGraph(const abstract::AbstractBasePtrList& | |||
| return ret; | |||
| } | |||
| REGISTER_PYBIND_DEFINE(ListAppend_, ([](const py::module* m) { | |||
| REGISTER_PYBIND_DEFINE(ListAppend_, ([](const py::module *m) { | |||
| (void)py::class_<ListAppend, MetaFuncGraph, std::shared_ptr<ListAppend>>(*m, "ListAppend_") | |||
| .def(py::init<std::string&>()); | |||
| .def(py::init<std::string &>()); | |||
| })); | |||
| } // namespace prim | |||
| } // namespace mindspore | |||
| @@ -28,15 +28,15 @@ namespace mindspore { | |||
| namespace prim { | |||
| class ListAppend : public MetaFuncGraph { | |||
| public: | |||
| explicit ListAppend(const std::string& name) : MetaFuncGraph(name) {} | |||
| explicit ListAppend(const std::string &name) : MetaFuncGraph(name) {} | |||
| ~ListAppend() override = default; | |||
| MS_DECLARE_PARENT(ListAppend, MetaFuncGraph) | |||
| FuncGraphPtr GenerateFuncGraph(const abstract::AbstractBasePtrList& a_list) override; | |||
| friend std::ostream& operator<<(std::ostream& os, const ListAppend& list_append) { | |||
| FuncGraphPtr GenerateFuncGraph(const abstract::AbstractBasePtrList &a_list) override; | |||
| friend std::ostream &operator<<(std::ostream &os, const ListAppend &list_append) { | |||
| os << list_append.name_; | |||
| return os; | |||
| } | |||
| friend bool operator==(const ListAppend& lhs, const ListAppend& rhs) { return lhs.name_ == rhs.name_; } | |||
| friend bool operator==(const ListAppend &lhs, const ListAppend &rhs) { return lhs.name_ == rhs.name_; } | |||
| }; | |||
| using ListAppendPtr = std::shared_ptr<ListAppend>; | |||
| } // namespace prim | |||
| @@ -40,7 +40,7 @@ using mindspore::abstract::AbstractKeywordArg; | |||
| using mindspore::abstract::AbstractTuple; | |||
| using mindspore::abstract::AbstractTuplePtr; | |||
| FuncGraphPtr UnpackCall::GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) { | |||
| FuncGraphPtr UnpackCall::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) { | |||
| // slice a tensor | |||
| // args: tensor, slice or slice tuple | |||
| const std::string op_name = std::string("UnpackCall"); | |||
| @@ -70,7 +70,7 @@ FuncGraphPtr UnpackCall::GenerateFuncGraph(const AbstractBasePtrList& args_spec_ | |||
| AnfNodePtr para_dict = ret_graph->add_parameter(); | |||
| auto dict_elems = arg_dict->elements(); | |||
| (void)std::transform(dict_elems.begin(), dict_elems.end(), std::back_inserter(elems), | |||
| [ret_graph, para_dict](const AbstractAttribute& item) { | |||
| [ret_graph, para_dict](const AbstractAttribute &item) { | |||
| auto dict_get_item = ret_graph->NewCNode( | |||
| {NewValueNode(prim::kPrimDictGetItem), para_dict, NewValueNode(item.first)}); | |||
| return ret_graph->NewCNode( | |||
| @@ -85,9 +85,9 @@ FuncGraphPtr UnpackCall::GenerateFuncGraph(const AbstractBasePtrList& args_spec_ | |||
| return ret_graph; | |||
| } | |||
| REGISTER_PYBIND_DEFINE(UnpackCall_, ([](const py::module* m) { | |||
| REGISTER_PYBIND_DEFINE(UnpackCall_, ([](const py::module *m) { | |||
| (void)py::class_<UnpackCall, MetaFuncGraph, std::shared_ptr<UnpackCall>>(*m, "UnpackCall_") | |||
| .def(py::init<std::string&>()); | |||
| .def(py::init<std::string &>()); | |||
| })); | |||
| } // namespace prim | |||
| @@ -40,11 +40,11 @@ namespace prim { | |||
| // and generate positional parameters and key-value pairs for function. | |||
| class UnpackCall : public MetaFuncGraph { | |||
| public: | |||
| explicit UnpackCall(const std::string& name) : MetaFuncGraph(name) {} | |||
| explicit UnpackCall(const std::string &name) : MetaFuncGraph(name) {} | |||
| ~UnpackCall() override = default; | |||
| MS_DECLARE_PARENT(UnpackCall, MetaFuncGraph) | |||
| FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) override; | |||
| friend bool operator==(const UnpackCall& lhs, const UnpackCall& rhs) { return lhs.name_ == rhs.name_; } | |||
| FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override; | |||
| friend bool operator==(const UnpackCall &lhs, const UnpackCall &rhs) { return lhs.name_ == rhs.name_; } | |||
| }; | |||
| using UnpackCallPtr = std::shared_ptr<UnpackCall>; | |||
| @@ -36,7 +36,7 @@ namespace prim { | |||
| using mindspore::abstract::AbstractBase; | |||
| using mindspore::abstract::AbstractTuple; | |||
| FuncGraphPtr ZipOperation::GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) { | |||
| FuncGraphPtr ZipOperation::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) { | |||
| // zip operation: | |||
| // input: tuple arguments | |||
| // output: tuple of items of input iterated on every input | |||
| @@ -44,7 +44,7 @@ FuncGraphPtr ZipOperation::GenerateFuncGraph(const AbstractBasePtrList& args_spe | |||
| MS_LOG(EXCEPTION) << "zip arguments input should not be empty"; | |||
| } | |||
| auto is_all_tuple = std::all_of(args_spec_list.begin(), args_spec_list.end(), [](const AbstractBasePtr& abs) -> bool { | |||
| auto is_all_tuple = std::all_of(args_spec_list.begin(), args_spec_list.end(), [](const AbstractBasePtr &abs) -> bool { | |||
| MS_EXCEPTION_IF_NULL(abs); | |||
| return abs->isa<AbstractTuple>(); | |||
| }); | |||
| @@ -53,7 +53,7 @@ FuncGraphPtr ZipOperation::GenerateFuncGraph(const AbstractBasePtrList& args_spe | |||
| } | |||
| auto min_abs = std::min_element(args_spec_list.begin(), args_spec_list.end(), | |||
| [](const AbstractBasePtr& x, const AbstractBasePtr& y) { | |||
| [](const AbstractBasePtr &x, const AbstractBasePtr &y) { | |||
| return (x->cast<AbstractTuplePtr>()->size() < y->cast<AbstractTuplePtr>()->size()); | |||
| }); | |||
| FuncGraphPtr ret_graph = std::make_shared<FuncGraph>(); | |||
| @@ -81,10 +81,10 @@ FuncGraphPtr ZipOperation::GenerateFuncGraph(const AbstractBasePtrList& args_spe | |||
| return ret_graph; | |||
| } | |||
| REGISTER_PYBIND_DEFINE(ZipOperation_, ([](const py::module* m) { | |||
| REGISTER_PYBIND_DEFINE(ZipOperation_, ([](const py::module *m) { | |||
| (void)py::class_<ZipOperation, MetaFuncGraph, std::shared_ptr<ZipOperation>>(*m, | |||
| "ZipOperation_") | |||
| .def(py::init<std::string&>()); | |||
| .def(py::init<std::string &>()); | |||
| })); | |||
| } // namespace prim | |||
| } // namespace mindspore | |||
| @@ -42,15 +42,15 @@ using AbstractTuplePtr = abstract::AbstractTuplePtr; | |||
| class ZipOperation : public MetaFuncGraph { | |||
| public: | |||
| explicit ZipOperation(const std::string& name) : MetaFuncGraph(name) {} | |||
| explicit ZipOperation(const std::string &name) : MetaFuncGraph(name) {} | |||
| ~ZipOperation() override = default; | |||
| MS_DECLARE_PARENT(ZipOperation, MetaFuncGraph) | |||
| FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) override; | |||
| friend std::ostream& operator<<(std::ostream& os, const ZipOperation& op) { | |||
| FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override; | |||
| friend std::ostream &operator<<(std::ostream &os, const ZipOperation &op) { | |||
| os << op.name_; | |||
| return os; | |||
| } | |||
| friend bool operator==(const ZipOperation& lhs, const ZipOperation& rhs) { return lhs.name_ == rhs.name_; } | |||
| friend bool operator==(const ZipOperation &lhs, const ZipOperation &rhs) { return lhs.name_ == rhs.name_; } | |||
| }; | |||
| using ZipOperationPtr = std::shared_ptr<ZipOperation>; | |||
| } // namespace prim | |||
| @@ -238,7 +238,7 @@ const PrimitivePtr kPrimImageSummary = std::make_shared<Primitive>("ImageSummary | |||
| const PrimitivePtr kPrimTensorSummary = std::make_shared<Primitive>("TensorSummary"); | |||
| const PrimitivePtr kPrimHistogramSummary = std::make_shared<Primitive>("HistogramSummary"); | |||
| ValuePtr GetPythonOps(const std::string& op_name, const std::string& module_name) { | |||
| ValuePtr GetPythonOps(const std::string &op_name, const std::string &module_name) { | |||
| py::object obj = parse::python_adapter::GetPyFn(module_name, op_name); | |||
| ValuePtr node = nullptr; | |||
| bool succ = parse::ConvertData(obj, &node); | |||
| @@ -26,8 +26,8 @@ | |||
| namespace mindspore { | |||
| // namespace to support primitive operators | |||
| namespace prim { | |||
| ValuePtr GetPythonOps(const std::string& op_name, | |||
| const std::string& module_name = "mindspore._extends.parse.standard_method"); | |||
| ValuePtr GetPythonOps(const std::string &op_name, | |||
| const std::string &module_name = "mindspore._extends.parse.standard_method"); | |||
| // Arithmetic | |||
| extern const PrimitivePtr kPrimScalarAdd; | |||
| @@ -241,7 +241,7 @@ extern const PrimitivePtr kPrimVirtualDataset; | |||
| class DoSignaturePrimitive : public Primitive { | |||
| public: | |||
| explicit DoSignaturePrimitive(const std::string& name, const ValuePtr& function) | |||
| explicit DoSignaturePrimitive(const std::string &name, const ValuePtr &function) | |||
| : Primitive("S-Prim-" + name), function_(function) {} | |||
| ~DoSignaturePrimitive() override = default; | |||
| @@ -257,7 +257,7 @@ using DoSignaturePrimitivePtr = std::shared_ptr<DoSignaturePrimitive>; | |||
| class UnpackGraphPrimitive : public Primitive { | |||
| public: | |||
| explicit UnpackGraphPrimitive(const std::string& name, const bool& with_sens, const bool& need_unpack_args) | |||
| explicit UnpackGraphPrimitive(const std::string &name, const bool &with_sens, const bool &need_unpack_args) | |||
| : Primitive("UnpackGraph"), with_sens_in_args_(with_sens), need_unpack_args_(need_unpack_args) {} | |||
| ~UnpackGraphPrimitive() override = default; | |||
| MS_DECLARE_PARENT(UnpackGraphPrimitive, Primitive) | |||
| @@ -54,7 +54,7 @@ PrimToFunction::PrimToFunction() | |||
| {"scalar_sub", kPrimTypeTwoArgs}, | |||
| {"scalar_floordiv", kPrimTypeTwoArgs}}) {} | |||
| bool PrimToFunction::GetFunction(const PrimitivePtr& prim, FunctionPtr* const func) const { | |||
| bool PrimToFunction::GetFunction(const PrimitivePtr &prim, FunctionPtr *const func) const { | |||
| bool result = false; | |||
| if (func != nullptr) { | |||
| @@ -79,7 +79,7 @@ bool PrimToFunction::GetFunction(const PrimitivePtr& prim, FunctionPtr* const fu | |||
| return result; | |||
| } | |||
| int PrimToFunction::GetPrimType(const PrimitivePtr& prim) const { | |||
| int PrimToFunction::GetPrimType(const PrimitivePtr &prim) const { | |||
| MS_EXCEPTION_IF_NULL(prim); | |||
| int prim_type = static_cast<int>(kPrimTypeUnknown); | |||
| @@ -41,21 +41,21 @@ class PrimToFunction; | |||
| class PrimToFunction { | |||
| public: | |||
| // Return a thread-safe singleton instance | |||
| static PrimToFunction& GetInstance() { | |||
| static PrimToFunction &GetInstance() { | |||
| static PrimToFunction instance; | |||
| return instance; | |||
| } | |||
| PrimToFunction(const PrimToFunction&) = delete; | |||
| PrimToFunction& operator=(const PrimToFunction&) = delete; | |||
| PrimToFunction(const PrimToFunction &) = delete; | |||
| PrimToFunction &operator=(const PrimToFunction &) = delete; | |||
| ~PrimToFunction() = default; | |||
| // Get the args and return value for a primitive instance. | |||
| bool GetFunction(const PrimitivePtr& prim, FunctionPtr* func) const; | |||
| bool GetFunction(const PrimitivePtr &prim, FunctionPtr *func) const; | |||
| private: | |||
| PrimToFunction(); | |||
| // Get the number of primitive arguments | |||
| int GetPrimType(const PrimitivePtr& prim) const; | |||
| int GetPrimType(const PrimitivePtr &prim) const; | |||
| const std::unordered_map<std::string, int> prim_func_type_map_; | |||
| }; | |||
| } // namespace prim | |||
| @@ -24,7 +24,7 @@ | |||
| namespace mindspore { | |||
| namespace ad { | |||
| Adjoint::Adjoint(const AnfNodePtr& primal, const AnfNodePtr& k, const FuncGraphPtr& caller) | |||
| Adjoint::Adjoint(const AnfNodePtr &primal, const AnfNodePtr &k, const FuncGraphPtr &caller) | |||
| : primal_(primal), caller_(caller), dout_(nullptr) { | |||
| if (k != nullptr) { | |||
| k_ = k; | |||
| @@ -43,13 +43,13 @@ Adjoint::Adjoint(const AnfNodePtr& primal, const AnfNodePtr& k, const FuncGraphP | |||
| AnfNodePtr Adjoint::k() { return k_; } | |||
| void Adjoint::RegisterKUser(const CNodePtr& user, size_t index) { k_user_.emplace_back(std::make_pair(user, index)); } | |||
| void Adjoint::RegisterKUser(const CNodePtr &user, size_t index) { k_user_.emplace_back(std::make_pair(user, index)); } | |||
| void Adjoint::UpdateK(const AnfNodePtr& new_k) { | |||
| void Adjoint::UpdateK(const AnfNodePtr &new_k) { | |||
| MS_EXCEPTION_IF_NULL(new_k); | |||
| MS_LOG(DEBUG) << "Replace k " << k_->ToString() << " with " << new_k->ToString(); | |||
| // In recursive case, it needs update. | |||
| for (auto& user : k_user_) { | |||
| for (auto &user : k_user_) { | |||
| MS_LOG(DEBUG) << "Update k user " << user.first->ToString() << " " << user.second << " input with new_k" | |||
| << new_k->ToString(); | |||
| if (user.first->input(user.second) != k_) { | |||
| @@ -65,11 +65,11 @@ AnfNodePtr Adjoint::primal() { return primal_; } | |||
| AnfNodePtr Adjoint::dout() { return dout_hole_; } | |||
| void Adjoint::RegisterDoutUser(const CNodePtr& user, size_t index) { | |||
| void Adjoint::RegisterDoutUser(const CNodePtr &user, size_t index) { | |||
| dout_user_.emplace_back(std::make_pair(user, index)); | |||
| } | |||
| void Adjoint::AccumulateDout(const AnfNodePtr& dout_factor) { | |||
| void Adjoint::AccumulateDout(const AnfNodePtr &dout_factor) { | |||
| if (dout_ != nullptr) { | |||
| MS_LOG(DEBUG) << "Update dout " << dout_->ToString() << " with dout_factor " << dout_factor->ToString(); | |||
| auto add = prim::GetPythonOps("hyper_add"); | |||
| @@ -81,7 +81,7 @@ void Adjoint::AccumulateDout(const AnfNodePtr& dout_factor) { | |||
| void Adjoint::CallDoutHole() { | |||
| if (dout_ != nullptr) { | |||
| for (auto& user : dout_user_) { | |||
| for (auto &user : dout_user_) { | |||
| MS_LOG(DEBUG) << "Update dout user " << user.first->ToString() << " " << user.second << " input with dout " | |||
| << dout_->ToString(); | |||
| if (user.first->input(user.second) != dout_hole_) { | |||
| @@ -28,15 +28,15 @@ namespace mindspore { | |||
| namespace ad { | |||
| class Adjoint { | |||
| public: | |||
| Adjoint(const AnfNodePtr& primal, const AnfNodePtr& k, const FuncGraphPtr& caller); | |||
| Adjoint(const AnfNodePtr &primal, const AnfNodePtr &k, const FuncGraphPtr &caller); | |||
| ~Adjoint() = default; | |||
| AnfNodePtr primal(); | |||
| AnfNodePtr k(); | |||
| void UpdateK(const AnfNodePtr& k); | |||
| void RegisterKUser(const CNodePtr& user, size_t index); | |||
| void UpdateK(const AnfNodePtr &k); | |||
| void RegisterKUser(const CNodePtr &user, size_t index); | |||
| AnfNodePtr dout(); | |||
| void AccumulateDout(const AnfNodePtr& dout_factor); | |||
| void RegisterDoutUser(const CNodePtr& user, size_t index); | |||
| void AccumulateDout(const AnfNodePtr &dout_factor); | |||
| void RegisterDoutUser(const CNodePtr &user, size_t index); | |||
| void CallDoutHole(); | |||
| private: | |||
| @@ -36,7 +36,7 @@ using mindspore::abstract::AbstractList; | |||
| using mindspore::abstract::AbstractScalar; | |||
| using mindspore::abstract::AbstractTuple; | |||
| static AbstractBasePtr Reabs(const AbstractBasePtr& t) { | |||
| static AbstractBasePtr Reabs(const AbstractBasePtr &t) { | |||
| if (t == nullptr) { | |||
| return nullptr; | |||
| } | |||
| @@ -47,14 +47,14 @@ static AbstractBasePtr Reabs(const AbstractBasePtr& t) { | |||
| AbstractBasePtrList baselist; | |||
| auto attributes = abs_class->attributes(); | |||
| (void)std::transform(attributes.begin(), attributes.end(), std::back_inserter(baselist), | |||
| [](const AbstractAttribute& item) { return item.second; }); | |||
| [](const AbstractAttribute &item) { return item.second; }); | |||
| res = std::make_shared<AbstractTuple>(baselist); | |||
| } else if (t->isa<AbstractDictionary>()) { | |||
| auto abs_dict = dyn_cast<AbstractDictionary>(t); | |||
| AbstractBasePtrList baselist; | |||
| auto elements = abs_dict->elements(); | |||
| (void)std::transform(elements.begin(), elements.end(), std::back_inserter(baselist), | |||
| [](const AbstractAttribute& item) { return item.second; }); | |||
| [](const AbstractAttribute &item) { return item.second; }); | |||
| res = std::make_shared<AbstractTuple>(baselist); | |||
| } else if (t->isa<AbstractList>()) { | |||
| auto abs_dict = dyn_cast<AbstractList>(t); | |||
| @@ -63,11 +63,11 @@ static AbstractBasePtr Reabs(const AbstractBasePtr& t) { | |||
| return res; | |||
| } | |||
| AnfNodePtr ConvertGetAttrToTupleGetItem(const CNodePtr& node) { | |||
| AnfNodePtr ConvertGetAttrToTupleGetItem(const CNodePtr &node) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| MS_EXCEPTION_IF_NULL(node->func_graph()); | |||
| const auto& inputs = node->inputs(); | |||
| const auto &inputs = node->inputs(); | |||
| // Inputs should be [getattr, data, attribute] | |||
| MS_ASSERT(inputs.size() == 3 && "GetAttr should have three inputs."); | |||
| @@ -86,9 +86,9 @@ AnfNodePtr ConvertGetAttrToTupleGetItem(const CNodePtr& node) { | |||
| auto cons_str = cons_is_str ? GetValue<std::string>(GetValueNode(cons)) : ""; | |||
| auto ct = dyn_cast<AbstractClass>(dt); | |||
| const auto& cmap = ct->attributes(); | |||
| const auto &cmap = ct->attributes(); | |||
| int count = 0; | |||
| for (auto& item : cmap) { | |||
| for (auto &item : cmap) { | |||
| if (cons_is_str && item.first == cons_str) { | |||
| break; | |||
| } | |||
| @@ -102,12 +102,12 @@ AnfNodePtr ConvertGetAttrToTupleGetItem(const CNodePtr& node) { | |||
| return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleGetItem), data, idx_c}); | |||
| } | |||
| AnfNodePtr ConvertDictGetItemToTupleGetItem(const CNodePtr& node) { | |||
| AnfNodePtr ConvertDictGetItemToTupleGetItem(const CNodePtr &node) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| MS_EXCEPTION_IF_NULL(node->func_graph()); | |||
| // Inputs should be [dict_getitem, dict, item] | |||
| const auto& inputs = node->inputs(); | |||
| const auto &inputs = node->inputs(); | |||
| MS_ASSERT(inputs.size() == 3 && "DictGetItem should have three inputs."); | |||
| AnfNodePtr data = inputs[1]; | |||
| @@ -124,9 +124,9 @@ AnfNodePtr ConvertDictGetItemToTupleGetItem(const CNodePtr& node) { | |||
| auto cons_str = cons_is_str ? GetValue<std::string>(GetValueNode(cons)) : ""; | |||
| auto ct = dyn_cast<abstract::AbstractDictionary>(dt); | |||
| const auto& cmap = ct->elements(); | |||
| const auto &cmap = ct->elements(); | |||
| int count = 0; | |||
| for (auto& item : cmap) { | |||
| for (auto &item : cmap) { | |||
| if (cons_is_str && item.first == cons_str) { | |||
| break; | |||
| } | |||
| @@ -139,7 +139,7 @@ AnfNodePtr ConvertDictGetItemToTupleGetItem(const CNodePtr& node) { | |||
| return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleGetItem), data, idx_c}); | |||
| } | |||
| AnfNodePtr ConvertMakeRecordToMakeTuple(const CNodePtr& node) { | |||
| AnfNodePtr ConvertMakeRecordToMakeTuple(const CNodePtr &node) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| MS_EXCEPTION_IF_NULL(node->func_graph()); | |||
| @@ -150,11 +150,11 @@ AnfNodePtr ConvertMakeRecordToMakeTuple(const CNodePtr& node) { | |||
| return node->func_graph()->NewCNode(inputs); | |||
| } | |||
| AnfNodePtr ErasePartialNode(const CNodePtr& node) { | |||
| AnfNodePtr ErasePartialNode(const CNodePtr &node) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| MS_EXCEPTION_IF_NULL(node->func_graph()); | |||
| const auto& inputs = node->inputs(); | |||
| const auto &inputs = node->inputs(); | |||
| // Inputs should be [partial, fn, arg1, ...], so offset by 2 to get arg; | |||
| MS_ASSERT(inputs.size() >= 2 && "Partial should have more than two inputs."); | |||
| @@ -178,7 +178,7 @@ AnfNodePtr ErasePartialNode(const CNodePtr& node) { | |||
| return nullptr; | |||
| } | |||
| AnfNodePtr ConvertMakeListToMakeTuple(const CNodePtr& node) { | |||
| AnfNodePtr ConvertMakeListToMakeTuple(const CNodePtr &node) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| MS_EXCEPTION_IF_NULL(node->func_graph()); | |||
| @@ -189,11 +189,11 @@ AnfNodePtr ConvertMakeListToMakeTuple(const CNodePtr& node) { | |||
| return node->func_graph()->NewCNode(inputs); | |||
| } | |||
| AnfNodePtr ConvertListGetItemToTupleGetItem(const CNodePtr& node) { | |||
| AnfNodePtr ConvertListGetItemToTupleGetItem(const CNodePtr &node) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| MS_EXCEPTION_IF_NULL(node->func_graph()); | |||
| const auto& inputs = node->inputs(); | |||
| const auto &inputs = node->inputs(); | |||
| // Inputs should be [list_getitem, list, item] | |||
| if (inputs.size() < 3) { | |||
| MS_LOG(EXCEPTION) << "Node's input number < 3."; | |||
| @@ -208,11 +208,11 @@ AnfNodePtr ConvertListGetItemToTupleGetItem(const CNodePtr& node) { | |||
| return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleGetItem), data, cons_node}); | |||
| } | |||
| AnfNodePtr ConvertListSetItemToTupleSetItem(const CNodePtr& node) { | |||
| AnfNodePtr ConvertListSetItemToTupleSetItem(const CNodePtr &node) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| MS_EXCEPTION_IF_NULL(node->func_graph()); | |||
| const auto& inputs = node->inputs(); | |||
| const auto &inputs = node->inputs(); | |||
| // Inputs should be [list_setitem, list, index, item] | |||
| if (inputs.size() < 4) { | |||
| MS_LOG(EXCEPTION) << "Node's input number < 4."; | |||
| @@ -225,36 +225,36 @@ AnfNodePtr ConvertListSetItemToTupleSetItem(const CNodePtr& node) { | |||
| return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleSetItem), data, cons, value}); | |||
| } | |||
| AnfNodePtr EraseMakeDictNode(const CNodePtr& node) { | |||
| AnfNodePtr EraseMakeDictNode(const CNodePtr &node) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| const auto& inputs = node->inputs(); | |||
| const auto &inputs = node->inputs(); | |||
| MS_ASSERT(inputs.size() >= 3 && "MakeDict should have three inputs"); | |||
| return inputs[2]; | |||
| } | |||
| AnfNodePtr EraseMakeKeywordArgNode(const CNodePtr& node) { | |||
| AnfNodePtr EraseMakeKeywordArgNode(const CNodePtr &node) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| const auto& inputs = node->inputs(); | |||
| const auto &inputs = node->inputs(); | |||
| // Inputs should be [make_keyword_arg, key, value] | |||
| MS_ASSERT(inputs.size() == 3 && "MakeKeyword should have three inputs"); | |||
| return inputs[2]; | |||
| } | |||
| AnfNodePtr EraseExtractKeywordArg(const CNodePtr& node) { | |||
| AnfNodePtr EraseExtractKeywordArg(const CNodePtr &node) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| const auto& inputs = node->inputs(); | |||
| const auto &inputs = node->inputs(); | |||
| // Inputs should be [extract_keyword_arg, arg, key] | |||
| MS_ASSERT(inputs.size() == 3 && "ExtractKeyword should have three inputs"); | |||
| return inputs[2]; | |||
| } | |||
| ValueTuplePtr ConvertValueListToValueTuple(const ValueListPtr& value_list, int depth) { | |||
| ValueTuplePtr ConvertValueListToValueTuple(const ValueListPtr &value_list, int depth) { | |||
| const int DEPTH_MAX = 5; | |||
| if (depth > DEPTH_MAX) { | |||
| MS_LOG(EXCEPTION) << "List nesting is not allowed more than 5 levels."; | |||
| } | |||
| std::vector<ValuePtr> elements; | |||
| for (const auto& it : value_list->value()) { | |||
| for (const auto &it : value_list->value()) { | |||
| ValuePtr value = nullptr; | |||
| if (it->isa<ValueList>()) { | |||
| value = ConvertValueListToValueTuple(it->cast<ValueListPtr>(), depth + 1); | |||
| @@ -266,7 +266,7 @@ ValueTuplePtr ConvertValueListToValueTuple(const ValueListPtr& value_list, int d | |||
| return std::make_shared<ValueTuple>(elements); | |||
| } | |||
| AnfNodePtr ConvertValueListNodeToValueTupleNode(const ValueNodePtr& node) { | |||
| AnfNodePtr ConvertValueListNodeToValueTupleNode(const ValueNodePtr &node) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| ValuePtr value = node->value(); | |||
| auto value_list = value->cast<ValueListPtr>(); | |||
| @@ -278,13 +278,13 @@ AnfNodePtr ConvertValueListNodeToValueTupleNode(const ValueNodePtr& node) { | |||
| // Convert class to Tuple | |||
| // Convert getattr to getitem | |||
| // Convert make_record to make_tuple | |||
| void SimplifyDataStructures(const FuncGraphPtr& root, const FuncGraphManagerPtr& manager) { | |||
| void SimplifyDataStructures(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager) { | |||
| MS_EXCEPTION_IF_NULL(manager); | |||
| manager->AddFuncGraph(root); | |||
| // Since `manager->Replace(...);` will modify member `all_nodes_`, so `all_node` can't be a ref var | |||
| AnfNodeSet all_node = manager->all_nodes(); | |||
| for (auto& node : all_node) { | |||
| for (auto &node : all_node) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| AnfNodePtr new_node = nullptr; | |||
| @@ -320,20 +320,20 @@ void SimplifyDataStructures(const FuncGraphPtr& root, const FuncGraphManagerPtr& | |||
| } | |||
| } | |||
| for (auto& node : manager->all_nodes()) { | |||
| for (auto &node : manager->all_nodes()) { | |||
| auto ret = Reabs(node->abstract()); | |||
| node->set_abstract(ret); | |||
| } | |||
| } | |||
| // expand tuples in graph parameters | |||
| static std::vector<AnfNodePtr> ExpandTuplesP(const FuncGraphManagerPtr& mng, const FuncGraphPtr& func_graph, | |||
| const std::vector<AnfNodePtr>& params) { | |||
| static std::vector<AnfNodePtr> ExpandTuplesP(const FuncGraphManagerPtr &mng, const FuncGraphPtr &func_graph, | |||
| const std::vector<AnfNodePtr> ¶ms) { | |||
| MS_EXCEPTION_IF_NULL(mng); | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| std::vector<AnfNodePtr> new_params; | |||
| for (const auto& param : params) { | |||
| for (const auto ¶m : params) { | |||
| MS_EXCEPTION_IF_NULL(param); | |||
| auto param_abs = param->abstract(); | |||
| MS_EXCEPTION_IF_NULL(param_abs); | |||
| @@ -350,7 +350,7 @@ static std::vector<AnfNodePtr> ExpandTuplesP(const FuncGraphManagerPtr& mng, con | |||
| std::vector<AnfNodePtr> new_param; | |||
| std::vector<AnfNodePtr> inputs{NewValueNode(prim::kPrimMakeTuple)}; | |||
| auto abs_tuple = dyn_cast<AbstractTuple>(param_abs); | |||
| for (auto& elem : abs_tuple->elements()) { | |||
| for (auto &elem : abs_tuple->elements()) { | |||
| auto np = std::make_shared<Parameter>(func_graph); | |||
| np->set_abstract(elem); | |||
| new_param.emplace_back(np); | |||
| @@ -366,11 +366,11 @@ static std::vector<AnfNodePtr> ExpandTuplesP(const FuncGraphManagerPtr& mng, con | |||
| } | |||
| // expand tuples in graph applies | |||
| static std::vector<AnfNodePtr> ExpandTuplesC(const FuncGraphPtr& graph, const std::vector<AnfNodePtr>& inputs) { | |||
| static std::vector<AnfNodePtr> ExpandTuplesC(const FuncGraphPtr &graph, const std::vector<AnfNodePtr> &inputs) { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| std::vector<AnfNodePtr> new_inputs; | |||
| for (const auto& input : inputs) { | |||
| for (const auto &input : inputs) { | |||
| MS_EXCEPTION_IF_NULL(input); | |||
| auto input_abs = input->abstract(); | |||
| @@ -391,7 +391,7 @@ static std::vector<AnfNodePtr> ExpandTuplesC(const FuncGraphPtr& graph, const st | |||
| int idx = 0; | |||
| std::vector<AnfNodePtr> new_input; | |||
| auto abs_tuple = dyn_cast<AbstractTuple>(input_abs); | |||
| for (auto& elem : abs_tuple->elements()) { | |||
| for (auto &elem : abs_tuple->elements()) { | |||
| auto c_node = graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), input, NewValueNode(idx)}); | |||
| AbstractBasePtr aptr = std::make_shared<AbstractScalar>(std::make_shared<Int32Imm>(idx)); | |||
| c_node->input(2)->set_abstract(aptr); | |||
| @@ -416,19 +416,19 @@ static std::vector<AnfNodePtr> ExpandTuplesC(const FuncGraphPtr& graph, const st | |||
| // tuples in Graph's parameters: AbstractTuple (a, b, c) --> | |||
| // CNode("make_tuple", Parameter(a), Parameter(b), Parameter(c)) | |||
| // cppcheck-suppress unusedFunction | |||
| void EraseTuple(const FuncGraphPtr& root, const FuncGraphManagerPtr& manager) { | |||
| void EraseTuple(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager) { | |||
| MS_EXCEPTION_IF_NULL(manager); | |||
| manager->AddFuncGraph(root); | |||
| // NOTICE: since `manager->Replace(...);` will modify member `all_nodes_`, so `all_node` can't be a ref var | |||
| AnfNodeSet all_node = manager->all_nodes(); | |||
| for (auto& node : all_node) { | |||
| for (auto &node : all_node) { | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| if (cnode == nullptr) { | |||
| continue; | |||
| } | |||
| const auto& inputs = cnode->inputs(); | |||
| const auto &inputs = cnode->inputs(); | |||
| // Bypass the first input in inputs as it's fn. | |||
| if (!IsValueNode<Primitive>(inputs[0])) { | |||
| @@ -466,7 +466,7 @@ void EraseTuple(const FuncGraphPtr& root, const FuncGraphManagerPtr& manager) { | |||
| } | |||
| FuncGraphSet all_graph = manager->func_graphs(); | |||
| for (auto& func_graph : all_graph) { | |||
| for (auto &func_graph : all_graph) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| auto expand_p = ExpandTuplesP(manager, func_graph, func_graph->parameters()); | |||
| manager->SetParameters(func_graph, expand_p); | |||
| @@ -22,7 +22,7 @@ | |||
| namespace mindspore { | |||
| namespace opt { | |||
| // Automatically adding control depend based on effect order and side effect analysis. | |||
| void AddControlDepend(const FuncGraphPtr& graph); | |||
| void AddControlDepend(const FuncGraphPtr &graph); | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_OPTIMIZER_CONTROL_DEPEND_H_ | |||
| @@ -44,7 +44,7 @@ static AnfNodePtr GenerateUnpackGraphNode(std::vector<AnfNodePtr> inputs_y, Func | |||
| nodes.push_back(func_node); | |||
| // {unpackcall, {GradOperation, ...}, args...} | |||
| std::transform(inputs_y.begin() + 2, inputs_y.end(), std::back_inserter(nodes), | |||
| [](const AnfNodePtr& node) { return node; }); | |||
| [](const AnfNodePtr &node) { return node; }); | |||
| unpack_graph_node = func_graph->NewCNode(nodes); | |||
| } else { | |||
| auto unpack_graph = std::make_shared<prim::UnpackGraphPrimitive>("unpack_graph", sens_param, false); | |||
| @@ -52,14 +52,14 @@ static AnfNodePtr GenerateUnpackGraphNode(std::vector<AnfNodePtr> inputs_y, Func | |||
| nodes.push_back(func_node); | |||
| // {{GradOperation, ...}, args...} | |||
| std::transform(inputs_y.begin() + 1, inputs_y.end(), std::back_inserter(nodes), | |||
| [](const AnfNodePtr& node) { return node; }); | |||
| [](const AnfNodePtr &node) { return node; }); | |||
| unpack_graph_node = func_graph->NewCNode(nodes); | |||
| } | |||
| return unpack_graph_node; | |||
| } | |||
| // get metagraph of value node | |||
| MetaFuncGraphPtr GetMetaFuncGraphOfValueNode(const AnfNodePtr& node) { | |||
| MetaFuncGraphPtr GetMetaFuncGraphOfValueNode(const AnfNodePtr &node) { | |||
| ValuePtr value; | |||
| if (IsValueNode<prim::DoSignaturePrimitive>(node)) { | |||
| value = GetValueNode(node)->cast<prim::DoSignaturePrimitivePtr>()->function(); | |||
| @@ -73,7 +73,7 @@ MetaFuncGraphPtr GetMetaFuncGraphOfValueNode(const AnfNodePtr& node) { | |||
| } | |||
| // check if node is a specific metafuncgraph op | |||
| bool IsMetaFuncGraph(const AnfNodePtr& node, const MetaFuncGraphPtr meta_func_graph) { | |||
| bool IsMetaFuncGraph(const AnfNodePtr &node, const MetaFuncGraphPtr meta_func_graph) { | |||
| if (node != nullptr) { | |||
| auto meta_func_graph_ptr = GetMetaFuncGraphOfValueNode(node); | |||
| if (meta_func_graph_ptr == nullptr) { | |||
| @@ -89,7 +89,7 @@ bool IsMetaFuncGraph(const AnfNodePtr& node, const MetaFuncGraphPtr meta_func_gr | |||
| // {{GradOperation, g, w}, Ys} | |||
| // {UnPackCall, {GradOperation, g, w}, Ys} | |||
| AnfNodePtr GradVarPrepare::operator()(const OptimizerPtr&, const AnfNodePtr& node) { | |||
| AnfNodePtr GradVarPrepare::operator()(const OptimizerPtr &, const AnfNodePtr &node) { | |||
| if (!node->isa<CNode>() || node->func_graph() == nullptr) { | |||
| return nullptr; | |||
| } | |||
| @@ -31,20 +31,20 @@ | |||
| namespace mindspore { | |||
| /* namespace to support opt */ | |||
| namespace opt { | |||
| SubstitutionPtr MakeSubstitution(const TransformFuncType& transform, const std::string& name, const PrimitivePtr& prim, | |||
| const RenormAction& renorm_action) { | |||
| auto fn = [prim](const AnfNodePtr& node) -> bool { return IsPrimitiveCNode(node, prim); }; | |||
| SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std::string &name, const PrimitivePtr &prim, | |||
| const RenormAction &renorm_action) { | |||
| auto fn = [prim](const AnfNodePtr &node) -> bool { return IsPrimitiveCNode(node, prim); }; | |||
| return std::make_shared<Substitution>(transform, name, fn, renorm_action); | |||
| } | |||
| SubstitutionPtr MakeSubstitution(const TransformFuncType& transform, const std::string& name, | |||
| const std::vector<PrimitivePtr>& prims, const RenormAction& renorm_action) { | |||
| auto fn = [prims](const AnfNodePtr& node) -> bool { | |||
| SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std::string &name, | |||
| const std::vector<PrimitivePtr> &prims, const RenormAction &renorm_action) { | |||
| auto fn = [prims](const AnfNodePtr &node) -> bool { | |||
| if (!node->isa<CNode>()) { | |||
| return false; | |||
| } | |||
| for (auto& prim : prims) { | |||
| for (auto &prim : prims) { | |||
| if (IsPrimitiveCNode(node, prim)) { | |||
| return true; | |||
| } | |||
| @@ -55,12 +55,12 @@ SubstitutionPtr MakeSubstitution(const TransformFuncType& transform, const std:: | |||
| return std::make_shared<Substitution>(transform, name, fn, renorm_action); | |||
| } | |||
| SubstitutionPtr MakeSubstitution(const TransformFuncType& transform, const std::string& name, | |||
| const PredicateFuncType& predicate, const RenormAction& renorm_action) { | |||
| SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std::string &name, | |||
| const PredicateFuncType &predicate, const RenormAction &renorm_action) { | |||
| return std::make_shared<Substitution>(transform, name, predicate, renorm_action); | |||
| } | |||
| AnfNodePtr Substitution::operator()(const OptimizerPtr& optimizer, const AnfNodePtr& node) const { | |||
| AnfNodePtr Substitution::operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) const { | |||
| #ifdef ENABLE_PROFILE | |||
| double t = GetTime(); | |||
| #endif | |||
| @@ -88,8 +88,8 @@ AnfNodePtr Substitution::operator()(const OptimizerPtr& optimizer, const AnfNode | |||
| return result; | |||
| } | |||
| bool SubstitutionList::ApplyTransform(const OptimizerPtr& optimizer, const AnfNodePtr& root_node, | |||
| const SubstitutionPtr& transform) const { | |||
| bool SubstitutionList::ApplyTransform(const OptimizerPtr &optimizer, const AnfNodePtr &root_node, | |||
| const SubstitutionPtr &transform) const { | |||
| FuncGraphManagerPtr manager = optimizer->manager(); | |||
| std::unordered_set<AnfNodePtr> seen_node; | |||
| std::deque<AnfNodePtr> todo{root_node}; | |||
| @@ -131,13 +131,13 @@ bool SubstitutionList::ApplyTransform(const OptimizerPtr& optimizer, const AnfNo | |||
| } | |||
| if (node->isa<CNode>()) { | |||
| auto& inputs = node->cast<CNodePtr>()->inputs(); | |||
| auto &inputs = node->cast<CNodePtr>()->inputs(); | |||
| (void)std::copy(inputs.begin(), inputs.end(), std::back_inserter(todo)); | |||
| } | |||
| auto& node_users = manager->node_users(); | |||
| auto &node_users = manager->node_users(); | |||
| if (change && node_users.find(node) != node_users.end()) { | |||
| for (auto& use : node_users[node]) { | |||
| for (auto &use : node_users[node]) { | |||
| auto use_node = use.first; | |||
| todo.push_back(use_node); | |||
| if (seen_node.find(use_node) != seen_node.end()) { | |||
| @@ -152,7 +152,7 @@ bool SubstitutionList::ApplyTransform(const OptimizerPtr& optimizer, const AnfNo | |||
| return changes; | |||
| } | |||
| bool SubstitutionList::operator()(const FuncGraphPtr& func_graph, const OptimizerPtr& optimizer) const { | |||
| bool SubstitutionList::operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer) const { | |||
| MS_EXCEPTION_IF_NULL(optimizer); | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| FuncGraphManagerPtr manager = optimizer->manager(); | |||
| @@ -163,7 +163,7 @@ bool SubstitutionList::operator()(const FuncGraphPtr& func_graph, const Optimize | |||
| do { | |||
| loop = false; | |||
| for (auto const& transform : list_) { | |||
| for (auto const &transform : list_) { | |||
| auto change = ApplyTransform(optimizer, func_graph->output(), transform); | |||
| changes = changes || change; | |||
| loop = loop || change; | |||
| @@ -28,7 +28,7 @@ | |||
| namespace mindspore { | |||
| namespace parallel { | |||
| std::unordered_set<CNodePtr> FindCNodesWithPara(const AnfNodePtr& para, uint32_t recursive_times = 0) { | |||
| std::unordered_set<CNodePtr> FindCNodesWithPara(const AnfNodePtr ¶, uint32_t recursive_times = 0) { | |||
| if (recursive_times > MAX_RECURSIVE_CALL_TIMES) { | |||
| MS_LOG(EXCEPTION) << "FindCNodesWithPara exceeds max recursive call times! Max recursive call times is " | |||
| << MAX_RECURSIVE_CALL_TIMES; | |||
| @@ -39,7 +39,7 @@ std::unordered_set<CNodePtr> FindCNodesWithPara(const AnfNodePtr& para, uint32_t | |||
| MS_EXCEPTION_IF_NULL(manager); | |||
| auto node_set = manager->node_users()[para]; | |||
| std::unordered_set<CNodePtr> cnode_set; | |||
| for (auto& node_pair : node_set) { | |||
| for (auto &node_pair : node_set) { | |||
| auto cnode = node_pair.first->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| if (!IsValueNode<Primitive>(cnode->input(0))) { | |||
| @@ -54,7 +54,7 @@ std::unordered_set<CNodePtr> FindCNodesWithPara(const AnfNodePtr& para, uint32_t | |||
| (void)cnode_set.emplace(cnode); | |||
| } else { | |||
| auto cnode_set_sub = FindCNodesWithPara(node_pair.first, recursive_times + 1); | |||
| for (auto& cnode_sub : cnode_set_sub) { | |||
| for (auto &cnode_sub : cnode_set_sub) { | |||
| (void)cnode_set.emplace(cnode_sub); | |||
| } | |||
| } | |||
| @@ -63,8 +63,8 @@ std::unordered_set<CNodePtr> FindCNodesWithPara(const AnfNodePtr& para, uint32_t | |||
| } | |||
| Status AllreduceFusion::AddNodeToGraph() { | |||
| const auto& parameters = root_graph_->parameters(); | |||
| for (auto& parameter : parameters) { | |||
| const auto ¶meters = root_graph_->parameters(); | |||
| for (auto ¶meter : parameters) { | |||
| if (!ParameterRequireGrad(parameter)) { | |||
| continue; | |||
| } | |||
| @@ -72,7 +72,7 @@ Status AllreduceFusion::AddNodeToGraph() { | |||
| if (cnode_set.empty()) { | |||
| continue; | |||
| } | |||
| for (auto& cnode : cnode_set) { | |||
| for (auto &cnode : cnode_set) { | |||
| MS_LOG(DEBUG) << "AddNode " << cnode->DebugString(); | |||
| if (allreduce_graph_.AddNode(cnode, parameter) != SUCCESS) { | |||
| MS_LOG(ERROR) << "AddNode failed! cnode: " << cnode->DebugString(); | |||
| @@ -83,7 +83,7 @@ Status AllreduceFusion::AddNodeToGraph() { | |||
| return SUCCESS; | |||
| } | |||
| CNodeCostMap AllreduceFusion::FindCNode(const AnfNodePtr& from, uint32_t recursive_times) const { | |||
| CNodeCostMap AllreduceFusion::FindCNode(const AnfNodePtr &from, uint32_t recursive_times) const { | |||
| if (recursive_times > MAX_RECURSIVE_CALL_TIMES) { | |||
| MS_LOG(EXCEPTION) << "FindCNode exceeds max recursive call times! Max recursive call times is " | |||
| << MAX_RECURSIVE_CALL_TIMES; | |||
| @@ -110,30 +110,30 @@ CNodeCostMap AllreduceFusion::FindCNode(const AnfNodePtr& from, uint32_t recursi | |||
| return cnode_dist; | |||
| } else { | |||
| auto cnode_dist_next = FindNextCNodes(cnode, recursive_times + 1); | |||
| for (auto& ele : cnode_dist_next) { | |||
| for (auto &ele : cnode_dist_next) { | |||
| cnode_dist[ele.first] = cost + ele.second; | |||
| } | |||
| } | |||
| } else { | |||
| auto cnode_dist_next = FindNextCNodes(cnode); | |||
| for (auto& ele : cnode_dist_next) { | |||
| for (auto &ele : cnode_dist_next) { | |||
| cnode_dist[ele.first] = ele.second; | |||
| } | |||
| } | |||
| return cnode_dist; | |||
| } | |||
| CNodeCostMap AllreduceFusion::FindNextCNodes(const CNodePtr& from, uint32_t recursive_times) const { | |||
| CNodeCostMap AllreduceFusion::FindNextCNodes(const CNodePtr &from, uint32_t recursive_times) const { | |||
| if (recursive_times > MAX_RECURSIVE_CALL_TIMES) { | |||
| MS_LOG(EXCEPTION) << "FindNextCNodes exceeds max recursive call times! Max recursive call times is " | |||
| << MAX_RECURSIVE_CALL_TIMES; | |||
| } | |||
| const auto& from_inputs = from->inputs(); | |||
| const auto &from_inputs = from->inputs(); | |||
| std::unordered_map<CNodePtr, double> dist_map; | |||
| MS_LOG(DEBUG) << "from cnode " << from->DebugString() << " has " << from_inputs.size() << " inputs"; | |||
| for (auto& input_node : from_inputs) { | |||
| for (auto &input_node : from_inputs) { | |||
| auto cnode_dist = FindCNode(input_node, recursive_times + 1); | |||
| for (auto& ele : cnode_dist) { | |||
| for (auto &ele : cnode_dist) { | |||
| (void)dist_map.emplace(ele); | |||
| } | |||
| } | |||
| @@ -142,11 +142,11 @@ CNodeCostMap AllreduceFusion::FindNextCNodes(const CNodePtr& from, uint32_t recu | |||
| Status AllreduceFusion::AddEdgeToGraph() { | |||
| std::unordered_map<CNodePtr, int32_t> cnode_state_map; | |||
| const auto& cnodes = allreduce_graph_.cnode_set(); | |||
| for (auto& cnode : cnodes) { | |||
| const auto &cnodes = allreduce_graph_.cnode_set(); | |||
| for (auto &cnode : cnodes) { | |||
| cnode_state_map[cnode] = 0; | |||
| } | |||
| const auto& head_cnode = allreduce_graph_.head_cnode(); | |||
| const auto &head_cnode = allreduce_graph_.head_cnode(); | |||
| std::queue<CNodePtr> cnode_queue; | |||
| cnode_queue.emplace(head_cnode); | |||
| cnode_state_map[head_cnode] = 1; | |||
| @@ -156,9 +156,9 @@ Status AllreduceFusion::AddEdgeToGraph() { | |||
| cnode_queue.pop(); | |||
| cnode_state_map[cur_cnode] = 2; | |||
| auto next = FindNextCNodes(cur_cnode); | |||
| for (auto& ele : next) { | |||
| auto& cnode = ele.first; | |||
| auto& dist = ele.second; | |||
| for (auto &ele : next) { | |||
| auto &cnode = ele.first; | |||
| auto &dist = ele.second; | |||
| if (cnode_state_map[cnode] == 0) { | |||
| cnode_queue.emplace(cnode); | |||
| cnode_state_map[cnode] = 1; | |||
| @@ -173,7 +173,7 @@ Status AllreduceFusion::AddEdgeToGraph() { | |||
| return SUCCESS; | |||
| } | |||
| std::vector<CNodePtr> FindMirror(const AnfNodePtr& para, uint32_t recursive_times = 0) { | |||
| std::vector<CNodePtr> FindMirror(const AnfNodePtr ¶, uint32_t recursive_times = 0) { | |||
| if (recursive_times > MAX_RECURSIVE_CALL_TIMES) { | |||
| MS_LOG(EXCEPTION) << "FindMirror exceeds max recursive call times! Max recursive call times is " | |||
| << MAX_RECURSIVE_CALL_TIMES; | |||
| @@ -184,7 +184,7 @@ std::vector<CNodePtr> FindMirror(const AnfNodePtr& para, uint32_t recursive_time | |||
| MS_EXCEPTION_IF_NULL(manager); | |||
| AnfNodeIndexSet node_set = manager->node_users()[para]; | |||
| std::vector<CNodePtr> cnode_list; | |||
| for (auto& node_pair : node_set) { | |||
| for (auto &node_pair : node_set) { | |||
| auto cnode = node_pair.first->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| if (!IsValueNode<Primitive>(cnode->input(0))) { | |||
| @@ -210,7 +210,7 @@ std::vector<CNodePtr> FindMirror(const AnfNodePtr& para, uint32_t recursive_time | |||
| return cnode_list; | |||
| } | |||
| void SetMirrorFusion(const CNodePtr& mirror_cnode, int32_t fusion, const std::string& parameter_name) { | |||
| void SetMirrorFusion(const CNodePtr &mirror_cnode, int32_t fusion, const std::string ¶meter_name) { | |||
| MS_EXCEPTION_IF_NULL(mirror_cnode); | |||
| MS_LOG(DEBUG) << "Set Mirror " << mirror_cnode->DebugString() << " fusion " << fusion; | |||
| auto node_prim = GetValueNode<PrimitivePtr>(mirror_cnode->input(0)); | |||
| @@ -227,14 +227,14 @@ void SetMirrorFusion(const CNodePtr& mirror_cnode, int32_t fusion, const std::st | |||
| (void)node_prim->AddAttr(PARAMETER, MakeValue(std::make_shared<StringImm>(parameter_name))); | |||
| } | |||
| Status FindMirrorAndSetFusion(const AnfNodePtr& para, int32_t fusion) { | |||
| Status FindMirrorAndSetFusion(const AnfNodePtr ¶, int32_t fusion) { | |||
| auto mirror_cnodes = FindMirror(para); | |||
| if (mirror_cnodes.empty()) { | |||
| MS_LOG(WARNING) << para->ToString() << " 0 Mirror CNode found."; | |||
| return SUCCESS; | |||
| } | |||
| if (mirror_cnodes.size() > 2) { | |||
| for (auto& mirror_cnode : mirror_cnodes) { | |||
| for (auto &mirror_cnode : mirror_cnodes) { | |||
| MS_EXCEPTION_IF_NULL(mirror_cnode); | |||
| MS_LOG(INFO) << mirror_cnode->DebugString(); | |||
| } | |||
| @@ -243,15 +243,15 @@ Status FindMirrorAndSetFusion(const AnfNodePtr& para, int32_t fusion) { | |||
| << "Mirror CNode found."; | |||
| return FAILED; | |||
| } | |||
| for (auto& mirror_cnode : mirror_cnodes) { | |||
| for (auto &mirror_cnode : mirror_cnodes) { | |||
| auto parameter_name = ParameterName(para); | |||
| SetMirrorFusion(mirror_cnode, fusion, parameter_name); | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| Status FindMirrorAndSetFusion(const std::vector<AnfNodePtr>& paras, int32_t fusion) { | |||
| for (auto& param_node : paras) { | |||
| Status FindMirrorAndSetFusion(const std::vector<AnfNodePtr> ¶s, int32_t fusion) { | |||
| for (auto ¶m_node : paras) { | |||
| if (FindMirrorAndSetFusion(param_node, fusion) != SUCCESS) { | |||
| MS_LOG(ERROR) << "FindMirrorAndSetFusion failed"; | |||
| return FAILED; | |||
| @@ -260,7 +260,7 @@ Status FindMirrorAndSetFusion(const std::vector<AnfNodePtr>& paras, int32_t fusi | |||
| return SUCCESS; | |||
| } | |||
| Status AllreduceFusion::SetFusion(const std::vector<double>& cost_map) { | |||
| Status AllreduceFusion::SetFusion(const std::vector<double> &cost_map) { | |||
| if (cost_map.size() < 2) { | |||
| MS_LOG(ERROR) << "cost_map must has at least 2 items, cost_map size is " << cost_map.size(); | |||
| return FAILED; | |||
| @@ -386,7 +386,7 @@ Status AllreduceFusion::SetFusionByAlgorithm(int32_t algorithm) { | |||
| return SetFusionByBackwardCompAndAllreduceTime(); | |||
| } | |||
| Status AllreduceFusion::ProcessAllreduceFusion(const CNodePtr& ret) { | |||
| Status AllreduceFusion::ProcessAllreduceFusion(const CNodePtr &ret) { | |||
| if (ret == nullptr) { | |||
| MS_LOG(ERROR) << "ret is nullptr."; | |||
| return FAILED; | |||
| @@ -50,15 +50,15 @@ class AllreduceFusion { | |||
| allreduce_bandwidth_(0), | |||
| computation_time_parameter_(0) {} | |||
| virtual ~AllreduceFusion() = default; | |||
| Status ProcessAllreduceFusion(const CNodePtr& ret); | |||
| Status ProcessAllreduceFusion(const CNodePtr &ret); | |||
| private: | |||
| Status AddNodeToGraph(); | |||
| CNodeCostMap FindCNode(const AnfNodePtr& from, uint32_t recursive_times = 0) const; | |||
| CNodeCostMap FindNextCNodes(const CNodePtr& from, uint32_t recursive_times = 0) const; | |||
| CNodeCostMap FindCNode(const AnfNodePtr &from, uint32_t recursive_times = 0) const; | |||
| CNodeCostMap FindNextCNodes(const CNodePtr &from, uint32_t recursive_times = 0) const; | |||
| Status AddEdgeToGraph(); | |||
| std::vector<double> GenerateCostMap(int32_t fusion_times, double tail_percent) const; | |||
| Status SetFusion(const std::vector<double>& cost_map); | |||
| Status SetFusion(const std::vector<double> &cost_map); | |||
| Status SetFusionByAlgorithm(int32_t algorithm); | |||
| Status SetFusionByBackwardCompTime(); | |||
| Status SetFusionByBackwardCompAndAllreduceTime(); | |||
| @@ -23,7 +23,7 @@ | |||
| namespace mindspore { | |||
| namespace parallel { | |||
| Status AllreduceGraph::AddNode(const CNodePtr& node, const AnfNodePtr& para) { | |||
| Status AllreduceGraph::AddNode(const CNodePtr &node, const AnfNodePtr ¶) { | |||
| AllreduceNodePtr arnode; | |||
| auto cnode_emplace_return = cnode_set_.emplace(node); | |||
| if (!cnode_emplace_return.second) { | |||
| @@ -64,7 +64,7 @@ Status AllreduceGraph::AddNode(const CNodePtr& node, const AnfNodePtr& para) { | |||
| return SUCCESS; | |||
| } | |||
| Status AllreduceGraph::AddEdge(const CNodePtr& from, const CNodePtr& to, double dist) { | |||
| Status AllreduceGraph::AddEdge(const CNodePtr &from, const CNodePtr &to, double dist) { | |||
| auto from_arnode_iter = cnode_arnode_map_.find(from); | |||
| if (from_arnode_iter == cnode_arnode_map_.end()) { | |||
| MS_LOG(ERROR) << "cnode from: " << from->DebugString() << "has not been added"; | |||
| @@ -94,14 +94,14 @@ Status AllreduceGraph::AddEdge(const CNodePtr& from, const CNodePtr& to, double | |||
| return SUCCESS; | |||
| } | |||
| bool AllreduceGraph::NodeInGraph(const CNodePtr& node) const { | |||
| bool AllreduceGraph::NodeInGraph(const CNodePtr &node) const { | |||
| auto cnode_iter = cnode_set_.find(node); | |||
| return !(cnode_iter == cnode_set_.end()); | |||
| } | |||
| std::vector<AnfNodePtr> AllreduceGraph::GetParaByCost(double from, double to) { | |||
| std::vector<AnfNodePtr> nodes; | |||
| for (auto& cnode_arnode : cnode_arnode_map_) { | |||
| for (auto &cnode_arnode : cnode_arnode_map_) { | |||
| MS_LOG(DEBUG) << "cnode: " << cnode_arnode.first->DebugString() | |||
| << ", depend_feat_size: " << cnode_arnode.second->depend_feat_size() | |||
| << " curr_para_size: " << cnode_arnode.second->curr_para_size(); | |||
| @@ -117,7 +117,7 @@ std::pair<std::vector<AnfNodePtr>, double> AllreduceGraph::GetParaByParaSize(dou | |||
| std::vector<AnfNodePtr> nodes; | |||
| double cur_para_size = 0; | |||
| double from = to; | |||
| for (auto& arnode : arnode_vec_) { | |||
| for (auto &arnode : arnode_vec_) { | |||
| if (arnode.depend_feat_size() != max_ && arnode.depend_feat_size() >= to) { | |||
| continue; | |||
| } | |||
| @@ -135,14 +135,14 @@ std::pair<std::vector<AnfNodePtr>, double> AllreduceGraph::GetParaByParaSize(dou | |||
| void AllreduceGraph::PrintCNodeSet() const { | |||
| MS_LOG(INFO) << "CNodeSet:"; | |||
| for (auto& cnode : cnode_set_) { | |||
| for (auto &cnode : cnode_set_) { | |||
| MS_LOG(INFO) << cnode->DebugString(); | |||
| } | |||
| } | |||
| void AllreduceGraph::PrintAllredueGraphInfo() const { | |||
| MS_LOG(INFO) << "max: " << max_; | |||
| for (auto& cnode_arnode : cnode_arnode_map_) { | |||
| for (auto &cnode_arnode : cnode_arnode_map_) { | |||
| MS_LOG(INFO) << "cnode: " << cnode_arnode.first->DebugString(); | |||
| MS_LOG(INFO) << "arnode info: "; | |||
| cnode_arnode.second->ToString(); | |||
| @@ -151,21 +151,21 @@ void AllreduceGraph::PrintAllredueGraphInfo() const { | |||
| void AllreduceGraph::PrintArnodeVec() const { | |||
| MS_LOG(INFO) << "ArnodeVec:"; | |||
| for (auto& arnode : arnode_vec_) { | |||
| for (auto &arnode : arnode_vec_) { | |||
| arnode.ToString(); | |||
| } | |||
| } | |||
| void AllreduceGraph::PrintArnodeSet() const { | |||
| MS_LOG(INFO) << "ArnodeSet:"; | |||
| for (auto& arnode : arnode_set_) { | |||
| for (auto &arnode : arnode_set_) { | |||
| arnode->ToString(); | |||
| } | |||
| } | |||
| void AllreduceGraph::SortArnode() { | |||
| arnode_vec_.clear(); | |||
| for (auto& node : arnode_set_) { | |||
| for (auto &node : arnode_set_) { | |||
| arnode_vec_.emplace_back(*node); | |||
| } | |||
| std::sort(arnode_vec_.begin(), arnode_vec_.end(), std::greater<>()); | |||
| @@ -173,8 +173,8 @@ void AllreduceGraph::SortArnode() { | |||
| Status AllreduceGraph::RemoveExtraParas() { | |||
| std::unordered_set<AnfNodePtr> para_map; | |||
| for (auto& node : arnode_vec_) { | |||
| for (auto& para : node.paras()) { | |||
| for (auto &node : arnode_vec_) { | |||
| for (auto ¶ : node.paras()) { | |||
| auto emplac_result = para_map.emplace(para); | |||
| if (!emplac_result.second) { | |||
| MS_LOG(DEBUG) << "parameter: " << para->fullname_with_scope() << "in arnode"; | |||
| @@ -188,7 +188,7 @@ Status AllreduceGraph::RemoveExtraParas() { | |||
| return SUCCESS; | |||
| } | |||
| Status AllreduceGraph::set_head_cnode(const CNodePtr& node) { | |||
| Status AllreduceGraph::set_head_cnode(const CNodePtr &node) { | |||
| auto arnode = std::make_shared<AllreduceNode>(AllreduceNode()); | |||
| if (arnode->Init(node) != SUCCESS) { | |||
| MS_LOG(ERROR) << "AllreduceNode Init failed"; | |||
| @@ -42,9 +42,9 @@ class AllreduceGraph { | |||
| cnode_arnode_map_(), | |||
| max_(0) {} | |||
| virtual ~AllreduceGraph() = default; | |||
| Status AddNode(const CNodePtr& node, const AnfNodePtr& para); | |||
| Status AddEdge(const CNodePtr& from, const CNodePtr& to, double dist); | |||
| bool NodeInGraph(const CNodePtr& node) const; | |||
| Status AddNode(const CNodePtr &node, const AnfNodePtr ¶); | |||
| Status AddEdge(const CNodePtr &from, const CNodePtr &to, double dist); | |||
| bool NodeInGraph(const CNodePtr &node) const; | |||
| std::vector<AnfNodePtr> GetParaByCost(double from, double to); | |||
| // Find the first several AllreduceNode whose depend_feat_size is less than to, the sum of whose parameter size is | |||
| // over para_size. | |||
| @@ -60,9 +60,9 @@ class AllreduceGraph { | |||
| void PrintAllredueGraphInfo() const; | |||
| void PrintArnodeVec() const; | |||
| void PrintArnodeSet() const; | |||
| const std::unordered_set<CNodePtr>& cnode_set() const { return cnode_set_; } | |||
| const std::unordered_set<CNodePtr> &cnode_set() const { return cnode_set_; } | |||
| CNodePtr head_cnode() const { return head_cnode_; } | |||
| Status set_head_cnode(const CNodePtr& node); | |||
| Status set_head_cnode(const CNodePtr &node); | |||
| double max() const { return max_; } | |||
| private: | |||