| @@ -94,7 +94,7 @@ PenaltyBreakString: 1000 | |||||
| PenaltyBreakTemplateDeclaration: 10 | PenaltyBreakTemplateDeclaration: 10 | ||||
| PenaltyExcessCharacter: 1000000 | PenaltyExcessCharacter: 1000000 | ||||
| PenaltyReturnTypeOnItsOwnLine: 200 | PenaltyReturnTypeOnItsOwnLine: 200 | ||||
| PointerAlignment: Left | |||||
| PointerAlignment: Right | |||||
| RawStringFormats: | RawStringFormats: | ||||
| - Language: Cpp | - Language: Cpp | ||||
| Delimiters: | Delimiters: | ||||
| @@ -23,7 +23,7 @@ namespace common { | |||||
| const int CACHED_STR_NUM = 1 << 8; | const int CACHED_STR_NUM = 1 << 8; | ||||
| const int CACHED_STR_MASK = CACHED_STR_NUM - 1; | const int CACHED_STR_MASK = CACHED_STR_NUM - 1; | ||||
| std::vector<std::string> STR_HOLDER(CACHED_STR_NUM); | std::vector<std::string> STR_HOLDER(CACHED_STR_NUM); | ||||
| const char* SafeCStr(const std::string&& str) { | |||||
| const char *SafeCStr(const std::string &&str) { | |||||
| static std::atomic<uint32_t> index{0}; | static std::atomic<uint32_t> index{0}; | ||||
| uint32_t cur_index = index++; | uint32_t cur_index = index++; | ||||
| cur_index = cur_index & CACHED_STR_MASK; | cur_index = cur_index & CACHED_STR_MASK; | ||||
| @@ -21,16 +21,16 @@ | |||||
| #include <string> | #include <string> | ||||
| #define DISABLE_COPY_AND_ASSIGN(ClassType) \ | #define DISABLE_COPY_AND_ASSIGN(ClassType) \ | ||||
| ClassType(const ClassType&) = delete; \ | |||||
| ClassType& operator=(const ClassType&) = delete; | |||||
| ClassType(const ClassType &) = delete; \ | |||||
| ClassType &operator=(const ClassType &) = delete; | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace common { | namespace common { | ||||
| inline const char* SafeCStr(const std::string& str) { return str.c_str(); } | |||||
| const char* SafeCStr(const std::string&& str); | |||||
| inline const char *SafeCStr(const std::string &str) { return str.c_str(); } | |||||
| const char *SafeCStr(const std::string &&str); | |||||
| static inline std::string GetEnv(const std::string& envvar) { | |||||
| const char* value = ::getenv(envvar.c_str()); | |||||
| static inline std::string GetEnv(const std::string &envvar) { | |||||
| const char *value = ::getenv(envvar.c_str()); | |||||
| if (value == nullptr) { | if (value == nullptr) { | ||||
| return std::string(); | return std::string(); | ||||
| @@ -34,11 +34,11 @@ class DecodeOp : public TensorOp { | |||||
| ~DecodeOp() = default; | ~DecodeOp() = default; | ||||
| Status Compute(const std::shared_ptr<Tensor>& input, std::shared_ptr<Tensor>* output) override; | |||||
| Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override; | |||||
| void Print(std::ostream& out) const override { out << "DecodeOp"; } | |||||
| Status OutputShape(const std::vector<TensorShape>& inputs, std::vector<TensorShape>& outputs) override; | |||||
| Status OutputType(const std::vector<DataType>& inputs, std::vector<DataType>& outputs) override; | |||||
| void Print(std::ostream &out) const override { out << "DecodeOp"; } | |||||
| Status OutputShape(const std::vector<TensorShape> &inputs, std::vector<TensorShape> &outputs) override; | |||||
| Status OutputType(const std::vector<DataType> &inputs, std::vector<DataType> &outputs) override; | |||||
| private: | private: | ||||
| bool is_rgb_format_ = true; | bool is_rgb_format_ = true; | ||||
| @@ -37,8 +37,8 @@ DistortBoundingBoxCropOp::DistortBoundingBoxCropOp(float aspect_ratio, float int | |||||
| rnd_.seed(seed_); | rnd_.seed(seed_); | ||||
| } | } | ||||
| Status DistortBoundingBoxCropOp::Compute(const std::vector<std::shared_ptr<Tensor>>& input, | |||||
| std::vector<std::shared_ptr<Tensor>>* output) { | |||||
| Status DistortBoundingBoxCropOp::Compute(const std::vector<std::shared_ptr<Tensor>> &input, | |||||
| std::vector<std::shared_ptr<Tensor>> *output) { | |||||
| IO_CHECK_VECTOR(input, output); | IO_CHECK_VECTOR(input, output); | ||||
| if (input.size() != NumInput()) | if (input.size() != NumInput()) | ||||
| return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Number of inputs is not 5"); | return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Number of inputs is not 5"); | ||||
| @@ -98,8 +98,8 @@ Status DistortBoundingBoxCropOp::Compute(const std::vector<std::shared_ptr<Tenso | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status DistortBoundingBoxCropOp::OutputShape(const std::vector<TensorShape>& inputs, | |||||
| std::vector<TensorShape>& outputs) { | |||||
| Status DistortBoundingBoxCropOp::OutputShape(const std::vector<TensorShape> &inputs, | |||||
| std::vector<TensorShape> &outputs) { | |||||
| RETURN_IF_NOT_OK(TensorOp::OutputShape(inputs, outputs)); | RETURN_IF_NOT_OK(TensorOp::OutputShape(inputs, outputs)); | ||||
| outputs.clear(); | outputs.clear(); | ||||
| TensorShape out = TensorShape{-1, -1}; | TensorShape out = TensorShape{-1, -1}; | ||||
| @@ -108,7 +108,7 @@ Status DistortBoundingBoxCropOp::OutputShape(const std::vector<TensorShape>& inp | |||||
| if (!outputs.empty()) return Status::OK(); | if (!outputs.empty()) return Status::OK(); | ||||
| return Status(StatusCode::kUnexpectedError, "Input has a wrong shape"); | return Status(StatusCode::kUnexpectedError, "Input has a wrong shape"); | ||||
| } | } | ||||
| Status DistortBoundingBoxCropOp::OutputType(const std::vector<DataType>& inputs, std::vector<DataType>& outputs) { | |||||
| Status DistortBoundingBoxCropOp::OutputType(const std::vector<DataType> &inputs, std::vector<DataType> &outputs) { | |||||
| RETURN_IF_NOT_OK(TensorOp::OutputType(inputs, outputs)); | RETURN_IF_NOT_OK(TensorOp::OutputType(inputs, outputs)); | ||||
| outputs[0] = inputs[0]; | outputs[0] = inputs[0]; | ||||
| return Status::OK(); | return Status::OK(); | ||||
| @@ -45,16 +45,16 @@ class DistortBoundingBoxCropOp : public TensorOp { | |||||
| ~DistortBoundingBoxCropOp() override = default; | ~DistortBoundingBoxCropOp() override = default; | ||||
| void Print(std::ostream& out) const override { | |||||
| void Print(std::ostream &out) const override { | |||||
| out << "DistortBoundingBoxCropOp: " << max_attempts_ << " " << intersect_ratio_; | out << "DistortBoundingBoxCropOp: " << max_attempts_ << " " << intersect_ratio_; | ||||
| } | } | ||||
| Status Compute(const std::vector<std::shared_ptr<Tensor>>& input, | |||||
| std::vector<std::shared_ptr<Tensor>>* output) override; | |||||
| Status Compute(const std::vector<std::shared_ptr<Tensor>> &input, | |||||
| std::vector<std::shared_ptr<Tensor>> *output) override; | |||||
| uint32_t NumInput() override { return 5; } | uint32_t NumInput() override { return 5; } | ||||
| Status OutputShape(const std::vector<TensorShape>& inputs, std::vector<TensorShape>& outputs) override; | |||||
| Status OutputType(const std::vector<DataType>& inputs, std::vector<DataType>& outputs) override; | |||||
| Status OutputShape(const std::vector<TensorShape> &inputs, std::vector<TensorShape> &outputs) override; | |||||
| Status OutputType(const std::vector<DataType> &inputs, std::vector<DataType> &outputs) override; | |||||
| private: | private: | ||||
| int32_t max_attempts_; | int32_t max_attempts_; | ||||
| @@ -41,7 +41,7 @@ RandomCropAndResizeOp::RandomCropAndResizeOp(int32_t target_height, int32_t targ | |||||
| rnd_.seed(GetSeed()); | rnd_.seed(GetSeed()); | ||||
| } | } | ||||
| Status RandomCropAndResizeOp::Compute(const std::shared_ptr<Tensor>& input, std::shared_ptr<Tensor>* output) { | |||||
| Status RandomCropAndResizeOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) { | |||||
| IO_CHECK(input, output); | IO_CHECK(input, output); | ||||
| CHECK_FAIL_RETURN_UNEXPECTED(input->shape().Size() >= 2, "The shape of input is abnormal"); | CHECK_FAIL_RETURN_UNEXPECTED(input->shape().Size() >= 2, "The shape of input is abnormal"); | ||||
| @@ -54,7 +54,7 @@ Status RandomCropAndResizeOp::Compute(const std::shared_ptr<Tensor>& input, std: | |||||
| (void)GetCropBox(h_in, w_in, &x, &y, &crop_height, &crop_width); | (void)GetCropBox(h_in, w_in, &x, &y, &crop_height, &crop_width); | ||||
| return CropAndResize(input, output, x, y, crop_height, crop_width, target_height_, target_width_, interpolation_); | return CropAndResize(input, output, x, y, crop_height, crop_width, target_height_, target_width_, interpolation_); | ||||
| } | } | ||||
| Status RandomCropAndResizeOp::OutputShape(const std::vector<TensorShape>& inputs, std::vector<TensorShape>& outputs) { | |||||
| Status RandomCropAndResizeOp::OutputShape(const std::vector<TensorShape> &inputs, std::vector<TensorShape> &outputs) { | |||||
| RETURN_IF_NOT_OK(TensorOp::OutputShape(inputs, outputs)); | RETURN_IF_NOT_OK(TensorOp::OutputShape(inputs, outputs)); | ||||
| outputs.clear(); | outputs.clear(); | ||||
| TensorShape out = TensorShape{target_height_, target_width_}; | TensorShape out = TensorShape{target_height_, target_width_}; | ||||
| @@ -63,7 +63,7 @@ Status RandomCropAndResizeOp::OutputShape(const std::vector<TensorShape>& inputs | |||||
| if (!outputs.empty()) return Status::OK(); | if (!outputs.empty()) return Status::OK(); | ||||
| return Status(StatusCode::kUnexpectedError, "Input has a wrong shape"); | return Status(StatusCode::kUnexpectedError, "Input has a wrong shape"); | ||||
| } | } | ||||
| Status RandomCropAndResizeOp::GetCropBox(int h_in, int w_in, int* x, int* y, int* crop_height, int* crop_width) { | |||||
| Status RandomCropAndResizeOp::GetCropBox(int h_in, int w_in, int *x, int *y, int *crop_height, int *crop_width) { | |||||
| double scale, aspect; | double scale, aspect; | ||||
| *crop_width = w_in; | *crop_width = w_in; | ||||
| *crop_height = h_in; | *crop_height = h_in; | ||||
| @@ -22,7 +22,7 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| constexpr char PARALLEL_STRATEGY[] = "strategy"; | constexpr char PARALLEL_STRATEGY[] = "strategy"; | ||||
| void DumpIR(const std::string& filename, const FuncGraphPtr& func_graph, bool dump_full_name = false); | |||||
| void DumpIR(const std::string &filename, const FuncGraphPtr &func_graph, bool dump_full_name = false); | |||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -44,7 +44,7 @@ const int NUM_MAX_SEQUENCE_ELEMS = 0x00FFFFFF; | |||||
| // get MindSpore Intermediate Representation Path | // get MindSpore Intermediate Representation Path | ||||
| std::string GetMsIrPath(void) { | std::string GetMsIrPath(void) { | ||||
| std::string path; | std::string path; | ||||
| const char* path_ptr = getenv("MS_IR_PATH"); | |||||
| const char *path_ptr = getenv("MS_IR_PATH"); | |||||
| if (path_ptr != nullptr) { | if (path_ptr != nullptr) { | ||||
| path = path_ptr; | path = path_ptr; | ||||
| char real_path[PATH_MAX] = {0}; | char real_path[PATH_MAX] = {0}; | ||||
| @@ -62,13 +62,13 @@ std::string GetMsIrPath(void) { | |||||
| return path; | return path; | ||||
| } | } | ||||
| std::string dump_obj(const py::object& obj, const std::string& path) { | |||||
| std::string dump_obj(const py::object &obj, const std::string &path) { | |||||
| py::module mod = parse::python_adapter::GetPyModule(parse::PYTHON_MOD_PARSE_MODULE); | py::module mod = parse::python_adapter::GetPyModule(parse::PYTHON_MOD_PARSE_MODULE); | ||||
| py::object name = parse::python_adapter::CallPyModFn(mod, "dump_obj", obj, py::str(path)); | py::object name = parse::python_adapter::CallPyModFn(mod, "dump_obj", obj, py::str(path)); | ||||
| return py::str(name); | return py::str(name); | ||||
| } | } | ||||
| py::object load_obj(const std::string& path) { | |||||
| py::object load_obj(const std::string &path) { | |||||
| py::module mod = parse::python_adapter::GetPyModule(parse::PYTHON_MOD_PARSE_MODULE); | py::module mod = parse::python_adapter::GetPyModule(parse::PYTHON_MOD_PARSE_MODULE); | ||||
| py::object obj = parse::python_adapter::CallPyModFn(mod, "load_obj", py::str(path)); | py::object obj = parse::python_adapter::CallPyModFn(mod, "load_obj", py::str(path)); | ||||
| return obj; | return obj; | ||||
| @@ -76,7 +76,7 @@ py::object load_obj(const std::string& path) { | |||||
| // ============================================= MindSpore IR Exporter ============================================= | // ============================================= MindSpore IR Exporter ============================================= | ||||
| std::string AnfExporter::GetNodeType(const AnfNodePtr& nd) { | |||||
| std::string AnfExporter::GetNodeType(const AnfNodePtr &nd) { | |||||
| abstract::ShapePtr shape = nd->Shape() == nullptr ? nullptr : dyn_cast<abstract::Shape>(nd->Shape()); | abstract::ShapePtr shape = nd->Shape() == nullptr ? nullptr : dyn_cast<abstract::Shape>(nd->Shape()); | ||||
| TypePtr type = dyn_cast<Type>(nd->Type()); | TypePtr type = dyn_cast<Type>(nd->Type()); | ||||
| std::ostringstream oss; | std::ostringstream oss; | ||||
| @@ -90,7 +90,7 @@ std::string AnfExporter::GetNodeType(const AnfNodePtr& nd) { | |||||
| return oss.str(); | return oss.str(); | ||||
| } | } | ||||
| std::string AnfExporter::DumpObject(const py::object& obj, const std::string& category) const { | |||||
| std::string AnfExporter::DumpObject(const py::object &obj, const std::string &category) const { | |||||
| std::string pkl_path = GetMsIrPath(); | std::string pkl_path = GetMsIrPath(); | ||||
| // if not specified env 'MS_IR_PATH', do not create any files | // if not specified env 'MS_IR_PATH', do not create any files | ||||
| if (pkl_path.empty() || (getenv("MS_IR_FILE") != nullptr)) { | if (pkl_path.empty() || (getenv("MS_IR_FILE") != nullptr)) { | ||||
| @@ -101,7 +101,7 @@ std::string AnfExporter::DumpObject(const py::object& obj, const std::string& ca | |||||
| return file_prefix + file_name; | return file_prefix + file_name; | ||||
| } | } | ||||
| int AnfExporter::GetParamIndex(const FuncGraphPtr& func_graph, const AnfNodePtr& param, bool throw_excp) { | |||||
| int AnfExporter::GetParamIndex(const FuncGraphPtr &func_graph, const AnfNodePtr ¶m, bool throw_excp) { | |||||
| if (func_graph == nullptr || param == nullptr) { | if (func_graph == nullptr || param == nullptr) { | ||||
| return -1; | return -1; | ||||
| } | } | ||||
| @@ -129,13 +129,13 @@ int AnfExporter::GetParamIndex(const FuncGraphPtr& func_graph, const AnfNodePtr& | |||||
| // try to find index of parameter for SymbolicKeyInstance from all exported graphs | // try to find index of parameter for SymbolicKeyInstance from all exported graphs | ||||
| // NOTICE: Suppose name of all parameters in SymbolicKeyInstance are different | // NOTICE: Suppose name of all parameters in SymbolicKeyInstance are different | ||||
| int AnfExporter::GetParamIndexFromExported(const AnfNodePtr& param) { | |||||
| int AnfExporter::GetParamIndexFromExported(const AnfNodePtr ¶m) { | |||||
| if (param == nullptr) { | if (param == nullptr) { | ||||
| return -1; | return -1; | ||||
| } | } | ||||
| int ret = -1; | int ret = -1; | ||||
| for (const auto& item : exported) { | |||||
| for (const auto &item : exported) { | |||||
| auto pram_iter = item.second.find(param); | auto pram_iter = item.second.find(param); | ||||
| if (pram_iter != item.second.end()) { | if (pram_iter != item.second.end()) { | ||||
| return pram_iter->second; | return pram_iter->second; | ||||
| @@ -144,12 +144,12 @@ int AnfExporter::GetParamIndexFromExported(const AnfNodePtr& param) { | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| std::string AnfExporter::GetValueNodeText(const FuncGraphPtr& fg, const ValueNodePtr& node) { | |||||
| std::string AnfExporter::GetValueNodeText(const FuncGraphPtr &fg, const ValueNodePtr &node) { | |||||
| MS_EXCEPTION_IF_NULL(node); | MS_EXCEPTION_IF_NULL(node); | ||||
| return GetValueText(fg, node->value()); | return GetValueText(fg, node->value()); | ||||
| } | } | ||||
| std::string AnfExporter::GetMultitypeFuncGraphText(const prim::MultitypeFuncGraphPtr& mt_func_graph) { | |||||
| std::string AnfExporter::GetMultitypeFuncGraphText(const prim::MultitypeFuncGraphPtr &mt_func_graph) { | |||||
| auto py_funcs = mt_func_graph->GetPyFunctions(); | auto py_funcs = mt_func_graph->GetPyFunctions(); | ||||
| if (py_funcs.empty()) { | if (py_funcs.empty()) { | ||||
| return ""; | return ""; | ||||
| @@ -159,7 +159,7 @@ std::string AnfExporter::GetMultitypeFuncGraphText(const prim::MultitypeFuncGrap | |||||
| oss << "{"; | oss << "{"; | ||||
| bool is_first = true; | bool is_first = true; | ||||
| for (const auto& py_func : py_funcs) { | |||||
| for (const auto &py_func : py_funcs) { | |||||
| if (is_first) { | if (is_first) { | ||||
| is_first = false; | is_first = false; | ||||
| } else { | } else { | ||||
| @@ -193,7 +193,7 @@ std::string AnfExporter::GetMultitypeFuncGraphText(const prim::MultitypeFuncGrap | |||||
| * ├── GradOperation | * ├── GradOperation | ||||
| * └── TupleAdd | * └── TupleAdd | ||||
| */ | */ | ||||
| std::string AnfExporter::GetMetaFuncGraphText(const MetaFuncGraphPtr& meta_func_graph) { | |||||
| std::string AnfExporter::GetMetaFuncGraphText(const MetaFuncGraphPtr &meta_func_graph) { | |||||
| if (meta_func_graph == nullptr) { | if (meta_func_graph == nullptr) { | ||||
| return ""; | return ""; | ||||
| } | } | ||||
| @@ -244,7 +244,7 @@ std::string AnfExporter::GetMetaFuncGraphText(const MetaFuncGraphPtr& meta_func_ | |||||
| return oss.str(); | return oss.str(); | ||||
| } | } | ||||
| std::string AnfExporter::GetPrimitiveText(const PrimitivePtr& prim) { | |||||
| std::string AnfExporter::GetPrimitiveText(const PrimitivePtr &prim) { | |||||
| std::ostringstream oss; | std::ostringstream oss; | ||||
| if (prim == nullptr) { | if (prim == nullptr) { | ||||
| return oss.str(); | return oss.str(); | ||||
| @@ -266,7 +266,7 @@ std::string AnfExporter::GetPrimitiveText(const PrimitivePtr& prim) { | |||||
| if (prim->isa<prim::DoSignaturePrimitive>()) { | if (prim->isa<prim::DoSignaturePrimitive>()) { | ||||
| auto do_signature = dyn_cast<prim::DoSignaturePrimitive>(prim); | auto do_signature = dyn_cast<prim::DoSignaturePrimitive>(prim); | ||||
| auto& func = do_signature->function(); | |||||
| auto &func = do_signature->function(); | |||||
| if (func->isa<Primitive>()) { | if (func->isa<Primitive>()) { | ||||
| auto sig_prim = dyn_cast<Primitive>(func); | auto sig_prim = dyn_cast<Primitive>(func); | ||||
| oss << sig_prim->GetAttrsText(); | oss << sig_prim->GetAttrsText(); | ||||
| @@ -276,7 +276,7 @@ std::string AnfExporter::GetPrimitiveText(const PrimitivePtr& prim) { | |||||
| return oss.str(); | return oss.str(); | ||||
| } | } | ||||
| std::string AnfExporter::GetNameSpaceText(const parse::NameSpacePtr& ns) { | |||||
| std::string AnfExporter::GetNameSpaceText(const parse::NameSpacePtr &ns) { | |||||
| std::ostringstream oss; | std::ostringstream oss; | ||||
| if (ns == nullptr) { | if (ns == nullptr) { | ||||
| return oss.str(); | return oss.str(); | ||||
| @@ -288,8 +288,8 @@ std::string AnfExporter::GetNameSpaceText(const parse::NameSpacePtr& ns) { | |||||
| return oss.str(); | return oss.str(); | ||||
| } | } | ||||
| std::string AnfExporter::GetSymbolicKeyInstanceText(const FuncGraphPtr& func_graph, | |||||
| const SymbolicKeyInstancePtr& sym_inst) { | |||||
| std::string AnfExporter::GetSymbolicKeyInstanceText(const FuncGraphPtr &func_graph, | |||||
| const SymbolicKeyInstancePtr &sym_inst) { | |||||
| MS_EXCEPTION_IF_NULL(func_graph); | MS_EXCEPTION_IF_NULL(func_graph); | ||||
| MS_EXCEPTION_IF_NULL(sym_inst); | MS_EXCEPTION_IF_NULL(sym_inst); | ||||
| AnfNodePtr sym_node = sym_inst->node(); | AnfNodePtr sym_node = sym_inst->node(); | ||||
| @@ -317,7 +317,7 @@ std::string AnfExporter::GetSymbolicKeyInstanceText(const FuncGraphPtr& func_gra | |||||
| return oss.str(); | return oss.str(); | ||||
| } | } | ||||
| std::string AnfExporter::GetSequenceText(const FuncGraphPtr& func_graph, const ValuePtr& value) { | |||||
| std::string AnfExporter::GetSequenceText(const FuncGraphPtr &func_graph, const ValuePtr &value) { | |||||
| std::ostringstream oss; | std::ostringstream oss; | ||||
| // output ValueList, ValueTuple | // output ValueList, ValueTuple | ||||
| ValueSequeuePtr seq = dyn_cast<ValueSequeue>(value); | ValueSequeuePtr seq = dyn_cast<ValueSequeue>(value); | ||||
| @@ -338,12 +338,12 @@ std::string AnfExporter::GetSequenceText(const FuncGraphPtr& func_graph, const V | |||||
| return oss.str(); | return oss.str(); | ||||
| } | } | ||||
| std::string AnfExporter::GetDictText(const FuncGraphPtr& func_graph, const ValuePtr& value) { | |||||
| std::string AnfExporter::GetDictText(const FuncGraphPtr &func_graph, const ValuePtr &value) { | |||||
| std::ostringstream oss; | std::ostringstream oss; | ||||
| ValueDictionaryPtr dict = value->cast<ValueDictionaryPtr>(); | ValueDictionaryPtr dict = value->cast<ValueDictionaryPtr>(); | ||||
| oss << "{"; | oss << "{"; | ||||
| bool first_flag = true; | bool first_flag = true; | ||||
| for (const auto& elem : dict->value()) { | |||||
| for (const auto &elem : dict->value()) { | |||||
| if (first_flag) { | if (first_flag) { | ||||
| first_flag = false; | first_flag = false; | ||||
| } else { | } else { | ||||
| @@ -355,7 +355,7 @@ std::string AnfExporter::GetDictText(const FuncGraphPtr& func_graph, const Value | |||||
| return oss.str(); | return oss.str(); | ||||
| } | } | ||||
| std::string AnfExporter::GetOtherValueText(const FuncGraphPtr&, const ValuePtr& value) { | |||||
| std::string AnfExporter::GetOtherValueText(const FuncGraphPtr &, const ValuePtr &value) { | |||||
| std::ostringstream oss; | std::ostringstream oss; | ||||
| if (check_integrity_) { | if (check_integrity_) { | ||||
| @@ -366,7 +366,7 @@ std::string AnfExporter::GetOtherValueText(const FuncGraphPtr&, const ValuePtr& | |||||
| return oss.str(); | return oss.str(); | ||||
| } | } | ||||
| std::string AnfExporter::GetValueText(const FuncGraphPtr& func_graph, const ValuePtr& value) { | |||||
| std::string AnfExporter::GetValueText(const FuncGraphPtr &func_graph, const ValuePtr &value) { | |||||
| std::ostringstream oss; | std::ostringstream oss; | ||||
| bool is_null_ptr = (func_graph == nullptr || value == nullptr); | bool is_null_ptr = (func_graph == nullptr || value == nullptr); | ||||
| if (is_null_ptr) { | if (is_null_ptr) { | ||||
| @@ -413,8 +413,8 @@ std::string AnfExporter::GetValueText(const FuncGraphPtr& func_graph, const Valu | |||||
| } | } | ||||
| // this function is used to output node in CNode's inputs | // this function is used to output node in CNode's inputs | ||||
| std::string AnfExporter::GetAnfNodeText(const FuncGraphPtr& func_graph, const AnfNodePtr& node, | |||||
| const std::map<AnfNodePtr, int>& apply_map) { | |||||
| std::string AnfExporter::GetAnfNodeText(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | |||||
| const std::map<AnfNodePtr, int> &apply_map) { | |||||
| std::ostringstream oss; | std::ostringstream oss; | ||||
| if (func_graph == nullptr || node == nullptr) { | if (func_graph == nullptr || node == nullptr) { | ||||
| return oss.str(); | return oss.str(); | ||||
| @@ -444,10 +444,10 @@ std::string AnfExporter::GetAnfNodeText(const FuncGraphPtr& func_graph, const An | |||||
| return oss.str(); | return oss.str(); | ||||
| } | } | ||||
| void AnfExporter::OutputParameters(std::ofstream& ofs, const std::vector<AnfNodePtr>& parameters, | |||||
| OrderedMap<AnfNodePtr, int, ParamPtrHasher, ParamPtrEqual>* param_map) { | |||||
| void AnfExporter::OutputParameters(std::ofstream &ofs, const std::vector<AnfNodePtr> ¶meters, | |||||
| OrderedMap<AnfNodePtr, int, ParamPtrHasher, ParamPtrEqual> *param_map) { | |||||
| bool first_flag = true; | bool first_flag = true; | ||||
| for (const AnfNodePtr& param : parameters) { | |||||
| for (const AnfNodePtr ¶m : parameters) { | |||||
| if (first_flag) { | if (first_flag) { | ||||
| first_flag = false; | first_flag = false; | ||||
| ofs << " "; | ofs << " "; | ||||
| @@ -479,13 +479,13 @@ void AnfExporter::OutputParameters(std::ofstream& ofs, const std::vector<AnfNode | |||||
| } | } | ||||
| } | } | ||||
| void AnfExporter::OutputStatementComment(std::ofstream& ofs, const CNodePtr& node) { | |||||
| void AnfExporter::OutputStatementComment(std::ofstream &ofs, const CNodePtr &node) { | |||||
| if (node == nullptr) { | if (node == nullptr) { | ||||
| return; | return; | ||||
| } | } | ||||
| // output type of each input argument | // output type of each input argument | ||||
| auto& inputs = node->inputs(); | |||||
| auto &inputs = node->inputs(); | |||||
| if (inputs.size() > 1) { | if (inputs.size() > 1) { | ||||
| ofs << " #("; | ofs << " #("; | ||||
| for (size_t i = 1; i < inputs.size(); ++i) { | for (size_t i = 1; i < inputs.size(); ++i) { | ||||
| @@ -521,15 +521,15 @@ void AnfExporter::OutputStatementComment(std::ofstream& ofs, const CNodePtr& nod | |||||
| ofs << " #scope: " << node->scope()->name(); | ofs << " #scope: " << node->scope()->name(); | ||||
| } | } | ||||
| void AnfExporter::OutputCNodes(std::ofstream& ofs, const std::vector<AnfNodePtr>& nodes, | |||||
| const FuncGraphPtr& func_graph) { | |||||
| void AnfExporter::OutputCNodes(std::ofstream &ofs, const std::vector<AnfNodePtr> &nodes, | |||||
| const FuncGraphPtr &func_graph) { | |||||
| if (func_graph == nullptr) { | if (func_graph == nullptr) { | ||||
| return; | return; | ||||
| } | } | ||||
| int idx = 1; | int idx = 1; | ||||
| std::map<AnfNodePtr, int> apply_map; | std::map<AnfNodePtr, int> apply_map; | ||||
| for (const AnfNodePtr& node : nodes) { | |||||
| for (const AnfNodePtr &node : nodes) { | |||||
| MS_EXCEPTION_IF_NULL(node); | MS_EXCEPTION_IF_NULL(node); | ||||
| if (!node->isa<CNode>()) { | if (!node->isa<CNode>()) { | ||||
| continue; | continue; | ||||
| @@ -541,7 +541,7 @@ void AnfExporter::OutputCNodes(std::ofstream& ofs, const std::vector<AnfNodePtr> | |||||
| } | } | ||||
| auto cnode = node->cast<CNodePtr>(); | auto cnode = node->cast<CNodePtr>(); | ||||
| auto& inputs = cnode->inputs(); | |||||
| auto &inputs = cnode->inputs(); | |||||
| std::string op_text = GetAnfNodeText(func_graph, inputs[0], apply_map); | std::string op_text = GetAnfNodeText(func_graph, inputs[0], apply_map); | ||||
| // non-return node | // non-return node | ||||
| if (node != func_graph->get_return()) { | if (node != func_graph->get_return()) { | ||||
| @@ -578,7 +578,7 @@ void AnfExporter::OutputCNodes(std::ofstream& ofs, const std::vector<AnfNodePtr> | |||||
| } | } | ||||
| } | } | ||||
| void AnfExporter::ExportOneFuncGraph(std::ofstream& ofs, const FuncGraphPtr& func_graph) { | |||||
| void AnfExporter::ExportOneFuncGraph(std::ofstream &ofs, const FuncGraphPtr &func_graph) { | |||||
| if (func_graph == nullptr) { | if (func_graph == nullptr) { | ||||
| return; | return; | ||||
| } | } | ||||
| @@ -612,7 +612,7 @@ void AnfExporter::ExportOneFuncGraph(std::ofstream& ofs, const FuncGraphPtr& fun | |||||
| ofs << "}\n"; | ofs << "}\n"; | ||||
| } | } | ||||
| void AnfExporter::ExportFuncGraph(const std::string& filename, const FuncGraphPtr& func_graph) { | |||||
| void AnfExporter::ExportFuncGraph(const std::string &filename, const FuncGraphPtr &func_graph) { | |||||
| if (func_graph == nullptr) { | if (func_graph == nullptr) { | ||||
| return; | return; | ||||
| } | } | ||||
| @@ -637,7 +637,7 @@ void AnfExporter::ExportFuncGraph(const std::string& filename, const FuncGraphPt | |||||
| ofs.close(); | ofs.close(); | ||||
| } | } | ||||
| void AnfExporter::ExportFuncGraph(const std::string& filename, const std::vector<TaggedGraph>& graphs) { | |||||
| void AnfExporter::ExportFuncGraph(const std::string &filename, const std::vector<TaggedGraph> &graphs) { | |||||
| if (graphs.empty()) { | if (graphs.empty()) { | ||||
| return; | return; | ||||
| } | } | ||||
| @@ -650,7 +650,7 @@ void AnfExporter::ExportFuncGraph(const std::string& filename, const std::vector | |||||
| param_index = 1; | param_index = 1; | ||||
| for (const auto& tagged_graph : graphs) { | |||||
| for (const auto &tagged_graph : graphs) { | |||||
| tagged_cnodes_ = tagged_graph.second; | tagged_cnodes_ = tagged_graph.second; | ||||
| ExportOneFuncGraph(ofs, tagged_graph.first); | ExportOneFuncGraph(ofs, tagged_graph.first); | ||||
| tagged_cnodes_.clear(); | tagged_cnodes_.clear(); | ||||
| @@ -663,7 +663,7 @@ void AnfExporter::ExportFuncGraph(const std::string& filename, const std::vector | |||||
| } | } | ||||
| #ifdef ENABLE_DUMP_IR | #ifdef ENABLE_DUMP_IR | ||||
| void ExportIR(const std::string& filename, const std::string& id, const FuncGraphPtr& func_graph) { | |||||
| void ExportIR(const std::string &filename, const std::string &id, const FuncGraphPtr &func_graph) { | |||||
| if (func_graph == nullptr) { | if (func_graph == nullptr) { | ||||
| return; | return; | ||||
| } | } | ||||
| @@ -675,7 +675,7 @@ void ExportIR(const std::string& filename, const std::string& id, const FuncGrap | |||||
| ChangeFileMode(filename, S_IRUSR); | ChangeFileMode(filename, S_IRUSR); | ||||
| } | } | ||||
| void ExportIR(const std::string& filename, const std::vector<TaggedGraph>& graphs) { | |||||
| void ExportIR(const std::string &filename, const std::vector<TaggedGraph> &graphs) { | |||||
| AnfExporter exporter("", false); | AnfExporter exporter("", false); | ||||
| ChangeFileMode(filename, S_IRWXU); | ChangeFileMode(filename, S_IRWXU); | ||||
| exporter.ExportFuncGraph(filename, graphs); | exporter.ExportFuncGraph(filename, graphs); | ||||
| @@ -683,7 +683,7 @@ void ExportIR(const std::string& filename, const std::vector<TaggedGraph>& graph | |||||
| ChangeFileMode(filename, S_IRUSR); | ChangeFileMode(filename, S_IRUSR); | ||||
| } | } | ||||
| #else | #else | ||||
| void ExportIR(const std::string&, const std::string&, const FuncGraphPtr&) { | |||||
| void ExportIR(const std::string &, const std::string &, const FuncGraphPtr &) { | |||||
| static bool already_printed = false; | static bool already_printed = false; | ||||
| if (already_printed) { | if (already_printed) { | ||||
| return; | return; | ||||
| @@ -693,7 +693,7 @@ void ExportIR(const std::string&, const std::string&, const FuncGraphPtr&) { | |||||
| << "please recompile source to enable it. See help of building script."; | << "please recompile source to enable it. See help of building script."; | ||||
| } | } | ||||
| void ExportIR(const std::string& filename, const std::vector<TaggedGraph>& graphs) { | |||||
| void ExportIR(const std::string &filename, const std::vector<TaggedGraph> &graphs) { | |||||
| static bool already_printed = false; | static bool already_printed = false; | ||||
| if (already_printed) { | if (already_printed) { | ||||
| return; | return; | ||||
| @@ -732,7 +732,7 @@ enum Token : int { | |||||
| TOK_ERROR // file read error | TOK_ERROR // file read error | ||||
| }; | }; | ||||
| std::map<Token, const char*> token_text = { | |||||
| std::map<Token, const char *> token_text = { | |||||
| {TOK_INVALID, "invalid"}, // invalid token | {TOK_INVALID, "invalid"}, // invalid token | ||||
| {TOK_LPARENTHESIS, "("}, // ( left parenthesis | {TOK_LPARENTHESIS, "("}, // ( left parenthesis | ||||
| {TOK_RPARENTHESIS, ")"}, // ) right parenthesis | {TOK_RPARENTHESIS, ")"}, // ) right parenthesis | ||||
| @@ -761,14 +761,14 @@ std::map<Token, const char*> token_text = { | |||||
| class Lexer { | class Lexer { | ||||
| public: | public: | ||||
| // filename is checked in ImportIR; | // filename is checked in ImportIR; | ||||
| explicit Lexer(const char* filename) : fin(filename) {} | |||||
| explicit Lexer(const char *filename) : fin(filename) {} | |||||
| ~Lexer() { | ~Lexer() { | ||||
| try { | try { | ||||
| if (fin.is_open()) { | if (fin.is_open()) { | ||||
| fin.close(); | fin.close(); | ||||
| } | } | ||||
| } catch (const std::exception& e) { | |||||
| } catch (const std::exception &e) { | |||||
| MS_LOG(ERROR) << "Exception when closing file"; | MS_LOG(ERROR) << "Exception when closing file"; | ||||
| } catch (...) { | } catch (...) { | ||||
| std::string exName(abi::__cxa_current_exception_type()->name()); | std::string exName(abi::__cxa_current_exception_type()->name()); | ||||
| @@ -776,7 +776,7 @@ class Lexer { | |||||
| } | } | ||||
| } | } | ||||
| bool IsSingleCharToken(char ch, Token* token_ptr) { | |||||
| bool IsSingleCharToken(char ch, Token *token_ptr) { | |||||
| // clang-format off | // clang-format off | ||||
| std::unordered_map<char, Token> char_to_token = { | std::unordered_map<char, Token> char_to_token = { | ||||
| {'(', TOK_LPARENTHESIS}, | {'(', TOK_LPARENTHESIS}, | ||||
| @@ -806,7 +806,7 @@ class Lexer { | |||||
| Token GetNextToken() { | Token GetNextToken() { | ||||
| #ifdef DEBUG | #ifdef DEBUG | ||||
| Token token = GetNextTokenInner(); | Token token = GetNextTokenInner(); | ||||
| const char* str = token_text[token]; | |||||
| const char *str = token_text[token]; | |||||
| std::string text = (str == nullptr ? GetTokenText() : str); | std::string text = (str == nullptr ? GetTokenText() : str); | ||||
| MS_LOG(DEBUG) << "------Parse token] " << text; | MS_LOG(DEBUG) << "------Parse token] " << text; | ||||
| return token; | return token; | ||||
| @@ -1064,11 +1064,11 @@ const unsigned Lexer::BUF_SIZE; | |||||
| class IrParser { | class IrParser { | ||||
| public: | public: | ||||
| explicit IrParser(const char* filename) : lexer_(filename) {} | |||||
| explicit IrParser(const char *filename) : lexer_(filename) {} | |||||
| ~IrParser() {} | ~IrParser() {} | ||||
| py::object LoadObject(const std::string& file_name) const { | |||||
| py::object LoadObject(const std::string &file_name) const { | |||||
| std::string pkl_path = GetMsIrPath(); | std::string pkl_path = GetMsIrPath(); | ||||
| py::object default_obj = load_obj(pkl_path + "/" + file_name); | py::object default_obj = load_obj(pkl_path + "/" + file_name); | ||||
| return default_obj; | return default_obj; | ||||
| @@ -1087,7 +1087,7 @@ class IrParser { | |||||
| MS_LOG(INFO) << "Total graphs: " << func_graphs_.size(); | MS_LOG(INFO) << "Total graphs: " << func_graphs_.size(); | ||||
| } | } | ||||
| Token ParseParent(FuncGraphPtr* const parent_ptr) { | |||||
| Token ParseParent(FuncGraphPtr *const parent_ptr) { | |||||
| if (lexer_.GetNextToken() != TOK_IDENTIFIER) { | if (lexer_.GetNextToken() != TOK_IDENTIFIER) { | ||||
| return TOK_ERROR; | return TOK_ERROR; | ||||
| } | } | ||||
| @@ -1168,7 +1168,7 @@ class IrParser { | |||||
| return func_graph; | return func_graph; | ||||
| } | } | ||||
| FuncGraphPtr ParseStatements(const FuncGraphPtr& func_graph) { | |||||
| FuncGraphPtr ParseStatements(const FuncGraphPtr &func_graph) { | |||||
| Token tok = lexer_.SkipWhiteToken(); | Token tok = lexer_.SkipWhiteToken(); | ||||
| while (tok == TOK_VARIABLE) { | while (tok == TOK_VARIABLE) { | ||||
| if (ParseStatement(func_graph) == nullptr) { | if (ParseStatement(func_graph) == nullptr) { | ||||
| @@ -1264,56 +1264,56 @@ class IrParser { | |||||
| return func_graph; | return func_graph; | ||||
| } | } | ||||
| void SetBasicType(TypePtr* ptr, const TypePtr& dtype) const { | |||||
| void SetBasicType(TypePtr *ptr, const TypePtr &dtype) const { | |||||
| if (ptr == nullptr) { | if (ptr == nullptr) { | ||||
| return; | return; | ||||
| } | } | ||||
| *ptr = dtype; | *ptr = dtype; | ||||
| } | } | ||||
| void SetTupleType(TypePtr* ptr) { | |||||
| void SetTupleType(TypePtr *ptr) { | |||||
| if (ptr == nullptr) { | if (ptr == nullptr) { | ||||
| return; | return; | ||||
| } | } | ||||
| *ptr = std::make_shared<Tuple>(); | *ptr = std::make_shared<Tuple>(); | ||||
| } | } | ||||
| void SetTupleType(TypePtr* ptr, const TypePtrList& elems) { | |||||
| void SetTupleType(TypePtr *ptr, const TypePtrList &elems) { | |||||
| if (ptr == nullptr) { | if (ptr == nullptr) { | ||||
| return; | return; | ||||
| } | } | ||||
| *ptr = std::make_shared<Tuple>(elems); | *ptr = std::make_shared<Tuple>(elems); | ||||
| } | } | ||||
| void SetArrayType(TypePtr* const ptr, const TypePtr& elem_type, const std::vector<int>&) { | |||||
| void SetArrayType(TypePtr *const ptr, const TypePtr &elem_type, const std::vector<int> &) { | |||||
| if (ptr == nullptr) { | if (ptr == nullptr) { | ||||
| return; | return; | ||||
| } | } | ||||
| *ptr = std::make_shared<TensorType>(elem_type); | *ptr = std::make_shared<TensorType>(elem_type); | ||||
| } | } | ||||
| void SetListType(TypePtr* ptr) { | |||||
| void SetListType(TypePtr *ptr) { | |||||
| if (ptr == nullptr) { | if (ptr == nullptr) { | ||||
| return; | return; | ||||
| } | } | ||||
| *ptr = std::make_shared<List>(); | *ptr = std::make_shared<List>(); | ||||
| } | } | ||||
| void SetListType(TypePtr* ptr, const TypePtrList& elems) { | |||||
| void SetListType(TypePtr *ptr, const TypePtrList &elems) { | |||||
| if (ptr == nullptr) { | if (ptr == nullptr) { | ||||
| return; | return; | ||||
| } | } | ||||
| *ptr = std::make_shared<List>(elems); | *ptr = std::make_shared<List>(elems); | ||||
| } | } | ||||
| void SetJTaggedType(TypePtr* ptr, const TypePtr& elem) { | |||||
| void SetJTaggedType(TypePtr *ptr, const TypePtr &elem) { | |||||
| if (ptr == nullptr) { | if (ptr == nullptr) { | ||||
| return; | return; | ||||
| } | } | ||||
| *ptr = std::make_shared<JTagged>(elem); | *ptr = std::make_shared<JTagged>(elem); | ||||
| } | } | ||||
| void SetBasicType(AbstractBasePtr* ptr, const TypePtr& dtype) const { | |||||
| void SetBasicType(AbstractBasePtr *ptr, const TypePtr &dtype) const { | |||||
| if (ptr == nullptr) { | if (ptr == nullptr) { | ||||
| return; | return; | ||||
| } | } | ||||
| @@ -1321,45 +1321,45 @@ class IrParser { | |||||
| } | } | ||||
| // void SetBasicType(AbstractBasePtr *ptr, const SymbolicKeyTypePtr& dtype) {} | // void SetBasicType(AbstractBasePtr *ptr, const SymbolicKeyTypePtr& dtype) {} | ||||
| void SetBasicType(AbstractBasePtr* const ptr, const TypeNonePtr&) const { | |||||
| void SetBasicType(AbstractBasePtr *const ptr, const TypeNonePtr &) const { | |||||
| if (ptr == nullptr) { | if (ptr == nullptr) { | ||||
| return; | return; | ||||
| } | } | ||||
| *ptr = std::make_shared<abstract::AbstractNone>(); | *ptr = std::make_shared<abstract::AbstractNone>(); | ||||
| } | } | ||||
| void SetBasicType(AbstractBasePtr*, const FunctionPtr&) const {} | |||||
| void SetBasicType(AbstractBasePtr*, const TensorTypePtr&) const {} | |||||
| void SetBasicType(AbstractBasePtr *, const FunctionPtr &) const {} | |||||
| void SetBasicType(AbstractBasePtr *, const TensorTypePtr &) const {} | |||||
| void SetTupleType(AbstractBasePtr* const ptr, const AbstractBasePtrList& elems) { | |||||
| void SetTupleType(AbstractBasePtr *const ptr, const AbstractBasePtrList &elems) { | |||||
| if (ptr == nullptr) { | if (ptr == nullptr) { | ||||
| return; | return; | ||||
| } | } | ||||
| // if one of elems is nullptr, just return | // if one of elems is nullptr, just return | ||||
| if (std::any_of(std::begin(elems), std::end(elems), [](const AbstractBasePtr& elem) { return elem == nullptr; })) { | |||||
| if (std::any_of(std::begin(elems), std::end(elems), [](const AbstractBasePtr &elem) { return elem == nullptr; })) { | |||||
| return; | return; | ||||
| } | } | ||||
| *ptr = std::make_shared<abstract::AbstractTuple>(elems); | *ptr = std::make_shared<abstract::AbstractTuple>(elems); | ||||
| } | } | ||||
| void SetArrayType(AbstractBasePtr* const ptr, const TypePtr& elem_type, const std::vector<int>& shape) { | |||||
| void SetArrayType(AbstractBasePtr *const ptr, const TypePtr &elem_type, const std::vector<int> &shape) { | |||||
| if (ptr == nullptr) { | if (ptr == nullptr) { | ||||
| return; | return; | ||||
| } | } | ||||
| *ptr = std::make_shared<abstract::AbstractTensor>(elem_type, shape); | *ptr = std::make_shared<abstract::AbstractTensor>(elem_type, shape); | ||||
| } | } | ||||
| void SetListType(AbstractBasePtr* const ptr, const AbstractBasePtrList& elems) { | |||||
| void SetListType(AbstractBasePtr *const ptr, const AbstractBasePtrList &elems) { | |||||
| if (ptr == nullptr) { | if (ptr == nullptr) { | ||||
| return; | return; | ||||
| } | } | ||||
| if (std::any_of(std::begin(elems), std::end(elems), [](const AbstractBasePtr& elem) { return elem == nullptr; })) { | |||||
| if (std::any_of(std::begin(elems), std::end(elems), [](const AbstractBasePtr &elem) { return elem == nullptr; })) { | |||||
| return; | return; | ||||
| } | } | ||||
| *ptr = std::make_shared<abstract::AbstractList>(elems); | *ptr = std::make_shared<abstract::AbstractList>(elems); | ||||
| } | } | ||||
| void SetJTaggedType(AbstractBasePtr* const ptr, const AbstractBasePtr& elem) { | |||||
| void SetJTaggedType(AbstractBasePtr *const ptr, const AbstractBasePtr &elem) { | |||||
| if (ptr == nullptr) { | if (ptr == nullptr) { | ||||
| return; | return; | ||||
| } | } | ||||
| @@ -1367,7 +1367,7 @@ class IrParser { | |||||
| } | } | ||||
| template <typename T> | template <typename T> | ||||
| Token ParseTypeVector(const FuncGraphPtr& func_graph, Token tok, const std::string& type, T* const ptr = nullptr) { | |||||
| Token ParseTypeVector(const FuncGraphPtr &func_graph, Token tok, const std::string &type, T *const ptr = nullptr) { | |||||
| if (tok != TOK_LBRACKET) { | if (tok != TOK_LBRACKET) { | ||||
| MS_LOG(EXCEPTION) << "Illegal case, , wrong token start symbol."; | MS_LOG(EXCEPTION) << "Illegal case, , wrong token start symbol."; | ||||
| return tok; | return tok; | ||||
| @@ -1415,7 +1415,7 @@ class IrParser { | |||||
| } | } | ||||
| template <typename T> | template <typename T> | ||||
| Token ParseTypeArray(const FuncGraphPtr& func_graph, Token tok, T* const ptr = nullptr) { | |||||
| Token ParseTypeArray(const FuncGraphPtr &func_graph, Token tok, T *const ptr = nullptr) { | |||||
| if (tok != TOK_LPARENTHESIS) { | if (tok != TOK_LPARENTHESIS) { | ||||
| if (ptr != nullptr) { | if (ptr != nullptr) { | ||||
| SetBasicType(ptr, std::make_shared<TensorType>()); | SetBasicType(ptr, std::make_shared<TensorType>()); | ||||
| @@ -1454,7 +1454,7 @@ class IrParser { | |||||
| return lexer_.GetNextToken(); | return lexer_.GetNextToken(); | ||||
| } | } | ||||
| bool IsNumberType(const std::string& type, TypeId* typeid_ptr) { | |||||
| bool IsNumberType(const std::string &type, TypeId *typeid_ptr) { | |||||
| // clang-format off | // clang-format off | ||||
| static std::unordered_map<std::string, TypeId> basic_types = { | static std::unordered_map<std::string, TypeId> basic_types = { | ||||
| {"Bool", kNumberTypeBool}, | {"Bool", kNumberTypeBool}, | ||||
| @@ -1486,7 +1486,7 @@ class IrParser { | |||||
| } | } | ||||
| template <typename T> | template <typename T> | ||||
| void ParseNumberType(const std::string& type, TypeId typeId, T* const ptr = nullptr) { | |||||
| void ParseNumberType(const std::string &type, TypeId typeId, T *const ptr = nullptr) { | |||||
| TypePtr dtype = nullptr; | TypePtr dtype = nullptr; | ||||
| std::unordered_map<int, TypePtr> type_map = { | std::unordered_map<int, TypePtr> type_map = { | ||||
| @@ -1519,7 +1519,7 @@ class IrParser { | |||||
| } | } | ||||
| template <typename T> | template <typename T> | ||||
| Token ParseTrivalType(const std::string& type, T* const ptr = nullptr) { | |||||
| Token ParseTrivalType(const std::string &type, T *const ptr = nullptr) { | |||||
| if (type == "NoneType") { | if (type == "NoneType") { | ||||
| SetBasicType(ptr, std::make_shared<TypeNone>()); | SetBasicType(ptr, std::make_shared<TypeNone>()); | ||||
| return lexer_.GetNextToken(); | return lexer_.GetNextToken(); | ||||
| @@ -1541,7 +1541,7 @@ class IrParser { | |||||
| } | } | ||||
| template <typename T> | template <typename T> | ||||
| Token ParseOneType(const FuncGraphPtr& func_graph, Token tok, T* const ptr = nullptr) { | |||||
| Token ParseOneType(const FuncGraphPtr &func_graph, Token tok, T *const ptr = nullptr) { | |||||
| if (tok != TOK_IDENTIFIER) { | if (tok != TOK_IDENTIFIER) { | ||||
| return TOK_ERROR; | return TOK_ERROR; | ||||
| } | } | ||||
| @@ -1588,11 +1588,11 @@ class IrParser { | |||||
| } | } | ||||
| } | } | ||||
| Token ParseType(const FuncGraphPtr& func_graph, AbstractBasePtr* const abstract = nullptr) { | |||||
| Token ParseType(const FuncGraphPtr &func_graph, AbstractBasePtr *const abstract = nullptr) { | |||||
| return ParseOneType(func_graph, lexer_.GetNextToken(), abstract); | return ParseOneType(func_graph, lexer_.GetNextToken(), abstract); | ||||
| } | } | ||||
| Token ParseAttributes(const FuncGraphPtr& func_graph, const PrimitivePtr& prim) { | |||||
| Token ParseAttributes(const FuncGraphPtr &func_graph, const PrimitivePtr &prim) { | |||||
| Token tok = ParseAttribute(func_graph, prim); | Token tok = ParseAttribute(func_graph, prim); | ||||
| while (tok == TOK_COMMA) { | while (tok == TOK_COMMA) { | ||||
| tok = ParseAttribute(func_graph, prim); | tok = ParseAttribute(func_graph, prim); | ||||
| @@ -1603,7 +1603,7 @@ class IrParser { | |||||
| return lexer_.GetNextToken(); | return lexer_.GetNextToken(); | ||||
| } | } | ||||
| Token ParseAttribute(const FuncGraphPtr& func_graph, const PrimitivePtr& prim) { | |||||
| Token ParseAttribute(const FuncGraphPtr &func_graph, const PrimitivePtr &prim) { | |||||
| Token tok = lexer_.GetNextToken(); | Token tok = lexer_.GetNextToken(); | ||||
| if (tok != TOK_IDENTIFIER) { | if (tok != TOK_IDENTIFIER) { | ||||
| return TOK_ERROR; | return TOK_ERROR; | ||||
| @@ -1670,7 +1670,7 @@ class IrParser { | |||||
| return tok == TOK_RPARENTHESIS ? func_graph : nullptr; | return tok == TOK_RPARENTHESIS ? func_graph : nullptr; | ||||
| } | } | ||||
| FuncGraphPtr ParseArguments(FuncGraphPtr func_graph, std::vector<AnfNodePtr>* const inputs_ptr) { | |||||
| FuncGraphPtr ParseArguments(FuncGraphPtr func_graph, std::vector<AnfNodePtr> *const inputs_ptr) { | |||||
| Token tok = ParseArgument(func_graph, inputs_ptr); | Token tok = ParseArgument(func_graph, inputs_ptr); | ||||
| while (tok == TOK_COMMA) { | while (tok == TOK_COMMA) { | ||||
| tok = ParseArgument(func_graph, inputs_ptr); | tok = ParseArgument(func_graph, inputs_ptr); | ||||
| @@ -1681,9 +1681,9 @@ class IrParser { | |||||
| return func_graph; | return func_graph; | ||||
| } | } | ||||
| AnfNodePtr FindParameter(FuncGraphPtr func_graph, const std::string& param_name) { | |||||
| AnfNodePtr FindParameter(FuncGraphPtr func_graph, const std::string ¶m_name) { | |||||
| while (func_graph != nullptr) { | while (func_graph != nullptr) { | ||||
| for (auto& ptr : func_graph->parameters()) { | |||||
| for (auto &ptr : func_graph->parameters()) { | |||||
| MS_EXCEPTION_IF_NULL(ptr); | MS_EXCEPTION_IF_NULL(ptr); | ||||
| ParameterPtr param = ptr->cast<ParameterPtr>(); | ParameterPtr param = ptr->cast<ParameterPtr>(); | ||||
| MS_EXCEPTION_IF_NULL(param); | MS_EXCEPTION_IF_NULL(param); | ||||
| @@ -1701,12 +1701,12 @@ class IrParser { | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| bool Match(const std::string& str, const std::string& pattern) const { | |||||
| bool Match(const std::string &str, const std::string &pattern) const { | |||||
| return strncmp(str.c_str(), pattern.c_str(), pattern.length()) == 0; | return strncmp(str.c_str(), pattern.c_str(), pattern.length()) == 0; | ||||
| } | } | ||||
| template <typename T, typename V> | template <typename T, typename V> | ||||
| Token ParseScalar(ValuePtr* const val_ptr) { | |||||
| Token ParseScalar(ValuePtr *const val_ptr) { | |||||
| if (lexer_.GetNextToken() != TOK_NUMBER) { | if (lexer_.GetNextToken() != TOK_NUMBER) { | ||||
| return TOK_ERROR; | return TOK_ERROR; | ||||
| } | } | ||||
| @@ -1725,7 +1725,7 @@ class IrParser { | |||||
| } | } | ||||
| template <typename VT, typename V, typename T> | template <typename VT, typename V, typename T> | ||||
| Token ParseScalar(ValuePtr* const val_ptr, Token tok) { | |||||
| Token ParseScalar(ValuePtr *const val_ptr, Token tok) { | |||||
| if (tok != TOK_LPARENTHESIS) { | if (tok != TOK_LPARENTHESIS) { | ||||
| *val_ptr = std::make_shared<T>(); | *val_ptr = std::make_shared<T>(); | ||||
| return tok; | return tok; | ||||
| @@ -1735,7 +1735,7 @@ class IrParser { | |||||
| } | } | ||||
| template <typename VT, typename V, typename T, const unsigned nbits> | template <typename VT, typename V, typename T, const unsigned nbits> | ||||
| Token ParseScalar(ValuePtr* const val_ptr, Token tok) { | |||||
| Token ParseScalar(ValuePtr *const val_ptr, Token tok) { | |||||
| if (tok != TOK_LPARENTHESIS) { | if (tok != TOK_LPARENTHESIS) { | ||||
| *val_ptr = std::make_shared<T>(nbits); | *val_ptr = std::make_shared<T>(nbits); | ||||
| return tok; | return tok; | ||||
| @@ -1745,7 +1745,7 @@ class IrParser { | |||||
| } | } | ||||
| template <typename T> | template <typename T> | ||||
| T StringToScalar(const std::string& text) { | |||||
| T StringToScalar(const std::string &text) { | |||||
| std::stringstream ss; | std::stringstream ss; | ||||
| T value; | T value; | ||||
| ss << text; | ss << text; | ||||
| @@ -1753,7 +1753,7 @@ class IrParser { | |||||
| return value; | return value; | ||||
| } | } | ||||
| Token ParseTensor(ValuePtr* const val_ptr) { | |||||
| Token ParseTensor(ValuePtr *const val_ptr) { | |||||
| // parse type | // parse type | ||||
| TypeId type; | TypeId type; | ||||
| if (lexer_.GetNextToken() != TOK_LPARENTHESIS) { | if (lexer_.GetNextToken() != TOK_LPARENTHESIS) { | ||||
| @@ -1803,7 +1803,7 @@ class IrParser { | |||||
| return lexer_.GetNextToken(); | return lexer_.GetNextToken(); | ||||
| } | } | ||||
| Token ParsePrimType(Token tok, PrimType* prim_type_ptr) { | |||||
| Token ParsePrimType(Token tok, PrimType *prim_type_ptr) { | |||||
| if (tok != TOK_LBRACE) { | if (tok != TOK_LBRACE) { | ||||
| return tok; | return tok; | ||||
| } | } | ||||
| @@ -1830,7 +1830,7 @@ class IrParser { | |||||
| return lexer_.GetNextToken(); | return lexer_.GetNextToken(); | ||||
| } | } | ||||
| Token ParseMultitypeFuncGraphItem(const prim::MultitypeFuncGraphPtr& mt_func_graph, Token tok) { | |||||
| Token ParseMultitypeFuncGraphItem(const prim::MultitypeFuncGraphPtr &mt_func_graph, Token tok) { | |||||
| if (tok != TOK_LPARENTHESIS) { | if (tok != TOK_LPARENTHESIS) { | ||||
| return TOK_ERROR; | return TOK_ERROR; | ||||
| } | } | ||||
| @@ -1855,7 +1855,7 @@ class IrParser { | |||||
| return lexer_.GetNextToken(); | return lexer_.GetNextToken(); | ||||
| } | } | ||||
| Token ParseMultitypeFuncGraph(const prim::MultitypeFuncGraphPtr& mt_func_graph, Token tok) { | |||||
| Token ParseMultitypeFuncGraph(const prim::MultitypeFuncGraphPtr &mt_func_graph, Token tok) { | |||||
| if (tok != TOK_LBRACE) { | if (tok != TOK_LBRACE) { | ||||
| return tok; | return tok; | ||||
| } | } | ||||
| @@ -1868,7 +1868,7 @@ class IrParser { | |||||
| return lexer_.GetNextToken(); | return lexer_.GetNextToken(); | ||||
| } | } | ||||
| Token ParseBoolValue(const std::string& key, bool* val_ptr) { | |||||
| Token ParseBoolValue(const std::string &key, bool *val_ptr) { | |||||
| if (lexer_.GetNextToken() != TOK_IDENTIFIER || lexer_.GetTokenText() != key) { | if (lexer_.GetNextToken() != TOK_IDENTIFIER || lexer_.GetTokenText() != key) { | ||||
| return TOK_ERROR; | return TOK_ERROR; | ||||
| } | } | ||||
| @@ -1892,7 +1892,7 @@ class IrParser { | |||||
| return lexer_.GetNextToken(); | return lexer_.GetNextToken(); | ||||
| } | } | ||||
| Token ParseValueGradOperation(const std::string& name, ValuePtr* const val_ptr) { | |||||
| Token ParseValueGradOperation(const std::string &name, ValuePtr *const val_ptr) { | |||||
| if (lexer_.GetNextToken() != TOK_LBRACE) { | if (lexer_.GetNextToken() != TOK_LBRACE) { | ||||
| return TOK_ERROR; | return TOK_ERROR; | ||||
| } | } | ||||
| @@ -1920,7 +1920,7 @@ class IrParser { | |||||
| return lexer_.GetNextToken(); | return lexer_.GetNextToken(); | ||||
| } | } | ||||
| Token ParseSymbolicKeyInstance(const FuncGraphPtr& func_graph, AnfNodePtr* const node_ptr = nullptr) { | |||||
| Token ParseSymbolicKeyInstance(const FuncGraphPtr &func_graph, AnfNodePtr *const node_ptr = nullptr) { | |||||
| if (lexer_.GetNextToken() != TOK_LPARENTHESIS) { | if (lexer_.GetNextToken() != TOK_LPARENTHESIS) { | ||||
| return TOK_ERROR; | return TOK_ERROR; | ||||
| } | } | ||||
| @@ -1951,7 +1951,7 @@ class IrParser { | |||||
| return lexer_.GetNextToken(); | return lexer_.GetNextToken(); | ||||
| } | } | ||||
| Token ParsePrimitivePy(const FuncGraphPtr& func_graph, const std::string& id, ValuePtr* const val_ptr) { | |||||
| Token ParsePrimitivePy(const FuncGraphPtr &func_graph, const std::string &id, ValuePtr *const val_ptr) { | |||||
| if (lexer_.GetNextToken() != TOK_AT_FILE) { | if (lexer_.GetNextToken() != TOK_AT_FILE) { | ||||
| return TOK_ERROR; | return TOK_ERROR; | ||||
| } | } | ||||
| @@ -1984,7 +1984,7 @@ class IrParser { | |||||
| return next; | return next; | ||||
| } | } | ||||
| Token ParseValueGraphAndNamespace(const std::string& id, ValuePtr* val_ptr) { | |||||
| Token ParseValueGraphAndNamespace(const std::string &id, ValuePtr *val_ptr) { | |||||
| if (Match(id, "MultitypeFuncGraph::")) { | if (Match(id, "MultitypeFuncGraph::")) { | ||||
| std::string name = id.substr(strlen("MultitypeFuncGraph::")); | std::string name = id.substr(strlen("MultitypeFuncGraph::")); | ||||
| auto mt_func_graph = std::make_shared<prim::MultitypeFuncGraph>(name); | auto mt_func_graph = std::make_shared<prim::MultitypeFuncGraph>(name); | ||||
| @@ -2024,8 +2024,8 @@ class IrParser { | |||||
| } | } | ||||
| } | } | ||||
| Token ParseValueBasic(const FuncGraphPtr& func_graph, const std::string& id, ValuePtr* val_ptr, | |||||
| AnfNodePtr* const node_ptr = nullptr) { | |||||
| Token ParseValueBasic(const FuncGraphPtr &func_graph, const std::string &id, ValuePtr *val_ptr, | |||||
| AnfNodePtr *const node_ptr = nullptr) { | |||||
| if (id == "None") { | if (id == "None") { | ||||
| *val_ptr = std::make_shared<None>(); | *val_ptr = std::make_shared<None>(); | ||||
| return lexer_.GetNextToken(); | return lexer_.GetNextToken(); | ||||
| @@ -2075,9 +2075,9 @@ class IrParser { | |||||
| } | } | ||||
| } | } | ||||
| Token SetListOrTupleValue(const FuncGraphPtr& func_graph, Token left_tok, Token next, bool node_is_valid, | |||||
| const std::vector<ValuePtr>& elems, const std::vector<AnfNodePtr>& nodes, | |||||
| ValuePtr* const val_ptr, AnfNodePtr* node_ptr) { | |||||
| Token SetListOrTupleValue(const FuncGraphPtr &func_graph, Token left_tok, Token next, bool node_is_valid, | |||||
| const std::vector<ValuePtr> &elems, const std::vector<AnfNodePtr> &nodes, | |||||
| ValuePtr *const val_ptr, AnfNodePtr *node_ptr) { | |||||
| if (left_tok == TOK_LPARENTHESIS && next == TOK_RPARENTHESIS) { | if (left_tok == TOK_LPARENTHESIS && next == TOK_RPARENTHESIS) { | ||||
| if (node_is_valid && node_ptr != nullptr) { | if (node_is_valid && node_ptr != nullptr) { | ||||
| MS_EXCEPTION_IF_NULL(func_graph); | MS_EXCEPTION_IF_NULL(func_graph); | ||||
| @@ -2097,8 +2097,8 @@ class IrParser { | |||||
| } | } | ||||
| } | } | ||||
| Token ParseListOrTupleValue(const FuncGraphPtr& func_graph, Token tok, ValuePtr* const val_ptr, | |||||
| AnfNodePtr* node_ptr = nullptr) { | |||||
| Token ParseListOrTupleValue(const FuncGraphPtr &func_graph, Token tok, ValuePtr *const val_ptr, | |||||
| AnfNodePtr *node_ptr = nullptr) { | |||||
| Token left_tok = tok; | Token left_tok = tok; | ||||
| std::vector<ValuePtr> elems; | std::vector<ValuePtr> elems; | ||||
| @@ -2138,7 +2138,7 @@ class IrParser { | |||||
| return SetListOrTupleValue(func_graph, left_tok, next, node_is_valid, elems, nodes, val_ptr, node_ptr); | return SetListOrTupleValue(func_graph, left_tok, next, node_is_valid, elems, nodes, val_ptr, node_ptr); | ||||
| } | } | ||||
| Token ParseValue(const FuncGraphPtr& func_graph, Token tok, ValuePtr* const val_ptr, AnfNodePtr* node_ptr = nullptr) { | |||||
| Token ParseValue(const FuncGraphPtr &func_graph, Token tok, ValuePtr *const val_ptr, AnfNodePtr *node_ptr = nullptr) { | |||||
| // tuple or list | // tuple or list | ||||
| if (tok == TOK_LPARENTHESIS || tok == TOK_LBRACKET) { | if (tok == TOK_LPARENTHESIS || tok == TOK_LBRACKET) { | ||||
| return ParseListOrTupleValue(func_graph, tok, val_ptr, node_ptr); | return ParseListOrTupleValue(func_graph, tok, val_ptr, node_ptr); | ||||
| @@ -2152,7 +2152,7 @@ class IrParser { | |||||
| return TOK_ERROR; | return TOK_ERROR; | ||||
| } | } | ||||
| Token ParseItem(const FuncGraphPtr& func_graph, AnfNodePtr* node_ptr, ValuePtr* const val_ptr, | |||||
| Token ParseItem(const FuncGraphPtr &func_graph, AnfNodePtr *node_ptr, ValuePtr *const val_ptr, | |||||
| Token tok = TOK_INVALID) { | Token tok = TOK_INVALID) { | ||||
| if (tok == TOK_INVALID) { | if (tok == TOK_INVALID) { | ||||
| tok = lexer_.GetNextToken(); | tok = lexer_.GetNextToken(); | ||||
| @@ -2193,7 +2193,7 @@ class IrParser { | |||||
| return lexer_.GetNextToken(); | return lexer_.GetNextToken(); | ||||
| } | } | ||||
| Token ParseArgument(const FuncGraphPtr& func_graph, std::vector<AnfNodePtr>* const inputs_ptr) { | |||||
| Token ParseArgument(const FuncGraphPtr &func_graph, std::vector<AnfNodePtr> *const inputs_ptr) { | |||||
| Token tok = lexer_.GetNextToken(); | Token tok = lexer_.GetNextToken(); | ||||
| if (tok == TOK_RPARENTHESIS) { | if (tok == TOK_RPARENTHESIS) { | ||||
| return tok; | return tok; | ||||
| @@ -2208,7 +2208,7 @@ class IrParser { | |||||
| return tok; | return tok; | ||||
| } | } | ||||
| const std::vector<FuncGraphPtr>& GetFuncGraphs() const { return func_graphs_; } | |||||
| const std::vector<FuncGraphPtr> &GetFuncGraphs() const { return func_graphs_; } | |||||
| private: | private: | ||||
| Lexer lexer_; | Lexer lexer_; | ||||
| @@ -2226,14 +2226,14 @@ class IrParser { | |||||
| std::map<std::string, ParameterPtr> param_nodes_; // map parameter name to parameter | std::map<std::string, ParameterPtr> param_nodes_; // map parameter name to parameter | ||||
| }; | }; | ||||
| std::vector<FuncGraphPtr> ImportIR(const std::string& filename) { | |||||
| std::vector<FuncGraphPtr> ImportIR(const std::string &filename) { | |||||
| IrParser parser(filename.c_str()); | IrParser parser(filename.c_str()); | ||||
| parser.ParseFile(); | parser.ParseFile(); | ||||
| return parser.GetFuncGraphs(); | return parser.GetFuncGraphs(); | ||||
| } | } | ||||
| #ifdef ENABLE_DUMP_IR | #ifdef ENABLE_DUMP_IR | ||||
| void DumpIRProto(const FuncGraphPtr& func_graph, const std::string& suffix) { | |||||
| void DumpIRProto(const FuncGraphPtr &func_graph, const std::string &suffix) { | |||||
| if (func_graph == nullptr) { | if (func_graph == nullptr) { | ||||
| MS_LOG(ERROR) << "Func graph is nullptr"; | MS_LOG(ERROR) << "Func graph is nullptr"; | ||||
| return; | return; | ||||
| @@ -2253,7 +2253,7 @@ void DumpIRProto(const FuncGraphPtr& func_graph, const std::string& suffix) { | |||||
| return; | return; | ||||
| } | } | ||||
| char real_path[PATH_MAX] = {0}; | char real_path[PATH_MAX] = {0}; | ||||
| char* real_path_ret = nullptr; | |||||
| char *real_path_ret = nullptr; | |||||
| #if defined(_WIN32) || defined(_WIN64) | #if defined(_WIN32) || defined(_WIN64) | ||||
| real_path_ret = _fullpath(real_path, file_path.c_str(), PATH_MAX); | real_path_ret = _fullpath(real_path, file_path.c_str(), PATH_MAX); | ||||
| #else | #else | ||||
| @@ -2281,7 +2281,7 @@ void DumpIRProto(const FuncGraphPtr& func_graph, const std::string& suffix) { | |||||
| ChangeFileMode(file_path, S_IRUSR); | ChangeFileMode(file_path, S_IRUSR); | ||||
| } | } | ||||
| #else | #else | ||||
| void DumpIRProto(const FuncGraphPtr&, const std::string&) { | |||||
| void DumpIRProto(const FuncGraphPtr &, const std::string &) { | |||||
| static bool already_printed = false; | static bool already_printed = false; | ||||
| if (already_printed) { | if (already_printed) { | ||||
| return; | return; | ||||
| @@ -39,7 +39,7 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| struct ParamPtrEqual { | struct ParamPtrEqual { | ||||
| bool operator()(AnfNodePtr const& t1, AnfNodePtr const& t2) const { | |||||
| bool operator()(AnfNodePtr const &t1, AnfNodePtr const &t2) const { | |||||
| const ParameterPtr param1 = dyn_cast<Parameter>(t1); | const ParameterPtr param1 = dyn_cast<Parameter>(t1); | ||||
| const ParameterPtr param2 = dyn_cast<Parameter>(t2); | const ParameterPtr param2 = dyn_cast<Parameter>(t2); | ||||
| @@ -52,7 +52,7 @@ struct ParamPtrEqual { | |||||
| }; | }; | ||||
| struct ParamPtrHasher { | struct ParamPtrHasher { | ||||
| std::size_t operator()(AnfNodePtr const& param) const { | |||||
| std::size_t operator()(AnfNodePtr const ¶m) const { | |||||
| const ParameterPtr parameter = dyn_cast<Parameter>(param); | const ParameterPtr parameter = dyn_cast<Parameter>(param); | ||||
| if (parameter == nullptr) { | if (parameter == nullptr) { | ||||
| return 0; | return 0; | ||||
| @@ -64,39 +64,39 @@ struct ParamPtrHasher { | |||||
| class AnfExporter { | class AnfExporter { | ||||
| public: | public: | ||||
| explicit AnfExporter(const std::string& id, bool export_used = true, bool check_integrity = false) | |||||
| explicit AnfExporter(const std::string &id, bool export_used = true, bool check_integrity = false) | |||||
| : param_index(-1), id_(id), export_used_(export_used), check_integrity_(check_integrity) { | : param_index(-1), id_(id), export_used_(export_used), check_integrity_(check_integrity) { | ||||
| func_graph_set.clear(); | func_graph_set.clear(); | ||||
| exported.clear(); | exported.clear(); | ||||
| } | } | ||||
| virtual ~AnfExporter() {} | virtual ~AnfExporter() {} | ||||
| void ExportFuncGraph(const std::string& filename, const FuncGraphPtr& func_graph); | |||||
| void ExportFuncGraph(const std::string& filename, const std::vector<TaggedGraph>& graphs); | |||||
| void ExportFuncGraph(const std::string &filename, const FuncGraphPtr &func_graph); | |||||
| void ExportFuncGraph(const std::string &filename, const std::vector<TaggedGraph> &graphs); | |||||
| protected: | protected: | ||||
| virtual std::string GetNodeType(const AnfNodePtr& nd); | |||||
| int GetParamIndex(const FuncGraphPtr& func_graph, const AnfNodePtr& param, bool throw_excp = true); | |||||
| int GetParamIndexFromExported(const AnfNodePtr& param); | |||||
| std::string DumpObject(const py::object& obj, const std::string& category) const; | |||||
| std::string GetValueNodeText(const FuncGraphPtr& func_graph, const ValueNodePtr& node); | |||||
| std::string GetMultitypeFuncGraphText(const prim::MultitypeFuncGraphPtr& mt_func_graph); | |||||
| std::string GetSymbolicKeyInstanceText(const FuncGraphPtr& func_graph, const SymbolicKeyInstancePtr& sym_inst); | |||||
| std::string GetSequenceText(const FuncGraphPtr& func_graph, const ValuePtr& value); | |||||
| std::string GetValueText(const FuncGraphPtr& func_graph, const ValuePtr& value); | |||||
| std::string GetOtherValueText(const FuncGraphPtr& func_graph, const ValuePtr& value); | |||||
| std::string GetPrimitiveText(const PrimitivePtr& prim); | |||||
| std::string GetDictText(const FuncGraphPtr& func_graph, const ValuePtr& value); | |||||
| std::string GetNameSpaceText(const parse::NameSpacePtr& ns); | |||||
| std::string GetMetaFuncGraphText(const MetaFuncGraphPtr& meta_func_graph); | |||||
| std::string GetAnfNodeText(const FuncGraphPtr& func_graph, const AnfNodePtr& node, | |||||
| const std::map<AnfNodePtr, int>& apply_map); | |||||
| void ExportOneFuncGraph(std::ofstream& ofs, const FuncGraphPtr& func_graph); | |||||
| void OutputParameters(std::ofstream& ofs, const std::vector<AnfNodePtr>& parameters, | |||||
| OrderedMap<AnfNodePtr, int, ParamPtrHasher, ParamPtrEqual>* param_map); | |||||
| void OutputStatementComment(std::ofstream& ofs, const CNodePtr& node); | |||||
| void OutputCNodes(std::ofstream& ofs, const std::vector<AnfNodePtr>& nodes, const FuncGraphPtr& func_graph); | |||||
| virtual std::string GetNodeType(const AnfNodePtr &nd); | |||||
| int GetParamIndex(const FuncGraphPtr &func_graph, const AnfNodePtr ¶m, bool throw_excp = true); | |||||
| int GetParamIndexFromExported(const AnfNodePtr ¶m); | |||||
| std::string DumpObject(const py::object &obj, const std::string &category) const; | |||||
| std::string GetValueNodeText(const FuncGraphPtr &func_graph, const ValueNodePtr &node); | |||||
| std::string GetMultitypeFuncGraphText(const prim::MultitypeFuncGraphPtr &mt_func_graph); | |||||
| std::string GetSymbolicKeyInstanceText(const FuncGraphPtr &func_graph, const SymbolicKeyInstancePtr &sym_inst); | |||||
| std::string GetSequenceText(const FuncGraphPtr &func_graph, const ValuePtr &value); | |||||
| std::string GetValueText(const FuncGraphPtr &func_graph, const ValuePtr &value); | |||||
| std::string GetOtherValueText(const FuncGraphPtr &func_graph, const ValuePtr &value); | |||||
| std::string GetPrimitiveText(const PrimitivePtr &prim); | |||||
| std::string GetDictText(const FuncGraphPtr &func_graph, const ValuePtr &value); | |||||
| std::string GetNameSpaceText(const parse::NameSpacePtr &ns); | |||||
| std::string GetMetaFuncGraphText(const MetaFuncGraphPtr &meta_func_graph); | |||||
| std::string GetAnfNodeText(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | |||||
| const std::map<AnfNodePtr, int> &apply_map); | |||||
| void ExportOneFuncGraph(std::ofstream &ofs, const FuncGraphPtr &func_graph); | |||||
| void OutputParameters(std::ofstream &ofs, const std::vector<AnfNodePtr> ¶meters, | |||||
| OrderedMap<AnfNodePtr, int, ParamPtrHasher, ParamPtrEqual> *param_map); | |||||
| void OutputStatementComment(std::ofstream &ofs, const CNodePtr &node); | |||||
| void OutputCNodes(std::ofstream &ofs, const std::vector<AnfNodePtr> &nodes, const FuncGraphPtr &func_graph); | |||||
| int param_index; | int param_index; | ||||
| OrderedSet<FuncGraphPtr> func_graph_set{}; | OrderedSet<FuncGraphPtr> func_graph_set{}; | ||||
| @@ -108,16 +108,16 @@ class AnfExporter { | |||||
| abstract::AnfNodeConfigPtr node_cfg_ = nullptr; | abstract::AnfNodeConfigPtr node_cfg_ = nullptr; | ||||
| }; | }; | ||||
| void ExportIR(const std::string& filename, const std::string& id, const FuncGraphPtr& func_graph); | |||||
| void ExportIR(const std::string& filename, const std::vector<TaggedGraph>& graphs); | |||||
| void ExportIR(const std::string &filename, const std::string &id, const FuncGraphPtr &func_graph); | |||||
| void ExportIR(const std::string &filename, const std::vector<TaggedGraph> &graphs); | |||||
| std::vector<FuncGraphPtr> ImportIR(const std::string& filename); | |||||
| std::vector<FuncGraphPtr> ImportIR(const std::string &filename); | |||||
| std::string GetFuncGraphProtoString(const FuncGraphPtr& func_graph); | |||||
| std::string GetFuncGraphProtoString(const FuncGraphPtr &func_graph); | |||||
| void DumpIRProto(const FuncGraphPtr& func_graph, const std::string& suffix); | |||||
| void DumpIRProto(const FuncGraphPtr &func_graph, const std::string &suffix); | |||||
| std::string GetOnnxProtoString(const FuncGraphPtr& func_graph); | |||||
| std::string GetOnnxProtoString(const FuncGraphPtr &func_graph); | |||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CCSRC_DEBUG_ANF_IR_UTILS_H_ | #endif // MINDSPORE_CCSRC_DEBUG_ANF_IR_UTILS_H_ | ||||
| @@ -34,7 +34,7 @@ namespace draw { | |||||
| namespace { | namespace { | ||||
| // Only for ValueNode | // Only for ValueNode | ||||
| std::string ValueType(const ValueNodePtr& node) { | |||||
| std::string ValueType(const ValueNodePtr &node) { | |||||
| if (node == nullptr) { | if (node == nullptr) { | ||||
| return ""; | return ""; | ||||
| } | } | ||||
| @@ -43,7 +43,7 @@ std::string ValueType(const ValueNodePtr& node) { | |||||
| return v->type_name(); | return v->type_name(); | ||||
| } | } | ||||
| std::string ReplaceSpecialChar(const std::string& str) { | |||||
| std::string ReplaceSpecialChar(const std::string &str) { | |||||
| std::ostringstream oss; | std::ostringstream oss; | ||||
| for (size_t i = 0; i < str.size(); i++) { | for (size_t i = 0; i < str.size(); i++) { | ||||
| if (str[i] == '<') { | if (str[i] == '<') { | ||||
| @@ -59,12 +59,12 @@ std::string ReplaceSpecialChar(const std::string& str) { | |||||
| } // namespace | } // namespace | ||||
| // API of debug utils | // API of debug utils | ||||
| void DrawNodes(const std::vector<AnfNodePtr>& nodes, OrderedMap<FuncGraphPtr, std::shared_ptr<BaseDigraph>>* sub_graphs, | |||||
| void DrawNodes(const std::vector<AnfNodePtr> &nodes, OrderedMap<FuncGraphPtr, std::shared_ptr<BaseDigraph>> *sub_graphs, | |||||
| bool is_user) { | bool is_user) { | ||||
| if (sub_graphs == nullptr) { | if (sub_graphs == nullptr) { | ||||
| return; | return; | ||||
| } | } | ||||
| for (auto& nd : nodes) { | |||||
| for (auto &nd : nodes) { | |||||
| MS_EXCEPTION_IF_NULL(nd); | MS_EXCEPTION_IF_NULL(nd); | ||||
| auto sub_graph = nd->func_graph(); | auto sub_graph = nd->func_graph(); | ||||
| if (sub_graph != nullptr) { | if (sub_graph != nullptr) { | ||||
| @@ -84,16 +84,16 @@ void DrawNodes(const std::vector<AnfNodePtr>& nodes, OrderedMap<FuncGraphPtr, st | |||||
| } | } | ||||
| } | } | ||||
| void DrawValueNodes(const std::vector<AnfNodePtr>& nodes, | |||||
| OrderedMap<FuncGraphPtr, std::shared_ptr<BaseDigraph>>* sub_graphs) { | |||||
| void DrawValueNodes(const std::vector<AnfNodePtr> &nodes, | |||||
| OrderedMap<FuncGraphPtr, std::shared_ptr<BaseDigraph>> *sub_graphs) { | |||||
| if (sub_graphs == nullptr) { | if (sub_graphs == nullptr) { | ||||
| return; | return; | ||||
| } | } | ||||
| int dup_idx = 0; | int dup_idx = 0; | ||||
| for (auto& nd : nodes) { | |||||
| for (auto& t : SuccIncoming(nd)) { | |||||
| for (auto &nd : nodes) { | |||||
| for (auto &t : SuccIncoming(nd)) { | |||||
| MS_EXCEPTION_IF_NULL(t); | MS_EXCEPTION_IF_NULL(t); | ||||
| MS_EXCEPTION_IF_NULL(nd); | MS_EXCEPTION_IF_NULL(nd); | ||||
| if (t->isa<ValueNode>() && (*sub_graphs).find(nd->func_graph()) != (*sub_graphs).end()) { | if (t->isa<ValueNode>() && (*sub_graphs).find(nd->func_graph()) != (*sub_graphs).end()) { | ||||
| @@ -107,7 +107,7 @@ void DrawValueNodes(const std::vector<AnfNodePtr>& nodes, | |||||
| } | } | ||||
| } | } | ||||
| void DrawEdges(const std::vector<AnfNodePtr>& nodes, const std::shared_ptr<BaseDigraph>& digraph, bool is_user) { | |||||
| void DrawEdges(const std::vector<AnfNodePtr> &nodes, const std::shared_ptr<BaseDigraph> &digraph, bool is_user) { | |||||
| if (digraph == nullptr) { | if (digraph == nullptr) { | ||||
| return; | return; | ||||
| } | } | ||||
| @@ -120,11 +120,11 @@ void DrawEdges(const std::vector<AnfNodePtr>& nodes, const std::shared_ptr<BaseD | |||||
| } | } | ||||
| // Draw edge | // Draw edge | ||||
| for (auto& nd : nodes) { | |||||
| for (auto &nd : nodes) { | |||||
| auto succs = SuccIncoming(nd); | auto succs = SuccIncoming(nd); | ||||
| auto num = succs.size(); | auto num = succs.size(); | ||||
| for (size_t i = 0; i < num; i++) { | for (size_t i = 0; i < num; i++) { | ||||
| auto& t = succs.at(i); | |||||
| auto &t = succs.at(i); | |||||
| MS_EXCEPTION_IF_NULL(t); | MS_EXCEPTION_IF_NULL(t); | ||||
| if (t->isa<ValueNode>() || t->isa<Parameter>()) { | if (t->isa<ValueNode>() || t->isa<Parameter>()) { | ||||
| if ((!is_user) || (i != 0)) { | if ((!is_user) || (i != 0)) { | ||||
| @@ -143,7 +143,7 @@ void DrawEdges(const std::vector<AnfNodePtr>& nodes, const std::shared_ptr<BaseD | |||||
| } | } | ||||
| } | } | ||||
| void DrawByOpt(std::string filename, const FuncGraphPtr& func_graph, bool is_user) { | |||||
| void DrawByOpt(std::string filename, const FuncGraphPtr &func_graph, bool is_user) { | |||||
| if (func_graph == nullptr) { | if (func_graph == nullptr) { | ||||
| return; | return; | ||||
| } | } | ||||
| @@ -169,7 +169,7 @@ void DrawByOpt(std::string filename, const FuncGraphPtr& func_graph, bool is_use | |||||
| DrawValueNodes(nodes, &sub_graphs); | DrawValueNodes(nodes, &sub_graphs); | ||||
| // Draw subgraph | // Draw subgraph | ||||
| for (const auto& gsub : sub_graphs) { | |||||
| for (const auto &gsub : sub_graphs) { | |||||
| digraph->SubGraph(gsub.first, gsub.second); | digraph->SubGraph(gsub.first, gsub.second); | ||||
| } | } | ||||
| @@ -182,18 +182,18 @@ void DrawByOpt(std::string filename, const FuncGraphPtr& func_graph, bool is_use | |||||
| } | } | ||||
| #ifdef ENABLE_DUMP_IR | #ifdef ENABLE_DUMP_IR | ||||
| void Draw(const std::string& filename, const FuncGraphPtr& func_graph) { | |||||
| void Draw(const std::string &filename, const FuncGraphPtr &func_graph) { | |||||
| const std::string dot_suffix = ".dot"; | const std::string dot_suffix = ".dot"; | ||||
| std::string filename_with_suffix = | std::string filename_with_suffix = | ||||
| (filename.rfind(dot_suffix) != (filename.size() - dot_suffix.size())) ? (filename + dot_suffix) : filename; | (filename.rfind(dot_suffix) != (filename.size() - dot_suffix.size())) ? (filename + dot_suffix) : filename; | ||||
| DrawByOpt(filename_with_suffix, func_graph, false); | DrawByOpt(filename_with_suffix, func_graph, false); | ||||
| } | } | ||||
| void DrawUserFuncGraph(const std::string& filename, const FuncGraphPtr& func_graph) { | |||||
| void DrawUserFuncGraph(const std::string &filename, const FuncGraphPtr &func_graph) { | |||||
| DrawByOpt(filename, func_graph, true); | DrawByOpt(filename, func_graph, true); | ||||
| } | } | ||||
| #else | #else | ||||
| void Draw(const std::string&, const FuncGraphPtr&) { | |||||
| void Draw(const std::string &, const FuncGraphPtr &) { | |||||
| static bool already_printed = false; | static bool already_printed = false; | ||||
| if (already_printed) { | if (already_printed) { | ||||
| return; | return; | ||||
| @@ -203,7 +203,7 @@ void Draw(const std::string&, const FuncGraphPtr&) { | |||||
| << "please recompile source to enable it. See help of building script."; | << "please recompile source to enable it. See help of building script."; | ||||
| } | } | ||||
| void DrawUserFuncGraph(const std::string&, const FuncGraphPtr&) { | |||||
| void DrawUserFuncGraph(const std::string &, const FuncGraphPtr &) { | |||||
| static bool already_printed = false; | static bool already_printed = false; | ||||
| if (already_printed) { | if (already_printed) { | ||||
| return; | return; | ||||
| @@ -234,7 +234,7 @@ std::string Graphviz::Shape(AnfNodePtr node) { | |||||
| return "plaintext"; | return "plaintext"; | ||||
| } | } | ||||
| std::string Graphviz::Color(const AnfNodePtr& node) { | |||||
| std::string Graphviz::Color(const AnfNodePtr &node) { | |||||
| if (node == nullptr) { | if (node == nullptr) { | ||||
| return ""; | return ""; | ||||
| } | } | ||||
| @@ -259,7 +259,7 @@ void BaseDigraph::Start() { | |||||
| buffer_ << "compound=true" << std::endl; | buffer_ << "compound=true" << std::endl; | ||||
| } | } | ||||
| void BaseDigraph::Head(const AnfNodePtr& node, int id) { | |||||
| void BaseDigraph::Head(const AnfNodePtr &node, int id) { | |||||
| if (node == nullptr) { | if (node == nullptr) { | ||||
| return; | return; | ||||
| } | } | ||||
| @@ -270,7 +270,7 @@ void BaseDigraph::Head(const AnfNodePtr& node, int id) { | |||||
| } | } | ||||
| } | } | ||||
| void BaseDigraph::Tail(const AnfNodePtr& node, int idx, int id) { | |||||
| void BaseDigraph::Tail(const AnfNodePtr &node, int idx, int id) { | |||||
| if (node == nullptr) { | if (node == nullptr) { | ||||
| return; | return; | ||||
| } | } | ||||
| @@ -279,7 +279,7 @@ void BaseDigraph::Tail(const AnfNodePtr& node, int idx, int id) { | |||||
| buffer_ << ":" << idx; | buffer_ << ":" << idx; | ||||
| } | } | ||||
| void BaseDigraph::Tail(const FuncGraphPtr& func_graph) { | |||||
| void BaseDigraph::Tail(const FuncGraphPtr &func_graph) { | |||||
| if (func_graph == nullptr) { | if (func_graph == nullptr) { | ||||
| return; | return; | ||||
| } | } | ||||
| @@ -304,12 +304,12 @@ void BaseDigraph::End() { | |||||
| } | } | ||||
| } | } | ||||
| void BaseDigraph::FuncGraphParameters(const FuncGraphPtr& key) { | |||||
| void BaseDigraph::FuncGraphParameters(const FuncGraphPtr &key) { | |||||
| buffer_ << "parameters_" << key << "[shape=plaintext "; | buffer_ << "parameters_" << key << "[shape=plaintext "; | ||||
| buffer_ << "label=<<table bgcolor='paleturquoise' cellspacing='0' cellborder='1' border='0'>"; | buffer_ << "label=<<table bgcolor='paleturquoise' cellspacing='0' cellborder='1' border='0'>"; | ||||
| buffer_ << "<tr><td>parameters</td></tr>"; | buffer_ << "<tr><td>parameters</td></tr>"; | ||||
| int count = 0; | int count = 0; | ||||
| for (auto& parameter : key->parameters()) { | |||||
| for (auto ¶meter : key->parameters()) { | |||||
| buffer_ << "<tr><td>"; | buffer_ << "<tr><td>"; | ||||
| buffer_ << parameter->ToString(); | buffer_ << parameter->ToString(); | ||||
| auto py_p = dyn_cast<Parameter>(parameter)->default_param(); | auto py_p = dyn_cast<Parameter>(parameter)->default_param(); | ||||
| @@ -331,7 +331,7 @@ void BaseDigraph::FuncGraphParameters(const FuncGraphPtr& key) { | |||||
| buffer_ << "</table>>,];"; | buffer_ << "</table>>,];"; | ||||
| } | } | ||||
| void BaseDigraph::SubGraph(const FuncGraphPtr& key, const std::shared_ptr<BaseDigraph>& gsub) { | |||||
| void BaseDigraph::SubGraph(const FuncGraphPtr &key, const std::shared_ptr<BaseDigraph> &gsub) { | |||||
| if (key == nullptr || gsub == nullptr) { | if (key == nullptr || gsub == nullptr) { | ||||
| return; | return; | ||||
| } | } | ||||
| @@ -361,12 +361,12 @@ Digraph::~Digraph() { | |||||
| if (fout_.is_open()) { | if (fout_.is_open()) { | ||||
| fout_.close(); | fout_.close(); | ||||
| } | } | ||||
| } catch (const std::exception& e) { | |||||
| } catch (const std::exception &e) { | |||||
| MS_LOG(ERROR) << "Exception when closing file " << filename_; | MS_LOG(ERROR) << "Exception when closing file " << filename_; | ||||
| } | } | ||||
| } | } | ||||
| static std::string ReplaceAll(std::string str, const std::string& from, const std::string& to) { | |||||
| static std::string ReplaceAll(std::string str, const std::string &from, const std::string &to) { | |||||
| size_t start_pos = 0; | size_t start_pos = 0; | ||||
| while ((start_pos = str.find(from, start_pos)) != std::string::npos) { | while ((start_pos = str.find(from, start_pos)) != std::string::npos) { | ||||
| (void)str.replace(start_pos, from.length(), to); | (void)str.replace(start_pos, from.length(), to); | ||||
| @@ -375,7 +375,7 @@ static std::string ReplaceAll(std::string str, const std::string& from, const st | |||||
| return str; | return str; | ||||
| } | } | ||||
| static void DrawValueNode(Graphviz* const graph_obj, const ValueNodePtr& node) { | |||||
| static void DrawValueNode(Graphviz *const graph_obj, const ValueNodePtr &node) { | |||||
| MS_EXCEPTION_IF_NULL(graph_obj); | MS_EXCEPTION_IF_NULL(graph_obj); | ||||
| graph_obj->buffer() << "label=<<table port='core' cellborder='0' cellspacing='2' bgcolor='" << graph_obj->Color(node) | graph_obj->buffer() << "label=<<table port='core' cellborder='0' cellspacing='2' bgcolor='" << graph_obj->Color(node) | ||||
| << "'>"; | << "'>"; | ||||
| @@ -410,7 +410,7 @@ static void DrawValueNode(Graphviz* const graph_obj, const ValueNodePtr& node) { | |||||
| graph_obj->buffer() << "</td></tr>"; | graph_obj->buffer() << "</td></tr>"; | ||||
| graph_obj->buffer() << "<tr><td align='left'>"; | graph_obj->buffer() << "<tr><td align='left'>"; | ||||
| int i = 0; | int i = 0; | ||||
| for (const auto& attr : attrs) { | |||||
| for (const auto &attr : attrs) { | |||||
| if (i != 0) { | if (i != 0) { | ||||
| graph_obj->buffer() << "<br/>"; | graph_obj->buffer() << "<br/>"; | ||||
| } | } | ||||
| @@ -425,7 +425,7 @@ static void DrawValueNode(Graphviz* const graph_obj, const ValueNodePtr& node) { | |||||
| graph_obj->buffer() << "</table>>,"; | graph_obj->buffer() << "</table>>,"; | ||||
| } | } | ||||
| static void DrawParallelInfo(Graphviz* const graph_obj, const CNodePtr& node) { | |||||
| static void DrawParallelInfo(Graphviz *const graph_obj, const CNodePtr &node) { | |||||
| if (graph_obj == nullptr || node == nullptr) { | if (graph_obj == nullptr || node == nullptr) { | ||||
| return; | return; | ||||
| } | } | ||||
| @@ -444,7 +444,7 @@ static void DrawParallelInfo(Graphviz* const graph_obj, const CNodePtr& node) { | |||||
| } | } | ||||
| } | } | ||||
| static void DrawCNode(Graphviz* const graph_obj, const CNodePtr& node) { | |||||
| static void DrawCNode(Graphviz *const graph_obj, const CNodePtr &node) { | |||||
| if (graph_obj == nullptr || node == nullptr || node->size() == 0) { | if (graph_obj == nullptr || node == nullptr || node->size() == 0) { | ||||
| return; | return; | ||||
| } | } | ||||
| @@ -484,7 +484,7 @@ static void DrawCNode(Graphviz* const graph_obj, const CNodePtr& node) { | |||||
| } | } | ||||
| graph_obj->buffer() << ">"; | graph_obj->buffer() << ">"; | ||||
| int i = 0; | int i = 0; | ||||
| for (auto& attr : attrs) { | |||||
| for (auto &attr : attrs) { | |||||
| if (i != 0) { | if (i != 0) { | ||||
| graph_obj->buffer() << "<br/>"; | graph_obj->buffer() << "<br/>"; | ||||
| } | } | ||||
| @@ -567,7 +567,7 @@ ModelDigraph::~ModelDigraph() { | |||||
| if (fout_.is_open()) { | if (fout_.is_open()) { | ||||
| fout_.close(); | fout_.close(); | ||||
| } | } | ||||
| } catch (const std::exception& e) { | |||||
| } catch (const std::exception &e) { | |||||
| MS_LOG(ERROR) << "exception when closing file " << filename_; | MS_LOG(ERROR) << "exception when closing file " << filename_; | ||||
| } | } | ||||
| } | } | ||||
| @@ -31,9 +31,9 @@ namespace parse = mindspore::parse; | |||||
| class Graphviz { | class Graphviz { | ||||
| public: | public: | ||||
| Graphviz(const std::string& name, const std::string& filename) : name_(name), filename_(filename), fout_(filename_) {} | |||||
| Graphviz(const std::string &name, const std::string &filename) : name_(name), filename_(filename), fout_(filename_) {} | |||||
| explicit Graphviz(const std::string& name) : name_(name) {} | |||||
| explicit Graphviz(const std::string &name) : name_(name) {} | |||||
| virtual ~Graphviz() {} | virtual ~Graphviz() {} | ||||
| @@ -41,8 +41,8 @@ class Graphviz { | |||||
| virtual void End() {} | virtual void End() {} | ||||
| virtual std::string Shape(AnfNodePtr node); | virtual std::string Shape(AnfNodePtr node); | ||||
| std::string Color(const AnfNodePtr& node); | |||||
| std::ostringstream& buffer() { return buffer_; } | |||||
| std::string Color(const AnfNodePtr &node); | |||||
| std::ostringstream &buffer() { return buffer_; } | |||||
| std::ostringstream buffer_; | std::ostringstream buffer_; | ||||
| protected: | protected: | ||||
| @@ -53,8 +53,8 @@ class Graphviz { | |||||
| class BaseDigraph : public Graphviz { | class BaseDigraph : public Graphviz { | ||||
| public: | public: | ||||
| BaseDigraph(const std::string& name, const std::string& filename) : Graphviz(name, filename) {} | |||||
| explicit BaseDigraph(const std::string& name) : Graphviz(name) {} | |||||
| BaseDigraph(const std::string &name, const std::string &filename) : Graphviz(name, filename) {} | |||||
| explicit BaseDigraph(const std::string &name) : Graphviz(name) {} | |||||
| ~BaseDigraph() override = default; | ~BaseDigraph() override = default; | ||||
| virtual void Node(AnfNodePtr node, int id = 0) = 0; | virtual void Node(AnfNodePtr node, int id = 0) = 0; | ||||
| @@ -63,21 +63,21 @@ class BaseDigraph : public Graphviz { | |||||
| void Start() override; | void Start() override; | ||||
| void End() override; | void End() override; | ||||
| virtual void Edge(AnfNodePtr start, FuncGraphPtr end, int id_start); | virtual void Edge(AnfNodePtr start, FuncGraphPtr end, int id_start); | ||||
| void FuncGraphParameters(const FuncGraphPtr& key); | |||||
| void SubGraph(const FuncGraphPtr& key, const std::shared_ptr<BaseDigraph>& gsub); | |||||
| void FuncGraphParameters(const FuncGraphPtr &key); | |||||
| void SubGraph(const FuncGraphPtr &key, const std::shared_ptr<BaseDigraph> &gsub); | |||||
| const std::string& name() const { return name_; } | |||||
| const std::string &name() const { return name_; } | |||||
| protected: | protected: | ||||
| void Head(const AnfNodePtr& node, int id = 0); | |||||
| void Tail(const AnfNodePtr& node, int idx, int id = 0); | |||||
| void Tail(const FuncGraphPtr& func_graph); | |||||
| void Head(const AnfNodePtr &node, int id = 0); | |||||
| void Tail(const AnfNodePtr &node, int idx, int id = 0); | |||||
| void Tail(const FuncGraphPtr &func_graph); | |||||
| }; | }; | ||||
| class Digraph : public BaseDigraph { | class Digraph : public BaseDigraph { | ||||
| public: | public: | ||||
| Digraph(const std::string& name, const std::string& filename) : BaseDigraph(name, filename) {} | |||||
| explicit Digraph(const std::string& name) : BaseDigraph(name) {} | |||||
| Digraph(const std::string &name, const std::string &filename) : BaseDigraph(name, filename) {} | |||||
| explicit Digraph(const std::string &name) : BaseDigraph(name) {} | |||||
| ~Digraph() override; | ~Digraph() override; | ||||
| void Node(AnfNodePtr node, int id = 0) override; | void Node(AnfNodePtr node, int id = 0) override; | ||||
| @@ -86,8 +86,8 @@ class Digraph : public BaseDigraph { | |||||
| class ModelDigraph : public BaseDigraph { | class ModelDigraph : public BaseDigraph { | ||||
| public: | public: | ||||
| ModelDigraph(const std::string& name, const std::string& filename) : BaseDigraph(name, filename) {} | |||||
| explicit ModelDigraph(const std::string& name) : BaseDigraph(name) {} | |||||
| ModelDigraph(const std::string &name, const std::string &filename) : BaseDigraph(name, filename) {} | |||||
| explicit ModelDigraph(const std::string &name) : BaseDigraph(name) {} | |||||
| ~ModelDigraph() override; | ~ModelDigraph() override; | ||||
| std::string Shape(AnfNodePtr node) override; | std::string Shape(AnfNodePtr node) override; | ||||
| @@ -96,8 +96,8 @@ class ModelDigraph : public BaseDigraph { | |||||
| }; | }; | ||||
| // API to draw | // API to draw | ||||
| void Draw(const std::string& filename, const FuncGraphPtr& func_graph); | |||||
| void DrawUserFuncGraph(const std::string& filename, const FuncGraphPtr& func_graph); | |||||
| void Draw(const std::string &filename, const FuncGraphPtr &func_graph); | |||||
| void DrawUserFuncGraph(const std::string &filename, const FuncGraphPtr &func_graph); | |||||
| } // namespace draw | } // namespace draw | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -33,38 +33,38 @@ class ProtoExporter { | |||||
| ProtoExporter() {} | ProtoExporter() {} | ||||
| ~ProtoExporter() {} | ~ProtoExporter() {} | ||||
| std::string GetFuncGraphProtoString(const FuncGraphPtr& func_graph); | |||||
| std::string GetFuncGraphProtoString(const FuncGraphPtr &func_graph); | |||||
| private: | private: | ||||
| void InitModelInfo(); | void InitModelInfo(); | ||||
| void GetOpNodeTypeAndAttrs(const FuncGraphPtr& func_graph, const AnfNodePtr& node, irpb::NodeProto* node_proto); | |||||
| std::string GetOpNodeInputId(const FuncGraphPtr& func_graph, const AnfNodePtr& node, | |||||
| const std::map<AnfNodePtr, size_t>& apply_map, | |||||
| std::map<AnfNodePtr, size_t>* const_map_ptr); | |||||
| void SetValueToProto(const ValuePtr& attr_value, irpb::ValueProto* value_proto); | |||||
| void SetScalarToProto(const ScalarPtr& val, irpb::ValueProto* value_proto); | |||||
| void SetSequenceToProto(const ValueSequeuePtr& val, irpb::ValueProto* value_proto); | |||||
| void SetDictionaryToProto(const ValueDictionaryPtr& val, irpb::ValueProto* value_proto); | |||||
| void SetNodeOutputType(const AnfNodePtr& node, irpb::TypeProto* type_proto); | |||||
| void SetNodeOutputType(const TypePtr& node, const BaseShapePtr& shape, irpb::TypeProto* type_proto); | |||||
| void ExportFuncGraph(const FuncGraphPtr& func_graph, irpb::GraphProto* graph_proto); | |||||
| void ExportParameters(const FuncGraphPtr& func_graph, irpb::GraphProto* graph_proto); | |||||
| void ExportCNodes(const FuncGraphPtr& func_graph, irpb::GraphProto* graph_proto, | |||||
| std::map<AnfNodePtr, size_t>* const_map_ptr); | |||||
| void ExportCNode(const FuncGraphPtr& func_graph, const CNodePtr& node, std::map<AnfNodePtr, size_t>* apply_map_ptr, | |||||
| std::map<AnfNodePtr, size_t>* const_map_ptr, irpb::GraphProto* graph_proto); | |||||
| void ExportFuncGraphOutput(const FuncGraphPtr& func_graph, const CNodePtr& ret_node, | |||||
| const std::map<AnfNodePtr, size_t>& apply_map, std::map<AnfNodePtr, size_t>* const_map_ptr, | |||||
| irpb::GraphProto* graph_proto); | |||||
| void ExportValueNodes(const std::map<AnfNodePtr, size_t>& const_map, irpb::GraphProto* graph_proto); | |||||
| void GetOpNodeTypeAndAttrs(const FuncGraphPtr &func_graph, const AnfNodePtr &node, irpb::NodeProto *node_proto); | |||||
| std::string GetOpNodeInputId(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | |||||
| const std::map<AnfNodePtr, size_t> &apply_map, | |||||
| std::map<AnfNodePtr, size_t> *const_map_ptr); | |||||
| void SetValueToProto(const ValuePtr &attr_value, irpb::ValueProto *value_proto); | |||||
| void SetScalarToProto(const ScalarPtr &val, irpb::ValueProto *value_proto); | |||||
| void SetSequenceToProto(const ValueSequeuePtr &val, irpb::ValueProto *value_proto); | |||||
| void SetDictionaryToProto(const ValueDictionaryPtr &val, irpb::ValueProto *value_proto); | |||||
| void SetNodeOutputType(const AnfNodePtr &node, irpb::TypeProto *type_proto); | |||||
| void SetNodeOutputType(const TypePtr &node, const BaseShapePtr &shape, irpb::TypeProto *type_proto); | |||||
| void ExportFuncGraph(const FuncGraphPtr &func_graph, irpb::GraphProto *graph_proto); | |||||
| void ExportParameters(const FuncGraphPtr &func_graph, irpb::GraphProto *graph_proto); | |||||
| void ExportCNodes(const FuncGraphPtr &func_graph, irpb::GraphProto *graph_proto, | |||||
| std::map<AnfNodePtr, size_t> *const_map_ptr); | |||||
| void ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map<AnfNodePtr, size_t> *apply_map_ptr, | |||||
| std::map<AnfNodePtr, size_t> *const_map_ptr, irpb::GraphProto *graph_proto); | |||||
| void ExportFuncGraphOutput(const FuncGraphPtr &func_graph, const CNodePtr &ret_node, | |||||
| const std::map<AnfNodePtr, size_t> &apply_map, std::map<AnfNodePtr, size_t> *const_map_ptr, | |||||
| irpb::GraphProto *graph_proto); | |||||
| void ExportValueNodes(const std::map<AnfNodePtr, size_t> &const_map, irpb::GraphProto *graph_proto); | |||||
| static std::string GetConstNodeId(size_t idx) { return std::string("cst") + std::to_string(idx); } | static std::string GetConstNodeId(size_t idx) { return std::string("cst") + std::to_string(idx); } | ||||
| irpb::ModelProto model_; | irpb::ModelProto model_; | ||||
| }; | }; | ||||
| static irpb::DataType GetNumberDataType(const TypePtr& type) { | |||||
| static irpb::DataType GetNumberDataType(const TypePtr &type) { | |||||
| switch (type->type_id()) { | switch (type->type_id()) { | ||||
| case kNumberTypeBool: | case kNumberTypeBool: | ||||
| return irpb::DT_BOOL; | return irpb::DT_BOOL; | ||||
| @@ -101,7 +101,7 @@ static irpb::DataType GetNumberDataType(const TypePtr& type) { | |||||
| } | } | ||||
| } | } | ||||
| void ProtoExporter::SetNodeOutputType(const TypePtr& type, const BaseShapePtr& shape, irpb::TypeProto* type_proto) { | |||||
| void ProtoExporter::SetNodeOutputType(const TypePtr &type, const BaseShapePtr &shape, irpb::TypeProto *type_proto) { | |||||
| if (type_proto == nullptr) { | if (type_proto == nullptr) { | ||||
| return; | return; | ||||
| } | } | ||||
| @@ -116,14 +116,14 @@ void ProtoExporter::SetNodeOutputType(const TypePtr& type, const BaseShapePtr& s | |||||
| type_proto->set_data_type(irpb::DT_TENSOR); | type_proto->set_data_type(irpb::DT_TENSOR); | ||||
| if (shape != nullptr && shape->isa<abstract::Shape>()) { | if (shape != nullptr && shape->isa<abstract::Shape>()) { | ||||
| abstract::ShapePtr shape_info = dyn_cast<abstract::Shape>(shape); | abstract::ShapePtr shape_info = dyn_cast<abstract::Shape>(shape); | ||||
| for (const auto& elem : shape_info->shape()) { | |||||
| for (const auto &elem : shape_info->shape()) { | |||||
| type_proto->mutable_tensor_type()->mutable_shape()->add_dim()->set_size(elem); | type_proto->mutable_tensor_type()->mutable_shape()->add_dim()->set_size(elem); | ||||
| } | } | ||||
| } | } | ||||
| } else if (type->isa<Tuple>()) { | } else if (type->isa<Tuple>()) { | ||||
| TuplePtr tuple_type = dyn_cast<Tuple>(type); | TuplePtr tuple_type = dyn_cast<Tuple>(type); | ||||
| type_proto->set_data_type(irpb::DT_TUPLE); | type_proto->set_data_type(irpb::DT_TUPLE); | ||||
| for (const auto& elem_type : tuple_type->elements()) { | |||||
| for (const auto &elem_type : tuple_type->elements()) { | |||||
| SetNodeOutputType(elem_type, nullptr, type_proto->mutable_sequence_type()->add_elem_types()); | SetNodeOutputType(elem_type, nullptr, type_proto->mutable_sequence_type()->add_elem_types()); | ||||
| } | } | ||||
| } else if (type->isa<TypeType>()) { | } else if (type->isa<TypeType>()) { | ||||
| @@ -131,7 +131,7 @@ void ProtoExporter::SetNodeOutputType(const TypePtr& type, const BaseShapePtr& s | |||||
| } else if (type->isa<List>()) { | } else if (type->isa<List>()) { | ||||
| ListPtr list_type = dyn_cast<List>(type); | ListPtr list_type = dyn_cast<List>(type); | ||||
| type_proto->set_data_type(irpb::DT_LIST); | type_proto->set_data_type(irpb::DT_LIST); | ||||
| for (const auto& elem_type : list_type->elements()) { | |||||
| for (const auto &elem_type : list_type->elements()) { | |||||
| SetNodeOutputType(elem_type, nullptr, type_proto->mutable_sequence_type()->add_elem_types()); | SetNodeOutputType(elem_type, nullptr, type_proto->mutable_sequence_type()->add_elem_types()); | ||||
| } | } | ||||
| } else if (type->isa<TypeAnything>()) { | } else if (type->isa<TypeAnything>()) { | ||||
| @@ -153,20 +153,20 @@ void ProtoExporter::SetNodeOutputType(const TypePtr& type, const BaseShapePtr& s | |||||
| } | } | ||||
| } | } | ||||
| void ProtoExporter::SetNodeOutputType(const AnfNodePtr& node, irpb::TypeProto* type_proto) { | |||||
| void ProtoExporter::SetNodeOutputType(const AnfNodePtr &node, irpb::TypeProto *type_proto) { | |||||
| if (node == nullptr || type_proto == nullptr) { | if (node == nullptr || type_proto == nullptr) { | ||||
| return; | return; | ||||
| } | } | ||||
| SetNodeOutputType(node->Type(), node->Shape(), type_proto); | SetNodeOutputType(node->Type(), node->Shape(), type_proto); | ||||
| } | } | ||||
| void ProtoExporter::SetValueToProto(const ValuePtr& val, irpb::ValueProto* value_proto) { | |||||
| void ProtoExporter::SetValueToProto(const ValuePtr &val, irpb::ValueProto *value_proto) { | |||||
| if (val == nullptr || value_proto == nullptr) { | if (val == nullptr || value_proto == nullptr) { | ||||
| return; | return; | ||||
| } | } | ||||
| if (val->isa<StringImm>()) { | if (val->isa<StringImm>()) { | ||||
| const StringImmPtr& value = dyn_cast<StringImm>(val); | |||||
| const StringImmPtr &value = dyn_cast<StringImm>(val); | |||||
| value_proto->set_dtype(irpb::DT_STRING); | value_proto->set_dtype(irpb::DT_STRING); | ||||
| value_proto->set_str_val(value->value()); | value_proto->set_str_val(value->value()); | ||||
| } else if (val->isa<Scalar>()) { | } else if (val->isa<Scalar>()) { | ||||
| @@ -195,15 +195,15 @@ void ProtoExporter::SetValueToProto(const ValuePtr& val, irpb::ValueProto* value | |||||
| } else if (val->isa<tensor::Tensor>()) { | } else if (val->isa<tensor::Tensor>()) { | ||||
| tensor::TensorPtr tensor_ptr = dyn_cast<tensor::Tensor>(val); | tensor::TensorPtr tensor_ptr = dyn_cast<tensor::Tensor>(val); | ||||
| value_proto->set_dtype(irpb::DT_TENSOR); | value_proto->set_dtype(irpb::DT_TENSOR); | ||||
| irpb::TensorProto* tensor_proto = value_proto->mutable_tensor_val(); | |||||
| irpb::TensorProto *tensor_proto = value_proto->mutable_tensor_val(); | |||||
| tensor_proto->set_data_type(GetNumberDataType(tensor_ptr->Dtype())); | tensor_proto->set_data_type(GetNumberDataType(tensor_ptr->Dtype())); | ||||
| for (auto& elem : tensor_ptr->shape()) { | |||||
| for (auto &elem : tensor_ptr->shape()) { | |||||
| tensor_proto->add_dims(elem); | tensor_proto->add_dims(elem); | ||||
| } | } | ||||
| } else if (val->isa<TensorType>()) { | } else if (val->isa<TensorType>()) { | ||||
| value_proto->set_dtype(irpb::DT_TYPE); | value_proto->set_dtype(irpb::DT_TYPE); | ||||
| irpb::TypeProto* type_proto = value_proto->mutable_type_val(); | |||||
| irpb::TypeProto *type_proto = value_proto->mutable_type_val(); | |||||
| type_proto->set_data_type(irpb::DT_TENSOR); | type_proto->set_data_type(irpb::DT_TENSOR); | ||||
| TypePtr elem_type = dyn_cast<TensorType>(val)->element(); | TypePtr elem_type = dyn_cast<TensorType>(val)->element(); | ||||
| type_proto->mutable_tensor_type()->set_elem_type(GetNumberDataType(elem_type)); | type_proto->mutable_tensor_type()->set_elem_type(GetNumberDataType(elem_type)); | ||||
| @@ -212,53 +212,53 @@ void ProtoExporter::SetValueToProto(const ValuePtr& val, irpb::ValueProto* value | |||||
| } | } | ||||
| } | } | ||||
| void ProtoExporter::SetScalarToProto(const ScalarPtr& val, irpb::ValueProto* value_proto) { | |||||
| void ProtoExporter::SetScalarToProto(const ScalarPtr &val, irpb::ValueProto *value_proto) { | |||||
| if (val == nullptr || value_proto == nullptr) { | if (val == nullptr || value_proto == nullptr) { | ||||
| return; | return; | ||||
| } | } | ||||
| if (val->isa<BoolImm>()) { | if (val->isa<BoolImm>()) { | ||||
| const BoolImmPtr& value = dyn_cast<BoolImm>(val); | |||||
| const BoolImmPtr &value = dyn_cast<BoolImm>(val); | |||||
| value_proto->set_dtype(irpb::DT_BOOL); | value_proto->set_dtype(irpb::DT_BOOL); | ||||
| value_proto->set_bool_val(value->value()); | value_proto->set_bool_val(value->value()); | ||||
| } else if (val->isa<Int8Imm>()) { | } else if (val->isa<Int8Imm>()) { | ||||
| const Int8ImmPtr& value = dyn_cast<Int8Imm>(val); | |||||
| const Int8ImmPtr &value = dyn_cast<Int8Imm>(val); | |||||
| value_proto->set_dtype(irpb::DT_INT8); | value_proto->set_dtype(irpb::DT_INT8); | ||||
| value_proto->set_int_val(value->value()); | value_proto->set_int_val(value->value()); | ||||
| } else if (val->isa<Int16Imm>()) { | } else if (val->isa<Int16Imm>()) { | ||||
| const Int16ImmPtr& value = dyn_cast<Int16Imm>(val); | |||||
| const Int16ImmPtr &value = dyn_cast<Int16Imm>(val); | |||||
| value_proto->set_dtype(irpb::DT_INT16); | value_proto->set_dtype(irpb::DT_INT16); | ||||
| value_proto->set_int_val(value->value()); | value_proto->set_int_val(value->value()); | ||||
| } else if (val->isa<Int32Imm>()) { | } else if (val->isa<Int32Imm>()) { | ||||
| const Int32ImmPtr& value = dyn_cast<Int32Imm>(val); | |||||
| const Int32ImmPtr &value = dyn_cast<Int32Imm>(val); | |||||
| value_proto->set_dtype(irpb::DT_INT32); | value_proto->set_dtype(irpb::DT_INT32); | ||||
| value_proto->set_int_val(value->value()); | value_proto->set_int_val(value->value()); | ||||
| } else if (val->isa<Int64Imm>()) { | } else if (val->isa<Int64Imm>()) { | ||||
| const Int64ImmPtr& value = dyn_cast<Int64Imm>(val); | |||||
| const Int64ImmPtr &value = dyn_cast<Int64Imm>(val); | |||||
| value_proto->set_dtype(irpb::DT_INT64); | value_proto->set_dtype(irpb::DT_INT64); | ||||
| value_proto->set_int_val(value->value()); | value_proto->set_int_val(value->value()); | ||||
| } else if (val->isa<UInt8Imm>()) { | } else if (val->isa<UInt8Imm>()) { | ||||
| const UInt8ImmPtr& value = dyn_cast<UInt8Imm>(val); | |||||
| const UInt8ImmPtr &value = dyn_cast<UInt8Imm>(val); | |||||
| value_proto->set_dtype(irpb::DT_UINT8); | value_proto->set_dtype(irpb::DT_UINT8); | ||||
| value_proto->set_uint_val(value->value()); | value_proto->set_uint_val(value->value()); | ||||
| } else if (val->isa<UInt16Imm>()) { | } else if (val->isa<UInt16Imm>()) { | ||||
| const UInt16ImmPtr& value = dyn_cast<UInt16Imm>(val); | |||||
| const UInt16ImmPtr &value = dyn_cast<UInt16Imm>(val); | |||||
| value_proto->set_dtype(irpb::DT_UINT16); | value_proto->set_dtype(irpb::DT_UINT16); | ||||
| value_proto->set_uint_val(value->value()); | value_proto->set_uint_val(value->value()); | ||||
| } else if (val->isa<UInt32Imm>()) { | } else if (val->isa<UInt32Imm>()) { | ||||
| const UInt32ImmPtr& value = dyn_cast<UInt32Imm>(val); | |||||
| const UInt32ImmPtr &value = dyn_cast<UInt32Imm>(val); | |||||
| value_proto->set_dtype(irpb::DT_UINT32); | value_proto->set_dtype(irpb::DT_UINT32); | ||||
| value_proto->set_uint_val(value->value()); | value_proto->set_uint_val(value->value()); | ||||
| } else if (val->isa<UInt64Imm>()) { | } else if (val->isa<UInt64Imm>()) { | ||||
| const UInt64ImmPtr& value = dyn_cast<UInt64Imm>(val); | |||||
| const UInt64ImmPtr &value = dyn_cast<UInt64Imm>(val); | |||||
| value_proto->set_dtype(irpb::DT_UINT64); | value_proto->set_dtype(irpb::DT_UINT64); | ||||
| value_proto->set_uint_val(value->value()); | value_proto->set_uint_val(value->value()); | ||||
| } else if (val->isa<FP32Imm>()) { | } else if (val->isa<FP32Imm>()) { | ||||
| const FP32ImmPtr& value = dyn_cast<FP32Imm>(val); | |||||
| const FP32ImmPtr &value = dyn_cast<FP32Imm>(val); | |||||
| value_proto->set_dtype(irpb::DT_FLOAT32); | value_proto->set_dtype(irpb::DT_FLOAT32); | ||||
| value_proto->set_float_val(value->value()); | value_proto->set_float_val(value->value()); | ||||
| } else if (val->isa<FP64Imm>()) { | } else if (val->isa<FP64Imm>()) { | ||||
| const FP64ImmPtr& value = dyn_cast<FP64Imm>(val); | |||||
| const FP64ImmPtr &value = dyn_cast<FP64Imm>(val); | |||||
| value_proto->set_dtype(irpb::DT_FLOAT64); | value_proto->set_dtype(irpb::DT_FLOAT64); | ||||
| value_proto->set_double_val(value->value()); | value_proto->set_double_val(value->value()); | ||||
| } else { | } else { | ||||
| @@ -266,40 +266,40 @@ void ProtoExporter::SetScalarToProto(const ScalarPtr& val, irpb::ValueProto* val | |||||
| } | } | ||||
| } | } | ||||
| void ProtoExporter::SetSequenceToProto(const ValueSequeuePtr& val, irpb::ValueProto* value_proto) { | |||||
| void ProtoExporter::SetSequenceToProto(const ValueSequeuePtr &val, irpb::ValueProto *value_proto) { | |||||
| if (val == nullptr || value_proto == nullptr) { | if (val == nullptr || value_proto == nullptr) { | ||||
| return; | return; | ||||
| } | } | ||||
| if (val->isa<ValueTuple>()) { | if (val->isa<ValueTuple>()) { | ||||
| const ValueTuplePtr& value = dyn_cast<ValueTuple>(val); | |||||
| const ValueTuplePtr &value = dyn_cast<ValueTuple>(val); | |||||
| value_proto->set_dtype(irpb::DT_TUPLE); | value_proto->set_dtype(irpb::DT_TUPLE); | ||||
| for (const auto& item : value->value()) { | |||||
| for (const auto &item : value->value()) { | |||||
| SetValueToProto(item, value_proto->add_values()); | SetValueToProto(item, value_proto->add_values()); | ||||
| } | } | ||||
| } else if (val->isa<ValueList>()) { | } else if (val->isa<ValueList>()) { | ||||
| const ValueListPtr& value = dyn_cast<ValueList>(val); | |||||
| const ValueListPtr &value = dyn_cast<ValueList>(val); | |||||
| value_proto->set_dtype(irpb::DT_LIST); | value_proto->set_dtype(irpb::DT_LIST); | ||||
| for (const auto& item : value->value()) { | |||||
| for (const auto &item : value->value()) { | |||||
| SetValueToProto(item, value_proto->add_values()); | SetValueToProto(item, value_proto->add_values()); | ||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| void ProtoExporter::SetDictionaryToProto(const ValueDictionaryPtr& val, irpb::ValueProto* value_proto) { | |||||
| void ProtoExporter::SetDictionaryToProto(const ValueDictionaryPtr &val, irpb::ValueProto *value_proto) { | |||||
| if (val == nullptr || value_proto == nullptr) { | if (val == nullptr || value_proto == nullptr) { | ||||
| return; | return; | ||||
| } | } | ||||
| value_proto->set_dtype(irpb::DT_DICT); | value_proto->set_dtype(irpb::DT_DICT); | ||||
| for (const auto& item : val->value()) { | |||||
| irpb::NamedValueProto* named_val = value_proto->add_dict_val(); | |||||
| for (const auto &item : val->value()) { | |||||
| irpb::NamedValueProto *named_val = value_proto->add_dict_val(); | |||||
| named_val->set_key(item.first); | named_val->set_key(item.first); | ||||
| SetValueToProto(item.second, named_val->mutable_value()); | SetValueToProto(item.second, named_val->mutable_value()); | ||||
| } | } | ||||
| } | } | ||||
| void ProtoExporter::GetOpNodeTypeAndAttrs(const FuncGraphPtr&, const AnfNodePtr& node, irpb::NodeProto* node_proto) { | |||||
| void ProtoExporter::GetOpNodeTypeAndAttrs(const FuncGraphPtr &, const AnfNodePtr &node, irpb::NodeProto *node_proto) { | |||||
| if (node == nullptr || node_proto == nullptr) { | if (node == nullptr || node_proto == nullptr) { | ||||
| return; | return; | ||||
| } | } | ||||
| @@ -312,19 +312,19 @@ void ProtoExporter::GetOpNodeTypeAndAttrs(const FuncGraphPtr&, const AnfNodePtr& | |||||
| MS_LOG(EXCEPTION) << "Op node is not primitive: " << node->ToString(); | MS_LOG(EXCEPTION) << "Op node is not primitive: " << node->ToString(); | ||||
| } | } | ||||
| const PrimitivePtr& prim = GetValueNode<PrimitivePtr>(node); | |||||
| const PrimitivePtr &prim = GetValueNode<PrimitivePtr>(node); | |||||
| node_proto->set_op_type(prim->name()); | node_proto->set_op_type(prim->name()); | ||||
| for (const auto& attr : prim->attrs()) { | |||||
| irpb::AttributeProto* attr_proto = node_proto->add_attribute(); | |||||
| for (const auto &attr : prim->attrs()) { | |||||
| irpb::AttributeProto *attr_proto = node_proto->add_attribute(); | |||||
| attr_proto->set_name(attr.first); | attr_proto->set_name(attr.first); | ||||
| SetValueToProto(attr.second, attr_proto->mutable_value()); | SetValueToProto(attr.second, attr_proto->mutable_value()); | ||||
| } | } | ||||
| node_proto->set_scope(node->scope()->name()); | node_proto->set_scope(node->scope()->name()); | ||||
| } | } | ||||
| std::string ProtoExporter::GetOpNodeInputId(const FuncGraphPtr&, const AnfNodePtr& node, | |||||
| const std::map<AnfNodePtr, size_t>& apply_map, | |||||
| std::map<AnfNodePtr, size_t>* const_map_ptr) { | |||||
| std::string ProtoExporter::GetOpNodeInputId(const FuncGraphPtr &, const AnfNodePtr &node, | |||||
| const std::map<AnfNodePtr, size_t> &apply_map, | |||||
| std::map<AnfNodePtr, size_t> *const_map_ptr) { | |||||
| if (node == nullptr || const_map_ptr == nullptr) { | if (node == nullptr || const_map_ptr == nullptr) { | ||||
| return ""; | return ""; | ||||
| } | } | ||||
| @@ -354,18 +354,18 @@ std::string ProtoExporter::GetOpNodeInputId(const FuncGraphPtr&, const AnfNodePt | |||||
| MS_LOG(EXCEPTION) << "Unknown node type. node is '" << node->ToString() << "'"; | MS_LOG(EXCEPTION) << "Unknown node type. node is '" << node->ToString() << "'"; | ||||
| } | } | ||||
| std::string ProtoExporter::GetFuncGraphProtoString(const FuncGraphPtr& func_graph) { | |||||
| std::string ProtoExporter::GetFuncGraphProtoString(const FuncGraphPtr &func_graph) { | |||||
| if (func_graph == nullptr) { | if (func_graph == nullptr) { | ||||
| return ""; | return ""; | ||||
| } | } | ||||
| InitModelInfo(); | InitModelInfo(); | ||||
| irpb::GraphProto* graph_proto = model_.mutable_graph(); | |||||
| irpb::GraphProto *graph_proto = model_.mutable_graph(); | |||||
| ExportFuncGraph(func_graph, graph_proto); | ExportFuncGraph(func_graph, graph_proto); | ||||
| return model_.SerializeAsString(); | return model_.SerializeAsString(); | ||||
| } | } | ||||
| void ProtoExporter::ExportFuncGraph(const FuncGraphPtr& func_graph, irpb::GraphProto* graph_proto) { | |||||
| void ProtoExporter::ExportFuncGraph(const FuncGraphPtr &func_graph, irpb::GraphProto *graph_proto) { | |||||
| if (func_graph == nullptr || graph_proto == nullptr) { | if (func_graph == nullptr || graph_proto == nullptr) { | ||||
| return; | return; | ||||
| } | } | ||||
| @@ -383,14 +383,14 @@ void ProtoExporter::ExportFuncGraph(const FuncGraphPtr& func_graph, irpb::GraphP | |||||
| ExportValueNodes(const_map, graph_proto); | ExportValueNodes(const_map, graph_proto); | ||||
| } | } | ||||
| void ProtoExporter::ExportParameters(const FuncGraphPtr& func_graph, irpb::GraphProto* graph_proto) { | |||||
| void ProtoExporter::ExportParameters(const FuncGraphPtr &func_graph, irpb::GraphProto *graph_proto) { | |||||
| if (func_graph == nullptr || graph_proto == nullptr) { | if (func_graph == nullptr || graph_proto == nullptr) { | ||||
| return; | return; | ||||
| } | } | ||||
| std::vector<AnfNodePtr> parameters = func_graph->parameters(); | std::vector<AnfNodePtr> parameters = func_graph->parameters(); | ||||
| for (auto& param : parameters) { | |||||
| irpb::ParameterProto* param_proto = graph_proto->add_parameters(); | |||||
| for (auto ¶m : parameters) { | |||||
| irpb::ParameterProto *param_proto = graph_proto->add_parameters(); | |||||
| param_proto->set_name(param->ToString()); | param_proto->set_name(param->ToString()); | ||||
| SetNodeOutputType(param, param_proto->mutable_type()); | SetNodeOutputType(param, param_proto->mutable_type()); | ||||
| @@ -402,15 +402,15 @@ void ProtoExporter::ExportParameters(const FuncGraphPtr& func_graph, irpb::Graph | |||||
| } | } | ||||
| } | } | ||||
| void ProtoExporter::ExportCNodes(const FuncGraphPtr& func_graph, irpb::GraphProto* graph_proto, | |||||
| std::map<AnfNodePtr, size_t>* const_map_ptr) { | |||||
| void ProtoExporter::ExportCNodes(const FuncGraphPtr &func_graph, irpb::GraphProto *graph_proto, | |||||
| std::map<AnfNodePtr, size_t> *const_map_ptr) { | |||||
| if (func_graph == nullptr || graph_proto == nullptr || const_map_ptr == nullptr) { | if (func_graph == nullptr || graph_proto == nullptr || const_map_ptr == nullptr) { | ||||
| return; | return; | ||||
| } | } | ||||
| // topo sort nodes | // topo sort nodes | ||||
| std::vector<AnfNodePtr> nodes = TopoSort(func_graph->get_return(), SuccIncoming, AlwaysInclude); | std::vector<AnfNodePtr> nodes = TopoSort(func_graph->get_return(), SuccIncoming, AlwaysInclude); | ||||
| std::map<AnfNodePtr, size_t> apply_map; | std::map<AnfNodePtr, size_t> apply_map; | ||||
| for (const AnfNodePtr& node : nodes) { | |||||
| for (const AnfNodePtr &node : nodes) { | |||||
| MS_EXCEPTION_IF_NULL(node); | MS_EXCEPTION_IF_NULL(node); | ||||
| if (!node->isa<CNode>()) { | if (!node->isa<CNode>()) { | ||||
| continue; | continue; | ||||
| @@ -424,9 +424,9 @@ void ProtoExporter::ExportCNodes(const FuncGraphPtr& func_graph, irpb::GraphProt | |||||
| } | } | ||||
| } | } | ||||
| void ProtoExporter::ExportCNode(const FuncGraphPtr& func_graph, const CNodePtr& node, | |||||
| std::map<AnfNodePtr, size_t>* apply_map_ptr, | |||||
| std::map<AnfNodePtr, size_t>* const_map_ptr, irpb::GraphProto* graph_proto) { | |||||
| void ProtoExporter::ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &node, | |||||
| std::map<AnfNodePtr, size_t> *apply_map_ptr, | |||||
| std::map<AnfNodePtr, size_t> *const_map_ptr, irpb::GraphProto *graph_proto) { | |||||
| if (func_graph == nullptr || node == nullptr || apply_map_ptr == nullptr || const_map_ptr == nullptr || | if (func_graph == nullptr || node == nullptr || apply_map_ptr == nullptr || const_map_ptr == nullptr || | ||||
| graph_proto == nullptr) { | graph_proto == nullptr) { | ||||
| return; | return; | ||||
| @@ -435,12 +435,12 @@ void ProtoExporter::ExportCNode(const FuncGraphPtr& func_graph, const CNodePtr& | |||||
| auto apply_idx = apply_map_ptr->size() + 1; | auto apply_idx = apply_map_ptr->size() + 1; | ||||
| (*apply_map_ptr)[node] = apply_idx; | (*apply_map_ptr)[node] = apply_idx; | ||||
| auto& inputs = node->inputs(); | |||||
| auto &inputs = node->inputs(); | |||||
| if (inputs.size() < 1) { | if (inputs.size() < 1) { | ||||
| MS_LOG(EXCEPTION) << "Inputs of apply node is empty"; | MS_LOG(EXCEPTION) << "Inputs of apply node is empty"; | ||||
| } | } | ||||
| AnfNodePtr op = inputs[0]; | AnfNodePtr op = inputs[0]; | ||||
| irpb::NodeProto* node_proto = graph_proto->add_node(); | |||||
| irpb::NodeProto *node_proto = graph_proto->add_node(); | |||||
| // CNode/ConstGraph/Const/Parameter | // CNode/ConstGraph/Const/Parameter | ||||
| if (op->isa<CNode>() || IsValueNode<FuncGraph>(op) || op->isa<Parameter>()) { | if (op->isa<CNode>() || IsValueNode<FuncGraph>(op) || op->isa<Parameter>()) { | ||||
| @@ -452,7 +452,7 @@ void ProtoExporter::ExportCNode(const FuncGraphPtr& func_graph, const CNodePtr& | |||||
| // process OP inputs | // process OP inputs | ||||
| for (size_t i = 1; i < inputs.size(); ++i) { | for (size_t i = 1; i < inputs.size(); ++i) { | ||||
| irpb::InputProto* input_proto = node_proto->add_input(); | |||||
| irpb::InputProto *input_proto = node_proto->add_input(); | |||||
| input_proto->set_type(irpb::InputProto_EdgeType_DATA_EDGE); | input_proto->set_type(irpb::InputProto_EdgeType_DATA_EDGE); | ||||
| std::string id = GetOpNodeInputId(func_graph, inputs[i], *apply_map_ptr, const_map_ptr); | std::string id = GetOpNodeInputId(func_graph, inputs[i], *apply_map_ptr, const_map_ptr); | ||||
| input_proto->set_name(id); | input_proto->set_name(id); | ||||
| @@ -463,9 +463,9 @@ void ProtoExporter::ExportCNode(const FuncGraphPtr& func_graph, const CNodePtr& | |||||
| } | } | ||||
| } | } | ||||
| void ProtoExporter::ExportFuncGraphOutput(const FuncGraphPtr& func_graph, const CNodePtr& ret_node, | |||||
| const std::map<AnfNodePtr, size_t>& apply_map, | |||||
| std::map<AnfNodePtr, size_t>* const_map_ptr, irpb::GraphProto* graph_proto) { | |||||
| void ProtoExporter::ExportFuncGraphOutput(const FuncGraphPtr &func_graph, const CNodePtr &ret_node, | |||||
| const std::map<AnfNodePtr, size_t> &apply_map, | |||||
| std::map<AnfNodePtr, size_t> *const_map_ptr, irpb::GraphProto *graph_proto) { | |||||
| if (ret_node == nullptr || !ret_node->isa<CNode>()) { | if (ret_node == nullptr || !ret_node->isa<CNode>()) { | ||||
| MS_LOG(EXCEPTION) << "Graph return node is illegal"; | MS_LOG(EXCEPTION) << "Graph return node is illegal"; | ||||
| } | } | ||||
| @@ -473,7 +473,7 @@ void ProtoExporter::ExportFuncGraphOutput(const FuncGraphPtr& func_graph, const | |||||
| if (graph_proto == nullptr) { | if (graph_proto == nullptr) { | ||||
| MS_LOG(EXCEPTION) << "graph_proto is nullptr"; | MS_LOG(EXCEPTION) << "graph_proto is nullptr"; | ||||
| } | } | ||||
| irpb::OutputProto* output_proto = graph_proto->add_outputs(); | |||||
| irpb::OutputProto *output_proto = graph_proto->add_outputs(); | |||||
| if (output_proto == nullptr) { | if (output_proto == nullptr) { | ||||
| MS_LOG(EXCEPTION) << "output_proto is nullptr"; | MS_LOG(EXCEPTION) << "output_proto is nullptr"; | ||||
| } | } | ||||
| @@ -482,22 +482,22 @@ void ProtoExporter::ExportFuncGraphOutput(const FuncGraphPtr& func_graph, const | |||||
| SetNodeOutputType(arg, output_proto->mutable_type()); | SetNodeOutputType(arg, output_proto->mutable_type()); | ||||
| } | } | ||||
| static bool CompareValue(const std::pair<AnfNodePtr, size_t>& x, const std::pair<AnfNodePtr, size_t>& y) { | |||||
| static bool CompareValue(const std::pair<AnfNodePtr, size_t> &x, const std::pair<AnfNodePtr, size_t> &y) { | |||||
| return x.second < y.second; | return x.second < y.second; | ||||
| } | } | ||||
| void ProtoExporter::ExportValueNodes(const std::map<AnfNodePtr, size_t>& const_map, irpb::GraphProto* graph_proto) { | |||||
| void ProtoExporter::ExportValueNodes(const std::map<AnfNodePtr, size_t> &const_map, irpb::GraphProto *graph_proto) { | |||||
| std::vector<std::pair<AnfNodePtr, size_t>> nodes; | std::vector<std::pair<AnfNodePtr, size_t>> nodes; | ||||
| (void)std::transform(const_map.cbegin(), const_map.cend(), std::back_inserter(nodes), | (void)std::transform(const_map.cbegin(), const_map.cend(), std::back_inserter(nodes), | ||||
| [](const std::pair<AnfNodePtr, size_t>& item) { return item; }); | |||||
| [](const std::pair<AnfNodePtr, size_t> &item) { return item; }); | |||||
| sort(nodes.begin(), nodes.end(), CompareValue); | sort(nodes.begin(), nodes.end(), CompareValue); | ||||
| for (auto& item : nodes) { | |||||
| for (auto &item : nodes) { | |||||
| if (graph_proto == nullptr) { | if (graph_proto == nullptr) { | ||||
| MS_LOG(EXCEPTION) << "graph_proto is nullptr"; | MS_LOG(EXCEPTION) << "graph_proto is nullptr"; | ||||
| } | } | ||||
| irpb::NamedValueProto* named_value = graph_proto->add_const_vals(); | |||||
| irpb::NamedValueProto *named_value = graph_proto->add_const_vals(); | |||||
| MS_EXCEPTION_IF_NULL(named_value); | MS_EXCEPTION_IF_NULL(named_value); | ||||
| named_value->set_key(GetConstNodeId(item.second)); | named_value->set_key(GetConstNodeId(item.second)); | ||||
| SetValueToProto(GetValueNode(item.first), named_value->mutable_value()); | SetValueToProto(GetValueNode(item.first), named_value->mutable_value()); | ||||
| @@ -506,7 +506,7 @@ void ProtoExporter::ExportValueNodes(const std::map<AnfNodePtr, size_t>& const_m | |||||
| void ProtoExporter::InitModelInfo() { model_.set_ir_version(irpb::IR_VERSION); } | void ProtoExporter::InitModelInfo() { model_.set_ir_version(irpb::IR_VERSION); } | ||||
| std::string GetFuncGraphProtoString(const FuncGraphPtr& func_graph) { | |||||
| std::string GetFuncGraphProtoString(const FuncGraphPtr &func_graph) { | |||||
| ProtoExporter exporter; | ProtoExporter exporter; | ||||
| return exporter.GetFuncGraphProtoString(func_graph); | return exporter.GetFuncGraphProtoString(func_graph); | ||||
| } | } | ||||
| @@ -36,7 +36,7 @@ Dump::Dump() | |||||
| dump_iter_(0), | dump_iter_(0), | ||||
| cur_iter_(0) {} | cur_iter_(0) {} | ||||
| bool Dump::IsKernelNeedDump(const std::string& kernel_name) { | |||||
| bool Dump::IsKernelNeedDump(const std::string &kernel_name) { | |||||
| if (dump_mode_ == 0) { | if (dump_mode_ == 0) { | ||||
| // Dump All Kernels mode | // Dump All Kernels mode | ||||
| return true; | return true; | ||||
| @@ -49,7 +49,7 @@ bool Dump::IsKernelNeedDump(const std::string& kernel_name) { | |||||
| return false; | return false; | ||||
| } | } | ||||
| bool Dump::ParseDumpConfig(const std::string& dump_config_file) { | |||||
| bool Dump::ParseDumpConfig(const std::string &dump_config_file) { | |||||
| std::ifstream jsonFile(dump_config_file); | std::ifstream jsonFile(dump_config_file); | ||||
| if (!jsonFile.is_open()) { | if (!jsonFile.is_open()) { | ||||
| MS_LOG(ERROR) << dump_config_file << " open failed."; | MS_LOG(ERROR) << dump_config_file << " open failed."; | ||||
| @@ -79,7 +79,7 @@ bool Dump::ParseDumpConfig(const std::string& dump_config_file) { | |||||
| return true; | return true; | ||||
| } | } | ||||
| bool Dump::IsConfigExist(const nlohmann::json& dumpSettings) { | |||||
| bool Dump::IsConfigExist(const nlohmann::json &dumpSettings) { | |||||
| if (dumpSettings.find("trans_flag") == dumpSettings.end() || dumpSettings.find("enable") == dumpSettings.end() || | if (dumpSettings.find("trans_flag") == dumpSettings.end() || dumpSettings.find("enable") == dumpSettings.end() || | ||||
| dumpSettings.find("mode") == dumpSettings.end() || dumpSettings.find("path") == dumpSettings.end() || | dumpSettings.find("mode") == dumpSettings.end() || dumpSettings.find("path") == dumpSettings.end() || | ||||
| dumpSettings.find("net_name") == dumpSettings.end() || dumpSettings.find("iteration") == dumpSettings.end() || | dumpSettings.find("net_name") == dumpSettings.end() || dumpSettings.find("iteration") == dumpSettings.end() || | ||||
| @@ -91,7 +91,7 @@ bool Dump::IsConfigExist(const nlohmann::json& dumpSettings) { | |||||
| return true; | return true; | ||||
| } | } | ||||
| bool Dump::IsConfigValid(const nlohmann::json& dumpSettings) { | |||||
| bool Dump::IsConfigValid(const nlohmann::json &dumpSettings) { | |||||
| auto trans_flag = dumpSettings.at("trans_flag"); | auto trans_flag = dumpSettings.at("trans_flag"); | ||||
| auto enable = dumpSettings.at("enable"); | auto enable = dumpSettings.at("enable"); | ||||
| auto mode = dumpSettings.at("mode"); | auto mode = dumpSettings.at("mode"); | ||||
| @@ -112,14 +112,14 @@ bool Dump::IsConfigValid(const nlohmann::json& dumpSettings) { | |||||
| dump_path_ = path; | dump_path_ = path; | ||||
| dump_net_name_ = net_name; | dump_net_name_ = net_name; | ||||
| dump_iter_ = iteration; | dump_iter_ = iteration; | ||||
| for (const auto& kernel : kernels) { | |||||
| for (const auto &kernel : kernels) { | |||||
| dump_kernels_.push_back(kernel); | dump_kernels_.push_back(kernel); | ||||
| } | } | ||||
| return true; | return true; | ||||
| } | } | ||||
| bool Dump::SetDumpConfFromJsonFile() { | bool Dump::SetDumpConfFromJsonFile() { | ||||
| const char* config_path_str = std::getenv("MINDSPORE_CONFIG_PATH"); | |||||
| const char *config_path_str = std::getenv("MINDSPORE_CONFIG_PATH"); | |||||
| if (config_path_str != nullptr) { | if (config_path_str != nullptr) { | ||||
| MS_LOG(INFO) << "Getenv MINDSPORE_CONFIG_PATH :" << config_path_str; | MS_LOG(INFO) << "Getenv MINDSPORE_CONFIG_PATH :" << config_path_str; | ||||
| } else { | } else { | ||||
| @@ -148,7 +148,7 @@ bool Dump::SetDumpConfFromJsonFile() { | |||||
| return ParseDumpConfig(dump_config_file); | return ParseDumpConfig(dump_config_file); | ||||
| } | } | ||||
| bool Dump::DumpToFile(const std::string& filename, const void* data, size_t len) { | |||||
| bool Dump::DumpToFile(const std::string &filename, const void *data, size_t len) { | |||||
| if (filename.empty() || data == nullptr || len == 0) { | if (filename.empty() || data == nullptr || len == 0) { | ||||
| MS_LOG(ERROR) << "Incorrect parameter."; | MS_LOG(ERROR) << "Incorrect parameter."; | ||||
| return false; | return false; | ||||
| @@ -166,12 +166,12 @@ bool Dump::DumpToFile(const std::string& filename, const void* data, size_t len) | |||||
| MS_LOG(ERROR) << "Open file " << realpath << " fail."; | MS_LOG(ERROR) << "Open file " << realpath << " fail."; | ||||
| return false; | return false; | ||||
| } | } | ||||
| (void)fd.write(reinterpret_cast<const char*>(data), SizeToLong(len)); | |||||
| (void)fd.write(reinterpret_cast<const char *>(data), SizeToLong(len)); | |||||
| fd.close(); | fd.close(); | ||||
| return true; | return true; | ||||
| } | } | ||||
| bool Dump::GetRealPath(const std::string& inpath, std::string* outpath) { | |||||
| bool Dump::GetRealPath(const std::string &inpath, std::string *outpath) { | |||||
| MS_EXCEPTION_IF_NULL(outpath); | MS_EXCEPTION_IF_NULL(outpath); | ||||
| auto path_split_pos = inpath.find_last_of('/'); | auto path_split_pos = inpath.find_last_of('/'); | ||||
| if (path_split_pos == std::string::npos) { | if (path_split_pos == std::string::npos) { | ||||
| @@ -213,7 +213,7 @@ bool Dump::GetRealPath(const std::string& inpath, std::string* outpath) { | |||||
| return true; | return true; | ||||
| } | } | ||||
| bool Dump::CreateNotExistDirs(const std::string& path) { | |||||
| bool Dump::CreateNotExistDirs(const std::string &path) { | |||||
| std::shared_ptr<system::FileSystem> fs = system::Env::GetFileSystem(); | std::shared_ptr<system::FileSystem> fs = system::Env::GetFileSystem(); | ||||
| MS_EXCEPTION_IF_NULL(fs); | MS_EXCEPTION_IF_NULL(fs); | ||||
| char temp_path[PATH_MAX] = {0}; | char temp_path[PATH_MAX] = {0}; | ||||
| @@ -43,11 +43,11 @@ class Dump { | |||||
| uint32_t cur_iter() const { return cur_iter_; } | uint32_t cur_iter() const { return cur_iter_; } | ||||
| bool IsKernelNeedDump(const std::string& kernel_name); | |||||
| bool IsKernelNeedDump(const std::string &kernel_name); | |||||
| bool SetDumpConfFromJsonFile(); | bool SetDumpConfFromJsonFile(); | ||||
| static bool DumpToFile(const std::string& filename, const void* data, size_t len); | |||||
| static bool DumpToFile(const std::string &filename, const void *data, size_t len); | |||||
| protected: | protected: | ||||
| bool dump_enable_; | bool dump_enable_; | ||||
| @@ -59,14 +59,14 @@ class Dump { | |||||
| uint32_t cur_iter_; | uint32_t cur_iter_; | ||||
| std::vector<std::string> dump_kernels_; | std::vector<std::string> dump_kernels_; | ||||
| static bool GetRealPath(const std::string& inpath, std::string* outpath); | |||||
| static bool GetRealPath(const std::string &inpath, std::string *outpath); | |||||
| static bool CreateNotExistDirs(const std::string& path); | |||||
| static bool CreateNotExistDirs(const std::string &path); | |||||
| private: | private: | ||||
| bool ParseDumpConfig(const std::string& dump_config_file); | |||||
| bool IsConfigExist(const nlohmann::json& dumpSettings); | |||||
| bool IsConfigValid(const nlohmann::json& dumpSettings); | |||||
| bool ParseDumpConfig(const std::string &dump_config_file); | |||||
| bool IsConfigExist(const nlohmann::json &dumpSettings); | |||||
| bool IsConfigValid(const nlohmann::json &dumpSettings); | |||||
| }; | }; | ||||
| using DumpConfPtr = std::shared_ptr<Dump>; | using DumpConfPtr = std::shared_ptr<Dump>; | ||||
| @@ -23,7 +23,7 @@ | |||||
| #include "pipeline/parse/python_adapter.h" | #include "pipeline/parse/python_adapter.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| std::string HighLightLine(const std::string& line, int col_begin, int col_end, SourceLineTip tip) { | |||||
| std::string HighLightLine(const std::string &line, int col_begin, int col_end, SourceLineTip tip) { | |||||
| std::string temp_line = line; | std::string temp_line = line; | ||||
| if (col_begin < col_end && col_begin != -1 && col_end <= SizeToInt(temp_line.length()) && | if (col_begin < col_end && col_begin != -1 && col_end <= SizeToInt(temp_line.length()) && | ||||
| tip != kSourceLineTipDiscard) { | tip != kSourceLineTipDiscard) { | ||||
| @@ -101,14 +101,14 @@ DebugInfo::DebugInfo() { | |||||
| name_ = ""; | name_ = ""; | ||||
| } | } | ||||
| DebugInfo::DebugInfo(const std::string& name) { | |||||
| DebugInfo::DebugInfo(const std::string &name) { | |||||
| InitValueFromContext(); | InitValueFromContext(); | ||||
| unique_id_ = gen_unique_id(); | unique_id_ = gen_unique_id(); | ||||
| debug_id_ = -1; | debug_id_ = -1; | ||||
| name_ = name; | name_ = name; | ||||
| } | } | ||||
| DebugInfo::DebugInfo(const LocationPtr& loc) { | |||||
| DebugInfo::DebugInfo(const LocationPtr &loc) { | |||||
| InitValueFromContext(); | InitValueFromContext(); | ||||
| unique_id_ = gen_unique_id(); | unique_id_ = gen_unique_id(); | ||||
| debug_id_ = -1; | debug_id_ = -1; | ||||
| @@ -126,7 +126,7 @@ int64_t DebugInfo::debug_id() { | |||||
| } | } | ||||
| int64_t DebugInfo::unique_id_through_copy() const { | int64_t DebugInfo::unique_id_through_copy() const { | ||||
| TraceInfoPtr trace_info = const_cast<DebugInfo*>(this)->trace_info(); | |||||
| TraceInfoPtr trace_info = const_cast<DebugInfo *>(this)->trace_info(); | |||||
| if (trace_info != nullptr) { | if (trace_info != nullptr) { | ||||
| if (trace_info->isa<TraceCopy>() && trace_info->debug_info() != nullptr) { | if (trace_info->isa<TraceCopy>() && trace_info->debug_info() != nullptr) { | ||||
| return trace_info->debug_info()->unique_id_through_copy(); | return trace_info->debug_info()->unique_id_through_copy(); | ||||
| @@ -172,7 +172,7 @@ LocationPtr GraphDebugInfo::location() { | |||||
| } | } | ||||
| return DebugInfo::location(); | return DebugInfo::location(); | ||||
| } | } | ||||
| void GraphDebugInfo::set_deco_location(const LocationPtr& deco_list_loc) { deco_loc_ = deco_list_loc; } | |||||
| void GraphDebugInfo::set_deco_location(const LocationPtr &deco_list_loc) { deco_loc_ = deco_list_loc; } | |||||
| TraceContextPtr TraceManager::CurrentContextInfo() { | TraceContextPtr TraceManager::CurrentContextInfo() { | ||||
| if (!TraceManager::trace_context_stack_.empty()) { | if (!TraceManager::trace_context_stack_.empty()) { | ||||
| @@ -181,18 +181,18 @@ TraceContextPtr TraceManager::CurrentContextInfo() { | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| void TraceManager::DebugTrace(const std::string& func_name, const LocationPtr& location) { | |||||
| void TraceManager::DebugTrace(const std::string &func_name, const LocationPtr &location) { | |||||
| TraceContextPtr context = std::make_shared<TraceContext>(location); | TraceContextPtr context = std::make_shared<TraceContext>(location); | ||||
| context->set_func_name(func_name); | context->set_func_name(func_name); | ||||
| TraceManager::trace_context_stack_.push(context); | TraceManager::trace_context_stack_.push(context); | ||||
| } | } | ||||
| void TraceManager::DebugTrace(const LocationPtr& location) { | |||||
| void TraceManager::DebugTrace(const LocationPtr &location) { | |||||
| TraceContextPtr context = std::make_shared<TraceContext>(location); | TraceContextPtr context = std::make_shared<TraceContext>(location); | ||||
| TraceManager::trace_context_stack_.push(context); | TraceManager::trace_context_stack_.push(context); | ||||
| } | } | ||||
| void TraceManager::DebugTrace(const TraceInfoPtr& trace_info) { | |||||
| void TraceManager::DebugTrace(const TraceInfoPtr &trace_info) { | |||||
| if (trace_info == nullptr) { | if (trace_info == nullptr) { | ||||
| MS_LOG(EXCEPTION) << "DebugTrace wrong traced info is null"; | MS_LOG(EXCEPTION) << "DebugTrace wrong traced info is null"; | ||||
| } | } | ||||
| @@ -203,7 +203,7 @@ void TraceManager::DebugTrace(const TraceInfoPtr& trace_info) { | |||||
| TraceManager::trace_context_stack_.push(context); | TraceManager::trace_context_stack_.push(context); | ||||
| } | } | ||||
| void TraceManager::DebugTrace(const DebugInfoPtr& debug_info, const TraceInfoPtr& trace_info) { | |||||
| void TraceManager::DebugTrace(const DebugInfoPtr &debug_info, const TraceInfoPtr &trace_info) { | |||||
| if (trace_info == nullptr) { | if (trace_info == nullptr) { | ||||
| MS_LOG(EXCEPTION) << "DebugTrace wrong traced info is null"; | MS_LOG(EXCEPTION) << "DebugTrace wrong traced info is null"; | ||||
| } | } | ||||
| @@ -37,9 +37,9 @@ enum SourceLineTip { kSourceLineTipDiscard = 0, kSourceLineTipNextLine = 1, kSou | |||||
| // Location class record the location in source code. | // Location class record the location in source code. | ||||
| class Location { | class Location { | ||||
| public: | public: | ||||
| Location(const std::string& file_name, int line, int column, int line_end, int column_end) | |||||
| Location(const std::string &file_name, int line, int column, int line_end, int column_end) | |||||
| : file_name_(file_name), line_(line), column_(column), line_end_(line_end), column_end_(column_end) {} | : file_name_(file_name), line_(line), column_(column), line_end_(line_end), column_end_(column_end) {} | ||||
| Location(const Location& loc) | |||||
| Location(const Location &loc) | |||||
| : file_name_(loc.file_name_), | : file_name_(loc.file_name_), | ||||
| line_(loc.line_), | line_(loc.line_), | ||||
| column_(loc.column_), | column_(loc.column_), | ||||
| @@ -77,21 +77,21 @@ class TraceManager { | |||||
| TraceManager() = default; | TraceManager() = default; | ||||
| ~TraceManager() = default; | ~TraceManager() = default; | ||||
| static TraceContextPtr CurrentContextInfo(); | static TraceContextPtr CurrentContextInfo(); | ||||
| static void DebugTrace(const std::string& func_name, const LocationPtr& location); | |||||
| static void DebugTrace(const LocationPtr& location); | |||||
| static void DebugTrace(const TraceInfoPtr& trace_info); | |||||
| static void DebugTrace(const std::string &func_name, const LocationPtr &location); | |||||
| static void DebugTrace(const LocationPtr &location); | |||||
| static void DebugTrace(const TraceInfoPtr &trace_info); | |||||
| // debug trace with a cloned trace info with debug_info | // debug trace with a cloned trace info with debug_info | ||||
| static void DebugTrace(const DebugInfoPtr& debug_info, const TraceInfoPtr& trace_info); | |||||
| static void DebugTrace(const DebugInfoPtr &debug_info, const TraceInfoPtr &trace_info); | |||||
| static void EndTrace(); | static void EndTrace(); | ||||
| static std::stack<TraceContextPtr> trace_context_stack_; | static std::stack<TraceContextPtr> trace_context_stack_; | ||||
| }; | }; | ||||
| class TraceGuard { | class TraceGuard { | ||||
| public: | public: | ||||
| explicit TraceGuard(const std::string func_name, const LocationPtr& location) { | |||||
| explicit TraceGuard(const std::string func_name, const LocationPtr &location) { | |||||
| TraceManager::DebugTrace(func_name, location); | TraceManager::DebugTrace(func_name, location); | ||||
| } | } | ||||
| explicit TraceGuard(const LocationPtr& location) { TraceManager::DebugTrace(location); } | |||||
| explicit TraceGuard(const LocationPtr &location) { TraceManager::DebugTrace(location); } | |||||
| ~TraceGuard() { TraceManager::EndTrace(); } | ~TraceGuard() { TraceManager::EndTrace(); } | ||||
| }; | }; | ||||
| @@ -106,23 +106,23 @@ class TraceContext { | |||||
| public: | public: | ||||
| ~TraceContext() = default; | ~TraceContext() = default; | ||||
| explicit TraceContext(const LocationPtr& loc) { | |||||
| explicit TraceContext(const LocationPtr &loc) { | |||||
| ProcessAttributeFromContext(); | ProcessAttributeFromContext(); | ||||
| location_ = loc; | location_ = loc; | ||||
| } | } | ||||
| explicit TraceContext(const std::string& func_name) { | |||||
| explicit TraceContext(const std::string &func_name) { | |||||
| ProcessAttributeFromContext(); | ProcessAttributeFromContext(); | ||||
| func_name_ = func_name; | func_name_ = func_name; | ||||
| } | } | ||||
| explicit TraceContext(const TraceInfoPtr& trace_info) { | |||||
| explicit TraceContext(const TraceInfoPtr &trace_info) { | |||||
| ProcessAttributeFromContext(); | ProcessAttributeFromContext(); | ||||
| trace_info_ = trace_info; | trace_info_ = trace_info; | ||||
| } | } | ||||
| void set_location(const LocationPtr& loc) { location_ = loc; } | |||||
| void set_location(const LocationPtr &loc) { location_ = loc; } | |||||
| LocationPtr location() { return location_; } | LocationPtr location() { return location_; } | ||||
| void set_trace_info(const TraceInfoPtr& trace_info) { trace_info_ = trace_info; } | |||||
| void set_trace_info(const TraceInfoPtr &trace_info) { trace_info_ = trace_info; } | |||||
| TraceInfoPtr trace_info() { return trace_info_; } | TraceInfoPtr trace_info() { return trace_info_; } | ||||
| void set_func_name(const std::string& func_name) { func_name_ = func_name; } | |||||
| void set_func_name(const std::string &func_name) { func_name_ = func_name; } | |||||
| std::string func_name() { return func_name_; } | std::string func_name() { return func_name_; } | ||||
| }; | }; | ||||
| @@ -130,9 +130,9 @@ class DebugInfo : public Base { | |||||
| public: | public: | ||||
| DebugInfo(); | DebugInfo(); | ||||
| explicit DebugInfo(const std::string& name); | |||||
| explicit DebugInfo(const std::string &name); | |||||
| explicit DebugInfo(const LocationPtr& loc); | |||||
| explicit DebugInfo(const LocationPtr &loc); | |||||
| virtual ~DebugInfo() = default; | virtual ~DebugInfo() = default; | ||||
| MS_DECLARE_PARENT(DebugInfo, Base); | MS_DECLARE_PARENT(DebugInfo, Base); | ||||
| @@ -141,12 +141,12 @@ class DebugInfo : public Base { | |||||
| int64_t unique_id_through_copy() const; | int64_t unique_id_through_copy() const; | ||||
| std::string get_id() { return std::to_string(debug_id()); } | std::string get_id() { return std::to_string(debug_id()); } | ||||
| void set_trace_info(const TraceInfoPtr& trace_info) { trace_info_ = trace_info; } | |||||
| void set_trace_info(const TraceInfoPtr &trace_info) { trace_info_ = trace_info; } | |||||
| TraceInfoPtr trace_info() { return trace_info_; } | TraceInfoPtr trace_info() { return trace_info_; } | ||||
| void set_location(const LocationPtr& loc) { location_ = loc; } | |||||
| void set_location(const LocationPtr &loc) { location_ = loc; } | |||||
| virtual LocationPtr location() { return location_; } | virtual LocationPtr location() { return location_; } | ||||
| std::string name() { return name_; } | std::string name() { return name_; } | ||||
| void set_name(const std::string& name) { name_ = name; } | |||||
| void set_name(const std::string &name) { name_ = name; } | |||||
| virtual std::string debug_name(); | virtual std::string debug_name(); | ||||
| virtual std::string get_python_func_belonged() { return ""; } | virtual std::string get_python_func_belonged() { return ""; } | ||||
| @@ -186,7 +186,7 @@ class NodeDebugInfo : public DebugInfo { | |||||
| py_func_belonged_ = context_info->func_name(); | py_func_belonged_ = context_info->func_name(); | ||||
| } | } | ||||
| } | } | ||||
| explicit NodeDebugInfo(const std::string& name) : DebugInfo(name) { | |||||
| explicit NodeDebugInfo(const std::string &name) : DebugInfo(name) { | |||||
| if (TraceManager::CurrentContextInfo() != nullptr) { | if (TraceManager::CurrentContextInfo() != nullptr) { | ||||
| auto context_info = TraceManager::CurrentContextInfo(); | auto context_info = TraceManager::CurrentContextInfo(); | ||||
| py_func_belonged_ = context_info->func_name(); | py_func_belonged_ = context_info->func_name(); | ||||
| @@ -195,9 +195,9 @@ class NodeDebugInfo : public DebugInfo { | |||||
| ~NodeDebugInfo() override = default; | ~NodeDebugInfo() override = default; | ||||
| std::string debug_name() override; | std::string debug_name() override; | ||||
| void set_node(const std::shared_ptr<AnfNode>& node) { node_ = AnfNodeWeakPtr(node); } | |||||
| void set_node(const std::shared_ptr<AnfNode> &node) { node_ = AnfNodeWeakPtr(node); } | |||||
| std::shared_ptr<AnfNode> get_node() const { return node_.lock(); } | std::shared_ptr<AnfNode> get_node() const { return node_.lock(); } | ||||
| void set_py_func_belonged(const std::string& name) { py_func_belonged_ = name; } | |||||
| void set_py_func_belonged(const std::string &name) { py_func_belonged_ = name; } | |||||
| std::string get_python_func_belonged() override { return py_func_belonged_; } | std::string get_python_func_belonged() override { return py_func_belonged_; } | ||||
| AnfNodeWeakPtr node_; | AnfNodeWeakPtr node_; | ||||
| std::string py_func_belonged_; | std::string py_func_belonged_; | ||||
| @@ -214,7 +214,7 @@ class GraphDebugInfo : public DebugInfo { | |||||
| } | } | ||||
| } | } | ||||
| explicit GraphDebugInfo(const std::string& name) : DebugInfo(name) { | |||||
| explicit GraphDebugInfo(const std::string &name) : DebugInfo(name) { | |||||
| if (TraceManager::CurrentContextInfo() != nullptr) { | if (TraceManager::CurrentContextInfo() != nullptr) { | ||||
| auto context_info = TraceManager::CurrentContextInfo(); | auto context_info = TraceManager::CurrentContextInfo(); | ||||
| py_func_name_ = context_info->func_name(); | py_func_name_ = context_info->func_name(); | ||||
| @@ -225,11 +225,11 @@ class GraphDebugInfo : public DebugInfo { | |||||
| std::string debug_name() override; | std::string debug_name() override; | ||||
| LocationPtr location() override; | LocationPtr location() override; | ||||
| LocationPtr deco_location() { return deco_loc_; } | LocationPtr deco_location() { return deco_loc_; } | ||||
| void set_graph(const FuncGraphPtr& func_graph) { func_graph_ = FuncGraphWeakPtr(func_graph); } | |||||
| void set_graph(const FuncGraphPtr &func_graph) { func_graph_ = FuncGraphWeakPtr(func_graph); } | |||||
| FuncGraphPtr get_graph() const { return func_graph_.lock(); } | FuncGraphPtr get_graph() const { return func_graph_.lock(); } | ||||
| void set_full_name(const std::string& name) { full_name_ = name; } | |||||
| void set_full_name(const std::string &name) { full_name_ = name; } | |||||
| std::string get_full_name() { return full_name_; } | std::string get_full_name() { return full_name_; } | ||||
| void set_deco_location(const LocationPtr& deco_list_loc); | |||||
| void set_deco_location(const LocationPtr &deco_list_loc); | |||||
| std::string get_python_func_belonged() override { return py_func_name_; } | std::string get_python_func_belonged() override { return py_func_name_; } | ||||
| FuncGraphWeakPtr func_graph_; | FuncGraphWeakPtr func_graph_; | ||||
| LocationPtr deco_loc_; | LocationPtr deco_loc_; | ||||
| @@ -31,7 +31,7 @@ struct NameWithTrace { | |||||
| std::string name; | std::string name; | ||||
| std::vector<std::string> trace_labels; | std::vector<std::string> trace_labels; | ||||
| }; | }; | ||||
| static std::string GetTraceName(const TraceInfoPtr& trace_info, TraceLabelType trace_label) { | |||||
| static std::string GetTraceName(const TraceInfoPtr &trace_info, TraceLabelType trace_label) { | |||||
| switch (trace_label) { | switch (trace_label) { | ||||
| case TraceLabelType::kShortSymbol: | case TraceLabelType::kShortSymbol: | ||||
| return trace_info->symbol(); | return trace_info->symbol(); | ||||
| @@ -42,7 +42,7 @@ static std::string GetTraceName(const TraceInfoPtr& trace_info, TraceLabelType t | |||||
| } | } | ||||
| } | } | ||||
| NameWithTrace RootName(const DebugInfoPtr& debug_info, TraceLabelType trace_label) { | |||||
| NameWithTrace RootName(const DebugInfoPtr &debug_info, TraceLabelType trace_label) { | |||||
| NameWithTrace trace_name; | NameWithTrace trace_name; | ||||
| // find debug info after Resolve/ExpandJ/GenMetaFuncGraph, it is a new node | // find debug info after Resolve/ExpandJ/GenMetaFuncGraph, it is a new node | ||||
| auto temp_info = debug_info; | auto temp_info = debug_info; | ||||
| @@ -66,9 +66,9 @@ NameWithTrace RootName(const DebugInfoPtr& debug_info, TraceLabelType trace_labe | |||||
| return trace_name; | return trace_name; | ||||
| } | } | ||||
| std::string CombineTraceTypes(const std::string& root_name, const std::vector<std::string>& trace_labels) { | |||||
| std::string CombineTraceTypes(const std::string &root_name, const std::vector<std::string> &trace_labels) { | |||||
| std::string tags = ""; | std::string tags = ""; | ||||
| for (auto& itr : trace_labels) { | |||||
| for (auto &itr : trace_labels) { | |||||
| std::string symbol = itr; | std::string symbol = itr; | ||||
| tags = tags + symbol; | tags = tags + symbol; | ||||
| } | } | ||||
| @@ -76,12 +76,12 @@ std::string CombineTraceTypes(const std::string& root_name, const std::vector<st | |||||
| } | } | ||||
| // get the label name of the node debug info | // get the label name of the node debug info | ||||
| std::string LabelString(const DebugInfoPtr& debug_info, TraceLabelType trace_label) { | |||||
| std::string LabelString(const DebugInfoPtr &debug_info, TraceLabelType trace_label) { | |||||
| NameWithTrace trace_name = RootName(debug_info, trace_label); | NameWithTrace trace_name = RootName(debug_info, trace_label); | ||||
| return CombineTraceTypes(trace_name.name, trace_name.trace_labels); | return CombineTraceTypes(trace_name.name, trace_name.trace_labels); | ||||
| } | } | ||||
| std::string CombineUniqueID(const DebugInfoPtr& debug_info) { | |||||
| std::string CombineUniqueID(const DebugInfoPtr &debug_info) { | |||||
| auto temp_info = debug_info; | auto temp_info = debug_info; | ||||
| std::string label = ""; | std::string label = ""; | ||||
| while (temp_info != nullptr) { | while (temp_info != nullptr) { | ||||
| @@ -103,9 +103,9 @@ std::string CombineUniqueID(const DebugInfoPtr& debug_info) { | |||||
| } | } | ||||
| // get trace with unique id chain | // get trace with unique id chain | ||||
| std::string LabelStringUnique(const DebugInfoPtr& debug_info) { return CombineUniqueID(debug_info); } | |||||
| std::string LabelStringUnique(const DebugInfoPtr &debug_info) { return CombineUniqueID(debug_info); } | |||||
| std::string Label(const DebugInfoPtr& debug_info, TraceLabelType trace_label) { | |||||
| std::string Label(const DebugInfoPtr &debug_info, TraceLabelType trace_label) { | |||||
| if (GetGlobalTraceLabelType() == TraceLabelType::kWithUniqueId) { | if (GetGlobalTraceLabelType() == TraceLabelType::kWithUniqueId) { | ||||
| return LabelStringUnique(debug_info); | return LabelStringUnique(debug_info); | ||||
| } | } | ||||
| @@ -29,7 +29,7 @@ namespace label_manage { | |||||
| enum class TraceLabelType { kShortSymbol, kFullName, kWithUniqueId }; | enum class TraceLabelType { kShortSymbol, kFullName, kWithUniqueId }; | ||||
| TraceLabelType GetGlobalTraceLabelType(); | TraceLabelType GetGlobalTraceLabelType(); | ||||
| void SetGlobalTraceLabelType(TraceLabelType label_type); | void SetGlobalTraceLabelType(TraceLabelType label_type); | ||||
| std::string Label(const DebugInfoPtr& debug_info, TraceLabelType trace_type = TraceLabelType::kShortSymbol); | |||||
| std::string Label(const DebugInfoPtr &debug_info, TraceLabelType trace_type = TraceLabelType::kShortSymbol); | |||||
| } // namespace label_manage | } // namespace label_manage | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -37,7 +37,7 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| // namespace to support debug trace infomation | // namespace to support debug trace infomation | ||||
| namespace trace { | namespace trace { | ||||
| std::string GetAbstractStr(const abstract::AbstractBasePtr& abs) { | |||||
| std::string GetAbstractStr(const abstract::AbstractBasePtr &abs) { | |||||
| if (abs == nullptr) { | if (abs == nullptr) { | ||||
| return "Null Abstract"; | return "Null Abstract"; | ||||
| } | } | ||||
| @@ -69,7 +69,7 @@ std::vector<DebugInfoPtr> GetSourceCodeDebugInfoVec(DebugInfoPtr debug_info) { | |||||
| return debug_with_loc_vec; | return debug_with_loc_vec; | ||||
| } | } | ||||
| DebugInfoPtr GetSourceCodeDebugInfo(const DebugInfoPtr& info) { | |||||
| DebugInfoPtr GetSourceCodeDebugInfo(const DebugInfoPtr &info) { | |||||
| auto debug_with_loc_vec = GetSourceCodeDebugInfoVec(info); | auto debug_with_loc_vec = GetSourceCodeDebugInfoVec(info); | ||||
| if (debug_with_loc_vec.size() > 0) { | if (debug_with_loc_vec.size() > 0) { | ||||
| return debug_with_loc_vec[0]; | return debug_with_loc_vec[0]; | ||||
| @@ -78,7 +78,7 @@ DebugInfoPtr GetSourceCodeDebugInfo(const DebugInfoPtr& info) { | |||||
| } | } | ||||
| } | } | ||||
| std::string GetDebugInfo(const DebugInfoPtr& info, SourceLineTip tip) { | |||||
| std::string GetDebugInfo(const DebugInfoPtr &info, SourceLineTip tip) { | |||||
| if (info == nullptr) { | if (info == nullptr) { | ||||
| return ""; | return ""; | ||||
| } | } | ||||
| @@ -91,7 +91,7 @@ std::string GetDebugInfo(const DebugInfoPtr& info, SourceLineTip tip) { | |||||
| // a trace info identifies a node transform, so we can trace the node transform through | // a trace info identifies a node transform, so we can trace the node transform through | ||||
| // a link of trace info and debug info | // a link of trace info and debug info | ||||
| std::string GetInfoWithAction(const std::vector<DebugInfoPtr>& info_vec, SourceLineTip tip) { | |||||
| std::string GetInfoWithAction(const std::vector<DebugInfoPtr> &info_vec, SourceLineTip tip) { | |||||
| if (info_vec.size() < 1) { | if (info_vec.size() < 1) { | ||||
| return ""; | return ""; | ||||
| } | } | ||||
| @@ -109,7 +109,7 @@ std::string GetInfoWithAction(const std::vector<DebugInfoPtr>& info_vec, SourceL | |||||
| return traced_info; | return traced_info; | ||||
| } | } | ||||
| std::string GetTracedDebugInfo(const DebugInfoPtr& info, SourceLineTip tip) { | |||||
| std::string GetTracedDebugInfo(const DebugInfoPtr &info, SourceLineTip tip) { | |||||
| if (info == nullptr) { | if (info == nullptr) { | ||||
| return ""; | return ""; | ||||
| } | } | ||||
| @@ -124,7 +124,7 @@ std::string GetTracedDebugInfo(const DebugInfoPtr& info, SourceLineTip tip) { | |||||
| return ""; | return ""; | ||||
| } | } | ||||
| std::string GetDebugInfo(const DebugInfoPtr& info, const std::string& prefix, SourceLineTip tip) { | |||||
| std::string GetDebugInfo(const DebugInfoPtr &info, const std::string &prefix, SourceLineTip tip) { | |||||
| std::ostringstream oss; | std::ostringstream oss; | ||||
| if (info == nullptr) { | if (info == nullptr) { | ||||
| return ""; | return ""; | ||||
| @@ -139,7 +139,7 @@ std::string GetDebugInfo(const DebugInfoPtr& info, const std::string& prefix, So | |||||
| return oss.str(); | return oss.str(); | ||||
| } | } | ||||
| std::string GetGraphParamString(const FuncGraphPtr& graph, abstract::AbstractBasePtrList args_spec_list) { | |||||
| std::string GetGraphParamString(const FuncGraphPtr &graph, abstract::AbstractBasePtrList args_spec_list) { | |||||
| std::ostringstream oss; | std::ostringstream oss; | ||||
| oss << "graph:" << graph->ToString() << " with args["; | oss << "graph:" << graph->ToString() << " with args["; | ||||
| auto params = graph->parameters(); | auto params = graph->parameters(); | ||||
| @@ -151,8 +151,8 @@ std::string GetGraphParamString(const FuncGraphPtr& graph, abstract::AbstractBas | |||||
| return oss.str(); | return oss.str(); | ||||
| } | } | ||||
| void DumpInferStack(std::ostringstream& oss) { | |||||
| auto& infer_stack = GetCurrenGraphInferStack(); | |||||
| void DumpInferStack(std::ostringstream &oss) { | |||||
| auto &infer_stack = GetCurrenGraphInferStack(); | |||||
| if (infer_stack.empty()) { | if (infer_stack.empty()) { | ||||
| return; | return; | ||||
| } | } | ||||
| @@ -164,7 +164,7 @@ void DumpInferStack(std::ostringstream& oss) { | |||||
| } | } | ||||
| std::reverse(infer_vec.begin(), infer_vec.end()); | std::reverse(infer_vec.begin(), infer_vec.end()); | ||||
| int index = 0; | int index = 0; | ||||
| for (auto& item : infer_vec) { | |||||
| for (auto &item : infer_vec) { | |||||
| auto graph_infer = std::dynamic_pointer_cast<abstract::BaseFuncGraphEvaluator>(item.first); | auto graph_infer = std::dynamic_pointer_cast<abstract::BaseFuncGraphEvaluator>(item.first); | ||||
| if (graph_infer == nullptr) { | if (graph_infer == nullptr) { | ||||
| MS_LOG(WARNING) << "DumpInferStack failed, got null graph evaluator"; | MS_LOG(WARNING) << "DumpInferStack failed, got null graph evaluator"; | ||||
| @@ -183,7 +183,7 @@ void DumpInferStack(std::ostringstream& oss) { | |||||
| } | } | ||||
| void TraceGraphInfer() { | void TraceGraphInfer() { | ||||
| auto& infer_stack = GetCurrenGraphInferStack(); | |||||
| auto &infer_stack = GetCurrenGraphInferStack(); | |||||
| std::ostringstream oss; | std::ostringstream oss; | ||||
| if (infer_stack.empty()) { | if (infer_stack.empty()) { | ||||
| return; | return; | ||||
| @@ -200,15 +200,15 @@ class AnalyzedFuncGraphExporter : public AnfExporter { | |||||
| AnalyzedFuncGraphExporter() : AnfExporter("", true, false) {} | AnalyzedFuncGraphExporter() : AnfExporter("", true, false) {} | ||||
| ~AnalyzedFuncGraphExporter() override = default; | ~AnalyzedFuncGraphExporter() override = default; | ||||
| void ExportFuncGraph(const std::string& filename, const std::vector<abstract::AnfNodeConfigPtr>& node_cfgs); | |||||
| void ExportFuncGraph(const std::string &filename, const std::vector<abstract::AnfNodeConfigPtr> &node_cfgs); | |||||
| private: | private: | ||||
| std::string GetNodeType(const AnfNodePtr& nd) override; | |||||
| std::string GetNodeType(const AnfNodePtr &nd) override; | |||||
| }; | }; | ||||
| std::unordered_map<FuncGraphPtr, TaggedNodeMap> CalcTaggedFuncGraphs() { | std::unordered_map<FuncGraphPtr, TaggedNodeMap> CalcTaggedFuncGraphs() { | ||||
| std::unordered_map<FuncGraphPtr, TaggedNodeMap> tagged_func_graphs; | std::unordered_map<FuncGraphPtr, TaggedNodeMap> tagged_func_graphs; | ||||
| auto& list = GetCNodeDebugStack(); | |||||
| auto &list = GetCNodeDebugStack(); | |||||
| for (size_t i = 0; i < list.size(); ++i) { | for (size_t i = 0; i < list.size(); ++i) { | ||||
| auto node_cfg = list[i]; | auto node_cfg = list[i]; | ||||
| auto fg = node_cfg->context()->func_graph(); | auto fg = node_cfg->context()->func_graph(); | ||||
| @@ -223,7 +223,7 @@ void OutputAnalyzedGraphWithType() { | |||||
| exporter.ExportFuncGraph("analyze_fail.dat", GetCNodeDebugStack()); | exporter.ExportFuncGraph("analyze_fail.dat", GetCNodeDebugStack()); | ||||
| } | } | ||||
| std::string AnalyzedFuncGraphExporter::GetNodeType(const AnfNodePtr& node) { | |||||
| std::string AnalyzedFuncGraphExporter::GetNodeType(const AnfNodePtr &node) { | |||||
| if (node_cfg_ == nullptr) { | if (node_cfg_ == nullptr) { | ||||
| return AnfExporter::GetNodeType(node); | return AnfExporter::GetNodeType(node); | ||||
| } | } | ||||
| @@ -248,8 +248,8 @@ std::string AnalyzedFuncGraphExporter::GetNodeType(const AnfNodePtr& node) { | |||||
| return oss.str(); | return oss.str(); | ||||
| } | } | ||||
| void AnalyzedFuncGraphExporter::ExportFuncGraph(const std::string& filename, | |||||
| const std::vector<abstract::AnfNodeConfigPtr>& node_cfgs) { | |||||
| void AnalyzedFuncGraphExporter::ExportFuncGraph(const std::string &filename, | |||||
| const std::vector<abstract::AnfNodeConfigPtr> &node_cfgs) { | |||||
| if (node_cfgs.empty()) { | if (node_cfgs.empty()) { | ||||
| MS_LOG(DEBUG) << "Node configs is empty"; | MS_LOG(DEBUG) << "Node configs is empty"; | ||||
| return; | return; | ||||
| @@ -265,7 +265,7 @@ void AnalyzedFuncGraphExporter::ExportFuncGraph(const std::string& filename, | |||||
| auto tagged_func_graphs = CalcTaggedFuncGraphs(); | auto tagged_func_graphs = CalcTaggedFuncGraphs(); | ||||
| // first output graph on the analysis stack | // first output graph on the analysis stack | ||||
| for (const auto& node_cfg : node_cfgs) { | |||||
| for (const auto &node_cfg : node_cfgs) { | |||||
| auto fg = node_cfg->context()->func_graph(); | auto fg = node_cfg->context()->func_graph(); | ||||
| // the graph is already output, skip it | // the graph is already output, skip it | ||||
| if (exported.find(fg) != exported.end()) { | if (exported.find(fg) != exported.end()) { | ||||
| @@ -296,7 +296,7 @@ void AnalyzedFuncGraphExporter::ExportFuncGraph(const std::string& filename, | |||||
| ofs.close(); | ofs.close(); | ||||
| } | } | ||||
| void GetInferStackInfo(std::ostringstream& oss) { | |||||
| void GetInferStackInfo(std::ostringstream &oss) { | |||||
| MS_LOG(INFO) << "Get graph analysis information begin"; | MS_LOG(INFO) << "Get graph analysis information begin"; | ||||
| auto stack = GetCNodeDebugStack(); | auto stack = GetCNodeDebugStack(); | ||||
| if (stack.empty()) { | if (stack.empty()) { | ||||
| @@ -336,7 +336,7 @@ void GetInferStackInfo(std::ostringstream& oss) { | |||||
| static std::stack<std::pair<abstract::EvaluatorPtr, abstract::AnfNodeConfigPtr>> graph_infer_stack; | static std::stack<std::pair<abstract::EvaluatorPtr, abstract::AnfNodeConfigPtr>> graph_infer_stack; | ||||
| // trace the cnode infer debug info | // trace the cnode infer debug info | ||||
| static std::vector<abstract::AnfNodeConfigPtr> cnode_debug_stack{}; | static std::vector<abstract::AnfNodeConfigPtr> cnode_debug_stack{}; | ||||
| void TraceGraphInferEnter(const abstract::EvaluatorPtr& eval, const abstract::AnfNodeConfigPtr& node) { | |||||
| void TraceGraphInferEnter(const abstract::EvaluatorPtr &eval, const abstract::AnfNodeConfigPtr &node) { | |||||
| if (eval == nullptr) { | if (eval == nullptr) { | ||||
| MS_LOG(EXCEPTION) << "GraphInferEnter got null eval"; | MS_LOG(EXCEPTION) << "GraphInferEnter got null eval"; | ||||
| } | } | ||||
| @@ -345,7 +345,7 @@ void TraceGraphInferEnter(const abstract::EvaluatorPtr& eval, const abstract::An | |||||
| } | } | ||||
| } | } | ||||
| void TraceGraphInferLeave(const abstract::EvaluatorPtr& eval) { | |||||
| void TraceGraphInferLeave(const abstract::EvaluatorPtr &eval) { | |||||
| if (eval == nullptr) { | if (eval == nullptr) { | ||||
| MS_LOG(EXCEPTION) << "GraphInferEnter got null eval"; | MS_LOG(EXCEPTION) << "GraphInferEnter got null eval"; | ||||
| } | } | ||||
| @@ -354,13 +354,13 @@ void TraceGraphInferLeave(const abstract::EvaluatorPtr& eval) { | |||||
| } | } | ||||
| } | } | ||||
| void TraceInferCNodeEnter(const abstract::AnfNodeConfigPtr& node_cfg) { cnode_debug_stack.push_back(node_cfg); } | |||||
| void TraceInferCNodeEnter(const abstract::AnfNodeConfigPtr &node_cfg) { cnode_debug_stack.push_back(node_cfg); } | |||||
| void TraceInferCNodeLeave() { cnode_debug_stack.pop_back(); } | void TraceInferCNodeLeave() { cnode_debug_stack.pop_back(); } | ||||
| std::vector<abstract::AnfNodeConfigPtr>& GetCNodeDebugStack() { return cnode_debug_stack; } | |||||
| std::vector<abstract::AnfNodeConfigPtr> &GetCNodeDebugStack() { return cnode_debug_stack; } | |||||
| std::stack<std::pair<abstract::EvaluatorPtr, abstract::AnfNodeConfigPtr>>& GetCurrenGraphInferStack() { | |||||
| std::stack<std::pair<abstract::EvaluatorPtr, abstract::AnfNodeConfigPtr>> &GetCurrenGraphInferStack() { | |||||
| return graph_infer_stack; | return graph_infer_stack; | ||||
| } | } | ||||
| void ClearTraceStack() { | void ClearTraceStack() { | ||||
| @@ -31,19 +31,19 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace trace { | namespace trace { | ||||
| std::string GetDebugInfo(const DebugInfoPtr& info, SourceLineTip tip = kSourceLineTipNextLine); | |||||
| std::string GetDebugInfo(const DebugInfoPtr& info, const std::string& prefix, | |||||
| std::string GetDebugInfo(const DebugInfoPtr &info, SourceLineTip tip = kSourceLineTipNextLine); | |||||
| std::string GetDebugInfo(const DebugInfoPtr &info, const std::string &prefix, | |||||
| SourceLineTip tip = kSourceLineTipNextLine); | SourceLineTip tip = kSourceLineTipNextLine); | ||||
| DebugInfoPtr GetSourceCodeDebugInfo(const DebugInfoPtr& info); | |||||
| DebugInfoPtr GetSourceCodeDebugInfo(const DebugInfoPtr &info); | |||||
| void TraceGraphInfer(); | void TraceGraphInfer(); | ||||
| void GetInferStackInfo(std::ostringstream& oss); | |||||
| void TraceGraphInferEnter(const abstract::EvaluatorPtr& eval, const abstract::AnfNodeConfigPtr& node); | |||||
| void TraceGraphInferLeave(const abstract::EvaluatorPtr& eval); | |||||
| void TraceInferCNodeEnter(const abstract::AnfNodeConfigPtr& node_cfg); | |||||
| void GetInferStackInfo(std::ostringstream &oss); | |||||
| void TraceGraphInferEnter(const abstract::EvaluatorPtr &eval, const abstract::AnfNodeConfigPtr &node); | |||||
| void TraceGraphInferLeave(const abstract::EvaluatorPtr &eval); | |||||
| void TraceInferCNodeEnter(const abstract::AnfNodeConfigPtr &node_cfg); | |||||
| void TraceInferCNodeLeave(); | void TraceInferCNodeLeave(); | ||||
| std::vector<abstract::AnfNodeConfigPtr>& GetCNodeDebugStack(); | |||||
| std::stack<std::pair<abstract::EvaluatorPtr, abstract::AnfNodeConfigPtr>>& GetCurrenGraphInferStack(); | |||||
| std::string GetAbstractStr(const abstract::AbstractBasePtr& abs); | |||||
| std::vector<abstract::AnfNodeConfigPtr> &GetCNodeDebugStack(); | |||||
| std::stack<std::pair<abstract::EvaluatorPtr, abstract::AnfNodeConfigPtr>> &GetCurrenGraphInferStack(); | |||||
| std::string GetAbstractStr(const abstract::AbstractBasePtr &abs); | |||||
| void ClearTraceStack(); | void ClearTraceStack(); | ||||
| } // namespace trace | } // namespace trace | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -23,7 +23,7 @@ | |||||
| #include "pipeline/parse/python_adapter.h" | #include "pipeline/parse/python_adapter.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| std::string TraceInfo::GetActionBetweenNode(const DebugInfoPtr& info) { | |||||
| std::string TraceInfo::GetActionBetweenNode(const DebugInfoPtr &info) { | |||||
| if (info == nullptr) { | if (info == nullptr) { | ||||
| return ""; | return ""; | ||||
| } | } | ||||
| @@ -40,13 +40,13 @@ using DebugInfoPtr = std::shared_ptr<DebugInfo>; | |||||
| // namespace to support intermediate representation definition | // namespace to support intermediate representation definition | ||||
| class TraceInfo : public Base { | class TraceInfo : public Base { | ||||
| public: | public: | ||||
| TraceInfo(const DebugInfoPtr& info, const std::string& full_name, const std::string& symbol) { | |||||
| TraceInfo(const DebugInfoPtr &info, const std::string &full_name, const std::string &symbol) { | |||||
| symbol_ = symbol; | symbol_ = symbol; | ||||
| full_name_ = full_name; | full_name_ = full_name; | ||||
| name_ = full_name_; | name_ = full_name_; | ||||
| debug_info_ = info; | debug_info_ = info; | ||||
| } | } | ||||
| TraceInfo(const TraceInfo& info) | |||||
| TraceInfo(const TraceInfo &info) | |||||
| : Base(), debug_info_(info.debug_info_), symbol_(info.symbol_), full_name_(info.full_name_), name_(info.name_) {} | : Base(), debug_info_(info.debug_info_), symbol_(info.symbol_), full_name_(info.full_name_), name_(info.name_) {} | ||||
| virtual ~TraceInfo() = default; | virtual ~TraceInfo() = default; | ||||
| MS_DECLARE_PARENT(TraceInfo, Base); | MS_DECLARE_PARENT(TraceInfo, Base); | ||||
| @@ -55,8 +55,8 @@ class TraceInfo : public Base { | |||||
| virtual std::string full_name() { return full_name_; } | virtual std::string full_name() { return full_name_; } | ||||
| virtual TraceInfoPtr clone() { return shared_from_base<TraceInfo>(); } | virtual TraceInfoPtr clone() { return shared_from_base<TraceInfo>(); } | ||||
| virtual std::string action_name() { return ""; } | virtual std::string action_name() { return ""; } | ||||
| virtual std::string GetActionBetweenNode(const DebugInfoPtr& info); | |||||
| void set_debug_info(const DebugInfoPtr& info) { debug_info_ = info; } | |||||
| virtual std::string GetActionBetweenNode(const DebugInfoPtr &info); | |||||
| void set_debug_info(const DebugInfoPtr &info) { debug_info_ = info; } | |||||
| DebugInfoPtr debug_info() { return debug_info_; } | DebugInfoPtr debug_info() { return debug_info_; } | ||||
| DebugInfoPtr DebugInfoHasLoc(); | DebugInfoPtr DebugInfoHasLoc(); | ||||
| std::vector<std::pair<DebugInfoPtr, TraceInfoPtr>> GetSourceCodeDebugInfo(); | std::vector<std::pair<DebugInfoPtr, TraceInfoPtr>> GetSourceCodeDebugInfo(); | ||||
| @@ -70,7 +70,7 @@ class TraceInfo : public Base { | |||||
| class TracePhi : public TraceInfo { | class TracePhi : public TraceInfo { | ||||
| public: | public: | ||||
| explicit TracePhi(const DebugInfoPtr& info) : TraceInfo(info, "phi", "Φ") {} | |||||
| explicit TracePhi(const DebugInfoPtr &info) : TraceInfo(info, "phi", "Φ") {} | |||||
| MS_DECLARE_PARENT(TracePhi, TraceInfo); | MS_DECLARE_PARENT(TracePhi, TraceInfo); | ||||
| ~TracePhi() override = default; | ~TracePhi() override = default; | ||||
| TraceInfoPtr clone() override { return std::make_shared<TracePhi>(*shared_from_base<TracePhi>()); } | TraceInfoPtr clone() override { return std::make_shared<TracePhi>(*shared_from_base<TracePhi>()); } | ||||
| @@ -78,8 +78,8 @@ class TracePhi : public TraceInfo { | |||||
| class TraceIfStmtTrueBranch : public TraceInfo { | class TraceIfStmtTrueBranch : public TraceInfo { | ||||
| public: | public: | ||||
| TraceIfStmtTrueBranch(const TraceIfStmtTrueBranch&) = default; | |||||
| explicit TraceIfStmtTrueBranch(const DebugInfoPtr& info) : TraceInfo(info, "if_true", "✓") {} | |||||
| TraceIfStmtTrueBranch(const TraceIfStmtTrueBranch &) = default; | |||||
| explicit TraceIfStmtTrueBranch(const DebugInfoPtr &info) : TraceInfo(info, "if_true", "✓") {} | |||||
| MS_DECLARE_PARENT(TraceIfStmtTrueBranch, TraceInfo); | MS_DECLARE_PARENT(TraceIfStmtTrueBranch, TraceInfo); | ||||
| ~TraceIfStmtTrueBranch() override = default; | ~TraceIfStmtTrueBranch() override = default; | ||||
| TraceInfoPtr clone() override { | TraceInfoPtr clone() override { | ||||
| @@ -89,8 +89,8 @@ class TraceIfStmtTrueBranch : public TraceInfo { | |||||
| class TraceIfStmtFalseBranch : public TraceInfo { | class TraceIfStmtFalseBranch : public TraceInfo { | ||||
| public: | public: | ||||
| TraceIfStmtFalseBranch(const TraceIfStmtFalseBranch&) = default; | |||||
| explicit TraceIfStmtFalseBranch(const DebugInfoPtr& info) : TraceInfo(info, "if_false", "✗") {} | |||||
| TraceIfStmtFalseBranch(const TraceIfStmtFalseBranch &) = default; | |||||
| explicit TraceIfStmtFalseBranch(const DebugInfoPtr &info) : TraceInfo(info, "if_false", "✗") {} | |||||
| MS_DECLARE_PARENT(TraceIfStmtFalseBranch, TraceInfo); | MS_DECLARE_PARENT(TraceIfStmtFalseBranch, TraceInfo); | ||||
| ~TraceIfStmtFalseBranch() override = default; | ~TraceIfStmtFalseBranch() override = default; | ||||
| TraceInfoPtr clone() override { | TraceInfoPtr clone() override { | ||||
| @@ -100,7 +100,7 @@ class TraceIfStmtFalseBranch : public TraceInfo { | |||||
| class TraceIfStmtAfterBranch : public TraceInfo { | class TraceIfStmtAfterBranch : public TraceInfo { | ||||
| public: | public: | ||||
| explicit TraceIfStmtAfterBranch(const DebugInfoPtr& info) : TraceInfo(info, "if_after", "↓") {} | |||||
| explicit TraceIfStmtAfterBranch(const DebugInfoPtr &info) : TraceInfo(info, "if_after", "↓") {} | |||||
| MS_DECLARE_PARENT(TraceIfStmtAfterBranch, TraceInfo); | MS_DECLARE_PARENT(TraceIfStmtAfterBranch, TraceInfo); | ||||
| ~TraceIfStmtAfterBranch() override = default; | ~TraceIfStmtAfterBranch() override = default; | ||||
| TraceInfoPtr clone() override { | TraceInfoPtr clone() override { | ||||
| @@ -110,7 +110,7 @@ class TraceIfStmtAfterBranch : public TraceInfo { | |||||
| class TraceIfExpTrueBranch : public TraceInfo { | class TraceIfExpTrueBranch : public TraceInfo { | ||||
| public: | public: | ||||
| explicit TraceIfExpTrueBranch(const DebugInfoPtr& info) : TraceInfo(info, "ifexp_true", "↰") {} | |||||
| explicit TraceIfExpTrueBranch(const DebugInfoPtr &info) : TraceInfo(info, "ifexp_true", "↰") {} | |||||
| MS_DECLARE_PARENT(TraceIfExpTrueBranch, TraceInfo); | MS_DECLARE_PARENT(TraceIfExpTrueBranch, TraceInfo); | ||||
| ~TraceIfExpTrueBranch() override = default; | ~TraceIfExpTrueBranch() override = default; | ||||
| TraceInfoPtr clone() override { | TraceInfoPtr clone() override { | ||||
| @@ -120,7 +120,7 @@ class TraceIfExpTrueBranch : public TraceInfo { | |||||
| class TraceIfExpFalseBranch : public TraceInfo { | class TraceIfExpFalseBranch : public TraceInfo { | ||||
| public: | public: | ||||
| explicit TraceIfExpFalseBranch(const DebugInfoPtr& info) : TraceInfo(info, "ifexp_false", "↱") {} | |||||
| explicit TraceIfExpFalseBranch(const DebugInfoPtr &info) : TraceInfo(info, "ifexp_false", "↱") {} | |||||
| MS_DECLARE_PARENT(TraceIfExpFalseBranch, TraceInfo); | MS_DECLARE_PARENT(TraceIfExpFalseBranch, TraceInfo); | ||||
| ~TraceIfExpFalseBranch() override = default; | ~TraceIfExpFalseBranch() override = default; | ||||
| TraceInfoPtr clone() override { | TraceInfoPtr clone() override { | ||||
| @@ -131,7 +131,7 @@ class TraceIfExpFalseBranch : public TraceInfo { | |||||
| class TraceCopy : public TraceInfo { | class TraceCopy : public TraceInfo { | ||||
| public: | public: | ||||
| TraceCopy() : TraceInfo(nullptr, "copy", "") {} | TraceCopy() : TraceInfo(nullptr, "copy", "") {} | ||||
| explicit TraceCopy(const DebugInfoPtr& info) : TraceInfo(info, "copy", "") {} | |||||
| explicit TraceCopy(const DebugInfoPtr &info) : TraceInfo(info, "copy", "") {} | |||||
| MS_DECLARE_PARENT(TraceCopy, TraceInfo); | MS_DECLARE_PARENT(TraceCopy, TraceInfo); | ||||
| ~TraceCopy() override = default; | ~TraceCopy() override = default; | ||||
| TraceInfoPtr clone() override { return std::make_shared<TraceCopy>(*shared_from_base<TraceCopy>()); } | TraceInfoPtr clone() override { return std::make_shared<TraceCopy>(*shared_from_base<TraceCopy>()); } | ||||
| @@ -139,7 +139,7 @@ class TraceCopy : public TraceInfo { | |||||
| class TraceIterator : public TraceInfo { | class TraceIterator : public TraceInfo { | ||||
| public: | public: | ||||
| explicit TraceIterator(const DebugInfoPtr& info) : TraceInfo(info, "iterator", "@") {} | |||||
| explicit TraceIterator(const DebugInfoPtr &info) : TraceInfo(info, "iterator", "@") {} | |||||
| MS_DECLARE_PARENT(TraceIterator, TraceInfo); | MS_DECLARE_PARENT(TraceIterator, TraceInfo); | ||||
| ~TraceIterator() override = default; | ~TraceIterator() override = default; | ||||
| TraceInfoPtr clone() override { return std::make_shared<TraceIterator>(*shared_from_base<TraceIterator>()); } | TraceInfoPtr clone() override { return std::make_shared<TraceIterator>(*shared_from_base<TraceIterator>()); } | ||||
| @@ -147,7 +147,7 @@ class TraceIterator : public TraceInfo { | |||||
| class TraceWhileHeader : public TraceInfo { | class TraceWhileHeader : public TraceInfo { | ||||
| public: | public: | ||||
| explicit TraceWhileHeader(const DebugInfoPtr& info) : TraceInfo(info, "while_header", "⤾") {} | |||||
| explicit TraceWhileHeader(const DebugInfoPtr &info) : TraceInfo(info, "while_header", "⤾") {} | |||||
| MS_DECLARE_PARENT(TraceWhileHeader, TraceInfo); | MS_DECLARE_PARENT(TraceWhileHeader, TraceInfo); | ||||
| ~TraceWhileHeader() override = default; | ~TraceWhileHeader() override = default; | ||||
| TraceInfoPtr clone() override { return std::make_shared<TraceWhileHeader>(*shared_from_base<TraceWhileHeader>()); } | TraceInfoPtr clone() override { return std::make_shared<TraceWhileHeader>(*shared_from_base<TraceWhileHeader>()); } | ||||
| @@ -155,7 +155,7 @@ class TraceWhileHeader : public TraceInfo { | |||||
| class TraceWhileBody : public TraceInfo { | class TraceWhileBody : public TraceInfo { | ||||
| public: | public: | ||||
| explicit TraceWhileBody(const DebugInfoPtr& info) : TraceInfo(info, "while_body", "⥁") {} | |||||
| explicit TraceWhileBody(const DebugInfoPtr &info) : TraceInfo(info, "while_body", "⥁") {} | |||||
| MS_DECLARE_PARENT(TraceWhileBody, TraceInfo); | MS_DECLARE_PARENT(TraceWhileBody, TraceInfo); | ||||
| ~TraceWhileBody() override = default; | ~TraceWhileBody() override = default; | ||||
| TraceInfoPtr clone() override { return std::make_shared<TraceWhileBody>(*shared_from_base<TraceWhileBody>()); } | TraceInfoPtr clone() override { return std::make_shared<TraceWhileBody>(*shared_from_base<TraceWhileBody>()); } | ||||
| @@ -163,7 +163,7 @@ class TraceWhileBody : public TraceInfo { | |||||
| class TraceWhileAfter : public TraceInfo { | class TraceWhileAfter : public TraceInfo { | ||||
| public: | public: | ||||
| explicit TraceWhileAfter(const DebugInfoPtr& info) : TraceInfo(info, "while_after", "↓") {} | |||||
| explicit TraceWhileAfter(const DebugInfoPtr &info) : TraceInfo(info, "while_after", "↓") {} | |||||
| MS_DECLARE_PARENT(TraceWhileAfter, TraceInfo); | MS_DECLARE_PARENT(TraceWhileAfter, TraceInfo); | ||||
| ~TraceWhileAfter() override = default; | ~TraceWhileAfter() override = default; | ||||
| TraceInfoPtr clone() override { return std::make_shared<TraceWhileAfter>(*shared_from_base<TraceWhileAfter>()); } | TraceInfoPtr clone() override { return std::make_shared<TraceWhileAfter>(*shared_from_base<TraceWhileAfter>()); } | ||||
| @@ -171,7 +171,7 @@ class TraceWhileAfter : public TraceInfo { | |||||
| class TraceForHeader : public TraceInfo { | class TraceForHeader : public TraceInfo { | ||||
| public: | public: | ||||
| explicit TraceForHeader(const DebugInfoPtr& info) : TraceInfo(info, "for_header", "⤾") {} | |||||
| explicit TraceForHeader(const DebugInfoPtr &info) : TraceInfo(info, "for_header", "⤾") {} | |||||
| MS_DECLARE_PARENT(TraceForHeader, TraceInfo); | MS_DECLARE_PARENT(TraceForHeader, TraceInfo); | ||||
| ~TraceForHeader() override = default; | ~TraceForHeader() override = default; | ||||
| TraceInfoPtr clone() override { return std::make_shared<TraceForHeader>(*shared_from_base<TraceForHeader>()); } | TraceInfoPtr clone() override { return std::make_shared<TraceForHeader>(*shared_from_base<TraceForHeader>()); } | ||||
| @@ -179,7 +179,7 @@ class TraceForHeader : public TraceInfo { | |||||
| class TraceForBody : public TraceInfo { | class TraceForBody : public TraceInfo { | ||||
| public: | public: | ||||
| explicit TraceForBody(const DebugInfoPtr& info) : TraceInfo(info, "for_body", "⥁") {} | |||||
| explicit TraceForBody(const DebugInfoPtr &info) : TraceInfo(info, "for_body", "⥁") {} | |||||
| MS_DECLARE_PARENT(TraceForBody, TraceInfo); | MS_DECLARE_PARENT(TraceForBody, TraceInfo); | ||||
| ~TraceForBody() override = default; | ~TraceForBody() override = default; | ||||
| TraceInfoPtr clone() override { return std::make_shared<TraceForBody>(*shared_from_base<TraceForBody>()); } | TraceInfoPtr clone() override { return std::make_shared<TraceForBody>(*shared_from_base<TraceForBody>()); } | ||||
| @@ -187,7 +187,7 @@ class TraceForBody : public TraceInfo { | |||||
| class TraceForAfter : public TraceInfo { | class TraceForAfter : public TraceInfo { | ||||
| public: | public: | ||||
| explicit TraceForAfter(const DebugInfoPtr& info) : TraceInfo(info, "for_after", "↓") {} | |||||
| explicit TraceForAfter(const DebugInfoPtr &info) : TraceInfo(info, "for_after", "↓") {} | |||||
| MS_DECLARE_PARENT(TraceForAfter, TraceInfo); | MS_DECLARE_PARENT(TraceForAfter, TraceInfo); | ||||
| ~TraceForAfter() override = default; | ~TraceForAfter() override = default; | ||||
| TraceInfoPtr clone() override { return std::make_shared<TraceForAfter>(*shared_from_base<TraceForAfter>()); } | TraceInfoPtr clone() override { return std::make_shared<TraceForAfter>(*shared_from_base<TraceForAfter>()); } | ||||
| @@ -195,7 +195,7 @@ class TraceForAfter : public TraceInfo { | |||||
| class TraceEquiv : public TraceInfo { | class TraceEquiv : public TraceInfo { | ||||
| public: | public: | ||||
| explicit TraceEquiv(const DebugInfoPtr& info) : TraceInfo(info, "equiv", "equiv") {} | |||||
| explicit TraceEquiv(const DebugInfoPtr &info) : TraceInfo(info, "equiv", "equiv") {} | |||||
| MS_DECLARE_PARENT(TraceEquiv, TraceInfo); | MS_DECLARE_PARENT(TraceEquiv, TraceInfo); | ||||
| ~TraceEquiv() override = default; | ~TraceEquiv() override = default; | ||||
| TraceInfoPtr clone() override { return std::make_shared<TraceEquiv>(*shared_from_base<TraceEquiv>()); } | TraceInfoPtr clone() override { return std::make_shared<TraceEquiv>(*shared_from_base<TraceEquiv>()); } | ||||
| @@ -204,7 +204,7 @@ class TraceEquiv : public TraceInfo { | |||||
| class TraceGradFpropApp : public TraceInfo { | class TraceGradFpropApp : public TraceInfo { | ||||
| public: | public: | ||||
| TraceGradFpropApp() : TraceInfo(nullptr, "grad_fprop_app", "▲") {} | TraceGradFpropApp() : TraceInfo(nullptr, "grad_fprop_app", "▲") {} | ||||
| explicit TraceGradFpropApp(const DebugInfoPtr& info) : TraceInfo(info, "grad_fprop_app", "▲") {} | |||||
| explicit TraceGradFpropApp(const DebugInfoPtr &info) : TraceInfo(info, "grad_fprop_app", "▲") {} | |||||
| MS_DECLARE_PARENT(TraceGradFpropApp, TraceInfo); | MS_DECLARE_PARENT(TraceGradFpropApp, TraceInfo); | ||||
| ~TraceGradFpropApp() override = default; | ~TraceGradFpropApp() override = default; | ||||
| TraceInfoPtr clone() override { return std::make_shared<TraceGradFpropApp>(*shared_from_base<TraceGradFpropApp>()); } | TraceInfoPtr clone() override { return std::make_shared<TraceGradFpropApp>(*shared_from_base<TraceGradFpropApp>()); } | ||||
| @@ -213,7 +213,7 @@ class TraceGradFpropApp : public TraceInfo { | |||||
| class TraceGradBpropApp : public TraceInfo { | class TraceGradBpropApp : public TraceInfo { | ||||
| public: | public: | ||||
| TraceGradBpropApp() : TraceInfo(nullptr, "grad_bprop_app", "▼") {} | TraceGradBpropApp() : TraceInfo(nullptr, "grad_bprop_app", "▼") {} | ||||
| explicit TraceGradBpropApp(const DebugInfoPtr& info) : TraceInfo(info, "grad_bprop_app", "▼") {} | |||||
| explicit TraceGradBpropApp(const DebugInfoPtr &info) : TraceInfo(info, "grad_bprop_app", "▼") {} | |||||
| MS_DECLARE_PARENT(TraceGradBpropApp, TraceInfo); | MS_DECLARE_PARENT(TraceGradBpropApp, TraceInfo); | ||||
| ~TraceGradBpropApp() override = default; | ~TraceGradBpropApp() override = default; | ||||
| TraceInfoPtr clone() override { return std::make_shared<TraceGradBpropApp>(*shared_from_base<TraceGradBpropApp>()); } | TraceInfoPtr clone() override { return std::make_shared<TraceGradBpropApp>(*shared_from_base<TraceGradBpropApp>()); } | ||||
| @@ -222,7 +222,7 @@ class TraceGradBpropApp : public TraceInfo { | |||||
| class TraceGradFprop : public TraceInfo { | class TraceGradFprop : public TraceInfo { | ||||
| public: | public: | ||||
| TraceGradFprop() : TraceInfo(nullptr, "grad_fprop", "▶") {} | TraceGradFprop() : TraceInfo(nullptr, "grad_fprop", "▶") {} | ||||
| explicit TraceGradFprop(const DebugInfoPtr& info) : TraceInfo(info, "grad_fprop", "▶") {} | |||||
| explicit TraceGradFprop(const DebugInfoPtr &info) : TraceInfo(info, "grad_fprop", "▶") {} | |||||
| MS_DECLARE_PARENT(TraceGradFprop, TraceInfo); | MS_DECLARE_PARENT(TraceGradFprop, TraceInfo); | ||||
| ~TraceGradFprop() override = default; | ~TraceGradFprop() override = default; | ||||
| TraceInfoPtr clone() override { return std::make_shared<TraceGradFprop>(*shared_from_base<TraceGradFprop>()); } | TraceInfoPtr clone() override { return std::make_shared<TraceGradFprop>(*shared_from_base<TraceGradFprop>()); } | ||||
| @@ -231,7 +231,7 @@ class TraceGradFprop : public TraceInfo { | |||||
| class TraceGradBprop : public TraceInfo { | class TraceGradBprop : public TraceInfo { | ||||
| public: | public: | ||||
| TraceGradBprop() : TraceInfo(nullptr, "grad_bprop", "◀") {} | TraceGradBprop() : TraceInfo(nullptr, "grad_bprop", "◀") {} | ||||
| explicit TraceGradBprop(const DebugInfoPtr& info) : TraceInfo(info, "grad_bprop", "◀") {} | |||||
| explicit TraceGradBprop(const DebugInfoPtr &info) : TraceInfo(info, "grad_bprop", "◀") {} | |||||
| MS_DECLARE_PARENT(TraceGradBprop, TraceInfo); | MS_DECLARE_PARENT(TraceGradBprop, TraceInfo); | ||||
| ~TraceGradBprop() override = default; | ~TraceGradBprop() override = default; | ||||
| TraceInfoPtr clone() override { return std::make_shared<TraceGradBprop>(*shared_from_base<TraceGradBprop>()); } | TraceInfoPtr clone() override { return std::make_shared<TraceGradBprop>(*shared_from_base<TraceGradBprop>()); } | ||||
| @@ -240,7 +240,7 @@ class TraceGradBprop : public TraceInfo { | |||||
| class TraceGradSens : public TraceInfo { | class TraceGradSens : public TraceInfo { | ||||
| public: | public: | ||||
| TraceGradSens() : TraceInfo(nullptr, "grad_sens", "∇") {} | TraceGradSens() : TraceInfo(nullptr, "grad_sens", "∇") {} | ||||
| explicit TraceGradSens(const DebugInfoPtr& info) : TraceInfo(info, "grad_sens", "∇") {} | |||||
| explicit TraceGradSens(const DebugInfoPtr &info) : TraceInfo(info, "grad_sens", "∇") {} | |||||
| MS_DECLARE_PARENT(TraceGradSens, TraceInfo); | MS_DECLARE_PARENT(TraceGradSens, TraceInfo); | ||||
| ~TraceGradSens() override = default; | ~TraceGradSens() override = default; | ||||
| TraceInfoPtr clone() override { return std::make_shared<TraceGradSens>(*shared_from_base<TraceGradSens>()); } | TraceInfoPtr clone() override { return std::make_shared<TraceGradSens>(*shared_from_base<TraceGradSens>()); } | ||||
| @@ -248,7 +248,7 @@ class TraceGradSens : public TraceInfo { | |||||
| class TraceSpecialize : public TraceInfo { | class TraceSpecialize : public TraceInfo { | ||||
| public: | public: | ||||
| explicit TraceSpecialize(const std::string& counter) : TraceInfo(nullptr, "specialize", "") { counter_ = counter; } | |||||
| explicit TraceSpecialize(const std::string &counter) : TraceInfo(nullptr, "specialize", "") { counter_ = counter; } | |||||
| MS_DECLARE_PARENT(TraceSpecialize, TraceInfo); | MS_DECLARE_PARENT(TraceSpecialize, TraceInfo); | ||||
| std::string name() override { return full_name_ + counter_; } | std::string name() override { return full_name_ + counter_; } | ||||
| std::string symbol() override { return counter_ + "_"; } | std::string symbol() override { return counter_ + "_"; } | ||||
| @@ -260,7 +260,7 @@ class TraceSpecialize : public TraceInfo { | |||||
| class TraceGradOperation : public TraceInfo { | class TraceGradOperation : public TraceInfo { | ||||
| public: | public: | ||||
| explicit TraceGradOperation(const DebugInfoPtr& info) : TraceInfo(info, "grad_ops", "") {} | |||||
| explicit TraceGradOperation(const DebugInfoPtr &info) : TraceInfo(info, "grad_ops", "") {} | |||||
| MS_DECLARE_PARENT(TraceGradOperation, TraceInfo); | MS_DECLARE_PARENT(TraceGradOperation, TraceInfo); | ||||
| ~TraceGradOperation() override = default; | ~TraceGradOperation() override = default; | ||||
| TraceInfoPtr clone() override { | TraceInfoPtr clone() override { | ||||
| @@ -270,7 +270,7 @@ class TraceGradOperation : public TraceInfo { | |||||
| class TraceForceBool : public TraceInfo { | class TraceForceBool : public TraceInfo { | ||||
| public: | public: | ||||
| explicit TraceForceBool(const DebugInfoPtr& info) : TraceInfo(info, "force_bool", "") {} | |||||
| explicit TraceForceBool(const DebugInfoPtr &info) : TraceInfo(info, "force_bool", "") {} | |||||
| MS_DECLARE_PARENT(TraceForceBool, TraceInfo); | MS_DECLARE_PARENT(TraceForceBool, TraceInfo); | ||||
| ~TraceForceBool() override = default; | ~TraceForceBool() override = default; | ||||
| TraceInfoPtr clone() override { return std::make_shared<TraceForceBool>(*shared_from_base<TraceForceBool>()); } | TraceInfoPtr clone() override { return std::make_shared<TraceForceBool>(*shared_from_base<TraceForceBool>()); } | ||||
| @@ -278,7 +278,7 @@ class TraceForceBool : public TraceInfo { | |||||
| class TraceExpandJ : public TraceInfo { | class TraceExpandJ : public TraceInfo { | ||||
| public: | public: | ||||
| explicit TraceExpandJ(const DebugInfoPtr& info) : TraceInfo(info, "expand_j", "") {} | |||||
| explicit TraceExpandJ(const DebugInfoPtr &info) : TraceInfo(info, "expand_j", "") {} | |||||
| MS_DECLARE_PARENT(TraceExpandJ, TraceInfo); | MS_DECLARE_PARENT(TraceExpandJ, TraceInfo); | ||||
| ~TraceExpandJ() override = default; | ~TraceExpandJ() override = default; | ||||
| TraceInfoPtr clone() override { return std::make_shared<TraceExpandJ>(*shared_from_base<TraceExpandJ>()); } | TraceInfoPtr clone() override { return std::make_shared<TraceExpandJ>(*shared_from_base<TraceExpandJ>()); } | ||||
| @@ -286,7 +286,7 @@ class TraceExpandJ : public TraceInfo { | |||||
| class TraceGenMetaFuncGraph : public TraceInfo { | class TraceGenMetaFuncGraph : public TraceInfo { | ||||
| public: | public: | ||||
| explicit TraceGenMetaFuncGraph(const DebugInfoPtr& info) : TraceInfo(info, "GenMetaFuncGraph", "") {} | |||||
| explicit TraceGenMetaFuncGraph(const DebugInfoPtr &info) : TraceInfo(info, "GenMetaFuncGraph", "") {} | |||||
| MS_DECLARE_PARENT(TraceGenMetaFuncGraph, TraceInfo); | MS_DECLARE_PARENT(TraceGenMetaFuncGraph, TraceInfo); | ||||
| ~TraceGenMetaFuncGraph() override = default; | ~TraceGenMetaFuncGraph() override = default; | ||||
| TraceInfoPtr clone() override { | TraceInfoPtr clone() override { | ||||
| @@ -296,7 +296,7 @@ class TraceGenMetaFuncGraph : public TraceInfo { | |||||
| class TraceEvaluatorGenGraph : public TraceInfo { | class TraceEvaluatorGenGraph : public TraceInfo { | ||||
| public: | public: | ||||
| explicit TraceEvaluatorGenGraph(const DebugInfoPtr& info) : TraceInfo(info, "GenEvaluatorGraph", "") {} | |||||
| explicit TraceEvaluatorGenGraph(const DebugInfoPtr &info) : TraceInfo(info, "GenEvaluatorGraph", "") {} | |||||
| MS_DECLARE_PARENT(TraceEvaluatorGenGraph, TraceInfo); | MS_DECLARE_PARENT(TraceEvaluatorGenGraph, TraceInfo); | ||||
| ~TraceEvaluatorGenGraph() override = default; | ~TraceEvaluatorGenGraph() override = default; | ||||
| TraceInfoPtr clone() override { | TraceInfoPtr clone() override { | ||||
| @@ -306,7 +306,7 @@ class TraceEvaluatorGenGraph : public TraceInfo { | |||||
| class TraceResolve : public TraceInfo { | class TraceResolve : public TraceInfo { | ||||
| public: | public: | ||||
| explicit TraceResolve(const DebugInfoPtr& info) : TraceInfo(info, "resolve", "") {} | |||||
| explicit TraceResolve(const DebugInfoPtr &info) : TraceInfo(info, "resolve", "") {} | |||||
| MS_DECLARE_PARENT(TraceResolve, TraceInfo); | MS_DECLARE_PARENT(TraceResolve, TraceInfo); | ||||
| ~TraceResolve() override = default; | ~TraceResolve() override = default; | ||||
| TraceInfoPtr clone() override { return std::make_shared<TraceResolve>(*shared_from_base<TraceResolve>()); } | TraceInfoPtr clone() override { return std::make_shared<TraceResolve>(*shared_from_base<TraceResolve>()); } | ||||
| @@ -315,7 +315,7 @@ class TraceResolve : public TraceInfo { | |||||
| class TraceTransform : public TraceInfo { | class TraceTransform : public TraceInfo { | ||||
| public: | public: | ||||
| TraceTransform() : TraceInfo(nullptr, "transform", "") { transform_name_ = ""; } | TraceTransform() : TraceInfo(nullptr, "transform", "") { transform_name_ = ""; } | ||||
| explicit TraceTransform(const std::string& transform_name) : TraceInfo(nullptr, "transform", "") { | |||||
| explicit TraceTransform(const std::string &transform_name) : TraceInfo(nullptr, "transform", "") { | |||||
| transform_name_ = transform_name; | transform_name_ = transform_name; | ||||
| } | } | ||||
| @@ -335,7 +335,7 @@ class TraceTransform : public TraceInfo { | |||||
| class TraceGenerateVarArg : public TraceInfo { | class TraceGenerateVarArg : public TraceInfo { | ||||
| public: | public: | ||||
| explicit TraceGenerateVarArg(const DebugInfoPtr& info) : TraceInfo(info, "GenerateVarArg", "") {} | |||||
| explicit TraceGenerateVarArg(const DebugInfoPtr &info) : TraceInfo(info, "GenerateVarArg", "") {} | |||||
| MS_DECLARE_PARENT(TraceGenerateVarArg, TraceInfo); | MS_DECLARE_PARENT(TraceGenerateVarArg, TraceInfo); | ||||
| ~TraceGenerateVarArg() override = default; | ~TraceGenerateVarArg() override = default; | ||||
| TraceInfoPtr clone() override { | TraceInfoPtr clone() override { | ||||
| @@ -345,7 +345,7 @@ class TraceGenerateVarArg : public TraceInfo { | |||||
| class TraceGenerateKwArg : public TraceInfo { | class TraceGenerateKwArg : public TraceInfo { | ||||
| public: | public: | ||||
| explicit TraceGenerateKwArg(const DebugInfoPtr& info) : TraceInfo(info, "GenerateKwArg", "") {} | |||||
| explicit TraceGenerateKwArg(const DebugInfoPtr &info) : TraceInfo(info, "GenerateKwArg", "") {} | |||||
| MS_DECLARE_PARENT(TraceGenerateKwArg, TraceInfo); | MS_DECLARE_PARENT(TraceGenerateKwArg, TraceInfo); | ||||
| ~TraceGenerateKwArg() override = default; | ~TraceGenerateKwArg() override = default; | ||||
| TraceInfoPtr clone() override { | TraceInfoPtr clone() override { | ||||
| @@ -355,7 +355,7 @@ class TraceGenerateKwArg : public TraceInfo { | |||||
| class TraceTrasformK : public TraceInfo { | class TraceTrasformK : public TraceInfo { | ||||
| public: | public: | ||||
| explicit TraceTrasformK(const DebugInfoPtr& info) : TraceInfo(info, "TraceTrasformK", "") {} | |||||
| explicit TraceTrasformK(const DebugInfoPtr &info) : TraceInfo(info, "TraceTrasformK", "") {} | |||||
| MS_DECLARE_PARENT(TraceTrasformK, TraceInfo); | MS_DECLARE_PARENT(TraceTrasformK, TraceInfo); | ||||
| ~TraceTrasformK() override = default; | ~TraceTrasformK() override = default; | ||||
| TraceInfoPtr clone() override { return std::make_shared<TraceTrasformK>(*shared_from_base<TraceTrasformK>()); } | TraceInfoPtr clone() override { return std::make_shared<TraceTrasformK>(*shared_from_base<TraceTrasformK>()); } | ||||
| @@ -363,7 +363,7 @@ class TraceTrasformK : public TraceInfo { | |||||
| class TracePartialTransform : public TraceInfo { | class TracePartialTransform : public TraceInfo { | ||||
| public: | public: | ||||
| explicit TracePartialTransform(const DebugInfoPtr& info) : TraceInfo(info, "PartialTransform", "") {} | |||||
| explicit TracePartialTransform(const DebugInfoPtr &info) : TraceInfo(info, "PartialTransform", "") {} | |||||
| MS_DECLARE_PARENT(TracePartialTransform, TraceInfo); | MS_DECLARE_PARENT(TracePartialTransform, TraceInfo); | ||||
| ~TracePartialTransform() override = default; | ~TracePartialTransform() override = default; | ||||
| TraceInfoPtr clone() override { | TraceInfoPtr clone() override { | ||||
| @@ -373,7 +373,7 @@ class TracePartialTransform : public TraceInfo { | |||||
| class TraceGetEnv : public TraceInfo { | class TraceGetEnv : public TraceInfo { | ||||
| public: | public: | ||||
| explicit TraceGetEnv(const DebugInfoPtr& info) : TraceInfo(info, "get_env", "") {} | |||||
| explicit TraceGetEnv(const DebugInfoPtr &info) : TraceInfo(info, "get_env", "") {} | |||||
| MS_DECLARE_PARENT(TraceGetEnv, TraceInfo); | MS_DECLARE_PARENT(TraceGetEnv, TraceInfo); | ||||
| ~TraceGetEnv() override = default; | ~TraceGetEnv() override = default; | ||||
| TraceInfoPtr clone() override { return std::make_shared<TraceGetEnv>(*shared_from_base<TraceGetEnv>()); } | TraceInfoPtr clone() override { return std::make_shared<TraceGetEnv>(*shared_from_base<TraceGetEnv>()); } | ||||
| @@ -381,7 +381,7 @@ class TraceGetEnv : public TraceInfo { | |||||
| class TraceDoSignature : public TraceInfo { | class TraceDoSignature : public TraceInfo { | ||||
| public: | public: | ||||
| explicit TraceDoSignature(const DebugInfoPtr& info) : TraceInfo(info, "DoSignature", "") {} | |||||
| explicit TraceDoSignature(const DebugInfoPtr &info) : TraceInfo(info, "DoSignature", "") {} | |||||
| MS_DECLARE_PARENT(TraceDoSignature, TraceInfo); | MS_DECLARE_PARENT(TraceDoSignature, TraceInfo); | ||||
| ~TraceDoSignature() override = default; | ~TraceDoSignature() override = default; | ||||
| TraceInfoPtr clone() override { return std::make_shared<TraceDoSignature>(*shared_from_base<TraceDoSignature>()); } | TraceInfoPtr clone() override { return std::make_shared<TraceDoSignature>(*shared_from_base<TraceDoSignature>()); } | ||||
| @@ -390,7 +390,7 @@ class TraceDoSignature : public TraceInfo { | |||||
| class TraceCombileLikeGraphs : public TraceInfo { | class TraceCombileLikeGraphs : public TraceInfo { | ||||
| public: | public: | ||||
| TraceCombileLikeGraphs() : TraceInfo(nullptr, "CombileLike", "L-") {} | TraceCombileLikeGraphs() : TraceInfo(nullptr, "CombileLike", "L-") {} | ||||
| explicit TraceCombileLikeGraphs(const DebugInfoPtr& info) : TraceInfo(info, "CombileLike", "L-") {} | |||||
| explicit TraceCombileLikeGraphs(const DebugInfoPtr &info) : TraceInfo(info, "CombileLike", "L-") {} | |||||
| MS_DECLARE_PARENT(TraceCombileLikeGraphs, TraceInfo); | MS_DECLARE_PARENT(TraceCombileLikeGraphs, TraceInfo); | ||||
| ~TraceCombileLikeGraphs() override = default; | ~TraceCombileLikeGraphs() override = default; | ||||
| TraceInfoPtr clone() override { | TraceInfoPtr clone() override { | ||||
| @@ -21,7 +21,7 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace device { | namespace device { | ||||
| namespace ascend { | namespace ascend { | ||||
| size_t AscendMemoryPool::AllocDeviceMem(size_t size, DeviceMemPtr* addr) { | |||||
| size_t AscendMemoryPool::AllocDeviceMem(size_t size, DeviceMemPtr *addr) { | |||||
| if (has_malloc_) { | if (has_malloc_) { | ||||
| MS_LOG(EXCEPTION) << "Has alloc memory pool memory !"; | MS_LOG(EXCEPTION) << "Has alloc memory pool memory !"; | ||||
| } | } | ||||
| @@ -37,7 +37,7 @@ size_t AscendMemoryPool::AllocDeviceMem(size_t size, DeviceMemPtr* addr) { | |||||
| return size; | return size; | ||||
| } | } | ||||
| bool AscendMemoryPool::FreeDeviceMem(const DeviceMemPtr& addr) { | |||||
| bool AscendMemoryPool::FreeDeviceMem(const DeviceMemPtr &addr) { | |||||
| MS_EXCEPTION_IF_NULL(addr); | MS_EXCEPTION_IF_NULL(addr); | ||||
| has_malloc_ = false; | has_malloc_ = false; | ||||
| free_mem_size_ = total_mem_size_; | free_mem_size_ = total_mem_size_; | ||||
| @@ -53,7 +53,7 @@ size_t AscendMemoryPool::AlignMemorySize(size_t size) const { | |||||
| size_t AscendMemoryPool::mem_alloc_unit_size() const { return free_mem_size_ - 512; } | size_t AscendMemoryPool::mem_alloc_unit_size() const { return free_mem_size_ - 512; } | ||||
| void AscendMemoryPool::set_device_mem_pool_base(uint8_t* device_mem_pool_base) { | |||||
| void AscendMemoryPool::set_device_mem_pool_base(uint8_t *device_mem_pool_base) { | |||||
| MS_EXCEPTION_IF_NULL(device_mem_pool_base); | MS_EXCEPTION_IF_NULL(device_mem_pool_base); | ||||
| device_mem_pool_base_ = device_mem_pool_base; | device_mem_pool_base_ = device_mem_pool_base; | ||||
| } | } | ||||
| @@ -26,12 +26,12 @@ namespace ascend { | |||||
| class AscendMemoryPool : public DynamicMemPoolBestFit { | class AscendMemoryPool : public DynamicMemPoolBestFit { | ||||
| public: | public: | ||||
| ~AscendMemoryPool() override = default; | ~AscendMemoryPool() override = default; | ||||
| AscendMemoryPool(const AscendMemoryPool&) = delete; | |||||
| AscendMemoryPool& operator=(const AscendMemoryPool&) = delete; | |||||
| AscendMemoryPool(const AscendMemoryPool &) = delete; | |||||
| AscendMemoryPool &operator=(const AscendMemoryPool &) = delete; | |||||
| size_t AllocDeviceMem(size_t size, DeviceMemPtr* addr) override; | |||||
| bool FreeDeviceMem(const DeviceMemPtr& addr) override; | |||||
| void set_device_mem_pool_base(uint8_t* device_mem_pool_base); | |||||
| size_t AllocDeviceMem(size_t size, DeviceMemPtr *addr) override; | |||||
| bool FreeDeviceMem(const DeviceMemPtr &addr) override; | |||||
| void set_device_mem_pool_base(uint8_t *device_mem_pool_base); | |||||
| void set_device_mem_pool_size(uint64_t device_mem_pool_size) { | void set_device_mem_pool_size(uint64_t device_mem_pool_size) { | ||||
| device_mem_pool_size_ = device_mem_pool_size; | device_mem_pool_size_ = device_mem_pool_size; | ||||
| free_mem_size_ = device_mem_pool_size_; | free_mem_size_ = device_mem_pool_size_; | ||||
| @@ -40,7 +40,7 @@ class AscendMemoryPool : public DynamicMemPoolBestFit { | |||||
| size_t free_mem_size() override; | size_t free_mem_size() override; | ||||
| size_t total_mem_size() override; | size_t total_mem_size() override; | ||||
| static AscendMemoryPool& GetInstance() { | |||||
| static AscendMemoryPool &GetInstance() { | |||||
| static AscendMemoryPool instance; | static AscendMemoryPool instance; | ||||
| return instance; | return instance; | ||||
| } | } | ||||
| @@ -54,7 +54,7 @@ class AscendMemoryPool : public DynamicMemPoolBestFit { | |||||
| private: | private: | ||||
| AscendMemoryPool() = default; | AscendMemoryPool() = default; | ||||
| bool has_malloc_{false}; | bool has_malloc_{false}; | ||||
| uint8_t* device_mem_pool_base_{nullptr}; | |||||
| uint8_t *device_mem_pool_base_{nullptr}; | |||||
| uint64_t device_mem_pool_size_{0}; | uint64_t device_mem_pool_size_{0}; | ||||
| size_t free_mem_size_{0}; | size_t free_mem_size_{0}; | ||||
| size_t total_mem_size_{0}; | size_t total_mem_size_{0}; | ||||
| @@ -39,13 +39,13 @@ using std::vector; | |||||
| class AscendStreamAssign { | class AscendStreamAssign { | ||||
| public: | public: | ||||
| static AscendStreamAssign& GetInstance() { | |||||
| static AscendStreamAssign &GetInstance() { | |||||
| static AscendStreamAssign instance; // Guaranteed to be destroyed. | static AscendStreamAssign instance; // Guaranteed to be destroyed. | ||||
| return instance; | return instance; | ||||
| } | } | ||||
| AscendStreamAssign(const AscendStreamAssign&) = delete; | |||||
| AscendStreamAssign& operator=(const AscendStreamAssign&) = delete; | |||||
| AscendStreamAssign(const AscendStreamAssign &) = delete; | |||||
| AscendStreamAssign &operator=(const AscendStreamAssign &) = delete; | |||||
| uint32_t GetTotalStreamNum() const; | uint32_t GetTotalStreamNum() const; | ||||
| // new stream policy | // new stream policy | ||||
| @@ -53,19 +53,19 @@ class AscendStreamAssign { | |||||
| uint32_t total_independ_stream_num() const { return total_independ_stream_num_; } | uint32_t total_independ_stream_num() const { return total_independ_stream_num_; } | ||||
| uint32_t total_event_num() const { return total_event_num_; } | uint32_t total_event_num() const { return total_event_num_; } | ||||
| void InsertActiveNew(const std::shared_ptr<session::KernelGraph>& graph_ptr); | |||||
| void AssignAllNodesStream(const std::shared_ptr<session::KernelGraph>& graph_ptr); | |||||
| void InsertActiveNew(const std::shared_ptr<session::KernelGraph> &graph_ptr); | |||||
| void AssignAllNodesStream(const std::shared_ptr<session::KernelGraph> &graph_ptr); | |||||
| void ResetNew(); | void ResetNew(); | ||||
| void AssignStreamNew(const std::shared_ptr<session::KernelGraph>& graph_ptr); | |||||
| bool IsIndependentNode(const CNodePtr& node_ptr); | |||||
| const std::unordered_map<uint32_t, uint32_t>& logic_to_independent_map() { return logic_to_independent_map_; } | |||||
| const std::unordered_map<uint32_t, uint32_t>& logic_to_physic_map() { return logic_to_physic_map_; } | |||||
| const std::vector<std::vector<uint32_t>>& inner_parallel_streams() { return inner_parallel_streams_; } | |||||
| void GetWaitStreams(vector<uint32_t>* wait_active_stream_list); | |||||
| const std::vector<uint32_t>& hcom_streams() { return hcom_stream_list_; } | |||||
| CNodePtr CreateSendApplyKernel(const std::shared_ptr<session::KernelGraph>& graph_ptr, uint32_t event_id, | |||||
| void AssignStreamNew(const std::shared_ptr<session::KernelGraph> &graph_ptr); | |||||
| bool IsIndependentNode(const CNodePtr &node_ptr); | |||||
| const std::unordered_map<uint32_t, uint32_t> &logic_to_independent_map() { return logic_to_independent_map_; } | |||||
| const std::unordered_map<uint32_t, uint32_t> &logic_to_physic_map() { return logic_to_physic_map_; } | |||||
| const std::vector<std::vector<uint32_t>> &inner_parallel_streams() { return inner_parallel_streams_; } | |||||
| void GetWaitStreams(vector<uint32_t> *wait_active_stream_list); | |||||
| const std::vector<uint32_t> &hcom_streams() { return hcom_stream_list_; } | |||||
| CNodePtr CreateSendApplyKernel(const std::shared_ptr<session::KernelGraph> &graph_ptr, uint32_t event_id, | |||||
| uint32_t stream_id); | uint32_t stream_id); | ||||
| CNodePtr CreateRecvApplyKernel(const std::shared_ptr<session::KernelGraph>& graph_ptr, uint32_t event_id, | |||||
| CNodePtr CreateRecvApplyKernel(const std::shared_ptr<session::KernelGraph> &graph_ptr, uint32_t event_id, | |||||
| uint32_t stream_id); | uint32_t stream_id); | ||||
| private: | private: | ||||
| @@ -73,30 +73,30 @@ class AscendStreamAssign { | |||||
| ~AscendStreamAssign() = default; | ~AscendStreamAssign() = default; | ||||
| vector<CNodePtr>::iterator FindTargetOp(vector<CNodePtr>::iterator begin, vector<CNodePtr>::iterator end, | vector<CNodePtr>::iterator FindTargetOp(vector<CNodePtr>::iterator begin, vector<CNodePtr>::iterator end, | ||||
| const CNodePtr& node); | |||||
| const CNodePtr &node); | |||||
| bool IsHcom(const CNodePtr& apply_kernel); | |||||
| bool IsHcom(const CNodePtr &apply_kernel); | |||||
| bool IsProcessed(uint32_t logic_id); | bool IsProcessed(uint32_t logic_id); | ||||
| void TransLogicToPhysic(const vector<uint32_t>& logic_ids, vector<uint32_t>* physic_ids); | |||||
| void AssignCommonStreamId(const CNodePtr& cur_cnode_ptr, CNodePtr* pre_cnode_ptr, uint32_t* cur_index, | |||||
| uint32_t* cur_stream_id); | |||||
| void TransLogicToPhysic(const vector<uint32_t> &logic_ids, vector<uint32_t> *physic_ids); | |||||
| void AssignCommonStreamId(const CNodePtr &cur_cnode_ptr, CNodePtr *pre_cnode_ptr, uint32_t *cur_index, | |||||
| uint32_t *cur_stream_id); | |||||
| void RecordIdMap(uint32_t logic_id, uint32_t physic_id); | void RecordIdMap(uint32_t logic_id, uint32_t physic_id); | ||||
| void UpdateStreamActive(const CNodePtr& active_ptr); | |||||
| void UpdateStreamSwitch(const CNodePtr& switch_ptr, const CNodePtr& active_ptr); | |||||
| void UpdateStreamActive(const CNodePtr &active_ptr); | |||||
| void UpdateStreamSwitch(const CNodePtr &switch_ptr, const CNodePtr &active_ptr); | |||||
| bool IsTaskSink(); | bool IsTaskSink(); | ||||
| void AssignIndependentStreamId(const CNodePtr& cur_cnode_ptr, uint32_t deal_logic_id); | |||||
| void UpdateStreamId(const std::shared_ptr<session::KernelGraph>& graph_ptr); | |||||
| void UpdateEventId(const std::shared_ptr<session::KernelGraph>& graph_ptr); | |||||
| void PrintGraphExeOrders(const std::shared_ptr<session::KernelGraph>& graph_ptr); | |||||
| void RecordFirstCommonOp(const CNodePtr& cur_cnode_ptr, uint32_t cur_node_logic_id, uint32_t cur_stream_id); | |||||
| uint32_t GetLogicId(const CNodePtr& cur_cnode_ptr); | |||||
| void AssignIndependentStreamId(const CNodePtr &cur_cnode_ptr, uint32_t deal_logic_id); | |||||
| void UpdateStreamId(const std::shared_ptr<session::KernelGraph> &graph_ptr); | |||||
| void UpdateEventId(const std::shared_ptr<session::KernelGraph> &graph_ptr); | |||||
| void PrintGraphExeOrders(const std::shared_ptr<session::KernelGraph> &graph_ptr); | |||||
| void RecordFirstCommonOp(const CNodePtr &cur_cnode_ptr, uint32_t cur_node_logic_id, uint32_t cur_stream_id); | |||||
| uint32_t GetLogicId(const CNodePtr &cur_cnode_ptr); | |||||
| void SetCommonStreamNum(uint32_t cur_stream_id); | void SetCommonStreamNum(uint32_t cur_stream_id); | ||||
| void FindAllReduceParallel(const std::shared_ptr<session::KernelGraph>& graph_ptr); | |||||
| void FindAllReduceParallel(const std::shared_ptr<session::KernelGraph> &graph_ptr); | |||||
| bool IsProcessedParallelStream(uint32_t stream_id); | bool IsProcessedParallelStream(uint32_t stream_id); | ||||
| void GetParallelStream(uint32_t cur_stream_id, uint32_t stream_acitve_id, std::vector<uint32_t>* parallel_streams); | |||||
| void InsertSendRecvForIndependent(const std::shared_ptr<session::KernelGraph>& graph_ptr); | |||||
| void InsertSendRecvForHcomParallel(const std::shared_ptr<session::KernelGraph>& graph_ptr); | |||||
| void GetNeedActiveStreams(const std::shared_ptr<session::KernelGraph>& graph_ptr); | |||||
| void GetParallelStream(uint32_t cur_stream_id, uint32_t stream_acitve_id, std::vector<uint32_t> *parallel_streams); | |||||
| void InsertSendRecvForIndependent(const std::shared_ptr<session::KernelGraph> &graph_ptr); | |||||
| void InsertSendRecvForHcomParallel(const std::shared_ptr<session::KernelGraph> &graph_ptr); | |||||
| void GetNeedActiveStreams(const std::shared_ptr<session::KernelGraph> &graph_ptr); | |||||
| uint32_t total_common_stream_num_{0}; | uint32_t total_common_stream_num_{0}; | ||||
| uint32_t total_independ_stream_num_{0}; | uint32_t total_independ_stream_num_{0}; | ||||
| @@ -28,14 +28,14 @@ namespace device { | |||||
| namespace ascend { | namespace ascend { | ||||
| class PluginImpl : public PluginIntf { | class PluginImpl : public PluginIntf { | ||||
| public: | public: | ||||
| explicit PluginImpl(const std::string& module); | |||||
| explicit PluginImpl(const std::string &module); | |||||
| ~PluginImpl() override = default; | ~PluginImpl() override = default; | ||||
| int Init(const Reporter* reporter) override; | |||||
| int Init(const Reporter *reporter) override; | |||||
| int UnInit() override; | int UnInit() override; | ||||
| static Reporter* GetPluginReporter() { return reporter_; } | |||||
| static Reporter *GetPluginReporter() { return reporter_; } | |||||
| private: | private: | ||||
| static Reporter* reporter_; | |||||
| static Reporter *reporter_; | |||||
| std::string module_; | std::string module_; | ||||
| }; | }; | ||||
| } // namespace ascend | } // namespace ascend | ||||
| @@ -20,12 +20,12 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace device { | namespace device { | ||||
| namespace ascend { | namespace ascend { | ||||
| PluginIntf* ProfilingEngineImpl::CreatePlugin() { | |||||
| PluginIntf *ProfilingEngineImpl::CreatePlugin() { | |||||
| MS_LOG(INFO) << "Create Plugin."; | MS_LOG(INFO) << "Create Plugin."; | ||||
| return new (std::nothrow) PluginImpl("Framework"); | return new (std::nothrow) PluginImpl("Framework"); | ||||
| } | } | ||||
| int ProfilingEngineImpl::ReleasePlugin(PluginIntf* plugin) { | |||||
| int ProfilingEngineImpl::ReleasePlugin(PluginIntf *plugin) { | |||||
| if (plugin != nullptr) { | if (plugin != nullptr) { | ||||
| delete plugin; | delete plugin; | ||||
| } | } | ||||
| @@ -29,8 +29,8 @@ class ProfilingEngineImpl : public EngineIntf { | |||||
| ProfilingEngineImpl() = default; | ProfilingEngineImpl() = default; | ||||
| ~ProfilingEngineImpl() override = default; | ~ProfilingEngineImpl() override = default; | ||||
| PluginIntf* CreatePlugin() override; | |||||
| int ReleasePlugin(PluginIntf* plugin) override; | |||||
| PluginIntf *CreatePlugin() override; | |||||
| int ReleasePlugin(PluginIntf *plugin) override; | |||||
| }; | }; | ||||
| } // namespace ascend | } // namespace ascend | ||||
| } // namespace device | } // namespace device | ||||
| @@ -35,7 +35,7 @@ using Json = nlohmann::json; | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace device { | namespace device { | ||||
| namespace ascend { | namespace ascend { | ||||
| ProfilingManager& ProfilingManager::GetInstance() { | |||||
| ProfilingManager &ProfilingManager::GetInstance() { | |||||
| static ProfilingManager inst; | static ProfilingManager inst; | ||||
| return inst; | return inst; | ||||
| } | } | ||||
| @@ -45,11 +45,11 @@ ProfilingManager::ProfilingManager() : device_id_(0), prof_handle_(nullptr) { | |||||
| } | } | ||||
| uint64_t ProfilingManager::GetJobId() const { | uint64_t ProfilingManager::GetJobId() const { | ||||
| const char* job_id = std::getenv("JOB_ID"); | |||||
| const char *job_id = std::getenv("JOB_ID"); | |||||
| return ((job_id != nullptr) ? std::strtoul(job_id, nullptr, 10) : 0); | return ((job_id != nullptr) ? std::strtoul(job_id, nullptr, 10) : 0); | ||||
| } | } | ||||
| bool ProfilingManager::ReportProfilingData(const map<uint32_t, string>& op_taskId_map) const { | |||||
| bool ProfilingManager::ReportProfilingData(const map<uint32_t, string> &op_taskId_map) const { | |||||
| if (!IsProfiling()) { | if (!IsProfiling()) { | ||||
| MS_LOG(INFO) << "No need profiling. please export PROFILING_MODE and in train mode."; | MS_LOG(INFO) << "No need profiling. please export PROFILING_MODE and in train mode."; | ||||
| return false; | return false; | ||||
| @@ -66,10 +66,10 @@ bool ProfilingManager::ReportProfilingData(const map<uint32_t, string>& op_taskI | |||||
| MS_LOG(INFO) << "DistributeTask: op tasId map size = " << op_taskId_map.size(); | MS_LOG(INFO) << "DistributeTask: op tasId map size = " << op_taskId_map.size(); | ||||
| Msprof::Engine::ReporterData reporter_data = {}; | Msprof::Engine::ReporterData reporter_data = {}; | ||||
| for (const auto& iter : op_taskId_map) { | |||||
| for (const auto &iter : op_taskId_map) { | |||||
| auto data = iter.second + ' ' + std::to_string(iter.first) + ';'; | auto data = iter.second + ' ' + std::to_string(iter.first) + ';'; | ||||
| reporter_data.deviceId = UintToInt(device_id_); | reporter_data.deviceId = UintToInt(device_id_); | ||||
| reporter_data.data = (unsigned char*)(const_cast<char*>(data.c_str())); | |||||
| reporter_data.data = (unsigned char *)(const_cast<char *>(data.c_str())); | |||||
| reporter_data.dataLen = data.size(); | reporter_data.dataLen = data.size(); | ||||
| auto ret = memcpy_s(reporter_data.tag, MSPROF_ENGINE_MAX_TAG_LEN + 1, "framework", sizeof("framework")); | auto ret = memcpy_s(reporter_data.tag, MSPROF_ENGINE_MAX_TAG_LEN + 1, "framework", sizeof("framework")); | ||||
| if (ret != 0) { | if (ret != 0) { | ||||
| @@ -85,7 +85,7 @@ bool ProfilingManager::ReportProfilingData(const map<uint32_t, string>& op_taskI | |||||
| return true; | return true; | ||||
| } | } | ||||
| static std::vector<std::string> Split(const std::string& str, const char delim) { | |||||
| static std::vector<std::string> Split(const std::string &str, const char delim) { | |||||
| std::vector<std::string> elems; | std::vector<std::string> elems; | ||||
| if (str.empty()) { | if (str.empty()) { | ||||
| @@ -116,7 +116,7 @@ bool ProfilingManager::StartupProfiling(uint32_t device_id) { | |||||
| device_id_ = device_id; | device_id_ = device_id; | ||||
| // exp: export PROFILING_MODE=true | // exp: export PROFILING_MODE=true | ||||
| // export PROFILING_OPTIONS=training_trace | // export PROFILING_OPTIONS=training_trace | ||||
| const char* prof_options_str = std::getenv("PROFILING_OPTIONS"); | |||||
| const char *prof_options_str = std::getenv("PROFILING_OPTIONS"); | |||||
| // register Framework to profiling | // register Framework to profiling | ||||
| int result = Msprof::Engine::RegisterEngine("Framework", engine_0_.get()); | int result = Msprof::Engine::RegisterEngine("Framework", engine_0_.get()); | ||||
| if (result != 0) { | if (result != 0) { | ||||
| @@ -176,7 +176,7 @@ bool ProfilingManager::StopProfiling() const { | |||||
| MS_LOG(INFO) << "No need profiling. please export PROFILING_MODE and in train mode."; | MS_LOG(INFO) << "No need profiling. please export PROFILING_MODE and in train mode."; | ||||
| return true; | return true; | ||||
| } | } | ||||
| Msprof::Engine::Reporter* reporter = PluginImpl::GetPluginReporter(); | |||||
| Msprof::Engine::Reporter *reporter = PluginImpl::GetPluginReporter(); | |||||
| if (reporter != nullptr) { | if (reporter != nullptr) { | ||||
| MS_LOG(INFO) << "report data end, ret = " << reporter->Flush(); | MS_LOG(INFO) << "report data end, ret = " << reporter->Flush(); | ||||
| } | } | ||||
| @@ -33,27 +33,27 @@ enum BlockQueueStatus_T : int { SUCCESS = 0, QUEUE_NOT_EXIST, HANDLE_NOT_EXIST, | |||||
| class GpuQueue { | class GpuQueue { | ||||
| public: | public: | ||||
| GpuQueue(void* addr, size_t feature_size, size_t label_size, size_t capacity); | |||||
| GpuQueue(void *addr, size_t feature_size, size_t label_size, size_t capacity); | |||||
| virtual ~GpuQueue(); | virtual ~GpuQueue(); | ||||
| void RegisterRelease(const std::function<void(void*)>& func) { host_release_ = func; } | |||||
| void RegisterRelease(const std::function<void(void *)> &func) { host_release_ = func; } | |||||
| inline bool IsEmpty() const { return head_ == tail_; } | inline bool IsEmpty() const { return head_ == tail_; } | ||||
| inline bool IsFull() const { return head_ == ((tail_ + 1) % (capacity_)); } | inline bool IsFull() const { return head_ == ((tail_ + 1) % (capacity_)); } | ||||
| BlockQueueStatus_T Push(void* feature_addr, size_t feature_size, void* label_addr, size_t label_size); | |||||
| BlockQueueStatus_T Front(void** feature_addr, size_t* feature_size, void** label_addr, size_t* label_size) const; | |||||
| BlockQueueStatus_T Push(void *feature_addr, size_t feature_size, void *label_addr, size_t label_size); | |||||
| BlockQueueStatus_T Front(void **feature_addr, size_t *feature_size, void **label_addr, size_t *label_size) const; | |||||
| BlockQueueStatus_T Pop(); | BlockQueueStatus_T Pop(); | ||||
| bool Destroy(); | bool Destroy(); | ||||
| private: | private: | ||||
| struct NodeInfo { | struct NodeInfo { | ||||
| std::unique_ptr<cudaEvent_t> event_; | std::unique_ptr<cudaEvent_t> event_; | ||||
| void* host_feature_addr_; | |||||
| void* host_label_addr_; | |||||
| void *host_feature_addr_; | |||||
| void *host_label_addr_; | |||||
| }; | }; | ||||
| void* buffer_; | |||||
| void *buffer_; | |||||
| size_t head_; | size_t head_; | ||||
| size_t tail_; | size_t tail_; | ||||
| size_t feature_size_; | size_t feature_size_; | ||||
| @@ -61,10 +61,10 @@ class GpuQueue { | |||||
| size_t capacity_; | size_t capacity_; | ||||
| cudaStream_t stream_; | cudaStream_t stream_; | ||||
| std::unique_ptr<NodeInfo[]> node_info_; | std::unique_ptr<NodeInfo[]> node_info_; | ||||
| std::function<void(void*)> host_release_; | |||||
| std::function<void(void *)> host_release_; | |||||
| GpuQueue(const GpuQueue&) = delete; | |||||
| GpuQueue& operator=(const GpuQueue&) = delete; | |||||
| GpuQueue(const GpuQueue &) = delete; | |||||
| GpuQueue &operator=(const GpuQueue &) = delete; | |||||
| }; | }; | ||||
| class BlockingQueue { | class BlockingQueue { | ||||
| @@ -72,11 +72,11 @@ class BlockingQueue { | |||||
| BlockingQueue() : queue_(nullptr) {} | BlockingQueue() : queue_(nullptr) {} | ||||
| ~BlockingQueue() = default; | ~BlockingQueue() = default; | ||||
| BlockQueueStatus_T Create(void* addr, size_t feature_size, size_t label_size, size_t capacity); | |||||
| void RegisterRelease(const std::function<void(void*)>& func); | |||||
| BlockQueueStatus_T Push(void* feature_addr, size_t feature_size, void* label_addr, size_t label_size, | |||||
| BlockQueueStatus_T Create(void *addr, size_t feature_size, size_t label_size, size_t capacity); | |||||
| void RegisterRelease(const std::function<void(void *)> &func); | |||||
| BlockQueueStatus_T Push(void *feature_addr, size_t feature_size, void *label_addr, size_t label_size, | |||||
| unsigned int timeout_in_sec); | unsigned int timeout_in_sec); | ||||
| BlockQueueStatus_T Front(void** feature_addr, size_t* feature_size, void** label_addr, size_t* label_size); | |||||
| BlockQueueStatus_T Front(void **feature_addr, size_t *feature_size, void **label_addr, size_t *label_size); | |||||
| BlockQueueStatus_T Pop(); | BlockQueueStatus_T Pop(); | ||||
| bool Destroy(); | bool Destroy(); | ||||
| @@ -20,17 +20,17 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace device { | namespace device { | ||||
| namespace gpu { | namespace gpu { | ||||
| CollectiveInitializer& CollectiveInitializer::instance() { | |||||
| CollectiveInitializer &CollectiveInitializer::instance() { | |||||
| static CollectiveInitializer instance = {}; | static CollectiveInitializer instance = {}; | ||||
| return instance; | return instance; | ||||
| } | } | ||||
| bool CollectiveInitializer::collective_inited() const { return collective_inited_; } | bool CollectiveInitializer::collective_inited() const { return collective_inited_; } | ||||
| const void* CollectiveInitializer::collective_handle() const { return collective_handle_; } | |||||
| const void *CollectiveInitializer::collective_handle() const { return collective_handle_; } | |||||
| void CollectiveInitializer::InitCollective() { | void CollectiveInitializer::InitCollective() { | ||||
| void* handle = dlopen("libgpu_collective.so", RTLD_LAZY); | |||||
| void *handle = dlopen("libgpu_collective.so", RTLD_LAZY); | |||||
| if (handle == nullptr) { | if (handle == nullptr) { | ||||
| MS_LOG(EXCEPTION) | MS_LOG(EXCEPTION) | ||||
| << "Loading libgpu_collective.so failed. Many reasons could cause this:\n1.libgpu_collective.so is not " | << "Loading libgpu_collective.so failed. Many reasons could cause this:\n1.libgpu_collective.so is not " | ||||
| @@ -50,13 +50,13 @@ void GPUDeviceManager::ReleaseDevice() { | |||||
| CHECK_OP_RET_WITH_ERROR(GPUMemoryAllocator::GetInstance().Finalize(), "Failed to destroy gpu memory allocator"); | CHECK_OP_RET_WITH_ERROR(GPUMemoryAllocator::GetInstance().Finalize(), "Failed to destroy gpu memory allocator"); | ||||
| } | } | ||||
| bool GPUDeviceManager::CreateStream(DeviceStream* stream) { | |||||
| bool GPUDeviceManager::CreateStream(DeviceStream *stream) { | |||||
| CHECK_OP_RET_WITH_EXCEPT(CudaDriver::CreateStream(stream), "Failed to create CUDA stream"); | CHECK_OP_RET_WITH_EXCEPT(CudaDriver::CreateStream(stream), "Failed to create CUDA stream"); | ||||
| gpu_streams_.emplace_back(*stream); | gpu_streams_.emplace_back(*stream); | ||||
| return true; | return true; | ||||
| } | } | ||||
| const DeviceStream& GPUDeviceManager::default_stream() const { return default_stream_; } | |||||
| const DeviceStream &GPUDeviceManager::default_stream() const { return default_stream_; } | |||||
| int GPUDeviceManager::device_count() const { return CudaDriver::device_count(); } | int GPUDeviceManager::device_count() const { return CudaDriver::device_count(); } | ||||
| @@ -76,17 +76,17 @@ uint32_t GPUDeviceManager::cur_device_id() const { return cur_dev_id_; } | |||||
| bool GPUDeviceManager::is_device_id_init() const { return dev_id_init_; } | bool GPUDeviceManager::is_device_id_init() const { return dev_id_init_; } | ||||
| const cudnnHandle_t& GPUDeviceManager::GetCudnnHandle() const { return cudnn_handle_; } | |||||
| const cudnnHandle_t &GPUDeviceManager::GetCudnnHandle() const { return cudnn_handle_; } | |||||
| const cublasHandle_t& GPUDeviceManager::GetCublasHandle() const { return cublas_handle_; } | |||||
| const cublasHandle_t &GPUDeviceManager::GetCublasHandle() const { return cublas_handle_; } | |||||
| bool GPUDeviceManager::SyncStream(const DeviceStream& stream) const { return CudaDriver::SyncStream(stream); } | |||||
| bool GPUDeviceManager::SyncStream(const DeviceStream &stream) const { return CudaDriver::SyncStream(stream); } | |||||
| bool GPUDeviceManager::CopyDeviceMemToHost(const HostMemPtr& dst, const DeviceMemPtr& src, size_t size) const { | |||||
| bool GPUDeviceManager::CopyDeviceMemToHost(const HostMemPtr &dst, const DeviceMemPtr &src, size_t size) const { | |||||
| return CudaDriver::CopyDeviceMemToHost(dst, src, size); | return CudaDriver::CopyDeviceMemToHost(dst, src, size); | ||||
| } | } | ||||
| bool GPUDeviceManager::CopyHostMemToDevice(const DeviceMemPtr& dst, const void* src, size_t size) const { | |||||
| bool GPUDeviceManager::CopyHostMemToDevice(const DeviceMemPtr &dst, const void *src, size_t size) const { | |||||
| return CudaDriver::CopyHostMemToDevice(dst, src, size); | return CudaDriver::CopyHostMemToDevice(dst, src, size); | ||||
| } | } | ||||
| } // namespace gpu | } // namespace gpu | ||||
| @@ -37,17 +37,17 @@ class GPUDeviceManager { | |||||
| uint32_t cur_device_id() const; | uint32_t cur_device_id() const; | ||||
| bool is_device_id_init() const; | bool is_device_id_init() const; | ||||
| bool CreateStream(DeviceStream* stream); | |||||
| bool SyncStream(const DeviceStream& stream) const; | |||||
| const DeviceStream& default_stream() const; | |||||
| bool CreateStream(DeviceStream *stream); | |||||
| bool SyncStream(const DeviceStream &stream) const; | |||||
| const DeviceStream &default_stream() const; | |||||
| const cudnnHandle_t& GetCudnnHandle() const; | |||||
| const cublasHandle_t& GetCublasHandle() const; | |||||
| const cudnnHandle_t &GetCudnnHandle() const; | |||||
| const cublasHandle_t &GetCublasHandle() const; | |||||
| bool CopyDeviceMemToHost(const HostMemPtr& dst, const DeviceMemPtr& src, size_t size) const; | |||||
| bool CopyHostMemToDevice(const DeviceMemPtr& dst, const void* src, size_t size) const; | |||||
| bool CopyDeviceMemToHost(const HostMemPtr &dst, const DeviceMemPtr &src, size_t size) const; | |||||
| bool CopyHostMemToDevice(const DeviceMemPtr &dst, const void *src, size_t size) const; | |||||
| static GPUDeviceManager& GetInstance() { | |||||
| static GPUDeviceManager &GetInstance() { | |||||
| static GPUDeviceManager instance; | static GPUDeviceManager instance; | ||||
| return instance; | return instance; | ||||
| } | } | ||||
| @@ -55,8 +55,8 @@ class GPUDeviceManager { | |||||
| private: | private: | ||||
| GPUDeviceManager() : dev_id_init_(false), cur_dev_id_(0) {} | GPUDeviceManager() : dev_id_init_(false), cur_dev_id_(0) {} | ||||
| ~GPUDeviceManager() = default; | ~GPUDeviceManager() = default; | ||||
| GPUDeviceManager(const GPUDeviceManager&) = delete; | |||||
| GPUDeviceManager& operator=(const GPUDeviceManager&) = delete; | |||||
| GPUDeviceManager(const GPUDeviceManager &) = delete; | |||||
| GPUDeviceManager &operator=(const GPUDeviceManager &) = delete; | |||||
| // default CUDA stream used for all the kernels. | // default CUDA stream used for all the kernels. | ||||
| DeviceStream default_stream_{nullptr}; | DeviceStream default_stream_{nullptr}; | ||||
| @@ -43,14 +43,14 @@ bool GPUMemoryAllocator::Finalize() { | |||||
| return true; | return true; | ||||
| } | } | ||||
| bool GPUMemoryAllocator::AllocBufferQueueMem(size_t size, DeviceMemPtr* addr) { | |||||
| bool GPUMemoryAllocator::AllocBufferQueueMem(size_t size, DeviceMemPtr *addr) { | |||||
| auto alloc_size = AllocDeviceMem(size, addr); | auto alloc_size = AllocDeviceMem(size, addr); | ||||
| buffer_q_addr_ = *addr; | buffer_q_addr_ = *addr; | ||||
| // Buffer queue needs to ensure that the alloc_size and size is equal. | // Buffer queue needs to ensure that the alloc_size and size is equal. | ||||
| return (alloc_size == size) ? true : false; | return (alloc_size == size) ? true : false; | ||||
| } | } | ||||
| size_t GPUMemoryAllocator::AllocDeviceMem(size_t size, DeviceMemPtr* addr) { | |||||
| size_t GPUMemoryAllocator::AllocDeviceMem(size_t size, DeviceMemPtr *addr) { | |||||
| if (size == 0) { | if (size == 0) { | ||||
| MS_LOG(EXCEPTION) << "The memory alloc size is 0."; | MS_LOG(EXCEPTION) << "The memory alloc size is 0."; | ||||
| } | } | ||||
| @@ -68,7 +68,7 @@ size_t GPUMemoryAllocator::AllocDeviceMem(size_t size, DeviceMemPtr* addr) { | |||||
| return alloc_size; | return alloc_size; | ||||
| } | } | ||||
| bool GPUMemoryAllocator::FreeDeviceMem(const DeviceMemPtr& addr) { return CudaDriver::FreeDeviceMem(addr); } | |||||
| bool GPUMemoryAllocator::FreeDeviceMem(const DeviceMemPtr &addr) { return CudaDriver::FreeDeviceMem(addr); } | |||||
| size_t GPUMemoryAllocator::free_mem_size() { return CudaDriver::free_mem_size(); } | size_t GPUMemoryAllocator::free_mem_size() { return CudaDriver::free_mem_size(); } | ||||
| @@ -29,22 +29,22 @@ class GPUMemoryAllocator : public DynamicMemPoolBestFit { | |||||
| ~GPUMemoryAllocator() override = default; | ~GPUMemoryAllocator() override = default; | ||||
| bool Init(); | bool Init(); | ||||
| bool Finalize(); | bool Finalize(); | ||||
| bool AllocBufferQueueMem(size_t size, DeviceMemPtr* addr); | |||||
| bool AllocBufferQueueMem(size_t size, DeviceMemPtr *addr); | |||||
| size_t AllocDeviceMem(size_t size, DeviceMemPtr* addr) override; | |||||
| bool FreeDeviceMem(const DeviceMemPtr& addr) override; | |||||
| size_t AllocDeviceMem(size_t size, DeviceMemPtr *addr) override; | |||||
| bool FreeDeviceMem(const DeviceMemPtr &addr) override; | |||||
| size_t free_mem_size() override; | size_t free_mem_size() override; | ||||
| size_t total_mem_size() override; | size_t total_mem_size() override; | ||||
| static GPUMemoryAllocator& GetInstance() { | |||||
| static GPUMemoryAllocator &GetInstance() { | |||||
| static GPUMemoryAllocator instance; | static GPUMemoryAllocator instance; | ||||
| return instance; | return instance; | ||||
| } | } | ||||
| private: | private: | ||||
| GPUMemoryAllocator() = default; | GPUMemoryAllocator() = default; | ||||
| GPUMemoryAllocator(const GPUMemoryAllocator&) = delete; | |||||
| GPUMemoryAllocator& operator=(const GPUMemoryAllocator&) = delete; | |||||
| GPUMemoryAllocator(const GPUMemoryAllocator &) = delete; | |||||
| GPUMemoryAllocator &operator=(const GPUMemoryAllocator &) = delete; | |||||
| // Used to track address of data buffer queue. | // Used to track address of data buffer queue. | ||||
| DeviceMemPtr buffer_q_addr_{nullptr}; | DeviceMemPtr buffer_q_addr_{nullptr}; | ||||
| @@ -33,8 +33,8 @@ namespace gpu { | |||||
| using AnfAlgo = mindspore::session::AnfRuntimeAlgorithm; | using AnfAlgo = mindspore::session::AnfRuntimeAlgorithm; | ||||
| using mindspore::kernel::KernelBuildInfo; | using mindspore::kernel::KernelBuildInfo; | ||||
| namespace { | namespace { | ||||
| bool CheckKernelInfo(const std::shared_ptr<KernelBuildInfo>& alternative_kernel_info, | |||||
| const std::shared_ptr<KernelBuildInfo>& selected_kernel_info) { | |||||
| bool CheckKernelInfo(const std::shared_ptr<KernelBuildInfo> &alternative_kernel_info, | |||||
| const std::shared_ptr<KernelBuildInfo> &selected_kernel_info) { | |||||
| MS_EXCEPTION_IF_NULL(selected_kernel_info); | MS_EXCEPTION_IF_NULL(selected_kernel_info); | ||||
| MS_EXCEPTION_IF_NULL(alternative_kernel_info); | MS_EXCEPTION_IF_NULL(alternative_kernel_info); | ||||
| size_t selected_input_num = selected_kernel_info->GetInputNum(); | size_t selected_input_num = selected_kernel_info->GetInputNum(); | ||||
| @@ -67,7 +67,7 @@ bool CheckKernelInfo(const std::shared_ptr<KernelBuildInfo>& alternative_kernel_ | |||||
| return true; | return true; | ||||
| } | } | ||||
| std::string SupportedTypeList(const CNodePtr& kernel_node) { | |||||
| std::string SupportedTypeList(const CNodePtr &kernel_node) { | |||||
| std::string supported_type_lists = | std::string supported_type_lists = | ||||
| kernel::GpuKernelFactory::GetInstance().SupportedTypeList(AnfAlgo::GetCNodeName(kernel_node)); | kernel::GpuKernelFactory::GetInstance().SupportedTypeList(AnfAlgo::GetCNodeName(kernel_node)); | ||||
| if (!supported_type_lists.empty()) { | if (!supported_type_lists.empty()) { | ||||
| @@ -91,7 +91,7 @@ std::string SupportedTypeList(const CNodePtr& kernel_node) { | |||||
| return supported_type_lists; | return supported_type_lists; | ||||
| } | } | ||||
| bool SelectAkgKernel(const CNodePtr& kernel_node, const std::shared_ptr<KernelBuildInfo>& selected_kernel_info) { | |||||
| bool SelectAkgKernel(const CNodePtr &kernel_node, const std::shared_ptr<KernelBuildInfo> &selected_kernel_info) { | |||||
| MS_EXCEPTION_IF_NULL(kernel_node); | MS_EXCEPTION_IF_NULL(kernel_node); | ||||
| MS_EXCEPTION_IF_NULL(selected_kernel_info); | MS_EXCEPTION_IF_NULL(selected_kernel_info); | ||||
| std::vector<std::shared_ptr<KernelBuildInfo>> kernel_info_list; | std::vector<std::shared_ptr<KernelBuildInfo>> kernel_info_list; | ||||
| @@ -110,7 +110,7 @@ bool SelectAkgKernel(const CNodePtr& kernel_node, const std::shared_ptr<KernelBu | |||||
| } | } | ||||
| bool match = std::any_of(kernel_info_list.begin(), kernel_info_list.end(), | bool match = std::any_of(kernel_info_list.begin(), kernel_info_list.end(), | ||||
| [&](const std::shared_ptr<KernelBuildInfo>& alternative_kernel_info) { | |||||
| [&](const std::shared_ptr<KernelBuildInfo> &alternative_kernel_info) { | |||||
| return CheckKernelInfo(alternative_kernel_info, selected_kernel_info); | return CheckKernelInfo(alternative_kernel_info, selected_kernel_info); | ||||
| }); | }); | ||||
| if (!match) { | if (!match) { | ||||
| @@ -120,7 +120,7 @@ bool SelectAkgKernel(const CNodePtr& kernel_node, const std::shared_ptr<KernelBu | |||||
| return true; | return true; | ||||
| } | } | ||||
| void SetTensorDeviceInfo(const kernel::KernelBuildInfo& selected_kernel_info, const CNodePtr& kernel_node) { | |||||
| void SetTensorDeviceInfo(const kernel::KernelBuildInfo &selected_kernel_info, const CNodePtr &kernel_node) { | |||||
| MS_EXCEPTION_IF_NULL(kernel_node); | MS_EXCEPTION_IF_NULL(kernel_node); | ||||
| for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) { | for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) { | ||||
| auto input_kernel_node = kernel_node->input(input_index + 1); | auto input_kernel_node = kernel_node->input(input_index + 1); | ||||
| @@ -153,7 +153,7 @@ void SetTensorDeviceInfo(const kernel::KernelBuildInfo& selected_kernel_info, co | |||||
| } | } | ||||
| } // namespace | } // namespace | ||||
| void SetKernelInfo(const CNodePtr& kernel_node) { | |||||
| void SetKernelInfo(const CNodePtr &kernel_node) { | |||||
| std::vector<std::string> inputs_format; | std::vector<std::string> inputs_format; | ||||
| std::vector<TypeId> inputs_type; | std::vector<TypeId> inputs_type; | ||||
| std::shared_ptr<KernelBuildInfo::KernelBuildInfoBuilder> builder = | std::shared_ptr<KernelBuildInfo::KernelBuildInfoBuilder> builder = | ||||
| @@ -27,7 +27,7 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace device { | namespace device { | ||||
| namespace gpu { | namespace gpu { | ||||
| void SetKernelInfo(const CNodePtr& apply_kernel_ptr); | |||||
| void SetKernelInfo(const CNodePtr &apply_kernel_ptr); | |||||
| class KernelAttr { | class KernelAttr { | ||||
| public: | public: | ||||
| @@ -35,24 +35,24 @@ class KernelAttr { | |||||
| KernelAttr() : all_same_(false) {} | KernelAttr() : all_same_(false) {} | ||||
| ~KernelAttr() = default; | ~KernelAttr() = default; | ||||
| KernelAttr& AddInputAttr(const TypeId& ms_type, const std::string& format = kOpFormat_DEFAULT) { | |||||
| KernelAttr &AddInputAttr(const TypeId &ms_type, const std::string &format = kOpFormat_DEFAULT) { | |||||
| input_type_.emplace_back(ms_type, format); | input_type_.emplace_back(ms_type, format); | ||||
| return *this; | return *this; | ||||
| } | } | ||||
| KernelAttr& AddOutputAttr(const TypeId& ms_type, const std::string& format = kOpFormat_DEFAULT) { | |||||
| KernelAttr &AddOutputAttr(const TypeId &ms_type, const std::string &format = kOpFormat_DEFAULT) { | |||||
| output_type_.emplace_back(ms_type, format); | output_type_.emplace_back(ms_type, format); | ||||
| return *this; | return *this; | ||||
| } | } | ||||
| KernelAttr& AddAllSameAttr(const bool& all_same) { | |||||
| KernelAttr &AddAllSameAttr(const bool &all_same) { | |||||
| all_same_ = all_same; | all_same_ = all_same; | ||||
| return *this; | return *this; | ||||
| } | } | ||||
| const DataType& GetInputAttr(const size_t index) const { return input_type_[index]; } | |||||
| const DataType& GetOutputAttr(const size_t index) const { return output_type_[index]; } | |||||
| const bool& GetAllSame() const { return all_same_; } | |||||
| const DataType &GetInputAttr(const size_t index) const { return input_type_[index]; } | |||||
| const DataType &GetOutputAttr(const size_t index) const { return output_type_[index]; } | |||||
| const bool &GetAllSame() const { return all_same_; } | |||||
| size_t GetInputSize() const { return input_type_.size(); } | size_t GetInputSize() const { return input_type_.size(); } | ||||
| size_t GetOutputSize() const { return output_type_.size(); } | size_t GetOutputSize() const { return output_type_.size(); } | ||||
| @@ -24,7 +24,7 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| struct TypeIdManager* TypeIdManager::Get() { | |||||
| struct TypeIdManager *TypeIdManager::Get() { | |||||
| static TypeIdManager manager; | static TypeIdManager manager; | ||||
| return &manager; | return &manager; | ||||
| } | } | ||||
| @@ -35,14 +35,14 @@ TypePtr AnfNode::Type() const { return (abstract_ == nullptr) ? nullptr : abstra | |||||
| BaseShapePtr AnfNode::Shape() const { return (abstract_ == nullptr) ? nullptr : abstract_->BuildShape(); } | BaseShapePtr AnfNode::Shape() const { return (abstract_ == nullptr) ? nullptr : abstract_->BuildShape(); } | ||||
| std::string AnfNode::ToString() const { | std::string AnfNode::ToString() const { | ||||
| return mindspore::label_manage::Label(const_cast<AnfNode*>(this)->shared_from_base<AnfNode>()->debug_info()); | |||||
| return mindspore::label_manage::Label(const_cast<AnfNode *>(this)->shared_from_base<AnfNode>()->debug_info()); | |||||
| } | } | ||||
| CNode::CNode(const std::vector<AnfNodePtr>& inputs, const FuncGraphPtr& func_graph) | |||||
| CNode::CNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphPtr &func_graph) | |||||
| : AnfNode(func_graph), inputs_(inputs), stop_gradient_(false) {} | : AnfNode(func_graph), inputs_(inputs), stop_gradient_(false) {} | ||||
| // Check if CNode is an apply with the specific Primitive. | // Check if CNode is an apply with the specific Primitive. | ||||
| bool CNode::IsApply(const PrimitivePtr& value) const { | |||||
| bool CNode::IsApply(const PrimitivePtr &value) const { | |||||
| if (value == nullptr) { | if (value == nullptr) { | ||||
| return false; | return false; | ||||
| } | } | ||||
| @@ -57,7 +57,7 @@ bool CNode::IsApply(const PrimitivePtr& value) const { | |||||
| return false; | return false; | ||||
| } | } | ||||
| void CNode::set_input(size_t i, const AnfNodePtr& new_input) { inputs_[i] = new_input; } | |||||
| void CNode::set_input(size_t i, const AnfNodePtr &new_input) { inputs_[i] = new_input; } | |||||
| std::string CNode::DebugString(int recursive_level) const { | std::string CNode::DebugString(int recursive_level) const { | ||||
| std::ostringstream buffer; | std::ostringstream buffer; | ||||
| @@ -68,7 +68,7 @@ std::string CNode::DebugString(int recursive_level) const { | |||||
| buffer << ToString() << "{"; | buffer << ToString() << "{"; | ||||
| bool is_first_node = true; | bool is_first_node = true; | ||||
| int idx = 0; | int idx = 0; | ||||
| for (auto& node : inputs_) { | |||||
| for (auto &node : inputs_) { | |||||
| MS_EXCEPTION_IF_NULL(node); | MS_EXCEPTION_IF_NULL(node); | ||||
| if (is_first_node) { | if (is_first_node) { | ||||
| is_first_node = false; | is_first_node = false; | ||||
| @@ -85,7 +85,7 @@ std::string CNode::DebugString(int recursive_level) const { | |||||
| return buffer.str(); | return buffer.str(); | ||||
| } | } | ||||
| OperatorInfoPtr CNode::set_operator_info(const OperatorInfoPtr& operator_info) { | |||||
| OperatorInfoPtr CNode::set_operator_info(const OperatorInfoPtr &operator_info) { | |||||
| if (operator_info_ != nullptr) { | if (operator_info_ != nullptr) { | ||||
| MS_LOG(WARNING) << "The CNode: " << ToString() << " has already been set OperatorInfo: " << operator_info_->name() | MS_LOG(WARNING) << "The CNode: " << ToString() << " has already been set OperatorInfo: " << operator_info_->name() | ||||
| << ", using the new one: " << operator_info->name(); | << ", using the new one: " << operator_info->name(); | ||||
| @@ -173,11 +173,11 @@ std::string ValueNode::fullname_with_scope() { | |||||
| return fullname_with_scope_; | return fullname_with_scope_; | ||||
| } | } | ||||
| void CNode::accept(AnfVisitor* v) { v->Visit(shared_from_base<CNode>()); } | |||||
| void ValueNode::accept(AnfVisitor* v) { v->Visit(shared_from_base<ValueNode>()); } | |||||
| void Parameter::accept(AnfVisitor* v) { v->Visit(shared_from_base<Parameter>()); } | |||||
| void CNode::accept(AnfVisitor *v) { v->Visit(shared_from_base<CNode>()); } | |||||
| void ValueNode::accept(AnfVisitor *v) { v->Visit(shared_from_base<ValueNode>()); } | |||||
| void Parameter::accept(AnfVisitor *v) { v->Visit(shared_from_base<Parameter>()); } | |||||
| bool IsPrimitiveCNode(const AnfNodePtr& node, const PrimitivePtr& value) { | |||||
| bool IsPrimitiveCNode(const AnfNodePtr &node, const PrimitivePtr &value) { | |||||
| MS_EXCEPTION_IF_NULL(node); | MS_EXCEPTION_IF_NULL(node); | ||||
| auto cnode = node->cast<CNodePtr>(); | auto cnode = node->cast<CNodePtr>(); | ||||
| if (cnode != nullptr) { | if (cnode != nullptr) { | ||||
| @@ -186,7 +186,7 @@ bool IsPrimitiveCNode(const AnfNodePtr& node, const PrimitivePtr& value) { | |||||
| return false; | return false; | ||||
| } | } | ||||
| PrimitivePtr GetCNodePrimitive(const AnfNodePtr& node) { | |||||
| PrimitivePtr GetCNodePrimitive(const AnfNodePtr &node) { | |||||
| if (node == nullptr) { | if (node == nullptr) { | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| @@ -217,7 +217,7 @@ std::string GetCNodeFuncName(const CNodePtr cnode) { | |||||
| return ""; | return ""; | ||||
| } | } | ||||
| bool IsPrimitive(const AnfNodePtr& node, const PrimitivePtr& value) { | |||||
| bool IsPrimitive(const AnfNodePtr &node, const PrimitivePtr &value) { | |||||
| if (IsValueNode<Primitive>(node)) { | if (IsValueNode<Primitive>(node)) { | ||||
| PrimitivePtr fn_value = GetValueNode<PrimitivePtr>(node); | PrimitivePtr fn_value = GetValueNode<PrimitivePtr>(node); | ||||
| MS_EXCEPTION_IF_NULL(value); | MS_EXCEPTION_IF_NULL(value); | ||||
| @@ -229,7 +229,7 @@ bool IsPrimitive(const AnfNodePtr& node, const PrimitivePtr& value) { | |||||
| } | } | ||||
| namespace id_generator { | namespace id_generator { | ||||
| static std::unordered_map<std::string, int> node_ids; | static std::unordered_map<std::string, int> node_ids; | ||||
| std::string get_id(const AnfNodePtr& node) { | |||||
| std::string get_id(const AnfNodePtr &node) { | |||||
| auto type_name = node->type_name(); | auto type_name = node->type_name(); | ||||
| if (node_ids.find(type_name) == node_ids.end()) { | if (node_ids.find(type_name) == node_ids.end()) { | ||||
| node_ids[type_name] = 0; | node_ids[type_name] = 0; | ||||
| @@ -39,15 +39,15 @@ struct is_shared_ptr<std::shared_ptr<T>> : public std::true_type {}; | |||||
| class Base : public std::enable_shared_from_this<Base> { | class Base : public std::enable_shared_from_this<Base> { | ||||
| public: | public: | ||||
| constexpr Base() = default; | constexpr Base() = default; | ||||
| Base(const Base& other) : std::enable_shared_from_this<Base>(other) {} | |||||
| virtual bool operator==(const Base& rhs) { | |||||
| Base(const Base &other) : std::enable_shared_from_this<Base>(other) {} | |||||
| virtual bool operator==(const Base &rhs) { | |||||
| if (this == &rhs) { | if (this == &rhs) { | ||||
| return true; | return true; | ||||
| } | } | ||||
| return false; | return false; | ||||
| } | } | ||||
| virtual Base& operator=(const Base&) { return *this; } | |||||
| virtual Base &operator=(const Base &) { return *this; } | |||||
| virtual ~Base() = default; | virtual ~Base() = default; | ||||
| virtual std::size_t hash() const { return tid(); } | virtual std::size_t hash() const { return tid(); } | ||||
| virtual std::string ToString() const { return type_name(); } | virtual std::string ToString() const { return type_name(); } | ||||
| @@ -57,14 +57,14 @@ class Base : public std::enable_shared_from_this<Base> { | |||||
| virtual const bool IsFromTypeId(uint32_t tid) const; | virtual const bool IsFromTypeId(uint32_t tid) const; | ||||
| virtual std::string type_name() const { return "Base"; } | virtual std::string type_name() const { return "Base"; } | ||||
| static uint32_t GetTypeId(const char* const type_key); | |||||
| static uint32_t GetTypeId(const char *const type_key); | |||||
| virtual uint32_t tid() const { | virtual uint32_t tid() const { | ||||
| static const uint32_t tid = GetTypeId(typeid(Base).name()); | static const uint32_t tid = GetTypeId(typeid(Base).name()); | ||||
| return tid; | return tid; | ||||
| } | } | ||||
| template <typename T, | template <typename T, | ||||
| typename std::enable_if<!is_shared_ptr<T>::value && std::is_base_of<Base, T>::value, T>::type* = nullptr> | |||||
| typename std::enable_if<!is_shared_ptr<T>::value && std::is_base_of<Base, T>::value, T>::type * = nullptr> | |||||
| inline bool isa() const { | inline bool isa() const { | ||||
| static const uint32_t tid = GetTypeId(typeid(T).name()); | static const uint32_t tid = GetTypeId(typeid(T).name()); | ||||
| return this->IsFromTypeId(tid); | return this->IsFromTypeId(tid); | ||||
| @@ -90,9 +90,9 @@ using BasePtr = std::shared_ptr<Base>; | |||||
| using BaseWeakPtr = std::weak_ptr<Base>; | using BaseWeakPtr = std::weak_ptr<Base>; | ||||
| template <typename T, typename U> | template <typename T, typename U> | ||||
| inline T* cast(U* source) { | |||||
| inline T *cast(U *source) { | |||||
| if (source != nullptr && source->template isa<T>()) { | if (source != nullptr && source->template isa<T>()) { | ||||
| return static_cast<T*>(source); | |||||
| return static_cast<T *>(source); | |||||
| } else { | } else { | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| @@ -100,7 +100,7 @@ inline T* cast(U* source) { | |||||
| template < | template < | ||||
| typename T, typename U, | typename T, typename U, | ||||
| typename std::enable_if<std::is_base_of<Base, T>::value && std::is_base_of<Base, U>::value, T>::type* = nullptr> | |||||
| typename std::enable_if<std::is_base_of<Base, T>::value && std::is_base_of<Base, U>::value, T>::type * = nullptr> | |||||
| inline std::shared_ptr<T> dyn_cast(const std::shared_ptr<U> r) { | inline std::shared_ptr<T> dyn_cast(const std::shared_ptr<U> r) { | ||||
| if (r != nullptr && r->template isa<T>()) { | if (r != nullptr && r->template isa<T>()) { | ||||
| return std::static_pointer_cast<T>(r); | return std::static_pointer_cast<T>(r); | ||||
| @@ -143,7 +143,7 @@ struct MS_EXPORT TypeIdManager { | |||||
| std::mutex mutex; | std::mutex mutex; | ||||
| std::atomic<uint32_t> type_counter{0}; | std::atomic<uint32_t> type_counter{0}; | ||||
| std::unordered_map<std::string, uint32_t> map; | std::unordered_map<std::string, uint32_t> map; | ||||
| static TypeIdManager* Get(); | |||||
| static TypeIdManager *Get(); | |||||
| TypeIdManager() : mutex(), type_counter(0), map() {} | TypeIdManager() : mutex(), type_counter(0), map() {} | ||||
| }; | }; | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -48,11 +48,11 @@ std::string Keyword::ToString() const { | |||||
| return buffer.str(); | return buffer.str(); | ||||
| } | } | ||||
| bool Keyword::operator==(const Type& other) const { | |||||
| bool Keyword::operator==(const Type &other) const { | |||||
| if (!IsSameObjectType(*this, other)) { | if (!IsSameObjectType(*this, other)) { | ||||
| return false; | return false; | ||||
| } | } | ||||
| const auto& other_keyword = static_cast<const Keyword&>(other); | |||||
| const auto &other_keyword = static_cast<const Keyword &>(other); | |||||
| return (other_keyword.key_ == key_ && *other_keyword.value_ == *value_); | return (other_keyword.key_ == key_ && *other_keyword.value_ == *value_); | ||||
| } | } | ||||
| @@ -87,11 +87,11 @@ std::string Slice::ToString() const { | |||||
| return buffer.str(); | return buffer.str(); | ||||
| } | } | ||||
| bool Slice::operator==(const Type& other) const { | |||||
| bool Slice::operator==(const Type &other) const { | |||||
| if (!IsSameObjectType(*this, other)) { | if (!IsSameObjectType(*this, other)) { | ||||
| return false; | return false; | ||||
| } | } | ||||
| auto other_slice = static_cast<const Slice&>(other); | |||||
| auto other_slice = static_cast<const Slice &>(other); | |||||
| return (*start_ == *other_slice.start_ && *stop_ == *other_slice.stop_ && *step_ == *other_slice.step_); | return (*start_ == *other_slice.start_ && *stop_ == *other_slice.stop_ && *step_ == *other_slice.step_); | ||||
| } | } | ||||
| @@ -122,11 +122,11 @@ std::string TensorType::DumpText() const { | |||||
| } | } | ||||
| } | } | ||||
| bool TensorType::operator==(const Type& other) const { | |||||
| bool TensorType::operator==(const Type &other) const { | |||||
| if (!IsSameObjectType(*this, other)) { | if (!IsSameObjectType(*this, other)) { | ||||
| return false; | return false; | ||||
| } | } | ||||
| auto other_elem_type = static_cast<const TensorType&>(other).element_type_; | |||||
| auto other_elem_type = static_cast<const TensorType &>(other).element_type_; | |||||
| // When element_type_ = nullptr, which means any type of Array. | // When element_type_ = nullptr, which means any type of Array. | ||||
| if (element_type_ == nullptr && other_elem_type == nullptr) { | if (element_type_ == nullptr && other_elem_type == nullptr) { | ||||
| return true; | return true; | ||||
| @@ -141,7 +141,7 @@ Function::Function() : Object(kObjectTypeFunction) { | |||||
| retval_ = nullptr; | retval_ = nullptr; | ||||
| } | } | ||||
| Function::Function(const std::vector<TypePtr>& args, const TypePtr retval) | |||||
| Function::Function(const std::vector<TypePtr> &args, const TypePtr retval) | |||||
| : Object(kObjectTypeFunction, false), args_(args), retval_(retval) {} | : Object(kObjectTypeFunction, false), args_(args), retval_(retval) {} | ||||
| TypePtr Function::DeepCopy() const { | TypePtr Function::DeepCopy() const { | ||||
| @@ -151,7 +151,7 @@ TypePtr Function::DeepCopy() const { | |||||
| TypePtrList args; | TypePtrList args; | ||||
| TypePtr retval = nullptr; | TypePtr retval = nullptr; | ||||
| (void)std::transform(args_.begin(), args_.end(), std::back_inserter(args), | (void)std::transform(args_.begin(), args_.end(), std::back_inserter(args), | ||||
| [](const TypePtr& arg) { return arg->DeepCopy(); }); | |||||
| [](const TypePtr &arg) { return arg->DeepCopy(); }); | |||||
| if (retval_ != nullptr) { | if (retval_ != nullptr) { | ||||
| retval = retval_->DeepCopy(); | retval = retval_->DeepCopy(); | ||||
| } | } | ||||
| @@ -159,12 +159,12 @@ TypePtr Function::DeepCopy() const { | |||||
| } | } | ||||
| } | } | ||||
| bool Function::operator==(const Type& other) const { | |||||
| bool Function::operator==(const Type &other) const { | |||||
| if (!IsSameObjectType(*this, other)) { | if (!IsSameObjectType(*this, other)) { | ||||
| return false; | return false; | ||||
| } | } | ||||
| const auto& other_function = static_cast<const Function&>(other); | |||||
| const auto &other_function = static_cast<const Function &>(other); | |||||
| if ((retval_ != nullptr) && (other_function.retval_ != nullptr)) { | if ((retval_ != nullptr) && (other_function.retval_ != nullptr)) { | ||||
| if (*retval_ != *other_function.retval_) { | if (*retval_ != *other_function.retval_) { | ||||
| return false; | return false; | ||||
| @@ -188,7 +188,7 @@ std::string Function::ToString() const { | |||||
| } else { | } else { | ||||
| buffer << "Func[("; | buffer << "Func[("; | ||||
| bool begin = true; | bool begin = true; | ||||
| for (auto& attr : args_) { | |||||
| for (auto &attr : args_) { | |||||
| if (!begin) { | if (!begin) { | ||||
| buffer << ", "; | buffer << ", "; | ||||
| } else { | } else { | ||||
| @@ -242,34 +242,34 @@ std::string JTagged::DumpText() const { | |||||
| return buffer.str(); | return buffer.str(); | ||||
| } | } | ||||
| std::ostream& operator<<(std::ostream& os, const std::shared_ptr<Problem> problem) { | |||||
| std::ostream &operator<<(std::ostream &os, const std::shared_ptr<Problem> problem) { | |||||
| MS_EXCEPTION_IF_NULL(problem); | MS_EXCEPTION_IF_NULL(problem); | ||||
| os << problem->ToString(); | os << problem->ToString(); | ||||
| return os; | return os; | ||||
| } | } | ||||
| std::size_t TypeHasher::operator()(TypePtr const& type) const { | |||||
| std::size_t TypeHasher::operator()(TypePtr const &type) const { | |||||
| MS_EXCEPTION_IF_NULL(type); | MS_EXCEPTION_IF_NULL(type); | ||||
| std::size_t hash = std::hash<size_t>()(type->type_id()); | std::size_t hash = std::hash<size_t>()(type->type_id()); | ||||
| return hash; | return hash; | ||||
| } | } | ||||
| std::size_t TypeListHasher::operator()(const TypePtrList& type_list) const { | |||||
| std::size_t TypeListHasher::operator()(const TypePtrList &type_list) const { | |||||
| std::size_t hash_sum = 0; | std::size_t hash_sum = 0; | ||||
| for (auto& type : type_list) { | |||||
| for (auto &type : type_list) { | |||||
| auto type_id = static_cast<std::size_t>(type->type_id()); | auto type_id = static_cast<std::size_t>(type->type_id()); | ||||
| hash_sum = hash_combine(hash_sum, type_id); | hash_sum = hash_combine(hash_sum, type_id); | ||||
| } | } | ||||
| return hash_sum; | return hash_sum; | ||||
| } | } | ||||
| bool TypeEqual::operator()(TypePtr const& t1, TypePtr const& t2) const { | |||||
| bool TypeEqual::operator()(TypePtr const &t1, TypePtr const &t2) const { | |||||
| MS_EXCEPTION_IF_NULL(t1); | MS_EXCEPTION_IF_NULL(t1); | ||||
| MS_EXCEPTION_IF_NULL(t2); | MS_EXCEPTION_IF_NULL(t2); | ||||
| return t1->type_id() == t2->type_id(); | return t1->type_id() == t2->type_id(); | ||||
| } | } | ||||
| bool TypeListEqual::operator()(TypePtrList const& lhs, TypePtrList const& rhs) const { | |||||
| bool TypeListEqual::operator()(TypePtrList const &lhs, TypePtrList const &rhs) const { | |||||
| if (lhs.size() != rhs.size()) { | if (lhs.size() != rhs.size()) { | ||||
| return false; | return false; | ||||
| } | } | ||||
| @@ -332,7 +332,7 @@ TypePtr TypeIdToType(TypeId id) { | |||||
| namespace { | namespace { | ||||
| template <typename T> | template <typename T> | ||||
| TypePtr StringToNumberType(const std::string& type_name, const std::string& num_type_name) { | |||||
| TypePtr StringToNumberType(const std::string &type_name, const std::string &num_type_name) { | |||||
| TypePtr type = nullptr; | TypePtr type = nullptr; | ||||
| if (type_name == num_type_name) { | if (type_name == num_type_name) { | ||||
| type = std::make_shared<T>(); | type = std::make_shared<T>(); | ||||
| @@ -344,14 +344,14 @@ TypePtr StringToNumberType(const std::string& type_name, const std::string& num_ | |||||
| } | } | ||||
| auto bits = std::stoi(type_name.substr(num_type_name.size())); | auto bits = std::stoi(type_name.substr(num_type_name.size())); | ||||
| type = std::make_shared<T>(bits); | type = std::make_shared<T>(bits); | ||||
| } catch (const std::exception& e) { | |||||
| } catch (const std::exception &e) { | |||||
| MS_LOG(EXCEPTION) << "" << num_type_name << " convert from string error " << e.what(); | MS_LOG(EXCEPTION) << "" << num_type_name << " convert from string error " << e.what(); | ||||
| } | } | ||||
| } | } | ||||
| return type; | return type; | ||||
| } | } | ||||
| std::vector<TypePtr> StringToVectorOfType(const std::string& type_names) { | |||||
| std::vector<TypePtr> StringToVectorOfType(const std::string &type_names) { | |||||
| std::vector<TypePtr> types; | std::vector<TypePtr> types; | ||||
| if (type_names.length() == 0) { | if (type_names.length() == 0) { | ||||
| return types; | return types; | ||||
| @@ -371,7 +371,7 @@ std::vector<TypePtr> StringToVectorOfType(const std::string& type_names) { | |||||
| return types; | return types; | ||||
| } | } | ||||
| TypePtr TensorStrToType(const std::string& type_name) { | |||||
| TypePtr TensorStrToType(const std::string &type_name) { | |||||
| TypePtr type = nullptr; | TypePtr type = nullptr; | ||||
| if (type_name == "Tensor") { | if (type_name == "Tensor") { | ||||
| type = std::make_shared<TensorType>(); | type = std::make_shared<TensorType>(); | ||||
| @@ -388,7 +388,7 @@ TypePtr TensorStrToType(const std::string& type_name) { | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| type = std::make_shared<TensorType>(element_type); | type = std::make_shared<TensorType>(element_type); | ||||
| } catch (const std::exception& e) { | |||||
| } catch (const std::exception &e) { | |||||
| MS_LOG(EXCEPTION) << "" << type_name << " convert from string error " << e.what(); | MS_LOG(EXCEPTION) << "" << type_name << " convert from string error " << e.what(); | ||||
| } | } | ||||
| } | } | ||||
| @@ -396,7 +396,7 @@ TypePtr TensorStrToType(const std::string& type_name) { | |||||
| return type; | return type; | ||||
| } | } | ||||
| TypePtr ListStrToType(const std::string& type_name) { | |||||
| TypePtr ListStrToType(const std::string &type_name) { | |||||
| TypePtr type = nullptr; | TypePtr type = nullptr; | ||||
| if (type_name == "List") { | if (type_name == "List") { | ||||
| type = std::make_shared<List>(); | type = std::make_shared<List>(); | ||||
| @@ -410,12 +410,12 @@ TypePtr ListStrToType(const std::string& type_name) { | |||||
| std::string element_strs = type_name.substr(start, end - start); | std::string element_strs = type_name.substr(start, end - start); | ||||
| std::vector<TypePtr> element_types = StringToVectorOfType(element_strs); | std::vector<TypePtr> element_types = StringToVectorOfType(element_strs); | ||||
| bool wrong = | bool wrong = | ||||
| std::any_of(element_types.begin(), element_types.end(), [](const TypePtr& x) { return x == nullptr; }); | |||||
| std::any_of(element_types.begin(), element_types.end(), [](const TypePtr &x) { return x == nullptr; }); | |||||
| if (wrong) { | if (wrong) { | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| type = std::make_shared<List>(element_types); | type = std::make_shared<List>(element_types); | ||||
| } catch (const std::exception& e) { | |||||
| } catch (const std::exception &e) { | |||||
| MS_LOG(EXCEPTION) << "" << type_name << " convert from string error " << e.what(); | MS_LOG(EXCEPTION) << "" << type_name << " convert from string error " << e.what(); | ||||
| } | } | ||||
| } | } | ||||
| @@ -423,7 +423,7 @@ TypePtr ListStrToType(const std::string& type_name) { | |||||
| return type; | return type; | ||||
| } | } | ||||
| TypePtr TupleStrToType(const std::string& type_name) { | |||||
| TypePtr TupleStrToType(const std::string &type_name) { | |||||
| TypePtr type = nullptr; | TypePtr type = nullptr; | ||||
| if (type_name == "Tuple") { | if (type_name == "Tuple") { | ||||
| type = std::make_shared<Tuple>(); | type = std::make_shared<Tuple>(); | ||||
| @@ -437,19 +437,19 @@ TypePtr TupleStrToType(const std::string& type_name) { | |||||
| std::string element_strs = type_name.substr(start, end - start); | std::string element_strs = type_name.substr(start, end - start); | ||||
| std::vector<TypePtr> element_types = StringToVectorOfType(element_strs); | std::vector<TypePtr> element_types = StringToVectorOfType(element_strs); | ||||
| bool wrong = | bool wrong = | ||||
| std::any_of(element_types.begin(), element_types.end(), [](const TypePtr& x) { return x == nullptr; }); | |||||
| std::any_of(element_types.begin(), element_types.end(), [](const TypePtr &x) { return x == nullptr; }); | |||||
| if (wrong) { | if (wrong) { | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| type = std::make_shared<Tuple>(element_types); | type = std::make_shared<Tuple>(element_types); | ||||
| } catch (const std::exception& e) { | |||||
| } catch (const std::exception &e) { | |||||
| MS_LOG(EXCEPTION) << "" << type_name << " convert from string error " << e.what(); | MS_LOG(EXCEPTION) << "" << type_name << " convert from string error " << e.what(); | ||||
| } | } | ||||
| } | } | ||||
| return type; | return type; | ||||
| } | } | ||||
| TypePtr FunctionStrToType(const std::string& type_name) { | |||||
| TypePtr FunctionStrToType(const std::string &type_name) { | |||||
| TypePtr type = nullptr; | TypePtr type = nullptr; | ||||
| if (type_name == "Function") { | if (type_name == "Function") { | ||||
| @@ -478,12 +478,12 @@ TypePtr FunctionStrToType(const std::string& type_name) { | |||||
| std::vector<TypePtr> args_type = StringToVectorOfType(str_args); | std::vector<TypePtr> args_type = StringToVectorOfType(str_args); | ||||
| TypePtr retval = StringToType(str_retval); | TypePtr retval = StringToType(str_retval); | ||||
| bool wrong = std::any_of(args_type.begin(), args_type.end(), [](const TypePtr& x) { return x == nullptr; }); | |||||
| bool wrong = std::any_of(args_type.begin(), args_type.end(), [](const TypePtr &x) { return x == nullptr; }); | |||||
| if (retval == nullptr || wrong) { | if (retval == nullptr || wrong) { | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| type = std::make_shared<Function>(args_type, retval); | type = std::make_shared<Function>(args_type, retval); | ||||
| } catch (const std::exception& e) { | |||||
| } catch (const std::exception &e) { | |||||
| MS_LOG(EXCEPTION) << "" << type_name << " convert from string error " << e.what(); | MS_LOG(EXCEPTION) << "" << type_name << " convert from string error " << e.what(); | ||||
| } | } | ||||
| } | } | ||||
| @@ -491,7 +491,7 @@ TypePtr FunctionStrToType(const std::string& type_name) { | |||||
| } | } | ||||
| } // namespace | } // namespace | ||||
| TypePtr StringToType(const std::string& type_name) { | |||||
| TypePtr StringToType(const std::string &type_name) { | |||||
| TypePtr type = nullptr; | TypePtr type = nullptr; | ||||
| if (type_name.compare("None") == 0) { | if (type_name.compare("None") == 0) { | ||||
| type = std::make_shared<TypeNone>(); | type = std::make_shared<TypeNone>(); | ||||
| @@ -542,7 +542,7 @@ TypePtr StringToType(const std::string& type_name) { | |||||
| return type; | return type; | ||||
| } | } | ||||
| bool IsIdentidityOrSubclass(TypePtr const& x, TypePtr const& base_type) { | |||||
| bool IsIdentidityOrSubclass(TypePtr const &x, TypePtr const &base_type) { | |||||
| if (x == nullptr || base_type == nullptr) { | if (x == nullptr || base_type == nullptr) { | ||||
| MS_LOG(ERROR) << "Type is nullptr."; | MS_LOG(ERROR) << "Type is nullptr."; | ||||
| return false; | return false; | ||||
| @@ -564,7 +564,7 @@ bool IsIdentidityOrSubclass(TypePtr const& x, TypePtr const& base_type) { | |||||
| } | } | ||||
| } | } | ||||
| bool IsSubType(TypePtr const& t1, TypePtr const& t2) { | |||||
| bool IsSubType(TypePtr const &t1, TypePtr const &t2) { | |||||
| MS_EXCEPTION_IF_NULL(t1); | MS_EXCEPTION_IF_NULL(t1); | ||||
| if (t1->type_id() == kTypeUnknown) { | if (t1->type_id() == kTypeUnknown) { | ||||
| return false; | return false; | ||||
| @@ -576,17 +576,17 @@ bool IsSubType(TypePtr const& t1, TypePtr const& t2) { | |||||
| } | } | ||||
| REGISTER_PYBIND_DEFINE( | REGISTER_PYBIND_DEFINE( | ||||
| typing, ([](py::module* const m) { | |||||
| typing, ([](py::module *const m) { | |||||
| auto m_sub = m->def_submodule("typing", "submodule for dtype"); | auto m_sub = m->def_submodule("typing", "submodule for dtype"); | ||||
| py::enum_<TypeId>(m_sub, "TypeId"); | py::enum_<TypeId>(m_sub, "TypeId"); | ||||
| (void)m_sub.def("is_subclass", &IsIdentidityOrSubclass, "is equal or subclass"); | (void)m_sub.def("is_subclass", &IsIdentidityOrSubclass, "is equal or subclass"); | ||||
| (void)m_sub.def("load_type", &TypeIdToType, "load type"); | (void)m_sub.def("load_type", &TypeIdToType, "load type"); | ||||
| (void)m_sub.def( | (void)m_sub.def( | ||||
| "dump_type", [](const TypePtr& t) { return t->type_id(); }, "dump type"); | |||||
| "dump_type", [](const TypePtr &t) { return t->type_id(); }, "dump type"); | |||||
| (void)py::class_<Type, std::shared_ptr<Type>>(m_sub, "Type") | (void)py::class_<Type, std::shared_ptr<Type>>(m_sub, "Type") | ||||
| .def_readonly(PYTHON_DTYPE_FLAG, &mindspore::Type::parse_info_) | .def_readonly(PYTHON_DTYPE_FLAG, &mindspore::Type::parse_info_) | ||||
| .def("__eq__", | .def("__eq__", | ||||
| [](const TypePtr& t1, const TypePtr& t2) { | |||||
| [](const TypePtr &t1, const TypePtr &t2) { | |||||
| if (t1 != nullptr && t2 != nullptr) { | if (t1 != nullptr && t2 != nullptr) { | ||||
| return *t1 == *t2; | return *t1 == *t2; | ||||
| } | } | ||||
| @@ -595,7 +595,7 @@ REGISTER_PYBIND_DEFINE( | |||||
| .def("__hash__", &Type::hash) | .def("__hash__", &Type::hash) | ||||
| .def("__str__", &Type::ToString) | .def("__str__", &Type::ToString) | ||||
| .def("__repr__", &Type::ReprString) | .def("__repr__", &Type::ReprString) | ||||
| .def("__deepcopy__", [](const TypePtr& t, py::dict) { | |||||
| .def("__deepcopy__", [](const TypePtr &t, py::dict) { | |||||
| if (t == nullptr) { | if (t == nullptr) { | ||||
| return static_cast<TypePtr>(nullptr); | return static_cast<TypePtr>(nullptr); | ||||
| } | } | ||||
| @@ -605,21 +605,21 @@ REGISTER_PYBIND_DEFINE( | |||||
| (void)py::class_<Bool, Type, std::shared_ptr<Bool>>(m_sub, "Bool") | (void)py::class_<Bool, Type, std::shared_ptr<Bool>>(m_sub, "Bool") | ||||
| .def(py::init()) | .def(py::init()) | ||||
| .def(py::pickle( | .def(py::pickle( | ||||
| [](const Bool&) { // __getstate__ | |||||
| [](const Bool &) { // __getstate__ | |||||
| return py::make_tuple(); | return py::make_tuple(); | ||||
| }, | }, | ||||
| [](const py::tuple&) { // __setstate__ | |||||
| [](const py::tuple &) { // __setstate__ | |||||
| return std::make_shared<Bool>(); | return std::make_shared<Bool>(); | ||||
| })); | })); | ||||
| (void)py::class_<Int, Type, std::shared_ptr<Int>>(m_sub, "Int") | (void)py::class_<Int, Type, std::shared_ptr<Int>>(m_sub, "Int") | ||||
| .def(py::init()) | .def(py::init()) | ||||
| .def(py::init<int>(), py::arg("nbits")) | .def(py::init<int>(), py::arg("nbits")) | ||||
| .def(py::pickle( | .def(py::pickle( | ||||
| [](const Int& t) { // __getstate__ | |||||
| [](const Int &t) { // __getstate__ | |||||
| /* Return a tuple that fully encodes the state of the object */ | /* Return a tuple that fully encodes the state of the object */ | ||||
| return py::make_tuple(py::int_(t.nbits())); | return py::make_tuple(py::int_(t.nbits())); | ||||
| }, | }, | ||||
| [](const py::tuple& t) { // __setstate__ | |||||
| [](const py::tuple &t) { // __setstate__ | |||||
| if (t.size() != 1) { | if (t.size() != 1) { | ||||
| throw std::runtime_error("Invalid state!"); | throw std::runtime_error("Invalid state!"); | ||||
| } | } | ||||
| @@ -631,11 +631,11 @@ REGISTER_PYBIND_DEFINE( | |||||
| .def(py::init()) | .def(py::init()) | ||||
| .def(py::init<int>(), py::arg("nbits")) | .def(py::init<int>(), py::arg("nbits")) | ||||
| .def(py::pickle( | .def(py::pickle( | ||||
| [](const UInt& t) { // __getstate__ | |||||
| [](const UInt &t) { // __getstate__ | |||||
| /* Return a tuple that fully encodes the state of the object */ | /* Return a tuple that fully encodes the state of the object */ | ||||
| return py::make_tuple(py::int_(t.nbits())); | return py::make_tuple(py::int_(t.nbits())); | ||||
| }, | }, | ||||
| [](const py::tuple& t) { // __setstate__ | |||||
| [](const py::tuple &t) { // __setstate__ | |||||
| if (t.size() != 1) { | if (t.size() != 1) { | ||||
| throw std::runtime_error("Invalid state!"); | throw std::runtime_error("Invalid state!"); | ||||
| } | } | ||||
| @@ -647,11 +647,11 @@ REGISTER_PYBIND_DEFINE( | |||||
| .def(py::init()) | .def(py::init()) | ||||
| .def(py::init<int>(), py::arg("nbits")) | .def(py::init<int>(), py::arg("nbits")) | ||||
| .def(py::pickle( | .def(py::pickle( | ||||
| [](const Float& t) { // __getstate__ | |||||
| [](const Float &t) { // __getstate__ | |||||
| /* Return a tuple that fully encodes the state of the object */ | /* Return a tuple that fully encodes the state of the object */ | ||||
| return py::make_tuple(py::int_(t.nbits())); | return py::make_tuple(py::int_(t.nbits())); | ||||
| }, | }, | ||||
| [](const py::tuple& t) { // __setstate__ | |||||
| [](const py::tuple &t) { // __setstate__ | |||||
| if (t.size() != 1) { | if (t.size() != 1) { | ||||
| throw std::runtime_error("Invalid state!"); | throw std::runtime_error("Invalid state!"); | ||||
| } | } | ||||
| @@ -670,11 +670,11 @@ REGISTER_PYBIND_DEFINE( | |||||
| .def(py::init<TypePtr>(), py::arg("element")) | .def(py::init<TypePtr>(), py::arg("element")) | ||||
| .def("element_type", &TensorType::element) | .def("element_type", &TensorType::element) | ||||
| .def(py::pickle( | .def(py::pickle( | ||||
| [](const TensorType& t) { // __getstate__ | |||||
| [](const TensorType &t) { // __getstate__ | |||||
| /* Return a tuple that fully encodes the state of the object */ | /* Return a tuple that fully encodes the state of the object */ | ||||
| return py::make_tuple(py::int_(static_cast<int>(t.element()->type_id()))); | return py::make_tuple(py::int_(static_cast<int>(t.element()->type_id()))); | ||||
| }, | }, | ||||
| [](const py::tuple& t) { // __setstate__ | |||||
| [](const py::tuple &t) { // __setstate__ | |||||
| if (t.size() != 1) { | if (t.size() != 1) { | ||||
| throw std::runtime_error("Invalid state!"); | throw std::runtime_error("Invalid state!"); | ||||
| } | } | ||||
| @@ -60,7 +60,7 @@ using StringPtr = std::shared_ptr<String>; | |||||
| class Keyword : public Object { | class Keyword : public Object { | ||||
| public: | public: | ||||
| Keyword() : Object(kObjectTypeKeyword, false), key_(""), value_(nullptr) {} | Keyword() : Object(kObjectTypeKeyword, false), key_(""), value_(nullptr) {} | ||||
| Keyword(const std::string& key, const TypePtr& value) : Object(kObjectTypeKeyword, false), key_(key), value_(value) {} | |||||
| Keyword(const std::string &key, const TypePtr &value) : Object(kObjectTypeKeyword, false), key_(key), value_(value) {} | |||||
| ~Keyword() override = default; | ~Keyword() override = default; | ||||
| MS_DECLARE_PARENT(Keyword, Object) | MS_DECLARE_PARENT(Keyword, Object) | ||||
| @@ -70,7 +70,7 @@ class Keyword : public Object { | |||||
| std::string ToString() const override; | std::string ToString() const override; | ||||
| std::string DumpText() const override; | std::string DumpText() const override; | ||||
| bool operator==(const Type& other) const override; | |||||
| bool operator==(const Type &other) const override; | |||||
| std::string GetKey() const { return key_; } | std::string GetKey() const { return key_; } | ||||
| TypePtr GetValue() const { return value_; } | TypePtr GetValue() const { return value_; } | ||||
| @@ -84,7 +84,7 @@ using KeywordPtr = std::shared_ptr<Keyword>; | |||||
| class Slice : public Object { | class Slice : public Object { | ||||
| public: | public: | ||||
| Slice() : Object(kObjectTypeSlice), start_(nullptr), stop_(nullptr), step_(nullptr) {} | Slice() : Object(kObjectTypeSlice), start_(nullptr), stop_(nullptr), step_(nullptr) {} | ||||
| Slice(const TypePtr& start, const TypePtr& stop, const TypePtr& step) | |||||
| Slice(const TypePtr &start, const TypePtr &stop, const TypePtr &step) | |||||
| : Object(kObjectTypeSlice, false), start_(start), stop_(stop), step_(step) {} | : Object(kObjectTypeSlice, false), start_(start), stop_(stop), step_(step) {} | ||||
| ~Slice() override = default; | ~Slice() override = default; | ||||
| @@ -95,7 +95,7 @@ class Slice : public Object { | |||||
| std::string ToString() const override; | std::string ToString() const override; | ||||
| std::string DumpText() const override; | std::string DumpText() const override; | ||||
| bool operator==(const Type& other) const override; | |||||
| bool operator==(const Type &other) const override; | |||||
| TypePtr get_start() const { return start_; } | TypePtr get_start() const { return start_; } | ||||
| TypePtr get_stop() const { return stop_; } | TypePtr get_stop() const { return stop_; } | ||||
| @@ -111,19 +111,19 @@ using SlicePtr = std::shared_ptr<Slice>; | |||||
| class TensorType : public Object { | class TensorType : public Object { | ||||
| public: | public: | ||||
| TensorType() : Object(kObjectTypeTensorType) {} | TensorType() : Object(kObjectTypeTensorType) {} | ||||
| explicit TensorType(const TypePtr& ele) : Object(kObjectTypeTensorType, false), element_type_(ele) {} | |||||
| explicit TensorType(const TypePtr &ele) : Object(kObjectTypeTensorType, false), element_type_(ele) {} | |||||
| ~TensorType() override = default; | ~TensorType() override = default; | ||||
| MS_DECLARE_PARENT(TensorType, Object) | MS_DECLARE_PARENT(TensorType, Object) | ||||
| TypeId generic_type_id() const override { return kObjectTypeTensorType; } | TypeId generic_type_id() const override { return kObjectTypeTensorType; } | ||||
| const TypePtr element() const { return element_type_; } | const TypePtr element() const { return element_type_; } | ||||
| void set_element(const TypePtr& element_type) { element_type_ = element_type; } | |||||
| void set_element(const TypePtr &element_type) { element_type_ = element_type; } | |||||
| TypePtr DeepCopy() const override; | TypePtr DeepCopy() const override; | ||||
| std::string ToString() const override; | std::string ToString() const override; | ||||
| std::string ToReprString() const override { return "tensor"; } | std::string ToReprString() const override { return "tensor"; } | ||||
| std::string DumpText() const override; | std::string DumpText() const override; | ||||
| bool operator==(const Type& other) const override; | |||||
| bool operator==(const Type &other) const override; | |||||
| private: | private: | ||||
| TypePtr element_type_; | TypePtr element_type_; | ||||
| @@ -133,7 +133,7 @@ using TensorTypePtr = std::shared_ptr<TensorType>; | |||||
| class Function : public Object { | class Function : public Object { | ||||
| public: | public: | ||||
| Function(); | Function(); | ||||
| Function(const std::vector<TypePtr>& args, const TypePtr retval); | |||||
| Function(const std::vector<TypePtr> &args, const TypePtr retval); | |||||
| ~Function() override = default; | ~Function() override = default; | ||||
| MS_DECLARE_PARENT(Function, Object) | MS_DECLARE_PARENT(Function, Object) | ||||
| @@ -141,11 +141,11 @@ class Function : public Object { | |||||
| // Add temporarily for return abstraction to avoid type checking. | // Add temporarily for return abstraction to avoid type checking. | ||||
| bool IsTransparent() const { return (args_.empty()) && (retval_ == nullptr); } | bool IsTransparent() const { return (args_.empty()) && (retval_ == nullptr); } | ||||
| const std::vector<TypePtr>& args() const { return args_; } | |||||
| const TypePtr& retval() const { return retval_; } | |||||
| const std::vector<TypePtr> &args() const { return args_; } | |||||
| const TypePtr &retval() const { return retval_; } | |||||
| TypePtr DeepCopy() const override; | TypePtr DeepCopy() const override; | ||||
| bool operator==(const Type& other) const override; | |||||
| bool operator==(const Type &other) const override; | |||||
| std::string ToString() const override; | std::string ToString() const override; | ||||
| std::string ToReprString() const override { return "function"; } | std::string ToReprString() const override { return "function"; } | ||||
| @@ -158,7 +158,7 @@ using FunctionPtr = std::shared_ptr<Function>; | |||||
| class JTagged : public Object { | class JTagged : public Object { | ||||
| public: | public: | ||||
| JTagged() : Object(kObjectTypeJTagged) {} | JTagged() : Object(kObjectTypeJTagged) {} | ||||
| explicit JTagged(const TypePtr& subtype) : Object(kObjectTypeJTagged, false), subtype_(subtype) {} | |||||
| explicit JTagged(const TypePtr &subtype) : Object(kObjectTypeJTagged, false), subtype_(subtype) {} | |||||
| ~JTagged() override = default; | ~JTagged() override = default; | ||||
| MS_DECLARE_PARENT(JTagged, Object) | MS_DECLARE_PARENT(JTagged, Object) | ||||
| @@ -213,7 +213,7 @@ using TypeTypePtr = std::shared_ptr<TypeType>; | |||||
| class Problem : public Type { | class Problem : public Type { | ||||
| public: | public: | ||||
| Problem() : Type(kMetaTypeProblem), kind_(Named("unknown")) {} | Problem() : Type(kMetaTypeProblem), kind_(Named("unknown")) {} | ||||
| explicit Problem(const Named& kind) : Type(kMetaTypeProblem), kind_(kind) {} | |||||
| explicit Problem(const Named &kind) : Type(kMetaTypeProblem), kind_(kind) {} | |||||
| ~Problem() override = default; | ~Problem() override = default; | ||||
| MS_DECLARE_PARENT(Problem, Type) | MS_DECLARE_PARENT(Problem, Type) | ||||
| @@ -222,7 +222,7 @@ class Problem : public Type { | |||||
| std::string ToString() const override { return kind_.name(); } | std::string ToString() const override { return kind_.name(); } | ||||
| std::string DumpText() const override { return "ProblemType"; } | std::string DumpText() const override { return "ProblemType"; } | ||||
| friend std::ostream& operator<<(std::ostream& os, const std::shared_ptr<Problem> problem); | |||||
| friend std::ostream &operator<<(std::ostream &os, const std::shared_ptr<Problem> problem); | |||||
| private: | private: | ||||
| Named kind_; | Named kind_; | ||||
| @@ -246,29 +246,29 @@ using ExternalPtr = std::shared_ptr<External>; | |||||
| // helper template | // helper template | ||||
| template <class T> | template <class T> | ||||
| TypePtr Clone(const T& t) { | |||||
| TypePtr Clone(const T &t) { | |||||
| return t.Clone(); | return t.Clone(); | ||||
| } | } | ||||
| TypePtr StringToType(const std::string& type_name); | |||||
| TypePtr StringToType(const std::string &type_name); | |||||
| // Judge whether x is predicate or is a subclass of predicate. | // Judge whether x is predicate or is a subclass of predicate. | ||||
| bool IsIdentidityOrSubclass(TypePtr const& x, TypePtr const& base_type); | |||||
| bool IsIdentidityOrSubclass(TypePtr const &x, TypePtr const &base_type); | |||||
| // Whether t1 is identity or a subclass of t2. | // Whether t1 is identity or a subclass of t2. | ||||
| bool IsSubType(TypePtr const& t1, TypePtr const& t2 = nullptr); | |||||
| bool IsSubType(TypePtr const &t1, TypePtr const &t2 = nullptr); | |||||
| struct TypeHasher { | struct TypeHasher { | ||||
| std::size_t operator()(TypePtr const& type) const; | |||||
| std::size_t operator()(TypePtr const &type) const; | |||||
| }; | }; | ||||
| struct TypeListHasher { | struct TypeListHasher { | ||||
| std::size_t operator()(const TypePtrList& type_list) const; | |||||
| std::size_t operator()(const TypePtrList &type_list) const; | |||||
| }; | }; | ||||
| struct TypeEqual { | struct TypeEqual { | ||||
| bool operator()(TypePtr const& t1, TypePtr const& t2) const; | |||||
| bool operator()(TypePtr const &t1, TypePtr const &t2) const; | |||||
| }; | }; | ||||
| struct TypeListEqual { | struct TypeListEqual { | ||||
| bool operator()(TypePtrList const& lhs, TypePtrList const& rhs) const; | |||||
| bool operator()(TypePtrList const &lhs, TypePtrList const &rhs) const; | |||||
| }; | }; | ||||
| extern const TypePtr kTypeExternal; | extern const TypePtr kTypeExternal; | ||||
| @@ -24,7 +24,7 @@ | |||||
| #include "pybind_api/export_flags.h" | #include "pybind_api/export_flags.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| static std::string DumpTypeVector(const std::vector<TypePtr>& elements, bool is_dumptext) { | |||||
| static std::string DumpTypeVector(const std::vector<TypePtr> &elements, bool is_dumptext) { | |||||
| std::ostringstream oss; | std::ostringstream oss; | ||||
| bool begin = true; | bool begin = true; | ||||
| int cnt = 0; | int cnt = 0; | ||||
| @@ -65,7 +65,7 @@ TypePtr List::DeepCopy() const { | |||||
| } else { | } else { | ||||
| TypePtrList elements; | TypePtrList elements; | ||||
| (void)std::transform(elements_.begin(), elements_.end(), std::back_inserter(elements), | (void)std::transform(elements_.begin(), elements_.end(), std::back_inserter(elements), | ||||
| [](const TypePtr& ele) { return ele->DeepCopy(); }); | |||||
| [](const TypePtr &ele) { return ele->DeepCopy(); }); | |||||
| auto copy = std::make_shared<List>(elements); | auto copy = std::make_shared<List>(elements); | ||||
| return copy; | return copy; | ||||
| } | } | ||||
| @@ -78,11 +78,11 @@ const TypePtr List::operator[](std::size_t dim) const { | |||||
| return elements_[dim]; | return elements_[dim]; | ||||
| } | } | ||||
| bool List::operator==(const Type& other) const { | |||||
| bool List::operator==(const Type &other) const { | |||||
| if (!IsSameObjectType(*this, other)) { | if (!IsSameObjectType(*this, other)) { | ||||
| return false; | return false; | ||||
| } | } | ||||
| const List& other_list = static_cast<const List&>(other); | |||||
| const List &other_list = static_cast<const List &>(other); | |||||
| if (elements_.size() != other_list.elements_.size()) { | if (elements_.size() != other_list.elements_.size()) { | ||||
| return false; | return false; | ||||
| } | } | ||||
| @@ -94,8 +94,8 @@ bool List::operator==(const Type& other) const { | |||||
| return true; | return true; | ||||
| } | } | ||||
| Class::Class(const Named& tag, const ClassAttrVector& attributes, | |||||
| const std::unordered_map<std::string, ValuePtr>& methods) | |||||
| Class::Class(const Named &tag, const ClassAttrVector &attributes, | |||||
| const std::unordered_map<std::string, ValuePtr> &methods) | |||||
| : Object(kObjectTypeClass, false), attributes_(attributes), tag_(tag), methods_(methods) {} | : Object(kObjectTypeClass, false), attributes_(attributes), tag_(tag), methods_(methods) {} | ||||
| std::string List::ToString() const { | std::string List::ToString() const { | ||||
| @@ -122,7 +122,7 @@ std::string List::DumpText() const { | |||||
| return buffer.str(); | return buffer.str(); | ||||
| } | } | ||||
| bool Class::operator==(const Type& other) const { | |||||
| bool Class::operator==(const Type &other) const { | |||||
| // Class is cached for each pyobj in ParseDataClass, so ClassPtr is one by one map to pyobj. | // Class is cached for each pyobj in ParseDataClass, so ClassPtr is one by one map to pyobj. | ||||
| return &other == this; | return &other == this; | ||||
| } | } | ||||
| @@ -143,7 +143,7 @@ std::string Class::ToString() const { | |||||
| } else { | } else { | ||||
| bool begin = true; | bool begin = true; | ||||
| buffer << "cls." << tag_ << "["; | buffer << "cls." << tag_ << "["; | ||||
| for (auto& attr : attributes_) { | |||||
| for (auto &attr : attributes_) { | |||||
| if (!begin) { | if (!begin) { | ||||
| buffer << ", "; | buffer << ", "; | ||||
| } else { | } else { | ||||
| @@ -163,7 +163,7 @@ std::string Class::DumpText() const { | |||||
| } else { | } else { | ||||
| bool begin = true; | bool begin = true; | ||||
| buffer << "Cls." << tag_ << "["; | buffer << "Cls." << tag_ << "["; | ||||
| for (auto& attr : attributes_) { | |||||
| for (auto &attr : attributes_) { | |||||
| if (!begin) { | if (!begin) { | ||||
| buffer << ", "; | buffer << ", "; | ||||
| } else { | } else { | ||||
| @@ -182,17 +182,17 @@ TypePtr Tuple::DeepCopy() const { | |||||
| } else { | } else { | ||||
| TypePtrList elements; | TypePtrList elements; | ||||
| (void)std::transform(elements_.begin(), elements_.end(), std::back_inserter(elements), | (void)std::transform(elements_.begin(), elements_.end(), std::back_inserter(elements), | ||||
| [](const TypePtr& ele) { return ele->DeepCopy(); }); | |||||
| [](const TypePtr &ele) { return ele->DeepCopy(); }); | |||||
| auto copy = std::make_shared<Tuple>(elements); | auto copy = std::make_shared<Tuple>(elements); | ||||
| return copy; | return copy; | ||||
| } | } | ||||
| } | } | ||||
| bool Tuple::operator==(const Type& other) const { | |||||
| bool Tuple::operator==(const Type &other) const { | |||||
| if (!IsSameObjectType(*this, other)) { | if (!IsSameObjectType(*this, other)) { | ||||
| return false; | return false; | ||||
| } | } | ||||
| auto other_tuple = static_cast<const Tuple&>(other); | |||||
| auto other_tuple = static_cast<const Tuple &>(other); | |||||
| if (elements_.size() != other_tuple.elements_.size()) { | if (elements_.size() != other_tuple.elements_.size()) { | ||||
| return false; | return false; | ||||
| } | } | ||||
| @@ -242,7 +242,7 @@ TypePtr Dictionary::DeepCopy() const { | |||||
| std::vector<std::pair<std::string, TypePtr>> kv; | std::vector<std::pair<std::string, TypePtr>> kv; | ||||
| (void)std::transform( | (void)std::transform( | ||||
| key_values_.begin(), key_values_.end(), std::back_inserter(kv), | key_values_.begin(), key_values_.end(), std::back_inserter(kv), | ||||
| [](const std::pair<std::string, TypePtr>& item) { return std::make_pair(item.first, item.second->DeepCopy()); }); | |||||
| [](const std::pair<std::string, TypePtr> &item) { return std::make_pair(item.first, item.second->DeepCopy()); }); | |||||
| return std::make_shared<Dictionary>(kv); | return std::make_shared<Dictionary>(kv); | ||||
| } | } | ||||
| } | } | ||||
| @@ -259,7 +259,7 @@ std::string Dictionary::ToString() const { | |||||
| std::ostringstream buffer; | std::ostringstream buffer; | ||||
| std::vector<std::string> keys; | std::vector<std::string> keys; | ||||
| std::vector<TypePtr> values; | std::vector<TypePtr> values; | ||||
| for (const auto& kv : key_values_) { | |||||
| for (const auto &kv : key_values_) { | |||||
| keys.push_back(kv.first); | keys.push_back(kv.first); | ||||
| values.push_back(kv.second); | values.push_back(kv.second); | ||||
| } | } | ||||
| @@ -276,12 +276,12 @@ std::string Dictionary::ToString() const { | |||||
| std::string Dictionary::DumpText() const { return ToString(); } | std::string Dictionary::DumpText() const { return ToString(); } | ||||
| bool Dictionary::operator==(const mindspore::Type& other) const { | |||||
| bool Dictionary::operator==(const mindspore::Type &other) const { | |||||
| if (!IsSameObjectType(*this, other)) { | if (!IsSameObjectType(*this, other)) { | ||||
| return false; | return false; | ||||
| } | } | ||||
| const auto& other_dict = static_cast<const Dictionary&>(other); | |||||
| const auto &other_dict = static_cast<const Dictionary &>(other); | |||||
| if (key_values_.size() != other_dict.key_values_.size()) { | if (key_values_.size() != other_dict.key_values_.size()) { | ||||
| return false; | return false; | ||||
| } | } | ||||
| @@ -40,10 +40,10 @@ namespace mindspore { | |||||
| class List : public Object { | class List : public Object { | ||||
| public: | public: | ||||
| List() : Object(kObjectTypeList) {} | List() : Object(kObjectTypeList) {} | ||||
| List(const std::initializer_list<TypePtr>& objs) | |||||
| List(const std::initializer_list<TypePtr> &objs) | |||||
| : Object(kObjectTypeList, false), elements_(objs.begin(), objs.end()) {} | : Object(kObjectTypeList, false), elements_(objs.begin(), objs.end()) {} | ||||
| // Shadow copy; | // Shadow copy; | ||||
| explicit List(const TypePtrList& obj) : Object(kObjectTypeList, false), elements_(obj) {} | |||||
| explicit List(const TypePtrList &obj) : Object(kObjectTypeList, false), elements_(obj) {} | |||||
| ~List() override {} | ~List() override {} | ||||
| MS_DECLARE_PARENT(List, Object) | MS_DECLARE_PARENT(List, Object) | ||||
| @@ -51,7 +51,7 @@ class List : public Object { | |||||
| TypeId generic_type_id() const override { return kObjectTypeList; } | TypeId generic_type_id() const override { return kObjectTypeList; } | ||||
| TypePtr DeepCopy() const override; | TypePtr DeepCopy() const override; | ||||
| bool operator==(const Type& other) const override; | |||||
| bool operator==(const Type &other) const override; | |||||
| std::size_t size() const { return elements_.size(); } | std::size_t size() const { return elements_.size(); } | ||||
| TypePtrList elements() const { return elements_; } | TypePtrList elements() const { return elements_; } | ||||
| std::string ToString() const override; | std::string ToString() const override; | ||||
| @@ -68,22 +68,22 @@ using ClassAttrVector = std::vector<std::pair<std::string, TypePtr>>; | |||||
| class Class : public Object { | class Class : public Object { | ||||
| public: | public: | ||||
| Class() : Object(kObjectTypeClass), tag_(Named("Class")) {} | Class() : Object(kObjectTypeClass), tag_(Named("Class")) {} | ||||
| Class(const Named& tag, const ClassAttrVector& attributes, const std::unordered_map<std::string, ValuePtr>& methods); | |||||
| Class(const Named &tag, const ClassAttrVector &attributes, const std::unordered_map<std::string, ValuePtr> &methods); | |||||
| ~Class() override {} | ~Class() override {} | ||||
| MS_DECLARE_PARENT(Class, Object) | MS_DECLARE_PARENT(Class, Object) | ||||
| TypeId generic_type_id() const override { return kObjectTypeClass; } | TypeId generic_type_id() const override { return kObjectTypeClass; } | ||||
| bool operator==(const Type& other) const override; | |||||
| bool operator==(const Type &other) const override; | |||||
| TypePtr DeepCopy() const override; | TypePtr DeepCopy() const override; | ||||
| std::string ToString() const override; | std::string ToString() const override; | ||||
| std::string DumpText() const override; | std::string DumpText() const override; | ||||
| void set_value(const std::unordered_map<std::string, ValuePtr>& v) { attributes_value_ = v; } | |||||
| void set_value(const std::unordered_map<std::string, ValuePtr> &v) { attributes_value_ = v; } | |||||
| Named tag() { return tag_; } | Named tag() { return tag_; } | ||||
| std::unordered_map<std::string, ValuePtr> GetValue() { return attributes_value_; } | std::unordered_map<std::string, ValuePtr> GetValue() { return attributes_value_; } | ||||
| std::unordered_map<std::string, ValuePtr> methods() { return methods_; } | std::unordered_map<std::string, ValuePtr> methods() { return methods_; } | ||||
| ClassAttrVector& GetAttributes() { return attributes_; } | |||||
| ClassAttrVector &GetAttributes() { return attributes_; } | |||||
| ClassAttrVector attributes_; | ClassAttrVector attributes_; | ||||
| @@ -99,11 +99,11 @@ class Tuple : public Object { | |||||
| public: | public: | ||||
| Tuple() : Object(kObjectTypeTuple) {} | Tuple() : Object(kObjectTypeTuple) {} | ||||
| // usage : Tuple t = {std::make_shared<Bool>(), std::make_shared<Int>(32)}; | // usage : Tuple t = {std::make_shared<Bool>(), std::make_shared<Int>(32)}; | ||||
| Tuple(const std::initializer_list<TypePtr>& objs) | |||||
| Tuple(const std::initializer_list<TypePtr> &objs) | |||||
| : Object(kObjectTypeTuple, false), elements_(objs.begin(), objs.end()) {} | : Object(kObjectTypeTuple, false), elements_(objs.begin(), objs.end()) {} | ||||
| // Shadow copy | // Shadow copy | ||||
| explicit Tuple(const TypePtrList& objs) : Object(kObjectTypeTuple, false), elements_(objs.begin(), objs.end()) {} | |||||
| explicit Tuple(const TypePtrList &objs) : Object(kObjectTypeTuple, false), elements_(objs.begin(), objs.end()) {} | |||||
| ~Tuple() override {} | ~Tuple() override {} | ||||
| MS_DECLARE_PARENT(Tuple, Object) | MS_DECLARE_PARENT(Tuple, Object) | ||||
| @@ -115,7 +115,7 @@ class Tuple : public Object { | |||||
| std::string ToReprString() const override { return "tuple_"; } | std::string ToReprString() const override { return "tuple_"; } | ||||
| std::string DumpText() const override; | std::string DumpText() const override; | ||||
| const TypePtr operator[](size_t dim) const; | const TypePtr operator[](size_t dim) const; | ||||
| bool operator==(const Type& other) const override; | |||||
| bool operator==(const Type &other) const override; | |||||
| TypePtrList elements() const { return elements_; } | TypePtrList elements() const { return elements_; } | ||||
| std::size_t size() const { return elements_.size(); } | std::size_t size() const { return elements_.size(); } | ||||
| @@ -128,7 +128,7 @@ using TuplePtr = std::shared_ptr<Tuple>; | |||||
| class Dictionary : public Object { | class Dictionary : public Object { | ||||
| public: | public: | ||||
| Dictionary() : Object(kObjectTypeDictionary) {} | Dictionary() : Object(kObjectTypeDictionary) {} | ||||
| explicit Dictionary(const std::vector<std::pair<std::string, TypePtr>>& key_values) | |||||
| explicit Dictionary(const std::vector<std::pair<std::string, TypePtr>> &key_values) | |||||
| : Object(kObjectTypeDictionary, false), key_values_(key_values) {} | : Object(kObjectTypeDictionary, false), key_values_(key_values) {} | ||||
| ~Dictionary() override {} | ~Dictionary() override {} | ||||
| @@ -136,7 +136,7 @@ class Dictionary : public Object { | |||||
| TypeId generic_type_id() const override { return kObjectTypeDictionary; } | TypeId generic_type_id() const override { return kObjectTypeDictionary; } | ||||
| bool operator==(const Type& other) const override; | |||||
| bool operator==(const Type &other) const override; | |||||
| TypePtr DeepCopy() const override; | TypePtr DeepCopy() const override; | ||||
| std::string ToString() const override; | std::string ToString() const override; | ||||
| std::string DumpText() const override; | std::string DumpText() const override; | ||||
| @@ -24,11 +24,11 @@ | |||||
| #include "pybind_api/export_flags.h" | #include "pybind_api/export_flags.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| bool Number::operator==(const Type& other) const { | |||||
| bool Number::operator==(const Type &other) const { | |||||
| if (!IsSameObjectType(*this, other)) { | if (!IsSameObjectType(*this, other)) { | ||||
| return false; | return false; | ||||
| } | } | ||||
| auto other_number = static_cast<const Number&>(other); | |||||
| auto other_number = static_cast<const Number &>(other); | |||||
| return ((number_type_ == other_number.number_type_) && (nbits_ == other_number.nbits_)); | return ((number_type_ == other_number.number_type_) && (nbits_ == other_number.nbits_)); | ||||
| } | } | ||||
| @@ -49,12 +49,12 @@ class Number : public Object { | |||||
| TypeId type_id() const override { return number_type_; } | TypeId type_id() const override { return number_type_; } | ||||
| TypeId generic_type_id() const override { return kObjectTypeNumber; } | TypeId generic_type_id() const override { return kObjectTypeNumber; } | ||||
| bool operator==(const Type& other) const override; | |||||
| bool operator==(const Type &other) const override; | |||||
| TypePtr DeepCopy() const override { return std::make_shared<Number>(); } | TypePtr DeepCopy() const override { return std::make_shared<Number>(); } | ||||
| std::string ToString() const override { return "Number"; } | std::string ToString() const override { return "Number"; } | ||||
| std::string ToReprString() const override { return "number"; } | std::string ToReprString() const override { return "number"; } | ||||
| std::string DumpText() const override { return "Number"; } | std::string DumpText() const override { return "Number"; } | ||||
| std::string GetTypeName(const std::string& type_name) const { | |||||
| std::string GetTypeName(const std::string &type_name) const { | |||||
| std::ostringstream oss; | std::ostringstream oss; | ||||
| oss << type_name; | oss << type_name; | ||||
| if (nbits() != 0) { | if (nbits() != 0) { | ||||
| @@ -51,7 +51,7 @@ class RefKeyType : public Object { | |||||
| class RefType : public Object { | class RefType : public Object { | ||||
| public: | public: | ||||
| RefType() : Object(kObjectTypeRef) {} | RefType() : Object(kObjectTypeRef) {} | ||||
| RefType(const TypePtr& subtype, const TypePtr& subtype_origin) | |||||
| RefType(const TypePtr &subtype, const TypePtr &subtype_origin) | |||||
| : Object(kObjectTypeRef, false), subtype_(subtype), subtype_origin_(subtype_origin) {} | : Object(kObjectTypeRef, false), subtype_(subtype), subtype_origin_(subtype_origin) {} | ||||
| ~RefType() override {} | ~RefType() override {} | ||||
| MS_DECLARE_PARENT(RefType, Object) | MS_DECLARE_PARENT(RefType, Object) | ||||
| @@ -69,7 +69,7 @@ TypeId FloatBitsToTypeId(const int nbits) { | |||||
| } | } | ||||
| } | } | ||||
| const char* MetaIdLabel(const TypeId& v) { | |||||
| const char *MetaIdLabel(const TypeId &v) { | |||||
| switch (v) { | switch (v) { | ||||
| case kTypeUnknown: | case kTypeUnknown: | ||||
| return "kTypeUnknown"; | return "kTypeUnknown"; | ||||
| @@ -92,7 +92,7 @@ const char* MetaIdLabel(const TypeId& v) { | |||||
| } | } | ||||
| } | } | ||||
| const char* ObjectIdLabel(const TypeId& v) { | |||||
| const char *ObjectIdLabel(const TypeId &v) { | |||||
| switch (v) { | switch (v) { | ||||
| case kObjectTypeNumber: | case kObjectTypeNumber: | ||||
| return "kObjectTypeNumber"; | return "kObjectTypeNumber"; | ||||
| @@ -129,7 +129,7 @@ const char* ObjectIdLabel(const TypeId& v) { | |||||
| } | } | ||||
| } | } | ||||
| const char* NumberIdLabel(const TypeId& v) { | |||||
| const char *NumberIdLabel(const TypeId &v) { | |||||
| switch (v) { | switch (v) { | ||||
| case kNumberTypeBool: | case kNumberTypeBool: | ||||
| return "kNumberTypeBool"; | return "kNumberTypeBool"; | ||||
| @@ -166,7 +166,7 @@ const char* NumberIdLabel(const TypeId& v) { | |||||
| } | } | ||||
| } | } | ||||
| const char* TypeIdLabel(const TypeId& v) { | |||||
| const char *TypeIdLabel(const TypeId &v) { | |||||
| if (v < kMetaTypeEnd) { | if (v < kMetaTypeEnd) { | ||||
| return MetaIdLabel(v); | return MetaIdLabel(v); | ||||
| } else { | } else { | ||||
| @@ -190,14 +190,14 @@ TypeId NormalizeTypeId(const TypeId type_id) { | |||||
| } | } | ||||
| } | } | ||||
| bool IsSameObjectType(const Type& lhs, const Type& rhs) { | |||||
| bool IsSameObjectType(const Type &lhs, const Type &rhs) { | |||||
| if ((lhs.meta_type() != kMetaTypeObject) || (rhs.meta_type() != kMetaTypeObject)) { | if ((lhs.meta_type() != kMetaTypeObject) || (rhs.meta_type() != kMetaTypeObject)) { | ||||
| return false; | return false; | ||||
| } | } | ||||
| return lhs.object_type() == rhs.object_type(); | return lhs.object_type() == rhs.object_type(); | ||||
| } | } | ||||
| size_t GetTypeByte(const TypePtr& type_ptr) { | |||||
| size_t GetTypeByte(const TypePtr &type_ptr) { | |||||
| if (type_ptr && type_ptr->isa<Number>()) { | if (type_ptr && type_ptr->isa<Number>()) { | ||||
| auto number = dyn_cast<Number>(type_ptr); | auto number = dyn_cast<Number>(type_ptr); | ||||
| if (!number) { | if (!number) { | ||||
| @@ -212,9 +212,9 @@ size_t GetTypeByte(const TypePtr& type_ptr) { | |||||
| } | } | ||||
| } | } | ||||
| bool Type::operator==(const Value& other) const { | |||||
| bool Type::operator==(const Value &other) const { | |||||
| if (other.isa<Type>()) { | if (other.isa<Type>()) { | ||||
| auto other_type = static_cast<const Type*>(&other); | |||||
| auto other_type = static_cast<const Type *>(&other); | |||||
| return *this == *other_type; | return *this == *other_type; | ||||
| } else { | } else { | ||||
| return false; | return false; | ||||
| @@ -226,12 +226,12 @@ abstract::AbstractBasePtr Type::ToAbstract() { | |||||
| return ptr; | return ptr; | ||||
| } | } | ||||
| std::ostream& operator<<(std::ostream& os, const Type& type) { | |||||
| std::ostream &operator<<(std::ostream &os, const Type &type) { | |||||
| os << type.ToString(); | os << type.ToString(); | ||||
| return os; | return os; | ||||
| } | } | ||||
| std::ostream& operator<<(std::ostream& os, const TypePtr type) { | |||||
| std::ostream &operator<<(std::ostream &os, const TypePtr type) { | |||||
| os << type->ToString(); | os << type->ToString(); | ||||
| return os; | return os; | ||||
| } | } | ||||
| @@ -244,17 +244,17 @@ bool Object::equal(const TypePtr other) const { | |||||
| return false; | return false; | ||||
| } | } | ||||
| std::ostream& operator<<(std::ostream& os, const Object& obj) { | |||||
| std::ostream &operator<<(std::ostream &os, const Object &obj) { | |||||
| os << obj.ToString(); | os << obj.ToString(); | ||||
| return os; | return os; | ||||
| } | } | ||||
| std::ostream& operator<<(std::ostream& os, const std::shared_ptr<Object> obj) { | |||||
| std::ostream &operator<<(std::ostream &os, const std::shared_ptr<Object> obj) { | |||||
| os << obj->ToString(); | os << obj->ToString(); | ||||
| return os; | return os; | ||||
| } | } | ||||
| std::ostream& operator<<(std::ostream& os, const TypePtrList& types) { | |||||
| std::ostream &operator<<(std::ostream &os, const TypePtrList &types) { | |||||
| os << "["; | os << "["; | ||||
| for (size_t i = 0; i < types.size(); ++i) { | for (size_t i = 0; i < types.size(); ++i) { | ||||
| if (i > 0) { | if (i > 0) { | ||||
| @@ -95,10 +95,10 @@ enum TypeId : int { | |||||
| TypeId IntBitsToTypeId(const int nbits); | TypeId IntBitsToTypeId(const int nbits); | ||||
| TypeId UIntBitsToTypeId(const int nbits); | TypeId UIntBitsToTypeId(const int nbits); | ||||
| TypeId FloatBitsToTypeId(const int nbits); | TypeId FloatBitsToTypeId(const int nbits); | ||||
| const char* TypeIdLabel(const TypeId& v); | |||||
| const char *TypeIdLabel(const TypeId &v); | |||||
| TypeId NormalizeTypeId(const TypeId type_id); | TypeId NormalizeTypeId(const TypeId type_id); | ||||
| bool IsSameObjectType(const Type& lhs, const Type& rhs); | |||||
| size_t GetTypeByte(const TypePtr& type_ptr); | |||||
| bool IsSameObjectType(const Type &lhs, const Type &rhs); | |||||
| size_t GetTypeByte(const TypePtr &type_ptr); | |||||
| // Base class for all types | // Base class for all types | ||||
| // forward declaration. | // forward declaration. | ||||
| @@ -110,14 +110,14 @@ class Type : public Value { | |||||
| ~Type() override = default; | ~Type() override = default; | ||||
| MS_DECLARE_PARENT(Type, Value) | MS_DECLARE_PARENT(Type, Value) | ||||
| bool operator==(const Value& other) const override; | |||||
| bool operator==(const Value &other) const override; | |||||
| TypeId meta_type() const { return meta_type_; } | TypeId meta_type() const { return meta_type_; } | ||||
| virtual TypeId type_id() const { return meta_type_; } | virtual TypeId type_id() const { return meta_type_; } | ||||
| virtual TypeId generic_type_id() const { return kMetaTypeType; } | virtual TypeId generic_type_id() const { return kMetaTypeType; } | ||||
| virtual bool operator!=(const Type& other) const { return !(*this == other); } | |||||
| virtual bool operator==(const Type& other) const { return this->type_id() == other.type_id(); } | |||||
| virtual bool operator!=(const Type &other) const { return !(*this == other); } | |||||
| virtual bool operator==(const Type &other) const { return this->type_id() == other.type_id(); } | |||||
| virtual bool equal(const TypePtr other) const { return *this == *other; } | virtual bool equal(const TypePtr other) const { return *this == *other; } | ||||
| virtual TypeId object_type() const { return kTypeUnknown; } | virtual TypeId object_type() const { return kTypeUnknown; } | ||||
| @@ -134,8 +134,8 @@ class Type : public Value { | |||||
| bool IsUnknown() const { return (meta_type_ == kMetaTypeType); } | bool IsUnknown() const { return (meta_type_ == kMetaTypeType); } | ||||
| bool IsGeneric() const { return is_generic_; } | bool IsGeneric() const { return is_generic_; } | ||||
| abstract::AbstractBasePtr ToAbstract() override; | abstract::AbstractBasePtr ToAbstract() override; | ||||
| friend std::ostream& operator<<(std::ostream& os, const Type& type); | |||||
| friend std::ostream& operator<<(std::ostream& os, const TypePtr type); | |||||
| friend std::ostream &operator<<(std::ostream &os, const Type &type); | |||||
| friend std::ostream &operator<<(std::ostream &os, const TypePtr type); | |||||
| const bool parse_info_ = true; | const bool parse_info_ = true; | ||||
| @@ -163,14 +163,14 @@ class Object : public Type { | |||||
| bool equal(const TypePtr other) const override; | bool equal(const TypePtr other) const override; | ||||
| std::string ToString() const override { return std::string("Object:") + TypeIdLabel(object_type_); } | std::string ToString() const override { return std::string("Object:") + TypeIdLabel(object_type_); } | ||||
| friend std::ostream& operator<<(std::ostream& os, const Object& obj); | |||||
| friend std::ostream& operator<<(std::ostream& os, const std::shared_ptr<Object> obj); | |||||
| friend std::ostream &operator<<(std::ostream &os, const Object &obj); | |||||
| friend std::ostream &operator<<(std::ostream &os, const std::shared_ptr<Object> obj); | |||||
| private: | private: | ||||
| const TypeId object_type_; | const TypeId object_type_; | ||||
| }; | }; | ||||
| std::ostream& operator<<(std::ostream& os, const TypePtrList& types); | |||||
| std::ostream &operator<<(std::ostream &os, const TypePtrList &types); | |||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CCSRC_IR_DTYPE_TYPE_H_ | #endif // MINDSPORE_CCSRC_IR_DTYPE_TYPE_H_ | ||||
| @@ -61,7 +61,7 @@ FuncGraph::FuncGraph() | |||||
| AbstractFunctionPtr FuncGraph::abstract() { | AbstractFunctionPtr FuncGraph::abstract() { | ||||
| AbstractBasePtrList args_spec_list; | AbstractBasePtrList args_spec_list; | ||||
| for (auto& p : parameters_) { | |||||
| for (auto &p : parameters_) { | |||||
| MS_EXCEPTION_IF_NULL(p); | MS_EXCEPTION_IF_NULL(p); | ||||
| if (p->abstract() == nullptr) { | if (p->abstract() == nullptr) { | ||||
| MS_LOG(ERROR) << "Error!!"; | MS_LOG(ERROR) << "Error!!"; | ||||
| @@ -78,7 +78,7 @@ AbstractFunctionPtr FuncGraph::abstract() { | |||||
| return std::make_shared<VirtualAbstractClosure>(args_spec_list, output()->abstract()); | return std::make_shared<VirtualAbstractClosure>(args_spec_list, output()->abstract()); | ||||
| } | } | ||||
| abstract::AbstractBasePtr FuncGraph::MakeAbstractClosure(const abstract::AnalysisContextPtr& context) { | |||||
| abstract::AbstractBasePtr FuncGraph::MakeAbstractClosure(const abstract::AnalysisContextPtr &context) { | |||||
| AnalysisContextPtr temp_context = context; | AnalysisContextPtr temp_context = context; | ||||
| if (temp_context == nullptr) { | if (temp_context == nullptr) { | ||||
| temp_context = abstract::AnalysisContext::DummyContext(); | temp_context = abstract::AnalysisContext::DummyContext(); | ||||
| @@ -96,7 +96,7 @@ AnfNodePtr FuncGraph::output() const { | |||||
| } | } | ||||
| } | } | ||||
| void FuncGraph::set_output(const AnfNodePtr& value, bool force_new_ret) { | |||||
| void FuncGraph::set_output(const AnfNodePtr &value, bool force_new_ret) { | |||||
| if (force_new_ret || return_ == nullptr) { | if (force_new_ret || return_ == nullptr) { | ||||
| std::vector<AnfNodePtr> params({NewValueNode(prim::kPrimReturn), value}); | std::vector<AnfNodePtr> params({NewValueNode(prim::kPrimReturn), value}); | ||||
| FuncGraphPtr this_graph = shared_from_base<FuncGraph>(); | FuncGraphPtr this_graph = shared_from_base<FuncGraph>(); | ||||
| @@ -125,7 +125,7 @@ ParameterPtr FuncGraph::add_parameter() { | |||||
| return p; | return p; | ||||
| } | } | ||||
| void FuncGraph::add_parameter(const ParameterPtr& p) { | |||||
| void FuncGraph::add_parameter(const ParameterPtr &p) { | |||||
| if (manager_.lock()) { | if (manager_.lock()) { | ||||
| std::vector<AnfNodePtr> new_params = parameters_; | std::vector<AnfNodePtr> new_params = parameters_; | ||||
| new_params.push_back(p); | new_params.push_back(p); | ||||
| @@ -135,7 +135,7 @@ void FuncGraph::add_parameter(const ParameterPtr& p) { | |||||
| } | } | ||||
| } | } | ||||
| ParameterPtr FuncGraph::AddWeightParameter(const std::string& name) { | |||||
| ParameterPtr FuncGraph::AddWeightParameter(const std::string &name) { | |||||
| FuncGraphPtr this_graph = shared_from_base<FuncGraph>(); | FuncGraphPtr this_graph = shared_from_base<FuncGraph>(); | ||||
| ParameterPtr p = std::make_shared<Parameter>(this_graph); | ParameterPtr p = std::make_shared<Parameter>(this_graph); | ||||
| p->set_name(name); | p->set_name(name); | ||||
| @@ -154,14 +154,14 @@ ParameterPtr FuncGraph::AddWeightParameter(const std::string& name) { | |||||
| return p; | return p; | ||||
| } | } | ||||
| bool FuncGraph::has_flag(const std::string& flag) { | |||||
| bool FuncGraph::has_flag(const std::string &flag) { | |||||
| if (flags_.count(flag)) { | if (flags_.count(flag)) { | ||||
| return flags_[flag]; | return flags_[flag]; | ||||
| } | } | ||||
| return false; | return false; | ||||
| } | } | ||||
| CNodePtr FuncGraph::NewCNode(const std::vector<AnfNodePtr>& inputs) { | |||||
| CNodePtr FuncGraph::NewCNode(const std::vector<AnfNodePtr> &inputs) { | |||||
| CNodePtr cnode = std::make_shared<CNode>(inputs, shared_from_base<FuncGraph>()); | CNodePtr cnode = std::make_shared<CNode>(inputs, shared_from_base<FuncGraph>()); | ||||
| if (has_flag(GRAPH_FLAG_HAS_EFFECT)) { | if (has_flag(GRAPH_FLAG_HAS_EFFECT)) { | ||||
| order_.push_back(cnode); | order_.push_back(cnode); | ||||
| @@ -170,7 +170,7 @@ CNodePtr FuncGraph::NewCNode(const std::vector<AnfNodePtr>& inputs) { | |||||
| return cnode; | return cnode; | ||||
| } | } | ||||
| CNodePtr FuncGraph::NewCNodeWithScope(const std::vector<AnfNodePtr>& inputs, const ScopePtr& scope) { | |||||
| CNodePtr FuncGraph::NewCNodeWithScope(const std::vector<AnfNodePtr> &inputs, const ScopePtr &scope) { | |||||
| CNodePtr app = NewCNode(inputs); | CNodePtr app = NewCNode(inputs); | ||||
| app->set_scope(scope); | app->set_scope(scope); | ||||
| return app; | return app; | ||||
| @@ -178,13 +178,13 @@ CNodePtr FuncGraph::NewCNodeWithScope(const std::vector<AnfNodePtr>& inputs, con | |||||
| void FuncGraph::DumpCNodeList() { | void FuncGraph::DumpCNodeList() { | ||||
| MS_LOG(INFO) << "FuncGraph " << ToString() << " has following CNode in code order:"; | MS_LOG(INFO) << "FuncGraph " << ToString() << " has following CNode in code order:"; | ||||
| for (const auto& cnode : order_) { | |||||
| for (const auto &cnode : order_) { | |||||
| MS_LOG(INFO) << cnode->DebugString(); | MS_LOG(INFO) << cnode->DebugString(); | ||||
| } | } | ||||
| } | } | ||||
| std::string FuncGraph::ToString() const { | std::string FuncGraph::ToString() const { | ||||
| return mindspore::label_manage::Label(const_cast<FuncGraph*>(this)->shared_from_base<FuncGraph>()->debug_info()); | |||||
| return mindspore::label_manage::Label(const_cast<FuncGraph *>(this)->shared_from_base<FuncGraph>()->debug_info()); | |||||
| } | } | ||||
| GraphDebugInfoPtr FuncGraph::debug_info() { | GraphDebugInfoPtr FuncGraph::debug_info() { | ||||
| @@ -195,38 +195,38 @@ GraphDebugInfoPtr FuncGraph::debug_info() { | |||||
| return this->debug_info_; | return this->debug_info_; | ||||
| } | } | ||||
| const AnfNodeSet& FuncGraph::nodes() { | |||||
| const AnfNodeSet &FuncGraph::nodes() { | |||||
| auto mng = manager_.lock(); | auto mng = manager_.lock(); | ||||
| MS_EXCEPTION_IF_NULL(mng); | MS_EXCEPTION_IF_NULL(mng); | ||||
| auto& nodes = mng->nodes(); | |||||
| auto &nodes = mng->nodes(); | |||||
| return nodes[shared_from_base<FuncGraph>()]; | return nodes[shared_from_base<FuncGraph>()]; | ||||
| } | } | ||||
| const AnfNodeCounterMap& FuncGraph::value_nodes() { | |||||
| const AnfNodeCounterMap &FuncGraph::value_nodes() { | |||||
| auto mng = manager_.lock(); | auto mng = manager_.lock(); | ||||
| MS_EXCEPTION_IF_NULL(mng); | MS_EXCEPTION_IF_NULL(mng); | ||||
| auto& cts = mng->valuenodes(); | |||||
| auto &cts = mng->valuenodes(); | |||||
| return cts[shared_from_base<FuncGraph>()]; | return cts[shared_from_base<FuncGraph>()]; | ||||
| } | } | ||||
| const AnfNodeCounterMap& FuncGraph::free_variables_direct() { | |||||
| const AnfNodeCounterMap &FuncGraph::free_variables_direct() { | |||||
| auto mng = manager_.lock(); | auto mng = manager_.lock(); | ||||
| MS_EXCEPTION_IF_NULL(mng); | MS_EXCEPTION_IF_NULL(mng); | ||||
| auto& fv_direct = mng->free_variables_direct(); | |||||
| auto &fv_direct = mng->free_variables_direct(); | |||||
| return fv_direct[shared_from_base<FuncGraph>()]; | return fv_direct[shared_from_base<FuncGraph>()]; | ||||
| } | } | ||||
| const BaseRefCounterMap& FuncGraph::free_variables_total() { | |||||
| const BaseRefCounterMap &FuncGraph::free_variables_total() { | |||||
| auto mng = manager_.lock(); | auto mng = manager_.lock(); | ||||
| MS_EXCEPTION_IF_NULL(mng); | MS_EXCEPTION_IF_NULL(mng); | ||||
| auto& fv_total = mng->free_variables_total(); | |||||
| auto &fv_total = mng->free_variables_total(); | |||||
| return fv_total[shared_from_base<FuncGraph>()]; | return fv_total[shared_from_base<FuncGraph>()]; | ||||
| } | } | ||||
| std::vector<AnfNodePtr> FuncGraph::free_variables_nodes() { | std::vector<AnfNodePtr> FuncGraph::free_variables_nodes() { | ||||
| std::vector<AnfNodePtr> nodes; | std::vector<AnfNodePtr> nodes; | ||||
| const auto& fv_total = this->free_variables_total(); | |||||
| for (auto& p : fv_total) { | |||||
| const auto &fv_total = this->free_variables_total(); | |||||
| for (auto &p : fv_total) { | |||||
| auto key = p.first; | auto key = p.first; | ||||
| if (utils::isa<AnfNodePtr>(key)) { | if (utils::isa<AnfNodePtr>(key)) { | ||||
| nodes.push_back(utils::cast<AnfNodePtr>(key)); | nodes.push_back(utils::cast<AnfNodePtr>(key)); | ||||
| @@ -238,8 +238,8 @@ std::vector<AnfNodePtr> FuncGraph::free_variables_nodes() { | |||||
| std::vector<FuncGraphPtr> FuncGraph::free_variables_func_graphs() { | std::vector<FuncGraphPtr> FuncGraph::free_variables_func_graphs() { | ||||
| std::vector<FuncGraphPtr> func_graphs; | std::vector<FuncGraphPtr> func_graphs; | ||||
| const auto& fv_total = this->free_variables_total(); | |||||
| for (auto& p : fv_total) { | |||||
| const auto &fv_total = this->free_variables_total(); | |||||
| for (auto &p : fv_total) { | |||||
| auto key = p.first; | auto key = p.first; | ||||
| if (utils::isa<FuncGraphPtr>(key)) { | if (utils::isa<FuncGraphPtr>(key)) { | ||||
| func_graphs.push_back(utils::cast<FuncGraphPtr>(key)); | func_graphs.push_back(utils::cast<FuncGraphPtr>(key)); | ||||
| @@ -249,31 +249,31 @@ std::vector<FuncGraphPtr> FuncGraph::free_variables_func_graphs() { | |||||
| return func_graphs; | return func_graphs; | ||||
| } | } | ||||
| const FuncGraphCounterMap& FuncGraph::func_graphs_used() { | |||||
| const FuncGraphCounterMap &FuncGraph::func_graphs_used() { | |||||
| auto mng = manager_.lock(); | auto mng = manager_.lock(); | ||||
| MS_EXCEPTION_IF_NULL(mng); | MS_EXCEPTION_IF_NULL(mng); | ||||
| auto& used = mng->func_graphs_used(); | |||||
| auto &used = mng->func_graphs_used(); | |||||
| return used[shared_from_base<FuncGraph>()]; | return used[shared_from_base<FuncGraph>()]; | ||||
| } | } | ||||
| const FuncGraphSet& FuncGraph::func_graphs_used_total() { | |||||
| const FuncGraphSet &FuncGraph::func_graphs_used_total() { | |||||
| auto mng = manager_.lock(); | auto mng = manager_.lock(); | ||||
| MS_EXCEPTION_IF_NULL(mng); | MS_EXCEPTION_IF_NULL(mng); | ||||
| auto& used = mng->func_graphs_used_total(shared_from_base<FuncGraph>()); | |||||
| auto &used = mng->func_graphs_used_total(shared_from_base<FuncGraph>()); | |||||
| return used; | return used; | ||||
| } | } | ||||
| const FuncGraphCounterMap& FuncGraph::func_graph_users() { | |||||
| const FuncGraphCounterMap &FuncGraph::func_graph_users() { | |||||
| auto mng = manager_.lock(); | auto mng = manager_.lock(); | ||||
| MS_EXCEPTION_IF_NULL(mng); | MS_EXCEPTION_IF_NULL(mng); | ||||
| auto& users = mng->func_graph_users(); | |||||
| auto &users = mng->func_graph_users(); | |||||
| return users[shared_from_base<FuncGraph>()]; | return users[shared_from_base<FuncGraph>()]; | ||||
| } | } | ||||
| const AnfNodeCounterMap& FuncGraph::func_graph_user_cnodes() { | |||||
| const AnfNodeCounterMap &FuncGraph::func_graph_user_cnodes() { | |||||
| auto mng = manager_.lock(); | auto mng = manager_.lock(); | ||||
| MS_EXCEPTION_IF_NULL(mng); | MS_EXCEPTION_IF_NULL(mng); | ||||
| auto& users = mng->func_graph_user_cnodes(); | |||||
| auto &users = mng->func_graph_user_cnodes(); | |||||
| return users[shared_from_base<FuncGraph>()]; | return users[shared_from_base<FuncGraph>()]; | ||||
| } | } | ||||
| @@ -288,13 +288,13 @@ FuncGraphPtr FuncGraph::parent() { | |||||
| return mng->parent(shared_from_base<FuncGraph>()); | return mng->parent(shared_from_base<FuncGraph>()); | ||||
| } | } | ||||
| const FuncGraphSet& FuncGraph::children() { | |||||
| const FuncGraphSet &FuncGraph::children() { | |||||
| auto mng = manager_.lock(); | auto mng = manager_.lock(); | ||||
| MS_EXCEPTION_IF_NULL(mng); | MS_EXCEPTION_IF_NULL(mng); | ||||
| return mng->children(shared_from_base<FuncGraph>()); | return mng->children(shared_from_base<FuncGraph>()); | ||||
| } | } | ||||
| const FuncGraphSet& FuncGraph::scope() { | |||||
| const FuncGraphSet &FuncGraph::scope() { | |||||
| auto mng = manager_.lock(); | auto mng = manager_.lock(); | ||||
| MS_EXCEPTION_IF_NULL(mng); | MS_EXCEPTION_IF_NULL(mng); | ||||
| return mng->scopes(shared_from_base<FuncGraph>()); | return mng->scopes(shared_from_base<FuncGraph>()); | ||||
| @@ -312,9 +312,9 @@ std::shared_ptr<std::list<FuncGraphPtr>> FuncGraph::recursive_graphs() { | |||||
| return mng->recursive_graphs(shared_from_base<FuncGraph>()); | return mng->recursive_graphs(shared_from_base<FuncGraph>()); | ||||
| } | } | ||||
| void FuncGraph::DumpFuncGraph(const std::string& path) { draw::Draw(path + ".dot", shared_from_base<FuncGraph>()); } | |||||
| void FuncGraph::DumpFuncGraph(const std::string &path) { draw::Draw(path + ".dot", shared_from_base<FuncGraph>()); } | |||||
| AnfNodePtr FuncGraph::GetDefaultValueByName(const std::string& name) { | |||||
| AnfNodePtr FuncGraph::GetDefaultValueByName(const std::string &name) { | |||||
| auto itr = this->parameter_default_value_.find(name); | auto itr = this->parameter_default_value_.find(name); | ||||
| if (itr == parameter_default_value_.end()) { | if (itr == parameter_default_value_.end()) { | ||||
| return nullptr; | return nullptr; | ||||
| @@ -330,9 +330,9 @@ AnfNodePtr FuncGraph::GetDefaultValueByName(const std::string& name) { | |||||
| } | } | ||||
| // set the default values | // set the default values | ||||
| void FuncGraph::SetDefaultValues(const std::vector<std::string>& name_list, const std::vector<AnfNodePtr>& value_list) { | |||||
| void FuncGraph::SetDefaultValues(const std::vector<std::string> &name_list, const std::vector<AnfNodePtr> &value_list) { | |||||
| auto all_is_null = std::all_of(value_list.begin(), value_list.end(), | auto all_is_null = std::all_of(value_list.begin(), value_list.end(), | ||||
| [](const AnfNodePtr& node) { return IsValueNode<NullObj>(node); }); | |||||
| [](const AnfNodePtr &node) { return IsValueNode<NullObj>(node); }); | |||||
| if (value_list.empty()) { | if (value_list.empty()) { | ||||
| all_is_null = true; | all_is_null = true; | ||||
| } | } | ||||
| @@ -348,7 +348,7 @@ void FuncGraph::ClearDefaultValues() { parameter_default_value_.clear(); } | |||||
| size_t FuncGraph::GetDefaultValueCount() { | size_t FuncGraph::GetDefaultValueCount() { | ||||
| int null_count = | int null_count = | ||||
| std::count_if(parameter_default_value_.begin(), parameter_default_value_.end(), | std::count_if(parameter_default_value_.begin(), parameter_default_value_.end(), | ||||
| [](const std::pair<std::string, AnfNodePtr>& pair) { return IsValueNode<NullObj>(pair.second); }); | |||||
| [](const std::pair<std::string, AnfNodePtr> &pair) { return IsValueNode<NullObj>(pair.second); }); | |||||
| return parameter_default_value_.size() - IntToSize(null_count); | return parameter_default_value_.size() - IntToSize(null_count); | ||||
| } | } | ||||
| @@ -425,7 +425,7 @@ int FuncGraph::GetPositionalArgsCount() const { | |||||
| return count - kwonlyargs_count_ - SizeToInt(hyper_param_count_); | return count - kwonlyargs_count_ - SizeToInt(hyper_param_count_); | ||||
| } | } | ||||
| AnfNodePtr FuncGraph::GetParameterByName(const std::string& name) { | |||||
| AnfNodePtr FuncGraph::GetParameterByName(const std::string &name) { | |||||
| for (size_t i = 0; i < parameters_.size(); ++i) { | for (size_t i = 0; i < parameters_.size(); ++i) { | ||||
| MS_EXCEPTION_IF_NULL(parameters_[i]); | MS_EXCEPTION_IF_NULL(parameters_[i]); | ||||
| auto param_cast = parameters_[i]->cast<ParameterPtr>(); | auto param_cast = parameters_[i]->cast<ParameterPtr>(); | ||||
| @@ -437,9 +437,9 @@ AnfNodePtr FuncGraph::GetParameterByName(const std::string& name) { | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| void FuncGraph::GenerateVarParams(const FuncGraphPtr& specialized_graph, | |||||
| std::vector<AnfNodePtr>* specialized_parameter_list, | |||||
| std::unordered_map<AnfNodePtr, AnfNodePtr>* repl_nodes, int variable_args_count, | |||||
| void FuncGraph::GenerateVarParams(const FuncGraphPtr &specialized_graph, | |||||
| std::vector<AnfNodePtr> *specialized_parameter_list, | |||||
| std::unordered_map<AnfNodePtr, AnfNodePtr> *repl_nodes, int variable_args_count, | |||||
| int pos_args_input_count) { | int pos_args_input_count) { | ||||
| // if there is variable argument, pass the input arguments that does not match positional args to it as a tuple | // if there is variable argument, pass the input arguments that does not match positional args to it as a tuple | ||||
| if (specialized_graph->has_vararg()) { | if (specialized_graph->has_vararg()) { | ||||
| @@ -472,14 +472,14 @@ void FuncGraph::GenerateVarParams(const FuncGraphPtr& specialized_graph, | |||||
| } | } | ||||
| } | } | ||||
| void FuncGraph::GenerateKwParams(const FuncGraphPtr& specialized_graph, | |||||
| std::vector<AnfNodePtr>* specialized_parameter_list, | |||||
| const std::vector<abstract::AbstractKeywordArgPtr>& kwarg_list, | |||||
| std::unordered_map<AnfNodePtr, AnfNodePtr>* repl_nodes) { | |||||
| void FuncGraph::GenerateKwParams(const FuncGraphPtr &specialized_graph, | |||||
| std::vector<AnfNodePtr> *specialized_parameter_list, | |||||
| const std::vector<abstract::AbstractKeywordArgPtr> &kwarg_list, | |||||
| std::unordered_map<AnfNodePtr, AnfNodePtr> *repl_nodes) { | |||||
| std::vector<AnfNodePtr> kwarg_keys_tuple_nodes = {NewValueNode(prim::kPrimMakeTuple)}; | std::vector<AnfNodePtr> kwarg_keys_tuple_nodes = {NewValueNode(prim::kPrimMakeTuple)}; | ||||
| std::vector<AnfNodePtr> kwarg_values_tuple_nodes = {NewValueNode(prim::kPrimMakeTuple)}; | std::vector<AnfNodePtr> kwarg_values_tuple_nodes = {NewValueNode(prim::kPrimMakeTuple)}; | ||||
| for (const auto& kwarg : kwarg_list) { | |||||
| for (const auto &kwarg : kwarg_list) { | |||||
| MS_EXCEPTION_IF_NULL(kwarg); | MS_EXCEPTION_IF_NULL(kwarg); | ||||
| std::string kw_param_name = kwarg->get_key(); | std::string kw_param_name = kwarg->get_key(); | ||||
| MS_EXCEPTION_IF_NULL(specialized_graph); | MS_EXCEPTION_IF_NULL(specialized_graph); | ||||
| @@ -493,7 +493,7 @@ void FuncGraph::GenerateKwParams(const FuncGraphPtr& specialized_graph, | |||||
| std::string param_name = specialized_graph->GetVariableKwargName() + "[" + kw_param_name + "]"; | std::string param_name = specialized_graph->GetVariableKwargName() + "[" + kw_param_name + "]"; | ||||
| MS_EXCEPTION_IF_NULL(specialized_parameter_list); | MS_EXCEPTION_IF_NULL(specialized_parameter_list); | ||||
| auto find_kw_arg_in_list = std::any_of(specialized_parameter_list->begin(), specialized_parameter_list->end(), | auto find_kw_arg_in_list = std::any_of(specialized_parameter_list->begin(), specialized_parameter_list->end(), | ||||
| [param_name](const AnfNodePtr& node) { | |||||
| [param_name](const AnfNodePtr &node) { | |||||
| MS_EXCEPTION_IF_NULL(node); | MS_EXCEPTION_IF_NULL(node); | ||||
| auto param = node->cast<ParameterPtr>(); | auto param = node->cast<ParameterPtr>(); | ||||
| return param != nullptr && param->name() == param_name; | return param != nullptr && param->name() == param_name; | ||||
| @@ -526,10 +526,10 @@ void FuncGraph::GenerateKwParams(const FuncGraphPtr& specialized_graph, | |||||
| GenerateKwargReplNode(specialized_graph, repl_nodes, kwarg_keys_tuple_nodes, kwarg_values_tuple_nodes); | GenerateKwargReplNode(specialized_graph, repl_nodes, kwarg_keys_tuple_nodes, kwarg_values_tuple_nodes); | ||||
| } | } | ||||
| void FuncGraph::GenerateKwargReplNode(const FuncGraphPtr& specialized_graph, | |||||
| std::unordered_map<AnfNodePtr, AnfNodePtr>* repl_nodes, | |||||
| const std::vector<AnfNodePtr>& kwarg_keys_tuple_nodes, | |||||
| const std::vector<AnfNodePtr>& kwarg_values_tuple_nodes) { | |||||
| void FuncGraph::GenerateKwargReplNode(const FuncGraphPtr &specialized_graph, | |||||
| std::unordered_map<AnfNodePtr, AnfNodePtr> *repl_nodes, | |||||
| const std::vector<AnfNodePtr> &kwarg_keys_tuple_nodes, | |||||
| const std::vector<AnfNodePtr> &kwarg_values_tuple_nodes) { | |||||
| if (has_kwarg()) { | if (has_kwarg()) { | ||||
| MS_EXCEPTION_IF_NULL(specialized_graph); | MS_EXCEPTION_IF_NULL(specialized_graph); | ||||
| TraceManager::DebugTrace( | TraceManager::DebugTrace( | ||||
| @@ -544,7 +544,7 @@ void FuncGraph::GenerateKwargReplNode(const FuncGraphPtr& specialized_graph, | |||||
| } | } | ||||
| } | } | ||||
| bool FuncGraph::NeedGenerate(const std::vector<abstract::AbstractKeywordArgPtr>& kwarg_list) { | |||||
| bool FuncGraph::NeedGenerate(const std::vector<abstract::AbstractKeywordArgPtr> &kwarg_list) { | |||||
| // if the function does not have any vararg/kwarg/kwonly/default value/kw args input | // if the function does not have any vararg/kwarg/kwonly/default value/kw args input | ||||
| // return the original graph | // return the original graph | ||||
| if (!has_vararg() && kwonlyargs_count() == 0 && !has_kwarg() && GetDefaultValueCount() == 0 && kwarg_list.empty()) { | if (!has_vararg() && kwonlyargs_count() == 0 && !has_kwarg() && GetDefaultValueCount() == 0 && kwarg_list.empty()) { | ||||
| @@ -558,9 +558,9 @@ bool FuncGraph::NeedGenerate(const std::vector<abstract::AbstractKeywordArgPtr>& | |||||
| return true; | return true; | ||||
| } | } | ||||
| void FuncGraph::GenerateDefaultValue(const FuncGraphPtr& specialized_graph, | |||||
| const std::vector<AnfNodePtr>& specialized_parameter_list, | |||||
| std::unordered_map<AnfNodePtr, AnfNodePtr>* repl_nodes) { | |||||
| void FuncGraph::GenerateDefaultValue(const FuncGraphPtr &specialized_graph, | |||||
| const std::vector<AnfNodePtr> &specialized_parameter_list, | |||||
| std::unordered_map<AnfNodePtr, AnfNodePtr> *repl_nodes) { | |||||
| MS_EXCEPTION_IF_NULL(specialized_graph); | MS_EXCEPTION_IF_NULL(specialized_graph); | ||||
| for (size_t i = 0; i < specialized_graph->parameters().size() - hyper_param_count(); ++i) { | for (size_t i = 0; i < specialized_graph->parameters().size() - hyper_param_count(); ++i) { | ||||
| auto param_node = specialized_graph->parameters()[i]; | auto param_node = specialized_graph->parameters()[i]; | ||||
| @@ -583,10 +583,10 @@ void FuncGraph::GenerateDefaultValue(const FuncGraphPtr& specialized_graph, | |||||
| } | } | ||||
| } | } | ||||
| FuncGraphPtr FuncGraph::GenerateGraph(const AbstractBasePtrList& args_spec_list) { | |||||
| FuncGraphPtr FuncGraph::GenerateGraph(const AbstractBasePtrList &args_spec_list) { | |||||
| std::vector<abstract::AbstractKeywordArgPtr> kwarg_list; | std::vector<abstract::AbstractKeywordArgPtr> kwarg_list; | ||||
| size_t arguments_count = args_spec_list.size(); | size_t arguments_count = args_spec_list.size(); | ||||
| for (const auto& arg : args_spec_list) { | |||||
| for (const auto &arg : args_spec_list) { | |||||
| // if it is a keyword argument | // if it is a keyword argument | ||||
| MS_EXCEPTION_IF_NULL(arg); | MS_EXCEPTION_IF_NULL(arg); | ||||
| if (arg->isa<abstract::AbstractKeywordArg>()) { | if (arg->isa<abstract::AbstractKeywordArg>()) { | ||||
| @@ -619,11 +619,11 @@ FuncGraphPtr FuncGraph::GenerateGraph(const AbstractBasePtrList& args_spec_list) | |||||
| MS_EXCEPTION_IF_NULL(specialized_graph); | MS_EXCEPTION_IF_NULL(specialized_graph); | ||||
| auto params = specialized_graph->parameters(); | auto params = specialized_graph->parameters(); | ||||
| (void)std::transform(params.end() - SizeToInt(hyper_param_count()), params.end(), | (void)std::transform(params.end() - SizeToInt(hyper_param_count()), params.end(), | ||||
| std::back_inserter(specialized_parameter_list), [](const AnfNodePtr& node) { return node; }); | |||||
| std::back_inserter(specialized_parameter_list), [](const AnfNodePtr &node) { return node; }); | |||||
| std::shared_ptr<mindspore::FuncGraphManager> manager = mindspore::Manage(specialized_graph, false); | std::shared_ptr<mindspore::FuncGraphManager> manager = mindspore::Manage(specialized_graph, false); | ||||
| auto tr = manager->Transact(); | auto tr = manager->Transact(); | ||||
| for (auto& node_pair : repl_nodes) { | |||||
| for (auto &node_pair : repl_nodes) { | |||||
| MS_LOG(DEBUG) << "GenerateGraph replace:" << node_pair.first->DebugString() << "-" | MS_LOG(DEBUG) << "GenerateGraph replace:" << node_pair.first->DebugString() << "-" | ||||
| << node_pair.second->DebugString(); | << node_pair.second->DebugString(); | ||||
| (void)tr.Replace(node_pair.first, node_pair.second); | (void)tr.Replace(node_pair.first, node_pair.second); | ||||
| @@ -638,7 +638,7 @@ FuncGraphPtr FuncGraph::GenerateGraph(const AbstractBasePtrList& args_spec_list) | |||||
| return specialized_graph; | return specialized_graph; | ||||
| } | } | ||||
| void FuncGraph::add_parameter_obj_node(const AnfNodePtr& p) { paramter_obj_nodes_.push_back(p); } | |||||
| void FuncGraph::add_parameter_obj_node(const AnfNodePtr &p) { paramter_obj_nodes_.push_back(p); } | |||||
| std::list<CNodePtr> FuncGraph::GetOrderedCnodes() { | std::list<CNodePtr> FuncGraph::GetOrderedCnodes() { | ||||
| if (has_flag(GRAPH_FLAG_HAS_EFFECT)) { | if (has_flag(GRAPH_FLAG_HAS_EFFECT)) { | ||||
| @@ -651,7 +651,7 @@ std::list<CNodePtr> FuncGraph::GetOrderedCnodes() { | |||||
| std::list<CNodePtr> cnodes; | std::list<CNodePtr> cnodes; | ||||
| auto nodes = TopoSort(get_return(), SuccDepends, BelongSameGraph); | auto nodes = TopoSort(get_return(), SuccDepends, BelongSameGraph); | ||||
| for (const auto& node : nodes) { | |||||
| for (const auto &node : nodes) { | |||||
| auto cnode = dyn_cast<CNode>(node); | auto cnode = dyn_cast<CNode>(node); | ||||
| if (cnode) { | if (cnode) { | ||||
| cnodes.push_back(cnode); | cnodes.push_back(cnode); | ||||
| @@ -679,7 +679,7 @@ void FuncGraph::EraseUnusedNodeInOrder() { | |||||
| } | } | ||||
| } | } | ||||
| void FuncGraph::EraseUnusedNodeInOrder(const AnfNodePtr& n) { | |||||
| void FuncGraph::EraseUnusedNodeInOrder(const AnfNodePtr &n) { | |||||
| if (has_flag(GRAPH_FLAG_HAS_EFFECT) && n && n->isa<CNode>()) { | if (has_flag(GRAPH_FLAG_HAS_EFFECT) && n && n->isa<CNode>()) { | ||||
| order_.remove(n->cast<CNodePtr>()); | order_.remove(n->cast<CNodePtr>()); | ||||
| MS_LOG(DEBUG) << "Remove the node" << n->DebugString() << " from order list."; | MS_LOG(DEBUG) << "Remove the node" << n->DebugString() << " from order list."; | ||||
| @@ -690,7 +690,7 @@ void FuncGraph::CheckOrder() { | |||||
| if (has_flag(GRAPH_FLAG_HAS_EFFECT)) { | if (has_flag(GRAPH_FLAG_HAS_EFFECT)) { | ||||
| MS_LOG(DEBUG) << "Check graph " << ToString(); | MS_LOG(DEBUG) << "Check graph " << ToString(); | ||||
| for (auto it = order_.begin(); it != order_.end(); (void)it++) { | for (auto it = order_.begin(); it != order_.end(); (void)it++) { | ||||
| for (const auto& input_node : (*it)->inputs()) { | |||||
| for (const auto &input_node : (*it)->inputs()) { | |||||
| if (input_node && input_node->isa<CNode>() && input_node->func_graph() == shared_from_base<FuncGraph>()) { | if (input_node && input_node->isa<CNode>() && input_node->func_graph() == shared_from_base<FuncGraph>()) { | ||||
| // Need to reorder the wrong order node. | // Need to reorder the wrong order node. | ||||
| auto found = std::find(order_.begin(), it, input_node); | auto found = std::find(order_.begin(), it, input_node); | ||||
| @@ -705,7 +705,7 @@ void FuncGraph::CheckOrder() { | |||||
| } | } | ||||
| auto mng = manager_.lock(); | auto mng = manager_.lock(); | ||||
| if (mng != nullptr) { | if (mng != nullptr) { | ||||
| const auto& nodes = mng->nodes()[shared_from_base<FuncGraph>()]; | |||||
| const auto &nodes = mng->nodes()[shared_from_base<FuncGraph>()]; | |||||
| if (nodes.size() != (order_.size() + parameters_.size())) { | if (nodes.size() != (order_.size() + parameters_.size())) { | ||||
| DumpCNodeList(); | DumpCNodeList(); | ||||
| MS_LOG(EXCEPTION) << "CNode order size " << order_.size() << " is not equal to managed node size " | MS_LOG(EXCEPTION) << "CNode order size " << order_.size() << " is not equal to managed node size " | ||||
| @@ -718,7 +718,7 @@ void FuncGraph::CheckOrder() { | |||||
| const char kPrimHasEffect[] = "_side_effect_flag"; | const char kPrimHasEffect[] = "_side_effect_flag"; | ||||
| bool FuncGraph::HasEffect(const CNodePtr& cnode) { | |||||
| bool FuncGraph::HasEffect(const CNodePtr &cnode) { | |||||
| auto prim = GetCNodePrimitive(cnode); | auto prim = GetCNodePrimitive(cnode); | ||||
| if (prim != nullptr && prim->isa<prim::DoSignaturePrimitive>()) { | if (prim != nullptr && prim->isa<prim::DoSignaturePrimitive>()) { | ||||
| auto do_sig = prim->cast<prim::DoSignaturePrimitivePtr>(); | auto do_sig = prim->cast<prim::DoSignaturePrimitivePtr>(); | ||||
| @@ -739,9 +739,9 @@ bool FuncGraph::HasEffect(const CNodePtr& cnode) { | |||||
| return false; | return false; | ||||
| } | } | ||||
| std::shared_ptr<OrderedSet<CNodePtr>> FindRoots(const std::vector<CNodePtr>& segment) { | |||||
| std::shared_ptr<OrderedSet<CNodePtr>> FindRoots(const std::vector<CNodePtr> &segment) { | |||||
| std::shared_ptr<OrderedSet<CNodePtr>> roots = std::make_shared<OrderedSet<CNodePtr>>(segment); | std::shared_ptr<OrderedSet<CNodePtr>> roots = std::make_shared<OrderedSet<CNodePtr>>(segment); | ||||
| for (const auto& node : segment) { | |||||
| for (const auto &node : segment) { | |||||
| if (roots->size() == 1) { | if (roots->size() == 1) { | ||||
| return roots; | return roots; | ||||
| } | } | ||||
| @@ -757,9 +757,9 @@ std::shared_ptr<OrderedSet<CNodePtr>> FindRoots(const std::vector<CNodePtr>& seg | |||||
| return roots; | return roots; | ||||
| } | } | ||||
| std::shared_ptr<OrderedSet<CNodePtr>> FindLeaves(const std::vector<CNodePtr>& segment) { | |||||
| std::shared_ptr<OrderedSet<CNodePtr>> FindLeaves(const std::vector<CNodePtr> &segment) { | |||||
| std::shared_ptr<OrderedSet<CNodePtr>> nodes = std::make_shared<OrderedSet<CNodePtr>>(segment); | std::shared_ptr<OrderedSet<CNodePtr>> nodes = std::make_shared<OrderedSet<CNodePtr>>(segment); | ||||
| for (const auto& node : segment) { | |||||
| for (const auto &node : segment) { | |||||
| if (nodes->size() == 1) { | if (nodes->size() == 1) { | ||||
| return nodes; | return nodes; | ||||
| } | } | ||||
| @@ -790,7 +790,7 @@ void FuncGraph::ReleaseFullOrderToEffectOrder() { | |||||
| if (has_flag(GRAPH_FLAG_HAS_EFFECT)) { | if (has_flag(GRAPH_FLAG_HAS_EFFECT)) { | ||||
| std::list<AnfNodePtr> depends_order; | std::list<AnfNodePtr> depends_order; | ||||
| std::vector<CNodePtr> segment; | std::vector<CNodePtr> segment; | ||||
| for (const auto& cnode : order_) { | |||||
| for (const auto &cnode : order_) { | |||||
| if (IsPrimitiveCNode(cnode, prim::kPrimReturn)) { | if (IsPrimitiveCNode(cnode, prim::kPrimReturn)) { | ||||
| continue; | continue; | ||||
| } | } | ||||
| @@ -830,7 +830,7 @@ void FuncGraph::ReleaseFullOrderToEffectOrder() { | |||||
| } | } | ||||
| } | } | ||||
| void FuncGraph::SetEffectDepends(const std::vector<AnfNodePtr>& depend_inputs) { | |||||
| void FuncGraph::SetEffectDepends(const std::vector<AnfNodePtr> &depend_inputs) { | |||||
| auto old_ret = output(); | auto old_ret = output(); | ||||
| std::vector<AnfNodePtr> inputs{NewValueNode(prim::kPrimDepend), old_ret}; | std::vector<AnfNodePtr> inputs{NewValueNode(prim::kPrimDepend), old_ret}; | ||||
| (void)inputs.insert(inputs.end(), depend_inputs.begin(), depend_inputs.end()); | (void)inputs.insert(inputs.end(), depend_inputs.begin(), depend_inputs.end()); | ||||
| @@ -26,29 +26,29 @@ | |||||
| // namespace to support intermediate representation definition | // namespace to support intermediate representation definition | ||||
| namespace mindspore { | namespace mindspore { | ||||
| Cloner::Cloner(const FuncGraphPtrList& func_graphs, bool clone_all_valuenodes, bool clone_all_child_graphs, | |||||
| bool clone_all_used_graphs, const TraceInfoPtr& relation, const TraceInfoPtr& target_relation) | |||||
| Cloner::Cloner(const FuncGraphPtrList &func_graphs, bool clone_all_valuenodes, bool clone_all_child_graphs, | |||||
| bool clone_all_used_graphs, const TraceInfoPtr &relation, const TraceInfoPtr &target_relation) | |||||
| : clone_all_valuenodes_(clone_all_valuenodes), | : clone_all_valuenodes_(clone_all_valuenodes), | ||||
| clone_all_child_graphs_(clone_all_child_graphs), | clone_all_child_graphs_(clone_all_child_graphs), | ||||
| clone_all_used_graphs_(clone_all_used_graphs), | clone_all_used_graphs_(clone_all_used_graphs), | ||||
| relation_(relation), | relation_(relation), | ||||
| target_relation_(target_relation == nullptr ? relation : target_relation) { | target_relation_(target_relation == nullptr ? relation : target_relation) { | ||||
| for (auto& func_graph : func_graphs) { | |||||
| for (auto &func_graph : func_graphs) { | |||||
| AddClone(func_graph); | AddClone(func_graph); | ||||
| } | } | ||||
| scope_ = kDefaultScope; | scope_ = kDefaultScope; | ||||
| type_ = kBasic; | type_ = kBasic; | ||||
| } | } | ||||
| void Cloner::AddClone(const FuncGraphPtr& func_graph, const FuncGraphPtr& target_func_graph, | |||||
| const AnfNodePtrList& params, CloneType type) { | |||||
| void Cloner::AddClone(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph, | |||||
| const AnfNodePtrList ¶ms, CloneType type) { | |||||
| if (func_graph != nullptr) { | if (func_graph != nullptr) { | ||||
| todo_.push_back({.origin = func_graph, .target = target_func_graph, .params = params}); | todo_.push_back({.origin = func_graph, .target = target_func_graph, .params = params}); | ||||
| type_ = type; | type_ = type; | ||||
| } | } | ||||
| } | } | ||||
| void Cloner::CloneNode(const AnfNodePtr& node, const FuncGraphPtr& target) { | |||||
| void Cloner::CloneNode(const AnfNodePtr &node, const FuncGraphPtr &target) { | |||||
| MS_EXCEPTION_IF_NULL(node); | MS_EXCEPTION_IF_NULL(node); | ||||
| if (repl_node_.find(node) != repl_node_.end() || node->isa<ValueNode>()) { | if (repl_node_.find(node) != repl_node_.end() || node->isa<ValueNode>()) { | ||||
| return; | return; | ||||
| @@ -60,7 +60,7 @@ void Cloner::CloneNode(const AnfNodePtr& node, const FuncGraphPtr& target) { | |||||
| } | } | ||||
| } | } | ||||
| void Cloner::CloneParameter(const AnfNodePtr& node, const FuncGraphPtr& target, bool is_add) { | |||||
| void Cloner::CloneParameter(const AnfNodePtr &node, const FuncGraphPtr &target, bool is_add) { | |||||
| MS_EXCEPTION_IF_NULL(node); | MS_EXCEPTION_IF_NULL(node); | ||||
| MS_EXCEPTION_IF_NULL(target); | MS_EXCEPTION_IF_NULL(target); | ||||
| TraceManager::DebugTrace(node->debug_info(), relation_); | TraceManager::DebugTrace(node->debug_info(), relation_); | ||||
| @@ -77,7 +77,7 @@ void Cloner::CloneParameter(const AnfNodePtr& node, const FuncGraphPtr& target, | |||||
| TraceManager::EndTrace(); | TraceManager::EndTrace(); | ||||
| } | } | ||||
| void Cloner::CloneCNode(const AnfNodePtr& node, const FuncGraphPtr& target) { | |||||
| void Cloner::CloneCNode(const AnfNodePtr &node, const FuncGraphPtr &target) { | |||||
| MS_EXCEPTION_IF_NULL(node); | MS_EXCEPTION_IF_NULL(node); | ||||
| MS_EXCEPTION_IF_NULL(target); | MS_EXCEPTION_IF_NULL(target); | ||||
| TraceManager::DebugTrace(node->debug_info(), relation_); | TraceManager::DebugTrace(node->debug_info(), relation_); | ||||
| @@ -91,7 +91,7 @@ void Cloner::CloneCNode(const AnfNodePtr& node, const FuncGraphPtr& target) { | |||||
| TraceManager::EndTrace(); | TraceManager::EndTrace(); | ||||
| } | } | ||||
| void Cloner::CloneValueNode(const AnfNodePtr& node) { | |||||
| void Cloner::CloneValueNode(const AnfNodePtr &node) { | |||||
| MS_EXCEPTION_IF_NULL(node); | MS_EXCEPTION_IF_NULL(node); | ||||
| TraceManager::DebugTrace(node->debug_info(), relation_); | TraceManager::DebugTrace(node->debug_info(), relation_); | ||||
| ValueNodePtr new_const = NewValueNode(GetValueNode(node)); | ValueNodePtr new_const = NewValueNode(GetValueNode(node)); | ||||
| @@ -102,7 +102,7 @@ void Cloner::CloneValueNode(const AnfNodePtr& node) { | |||||
| TraceManager::EndTrace(); | TraceManager::EndTrace(); | ||||
| } | } | ||||
| void Cloner::CloneValueNode(const AnfNodePtr& node, const FuncGraphPtr& target) { | |||||
| void Cloner::CloneValueNode(const AnfNodePtr &node, const FuncGraphPtr &target) { | |||||
| MS_EXCEPTION_IF_NULL(node); | MS_EXCEPTION_IF_NULL(node); | ||||
| MS_EXCEPTION_IF_NULL(target); | MS_EXCEPTION_IF_NULL(target); | ||||
| TraceManager::DebugTrace(node->debug_info(), relation_); | TraceManager::DebugTrace(node->debug_info(), relation_); | ||||
| @@ -114,14 +114,14 @@ void Cloner::CloneValueNode(const AnfNodePtr& node, const FuncGraphPtr& target) | |||||
| TraceManager::EndTrace(); | TraceManager::EndTrace(); | ||||
| } | } | ||||
| void Cloner::CloneValueNodes(const FuncGraphPtr& func_graph) { | |||||
| void Cloner::CloneValueNodes(const FuncGraphPtr &func_graph) { | |||||
| MS_EXCEPTION_IF_NULL(func_graph); | MS_EXCEPTION_IF_NULL(func_graph); | ||||
| MS_EXCEPTION_IF_NULL(manager_); | MS_EXCEPTION_IF_NULL(manager_); | ||||
| if (!clone_all_valuenodes_) { | if (!clone_all_valuenodes_) { | ||||
| return; | return; | ||||
| } | } | ||||
| auto& value_nodes = manager_->valuenodes()[func_graph]; | |||||
| for (auto& value_node : value_nodes) { | |||||
| auto &value_nodes = manager_->valuenodes()[func_graph]; | |||||
| for (auto &value_node : value_nodes) { | |||||
| auto old_node = value_node.first; | auto old_node = value_node.first; | ||||
| MS_EXCEPTION_IF_NULL(old_node); | MS_EXCEPTION_IF_NULL(old_node); | ||||
| if (repl_node_.count(old_node) == 0) { | if (repl_node_.count(old_node) == 0) { | ||||
| @@ -130,38 +130,38 @@ void Cloner::CloneValueNodes(const FuncGraphPtr& func_graph) { | |||||
| } | } | ||||
| } | } | ||||
| void Cloner::AddChildGraphs(const FuncGraphPtr& func_graph) { | |||||
| void Cloner::AddChildGraphs(const FuncGraphPtr &func_graph) { | |||||
| MS_EXCEPTION_IF_NULL(func_graph); | MS_EXCEPTION_IF_NULL(func_graph); | ||||
| MS_EXCEPTION_IF_NULL(manager_); | MS_EXCEPTION_IF_NULL(manager_); | ||||
| if (!clone_all_child_graphs_) { | if (!clone_all_child_graphs_) { | ||||
| return; | return; | ||||
| } | } | ||||
| auto& scopes = manager_->scopes(func_graph); | |||||
| for (auto& graph : scopes) { | |||||
| auto &scopes = manager_->scopes(func_graph); | |||||
| for (auto &graph : scopes) { | |||||
| if (graph != func_graph) { | if (graph != func_graph) { | ||||
| todo_.push_back({graph, nullptr, {}}); | todo_.push_back({graph, nullptr, {}}); | ||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| void Cloner::AddTotalGraphs(const FuncGraphPtr& func_graph) { | |||||
| void Cloner::AddTotalGraphs(const FuncGraphPtr &func_graph) { | |||||
| MS_EXCEPTION_IF_NULL(func_graph); | MS_EXCEPTION_IF_NULL(func_graph); | ||||
| MS_EXCEPTION_IF_NULL(manager_); | MS_EXCEPTION_IF_NULL(manager_); | ||||
| if (!clone_all_used_graphs_) { | if (!clone_all_used_graphs_) { | ||||
| return; | return; | ||||
| } | } | ||||
| auto& used_graphs = manager_->func_graphs_used()[func_graph]; | |||||
| for (auto& used_graph : used_graphs) { | |||||
| auto &used_graphs = manager_->func_graphs_used()[func_graph]; | |||||
| for (auto &used_graph : used_graphs) { | |||||
| todo_.push_back({used_graph.first, nullptr, {}}); | todo_.push_back({used_graph.first, nullptr, {}}); | ||||
| } | } | ||||
| } | } | ||||
| void Cloner::CloneFuncGraphDefaultValues(const FuncGraphPtr& func_graph, const FuncGraphPtr& target_func_graph) { | |||||
| void Cloner::CloneFuncGraphDefaultValues(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph) { | |||||
| MS_EXCEPTION_IF_NULL(func_graph); | MS_EXCEPTION_IF_NULL(func_graph); | ||||
| MS_EXCEPTION_IF_NULL(target_func_graph); | MS_EXCEPTION_IF_NULL(target_func_graph); | ||||
| for (auto& item : func_graph->parameter_default_value()) { | |||||
| for (auto &item : func_graph->parameter_default_value()) { | |||||
| auto nodes = DeepLinkedGraphSearch(item.second); | auto nodes = DeepLinkedGraphSearch(item.second); | ||||
| for (auto& node : nodes) { | |||||
| for (auto &node : nodes) { | |||||
| MS_EXCEPTION_IF_NULL(node); | MS_EXCEPTION_IF_NULL(node); | ||||
| if (node->isa<CNode>()) { | if (node->isa<CNode>()) { | ||||
| CloneNode(node, target_func_graph); | CloneNode(node, target_func_graph); | ||||
| @@ -172,7 +172,7 @@ void Cloner::CloneFuncGraphDefaultValues(const FuncGraphPtr& func_graph, const F | |||||
| } | } | ||||
| } | } | ||||
| void Cloner::CloneFuncGraphValueNodes(const FuncGraphPtr& func_graph, const FuncGraphPtr& target_func_graph) { | |||||
| void Cloner::CloneFuncGraphValueNodes(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph) { | |||||
| MS_EXCEPTION_IF_NULL(func_graph); | MS_EXCEPTION_IF_NULL(func_graph); | ||||
| MS_EXCEPTION_IF_NULL(target_func_graph); | MS_EXCEPTION_IF_NULL(target_func_graph); | ||||
| MS_EXCEPTION_IF_NULL(manager_); | MS_EXCEPTION_IF_NULL(manager_); | ||||
| @@ -182,15 +182,15 @@ void Cloner::CloneFuncGraphValueNodes(const FuncGraphPtr& func_graph, const Func | |||||
| } | } | ||||
| target_func_graph->set_return(return_node); | target_func_graph->set_return(return_node); | ||||
| auto& value_nodes = manager_->func_graph_valuenodes()[func_graph]; | |||||
| for (auto& value_node : value_nodes) { | |||||
| auto &value_nodes = manager_->func_graph_valuenodes()[func_graph]; | |||||
| for (auto &value_node : value_nodes) { | |||||
| CloneValueNode(value_node.first, target_func_graph); | CloneValueNode(value_node.first, target_func_graph); | ||||
| } | } | ||||
| } | } | ||||
| void Cloner::InlineCloneParameters(const FuncGraphPtr& func_graph, const AnfNodePtrList& params) { | |||||
| void Cloner::InlineCloneParameters(const FuncGraphPtr &func_graph, const AnfNodePtrList ¶ms) { | |||||
| MS_EXCEPTION_IF_NULL(func_graph); | MS_EXCEPTION_IF_NULL(func_graph); | ||||
| auto& old_params = func_graph->parameters(); | |||||
| auto &old_params = func_graph->parameters(); | |||||
| if (old_params.size() != params.size()) { | if (old_params.size() != params.size()) { | ||||
| MS_LOG(EXCEPTION) << "Origin params size[" << old_params.size() << "], inline params size[" << params.size() << "]"; | MS_LOG(EXCEPTION) << "Origin params size[" << old_params.size() << "], inline params size[" << params.size() << "]"; | ||||
| return; | return; | ||||
| @@ -200,7 +200,7 @@ void Cloner::InlineCloneParameters(const FuncGraphPtr& func_graph, const AnfNode | |||||
| } | } | ||||
| } | } | ||||
| void Cloner::SetFuncGraphInfo(const FuncGraphPtr& func_graph, FuncGraphPtr* const target_func_graph) { | |||||
| void Cloner::SetFuncGraphInfo(const FuncGraphPtr &func_graph, FuncGraphPtr *const target_func_graph) { | |||||
| MS_EXCEPTION_IF_NULL(func_graph); | MS_EXCEPTION_IF_NULL(func_graph); | ||||
| MS_EXCEPTION_IF_NULL(target_func_graph); | MS_EXCEPTION_IF_NULL(target_func_graph); | ||||
| TraceManager::DebugTrace(func_graph->debug_info(), target_relation_); | TraceManager::DebugTrace(func_graph->debug_info(), target_relation_); | ||||
| @@ -215,33 +215,33 @@ void Cloner::SetFuncGraphInfo(const FuncGraphPtr& func_graph, FuncGraphPtr* cons | |||||
| TraceManager::EndTrace(); | TraceManager::EndTrace(); | ||||
| } | } | ||||
| void Cloner::CloneParameters(const FuncGraphPtr& func_graph, const FuncGraphPtr& target_func_graph) { | |||||
| void Cloner::CloneParameters(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph) { | |||||
| MS_EXCEPTION_IF_NULL(func_graph); | MS_EXCEPTION_IF_NULL(func_graph); | ||||
| MS_EXCEPTION_IF_NULL(target_func_graph); | MS_EXCEPTION_IF_NULL(target_func_graph); | ||||
| auto& params = func_graph->parameters(); | |||||
| for (auto& param : params) { | |||||
| auto ¶ms = func_graph->parameters(); | |||||
| for (auto ¶m : params) { | |||||
| CloneParameter(param, target_func_graph, true); | CloneParameter(param, target_func_graph, true); | ||||
| } | } | ||||
| repl_func_graph_[func_graph] = target_func_graph; | repl_func_graph_[func_graph] = target_func_graph; | ||||
| } | } | ||||
| void Cloner::GenParameters(const FuncGraphPtr& func_graph) { | |||||
| void Cloner::GenParameters(const FuncGraphPtr &func_graph) { | |||||
| MS_EXCEPTION_IF_NULL(func_graph); | MS_EXCEPTION_IF_NULL(func_graph); | ||||
| auto& free_vars = manager_->free_variables_total(); | |||||
| auto &free_vars = manager_->free_variables_total(); | |||||
| auto iter = free_vars.find(func_graph); | auto iter = free_vars.find(func_graph); | ||||
| if (iter == free_vars.end()) { | if (iter == free_vars.end()) { | ||||
| return; | return; | ||||
| } | } | ||||
| for (auto& fv_map : iter->second) { | |||||
| auto& free_var = fv_map.first; | |||||
| for (auto &fv_map : iter->second) { | |||||
| auto &free_var = fv_map.first; | |||||
| if (utils::isa<AnfNodePtr>(free_var)) { | if (utils::isa<AnfNodePtr>(free_var)) { | ||||
| repl_func_graph_params_[func_graph].push_back(AddParameter(func_graph, utils::cast<AnfNodePtr>(free_var))); | repl_func_graph_params_[func_graph].push_back(AddParameter(func_graph, utils::cast<AnfNodePtr>(free_var))); | ||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| void Cloner::CloneParameter(const ParameterPtr& param, const AnfNodePtr& node) { | |||||
| void Cloner::CloneParameter(const ParameterPtr ¶m, const AnfNodePtr &node) { | |||||
| param->set_abstract(node->abstract()); | param->set_abstract(node->abstract()); | ||||
| if (node->isa<Parameter>()) { | if (node->isa<Parameter>()) { | ||||
| ParameterPtr old_param = dyn_cast<Parameter>(node); | ParameterPtr old_param = dyn_cast<Parameter>(node); | ||||
| @@ -252,7 +252,7 @@ void Cloner::CloneParameter(const ParameterPtr& param, const AnfNodePtr& node) { | |||||
| } | } | ||||
| } | } | ||||
| ParameterPtr Cloner::AddParameter(const FuncGraphPtr& func_graph, const AnfNodePtr& node, bool is_add) { | |||||
| ParameterPtr Cloner::AddParameter(const FuncGraphPtr &func_graph, const AnfNodePtr &node, bool is_add) { | |||||
| TraceManager::DebugTrace(std::make_shared<TraceCopy>(node->debug_info())); | TraceManager::DebugTrace(std::make_shared<TraceCopy>(node->debug_info())); | ||||
| ParameterPtr param = std::make_shared<Parameter>(func_graph); | ParameterPtr param = std::make_shared<Parameter>(func_graph); | ||||
| TraceManager::EndTrace(); | TraceManager::EndTrace(); | ||||
| @@ -265,11 +265,11 @@ ParameterPtr Cloner::AddParameter(const FuncGraphPtr& func_graph, const AnfNodeP | |||||
| return param; | return param; | ||||
| } | } | ||||
| void Cloner::AddParameters(const FuncGraphPtr& func_graph, const AnfNodePtrList& params, | |||||
| AnfNodePtrList* const lift_params, AnfNodePtrList* const input_params) { | |||||
| void Cloner::AddParameters(const FuncGraphPtr &func_graph, const AnfNodePtrList ¶ms, | |||||
| AnfNodePtrList *const lift_params, AnfNodePtrList *const input_params) { | |||||
| AnfNodePtrList parameters; | AnfNodePtrList parameters; | ||||
| std::unordered_set<AnfNodePtr> old_params; | std::unordered_set<AnfNodePtr> old_params; | ||||
| for (auto& param : func_graph->parameters()) { | |||||
| for (auto ¶m : func_graph->parameters()) { | |||||
| auto iter = repl_node_.find(param); | auto iter = repl_node_.find(param); | ||||
| if (iter != repl_node_.end()) { | if (iter != repl_node_.end()) { | ||||
| (void)old_params.insert(iter->second); | (void)old_params.insert(iter->second); | ||||
| @@ -280,7 +280,7 @@ void Cloner::AddParameters(const FuncGraphPtr& func_graph, const AnfNodePtrList& | |||||
| } | } | ||||
| } | } | ||||
| AnfNodePtr new_param = nullptr; | AnfNodePtr new_param = nullptr; | ||||
| for (auto& param : params) { | |||||
| for (auto ¶m : params) { | |||||
| auto old_param = repl_node_[param]; | auto old_param = repl_node_[param]; | ||||
| if (old_param->isa<CNode>() && old_param->func_graph() == func_graph) { | if (old_param->isa<CNode>() && old_param->func_graph() == func_graph) { | ||||
| repl_node_[old_param] = old_param; | repl_node_[old_param] = old_param; | ||||
| @@ -301,10 +301,10 @@ void Cloner::AddParameters(const FuncGraphPtr& func_graph, const AnfNodePtrList& | |||||
| func_graph->set_parameters(parameters); | func_graph->set_parameters(parameters); | ||||
| } | } | ||||
| void Cloner::AddInputs(const FuncGraphPtr& func_graph_user, const FuncGraphPtr& func_graph, | |||||
| const AnfNodePtrList& params) { | |||||
| void Cloner::AddInputs(const FuncGraphPtr &func_graph_user, const FuncGraphPtr &func_graph, | |||||
| const AnfNodePtrList ¶ms) { | |||||
| AnfNodePtr node = nullptr; | AnfNodePtr node = nullptr; | ||||
| auto& repl_func_graph = repl_map_func_graph_[func_graph_user]; | |||||
| auto &repl_func_graph = repl_map_func_graph_[func_graph_user]; | |||||
| auto iter = repl_func_graph.find(func_graph); | auto iter = repl_func_graph.find(func_graph); | ||||
| if (iter == repl_func_graph.end()) { | if (iter == repl_func_graph.end()) { | ||||
| node = func_graph_user->NewCNode({NewValueNode(prim::kPrimPartial), NewValueNode(func_graph)}); | node = func_graph_user->NewCNode({NewValueNode(prim::kPrimPartial), NewValueNode(func_graph)}); | ||||
| @@ -322,9 +322,9 @@ void Cloner::AddInputs(const FuncGraphPtr& func_graph_user, const FuncGraphPtr& | |||||
| OrderParameters(func_graph, inputs); | OrderParameters(func_graph, inputs); | ||||
| } | } | ||||
| void Cloner::OrderParameters(const FuncGraphPtr& func_graph, const AnfNodePtrList& inputs) { | |||||
| void Cloner::OrderParameters(const FuncGraphPtr &func_graph, const AnfNodePtrList &inputs) { | |||||
| std::unordered_set<AnfNodePtr> old_params; | std::unordered_set<AnfNodePtr> old_params; | ||||
| for (auto& param : func_graph->parameters()) { | |||||
| for (auto ¶m : func_graph->parameters()) { | |||||
| (void)old_params.insert(repl_node_[param]); | (void)old_params.insert(repl_node_[param]); | ||||
| } | } | ||||
| std::unordered_set<AnfNodePtr> new_params; | std::unordered_set<AnfNodePtr> new_params; | ||||
| @@ -339,7 +339,7 @@ void Cloner::OrderParameters(const FuncGraphPtr& func_graph, const AnfNodePtrLis | |||||
| (void)new_params.insert(new_param); | (void)new_params.insert(new_param); | ||||
| } | } | ||||
| } | } | ||||
| for (auto& param : func_graph->parameters()) { | |||||
| for (auto ¶m : func_graph->parameters()) { | |||||
| if (new_params.find(param) == new_params.end()) { | if (new_params.find(param) == new_params.end()) { | ||||
| parameters.push_back(param); | parameters.push_back(param); | ||||
| } | } | ||||
| @@ -347,9 +347,9 @@ void Cloner::OrderParameters(const FuncGraphPtr& func_graph, const AnfNodePtrLis | |||||
| func_graph->set_parameters(parameters); | func_graph->set_parameters(parameters); | ||||
| } | } | ||||
| void Cloner::SetEdges(const FuncGraphPtr& func_graph) { | |||||
| void Cloner::SetEdges(const FuncGraphPtr &func_graph) { | |||||
| MS_EXCEPTION_IF_NULL(func_graph); | MS_EXCEPTION_IF_NULL(func_graph); | ||||
| for (auto& node : func_graph->nodes()) { | |||||
| for (auto &node : func_graph->nodes()) { | |||||
| if (node == nullptr) { | if (node == nullptr) { | ||||
| continue; | continue; | ||||
| } | } | ||||
| @@ -358,17 +358,17 @@ void Cloner::SetEdges(const FuncGraphPtr& func_graph) { | |||||
| continue; | continue; | ||||
| } | } | ||||
| auto cnode = node->cast<CNodePtr>(); | auto cnode = node->cast<CNodePtr>(); | ||||
| auto& inputs = cnode->inputs(); | |||||
| auto &inputs = cnode->inputs(); | |||||
| for (size_t i = 0; i < inputs.size(); i++) { | for (size_t i = 0; i < inputs.size(); i++) { | ||||
| auto& input = inputs[i]; | |||||
| auto &input = inputs[i]; | |||||
| if (IsValueNode<FuncGraph>(input)) { | if (IsValueNode<FuncGraph>(input)) { | ||||
| auto graph = GetValueNode<FuncGraphPtr>(input); | auto graph = GetValueNode<FuncGraphPtr>(input); | ||||
| auto& repl_func_graph = repl_map_func_graph_[func_graph]; | |||||
| auto &repl_func_graph = repl_map_func_graph_[func_graph]; | |||||
| if (repl_func_graph.find(graph) != repl_func_graph.end()) { | if (repl_func_graph.find(graph) != repl_func_graph.end()) { | ||||
| transaction_.SetEdge(cnode, SizeToInt(i), repl_func_graph[graph]); | transaction_.SetEdge(cnode, SizeToInt(i), repl_func_graph[graph]); | ||||
| } | } | ||||
| } else { | } else { | ||||
| auto& repl_node = repl_map_node_[func_graph]; | |||||
| auto &repl_node = repl_map_node_[func_graph]; | |||||
| if (repl_node.find(input) != repl_node.end()) { | if (repl_node.find(input) != repl_node.end()) { | ||||
| transaction_.SetEdge(cnode, SizeToInt(i), repl_node[input]); | transaction_.SetEdge(cnode, SizeToInt(i), repl_node[input]); | ||||
| } | } | ||||
| @@ -377,8 +377,8 @@ void Cloner::SetEdges(const FuncGraphPtr& func_graph) { | |||||
| } | } | ||||
| } | } | ||||
| void Cloner::LiftParameters(const FuncGraphPtr& func_graph_user, const FuncGraphPtr& func_graph, | |||||
| const AnfNodePtrList& params) { | |||||
| void Cloner::LiftParameters(const FuncGraphPtr &func_graph_user, const FuncGraphPtr &func_graph, | |||||
| const AnfNodePtrList ¶ms) { | |||||
| AnfNodePtrList lift_params; | AnfNodePtrList lift_params; | ||||
| AnfNodePtrList input_params; | AnfNodePtrList input_params; | ||||
| AddParameters(func_graph_user, params, &lift_params, &input_params); | AddParameters(func_graph_user, params, &lift_params, &input_params); | ||||
| @@ -386,16 +386,16 @@ void Cloner::LiftParameters(const FuncGraphPtr& func_graph_user, const FuncGraph | |||||
| if (lift_params.empty()) { | if (lift_params.empty()) { | ||||
| return; | return; | ||||
| } | } | ||||
| for (auto& user : func_graph_user->func_graph_users()) { | |||||
| for (auto &user : func_graph_user->func_graph_users()) { | |||||
| LiftParameters(user.first, func_graph_user, lift_params); | LiftParameters(user.first, func_graph_user, lift_params); | ||||
| } | } | ||||
| } | } | ||||
| void Cloner::Lift() { | void Cloner::Lift() { | ||||
| for (auto& func_graph_params : repl_func_graph_params_) { | |||||
| auto& func_graph = func_graph_params.first; | |||||
| auto& params = func_graph_params.second; | |||||
| for (auto& user : func_graph->func_graph_users()) { | |||||
| for (auto &func_graph_params : repl_func_graph_params_) { | |||||
| auto &func_graph = func_graph_params.first; | |||||
| auto ¶ms = func_graph_params.second; | |||||
| for (auto &user : func_graph->func_graph_users()) { | |||||
| LiftParameters(user.first, func_graph, params); | LiftParameters(user.first, func_graph, params); | ||||
| } | } | ||||
| } | } | ||||
| @@ -404,18 +404,18 @@ void Cloner::Lift() { | |||||
| void Cloner::LiftParameters() { | void Cloner::LiftParameters() { | ||||
| MS_EXCEPTION_IF_NULL(manager_); | MS_EXCEPTION_IF_NULL(manager_); | ||||
| transaction_ = manager_->Transact(); | transaction_ = manager_->Transact(); | ||||
| const FuncGraphSet& func_graphs = manager_->func_graphs(); | |||||
| for (auto& func_graph : func_graphs) { | |||||
| const FuncGraphSet &func_graphs = manager_->func_graphs(); | |||||
| for (auto &func_graph : func_graphs) { | |||||
| GenParameters(func_graph); | GenParameters(func_graph); | ||||
| } | } | ||||
| Lift(); | Lift(); | ||||
| for (auto& func_graph : func_graphs) { | |||||
| for (auto &func_graph : func_graphs) { | |||||
| SetEdges(func_graph); | SetEdges(func_graph); | ||||
| } | } | ||||
| transaction_.Commit(); | transaction_.Commit(); | ||||
| } | } | ||||
| bool Cloner::CheckStatus(const FuncGraphPtr& func_graph, bool is_inline) { | |||||
| bool Cloner::CheckStatus(const FuncGraphPtr &func_graph, bool is_inline) { | |||||
| MS_EXCEPTION_IF_NULL(func_graph); | MS_EXCEPTION_IF_NULL(func_graph); | ||||
| // Make sure only inline once | // Make sure only inline once | ||||
| if (status_.count(func_graph) != 0) { | if (status_.count(func_graph) != 0) { | ||||
| @@ -430,12 +430,12 @@ bool Cloner::CheckStatus(const FuncGraphPtr& func_graph, bool is_inline) { | |||||
| return true; | return true; | ||||
| } | } | ||||
| void Cloner::CloneAllNodes(const FuncGraphPtr& func_graph, const FuncGraphPtr& target_func_graph) { | |||||
| void Cloner::CloneAllNodes(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph) { | |||||
| MS_EXCEPTION_IF_NULL(func_graph); | MS_EXCEPTION_IF_NULL(func_graph); | ||||
| MS_EXCEPTION_IF_NULL(target_func_graph); | MS_EXCEPTION_IF_NULL(target_func_graph); | ||||
| MS_EXCEPTION_IF_NULL(manager_); | MS_EXCEPTION_IF_NULL(manager_); | ||||
| const AnfNodeSet& nodes = manager_->nodes()[func_graph]; | |||||
| for (auto& node : nodes) { | |||||
| const AnfNodeSet &nodes = manager_->nodes()[func_graph]; | |||||
| for (auto &node : nodes) { | |||||
| CloneNode(node, target_func_graph); | CloneNode(node, target_func_graph); | ||||
| } | } | ||||
| } | } | ||||
| @@ -449,7 +449,7 @@ void Cloner::Run() { | |||||
| // Basic and Inline Clone | // Basic and Inline Clone | ||||
| FuncGraphPtrList func_graphs; | FuncGraphPtrList func_graphs; | ||||
| (void)std::transform(todo_.begin(), todo_.end(), std::back_inserter(func_graphs), | (void)std::transform(todo_.begin(), todo_.end(), std::back_inserter(func_graphs), | ||||
| [](const CloneInfo& item) -> FuncGraphPtr { return item.origin; }); | |||||
| [](const CloneInfo &item) -> FuncGraphPtr { return item.origin; }); | |||||
| manager_ = Manage(func_graphs, false); | manager_ = Manage(func_graphs, false); | ||||
| CloneNodes(); | CloneNodes(); | ||||
| LinkEdges(); | LinkEdges(); | ||||
| @@ -495,13 +495,13 @@ void Cloner::CloneNodes() { | |||||
| } | } | ||||
| void Cloner::LinkEdges() { | void Cloner::LinkEdges() { | ||||
| for (auto& node_pair : nodes_) { | |||||
| for (auto &node_pair : nodes_) { | |||||
| CNodePtr old_node = node_pair.first; | CNodePtr old_node = node_pair.first; | ||||
| CNodePtr new_node = node_pair.second; | CNodePtr new_node = node_pair.second; | ||||
| MS_EXCEPTION_IF_NULL(old_node); | MS_EXCEPTION_IF_NULL(old_node); | ||||
| MS_EXCEPTION_IF_NULL(new_node); | MS_EXCEPTION_IF_NULL(new_node); | ||||
| for (auto& input : old_node->inputs()) { | |||||
| auto& new_input = (repl_node_.count(input) == 0) ? input : repl_node_[input]; | |||||
| for (auto &input : old_node->inputs()) { | |||||
| auto &new_input = (repl_node_.count(input) == 0) ? input : repl_node_[input]; | |||||
| new_node->add_input(new_input); | new_node->add_input(new_input); | ||||
| } | } | ||||
| } | } | ||||
| @@ -509,10 +509,10 @@ void Cloner::LinkEdges() { | |||||
| // For the graphs cloned, update its default value map to the cloned nodes | // For the graphs cloned, update its default value map to the cloned nodes | ||||
| void Cloner::SetDefaults() { | void Cloner::SetDefaults() { | ||||
| for (auto& item : graph_set_) { | |||||
| for (auto &item : graph_set_) { | |||||
| MS_EXCEPTION_IF_NULL(item); | MS_EXCEPTION_IF_NULL(item); | ||||
| if (repl_func_graph_.count(item) != 0) { | if (repl_func_graph_.count(item) != 0) { | ||||
| for (auto& param_def : item->parameter_default_value()) { | |||||
| for (auto ¶m_def : item->parameter_default_value()) { | |||||
| MS_EXCEPTION_IF_NULL(repl_func_graph_[item]); | MS_EXCEPTION_IF_NULL(repl_func_graph_[item]); | ||||
| if (repl_node_.count(param_def.second) != 0) { | if (repl_node_.count(param_def.second) != 0) { | ||||
| repl_func_graph_[item]->set_param_default_value(param_def.first, repl_node_[param_def.second]); | repl_func_graph_[item]->set_param_default_value(param_def.first, repl_node_[param_def.second]); | ||||
| @@ -524,7 +524,7 @@ void Cloner::SetDefaults() { | |||||
| } | } | ||||
| } | } | ||||
| AnfNodePtr Cloner::CloneDisconnected(const AnfNodePtr& root) { | |||||
| AnfNodePtr Cloner::CloneDisconnected(const AnfNodePtr &root) { | |||||
| MS_EXCEPTION_IF_NULL(root); | MS_EXCEPTION_IF_NULL(root); | ||||
| if (repl_func_graph_.find(root->func_graph()) == repl_func_graph_.end()) { | if (repl_func_graph_.find(root->func_graph()) == repl_func_graph_.end()) { | ||||
| MS_LOG(EXCEPTION) << "Cannot find func graph " << root->func_graph()->ToString() << " in cloner."; | MS_LOG(EXCEPTION) << "Cannot find func graph " << root->func_graph()->ToString() << " in cloner."; | ||||
| @@ -537,7 +537,7 @@ AnfNodePtr Cloner::CloneDisconnected(const AnfNodePtr& root) { | |||||
| MS_LOG(EXCEPTION) << "Failed in clone for node " << root->DebugString() << "."; | MS_LOG(EXCEPTION) << "Failed in clone for node " << root->DebugString() << "."; | ||||
| } | } | ||||
| AnfNodePtr Cloner::operator[](const AnfNodePtr& node) { | |||||
| AnfNodePtr Cloner::operator[](const AnfNodePtr &node) { | |||||
| #ifdef ENABLE_PROFILE | #ifdef ENABLE_PROFILE | ||||
| double time = GetTime(); | double time = GetTime(); | ||||
| #endif | #endif | ||||
| @@ -548,7 +548,7 @@ AnfNodePtr Cloner::operator[](const AnfNodePtr& node) { | |||||
| return ((repl_node_.count(node) == 0) ? node : repl_node_[node]); | return ((repl_node_.count(node) == 0) ? node : repl_node_[node]); | ||||
| } | } | ||||
| FuncGraphPtr Cloner::operator[](const FuncGraphPtr& func_graph) { | |||||
| FuncGraphPtr Cloner::operator[](const FuncGraphPtr &func_graph) { | |||||
| #ifdef ENABLE_PROFILE | #ifdef ENABLE_PROFILE | ||||
| double time = GetTime(); | double time = GetTime(); | ||||
| #endif | #endif | ||||
| @@ -559,14 +559,14 @@ FuncGraphPtr Cloner::operator[](const FuncGraphPtr& func_graph) { | |||||
| return ((repl_func_graph_.count(func_graph) == 0) ? func_graph : repl_func_graph_[func_graph]); | return ((repl_func_graph_.count(func_graph) == 0) ? func_graph : repl_func_graph_[func_graph]); | ||||
| } | } | ||||
| FuncGraphPtr BasicClone(const FuncGraphPtr& func_graph) { | |||||
| FuncGraphPtr BasicClone(const FuncGraphPtr &func_graph) { | |||||
| MS_EXCEPTION_IF_NULL(func_graph); | MS_EXCEPTION_IF_NULL(func_graph); | ||||
| Cloner cloner({func_graph}, false, true, true, std::make_shared<TraceCopy>(), nullptr); | Cloner cloner({func_graph}, false, true, true, std::make_shared<TraceCopy>(), nullptr); | ||||
| return cloner[func_graph]; | return cloner[func_graph]; | ||||
| } | } | ||||
| AnfNodePtr InlineClone(const FuncGraphPtr& func_graph, const FuncGraphPtr& target_func_graph, | |||||
| const AnfNodePtrList& func_graph_args, const ScopePtr& scope) { | |||||
| AnfNodePtr InlineClone(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph, | |||||
| const AnfNodePtrList &func_graph_args, const ScopePtr &scope) { | |||||
| MS_EXCEPTION_IF_NULL(func_graph); | MS_EXCEPTION_IF_NULL(func_graph); | ||||
| MS_EXCEPTION_IF_NULL(target_func_graph); | MS_EXCEPTION_IF_NULL(target_func_graph); | ||||
| Cloner cloner({}, false); | Cloner cloner({}, false); | ||||
| @@ -577,14 +577,14 @@ AnfNodePtr InlineClone(const FuncGraphPtr& func_graph, const FuncGraphPtr& targe | |||||
| return cloner[func_graph->output()]; | return cloner[func_graph->output()]; | ||||
| } | } | ||||
| FuncGraphPtr LiftingClone(const FuncGraphPtr& func_graph) { | |||||
| FuncGraphPtr LiftingClone(const FuncGraphPtr &func_graph) { | |||||
| MS_EXCEPTION_IF_NULL(func_graph); | MS_EXCEPTION_IF_NULL(func_graph); | ||||
| Cloner cloner({}, false); | Cloner cloner({}, false); | ||||
| cloner.AddClone(func_graph, nullptr, {}, kLifting); | cloner.AddClone(func_graph, nullptr, {}, kLifting); | ||||
| return cloner[func_graph]; | return cloner[func_graph]; | ||||
| } | } | ||||
| ClonerPtr SpecializerClone(const FuncGraphPtr& func_graph, const TraceInfoPtr& relation) { | |||||
| ClonerPtr SpecializerClone(const FuncGraphPtr &func_graph, const TraceInfoPtr &relation) { | |||||
| MS_EXCEPTION_IF_NULL(func_graph); | MS_EXCEPTION_IF_NULL(func_graph); | ||||
| FuncGraphPtrList func_graphs = {func_graph}; | FuncGraphPtrList func_graphs = {func_graph}; | ||||
| ClonerPtr cloner = | ClonerPtr cloner = | ||||
| @@ -599,14 +599,14 @@ ClonerPtr SpecializerClone(const FuncGraphPtr& func_graph, const TraceInfoPtr& r | |||||
| return cloner; | return cloner; | ||||
| } | } | ||||
| FuncGraphPtr TransformableClone(const FuncGraphPtr& func_graph, const TraceInfoPtr& relation) { | |||||
| FuncGraphPtr TransformableClone(const FuncGraphPtr &func_graph, const TraceInfoPtr &relation) { | |||||
| MS_EXCEPTION_IF_NULL(func_graph); | MS_EXCEPTION_IF_NULL(func_graph); | ||||
| TraceManager::DebugTrace(func_graph->debug_info(), relation); | TraceManager::DebugTrace(func_graph->debug_info(), relation); | ||||
| auto new_func_graph = std::make_shared<FuncGraph>(); | auto new_func_graph = std::make_shared<FuncGraph>(); | ||||
| TraceManager::EndTrace(); | TraceManager::EndTrace(); | ||||
| auto& parameters = func_graph->parameters(); | |||||
| (void)std::for_each(parameters.begin(), parameters.end(), [&new_func_graph](const AnfNodePtr& param) -> void { | |||||
| auto ¶meters = func_graph->parameters(); | |||||
| (void)std::for_each(parameters.begin(), parameters.end(), [&new_func_graph](const AnfNodePtr ¶m) -> void { | |||||
| MS_EXCEPTION_IF_NULL(param); | MS_EXCEPTION_IF_NULL(param); | ||||
| TraceManager::DebugTrace(std::make_shared<TraceCopy>(param->debug_info())); | TraceManager::DebugTrace(std::make_shared<TraceCopy>(param->debug_info())); | ||||
| (void)new_func_graph->add_parameter(); | (void)new_func_graph->add_parameter(); | ||||
| @@ -622,7 +622,7 @@ FuncGraphPtr TransformableClone(const FuncGraphPtr& func_graph, const TraceInfoP | |||||
| new_func_graph->set_kwonlyargs_count(func_graph->kwonlyargs_count()); | new_func_graph->set_kwonlyargs_count(func_graph->kwonlyargs_count()); | ||||
| new_func_graph->set_hyper_param_count(func_graph->hyper_param_count()); | new_func_graph->set_hyper_param_count(func_graph->hyper_param_count()); | ||||
| new_func_graph->set_is_generate(func_graph->is_generated()); | new_func_graph->set_is_generate(func_graph->is_generated()); | ||||
| for (auto& item : func_graph->parameter_default_value()) { | |||||
| for (auto &item : func_graph->parameter_default_value()) { | |||||
| new_func_graph->set_param_default_value(item.first, cloner[item.second]); | new_func_graph->set_param_default_value(item.first, cloner[item.second]); | ||||
| } | } | ||||
| @@ -43,26 +43,26 @@ struct CloneInfo { | |||||
| class Cloner { | class Cloner { | ||||
| public: | public: | ||||
| explicit Cloner(const FuncGraphPtrList& func_graphs = {}, bool clone_all_valuenodes = false, | |||||
| explicit Cloner(const FuncGraphPtrList &func_graphs = {}, bool clone_all_valuenodes = false, | |||||
| bool clone_all_child_graphs = true, bool clone_all_used_graphs = false, | bool clone_all_child_graphs = true, bool clone_all_used_graphs = false, | ||||
| const TraceInfoPtr& relation = std::make_shared<TraceCopy>(), | |||||
| const TraceInfoPtr& target_relation = nullptr); | |||||
| const TraceInfoPtr &relation = std::make_shared<TraceCopy>(), | |||||
| const TraceInfoPtr &target_relation = nullptr); | |||||
| ~Cloner() = default; | ~Cloner() = default; | ||||
| void AddClone(const FuncGraphPtr& func_graph, const FuncGraphPtr& target_func_graph = nullptr, | |||||
| const AnfNodePtrList& params = {}, CloneType type = kBasic); | |||||
| void AddClone(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph = nullptr, | |||||
| const AnfNodePtrList ¶ms = {}, CloneType type = kBasic); | |||||
| void Run(); | void Run(); | ||||
| // Interfaces for specializer | // Interfaces for specializer | ||||
| AnfNodePtr CloneDisconnected(const AnfNodePtr& root); | |||||
| AnfNodePtr operator[](const AnfNodePtr& node); | |||||
| FuncGraphPtr operator[](const FuncGraphPtr& func_graph); | |||||
| AnfNodePtr CloneDisconnected(const AnfNodePtr &root); | |||||
| AnfNodePtr operator[](const AnfNodePtr &node); | |||||
| FuncGraphPtr operator[](const FuncGraphPtr &func_graph); | |||||
| // Map of replicate nodes and graphs | // Map of replicate nodes and graphs | ||||
| std::unordered_map<AnfNodePtr, AnfNodePtr>* cloned_node() { return &repl_node_; } | |||||
| std::unordered_map<AnfNodePtr, AnfNodePtr> *cloned_node() { return &repl_node_; } | |||||
| std::unordered_map<FuncGraphPtr, FuncGraphPtr> cloned_func_graph() { return repl_func_graph_; } | std::unordered_map<FuncGraphPtr, FuncGraphPtr> cloned_func_graph() { return repl_func_graph_; } | ||||
| // Scope of cloned graphs | // Scope of cloned graphs | ||||
| void set_scope(const ScopePtr& scope) { scope_ = scope; } | |||||
| void set_scope(const ScopePtr &scope) { scope_ = scope; } | |||||
| const ScopePtr scope() const { return scope_; } | const ScopePtr scope() const { return scope_; } | ||||
| std::unordered_map<AnfNodePtr, AnfNodePtr> repl_node_; | std::unordered_map<AnfNodePtr, AnfNodePtr> repl_node_; | ||||
| @@ -71,31 +71,31 @@ class Cloner { | |||||
| void CloneNodes(); | void CloneNodes(); | ||||
| void LinkEdges(); | void LinkEdges(); | ||||
| void SetDefaults(); | void SetDefaults(); | ||||
| void CloneNode(const AnfNodePtr& node, const FuncGraphPtr& target); | |||||
| void CloneValueNode(const AnfNodePtr& node); | |||||
| void CloneValueNode(const AnfNodePtr& node, const FuncGraphPtr& target); | |||||
| void CloneCNode(const AnfNodePtr& node, const FuncGraphPtr& target); | |||||
| void CloneParameter(const AnfNodePtr& node, const FuncGraphPtr& target, bool is_add = false); | |||||
| void CloneValueNodes(const FuncGraphPtr& func_graph); | |||||
| void AddChildGraphs(const FuncGraphPtr& func_graph); | |||||
| void AddTotalGraphs(const FuncGraphPtr& func_graph); | |||||
| bool CheckStatus(const FuncGraphPtr& func_graph, bool is_inline); | |||||
| void CloneAllNodes(const FuncGraphPtr& func_graph, const FuncGraphPtr& target_func_graph); | |||||
| void CloneFuncGraphValueNodes(const FuncGraphPtr& func_graph, const FuncGraphPtr& target_func_graph); | |||||
| void CloneFuncGraphDefaultValues(const FuncGraphPtr& func_graph, const FuncGraphPtr& target_func_graph); | |||||
| void InlineCloneParameters(const FuncGraphPtr& func_graph, const AnfNodePtrList& params); | |||||
| void SetFuncGraphInfo(const FuncGraphPtr& func_graph, FuncGraphPtr* const target_func_graph); | |||||
| void CloneParameters(const FuncGraphPtr& func_graph, const FuncGraphPtr& target_func_graph); | |||||
| void GenParameters(const FuncGraphPtr& func_graph); | |||||
| void CloneParameter(const ParameterPtr& param, const AnfNodePtr& node); | |||||
| ParameterPtr AddParameter(const FuncGraphPtr& func_graph, const AnfNodePtr& node, bool is_add = true); | |||||
| void AddParameters(const FuncGraphPtr& func_graph, const AnfNodePtrList& params, AnfNodePtrList* const lift_params, | |||||
| AnfNodePtrList* const input_params); | |||||
| void AddInputs(const FuncGraphPtr& func_graph_user, const FuncGraphPtr& func_graph, const AnfNodePtrList& params); | |||||
| void OrderParameters(const FuncGraphPtr& func_graph, const AnfNodePtrList& inputs); | |||||
| void SetEdges(const FuncGraphPtr& func_graph); | |||||
| void LiftParameters(const FuncGraphPtr& func_graph_user, const FuncGraphPtr& func_graph, | |||||
| const AnfNodePtrList& params); | |||||
| void CloneNode(const AnfNodePtr &node, const FuncGraphPtr &target); | |||||
| void CloneValueNode(const AnfNodePtr &node); | |||||
| void CloneValueNode(const AnfNodePtr &node, const FuncGraphPtr &target); | |||||
| void CloneCNode(const AnfNodePtr &node, const FuncGraphPtr &target); | |||||
| void CloneParameter(const AnfNodePtr &node, const FuncGraphPtr &target, bool is_add = false); | |||||
| void CloneValueNodes(const FuncGraphPtr &func_graph); | |||||
| void AddChildGraphs(const FuncGraphPtr &func_graph); | |||||
| void AddTotalGraphs(const FuncGraphPtr &func_graph); | |||||
| bool CheckStatus(const FuncGraphPtr &func_graph, bool is_inline); | |||||
| void CloneAllNodes(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph); | |||||
| void CloneFuncGraphValueNodes(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph); | |||||
| void CloneFuncGraphDefaultValues(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph); | |||||
| void InlineCloneParameters(const FuncGraphPtr &func_graph, const AnfNodePtrList ¶ms); | |||||
| void SetFuncGraphInfo(const FuncGraphPtr &func_graph, FuncGraphPtr *const target_func_graph); | |||||
| void CloneParameters(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph); | |||||
| void GenParameters(const FuncGraphPtr &func_graph); | |||||
| void CloneParameter(const ParameterPtr ¶m, const AnfNodePtr &node); | |||||
| ParameterPtr AddParameter(const FuncGraphPtr &func_graph, const AnfNodePtr &node, bool is_add = true); | |||||
| void AddParameters(const FuncGraphPtr &func_graph, const AnfNodePtrList ¶ms, AnfNodePtrList *const lift_params, | |||||
| AnfNodePtrList *const input_params); | |||||
| void AddInputs(const FuncGraphPtr &func_graph_user, const FuncGraphPtr &func_graph, const AnfNodePtrList ¶ms); | |||||
| void OrderParameters(const FuncGraphPtr &func_graph, const AnfNodePtrList &inputs); | |||||
| void SetEdges(const FuncGraphPtr &func_graph); | |||||
| void LiftParameters(const FuncGraphPtr &func_graph_user, const FuncGraphPtr &func_graph, | |||||
| const AnfNodePtrList ¶ms); | |||||
| void Lift(); | void Lift(); | ||||
| void LiftParameters(); | void LiftParameters(); | ||||
| @@ -118,17 +118,17 @@ class Cloner { | |||||
| std::unordered_map<FuncGraphPtr, AnfNodePtrList> repl_func_graph_params_; | std::unordered_map<FuncGraphPtr, AnfNodePtrList> repl_func_graph_params_; | ||||
| }; | }; | ||||
| FuncGraphPtr BasicClone(const FuncGraphPtr& func_graph); | |||||
| FuncGraphPtr BasicClone(const FuncGraphPtr &func_graph); | |||||
| AnfNodePtr InlineClone(const FuncGraphPtr& func_graph, const FuncGraphPtr& target_func_graph, | |||||
| const AnfNodePtrList& func_graph_args, const ScopePtr& scope = nullptr); | |||||
| AnfNodePtr InlineClone(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph, | |||||
| const AnfNodePtrList &func_graph_args, const ScopePtr &scope = nullptr); | |||||
| FuncGraphPtr LiftingClone(const FuncGraphPtr& func_graph); | |||||
| FuncGraphPtr LiftingClone(const FuncGraphPtr &func_graph); | |||||
| ClonerPtr SpecializerClone(const FuncGraphPtr& func_graph, const TraceInfoPtr& relation); | |||||
| ClonerPtr SpecializerClone(const FuncGraphPtr &func_graph, const TraceInfoPtr &relation); | |||||
| FuncGraphPtr TransformableClone(const FuncGraphPtr& func_graph, | |||||
| const TraceInfoPtr& relation = std::make_shared<TraceTransform>()); | |||||
| FuncGraphPtr TransformableClone(const FuncGraphPtr &func_graph, | |||||
| const TraceInfoPtr &relation = std::make_shared<TraceTransform>()); | |||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CCSRC_IR_FUNC_GRAPH_CLONER_H_ | #endif // MINDSPORE_CCSRC_IR_FUNC_GRAPH_CLONER_H_ | ||||
| @@ -27,17 +27,17 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| FuncGraphManagerPtr MakeManager(const std::vector<FuncGraphPtr>& func_graphs, bool manage) { | |||||
| FuncGraphManagerPtr MakeManager(const std::vector<FuncGraphPtr> &func_graphs, bool manage) { | |||||
| auto m = std::make_shared<FuncGraphManager>(func_graphs, manage); | auto m = std::make_shared<FuncGraphManager>(func_graphs, manage); | ||||
| m->Init(); | m->Init(); | ||||
| return m; | return m; | ||||
| } | } | ||||
| FuncGraphManagerPtr Manage(const std::vector<FuncGraphPtr>& func_graphs, bool manage) { | |||||
| FuncGraphManagerPtr Manage(const std::vector<FuncGraphPtr> &func_graphs, bool manage) { | |||||
| FuncGraphManagerPtr m = nullptr; | FuncGraphManagerPtr m = nullptr; | ||||
| bool root = false; | bool root = false; | ||||
| for (auto& fg : func_graphs) { | |||||
| for (auto &fg : func_graphs) { | |||||
| if (fg == nullptr) { | if (fg == nullptr) { | ||||
| continue; | continue; | ||||
| } | } | ||||
| @@ -53,7 +53,7 @@ FuncGraphManagerPtr Manage(const std::vector<FuncGraphPtr>& func_graphs, bool ma | |||||
| root = true; | root = true; | ||||
| } | } | ||||
| for (auto& fg : func_graphs) { | |||||
| for (auto &fg : func_graphs) { | |||||
| if (fg == nullptr) { | if (fg == nullptr) { | ||||
| continue; | continue; | ||||
| } | } | ||||
| @@ -67,7 +67,7 @@ FuncGraphManagerPtr Manage(FuncGraphPtr func_graph, bool manage) { | |||||
| return Manage(func_graphs, manage); | return Manage(func_graphs, manage); | ||||
| } | } | ||||
| FuncGraphManager::FuncGraphManager(const std::vector<FuncGraphPtr>& roots, bool manage) | |||||
| FuncGraphManager::FuncGraphManager(const std::vector<FuncGraphPtr> &roots, bool manage) | |||||
| : roots_(roots), is_manage_(manage) { | : roots_(roots), is_manage_(manage) { | ||||
| Reset(); | Reset(); | ||||
| } | } | ||||
| @@ -103,12 +103,12 @@ void FuncGraphManager::Init() { | |||||
| auto roots = roots_; | auto roots = roots_; | ||||
| roots_ = FuncGraphSet(); | roots_ = FuncGraphSet(); | ||||
| for (auto& fg : roots) { | |||||
| for (auto &fg : roots) { | |||||
| AddFuncGraph(fg, true); | AddFuncGraph(fg, true); | ||||
| } | } | ||||
| } | } | ||||
| FuncGraphSet& FuncGraphManager::func_graph_parents_total(const FuncGraphPtr& fg) const { | |||||
| FuncGraphSet &FuncGraphManager::func_graph_parents_total(const FuncGraphPtr &fg) const { | |||||
| MS_EXCEPTION_IF_NULL(fg); | MS_EXCEPTION_IF_NULL(fg); | ||||
| MS_LOG(DEBUG) << "Start func_graph_parents_total func graph " << fg->ToString(); | MS_LOG(DEBUG) << "Start func_graph_parents_total func graph " << fg->ToString(); | ||||
| func_graph_parents_total_->Recompute(fg); | func_graph_parents_total_->Recompute(fg); | ||||
| @@ -116,7 +116,7 @@ FuncGraphSet& FuncGraphManager::func_graph_parents_total(const FuncGraphPtr& fg) | |||||
| return func_graph_parents_total_->func_graph_parents_total_analysis()[fg]; | return func_graph_parents_total_->func_graph_parents_total_analysis()[fg]; | ||||
| } | } | ||||
| FuncGraphPtr FuncGraphManager::parent(const FuncGraphPtr& fg) const { | |||||
| FuncGraphPtr FuncGraphManager::parent(const FuncGraphPtr &fg) const { | |||||
| MS_EXCEPTION_IF_NULL(fg); | MS_EXCEPTION_IF_NULL(fg); | ||||
| MS_EXCEPTION_IF_NULL(func_graph_parent_); | MS_EXCEPTION_IF_NULL(func_graph_parent_); | ||||
| MS_LOG(DEBUG) << "Start parents func graph " << fg->ToString(); | MS_LOG(DEBUG) << "Start parents func graph " << fg->ToString(); | ||||
| @@ -129,7 +129,7 @@ FuncGraphPtr FuncGraphManager::parent(const FuncGraphPtr& fg) const { | |||||
| return func_graph_parent_->parent_analysis()[fg]; | return func_graph_parent_->parent_analysis()[fg]; | ||||
| } | } | ||||
| FuncGraphSet& FuncGraphManager::children(const FuncGraphPtr& fg) const { | |||||
| FuncGraphSet &FuncGraphManager::children(const FuncGraphPtr &fg) const { | |||||
| MS_EXCEPTION_IF_NULL(fg); | MS_EXCEPTION_IF_NULL(fg); | ||||
| MS_EXCEPTION_IF_NULL(children_); | MS_EXCEPTION_IF_NULL(children_); | ||||
| MS_LOG(DEBUG) << "Start child func graph " << fg->ToString(); | MS_LOG(DEBUG) << "Start child func graph " << fg->ToString(); | ||||
| @@ -137,7 +137,7 @@ FuncGraphSet& FuncGraphManager::children(const FuncGraphPtr& fg) const { | |||||
| return children_->children_analysis()[fg]; | return children_->children_analysis()[fg]; | ||||
| } | } | ||||
| FuncGraphSet& FuncGraphManager::scopes(const FuncGraphPtr& fg) const { | |||||
| FuncGraphSet &FuncGraphManager::scopes(const FuncGraphPtr &fg) const { | |||||
| MS_EXCEPTION_IF_NULL(fg); | MS_EXCEPTION_IF_NULL(fg); | ||||
| MS_EXCEPTION_IF_NULL(scopes_); | MS_EXCEPTION_IF_NULL(scopes_); | ||||
| MS_LOG(DEBUG) << "Start scopes func graph:" << fg->ToString(); | MS_LOG(DEBUG) << "Start scopes func graph:" << fg->ToString(); | ||||
| @@ -146,19 +146,19 @@ FuncGraphSet& FuncGraphManager::scopes(const FuncGraphPtr& fg) const { | |||||
| return scopes_->scope_analysis()[fg]; | return scopes_->scope_analysis()[fg]; | ||||
| } | } | ||||
| FVTotalMap& FuncGraphManager::free_variables_total() const { | |||||
| FVTotalMap &FuncGraphManager::free_variables_total() const { | |||||
| MS_EXCEPTION_IF_NULL(free_variables_total_); | MS_EXCEPTION_IF_NULL(free_variables_total_); | ||||
| free_variables_total_->Recompute(); | free_variables_total_->Recompute(); | ||||
| return free_variables_total_->fv_total_analysis(); | return free_variables_total_->fv_total_analysis(); | ||||
| } | } | ||||
| FuncGraphSet& FuncGraphManager::func_graphs_used_total(const FuncGraphPtr& fg) const { | |||||
| FuncGraphSet &FuncGraphManager::func_graphs_used_total(const FuncGraphPtr &fg) const { | |||||
| MS_EXCEPTION_IF_NULL(func_graphs_used_total_); | MS_EXCEPTION_IF_NULL(func_graphs_used_total_); | ||||
| func_graphs_used_total_->Recompute(fg); | func_graphs_used_total_->Recompute(fg); | ||||
| return func_graphs_used_total_->func_graph_used_total_analysis()[fg]; | return func_graphs_used_total_->func_graph_used_total_analysis()[fg]; | ||||
| } | } | ||||
| bool FuncGraphManager::recursive(const FuncGraphPtr& fg) const { | |||||
| bool FuncGraphManager::recursive(const FuncGraphPtr &fg) const { | |||||
| MS_EXCEPTION_IF_NULL(fg); | MS_EXCEPTION_IF_NULL(fg); | ||||
| recursive_->Recompute(fg); | recursive_->Recompute(fg); | ||||
| if (recursive_->recursive_analysis().count(fg) == 0) { | if (recursive_->recursive_analysis().count(fg) == 0) { | ||||
| @@ -168,7 +168,7 @@ bool FuncGraphManager::recursive(const FuncGraphPtr& fg) const { | |||||
| return recursive_->recursive_analysis()[fg]; | return recursive_->recursive_analysis()[fg]; | ||||
| } | } | ||||
| std::shared_ptr<std::list<FuncGraphPtr>> FuncGraphManager::recursive_graphs(const FuncGraphPtr& fg) const { | |||||
| std::shared_ptr<std::list<FuncGraphPtr>> FuncGraphManager::recursive_graphs(const FuncGraphPtr &fg) const { | |||||
| MS_EXCEPTION_IF_NULL(fg); | MS_EXCEPTION_IF_NULL(fg); | ||||
| if (recursive(fg)) { | if (recursive(fg)) { | ||||
| if (!recursive_->recursive_map().count(fg)) { | if (!recursive_->recursive_map().count(fg)) { | ||||
| @@ -185,7 +185,7 @@ std::shared_ptr<std::list<FuncGraphPtr>> FuncGraphManager::recursive_graphs(cons | |||||
| } | } | ||||
| } | } | ||||
| bool FuncGraphManager::func_graph_j_total(const FuncGraphPtr& fg) const { | |||||
| bool FuncGraphManager::func_graph_j_total(const FuncGraphPtr &fg) const { | |||||
| MS_EXCEPTION_IF_NULL(j_total_); | MS_EXCEPTION_IF_NULL(j_total_); | ||||
| MS_EXCEPTION_IF_NULL(fg); | MS_EXCEPTION_IF_NULL(fg); | ||||
| j_total_->Recompute(fg); | j_total_->Recompute(fg); | ||||
| @@ -225,10 +225,10 @@ void FuncGraphManager::Clear() { | |||||
| signals_->InvalidateComputer(); | signals_->InvalidateComputer(); | ||||
| } | } | ||||
| void FuncGraphManager::KeepRoots(const std::vector<FuncGraphPtr>& func_graphs) { | |||||
| void FuncGraphManager::KeepRoots(const std::vector<FuncGraphPtr> &func_graphs) { | |||||
| MS_LOG(DEBUG) << "Start keep roots"; | MS_LOG(DEBUG) << "Start keep roots"; | ||||
| bool root_exist = false; | bool root_exist = false; | ||||
| for (auto& item : func_graphs) { | |||||
| for (auto &item : func_graphs) { | |||||
| if (roots_.contains(item)) { | if (roots_.contains(item)) { | ||||
| root_exist = true; | root_exist = true; | ||||
| break; | break; | ||||
| @@ -245,17 +245,17 @@ void FuncGraphManager::KeepRoots(const std::vector<FuncGraphPtr>& func_graphs) { | |||||
| roots = roots_; | roots = roots_; | ||||
| } else { | } else { | ||||
| roots_.clear(); | roots_.clear(); | ||||
| for (auto& item : roots) { | |||||
| for (auto &item : roots) { | |||||
| AddFuncGraph(item, true); | AddFuncGraph(item, true); | ||||
| } | } | ||||
| } | } | ||||
| FuncGraphSet keep; | FuncGraphSet keep; | ||||
| for (auto& item : roots) { | |||||
| for (auto &item : roots) { | |||||
| MS_LOG(DEBUG) << "roots: " << item->ToString(); | MS_LOG(DEBUG) << "roots: " << item->ToString(); | ||||
| keep.update(func_graphs_used_total(item)); | keep.update(func_graphs_used_total(item)); | ||||
| #ifdef DEBUG | #ifdef DEBUG | ||||
| for (auto& k : keep) { | |||||
| for (auto &k : keep) { | |||||
| MS_LOG(DEBUG) << "keep: " << k->ToString(); | MS_LOG(DEBUG) << "keep: " << k->ToString(); | ||||
| } | } | ||||
| #endif | #endif | ||||
| @@ -264,7 +264,7 @@ void FuncGraphManager::KeepRoots(const std::vector<FuncGraphPtr>& func_graphs) { | |||||
| } else { | } else { | ||||
| Clear(); | Clear(); | ||||
| FuncGraphSet roots(func_graphs); | FuncGraphSet roots(func_graphs); | ||||
| for (auto& item : roots) { | |||||
| for (auto &item : roots) { | |||||
| AddFuncGraph(item, true); | AddFuncGraph(item, true); | ||||
| } | } | ||||
| } | } | ||||
| @@ -276,7 +276,7 @@ void FuncGraphManager::RemoveRoots() { | |||||
| MaybeDropFuncGraphs(func_graphs_, true); | MaybeDropFuncGraphs(func_graphs_, true); | ||||
| } | } | ||||
| void FuncGraphManager::AddIntoManaged(const FuncGraphPtr& fg) { | |||||
| void FuncGraphManager::AddIntoManaged(const FuncGraphPtr &fg) { | |||||
| MS_EXCEPTION_IF_NULL(fg); | MS_EXCEPTION_IF_NULL(fg); | ||||
| if (is_manage_) { | if (is_manage_) { | ||||
| if (fg->manager() != nullptr && (&(*fg->manager()) != this)) { | if (fg->manager() != nullptr && (&(*fg->manager()) != this)) { | ||||
| @@ -288,7 +288,7 @@ void FuncGraphManager::AddIntoManaged(const FuncGraphPtr& fg) { | |||||
| func_graphs_.add(fg); | func_graphs_.add(fg); | ||||
| } | } | ||||
| void FuncGraphManager::MaybeDropFuncGraphs(const FuncGraphSet& func_graphs, bool ignore_users) { | |||||
| void FuncGraphManager::MaybeDropFuncGraphs(const FuncGraphSet &func_graphs, bool ignore_users) { | |||||
| FuncGraphSet todo(func_graphs); | FuncGraphSet todo(func_graphs); | ||||
| std::set<FuncGraphPtr> dropped; | std::set<FuncGraphPtr> dropped; | ||||
| // int count = 0; | // int count = 0; | ||||
| @@ -301,7 +301,7 @@ void FuncGraphManager::MaybeDropFuncGraphs(const FuncGraphSet& func_graphs, bool | |||||
| continue; | continue; | ||||
| } | } | ||||
| MS_EXCEPTION_IF_NULL(func_graph_users_); | MS_EXCEPTION_IF_NULL(func_graph_users_); | ||||
| auto& users = func_graph_users_->count_func_graphs_map()[func_graph]; | |||||
| auto &users = func_graph_users_->count_func_graphs_map()[func_graph]; | |||||
| if (!users.empty() && !ignore_users) { | if (!users.empty() && !ignore_users) { | ||||
| MS_LOG(DEBUG) << "Cannot drop as users not empty: " << func_graph->ToString(); | MS_LOG(DEBUG) << "Cannot drop as users not empty: " << func_graph->ToString(); | ||||
| continue; | continue; | ||||
| @@ -315,7 +315,7 @@ void FuncGraphManager::MaybeDropFuncGraphs(const FuncGraphSet& func_graphs, bool | |||||
| todo.update(MaybeDropNodes(return_vec)); | todo.update(MaybeDropNodes(return_vec)); | ||||
| } | } | ||||
| MS_EXCEPTION_IF_NULL(signals_); | MS_EXCEPTION_IF_NULL(signals_); | ||||
| for (auto& fg : dropped) { | |||||
| for (auto &fg : dropped) { | |||||
| MS_EXCEPTION_IF_NULL(fg); | MS_EXCEPTION_IF_NULL(fg); | ||||
| signals_->DropFuncGraph(fg); | signals_->DropFuncGraph(fg); | ||||
| all_nodes_.difference_update(fg->parameters()); | all_nodes_.difference_update(fg->parameters()); | ||||
| @@ -331,7 +331,7 @@ void FuncGraphManager::ProcessEdge(AnfNodePtr node, int index, AnfNodePtr inp, E | |||||
| MS_EXCEPTION_IF_NULL(inp); | MS_EXCEPTION_IF_NULL(inp); | ||||
| if (direction == kDecEdge) { | if (direction == kDecEdge) { | ||||
| MS_LOG(DEBUG) << "Remove node " << node->ToString() << " input[" << index << "] " << inp->ToString(); | MS_LOG(DEBUG) << "Remove node " << node->ToString() << " input[" << index << "] " << inp->ToString(); | ||||
| auto& users_node = node_users_[inp]; | |||||
| auto &users_node = node_users_[inp]; | |||||
| if (!users_node.contains(make_pair(node, index))) { | if (!users_node.contains(make_pair(node, index))) { | ||||
| return; | return; | ||||
| } | } | ||||
| @@ -346,26 +346,26 @@ void FuncGraphManager::ProcessEdge(AnfNodePtr node, int index, AnfNodePtr inp, E | |||||
| MS_LOG(DEBUG) << "Input[" << index << "] is const graph " << inp->ToString(); | MS_LOG(DEBUG) << "Input[" << index << "] is const graph " << inp->ToString(); | ||||
| AddFuncGraph(GetValueNode<FuncGraphPtr>(inp)); | AddFuncGraph(GetValueNode<FuncGraphPtr>(inp)); | ||||
| } | } | ||||
| auto& users_node = node_users_[inp]; | |||||
| auto &users_node = node_users_[inp]; | |||||
| users_node.add(make_pair(node, index)); | users_node.add(make_pair(node, index)); | ||||
| MS_EXCEPTION_IF_NULL(signals_); | MS_EXCEPTION_IF_NULL(signals_); | ||||
| signals_->AddEdge(node, index, inp); | signals_->AddEdge(node, index, inp); | ||||
| } | } | ||||
| } | } | ||||
| void FuncGraphManager::ProcessInputs(const AnfNodePtr& node, EdgeProcessDirection direction) { | |||||
| void FuncGraphManager::ProcessInputs(const AnfNodePtr &node, EdgeProcessDirection direction) { | |||||
| MS_EXCEPTION_IF_NULL(node); | MS_EXCEPTION_IF_NULL(node); | ||||
| if (node->isa<CNode>()) { | if (node->isa<CNode>()) { | ||||
| auto cnode = node->cast<CNodePtr>(); | auto cnode = node->cast<CNodePtr>(); | ||||
| int index = 0; | int index = 0; | ||||
| for (auto& inp : cnode->inputs()) { | |||||
| for (auto &inp : cnode->inputs()) { | |||||
| ProcessEdge(cnode, index, inp, direction); | ProcessEdge(cnode, index, inp, direction); | ||||
| ++index; | ++index; | ||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| IncludeType FuncGraphManager::Limit(const AnfNodePtr& node) { | |||||
| IncludeType FuncGraphManager::Limit(const AnfNodePtr &node) { | |||||
| if (all_nodes_.contains(node)) { | if (all_nodes_.contains(node)) { | ||||
| return EXCLUDE; | return EXCLUDE; | ||||
| } else { | } else { | ||||
| @@ -373,9 +373,9 @@ IncludeType FuncGraphManager::Limit(const AnfNodePtr& node) { | |||||
| } | } | ||||
| } | } | ||||
| void FuncGraphManager::AcquireNodes(const std::vector<AnfNodePtr>& nodes) { | |||||
| void FuncGraphManager::AcquireNodes(const std::vector<AnfNodePtr> &nodes) { | |||||
| AnfNodeSet acq; | AnfNodeSet acq; | ||||
| for (auto& node : nodes) { | |||||
| for (auto &node : nodes) { | |||||
| std::function<IncludeType(AnfNodePtr)> limit = std::bind(&FuncGraphManager::Limit, this, std::placeholders::_1); | std::function<IncludeType(AnfNodePtr)> limit = std::bind(&FuncGraphManager::Limit, this, std::placeholders::_1); | ||||
| AnfNodeSet new_nodes = AnfNodeSet(DeepScopedGraphSearch(node, limit)); | AnfNodeSet new_nodes = AnfNodeSet(DeepScopedGraphSearch(node, limit)); | ||||
| @@ -384,7 +384,7 @@ void FuncGraphManager::AcquireNodes(const std::vector<AnfNodePtr>& nodes) { | |||||
| acq.update(new_nodes); | acq.update(new_nodes); | ||||
| } | } | ||||
| for (auto& node : acq) { | |||||
| for (auto &node : acq) { | |||||
| MS_EXCEPTION_IF_NULL(node); | MS_EXCEPTION_IF_NULL(node); | ||||
| FuncGraphPtr fg = node->func_graph(); | FuncGraphPtr fg = node->func_graph(); | ||||
| if (fg != nullptr) { | if (fg != nullptr) { | ||||
| @@ -395,7 +395,7 @@ void FuncGraphManager::AcquireNodes(const std::vector<AnfNodePtr>& nodes) { | |||||
| } | } | ||||
| } | } | ||||
| FuncGraphSetPtr FuncGraphManager::MaybeDropNodes(const std::vector<AnfNodePtr>& nodes) { | |||||
| FuncGraphSetPtr FuncGraphManager::MaybeDropNodes(const std::vector<AnfNodePtr> &nodes) { | |||||
| AnfNodeSet nodes_ordered(nodes); | AnfNodeSet nodes_ordered(nodes); | ||||
| FuncGraphSetPtr func_graphs_to_check = std::make_shared<FuncGraphSet>(); | FuncGraphSetPtr func_graphs_to_check = std::make_shared<FuncGraphSet>(); | ||||
| MS_EXCEPTION_IF_NULL(signals_); | MS_EXCEPTION_IF_NULL(signals_); | ||||
| @@ -406,7 +406,7 @@ FuncGraphSetPtr FuncGraphManager::MaybeDropNodes(const std::vector<AnfNodePtr>& | |||||
| if (!all_nodes_.contains(node)) { | if (!all_nodes_.contains(node)) { | ||||
| continue; | continue; | ||||
| } | } | ||||
| AnfNodeIndexSet& users = node_users_[node]; | |||||
| AnfNodeIndexSet &users = node_users_[node]; | |||||
| std::vector<AnfNodePtr> parameters; | std::vector<AnfNodePtr> parameters; | ||||
| if (!users.empty() || | if (!users.empty() || | ||||
| @@ -431,13 +431,13 @@ FuncGraphSetPtr FuncGraphManager::MaybeDropNodes(const std::vector<AnfNodePtr>& | |||||
| return func_graphs_to_check; | return func_graphs_to_check; | ||||
| } | } | ||||
| void FuncGraphManager::SetParameters(const FuncGraphPtr& fg, const std::vector<AnfNodePtr>& parameters) { | |||||
| void FuncGraphManager::SetParameters(const FuncGraphPtr &fg, const std::vector<AnfNodePtr> ¶meters) { | |||||
| auto tr = Transact(); | auto tr = Transact(); | ||||
| tr.SetParameters(fg, parameters); | tr.SetParameters(fg, parameters); | ||||
| tr.Commit(); | tr.Commit(); | ||||
| } | } | ||||
| bool FuncGraphManager::Replace(const AnfNodePtr& old_node, const AnfNodePtr& new_node) { | |||||
| bool FuncGraphManager::Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node) { | |||||
| auto tr = Transact(); | auto tr = Transact(); | ||||
| bool success = tr.Replace(old_node, new_node); | bool success = tr.Replace(old_node, new_node); | ||||
| if (success) { | if (success) { | ||||
| @@ -446,13 +446,13 @@ bool FuncGraphManager::Replace(const AnfNodePtr& old_node, const AnfNodePtr& new | |||||
| return success; | return success; | ||||
| } | } | ||||
| void FuncGraphManager::SetEdge(const AnfNodePtr& node, int index, const AnfNodePtr& value) { | |||||
| void FuncGraphManager::SetEdge(const AnfNodePtr &node, int index, const AnfNodePtr &value) { | |||||
| auto tr = Transact(); | auto tr = Transact(); | ||||
| tr.SetEdge(node, index, value); | tr.SetEdge(node, index, value); | ||||
| tr.Commit(); | tr.Commit(); | ||||
| } | } | ||||
| void FuncGraphManager::MoveAllCNodeDropGraph(FuncGraphPtr source, FuncGraphPtr target, const ScopePtr& scope) { | |||||
| void FuncGraphManager::MoveAllCNodeDropGraph(FuncGraphPtr source, FuncGraphPtr target, const ScopePtr &scope) { | |||||
| AnfNodePtr source_return = source->get_return(); | AnfNodePtr source_return = source->get_return(); | ||||
| AnfNodePtr source_output = source->output(); | AnfNodePtr source_output = source->output(); | ||||
| AnfNodePtr source_prim = source_return->cast<CNodePtr>()->input(0); | AnfNodePtr source_prim = source_return->cast<CNodePtr>()->input(0); | ||||
| @@ -466,23 +466,23 @@ void FuncGraphManager::MoveAllCNodeDropGraph(FuncGraphPtr source, FuncGraphPtr t | |||||
| (void)all_nodes_.erase(source_return); | (void)all_nodes_.erase(source_return); | ||||
| (void)node_users_.erase(source_return); | (void)node_users_.erase(source_return); | ||||
| signals_->DropNode(source_return); | signals_->DropNode(source_return); | ||||
| for (auto& node : source->nodes()) { | |||||
| for (auto &node : source->nodes()) { | |||||
| node->set_func_graph(target); | node->set_func_graph(target); | ||||
| if (node->scope() == kDefaultScope) { | if (node->scope() == kDefaultScope) { | ||||
| node->set_scope(scope); | node->set_scope(scope); | ||||
| } | } | ||||
| } | } | ||||
| for (auto& used : source->func_graphs_used()) { | |||||
| for (auto &used : source->func_graphs_used()) { | |||||
| (void)func_graph_users_->Inc(used.first, target, used.second); | (void)func_graph_users_->Inc(used.first, target, used.second); | ||||
| (void)this->func_graph_users()[used.first].erase(source); | (void)this->func_graph_users()[used.first].erase(source); | ||||
| } | } | ||||
| for (auto& child : this->func_graph_child_direct()[source]) { | |||||
| for (auto &child : this->func_graph_child_direct()[source]) { | |||||
| (void)func_graph_parents_direct_->Inc(child.first, target, child.second); | (void)func_graph_parents_direct_->Inc(child.first, target, child.second); | ||||
| (void)this->func_graph_parents_direct()[child.first].erase(source); | (void)this->func_graph_parents_direct()[child.first].erase(source); | ||||
| } | } | ||||
| for (auto& fv_count : this->free_variables_direct()[source]) { | |||||
| for (auto &fv_count : this->free_variables_direct()[source]) { | |||||
| auto fv_g = fv_count.first->func_graph(); | auto fv_g = fv_count.first->func_graph(); | ||||
| auto& count_on_g = this->func_graph_child_direct()[fv_g]; | |||||
| auto &count_on_g = this->func_graph_child_direct()[fv_g]; | |||||
| auto pair = count_on_g.find(source); | auto pair = count_on_g.find(source); | ||||
| if (fv_g != target && pair != count_on_g.end()) { | if (fv_g != target && pair != count_on_g.end()) { | ||||
| (void)func_graph_child_direct_->Inc(fv_g, target, pair->second); | (void)func_graph_child_direct_->Inc(fv_g, target, pair->second); | ||||
| @@ -504,9 +504,9 @@ FuncGraphTransaction FuncGraphManager::Transact() { | |||||
| return tr; | return tr; | ||||
| } | } | ||||
| void FuncGraphManager::ParseChanges(const std::vector<Change>& changes, EdgeTupleCounter* add_edges, | |||||
| EdgeTupleCounter* rm_edges, Counter<AnfNodePtr>* adds, Counter<AnfNodePtr>* rms) { | |||||
| for (auto& iter : changes) { | |||||
| void FuncGraphManager::ParseChanges(const std::vector<Change> &changes, EdgeTupleCounter *add_edges, | |||||
| EdgeTupleCounter *rm_edges, Counter<AnfNodePtr> *adds, Counter<AnfNodePtr> *rms) { | |||||
| for (auto &iter : changes) { | |||||
| auto operation = iter.op; | auto operation = iter.op; | ||||
| auto args = iter.args; | auto args = iter.args; | ||||
| if (operation == Change::kTxSetEdge) { | if (operation == Change::kTxSetEdge) { | ||||
| @@ -521,10 +521,10 @@ void FuncGraphManager::ParseChanges(const std::vector<Change>& changes, EdgeTupl | |||||
| auto param = args.cast<ArgsOfSetParams>(); | auto param = args.cast<ArgsOfSetParams>(); | ||||
| MS_EXCEPTION_IF_NULL(param.func_graph); | MS_EXCEPTION_IF_NULL(param.func_graph); | ||||
| auto old_parameters = param.func_graph->parameters(); | auto old_parameters = param.func_graph->parameters(); | ||||
| for (auto& p : param.params) { | |||||
| for (auto &p : param.params) { | |||||
| (*adds)[p] += 1; | (*adds)[p] += 1; | ||||
| } | } | ||||
| for (auto& p : old_parameters) { | |||||
| for (auto &p : old_parameters) { | |||||
| (*rms)[p] += 1; | (*rms)[p] += 1; | ||||
| } | } | ||||
| param.func_graph->set_parameters(param.params); | param.func_graph->set_parameters(param.params); | ||||
| @@ -532,7 +532,7 @@ void FuncGraphManager::ParseChanges(const std::vector<Change>& changes, EdgeTupl | |||||
| } | } | ||||
| } | } | ||||
| void FuncGraphManager::CommitChanges(const std::vector<Change>& changes) { | |||||
| void FuncGraphManager::CommitChanges(const std::vector<Change> &changes) { | |||||
| EdgeTupleCounter add_edges; | EdgeTupleCounter add_edges; | ||||
| EdgeTupleCounter rm_edges; | EdgeTupleCounter rm_edges; | ||||
| Counter<AnfNodePtr> adds; | Counter<AnfNodePtr> adds; | ||||
| @@ -540,7 +540,7 @@ void FuncGraphManager::CommitChanges(const std::vector<Change>& changes) { | |||||
| ParseChanges(changes, &add_edges, &rm_edges, &adds, &rms); | ParseChanges(changes, &add_edges, &rm_edges, &adds, &rms); | ||||
| auto sub_edges = add_edges - rm_edges; | auto sub_edges = add_edges - rm_edges; | ||||
| for (auto& iter : sub_edges) { | |||||
| for (auto &iter : sub_edges) { | |||||
| auto root_node = iter.first.first; | auto root_node = iter.first.first; | ||||
| int index = iter.first.second.first; | int index = iter.first.second.first; | ||||
| auto new_node = iter.first.second.second; | auto new_node = iter.first.second.second; | ||||
| @@ -550,12 +550,12 @@ void FuncGraphManager::CommitChanges(const std::vector<Change>& changes) { | |||||
| auto sub_nodes = adds - rms; | auto sub_nodes = adds - rms; | ||||
| std::vector<AnfNodePtr> nodes; | std::vector<AnfNodePtr> nodes; | ||||
| (void)std::transform(sub_nodes.begin(), sub_nodes.end(), std::back_inserter(nodes), | (void)std::transform(sub_nodes.begin(), sub_nodes.end(), std::back_inserter(nodes), | ||||
| [](const std::pair<const AnfNodePtr, int>& iter) -> AnfNodePtr { return iter.first; }); | |||||
| [](const std::pair<const AnfNodePtr, int> &iter) -> AnfNodePtr { return iter.first; }); | |||||
| AcquireNodes(nodes); | AcquireNodes(nodes); | ||||
| auto sub_edges_reverse = rm_edges - add_edges; | auto sub_edges_reverse = rm_edges - add_edges; | ||||
| for (auto& iter : sub_edges_reverse) { | |||||
| for (auto &iter : sub_edges_reverse) { | |||||
| auto root_node = iter.first.first; | auto root_node = iter.first.first; | ||||
| int index = iter.first.second.first; | int index = iter.first.second.first; | ||||
| auto old_node = iter.first.second.second; | auto old_node = iter.first.second.second; | ||||
| @@ -566,17 +566,17 @@ void FuncGraphManager::CommitChanges(const std::vector<Change>& changes) { | |||||
| std::vector<AnfNodePtr> nodes_reverse; | std::vector<AnfNodePtr> nodes_reverse; | ||||
| (void)std::transform(sub_nodes_reverse.begin(), sub_nodes_reverse.end(), std::back_inserter(nodes_reverse), | (void)std::transform(sub_nodes_reverse.begin(), sub_nodes_reverse.end(), std::back_inserter(nodes_reverse), | ||||
| [](const std::pair<const AnfNodePtr, int>& iter) -> AnfNodePtr { return iter.first; }); | |||||
| [](const std::pair<const AnfNodePtr, int> &iter) -> AnfNodePtr { return iter.first; }); | |||||
| auto drop_func_graphs = MaybeDropNodes(nodes_reverse); | auto drop_func_graphs = MaybeDropNodes(nodes_reverse); | ||||
| MaybeDropFuncGraphs(*drop_func_graphs); | MaybeDropFuncGraphs(*drop_func_graphs); | ||||
| } | } | ||||
| void FuncGraphTransaction::SetParameters(FuncGraphPtr fg, const std::vector<AnfNodePtr>& params) { | |||||
| void FuncGraphTransaction::SetParameters(FuncGraphPtr fg, const std::vector<AnfNodePtr> ¶ms) { | |||||
| changes_.emplace_back(Change::kTxSetParams, ArgsOfSetParams{fg, params}); | changes_.emplace_back(Change::kTxSetParams, ArgsOfSetParams{fg, params}); | ||||
| } | } | ||||
| bool FuncGraphTransaction::Replace(const AnfNodePtr& old_node, const AnfNodePtr& new_node) { | |||||
| bool FuncGraphTransaction::Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node) { | |||||
| MS_EXCEPTION_IF_NULL(old_node); | MS_EXCEPTION_IF_NULL(old_node); | ||||
| MS_EXCEPTION_IF_NULL(new_node); | MS_EXCEPTION_IF_NULL(new_node); | ||||
| FuncGraphPtr old_func_graph = old_node->func_graph(); | FuncGraphPtr old_func_graph = old_node->func_graph(); | ||||
| @@ -585,14 +585,14 @@ bool FuncGraphTransaction::Replace(const AnfNodePtr& old_node, const AnfNodePtr& | |||||
| return false; | return false; | ||||
| } | } | ||||
| auto users = manager_->node_users()[old_node]; | auto users = manager_->node_users()[old_node]; | ||||
| for (auto& node : users) { | |||||
| for (auto &node : users) { | |||||
| SetEdge(node.first, node.second, new_node); | SetEdge(node.first, node.second, new_node); | ||||
| } | } | ||||
| return true; | return true; | ||||
| } | } | ||||
| void FuncGraphTransaction::SetEdge(const AnfNodePtr& src_node, int k, const AnfNodePtr& v) { | |||||
| void FuncGraphTransaction::SetEdge(const AnfNodePtr &src_node, int k, const AnfNodePtr &v) { | |||||
| if (k < 0) { | if (k < 0) { | ||||
| MS_LOG(EXCEPTION) << "Invalid value k = " << k; | MS_LOG(EXCEPTION) << "Invalid value k = " << k; | ||||
| } | } | ||||
| @@ -610,7 +610,7 @@ void FuncGraphTransaction::Commit() { | |||||
| manager_->CommitChanges(changes); | manager_->CommitChanges(changes); | ||||
| } | } | ||||
| FuncGraphAnalysis::FuncGraphAnalysis(const FuncGraphManager* const manager) | |||||
| FuncGraphAnalysis::FuncGraphAnalysis(const FuncGraphManager *const manager) | |||||
| : manager_(manager), include_func_graph_none_(false) { | : manager_(manager), include_func_graph_none_(false) { | ||||
| manager_->signals()->AddFuncGraph.connect(this, &FuncGraphAnalysis::OnAddFuncGraph); | manager_->signals()->AddFuncGraph.connect(this, &FuncGraphAnalysis::OnAddFuncGraph); | ||||
| manager_->signals()->DropFuncGraph.connect(this, &FuncGraphAnalysis::OnDropFuncGraph); | manager_->signals()->DropFuncGraph.connect(this, &FuncGraphAnalysis::OnDropFuncGraph); | ||||
| @@ -619,7 +619,7 @@ FuncGraphAnalysis::FuncGraphAnalysis(const FuncGraphManager* const manager) | |||||
| manager_->signals()->MoveAllCNode.connect(this, &FuncGraphAnalysis::OnMoveAllCNode); | manager_->signals()->MoveAllCNode.connect(this, &FuncGraphAnalysis::OnMoveAllCNode); | ||||
| } | } | ||||
| NodesCollector::NodesCollector(const FuncGraphManager* const m) : DepCollector(m), nodes_analysis_() { | |||||
| NodesCollector::NodesCollector(const FuncGraphManager *const m) : DepCollector(m), nodes_analysis_() { | |||||
| include_func_graph_none_ = true; | include_func_graph_none_ = true; | ||||
| nodes_analysis_[nullptr] = AnfNodeSet(); | nodes_analysis_[nullptr] = AnfNodeSet(); | ||||
| @@ -646,7 +646,7 @@ void NodesCollector::OnDropNode(AnfNodePtr n) { | |||||
| void NodesCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) { | void NodesCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) { | ||||
| // change the owner of node except for the src's return node | // change the owner of node except for the src's return node | ||||
| for (auto& it : nodes_analysis_[src]) { | |||||
| for (auto &it : nodes_analysis_[src]) { | |||||
| nodes_analysis_[dst].add(it); | nodes_analysis_[dst].add(it); | ||||
| } | } | ||||
| (void)nodes_analysis_.erase(src); | (void)nodes_analysis_.erase(src); | ||||
| @@ -654,15 +654,15 @@ void NodesCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) { | |||||
| void DepCollector::OnAddEdge(AnfNodePtr node, int index, AnfNodePtr inp) { OnModEdge(node, index, inp, kIncEdge); } | void DepCollector::OnAddEdge(AnfNodePtr node, int index, AnfNodePtr inp) { OnModEdge(node, index, inp, kIncEdge); } | ||||
| DepCollector::DepCollector(const FuncGraphManager* const manager) : FuncGraphAnalysis(manager) { | |||||
| DepCollector::DepCollector(const FuncGraphManager *const manager) : FuncGraphAnalysis(manager) { | |||||
| MS_EXCEPTION_IF_NULL(manager_); | MS_EXCEPTION_IF_NULL(manager_); | ||||
| manager_->signals()->InvalidateCollector.connect(this, &DepCollector::OnInvalidateCollector); | manager_->signals()->InvalidateCollector.connect(this, &DepCollector::OnInvalidateCollector); | ||||
| } | } | ||||
| void DepCollector::OnDropEdge(AnfNodePtr node, int index, AnfNodePtr inp) { OnModEdge(node, index, inp, kDecEdge); } | void DepCollector::OnDropEdge(AnfNodePtr node, int index, AnfNodePtr inp) { OnModEdge(node, index, inp, kDecEdge); } | ||||
| bool CounterAnfNodeCollector::Inc(const FuncGraphPtr& func_graph, const AnfNodePtr& key, int count = 1) { | |||||
| auto& d = count_nodes_map_[func_graph]; | |||||
| bool CounterAnfNodeCollector::Inc(const FuncGraphPtr &func_graph, const AnfNodePtr &key, int count = 1) { | |||||
| auto &d = count_nodes_map_[func_graph]; | |||||
| if (d.count(key) == 0) { | if (d.count(key) == 0) { | ||||
| d[key] = count; | d[key] = count; | ||||
| return true; | return true; | ||||
| @@ -672,9 +672,9 @@ bool CounterAnfNodeCollector::Inc(const FuncGraphPtr& func_graph, const AnfNodeP | |||||
| return false; | return false; | ||||
| } | } | ||||
| bool CounterAnfNodeCollector::Dec(const FuncGraphPtr& func_graph, const AnfNodePtr& key, int count = 1) { | |||||
| bool CounterAnfNodeCollector::Dec(const FuncGraphPtr &func_graph, const AnfNodePtr &key, int count = 1) { | |||||
| MS_EXCEPTION_IF_NULL(func_graph); | MS_EXCEPTION_IF_NULL(func_graph); | ||||
| auto& d = count_nodes_map_[func_graph]; | |||||
| auto &d = count_nodes_map_[func_graph]; | |||||
| if (d.count(key) != 0) { | if (d.count(key) != 0) { | ||||
| if (d[key] == count) { | if (d[key] == count) { | ||||
| (void)d.erase(key); | (void)d.erase(key); | ||||
| @@ -690,7 +690,7 @@ bool CounterAnfNodeCollector::Dec(const FuncGraphPtr& func_graph, const AnfNodeP | |||||
| return false; | return false; | ||||
| } | } | ||||
| bool CounterAnfNodeCollector::Mod(const FuncGraphPtr& func_graph, const AnfNodePtr& key, int count) { | |||||
| bool CounterAnfNodeCollector::Mod(const FuncGraphPtr &func_graph, const AnfNodePtr &key, int count) { | |||||
| if (count > 0) { | if (count > 0) { | ||||
| return Inc(func_graph, key, count); | return Inc(func_graph, key, count); | ||||
| } else if (count < 0) { | } else if (count < 0) { | ||||
| @@ -701,8 +701,8 @@ bool CounterAnfNodeCollector::Mod(const FuncGraphPtr& func_graph, const AnfNodeP | |||||
| } | } | ||||
| } | } | ||||
| bool CounterFuncGraphCollector::Inc(const FuncGraphPtr& func_graph, const FuncGraphPtr& key, int count = 1) { | |||||
| auto& d = count_func_graphs_map_[func_graph]; | |||||
| bool CounterFuncGraphCollector::Inc(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count = 1) { | |||||
| auto &d = count_func_graphs_map_[func_graph]; | |||||
| if (d.count(key) == 0) { | if (d.count(key) == 0) { | ||||
| d[key] = count; | d[key] = count; | ||||
| return true; | return true; | ||||
| @@ -712,8 +712,8 @@ bool CounterFuncGraphCollector::Inc(const FuncGraphPtr& func_graph, const FuncGr | |||||
| return false; | return false; | ||||
| } | } | ||||
| bool CounterFuncGraphCollector::Dec(const FuncGraphPtr& func_graph, const FuncGraphPtr& key, int count = 1) { | |||||
| auto& d = count_func_graphs_map_[func_graph]; | |||||
| bool CounterFuncGraphCollector::Dec(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count = 1) { | |||||
| auto &d = count_func_graphs_map_[func_graph]; | |||||
| if (d.count(key) != 0) { | if (d.count(key) != 0) { | ||||
| if (d[key] == count) { | if (d[key] == count) { | ||||
| (void)d.erase(key); | (void)d.erase(key); | ||||
| @@ -729,7 +729,7 @@ bool CounterFuncGraphCollector::Dec(const FuncGraphPtr& func_graph, const FuncGr | |||||
| return false; | return false; | ||||
| } | } | ||||
| bool CounterFuncGraphCollector::Mod(const FuncGraphPtr& func_graph, const FuncGraphPtr& key, int count) { | |||||
| bool CounterFuncGraphCollector::Mod(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count) { | |||||
| if (count > 0) { | if (count > 0) { | ||||
| return Inc(func_graph, key, count); | return Inc(func_graph, key, count); | ||||
| } else if (count < 0) { | } else if (count < 0) { | ||||
| @@ -748,7 +748,7 @@ void ValueNodesCollector::OnModEdge(AnfNodePtr node, int, AnfNodePtr inp, EdgePr | |||||
| } | } | ||||
| void ValueNodesCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) { | void ValueNodesCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) { | ||||
| for (auto& it : count_nodes_map_[src]) { | |||||
| for (auto &it : count_nodes_map_[src]) { | |||||
| (void)Inc(dst, it.first, it.second); | (void)Inc(dst, it.first, it.second); | ||||
| } | } | ||||
| (void)count_nodes_map_.erase(src); | (void)count_nodes_map_.erase(src); | ||||
| @@ -762,7 +762,7 @@ void FuncGraphValueNodesCollector::OnModEdge(AnfNodePtr, int, AnfNodePtr inp, Ed | |||||
| } | } | ||||
| void FuncGraphValueNodesCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) { | void FuncGraphValueNodesCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) { | ||||
| for (auto& it : count_nodes_map_[src]) { | |||||
| for (auto &it : count_nodes_map_[src]) { | |||||
| (void)Inc(dst, it.first, it.second); | (void)Inc(dst, it.first, it.second); | ||||
| } | } | ||||
| (void)count_nodes_map_.erase(src); | (void)count_nodes_map_.erase(src); | ||||
| @@ -779,7 +779,7 @@ void FVDirectCollector::OnModEdge(AnfNodePtr node, int, AnfNodePtr inp, EdgeProc | |||||
| } | } | ||||
| void FVDirectCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) { | void FVDirectCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) { | ||||
| for (auto& it : count_nodes_map_[src]) { | |||||
| for (auto &it : count_nodes_map_[src]) { | |||||
| FuncGraphPtr fg2 = it.first->func_graph(); | FuncGraphPtr fg2 = it.first->func_graph(); | ||||
| if (fg2 != dst) { | if (fg2 != dst) { | ||||
| (void)Inc(dst, it.first, it.second); | (void)Inc(dst, it.first, it.second); | ||||
| @@ -788,7 +788,7 @@ void FVDirectCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) { | |||||
| (void)count_nodes_map_.erase(src); | (void)count_nodes_map_.erase(src); | ||||
| } | } | ||||
| static FuncGraphPtr ParentProxy(const FuncGraphPtr& fg) { | |||||
| static FuncGraphPtr ParentProxy(const FuncGraphPtr &fg) { | |||||
| FuncGraphPtr gn = std::make_shared<FuncGraph>(); | FuncGraphPtr gn = std::make_shared<FuncGraph>(); | ||||
| (void)gn->transforms().insert(std::make_pair("proxy", FuncGraphTransform(fg))); | (void)gn->transforms().insert(std::make_pair("proxy", FuncGraphTransform(fg))); | ||||
| return gn; | return gn; | ||||
| @@ -805,7 +805,7 @@ void FuncGraphChildDirect::OnModEdge(AnfNodePtr node, int, AnfNodePtr inp, EdgeP | |||||
| } | } | ||||
| void FuncGraphChildDirect::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) { | void FuncGraphChildDirect::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) { | ||||
| for (auto& it : count_func_graphs_map_[src]) { | |||||
| for (auto &it : count_func_graphs_map_[src]) { | |||||
| FuncGraphPtr fg = it.first; | FuncGraphPtr fg = it.first; | ||||
| if (fg != dst) { | if (fg != dst) { | ||||
| (void)Inc(dst, fg, it.second); | (void)Inc(dst, fg, it.second); | ||||
| @@ -835,7 +835,7 @@ void FuncGraphParentsDirectCollector::OnModEdge(AnfNodePtr node, int, AnfNodePtr | |||||
| } | } | ||||
| void FuncGraphParentsDirectCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) { | void FuncGraphParentsDirectCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) { | ||||
| for (auto& it : count_func_graphs_map_[src]) { | |||||
| for (auto &it : count_func_graphs_map_[src]) { | |||||
| if (it.first != dst) { | if (it.first != dst) { | ||||
| (void)Inc(dst, it.first, it.second); | (void)Inc(dst, it.first, it.second); | ||||
| } | } | ||||
| @@ -852,7 +852,7 @@ void FuncGraphsUsedCollector::OnModEdge(AnfNodePtr node, int, AnfNodePtr inp, Ed | |||||
| void FuncGraphsUsedCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) { | void FuncGraphsUsedCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) { | ||||
| // all graph use in src need to change to dst, so meger the to dst use | // all graph use in src need to change to dst, so meger the to dst use | ||||
| for (auto& it : count_func_graphs_map_[src]) { | |||||
| for (auto &it : count_func_graphs_map_[src]) { | |||||
| (void)Inc(dst, it.first, it.second); | (void)Inc(dst, it.first, it.second); | ||||
| } | } | ||||
| (void)count_func_graphs_map_[dst].erase(src); | (void)count_func_graphs_map_[dst].erase(src); | ||||
| @@ -879,7 +879,7 @@ void FuncGraphUserNodesCollector::OnModEdge(AnfNodePtr node, int, AnfNodePtr inp | |||||
| } | } | ||||
| void FuncGraphUserNodesCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) { | void FuncGraphUserNodesCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) { | ||||
| for (auto& it : count_nodes_map_[src]) { | |||||
| for (auto &it : count_nodes_map_[src]) { | |||||
| (void)Inc(dst, it.first, it.second); | (void)Inc(dst, it.first, it.second); | ||||
| } | } | ||||
| (void)count_nodes_map_.erase(src); | (void)count_nodes_map_.erase(src); | ||||
| @@ -895,13 +895,13 @@ void FuncGraphJDirectCollector::OnModEdge(AnfNodePtr node, int, AnfNodePtr inp, | |||||
| void FuncGraphJDirectCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) { | void FuncGraphJDirectCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) { | ||||
| // all graph use in src need to change to dst, so meger the to dst use | // all graph use in src need to change to dst, so meger the to dst use | ||||
| for (auto& it : count_func_graphs_map_[src]) { | |||||
| for (auto &it : count_func_graphs_map_[src]) { | |||||
| (void)Inc(dst, it.first, it.second); | (void)Inc(dst, it.first, it.second); | ||||
| } | } | ||||
| (void)count_func_graphs_map_.erase(src); | (void)count_func_graphs_map_.erase(src); | ||||
| } | } | ||||
| DepComputer::DepComputer(const FuncGraphManager* const manager) : FuncGraphAnalysis(manager) { | |||||
| DepComputer::DepComputer(const FuncGraphManager *const manager) : FuncGraphAnalysis(manager) { | |||||
| MS_EXCEPTION_IF_NULL(manager_); | MS_EXCEPTION_IF_NULL(manager_); | ||||
| manager_->signals()->InvalidateComputer.connect(this, &DepComputer::OnInvalidateComputer); | manager_->signals()->InvalidateComputer.connect(this, &DepComputer::OnInvalidateComputer); | ||||
| validate_ = false; | validate_ = false; | ||||
| @@ -914,20 +914,20 @@ void DepComputer::Recompute() { | |||||
| } | } | ||||
| } | } | ||||
| void DepComputer::Recompute(const FuncGraphPtr& fg) { | |||||
| void DepComputer::Recompute(const FuncGraphPtr &fg) { | |||||
| if (func_graphs_validate_.count(fg) == 0 || !func_graphs_validate_[fg]) { | if (func_graphs_validate_.count(fg) == 0 || !func_graphs_validate_[fg]) { | ||||
| RealRecompute(fg); | RealRecompute(fg); | ||||
| func_graphs_validate_[fg] = true; | func_graphs_validate_[fg] = true; | ||||
| } | } | ||||
| } | } | ||||
| FuncGraphSetPtr FuncGraphParentsTotalComputer::SeekParents(const FuncGraphPtr& fg, const FuncGraphSetPtr& path) { | |||||
| FuncGraphSetPtr FuncGraphParentsTotalComputer::SeekParents(const FuncGraphPtr &fg, const FuncGraphSetPtr &path) { | |||||
| if (path == nullptr || path->contains(fg)) { | if (path == nullptr || path->contains(fg)) { | ||||
| return std::make_shared<FuncGraphSet>(); | return std::make_shared<FuncGraphSet>(); | ||||
| } | } | ||||
| FuncGraphSetPtr parents = std::make_shared<FuncGraphSet>(); | FuncGraphSetPtr parents = std::make_shared<FuncGraphSet>(); | ||||
| FuncGraphToFuncGraphCounterMap& deps = *all_parents_direct_; | |||||
| for (auto& dep : deps[fg]) { | |||||
| FuncGraphToFuncGraphCounterMap &deps = *all_parents_direct_; | |||||
| for (auto &dep : deps[fg]) { | |||||
| MS_EXCEPTION_IF_NULL(dep.first); | MS_EXCEPTION_IF_NULL(dep.first); | ||||
| auto proxy = dep.first->transforms().find("proxy"); | auto proxy = dep.first->transforms().find("proxy"); | ||||
| if (proxy != dep.first->transforms().end()) { | if (proxy != dep.first->transforms().end()) { | ||||
| @@ -950,7 +950,7 @@ void FuncGraphParentsTotalComputer::RealRecompute(FuncGraphPtr fg) { | |||||
| MS_LOG(DEBUG) << "FuncGraphParentsTotalComputer end: " << func_graph_parents_total_analysis_[fg].size(); | MS_LOG(DEBUG) << "FuncGraphParentsTotalComputer end: " << func_graph_parents_total_analysis_[fg].size(); | ||||
| } | } | ||||
| bool set_len_compare(const FuncGraphSetPair& lhs, const FuncGraphSetPair& rhs) { | |||||
| bool set_len_compare(const FuncGraphSetPair &lhs, const FuncGraphSetPair &rhs) { | |||||
| auto l1 = lhs.second.size(); | auto l1 = lhs.second.size(); | ||||
| auto l2 = rhs.second.size(); | auto l2 = rhs.second.size(); | ||||
| return l1 < l2; | return l1 < l2; | ||||
| @@ -970,9 +970,9 @@ void ParentComputer::RealRecompute(FuncGraphPtr fg) { | |||||
| } else { | } else { | ||||
| // return nearest parent as parent | // return nearest parent as parent | ||||
| FuncGraphSet deps_copy(deps); | FuncGraphSet deps_copy(deps); | ||||
| for (auto& dep : deps) { | |||||
| for (auto &dep : deps) { | |||||
| auto parent_deps = this->manager_->func_graph_parents_total(dep); | auto parent_deps = this->manager_->func_graph_parents_total(dep); | ||||
| for (auto& p_d : parent_deps) { | |||||
| for (auto &p_d : parent_deps) { | |||||
| if (deps_copy.count(p_d)) { | if (deps_copy.count(p_d)) { | ||||
| (void)deps_copy.erase(p_d); | (void)deps_copy.erase(p_d); | ||||
| } | } | ||||
| @@ -988,7 +988,7 @@ void ParentComputer::RealRecompute(FuncGraphPtr fg) { | |||||
| void ChildrenComputer::RealRecompute(FuncGraphPtr fg) { | void ChildrenComputer::RealRecompute(FuncGraphPtr fg) { | ||||
| MS_EXCEPTION_IF_NULL(manager_); | MS_EXCEPTION_IF_NULL(manager_); | ||||
| auto used_fg_total = manager_->func_graphs_used_total(fg); | auto used_fg_total = manager_->func_graphs_used_total(fg); | ||||
| for (auto& used_fg : used_fg_total) { | |||||
| for (auto &used_fg : used_fg_total) { | |||||
| if (manager_->parent(used_fg) == fg) { | if (manager_->parent(used_fg) == fg) { | ||||
| children_analysis_[fg].add(used_fg); | children_analysis_[fg].add(used_fg); | ||||
| } | } | ||||
| @@ -997,11 +997,11 @@ void ChildrenComputer::RealRecompute(FuncGraphPtr fg) { | |||||
| void ScopeComputer::RealRecompute(FuncGraphPtr fg) { | void ScopeComputer::RealRecompute(FuncGraphPtr fg) { | ||||
| MS_EXCEPTION_IF_NULL(manager_); | MS_EXCEPTION_IF_NULL(manager_); | ||||
| auto& children = manager_->children(fg); | |||||
| auto &children = manager_->children(fg); | |||||
| scope_analysis_[fg] = FuncGraphSet(); | scope_analysis_[fg] = FuncGraphSet(); | ||||
| scope_analysis_[fg].add(fg); | scope_analysis_[fg].add(fg); | ||||
| for (auto& child : children) { | |||||
| for (auto &child : children) { | |||||
| scope_analysis_[fg].add(child); | scope_analysis_[fg].add(child); | ||||
| } | } | ||||
| } | } | ||||
| @@ -1010,20 +1010,20 @@ void FVTotalComputer::RealRecompute() { | |||||
| auto manager = DepComputer::manager_; | auto manager = DepComputer::manager_; | ||||
| MS_EXCEPTION_IF_NULL(manager); | MS_EXCEPTION_IF_NULL(manager); | ||||
| for (auto& fg : manager->func_graphs()) { | |||||
| for (auto &fg : manager->func_graphs()) { | |||||
| fv_total_analysis_[fg] = OrderedMap<BaseRef, int, BaseRefHash>(); | fv_total_analysis_[fg] = OrderedMap<BaseRef, int, BaseRefHash>(); | ||||
| count_nodes_map_[fg] = OrderedMap<AnfNodePtr, int>(); | count_nodes_map_[fg] = OrderedMap<AnfNodePtr, int>(); | ||||
| count_func_graphs_map_[fg] = OrderedMap<FuncGraphPtr, int>(); | count_func_graphs_map_[fg] = OrderedMap<FuncGraphPtr, int>(); | ||||
| } | } | ||||
| for (auto& fg : manager->func_graphs()) { | |||||
| for (auto &fg : manager->func_graphs()) { | |||||
| AnfNodeCounterMap items = manager->free_variables_direct()[fg]; | AnfNodeCounterMap items = manager->free_variables_direct()[fg]; | ||||
| for (auto& iter : items) { | |||||
| for (auto &iter : items) { | |||||
| auto curr = fg; | auto curr = fg; | ||||
| while (curr) { | while (curr) { | ||||
| (void)CounterAnfNodeCollector::Mod(curr, iter.first, iter.second); | (void)CounterAnfNodeCollector::Mod(curr, iter.first, iter.second); | ||||
| curr = manager->parent(curr); | curr = manager->parent(curr); | ||||
| const AnfNodeSet& nodes = manager->nodes()[curr]; | |||||
| const AnfNodeSet &nodes = manager->nodes()[curr]; | |||||
| if (nodes.contains(iter.first)) { | if (nodes.contains(iter.first)) { | ||||
| break; | break; | ||||
| } | } | ||||
| @@ -1031,7 +1031,7 @@ void FVTotalComputer::RealRecompute() { | |||||
| } | } | ||||
| auto items_fg = manager->func_graphs_used()[fg]; | auto items_fg = manager->func_graphs_used()[fg]; | ||||
| for (auto& iter : items_fg) { | |||||
| for (auto &iter : items_fg) { | |||||
| auto p = manager->parent(iter.first); | auto p = manager->parent(iter.first); | ||||
| if (p == nullptr) { | if (p == nullptr) { | ||||
| continue; | continue; | ||||
| @@ -1043,13 +1043,13 @@ void FVTotalComputer::RealRecompute() { | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| for (auto& fg : manager->func_graphs()) { | |||||
| auto& fvp = count_nodes_map_[fg]; | |||||
| auto& fvg = count_func_graphs_map_[fg]; | |||||
| for (auto& item : fvp) { | |||||
| for (auto &fg : manager->func_graphs()) { | |||||
| auto &fvp = count_nodes_map_[fg]; | |||||
| auto &fvg = count_func_graphs_map_[fg]; | |||||
| for (auto &item : fvp) { | |||||
| fv_total_analysis_[fg][item.first] = item.second; | fv_total_analysis_[fg][item.first] = item.second; | ||||
| } | } | ||||
| for (auto& item : fvg) { | |||||
| for (auto &item : fvg) { | |||||
| fv_total_analysis_[fg][item.first] = item.second; | fv_total_analysis_[fg][item.first] = item.second; | ||||
| } | } | ||||
| } | } | ||||
| @@ -1057,15 +1057,15 @@ void FVTotalComputer::RealRecompute() { | |||||
| void FuncGraphsUsedTotalComputer::RealRecompute(FuncGraphPtr fg) { | void FuncGraphsUsedTotalComputer::RealRecompute(FuncGraphPtr fg) { | ||||
| MS_EXCEPTION_IF_NULL(manager_); | MS_EXCEPTION_IF_NULL(manager_); | ||||
| auto& used = this->manager_->func_graphs_used(); | |||||
| auto &used = this->manager_->func_graphs_used(); | |||||
| std::vector<FuncGraphPtr> todo; | std::vector<FuncGraphPtr> todo; | ||||
| std::vector<FuncGraphPtr> todo_new; | std::vector<FuncGraphPtr> todo_new; | ||||
| todo.push_back(fg); | todo.push_back(fg); | ||||
| while (!todo.empty()) { | while (!todo.empty()) { | ||||
| todo_new.clear(); | todo_new.clear(); | ||||
| for (auto& gt : todo) { | |||||
| for (auto& item : used[gt]) { | |||||
| for (auto > : todo) { | |||||
| for (auto &item : used[gt]) { | |||||
| auto used_fg = item.first; | auto used_fg = item.first; | ||||
| if (used_fg == fg) { | if (used_fg == fg) { | ||||
| func_graph_used_total_analysis_[fg].add(used_fg); | func_graph_used_total_analysis_[fg].add(used_fg); | ||||
| @@ -1082,17 +1082,17 @@ void FuncGraphsUsedTotalComputer::RealRecompute(FuncGraphPtr fg) { | |||||
| } | } | ||||
| } | } | ||||
| bool CheckRecursive(const FuncGraphManager* const manager, const FuncGraphPtr& fg) { | |||||
| bool CheckRecursive(const FuncGraphManager *const manager, const FuncGraphPtr &fg) { | |||||
| MS_EXCEPTION_IF_NULL(manager); | MS_EXCEPTION_IF_NULL(manager); | ||||
| auto& used = manager->func_graphs_used(); | |||||
| auto &used = manager->func_graphs_used(); | |||||
| std::vector<FuncGraphPtr> todo; | std::vector<FuncGraphPtr> todo; | ||||
| std::vector<FuncGraphPtr> todo_new; | std::vector<FuncGraphPtr> todo_new; | ||||
| todo.push_back(fg); | todo.push_back(fg); | ||||
| FuncGraphSet used_total; | FuncGraphSet used_total; | ||||
| while (!todo.empty()) { | while (!todo.empty()) { | ||||
| todo_new.clear(); | todo_new.clear(); | ||||
| for (auto& gt : todo) { | |||||
| for (auto& item : used[gt]) { | |||||
| for (auto > : todo) { | |||||
| for (auto &item : used[gt]) { | |||||
| auto used_g = item.first; | auto used_g = item.first; | ||||
| if (used_g == fg) { | if (used_g == fg) { | ||||
| return true; | return true; | ||||
| @@ -1112,7 +1112,7 @@ void RecursiveComputer::RealRecompute(FuncGraphPtr fg) { | |||||
| this->recursive_analysis_[fg] = CheckRecursive(this->manager_, fg); | this->recursive_analysis_[fg] = CheckRecursive(this->manager_, fg); | ||||
| } | } | ||||
| void RecursiveComputer::CheckRecursiveGraphs(const FuncGraphPtr& fg, std::list<FuncGraphPtr>* trace) { | |||||
| void RecursiveComputer::CheckRecursiveGraphs(const FuncGraphPtr &fg, std::list<FuncGraphPtr> *trace) { | |||||
| MS_EXCEPTION_IF_NULL(trace); | MS_EXCEPTION_IF_NULL(trace); | ||||
| auto res = std::find(trace->begin(), trace->end(), fg); | auto res = std::find(trace->begin(), trace->end(), fg); | ||||
| // find recursive | // find recursive | ||||
| @@ -1124,7 +1124,7 @@ void RecursiveComputer::CheckRecursiveGraphs(const FuncGraphPtr& fg, std::list<F | |||||
| } | } | ||||
| } else { | } else { | ||||
| trace->push_back(fg); | trace->push_back(fg); | ||||
| auto& used_fgs = manager_->func_graphs_used()[fg]; | |||||
| auto &used_fgs = manager_->func_graphs_used()[fg]; | |||||
| for (auto iter = used_fgs.begin(); iter != used_fgs.end(); (void)iter++) { | for (auto iter = used_fgs.begin(); iter != used_fgs.end(); (void)iter++) { | ||||
| CheckRecursiveGraphs(iter->first, trace); | CheckRecursiveGraphs(iter->first, trace); | ||||
| } | } | ||||
| @@ -1135,14 +1135,14 @@ void RecursiveComputer::CheckRecursiveGraphs(const FuncGraphPtr& fg, std::list<F | |||||
| } | } | ||||
| } | } | ||||
| bool FuncGraphJTotalComputer::SeekJ(const FuncGraphPtr& fg, const FuncGraphSetPtr& path) { | |||||
| bool FuncGraphJTotalComputer::SeekJ(const FuncGraphPtr &fg, const FuncGraphSetPtr &path) { | |||||
| MS_EXCEPTION_IF_NULL(path); | MS_EXCEPTION_IF_NULL(path); | ||||
| if (path->contains(fg)) { | if (path->contains(fg)) { | ||||
| MS_LOG(DEBUG) << "" << fg->ToString() << " had been checked"; | MS_LOG(DEBUG) << "" << fg->ToString() << " had been checked"; | ||||
| return false; | return false; | ||||
| } | } | ||||
| MS_EXCEPTION_IF_NULL(manager_); | MS_EXCEPTION_IF_NULL(manager_); | ||||
| auto& func_graph_counter_map = manager_->func_graph_j_direct(); | |||||
| auto &func_graph_counter_map = manager_->func_graph_j_direct(); | |||||
| if (!func_graph_counter_map[fg].empty()) { | if (!func_graph_counter_map[fg].empty()) { | ||||
| // check g1->J(fg)->g2->g cycle; | // check g1->J(fg)->g2->g cycle; | ||||
| auto contains_j = | auto contains_j = | ||||
| @@ -1156,8 +1156,8 @@ bool FuncGraphJTotalComputer::SeekJ(const FuncGraphPtr& fg, const FuncGraphSetPt | |||||
| path->add(fg); | path->add(fg); | ||||
| // check if func graphs used contains J(func_graph); | // check if func graphs used contains J(func_graph); | ||||
| auto& used = this->manager_->func_graphs_used(); | |||||
| for (auto& item : used[fg]) { | |||||
| auto &used = this->manager_->func_graphs_used(); | |||||
| for (auto &item : used[fg]) { | |||||
| auto used_g = item.first; | auto used_g = item.first; | ||||
| if (SeekJ(used_g, path)) { | if (SeekJ(used_g, path)) { | ||||
| MS_LOG(DEBUG) << "" << fg->ToString() << " users func graph " << used_g->ToString() | MS_LOG(DEBUG) << "" << fg->ToString() << " users func graph " << used_g->ToString() | ||||
| @@ -46,13 +46,13 @@ class FuncGraphManager; | |||||
| using FuncGraphManagerPtr = std::shared_ptr<FuncGraphManager>; | using FuncGraphManagerPtr = std::shared_ptr<FuncGraphManager>; | ||||
| struct AnfNodeIndexPairHasher { | struct AnfNodeIndexPairHasher { | ||||
| std::size_t operator()(const std::pair<AnfNodePtr, int>& p1) const { | |||||
| return std::hash<const AnfNode*>{}(p1.first.get()); | |||||
| std::size_t operator()(const std::pair<AnfNodePtr, int> &p1) const { | |||||
| return std::hash<const AnfNode *>{}(p1.first.get()); | |||||
| } | } | ||||
| }; | }; | ||||
| struct AnfNodeIndexPairEqual { | struct AnfNodeIndexPairEqual { | ||||
| bool operator()(const std::pair<AnfNodePtr, int>& lhs, const std::pair<AnfNodePtr, int>& rhs) const { | |||||
| bool operator()(const std::pair<AnfNodePtr, int> &lhs, const std::pair<AnfNodePtr, int> &rhs) const { | |||||
| return lhs == rhs; | return lhs == rhs; | ||||
| } | } | ||||
| }; | }; | ||||
| @@ -63,14 +63,14 @@ using FuncGraphSetPair = std::pair<FuncGraphPtr, FuncGraphSet>; | |||||
| using FuncGraphSetPtr = std::shared_ptr<FuncGraphSet>; | using FuncGraphSetPtr = std::shared_ptr<FuncGraphSet>; | ||||
| using EdgeTuple = std::pair<AnfNodePtr, std::pair<int, AnfNodePtr>>; | using EdgeTuple = std::pair<AnfNodePtr, std::pair<int, AnfNodePtr>>; | ||||
| struct EdgeTupleHasher { | struct EdgeTupleHasher { | ||||
| std::size_t operator()(const EdgeTuple& p1) const { | |||||
| return hash_combine({std::hash<AnfNode*>{}(p1.first.get()), std::hash<int>{}(p1.second.first), | |||||
| std::hash<AnfNode*>{}(p1.second.second.get())}); | |||||
| std::size_t operator()(const EdgeTuple &p1) const { | |||||
| return hash_combine({std::hash<AnfNode *>{}(p1.first.get()), std::hash<int>{}(p1.second.first), | |||||
| std::hash<AnfNode *>{}(p1.second.second.get())}); | |||||
| } | } | ||||
| }; | }; | ||||
| struct EdgeTupleEqual { | struct EdgeTupleEqual { | ||||
| bool operator()(const EdgeTuple& lhs, const EdgeTuple& rhs) const { | |||||
| bool operator()(const EdgeTuple &lhs, const EdgeTuple &rhs) const { | |||||
| return lhs.first == rhs.first && lhs.second.first == rhs.second.first && lhs.second.second == rhs.second.second; | return lhs.first == rhs.first && lhs.second.first == rhs.second.first && lhs.second.second == rhs.second.second; | ||||
| } | } | ||||
| }; | }; | ||||
| @@ -82,9 +82,9 @@ using EdgeTupleCounter = Counter<EdgeTuple, EdgeTupleHasher, EdgeTupleEqual>; | |||||
| // FuncGraphManagerPtr: return created manager | // FuncGraphManagerPtr: return created manager | ||||
| FuncGraphManagerPtr Manage(FuncGraphPtr func_graph, bool manage = true); | FuncGraphManagerPtr Manage(FuncGraphPtr func_graph, bool manage = true); | ||||
| FuncGraphManagerPtr Manage(const std::vector<FuncGraphPtr>& func_graphs, bool manage = true); | |||||
| FuncGraphManagerPtr Manage(const std::vector<FuncGraphPtr> &func_graphs, bool manage = true); | |||||
| FuncGraphManagerPtr MakeManager(const std::vector<FuncGraphPtr>& func_graphs = {}, bool manage = true); | |||||
| FuncGraphManagerPtr MakeManager(const std::vector<FuncGraphPtr> &func_graphs = {}, bool manage = true); | |||||
| struct Signals { | struct Signals { | ||||
| Signal<void(FuncGraphPtr)> AddFuncGraph; | Signal<void(FuncGraphPtr)> AddFuncGraph; | ||||
| @@ -106,7 +106,7 @@ using FuncGraphToAnfNodeCounterMap = OrderedMap<FuncGraphPtr, OrderedMap<AnfNode | |||||
| // analysis base class | // analysis base class | ||||
| class FuncGraphAnalysis { | class FuncGraphAnalysis { | ||||
| public: | public: | ||||
| explicit FuncGraphAnalysis(const FuncGraphManager* const manager); | |||||
| explicit FuncGraphAnalysis(const FuncGraphManager *const manager); | |||||
| virtual ~FuncGraphAnalysis() { manager_ = nullptr; } | virtual ~FuncGraphAnalysis() { manager_ = nullptr; } | ||||
| @@ -130,7 +130,7 @@ class FuncGraphAnalysis { | |||||
| virtual void OnDropEdge(AnfNodePtr, int, AnfNodePtr) {} | virtual void OnDropEdge(AnfNodePtr, int, AnfNodePtr) {} | ||||
| const FuncGraphManager* manager_; | |||||
| const FuncGraphManager *manager_; | |||||
| bool include_func_graph_none_; | bool include_func_graph_none_; | ||||
| }; | }; | ||||
| @@ -139,7 +139,7 @@ using FuncGraphToAnfNodeMap = OrderedMap<FuncGraphPtr, AnfNodeSet>; | |||||
| // graphs analysis which compute in write, read needn't recompute | // graphs analysis which compute in write, read needn't recompute | ||||
| class DepCollector : public FuncGraphAnalysis { | class DepCollector : public FuncGraphAnalysis { | ||||
| public: | public: | ||||
| explicit DepCollector(const FuncGraphManager* manager); | |||||
| explicit DepCollector(const FuncGraphManager *manager); | |||||
| ~DepCollector() override = default; | ~DepCollector() override = default; | ||||
| void Reset() { ExtraReset(); } | void Reset() { ExtraReset(); } | ||||
| @@ -155,10 +155,10 @@ class DepCollector : public FuncGraphAnalysis { | |||||
| class NodesCollector final : public DepCollector { | class NodesCollector final : public DepCollector { | ||||
| public: | public: | ||||
| explicit NodesCollector(const FuncGraphManager* m); | |||||
| explicit NodesCollector(const FuncGraphManager *m); | |||||
| ~NodesCollector() override = default; | ~NodesCollector() override = default; | ||||
| const FuncGraphToAnfNodeMap& nodes_analysis() const { return nodes_analysis_; } | |||||
| const FuncGraphToAnfNodeMap &nodes_analysis() const { return nodes_analysis_; } | |||||
| size_t size() const override { return nodes_analysis_.size(); } | size_t size() const override { return nodes_analysis_.size(); } | ||||
| void OnAddFuncGraph(FuncGraphPtr fg) override { nodes_analysis_[fg] = AnfNodeSet(); } | void OnAddFuncGraph(FuncGraphPtr fg) override { nodes_analysis_[fg] = AnfNodeSet(); } | ||||
| @@ -176,16 +176,16 @@ class NodesCollector final : public DepCollector { | |||||
| class CounterFuncGraphCollector : public DepCollector { | class CounterFuncGraphCollector : public DepCollector { | ||||
| public: | public: | ||||
| explicit CounterFuncGraphCollector(const FuncGraphManager* m) : DepCollector(m) {} | |||||
| explicit CounterFuncGraphCollector(const FuncGraphManager *m) : DepCollector(m) {} | |||||
| ~CounterFuncGraphCollector() override = default; | ~CounterFuncGraphCollector() override = default; | ||||
| FuncGraphToFuncGraphCounterMap& count_func_graphs_map() { return count_func_graphs_map_; } | |||||
| FuncGraphToFuncGraphCounterMap &count_func_graphs_map() { return count_func_graphs_map_; } | |||||
| // inherit from FuncGraphAnalysis | // inherit from FuncGraphAnalysis | ||||
| size_t size() const override { return count_func_graphs_map_.size(); } | size_t size() const override { return count_func_graphs_map_.size(); } | ||||
| void OnAddFuncGraph(FuncGraphPtr fg) final { count_func_graphs_map_[fg] = OrderedMap<FuncGraphPtr, int>(); } | void OnAddFuncGraph(FuncGraphPtr fg) final { count_func_graphs_map_[fg] = OrderedMap<FuncGraphPtr, int>(); } | ||||
| void OnDropFuncGraph(FuncGraphPtr fg) final { (void)count_func_graphs_map_.erase(fg); } | void OnDropFuncGraph(FuncGraphPtr fg) final { (void)count_func_graphs_map_.erase(fg); } | ||||
| bool Inc(const FuncGraphPtr& func_graph, const FuncGraphPtr& key, int count); | |||||
| bool Dec(const FuncGraphPtr& func_graph, const FuncGraphPtr& key, int count); | |||||
| bool Mod(const FuncGraphPtr& func_graph, const FuncGraphPtr& key, int count); | |||||
| bool Inc(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count); | |||||
| bool Dec(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count); | |||||
| bool Mod(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count); | |||||
| FuncGraphToFuncGraphCounterMap count_func_graphs_map_; | FuncGraphToFuncGraphCounterMap count_func_graphs_map_; | ||||
| @@ -195,17 +195,17 @@ class CounterFuncGraphCollector : public DepCollector { | |||||
| class CounterAnfNodeCollector : public DepCollector { | class CounterAnfNodeCollector : public DepCollector { | ||||
| public: | public: | ||||
| explicit CounterAnfNodeCollector(const FuncGraphManager* m) : DepCollector(m) {} | |||||
| explicit CounterAnfNodeCollector(const FuncGraphManager *m) : DepCollector(m) {} | |||||
| ~CounterAnfNodeCollector() override = default; | ~CounterAnfNodeCollector() override = default; | ||||
| FuncGraphToAnfNodeCounterMap& count_nodes_map() { return count_nodes_map_; } | |||||
| FuncGraphToAnfNodeCounterMap &count_nodes_map() { return count_nodes_map_; } | |||||
| size_t size() const override { return count_nodes_map_.size(); } | size_t size() const override { return count_nodes_map_.size(); } | ||||
| void OnAddFuncGraph(FuncGraphPtr fg) final { count_nodes_map_[fg] = OrderedMap<AnfNodePtr, int>(); } | void OnAddFuncGraph(FuncGraphPtr fg) final { count_nodes_map_[fg] = OrderedMap<AnfNodePtr, int>(); } | ||||
| void OnDropFuncGraph(FuncGraphPtr fg) final { (void)count_nodes_map_.erase(fg); } | void OnDropFuncGraph(FuncGraphPtr fg) final { (void)count_nodes_map_.erase(fg); } | ||||
| bool Inc(const FuncGraphPtr& func_graph, const AnfNodePtr& key, int count); | |||||
| bool Dec(const FuncGraphPtr& func_graph, const AnfNodePtr& key, int count); | |||||
| bool Mod(const FuncGraphPtr& func_graph, const AnfNodePtr& key, int count); | |||||
| bool Inc(const FuncGraphPtr &func_graph, const AnfNodePtr &key, int count); | |||||
| bool Dec(const FuncGraphPtr &func_graph, const AnfNodePtr &key, int count); | |||||
| bool Mod(const FuncGraphPtr &func_graph, const AnfNodePtr &key, int count); | |||||
| FuncGraphToAnfNodeCounterMap count_nodes_map_; | FuncGraphToAnfNodeCounterMap count_nodes_map_; | ||||
| @@ -215,7 +215,7 @@ class CounterAnfNodeCollector : public DepCollector { | |||||
| class ValueNodesCollector final : public CounterAnfNodeCollector { | class ValueNodesCollector final : public CounterAnfNodeCollector { | ||||
| public: | public: | ||||
| explicit ValueNodesCollector(const FuncGraphManager* m) : CounterAnfNodeCollector(m) {} | |||||
| explicit ValueNodesCollector(const FuncGraphManager *m) : CounterAnfNodeCollector(m) {} | |||||
| ~ValueNodesCollector() override = default; | ~ValueNodesCollector() override = default; | ||||
| void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override; | void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override; | ||||
| @@ -225,7 +225,7 @@ class ValueNodesCollector final : public CounterAnfNodeCollector { | |||||
| class FuncGraphValueNodesCollector final : public CounterAnfNodeCollector { | class FuncGraphValueNodesCollector final : public CounterAnfNodeCollector { | ||||
| public: | public: | ||||
| explicit FuncGraphValueNodesCollector(const FuncGraphManager* m) : CounterAnfNodeCollector(m) {} | |||||
| explicit FuncGraphValueNodesCollector(const FuncGraphManager *m) : CounterAnfNodeCollector(m) {} | |||||
| ~FuncGraphValueNodesCollector() override = default; | ~FuncGraphValueNodesCollector() override = default; | ||||
| void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override; | void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override; | ||||
| @@ -235,7 +235,7 @@ class FuncGraphValueNodesCollector final : public CounterAnfNodeCollector { | |||||
| class FVDirectCollector final : public CounterAnfNodeCollector { | class FVDirectCollector final : public CounterAnfNodeCollector { | ||||
| public: | public: | ||||
| explicit FVDirectCollector(const FuncGraphManager* m) : CounterAnfNodeCollector(m) {} | |||||
| explicit FVDirectCollector(const FuncGraphManager *m) : CounterAnfNodeCollector(m) {} | |||||
| ~FVDirectCollector() override = default; | ~FVDirectCollector() override = default; | ||||
| void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override; | void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override; | ||||
| @@ -245,7 +245,7 @@ class FVDirectCollector final : public CounterAnfNodeCollector { | |||||
| class FuncGraphChildDirect final : public CounterFuncGraphCollector { | class FuncGraphChildDirect final : public CounterFuncGraphCollector { | ||||
| public: | public: | ||||
| explicit FuncGraphChildDirect(const FuncGraphManager* m) : CounterFuncGraphCollector(m) {} | |||||
| explicit FuncGraphChildDirect(const FuncGraphManager *m) : CounterFuncGraphCollector(m) {} | |||||
| void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override; | void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override; | ||||
| ~FuncGraphChildDirect() override = default; | ~FuncGraphChildDirect() override = default; | ||||
| @@ -260,7 +260,7 @@ class FuncGraphChildDirect final : public CounterFuncGraphCollector { | |||||
| // 2.direct parent: if graph g's node a used free_variable node in graph f, g's direct parent is f key is g, value is f | // 2.direct parent: if graph g's node a used free_variable node in graph f, g's direct parent is f key is g, value is f | ||||
| class FuncGraphParentsDirectCollector final : public CounterFuncGraphCollector { | class FuncGraphParentsDirectCollector final : public CounterFuncGraphCollector { | ||||
| public: | public: | ||||
| explicit FuncGraphParentsDirectCollector(const FuncGraphManager* m) : CounterFuncGraphCollector(m) {} | |||||
| explicit FuncGraphParentsDirectCollector(const FuncGraphManager *m) : CounterFuncGraphCollector(m) {} | |||||
| ~FuncGraphParentsDirectCollector() override = default; | ~FuncGraphParentsDirectCollector() override = default; | ||||
| void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override; | void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override; | ||||
| @@ -271,7 +271,7 @@ class FuncGraphParentsDirectCollector final : public CounterFuncGraphCollector { | |||||
| // graph's all used graphs: key is g, value is g used graph | // graph's all used graphs: key is g, value is g used graph | ||||
| class FuncGraphsUsedCollector final : public CounterFuncGraphCollector { | class FuncGraphsUsedCollector final : public CounterFuncGraphCollector { | ||||
| public: | public: | ||||
| explicit FuncGraphsUsedCollector(const FuncGraphManager* m) : CounterFuncGraphCollector(m) {} | |||||
| explicit FuncGraphsUsedCollector(const FuncGraphManager *m) : CounterFuncGraphCollector(m) {} | |||||
| void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override; | void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override; | ||||
| ~FuncGraphsUsedCollector() override = default; | ~FuncGraphsUsedCollector() override = default; | ||||
| @@ -282,7 +282,7 @@ class FuncGraphsUsedCollector final : public CounterFuncGraphCollector { | |||||
| // graph's all user graphs: key is g, value is graphs who used g | // graph's all user graphs: key is g, value is graphs who used g | ||||
| class FuncGraphUsersCollector final : public CounterFuncGraphCollector { | class FuncGraphUsersCollector final : public CounterFuncGraphCollector { | ||||
| public: | public: | ||||
| explicit FuncGraphUsersCollector(const FuncGraphManager* m) : CounterFuncGraphCollector(m) {} | |||||
| explicit FuncGraphUsersCollector(const FuncGraphManager *m) : CounterFuncGraphCollector(m) {} | |||||
| void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override; | void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override; | ||||
| ~FuncGraphUsersCollector() override = default; | ~FuncGraphUsersCollector() override = default; | ||||
| @@ -293,7 +293,7 @@ class FuncGraphUsersCollector final : public CounterFuncGraphCollector { | |||||
| // graph's all user cnodes: key is g, value is cnodes who used g | // graph's all user cnodes: key is g, value is cnodes who used g | ||||
| class FuncGraphUserNodesCollector final : public CounterAnfNodeCollector { | class FuncGraphUserNodesCollector final : public CounterAnfNodeCollector { | ||||
| public: | public: | ||||
| explicit FuncGraphUserNodesCollector(const FuncGraphManager* m) : CounterAnfNodeCollector(m) {} | |||||
| explicit FuncGraphUserNodesCollector(const FuncGraphManager *m) : CounterAnfNodeCollector(m) {} | |||||
| void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override; | void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override; | ||||
| ~FuncGraphUserNodesCollector() override = default; | ~FuncGraphUserNodesCollector() override = default; | ||||
| @@ -303,7 +303,7 @@ class FuncGraphUserNodesCollector final : public CounterAnfNodeCollector { | |||||
| class FuncGraphJDirectCollector final : public CounterFuncGraphCollector { | class FuncGraphJDirectCollector final : public CounterFuncGraphCollector { | ||||
| public: | public: | ||||
| explicit FuncGraphJDirectCollector(const FuncGraphManager* m) : CounterFuncGraphCollector(m) {} | |||||
| explicit FuncGraphJDirectCollector(const FuncGraphManager *m) : CounterFuncGraphCollector(m) {} | |||||
| void OnMoveAllCNode(FuncGraphPtr src, const FuncGraphPtr dst) override; | void OnMoveAllCNode(FuncGraphPtr src, const FuncGraphPtr dst) override; | ||||
| ~FuncGraphJDirectCollector() override = default; | ~FuncGraphJDirectCollector() override = default; | ||||
| @@ -316,7 +316,7 @@ using FuncGraphToFuncGraphSetMap = OrderedMap<FuncGraphPtr, FuncGraphSet>; | |||||
| // graphs analysis which need dynamic compute by DepCollector in each read | // graphs analysis which need dynamic compute by DepCollector in each read | ||||
| class DepComputer : public FuncGraphAnalysis { | class DepComputer : public FuncGraphAnalysis { | ||||
| public: | public: | ||||
| explicit DepComputer(const FuncGraphManager* manager); | |||||
| explicit DepComputer(const FuncGraphManager *manager); | |||||
| ~DepComputer() override = default; | ~DepComputer() override = default; | ||||
| void Reset() { | void Reset() { | ||||
| @@ -329,11 +329,11 @@ class DepComputer : public FuncGraphAnalysis { | |||||
| void Recompute(); | void Recompute(); | ||||
| void Recompute(const FuncGraphPtr& fg); | |||||
| void Recompute(const FuncGraphPtr &fg); | |||||
| bool IsValidate() const { return validate_; } | bool IsValidate() const { return validate_; } | ||||
| bool IsValidate(const FuncGraphPtr& fg) { return func_graphs_validate_[fg]; } | |||||
| bool IsValidate(const FuncGraphPtr &fg) { return func_graphs_validate_[fg]; } | |||||
| void OnAddFuncGraph(FuncGraphPtr) final { Reset(); } | void OnAddFuncGraph(FuncGraphPtr) final { Reset(); } | ||||
| @@ -354,10 +354,10 @@ class DepComputer : public FuncGraphAnalysis { | |||||
| // graph g's all direct or proxy parents | // graph g's all direct or proxy parents | ||||
| class FuncGraphParentsTotalComputer final : public DepComputer { | class FuncGraphParentsTotalComputer final : public DepComputer { | ||||
| public: | public: | ||||
| explicit FuncGraphParentsTotalComputer(const FuncGraphManager* m) : DepComputer(m), all_parents_direct_(nullptr) {} | |||||
| explicit FuncGraphParentsTotalComputer(const FuncGraphManager *m) : DepComputer(m), all_parents_direct_(nullptr) {} | |||||
| ~FuncGraphParentsTotalComputer() override { all_parents_direct_ = nullptr; } | ~FuncGraphParentsTotalComputer() override { all_parents_direct_ = nullptr; } | ||||
| FuncGraphToFuncGraphSetMap& func_graph_parents_total_analysis() { return func_graph_parents_total_analysis_; } | |||||
| FuncGraphToFuncGraphSetMap &func_graph_parents_total_analysis() { return func_graph_parents_total_analysis_; } | |||||
| size_t size() const override { return func_graph_parents_total_analysis_.size(); } | size_t size() const override { return func_graph_parents_total_analysis_.size(); } | ||||
| @@ -369,10 +369,10 @@ class FuncGraphParentsTotalComputer final : public DepComputer { | |||||
| void RealRecompute(FuncGraphPtr fg) override; | void RealRecompute(FuncGraphPtr fg) override; | ||||
| private: | private: | ||||
| FuncGraphSetPtr SeekParents(const FuncGraphPtr& fg, const FuncGraphSetPtr& path = std::make_shared<FuncGraphSet>()); | |||||
| FuncGraphSetPtr SeekParents(const FuncGraphPtr &fg, const FuncGraphSetPtr &path = std::make_shared<FuncGraphSet>()); | |||||
| // when SeekParents calls itself recursively, it can access these variables by class member | // when SeekParents calls itself recursively, it can access these variables by class member | ||||
| // other than pass by formal parameters, it can save 1 parameter for SeekParents(). | // other than pass by formal parameters, it can save 1 parameter for SeekParents(). | ||||
| FuncGraphToFuncGraphCounterMap* all_parents_direct_; | |||||
| FuncGraphToFuncGraphCounterMap *all_parents_direct_; | |||||
| }; | }; | ||||
| using FuncGraphToFuncGraphMap = OrderedMap<FuncGraphPtr, FuncGraphPtr>; | using FuncGraphToFuncGraphMap = OrderedMap<FuncGraphPtr, FuncGraphPtr>; | ||||
| @@ -380,10 +380,10 @@ using FuncGraphToFuncGraphMap = OrderedMap<FuncGraphPtr, FuncGraphPtr>; | |||||
| // graph's nearest parent in parents total | // graph's nearest parent in parents total | ||||
| class ParentComputer final : public DepComputer { | class ParentComputer final : public DepComputer { | ||||
| public: | public: | ||||
| explicit ParentComputer(const FuncGraphManager* m) : DepComputer(m) {} | |||||
| explicit ParentComputer(const FuncGraphManager *m) : DepComputer(m) {} | |||||
| ~ParentComputer() override = default; | ~ParentComputer() override = default; | ||||
| FuncGraphToFuncGraphMap& parent_analysis() { return parent_analysis_; } | |||||
| FuncGraphToFuncGraphMap &parent_analysis() { return parent_analysis_; } | |||||
| size_t size() const override { return parent_analysis_.size(); } | size_t size() const override { return parent_analysis_.size(); } | ||||
| @@ -398,10 +398,10 @@ class ParentComputer final : public DepComputer { | |||||
| // graph's children graph except self | // graph's children graph except self | ||||
| class ChildrenComputer final : public DepComputer { | class ChildrenComputer final : public DepComputer { | ||||
| public: | public: | ||||
| explicit ChildrenComputer(const FuncGraphManager* m) : DepComputer(m) {} | |||||
| explicit ChildrenComputer(const FuncGraphManager *m) : DepComputer(m) {} | |||||
| ~ChildrenComputer() override = default; | ~ChildrenComputer() override = default; | ||||
| FuncGraphToFuncGraphSetMap& children_analysis() { return children_analysis_; } | |||||
| FuncGraphToFuncGraphSetMap &children_analysis() { return children_analysis_; } | |||||
| size_t size() const override { return children_analysis_.size(); } | size_t size() const override { return children_analysis_.size(); } | ||||
| @@ -416,10 +416,10 @@ class ChildrenComputer final : public DepComputer { | |||||
| // graph's children graph include self | // graph's children graph include self | ||||
| class ScopeComputer final : public DepComputer { | class ScopeComputer final : public DepComputer { | ||||
| public: | public: | ||||
| explicit ScopeComputer(const FuncGraphManager* m) : DepComputer(m) {} | |||||
| explicit ScopeComputer(const FuncGraphManager *m) : DepComputer(m) {} | |||||
| ~ScopeComputer() override = default; | ~ScopeComputer() override = default; | ||||
| FuncGraphToFuncGraphSetMap& scope_analysis() { return scope_analysis_; } | |||||
| FuncGraphToFuncGraphSetMap &scope_analysis() { return scope_analysis_; } | |||||
| size_t size() const override { return scope_analysis_.size(); } | size_t size() const override { return scope_analysis_.size(); } | ||||
| @@ -435,11 +435,11 @@ using FVTotalMap = OrderedMap<FuncGraphPtr, OrderedMap<BaseRef, int, BaseRefHash | |||||
| class FVTotalComputer final : public DepComputer, public CounterAnfNodeCollector, public CounterFuncGraphCollector { | class FVTotalComputer final : public DepComputer, public CounterAnfNodeCollector, public CounterFuncGraphCollector { | ||||
| public: | public: | ||||
| explicit FVTotalComputer(const FuncGraphManager* m) | |||||
| explicit FVTotalComputer(const FuncGraphManager *m) | |||||
| : DepComputer(m), CounterAnfNodeCollector(m), CounterFuncGraphCollector(m) {} | : DepComputer(m), CounterAnfNodeCollector(m), CounterFuncGraphCollector(m) {} | ||||
| ~FVTotalComputer() override = default; | ~FVTotalComputer() override = default; | ||||
| FVTotalMap& fv_total_analysis() { return fv_total_analysis_; } | |||||
| FVTotalMap &fv_total_analysis() { return fv_total_analysis_; } | |||||
| size_t size() const override { return fv_total_analysis_.size(); } | size_t size() const override { return fv_total_analysis_.size(); } | ||||
| @@ -453,10 +453,10 @@ class FVTotalComputer final : public DepComputer, public CounterAnfNodeCollector | |||||
| class FuncGraphsUsedTotalComputer final : public DepComputer { | class FuncGraphsUsedTotalComputer final : public DepComputer { | ||||
| public: | public: | ||||
| explicit FuncGraphsUsedTotalComputer(const FuncGraphManager* m) : DepComputer(m) {} | |||||
| explicit FuncGraphsUsedTotalComputer(const FuncGraphManager *m) : DepComputer(m) {} | |||||
| ~FuncGraphsUsedTotalComputer() override = default; | ~FuncGraphsUsedTotalComputer() override = default; | ||||
| FuncGraphToFuncGraphSetMap& func_graph_used_total_analysis() { return func_graph_used_total_analysis_; } | |||||
| FuncGraphToFuncGraphSetMap &func_graph_used_total_analysis() { return func_graph_used_total_analysis_; } | |||||
| size_t size() const override { return func_graph_used_total_analysis_.size(); } | size_t size() const override { return func_graph_used_total_analysis_.size(); } | ||||
| @@ -473,13 +473,13 @@ using RecursiveMap = OrderedMap<FuncGraphPtr, std::shared_ptr<std::list<FuncGrap | |||||
| class RecursiveComputer final : public DepComputer { | class RecursiveComputer final : public DepComputer { | ||||
| public: | public: | ||||
| explicit RecursiveComputer(const FuncGraphManager* m) : DepComputer(m) {} | |||||
| explicit RecursiveComputer(const FuncGraphManager *m) : DepComputer(m) {} | |||||
| ~RecursiveComputer() override = default; | ~RecursiveComputer() override = default; | ||||
| RecursiveMap& recursive_map() { return recursive_map_; } | |||||
| FuncGraphToBoolMap& recursive_analysis() { return recursive_analysis_; } | |||||
| RecursiveMap &recursive_map() { return recursive_map_; } | |||||
| FuncGraphToBoolMap &recursive_analysis() { return recursive_analysis_; } | |||||
| void CheckRecursiveGraphs(const FuncGraphPtr& fg, std::list<FuncGraphPtr>* trace); | |||||
| void CheckRecursiveGraphs(const FuncGraphPtr &fg, std::list<FuncGraphPtr> *trace); | |||||
| size_t size() const override { return recursive_analysis_.size(); } | size_t size() const override { return recursive_analysis_.size(); } | ||||
| @@ -497,10 +497,10 @@ class RecursiveComputer final : public DepComputer { | |||||
| class FuncGraphJTotalComputer final : public DepComputer { | class FuncGraphJTotalComputer final : public DepComputer { | ||||
| public: | public: | ||||
| explicit FuncGraphJTotalComputer(const FuncGraphManager* m) : DepComputer(m) {} | |||||
| explicit FuncGraphJTotalComputer(const FuncGraphManager *m) : DepComputer(m) {} | |||||
| ~FuncGraphJTotalComputer() override = default; | ~FuncGraphJTotalComputer() override = default; | ||||
| FuncGraphToBoolMap& j_total_analysis() { return j_total_analysis_; } | |||||
| FuncGraphToBoolMap &j_total_analysis() { return j_total_analysis_; } | |||||
| size_t size() const override { return j_total_analysis_.size(); } | size_t size() const override { return j_total_analysis_.size(); } | ||||
| @@ -510,12 +510,12 @@ class FuncGraphJTotalComputer final : public DepComputer { | |||||
| void ExtraReset() override { j_total_analysis_.clear(); } | void ExtraReset() override { j_total_analysis_.clear(); } | ||||
| void RealRecompute(FuncGraphPtr fg) override; | void RealRecompute(FuncGraphPtr fg) override; | ||||
| bool SeekJ(const FuncGraphPtr& fg, const FuncGraphSetPtr& path); | |||||
| bool SeekJ(const FuncGraphPtr &fg, const FuncGraphSetPtr &path); | |||||
| }; | }; | ||||
| class FuncGraphManager : public std::enable_shared_from_this<FuncGraphManager> { | class FuncGraphManager : public std::enable_shared_from_this<FuncGraphManager> { | ||||
| public: | public: | ||||
| explicit FuncGraphManager(const std::vector<FuncGraphPtr>& roots, bool manage = true); | |||||
| explicit FuncGraphManager(const std::vector<FuncGraphPtr> &roots, bool manage = true); | |||||
| ~FuncGraphManager() { | ~FuncGraphManager() { | ||||
| if (is_manage_) { | if (is_manage_) { | ||||
| RemoveRoots(); | RemoveRoots(); | ||||
| @@ -526,71 +526,71 @@ class FuncGraphManager : public std::enable_shared_from_this<FuncGraphManager> { | |||||
| void Init(); | void Init(); | ||||
| void Clear(); | void Clear(); | ||||
| void AddFuncGraph(FuncGraphPtr func_graph, bool is_root = false); | void AddFuncGraph(FuncGraphPtr func_graph, bool is_root = false); | ||||
| void KeepRoots(const std::vector<FuncGraphPtr>& roots = {}); | |||||
| void KeepRoots(const std::vector<FuncGraphPtr> &roots = {}); | |||||
| void RemoveRoots(); | void RemoveRoots(); | ||||
| void SetParameters(const FuncGraphPtr& fg, const std::vector<AnfNodePtr>& parameters); | |||||
| void MaybeDropFuncGraphs(const FuncGraphSet& func_graphs, bool ignore_users = false); | |||||
| bool Replace(const AnfNodePtr& old_node, const AnfNodePtr& new_node); | |||||
| void SetEdge(const AnfNodePtr& node, int index, const AnfNodePtr& value); | |||||
| void MoveAllCNodeDropGraph(FuncGraphPtr source, FuncGraphPtr target, const ScopePtr& scope); | |||||
| void SetParameters(const FuncGraphPtr &fg, const std::vector<AnfNodePtr> ¶meters); | |||||
| void MaybeDropFuncGraphs(const FuncGraphSet &func_graphs, bool ignore_users = false); | |||||
| bool Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node); | |||||
| void SetEdge(const AnfNodePtr &node, int index, const AnfNodePtr &value); | |||||
| void MoveAllCNodeDropGraph(FuncGraphPtr source, FuncGraphPtr target, const ScopePtr &scope); | |||||
| FuncGraphTransaction Transact(); | FuncGraphTransaction Transact(); | ||||
| void CommitChanges(const std::vector<Change>& changes); | |||||
| void CommitChanges(const std::vector<Change> &changes); | |||||
| bool IsManaged() const { return is_manage_; } | bool IsManaged() const { return is_manage_; } | ||||
| const FuncGraphSet& roots() const { return roots_; } | |||||
| const FuncGraphSet &roots() const { return roots_; } | |||||
| const FuncGraphSet& func_graphs() const { return func_graphs_; } | |||||
| const FuncGraphSet &func_graphs() const { return func_graphs_; } | |||||
| AnfNodeSet& all_nodes() { return all_nodes_; } | |||||
| AnfNodeSet &all_nodes() { return all_nodes_; } | |||||
| NodeUsersMap& node_users() { return node_users_; } | |||||
| NodeUsersMap &node_users() { return node_users_; } | |||||
| FuncGraphToAnfNodeMap& nodes() const { return nodes_->nodes_analysis_; } | |||||
| FuncGraphToAnfNodeMap &nodes() const { return nodes_->nodes_analysis_; } | |||||
| FuncGraphToAnfNodeCounterMap& valuenodes() const { return valuenodes_->count_nodes_map_; } | |||||
| FuncGraphToAnfNodeCounterMap &valuenodes() const { return valuenodes_->count_nodes_map_; } | |||||
| FuncGraphToAnfNodeCounterMap& free_variables_direct() const { return free_variables_direct_->count_nodes_map_; } | |||||
| FuncGraphToAnfNodeCounterMap &free_variables_direct() const { return free_variables_direct_->count_nodes_map_; } | |||||
| FuncGraphToAnfNodeCounterMap& func_graph_valuenodes() const { return func_graph_valuenodes_->count_nodes_map_; } | |||||
| FuncGraphToAnfNodeCounterMap &func_graph_valuenodes() const { return func_graph_valuenodes_->count_nodes_map_; } | |||||
| FuncGraphToFuncGraphCounterMap& func_graphs_used() const { return func_graphs_used_->count_func_graphs_map_; } | |||||
| FuncGraphToFuncGraphCounterMap &func_graphs_used() const { return func_graphs_used_->count_func_graphs_map_; } | |||||
| FuncGraphToFuncGraphCounterMap& func_graph_users() const { return func_graph_users_->count_func_graphs_map_; } | |||||
| FuncGraphToFuncGraphCounterMap &func_graph_users() const { return func_graph_users_->count_func_graphs_map_; } | |||||
| FuncGraphToAnfNodeCounterMap& func_graph_user_cnodes() const { return func_graph_user_cnodes_->count_nodes_map_; } | |||||
| FuncGraphToAnfNodeCounterMap &func_graph_user_cnodes() const { return func_graph_user_cnodes_->count_nodes_map_; } | |||||
| FuncGraphToFuncGraphCounterMap& func_graph_child_direct() const { | |||||
| FuncGraphToFuncGraphCounterMap &func_graph_child_direct() const { | |||||
| return func_graph_child_direct_->count_func_graphs_map_; | return func_graph_child_direct_->count_func_graphs_map_; | ||||
| } | } | ||||
| FuncGraphToFuncGraphCounterMap& func_graph_parents_direct() const { | |||||
| FuncGraphToFuncGraphCounterMap &func_graph_parents_direct() const { | |||||
| return func_graph_parents_direct_->count_func_graphs_map_; | return func_graph_parents_direct_->count_func_graphs_map_; | ||||
| } | } | ||||
| FuncGraphToFuncGraphCounterMap& func_graph_j_direct() const { return func_graph_j_direct_->count_func_graphs_map_; } | |||||
| FuncGraphToFuncGraphCounterMap &func_graph_j_direct() const { return func_graph_j_direct_->count_func_graphs_map_; } | |||||
| FVTotalMap& free_variables_total() const; | |||||
| FVTotalMap &free_variables_total() const; | |||||
| FuncGraphSet& func_graph_parents_total(const FuncGraphPtr& fg) const; | |||||
| FuncGraphSet &func_graph_parents_total(const FuncGraphPtr &fg) const; | |||||
| FuncGraphSet& scopes(const FuncGraphPtr& fg) const; | |||||
| FuncGraphSet &scopes(const FuncGraphPtr &fg) const; | |||||
| FuncGraphPtr parent(const FuncGraphPtr& fg) const; | |||||
| FuncGraphPtr parent(const FuncGraphPtr &fg) const; | |||||
| FuncGraphSet& children(const FuncGraphPtr& fg) const; | |||||
| FuncGraphSet &children(const FuncGraphPtr &fg) const; | |||||
| FuncGraphSet& func_graphs_used_total(const FuncGraphPtr& fg) const; | |||||
| FuncGraphSet &func_graphs_used_total(const FuncGraphPtr &fg) const; | |||||
| bool recursive(const FuncGraphPtr& fg) const; | |||||
| std::shared_ptr<std::list<FuncGraphPtr>> recursive_graphs(const FuncGraphPtr& fg) const; | |||||
| bool recursive(const FuncGraphPtr &fg) const; | |||||
| std::shared_ptr<std::list<FuncGraphPtr>> recursive_graphs(const FuncGraphPtr &fg) const; | |||||
| bool func_graph_j_total(const FuncGraphPtr& fg) const; | |||||
| bool func_graph_j_total(const FuncGraphPtr &fg) const; | |||||
| std::shared_ptr<Signals> signals() const { return signals_; } | std::shared_ptr<Signals> signals() const { return signals_; } | ||||
| IncludeType Limit(const AnfNodePtr& node); | |||||
| IncludeType Limit(const AnfNodePtr &node); | |||||
| // Static Analysis | // Static Analysis | ||||
| NodeUsersMap node_users_; | NodeUsersMap node_users_; | ||||
| @@ -610,13 +610,13 @@ class FuncGraphManager : public std::enable_shared_from_this<FuncGraphManager> { | |||||
| std::shared_ptr<ParentComputer> func_graph_parent_; | std::shared_ptr<ParentComputer> func_graph_parent_; | ||||
| private: | private: | ||||
| void AddIntoManaged(const FuncGraphPtr& fg); | |||||
| void AddIntoManaged(const FuncGraphPtr &fg); | |||||
| void ProcessEdge(AnfNodePtr node, int index, AnfNodePtr inp, EdgeProcessDirection direction); | void ProcessEdge(AnfNodePtr node, int index, AnfNodePtr inp, EdgeProcessDirection direction); | ||||
| void ProcessInputs(const AnfNodePtr& node, EdgeProcessDirection direction); | |||||
| void AcquireNodes(const std::vector<AnfNodePtr>& nodes); | |||||
| FuncGraphSetPtr MaybeDropNodes(const std::vector<AnfNodePtr>& nodes); | |||||
| void ParseChanges(const std::vector<Change>& changes, EdgeTupleCounter* add_edges, EdgeTupleCounter* rm_edges, | |||||
| Counter<AnfNodePtr>* adds, Counter<AnfNodePtr>* rms); | |||||
| void ProcessInputs(const AnfNodePtr &node, EdgeProcessDirection direction); | |||||
| void AcquireNodes(const std::vector<AnfNodePtr> &nodes); | |||||
| FuncGraphSetPtr MaybeDropNodes(const std::vector<AnfNodePtr> &nodes); | |||||
| void ParseChanges(const std::vector<Change> &changes, EdgeTupleCounter *add_edges, EdgeTupleCounter *rm_edges, | |||||
| Counter<AnfNodePtr> *adds, Counter<AnfNodePtr> *rms); | |||||
| FuncGraphSet roots_; // managed roots | FuncGraphSet roots_; // managed roots | ||||
| FuncGraphSet func_graphs_; // managed func graphs | FuncGraphSet func_graphs_; // managed func graphs | ||||
| @@ -637,7 +637,7 @@ class FuncGraphManager : public std::enable_shared_from_this<FuncGraphManager> { | |||||
| class FuncGraphTransaction { | class FuncGraphTransaction { | ||||
| public: | public: | ||||
| explicit FuncGraphTransaction(FuncGraphManager* manager) : manager_(manager), changes_() { | |||||
| explicit FuncGraphTransaction(FuncGraphManager *manager) : manager_(manager), changes_() { | |||||
| MS_EXCEPTION_IF_NULL(manager_); | MS_EXCEPTION_IF_NULL(manager_); | ||||
| if (!manager_->IsManaged()) { | if (!manager_->IsManaged()) { | ||||
| MS_LOG(DEBUG) << "The manager is not managed yet"; | MS_LOG(DEBUG) << "The manager is not managed yet"; | ||||
| @@ -648,19 +648,19 @@ class FuncGraphTransaction { | |||||
| ~FuncGraphTransaction() { manager_ = nullptr; } | ~FuncGraphTransaction() { manager_ = nullptr; } | ||||
| // set parameters of a func graph | // set parameters of a func graph | ||||
| void SetParameters(FuncGraphPtr fg, const std::vector<AnfNodePtr>& params); | |||||
| void SetParameters(FuncGraphPtr fg, const std::vector<AnfNodePtr> ¶ms); | |||||
| // replace old_node with new_node | // replace old_node with new_node | ||||
| bool Replace(const AnfNodePtr& old_node, const AnfNodePtr& new_node); | |||||
| bool Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node); | |||||
| // set esge, i.e., declare setting node.inputs[key] to value. | // set esge, i.e., declare setting node.inputs[key] to value. | ||||
| void SetEdge(const AnfNodePtr& src_node, int k, const AnfNodePtr& v); | |||||
| void SetEdge(const AnfNodePtr &src_node, int k, const AnfNodePtr &v); | |||||
| // commit all changes | // commit all changes | ||||
| void Commit(); | void Commit(); | ||||
| private: | private: | ||||
| FuncGraphManager* manager_; | |||||
| FuncGraphManager *manager_; | |||||
| std::vector<Change> changes_; | std::vector<Change> changes_; | ||||
| }; | }; | ||||
| @@ -668,9 +668,9 @@ class FuncGraphTransaction { | |||||
| struct ArgsOfSetParams { | struct ArgsOfSetParams { | ||||
| FuncGraphPtr func_graph; | FuncGraphPtr func_graph; | ||||
| std::vector<AnfNodePtr> params; | std::vector<AnfNodePtr> params; | ||||
| bool operator==(const ArgsOfSetParams& other) const { return &other == this; } | |||||
| bool operator==(const ArgsOfSetParams &other) const { return &other == this; } | |||||
| friend std::ostream& operator<<(std::ostream& os, const ArgsOfSetParams&) { | |||||
| friend std::ostream &operator<<(std::ostream &os, const ArgsOfSetParams &) { | |||||
| os << "[ArgsOfSetParams]"; | os << "[ArgsOfSetParams]"; | ||||
| return os; | return os; | ||||
| } | } | ||||
| @@ -681,9 +681,9 @@ struct ArgsOfSetEdge { | |||||
| CNodePtr root_node; | CNodePtr root_node; | ||||
| AnfNodePtr new_node; | AnfNodePtr new_node; | ||||
| size_t index; | size_t index; | ||||
| bool operator==(const ArgsOfSetEdge& other) const { return &other == this; } | |||||
| bool operator==(const ArgsOfSetEdge &other) const { return &other == this; } | |||||
| friend std::ostream& operator<<(std::ostream& os, const ArgsOfSetEdge& other) { | |||||
| friend std::ostream &operator<<(std::ostream &os, const ArgsOfSetEdge &other) { | |||||
| os << "[ArgsOfSetEdge]"; | os << "[ArgsOfSetEdge]"; | ||||
| return os; | return os; | ||||
| } | } | ||||
| @@ -693,7 +693,7 @@ struct Change { | |||||
| enum OpName { kTxSetParams, kTxSetEdge }; | enum OpName { kTxSetParams, kTxSetEdge }; | ||||
| OpName op; | OpName op; | ||||
| Any args; | Any args; | ||||
| Change(OpName name, const Any& para) : op(name), args(para) {} | |||||
| Change(OpName name, const Any ¶) : op(name), args(para) {} | |||||
| }; | }; | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -42,25 +42,25 @@ namespace mindspore { | |||||
| // generate a graph corresponding to these types. | // generate a graph corresponding to these types. | ||||
| class MetaFuncGraph : public FuncGraphBase { | class MetaFuncGraph : public FuncGraphBase { | ||||
| public: | public: | ||||
| explicit MetaFuncGraph(const std::string& name) : name_(name) { cache_.clear(); } | |||||
| explicit MetaFuncGraph(const std::string &name) : name_(name) { cache_.clear(); } | |||||
| ~MetaFuncGraph() override = default; | ~MetaFuncGraph() override = default; | ||||
| MS_DECLARE_PARENT(MetaFuncGraph, FuncGraphBase); | MS_DECLARE_PARENT(MetaFuncGraph, FuncGraphBase); | ||||
| abstract::AbstractBasePtr MakeAbstractClosure(const AnfNodePtr& anf_node); | |||||
| abstract::AbstractBasePtr MakeAbstractClosure(const AnfNodePtr &anf_node); | |||||
| // Return normalized versions of the arguments. | // Return normalized versions of the arguments. | ||||
| // By default, this returns args unchanged. | // By default, this returns args unchanged. | ||||
| virtual abstract::AbstractBasePtrList NormalizeArgs(const abstract::AbstractBasePtrList& args_spec_list) const { | |||||
| virtual abstract::AbstractBasePtrList NormalizeArgs(const abstract::AbstractBasePtrList &args_spec_list) const { | |||||
| return args_spec_list; | return args_spec_list; | ||||
| } | } | ||||
| const std::vector<Signature>& signatures() const { return signatures_; } | |||||
| void set_signatures(const std::vector<Signature>& signatures) { signatures_ = signatures; } | |||||
| const std::vector<Signature> &signatures() const { return signatures_; } | |||||
| void set_signatures(const std::vector<Signature> &signatures) { signatures_ = signatures; } | |||||
| // Generate a Graph for the given abstract arguments. | // Generate a Graph for the given abstract arguments. | ||||
| virtual FuncGraphPtr GenerateFuncGraph(const abstract::AbstractBasePtrList& args_spec_list) { | |||||
| virtual FuncGraphPtr GenerateFuncGraph(const abstract::AbstractBasePtrList &args_spec_list) { | |||||
| TypePtrList types; | TypePtrList types; | ||||
| (void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(types), | (void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(types), | ||||
| [](const AbstractBasePtr& arg) -> TypePtr { | |||||
| [](const AbstractBasePtr &arg) -> TypePtr { | |||||
| MS_EXCEPTION_IF_NULL(arg); | MS_EXCEPTION_IF_NULL(arg); | ||||
| return arg->BuildType(); | return arg->BuildType(); | ||||
| }); | }); | ||||
| @@ -81,7 +81,7 @@ class MetaFuncGraph : public FuncGraphBase { | |||||
| } | } | ||||
| // Generate a Graph for this type signature. | // Generate a Graph for this type signature. | ||||
| virtual FuncGraphPtr GenerateFromTypes(const TypePtrList&) { | |||||
| virtual FuncGraphPtr GenerateFromTypes(const TypePtrList &) { | |||||
| MS_LOG(EXCEPTION) << "Undefine the method of generating graph from types."; | MS_LOG(EXCEPTION) << "Undefine the method of generating graph from types."; | ||||
| } | } | ||||
| @@ -89,8 +89,8 @@ class MetaFuncGraph : public FuncGraphBase { | |||||
| std::string ToString() const override { return name_; } | std::string ToString() const override { return name_; } | ||||
| std::size_t hash() const override { return tid(); } | std::size_t hash() const override { return tid(); } | ||||
| virtual bool operator==(const MetaFuncGraph& other) const { return &other == this; } | |||||
| bool operator==(const Value& other) const override { | |||||
| virtual bool operator==(const MetaFuncGraph &other) const { return &other == this; } | |||||
| bool operator==(const Value &other) const override { | |||||
| if (other.isa<MetaFuncGraph>()) { | if (other.isa<MetaFuncGraph>()) { | ||||
| return &other == this; | return &other == this; | ||||
| } else { | } else { | ||||
| @@ -31,7 +31,7 @@ namespace mindspore { | |||||
| namespace tensor { | namespace tensor { | ||||
| void DataBuf2Contiguous(const py::array& src, py::array* const dest) { | |||||
| void DataBuf2Contiguous(const py::array &src, py::array *const dest) { | |||||
| if (dest == nullptr) { | if (dest == nullptr) { | ||||
| MS_LOG(EXCEPTION) << "Failed to copy data to a contiguous buffer as dest is nullptr!"; | MS_LOG(EXCEPTION) << "Failed to copy data to a contiguous buffer as dest is nullptr!"; | ||||
| } | } | ||||
| @@ -55,9 +55,9 @@ void DataBuf2Contiguous(const py::array& src, py::array* const dest) { | |||||
| // MetaTensor has default type_id_ which is TypeId::kTypeUnknown. | // MetaTensor has default type_id_ which is TypeId::kTypeUnknown. | ||||
| MetaTensor::MetaTensor() : data_type_(TypeId::kTypeUnknown) {} | MetaTensor::MetaTensor() : data_type_(TypeId::kTypeUnknown) {} | ||||
| MetaTensor::MetaTensor(const TypeId data_type, const std::vector<int>& shape) : data_type_(data_type), shape_(shape) {} | |||||
| MetaTensor::MetaTensor(const TypeId data_type, const std::vector<int> &shape) : data_type_(data_type), shape_(shape) {} | |||||
| MetaTensor::MetaTensor(const TypePtr& type_ptr, const py::tuple& shape) { | |||||
| MetaTensor::MetaTensor(const TypePtr &type_ptr, const py::tuple &shape) { | |||||
| TypeId data_type = TypeId::kTypeUnknown; | TypeId data_type = TypeId::kTypeUnknown; | ||||
| if (type_ptr != nullptr) { | if (type_ptr != nullptr) { | ||||
| data_type = type_ptr->type_id(); | data_type = type_ptr->type_id(); | ||||
| @@ -69,10 +69,10 @@ MetaTensor::MetaTensor(const TypePtr& type_ptr, const py::tuple& shape) { | |||||
| } | } | ||||
| } | } | ||||
| MetaTensor::MetaTensor(const MetaTensor& meta_tensor) | |||||
| MetaTensor::MetaTensor(const MetaTensor &meta_tensor) | |||||
| : Value(meta_tensor), data_type_(meta_tensor.data_type()), shape_(meta_tensor.shape()) {} | : Value(meta_tensor), data_type_(meta_tensor.data_type()), shape_(meta_tensor.shape()) {} | ||||
| MetaTensor& MetaTensor::operator=(const MetaTensor& meta_tensor) { | |||||
| MetaTensor &MetaTensor::operator=(const MetaTensor &meta_tensor) { | |||||
| if (&meta_tensor == this) { | if (&meta_tensor == this) { | ||||
| return *this; | return *this; | ||||
| } | } | ||||
| @@ -84,7 +84,7 @@ MetaTensor& MetaTensor::operator=(const MetaTensor& meta_tensor) { | |||||
| return *this; | return *this; | ||||
| } | } | ||||
| bool MetaTensor::operator==(const MetaTensor& meta_tensor) const { | |||||
| bool MetaTensor::operator==(const MetaTensor &meta_tensor) const { | |||||
| return data_type_ == meta_tensor.data_type() && shape_ == meta_tensor.shape(); | return data_type_ == meta_tensor.data_type() && shape_ == meta_tensor.shape(); | ||||
| } | } | ||||
| @@ -117,7 +117,7 @@ TypePtr MetaTensor::SetDtype(const TypePtr type_ptr) { | |||||
| return type_ptr; | return type_ptr; | ||||
| } | } | ||||
| void MetaTensor::SetDeviceInfo(const std::string& format, const TypePtr& data_type) { | |||||
| void MetaTensor::SetDeviceInfo(const std::string &format, const TypePtr &data_type) { | |||||
| DeviceInfo info(format, data_type); | DeviceInfo info(format, data_type); | ||||
| set_device_info(info); | set_device_info(info); | ||||
| } | } | ||||
| @@ -138,7 +138,7 @@ std::string MetaTensor::DumpText() const { | |||||
| return oss.str(); | return oss.str(); | ||||
| } | } | ||||
| Tensor::Tensor(const TypePtr& type_ptr, const py::tuple& shape) { | |||||
| Tensor::Tensor(const TypePtr &type_ptr, const py::tuple &shape) { | |||||
| TypeId data_type = TypeId::kTypeUnknown; | TypeId data_type = TypeId::kTypeUnknown; | ||||
| if (type_ptr != nullptr) { | if (type_ptr != nullptr) { | ||||
| data_type = type_ptr->type_id(); | data_type = type_ptr->type_id(); | ||||
| @@ -151,24 +151,24 @@ Tensor::Tensor(const TypePtr& type_ptr, const py::tuple& shape) { | |||||
| init(data_type_, shape_, &data_); | init(data_type_, shape_, &data_); | ||||
| } | } | ||||
| Tensor::Tensor(TypeId data_type, const std::vector<int>& shape) { init(data_type, shape, &data_); } | |||||
| Tensor::Tensor(TypeId data_type, const std::vector<int> &shape) { init(data_type, shape, &data_); } | |||||
| Tensor::Tensor(const py::array& input, const TypePtr& data_type) { init(input, data_type); } | |||||
| Tensor::Tensor(const py::array &input, const TypePtr &data_type) { init(input, data_type); } | |||||
| Tensor::Tensor(const py::list& input, const TypePtr& data_type) { init(py::array(input), data_type); } | |||||
| Tensor::Tensor(const py::list &input, const TypePtr &data_type) { init(py::array(input), data_type); } | |||||
| Tensor::Tensor(const py::tuple& input, const TypePtr& data_type) { init(py::array(input), data_type); } | |||||
| Tensor::Tensor(const py::tuple &input, const TypePtr &data_type) { init(py::array(input), data_type); } | |||||
| Tensor::Tensor(const py::float_& input, const TypePtr& data_type) { init(py::array(input), data_type); } | |||||
| Tensor::Tensor(const py::float_ &input, const TypePtr &data_type) { init(py::array(input), data_type); } | |||||
| Tensor::Tensor(const py::int_& input, const TypePtr& data_type) { init(py::array(input), data_type); } | |||||
| Tensor::Tensor(const py::int_ &input, const TypePtr &data_type) { init(py::array(input), data_type); } | |||||
| Tensor::Tensor(const Tensor& tensor, const TypePtr& data_type) | |||||
| Tensor::Tensor(const Tensor &tensor, const TypePtr &data_type) | |||||
| : MetaTensor(tensor), device_address_(tensor.device_address()) { | : MetaTensor(tensor), device_address_(tensor.device_address()) { | ||||
| init(tensor.data_, data_type); | init(tensor.data_, data_type); | ||||
| } | } | ||||
| Tensor& Tensor::operator=(const Tensor& tensor) { | |||||
| Tensor &Tensor::operator=(const Tensor &tensor) { | |||||
| if (this != &tensor) { | if (this != &tensor) { | ||||
| MetaTensor::operator=(tensor); | MetaTensor::operator=(tensor); | ||||
| dirty_ = tensor.is_dirty(); | dirty_ = tensor.is_dirty(); | ||||
| @@ -178,11 +178,11 @@ Tensor& Tensor::operator=(const Tensor& tensor) { | |||||
| return *this; | return *this; | ||||
| } | } | ||||
| bool Tensor::operator==(const Tensor& tensor) const { | |||||
| bool Tensor::operator==(const Tensor &tensor) const { | |||||
| return (MetaTensor::operator==(tensor) && data_ == tensor.data_); | return (MetaTensor::operator==(tensor) && data_ == tensor.data_); | ||||
| } | } | ||||
| bool Tensor::ValueEqualPy(const py::object& other) const { | |||||
| bool Tensor::ValueEqualPy(const py::object &other) const { | |||||
| if (!py::isinstance<Tensor>(other)) { | if (!py::isinstance<Tensor>(other)) { | ||||
| MS_LOG(WARNING) << "compare other not a tensor"; | MS_LOG(WARNING) << "compare other not a tensor"; | ||||
| return false; | return false; | ||||
| @@ -190,7 +190,7 @@ bool Tensor::ValueEqualPy(const py::object& other) const { | |||||
| return ValueEqual(py::cast<Tensor>(other)); | return ValueEqual(py::cast<Tensor>(other)); | ||||
| } | } | ||||
| bool Tensor::ValueEqual(const Tensor& other) const { | |||||
| bool Tensor::ValueEqual(const Tensor &other) const { | |||||
| auto equal = [&other, this]() -> bool { | auto equal = [&other, this]() -> bool { | ||||
| auto np = py::module::import("numpy"); | auto np = py::module::import("numpy"); | ||||
| auto equal = np.attr("equal")(data_, other.data_); | auto equal = np.attr("equal")(data_, other.data_); | ||||
| @@ -218,7 +218,7 @@ int Tensor::data_type_c() const { return static_cast<int>(data_type_); } | |||||
| std::vector<int> Tensor::shape_c(void) const { return shape(); } | std::vector<int> Tensor::shape_c(void) const { return shape(); } | ||||
| void* Tensor::data_c(bool writable) { | |||||
| void *Tensor::data_c(bool writable) { | |||||
| // operand of bit operation should be unsigned int. | // operand of bit operation should be unsigned int. | ||||
| unsigned int flags = ((unsigned int)data_.flags()) & pybind11::detail::npy_api::NPY_ARRAY_C_CONTIGUOUS_; | unsigned int flags = ((unsigned int)data_.flags()) & pybind11::detail::npy_api::NPY_ARRAY_C_CONTIGUOUS_; | ||||
| bool is_c_contiguous = (flags != 0) ? true : false; | bool is_c_contiguous = (flags != 0) ? true : false; | ||||
| @@ -231,7 +231,7 @@ void* Tensor::data_c(bool writable) { | |||||
| return data_.request(writable).ptr; | return data_.request(writable).ptr; | ||||
| } | } | ||||
| TypeId Tensor::GetDataType(const py::buffer_info& buf) const { | |||||
| TypeId Tensor::GetDataType(const py::buffer_info &buf) const { | |||||
| TypeId data_type = TypeId::kTypeUnknown; | TypeId data_type = TypeId::kTypeUnknown; | ||||
| if (buf.format.compare("e") == 0) { | if (buf.format.compare("e") == 0) { | ||||
| data_type = TypeId::kNumberTypeFloat16; | data_type = TypeId::kNumberTypeFloat16; | ||||
| @@ -263,7 +263,7 @@ TypeId Tensor::GetDataType(const py::buffer_info& buf) const { | |||||
| return data_type; | return data_type; | ||||
| } | } | ||||
| void Tensor::init(const py::array& input, const TypePtr& type_ptr) { | |||||
| void Tensor::init(const py::array &input, const TypePtr &type_ptr) { | |||||
| TypeId data_type = TypeId::kTypeUnknown; | TypeId data_type = TypeId::kTypeUnknown; | ||||
| if (type_ptr != nullptr) { | if (type_ptr != nullptr) { | ||||
| data_type = type_ptr->type_id(); | data_type = type_ptr->type_id(); | ||||
| @@ -271,7 +271,7 @@ void Tensor::init(const py::array& input, const TypePtr& type_ptr) { | |||||
| init(input, data_type); | init(input, data_type); | ||||
| } | } | ||||
| void Tensor::init(const py::array& input, const TypeId& data_type) { | |||||
| void Tensor::init(const py::array &input, const TypeId &data_type) { | |||||
| py::buffer_info buf = input.request(); | py::buffer_info buf = input.request(); | ||||
| data_type_ = GetDataType(buf); | data_type_ = GetDataType(buf); | ||||
| @@ -301,7 +301,7 @@ void Tensor::init(const py::array& input, const TypeId& data_type) { | |||||
| } | } | ||||
| } | } | ||||
| void Tensor::init(TypeId data_type, const std::vector<int>& shape, py::array* const data) { | |||||
| void Tensor::init(TypeId data_type, const std::vector<int> &shape, py::array *const data) { | |||||
| data_type_ = data_type; | data_type_ = data_type; | ||||
| shape_ = shape; | shape_ = shape; | ||||
| switch (data_type) { | switch (data_type) { | ||||
| @@ -368,7 +368,7 @@ TypeId Tensor::set_data_type(const TypeId data_type) { | |||||
| return data_type_; | return data_type_; | ||||
| } | } | ||||
| bool Tensor::convert_data(const py::array& in, const TypeId in_data_type, py::array* const out, | |||||
| bool Tensor::convert_data(const py::array &in, const TypeId in_data_type, py::array *const out, | |||||
| const TypeId out_data_type) { | const TypeId out_data_type) { | ||||
| if (out == nullptr) { | if (out == nullptr) { | ||||
| return false; | return false; | ||||
| @@ -458,7 +458,7 @@ py::array Tensor::data_sync() { | |||||
| return data_; | return data_; | ||||
| } | } | ||||
| REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module* m) { | |||||
| REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) { | |||||
| // dtype should define before Tensor, because Tensor init depend dtype | // dtype should define before Tensor, because Tensor init depend dtype | ||||
| (void)py::class_<Tensor, std::shared_ptr<Tensor>>(*m, "Tensor") | (void)py::class_<Tensor, std::shared_ptr<Tensor>>(*m, "Tensor") | ||||
| .def(py::init<TypePtr, py::tuple>(), py::arg("dtype"), py::arg("shape")) | .def(py::init<TypePtr, py::tuple>(), py::arg("dtype"), py::arg("shape")) | ||||
| @@ -541,11 +541,11 @@ REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module* m) { | |||||
| .def("__repr__", &Tensor::ToStringRepr) | .def("__repr__", &Tensor::ToStringRepr) | ||||
| .def("__eq__", &Tensor::ValueEqualPy) | .def("__eq__", &Tensor::ValueEqualPy) | ||||
| .def(py::pickle( | .def(py::pickle( | ||||
| [](const Tensor& t) { // __getstate__ | |||||
| [](const Tensor &t) { // __getstate__ | |||||
| /* Return a tuple that fully encodes the state of the object */ | /* Return a tuple that fully encodes the state of the object */ | ||||
| return py::make_tuple(t.data()); | return py::make_tuple(t.data()); | ||||
| }, | }, | ||||
| [](const py::tuple& t) { // __setstate__ | |||||
| [](const py::tuple &t) { // __setstate__ | |||||
| if (t.size() != 1) { | if (t.size() != 1) { | ||||
| throw std::runtime_error("Invalid state!"); | throw std::runtime_error("Invalid state!"); | ||||
| } | } | ||||
| @@ -131,16 +131,16 @@ class MetaTensor : public Value { | |||||
| // information of a Tensor. The following codes will create a 2x3 float | // information of a Tensor. The following codes will create a 2x3 float | ||||
| // param data_type The data type of the tensor. | // param data_type The data type of the tensor. | ||||
| // param shape The shape of the tensor. | // param shape The shape of the tensor. | ||||
| MetaTensor(const TypeId data_type, const std::vector<int>& shape); | |||||
| MetaTensor(const TypeId data_type, const std::vector<int> &shape); | |||||
| MetaTensor(const TypePtr& type_ptr, const py::tuple& shape); | |||||
| MetaTensor(const TypePtr &type_ptr, const py::tuple &shape); | |||||
| // brief Constructs a MetaTensor object from an existing MetaTensor instance. | // brief Constructs a MetaTensor object from an existing MetaTensor instance. | ||||
| // | // | ||||
| // The constructed MetaTensor object will have the same data type and shape as the | // The constructed MetaTensor object will have the same data type and shape as the | ||||
| // meta_tensor. | // meta_tensor. | ||||
| // | // | ||||
| // param meta_tensor An existing MetaTensor object. | // param meta_tensor An existing MetaTensor object. | ||||
| MetaTensor(const MetaTensor& meta_tensor); | |||||
| MetaTensor(const MetaTensor &meta_tensor); | |||||
| ~MetaTensor() override = default; | ~MetaTensor() override = default; | ||||
| MS_DECLARE_PARENT(MetaTensor, Value) | MS_DECLARE_PARENT(MetaTensor, Value) | ||||
| @@ -149,7 +149,7 @@ class MetaTensor : public Value { | |||||
| // The constructed MetaTensor object has the same type and shape with meta_tensor. | // The constructed MetaTensor object has the same type and shape with meta_tensor. | ||||
| // | // | ||||
| // param meta_tensor An existing MetaTensor object. | // param meta_tensor An existing MetaTensor object. | ||||
| virtual MetaTensor& operator=(const MetaTensor& meta_tensor); | |||||
| virtual MetaTensor &operator=(const MetaTensor &meta_tensor); | |||||
| // brief Compares two MetaTensor objects. | // brief Compares two MetaTensor objects. | ||||
| // | // | ||||
| @@ -157,7 +157,7 @@ class MetaTensor : public Value { | |||||
| // | // | ||||
| // param meta_tensor The MetaTensor object to be compared. | // param meta_tensor The MetaTensor object to be compared. | ||||
| // return true: If having same type and shape, return true, or return false. | // return true: If having same type and shape, return true, or return false. | ||||
| virtual bool operator==(const MetaTensor& meta_tensor) const; | |||||
| virtual bool operator==(const MetaTensor &meta_tensor) const; | |||||
| // brief Returns the data type of the tensor in its MetaTensor. | // brief Returns the data type of the tensor in its MetaTensor. | ||||
| // | // | ||||
| @@ -193,7 +193,7 @@ class MetaTensor : public Value { | |||||
| // | // | ||||
| // param shape The shape of the tensor. | // param shape The shape of the tensor. | ||||
| // return The shape's size. | // return The shape's size. | ||||
| size_t set_shape(const std::vector<int>& shape) { | |||||
| size_t set_shape(const std::vector<int> &shape) { | |||||
| this->shape_ = shape; | this->shape_ = shape; | ||||
| return shape_.size(); | return shape_.size(); | ||||
| } | } | ||||
| @@ -202,9 +202,9 @@ class MetaTensor : public Value { | |||||
| DeviceInfo device_info() const { return device_info_; } | DeviceInfo device_info() const { return device_info_; } | ||||
| // Set tensor's device info. | // Set tensor's device info. | ||||
| void set_device_info(const DeviceInfo& device_info) { device_info_ = device_info; } | |||||
| void set_device_info(const DeviceInfo &device_info) { device_info_ = device_info; } | |||||
| void SetDeviceInfo(const std::string& format, const TypePtr& data_type); | |||||
| void SetDeviceInfo(const std::string &format, const TypePtr &data_type); | |||||
| // Get the size of a given dimension by its index number. | // Get the size of a given dimension by its index number. | ||||
| int DimensionSize(size_t index) const; | int DimensionSize(size_t index) const; | ||||
| @@ -222,9 +222,9 @@ class MetaTensor : public Value { | |||||
| } | } | ||||
| return hash_value; | return hash_value; | ||||
| } | } | ||||
| bool operator==(const Value& other) const override { | |||||
| bool operator==(const Value &other) const override { | |||||
| if (other.isa<MetaTensor>()) { | if (other.isa<MetaTensor>()) { | ||||
| auto other_ = static_cast<const MetaTensor&>(other); | |||||
| auto other_ = static_cast<const MetaTensor &>(other); | |||||
| return *this == other_; | return *this == other_; | ||||
| } else { | } else { | ||||
| return false; | return false; | ||||
| @@ -262,49 +262,49 @@ class Tensor : public MetaTensor { | |||||
| // | // | ||||
| // param type_ptr [TypePty] Data type of the tensor. | // param type_ptr [TypePty] Data type of the tensor. | ||||
| // param py_shape [py::tuple] The shape represented by py::tuple of the tensor. | // param py_shape [py::tuple] The shape represented by py::tuple of the tensor. | ||||
| Tensor(const TypePtr& type_ptr, const py::tuple& shape); | |||||
| Tensor(const TypePtr &type_ptr, const py::tuple &shape); | |||||
| // brief Constructor for C++. | // brief Constructor for C++. | ||||
| // | // | ||||
| // param data_type [TypeId] Data type of the tensor. | // param data_type [TypeId] Data type of the tensor. | ||||
| // param shape The shape represented by std::vector<int> of the tensor. | // param shape The shape represented by std::vector<int> of the tensor. | ||||
| Tensor(TypeId data_type, const std::vector<int>& shape); | |||||
| Tensor(TypeId data_type, const std::vector<int> &shape); | |||||
| // brief Constructor for Python. | // brief Constructor for Python. | ||||
| // | // | ||||
| // param input [py::array] Data value of the tensor. | // param input [py::array] Data value of the tensor. | ||||
| // param data_type [TypeId] Data type of the tensor. | // param data_type [TypeId] Data type of the tensor. | ||||
| explicit Tensor(const py::array& input, const TypePtr& data_type = nullptr); | |||||
| explicit Tensor(const py::array &input, const TypePtr &data_type = nullptr); | |||||
| // brief Constructor | // brief Constructor | ||||
| // | // | ||||
| // param input [py::list] the data for tensor | // param input [py::list] the data for tensor | ||||
| // param data_type [TypeId] data type | // param data_type [TypeId] data type | ||||
| explicit Tensor(const py::list& input, const TypePtr& data_type = nullptr); | |||||
| explicit Tensor(const py::list &input, const TypePtr &data_type = nullptr); | |||||
| // brief Constructor | // brief Constructor | ||||
| // | // | ||||
| // param input [py::tuple] the data for tensor | // param input [py::tuple] the data for tensor | ||||
| // param data_type [TypeId] data type | // param data_type [TypeId] data type | ||||
| explicit Tensor(const py::tuple& input, const TypePtr& data_type = nullptr); | |||||
| explicit Tensor(const py::tuple &input, const TypePtr &data_type = nullptr); | |||||
| // brief Constructor | // brief Constructor | ||||
| // | // | ||||
| // param input [py::float_] the data for tensor | // param input [py::float_] the data for tensor | ||||
| // param data_type [TypeId] data type | // param data_type [TypeId] data type | ||||
| explicit Tensor(const py::float_& input, const TypePtr& data_type = nullptr); | |||||
| explicit Tensor(const py::float_ &input, const TypePtr &data_type = nullptr); | |||||
| // brief Constructor | // brief Constructor | ||||
| // | // | ||||
| // param input [py::int_] the data for tensor | // param input [py::int_] the data for tensor | ||||
| // param data_type [TypeId] data type | // param data_type [TypeId] data type | ||||
| explicit Tensor(const py::int_& input, const TypePtr& data_type = nullptr); | |||||
| explicit Tensor(const py::int_ &input, const TypePtr &data_type = nullptr); | |||||
| // brief Constructor | // brief Constructor | ||||
| // | // | ||||
| // param input [Tensor] the data for tensor | // param input [Tensor] the data for tensor | ||||
| // param data_type [TypeId] data type | // param data_type [TypeId] data type | ||||
| Tensor(const Tensor& tensor, const TypePtr& data_type = nullptr); | |||||
| Tensor(const Tensor &tensor, const TypePtr &data_type = nullptr); | |||||
| ~Tensor() override = default; | ~Tensor() override = default; | ||||
| @@ -315,7 +315,7 @@ class Tensor : public MetaTensor { | |||||
| // The constructed Tensor object has the same type and shape with tensor. | // The constructed Tensor object has the same type and shape with tensor. | ||||
| // | // | ||||
| // param tensor An existing Tensor object. | // param tensor An existing Tensor object. | ||||
| Tensor& operator=(const Tensor& tensor); | |||||
| Tensor &operator=(const Tensor &tensor); | |||||
| // brief Compares two Tensor objects. | // brief Compares two Tensor objects. | ||||
| // | // | ||||
| @@ -324,17 +324,17 @@ class Tensor : public MetaTensor { | |||||
| // | // | ||||
| // param tensor The Tensor object to be compared. | // param tensor The Tensor object to be compared. | ||||
| // return true: If having same type, shape and data, return true, or return false. | // return true: If having same type, shape and data, return true, or return false. | ||||
| bool operator==(const Tensor& tensor) const; | |||||
| bool operator==(const Tensor &tensor) const; | |||||
| // It is different from 'operator==' which just compare shape/type/address, it do real value comparison. | // It is different from 'operator==' which just compare shape/type/address, it do real value comparison. | ||||
| bool ValueEqual(const Tensor& other) const; | |||||
| bool ValueEqual(const Tensor &other) const; | |||||
| // It is different from 'operator==' which just compare shape/type/address, it do real value comparison. | // It is different from 'operator==' which just compare shape/type/address, it do real value comparison. | ||||
| bool ValueEqualPy(const py::object& other) const; | |||||
| bool ValueEqualPy(const py::object &other) const; | |||||
| bool operator==(const Value& other) const override { | |||||
| bool operator==(const Value &other) const override { | |||||
| if (other.isa<Tensor>()) { | if (other.isa<Tensor>()) { | ||||
| auto other_ = static_cast<const Tensor&>(other); | |||||
| auto other_ = static_cast<const Tensor &>(other); | |||||
| return *this == other_; | return *this == other_; | ||||
| } else { | } else { | ||||
| return false; | return false; | ||||
| @@ -375,13 +375,13 @@ class Tensor : public MetaTensor { | |||||
| // | // | ||||
| // param writable true if writable, false if read only | // param writable true if writable, false if read only | ||||
| // return The pointer to the object | // return The pointer to the object | ||||
| void* data_c(bool writable = false); | |||||
| void *data_c(bool writable = false); | |||||
| // brief Get data type from tensor data. | // brief Get data type from tensor data. | ||||
| // | // | ||||
| // param buf The buffer info of the py::array data. | // param buf The buffer info of the py::array data. | ||||
| // return The [TypeId] of the tensor data. | // return The [TypeId] of the tensor data. | ||||
| TypeId GetDataType(const py::buffer_info& buf) const; | |||||
| TypeId GetDataType(const py::buffer_info &buf) const; | |||||
| // brief Sets the data type of a tensor. | // brief Sets the data type of a tensor. | ||||
| // | // | ||||
| @@ -401,23 +401,23 @@ class Tensor : public MetaTensor { | |||||
| // param input [py::array] the data for tensor | // param input [py::array] the data for tensor | ||||
| // param data_type [TypeId] data type | // param data_type [TypeId] data type | ||||
| // return true if succeed, false if failed. | // return true if succeed, false if failed. | ||||
| void init(const py::array& input, const TypeId& data_type); | |||||
| void init(const py::array& input, const TypePtr& type_ptr); | |||||
| void init(const py::array &input, const TypeId &data_type); | |||||
| void init(const py::array &input, const TypePtr &type_ptr); | |||||
| // brief init tensor attribute | // brief init tensor attribute | ||||
| // | // | ||||
| // param data_type [TypeId] Data type of the tensor. | // param data_type [TypeId] Data type of the tensor. | ||||
| // param shape [py::array] The shape of the tensor. | // param shape [py::array] The shape of the tensor. | ||||
| // return true if succeed, false if failed. | // return true if succeed, false if failed. | ||||
| void init(TypeId data_type, const std::vector<int>& shape, py::array* data); | |||||
| void init(TypeId data_type, const std::vector<int> &shape, py::array *data); | |||||
| bool convert_data(const py::array& in, const TypeId in_data_type, py::array* out, const TypeId out_data_type); | |||||
| bool convert_data(const py::array &in, const TypeId in_data_type, py::array *out, const TypeId out_data_type); | |||||
| public: | public: | ||||
| bool is_dirty() const { return dirty_; } | bool is_dirty() const { return dirty_; } | ||||
| void set_dirty(const bool dirty) { dirty_ = dirty; } | void set_dirty(const bool dirty) { dirty_ = dirty; } | ||||
| DeviceAddressPtr device_address() const { return device_address_; } | DeviceAddressPtr device_address() const { return device_address_; } | ||||
| void set_device_address(const DeviceAddressPtr& device_address) { device_address_ = device_address; } | |||||
| void set_device_address(const DeviceAddressPtr &device_address) { device_address_ = device_address; } | |||||
| py::array data_sync(); | py::array data_sync(); | ||||
| private: | private: | ||||
| @@ -18,9 +18,9 @@ | |||||
| #include "pipeline/static_analysis/abstract_value.h" | #include "pipeline/static_analysis/abstract_value.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| bool Named::operator==(const Value& other) const { | |||||
| bool Named::operator==(const Value &other) const { | |||||
| if (other.isa<Named>()) { | if (other.isa<Named>()) { | ||||
| auto other_named = static_cast<const Named&>(other); | |||||
| auto other_named = static_cast<const Named &>(other); | |||||
| return *this == other_named; | return *this == other_named; | ||||
| } else { | } else { | ||||
| return false; | return false; | ||||
| @@ -27,18 +27,18 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| class Named : public Value { | class Named : public Value { | ||||
| public: | public: | ||||
| explicit Named(const std::string& name) : name_(name) { hash_id_ = std::hash<std::string>{}(name); } | |||||
| Named(const Named& other) : Value(other) { | |||||
| explicit Named(const std::string &name) : name_(name) { hash_id_ = std::hash<std::string>{}(name); } | |||||
| Named(const Named &other) : Value(other) { | |||||
| this->name_ = other.name_; | this->name_ = other.name_; | ||||
| hash_id_ = std::hash<std::string>{}(other.name_); | hash_id_ = std::hash<std::string>{}(other.name_); | ||||
| } | } | ||||
| ~Named() override = default; | ~Named() override = default; | ||||
| MS_DECLARE_PARENT(Named, Value); | MS_DECLARE_PARENT(Named, Value); | ||||
| const std::string& name() const { return name_; } | |||||
| virtual bool operator==(const Named& other) const { return name_ == other.name(); } | |||||
| bool operator==(const Value& other) const override; | |||||
| Named& operator=(const Named& other) { | |||||
| const std::string &name() const { return name_; } | |||||
| virtual bool operator==(const Named &other) const { return name_ == other.name(); } | |||||
| bool operator==(const Value &other) const override; | |||||
| Named &operator=(const Named &other) { | |||||
| if (&other != this) { | if (&other != this) { | ||||
| this->type_ = other.type_; | this->type_ = other.type_; | ||||
| this->name_ = other.name_; | this->name_ = other.name_; | ||||
| @@ -50,7 +50,7 @@ class Named : public Value { | |||||
| std::size_t Hash() const { return hash_id_; } | std::size_t Hash() const { return hash_id_; } | ||||
| std::size_t hash() const override { return hash_id_; } | std::size_t hash() const override { return hash_id_; } | ||||
| friend std::ostream& operator<<(std::ostream& os, const Named& nmd) { | |||||
| friend std::ostream &operator<<(std::ostream &os, const Named &nmd) { | |||||
| os << nmd.name(); | os << nmd.name(); | ||||
| return os; | return os; | ||||
| } | } | ||||
| @@ -31,7 +31,7 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| using mindspore::abstract::AbstractFunction; | using mindspore::abstract::AbstractFunction; | ||||
| abstract::AbstractBasePtr Primitive::ToPrimAbstract(const AnfNodePtr& anf_node) { | |||||
| abstract::AbstractBasePtr Primitive::ToPrimAbstract(const AnfNodePtr &anf_node) { | |||||
| auto prim_func = std::make_shared<abstract::PrimitiveAbstractClosure>(shared_from_base<Primitive>(), anf_node); | auto prim_func = std::make_shared<abstract::PrimitiveAbstractClosure>(shared_from_base<Primitive>(), anf_node); | ||||
| return prim_func; | return prim_func; | ||||
| } | } | ||||
| @@ -63,23 +63,23 @@ py::function Primitive::GetComputeFunction() { | |||||
| return fn; | return fn; | ||||
| } | } | ||||
| bool Primitive::operator==(const Value& other) const { | |||||
| bool Primitive::operator==(const Value &other) const { | |||||
| if (other.isa<Primitive>()) { | if (other.isa<Primitive>()) { | ||||
| auto other_prim = static_cast<const Primitive&>(other); | |||||
| auto other_prim = static_cast<const Primitive &>(other); | |||||
| return *this == other_prim; | return *this == other_prim; | ||||
| } else { | } else { | ||||
| return false; | return false; | ||||
| } | } | ||||
| } | } | ||||
| bool Primitive::operator==(const Primitive& other) const { | |||||
| bool Primitive::operator==(const Primitive &other) const { | |||||
| if (name() != other.name()) { | if (name() != other.name()) { | ||||
| return false; | return false; | ||||
| } | } | ||||
| if (attrs_.size() != other.attrs_.size()) { | if (attrs_.size() != other.attrs_.size()) { | ||||
| return false; | return false; | ||||
| } | } | ||||
| auto all = std::all_of(attrs_.begin(), attrs_.end(), [&other](const std::pair<std::string, ValuePtr>& item) -> bool { | |||||
| auto all = std::all_of(attrs_.begin(), attrs_.end(), [&other](const std::pair<std::string, ValuePtr> &item) -> bool { | |||||
| if (item.second == nullptr) { | if (item.second == nullptr) { | ||||
| return false; | return false; | ||||
| } | } | ||||
| @@ -95,7 +95,7 @@ bool Primitive::operator==(const Primitive& other) const { | |||||
| void Primitive::set_signatures( | void Primitive::set_signatures( | ||||
| std::vector<std::tuple<std::string, SignatureEnumRW, SignatureEnumKind, py::object, SignatureEnumDType>> signatures) { | std::vector<std::tuple<std::string, SignatureEnumRW, SignatureEnumKind, py::object, SignatureEnumDType>> signatures) { | ||||
| signatures_.clear(); | signatures_.clear(); | ||||
| for (auto& signature : signatures) { | |||||
| for (auto &signature : signatures) { | |||||
| std::string name; | std::string name; | ||||
| SignatureEnumRW rw; | SignatureEnumRW rw; | ||||
| SignatureEnumKind kind; | SignatureEnumKind kind; | ||||
| @@ -114,7 +114,7 @@ std::string Primitive::GetAttrsText() const { | |||||
| std::ostringstream oss; | std::ostringstream oss; | ||||
| oss << "["; | oss << "["; | ||||
| bool is_first = true; | bool is_first = true; | ||||
| for (auto& attr : attrs_) { | |||||
| for (auto &attr : attrs_) { | |||||
| if (is_first) { | if (is_first) { | ||||
| is_first = false; | is_first = false; | ||||
| } else { | } else { | ||||
| @@ -128,7 +128,7 @@ std::string Primitive::GetAttrsText() const { | |||||
| } | } | ||||
| py::function PrimitivePy::GetBpropFunction() { | py::function PrimitivePy::GetBpropFunction() { | ||||
| static const char* const get_bprop_func_name = "get_bprop"; | |||||
| static const char *const get_bprop_func_name = "get_bprop"; | |||||
| if (py::hasattr(python_obj_, get_bprop_func_name)) { | if (py::hasattr(python_obj_, get_bprop_func_name)) { | ||||
| py::function fn = python_obj_.attr(get_bprop_func_name)().cast<py::function>(); | py::function fn = python_obj_.attr(get_bprop_func_name)().cast<py::function>(); | ||||
| return fn; | return fn; | ||||
| @@ -142,7 +142,7 @@ py::function PrimitivePy::GetBpropFunction() { | |||||
| } | } | ||||
| py::function PrimitivePy::GetComputeFunction() { | py::function PrimitivePy::GetComputeFunction() { | ||||
| static const char* const compute_func_name = "vm_impl"; | |||||
| static const char *const compute_func_name = "vm_impl"; | |||||
| if (py::hasattr(python_obj_, compute_func_name)) { | if (py::hasattr(python_obj_, compute_func_name)) { | ||||
| MS_LOG(INFO) << "" << name() << " compute_func_name"; | MS_LOG(INFO) << "" << name() << " compute_func_name"; | ||||
| @@ -163,7 +163,7 @@ py::function PrimitivePy::GetComputeFunction() { | |||||
| return vm_fn; | return vm_fn; | ||||
| } | } | ||||
| void PrimitivePy::AddPyAttr(const py::str& name, const py::object& obj) { | |||||
| void PrimitivePy::AddPyAttr(const py::str &name, const py::object &obj) { | |||||
| std::string attr_name = name; | std::string attr_name = name; | ||||
| ValuePtr converted_ret = nullptr; | ValuePtr converted_ret = nullptr; | ||||
| if (py::isinstance<py::module>(obj)) { | if (py::isinstance<py::module>(obj)) { | ||||
| @@ -178,13 +178,13 @@ void PrimitivePy::AddPyAttr(const py::str& name, const py::object& obj) { | |||||
| py::dict PrimitivePy::GetAttrDict() { | py::dict PrimitivePy::GetAttrDict() { | ||||
| py::dict attr_dict; | py::dict attr_dict; | ||||
| for (auto& attr : attrs_) { | |||||
| for (auto &attr : attrs_) { | |||||
| attr_dict[py::str(attr.first)] = ValuePtrToPyData(attr.second); | attr_dict[py::str(attr.first)] = ValuePtrToPyData(attr.second); | ||||
| } | } | ||||
| return attr_dict; | return attr_dict; | ||||
| } | } | ||||
| REGISTER_PYBIND_DEFINE(Primitive_, ([](const py::module* m) { | |||||
| REGISTER_PYBIND_DEFINE(Primitive_, ([](const py::module *m) { | |||||
| (void)py::enum_<PrimType>(*m, "prim_type", py::arithmetic()) | (void)py::enum_<PrimType>(*m, "prim_type", py::arithmetic()) | ||||
| .value("unknown", PrimType::kPrimTypeUnknown) | .value("unknown", PrimType::kPrimTypeUnknown) | ||||
| .value("builtin", PrimType::kPrimTypeBuiltIn) | .value("builtin", PrimType::kPrimTypeBuiltIn) | ||||
| @@ -192,7 +192,7 @@ REGISTER_PYBIND_DEFINE(Primitive_, ([](const py::module* m) { | |||||
| .value("user_custom", PrimType::kPrimTypeUserCustom); | .value("user_custom", PrimType::kPrimTypeUserCustom); | ||||
| (void)py::class_<PrimitivePy, std::shared_ptr<PrimitivePy>>(*m, "Primitive_") | (void)py::class_<PrimitivePy, std::shared_ptr<PrimitivePy>>(*m, "Primitive_") | ||||
| .def_readonly(PYTHON_PRIMITIVE_FLAG, &PrimitivePy::parse_info_) | .def_readonly(PYTHON_PRIMITIVE_FLAG, &PrimitivePy::parse_info_) | ||||
| .def(py::init<py::str&, py::object>()) | |||||
| .def(py::init<py::str &, py::object>()) | |||||
| .def("add_attr", &PrimitivePy::AddPyAttr, "add primitive attr") | .def("add_attr", &PrimitivePy::AddPyAttr, "add primitive attr") | ||||
| .def("get_attr_dict", &PrimitivePy::GetAttrDict, "get primitive attr") | .def("get_attr_dict", &PrimitivePy::GetAttrDict, "get primitive attr") | ||||
| .def("set_prim_type", &PrimitivePy::set_prim_type, "Set primitive type.") | .def("set_prim_type", &PrimitivePy::set_prim_type, "Set primitive type.") | ||||
| @@ -48,25 +48,25 @@ enum PrimType { | |||||
| class Primitive : public Named { | class Primitive : public Named { | ||||
| public: | public: | ||||
| explicit Primitive(const std::string& name, const PrimType prim_type = kPrimTypeBuiltIn) | |||||
| explicit Primitive(const std::string &name, const PrimType prim_type = kPrimTypeBuiltIn) | |||||
| : Named(name), signatures_(), prim_type_(prim_type) {} | : Named(name), signatures_(), prim_type_(prim_type) {} | ||||
| Primitive(const Primitive& prim) | |||||
| Primitive(const Primitive &prim) | |||||
| : Named(prim), attrs_(prim.attrs_), signatures_(prim.signatures_), prim_type_(prim.prim_type_) {} | : Named(prim), attrs_(prim.attrs_), signatures_(prim.signatures_), prim_type_(prim.prim_type_) {} | ||||
| MS_DECLARE_PARENT(Primitive, Named); | MS_DECLARE_PARENT(Primitive, Named); | ||||
| abstract::AbstractBasePtr ToPrimAbstract(const AnfNodePtr& anf_node); | |||||
| abstract::AbstractBasePtr ToPrimAbstract(const AnfNodePtr &anf_node); | |||||
| std::string ToString() const override { return name(); } | std::string ToString() const override { return name(); } | ||||
| virtual py::function GetBpropFunction(); | virtual py::function GetBpropFunction(); | ||||
| virtual py::function GetComputeFunction(); | virtual py::function GetComputeFunction(); | ||||
| Primitive& AddAttr(const std::string& name, const ValuePtr& attr) { | |||||
| Primitive &AddAttr(const std::string &name, const ValuePtr &attr) { | |||||
| attrs_[name] = attr; | attrs_[name] = attr; | ||||
| return *this; | return *this; | ||||
| } | } | ||||
| Primitive& SetAttrs(const std::unordered_map<std::string, ValuePtr>& attrs) { | |||||
| for (auto& attr : attrs) { | |||||
| Primitive &SetAttrs(const std::unordered_map<std::string, ValuePtr> &attrs) { | |||||
| for (auto &attr : attrs) { | |||||
| attrs_[attr.first] = attr.second; | attrs_[attr.first] = attr.second; | ||||
| } | } | ||||
| return *this; | return *this; | ||||
| @@ -76,21 +76,21 @@ class Primitive : public Named { | |||||
| std::vector<std::tuple<std::string, SignatureEnumRW, SignatureEnumKind, py::object, SignatureEnumDType>> | std::vector<std::tuple<std::string, SignatureEnumRW, SignatureEnumKind, py::object, SignatureEnumDType>> | ||||
| signatures); | signatures); | ||||
| const std::vector<Signature>& signatures() const { return signatures_; } | |||||
| const std::vector<Signature> &signatures() const { return signatures_; } | |||||
| void set_attr(const std::string& attrName, const ValuePtr& attr) { attrs_[attrName] = attr; } | |||||
| void EraseAttr(const std::string& attrName) { (void)attrs_.erase(attrName); } | |||||
| void set_attr(const std::string &attrName, const ValuePtr &attr) { attrs_[attrName] = attr; } | |||||
| void EraseAttr(const std::string &attrName) { (void)attrs_.erase(attrName); } | |||||
| ValuePtr GetAttr(const std::string& attrName) const { | |||||
| ValuePtr GetAttr(const std::string &attrName) const { | |||||
| auto iter = attrs_.find(attrName); | auto iter = attrs_.find(attrName); | ||||
| return iter == attrs_.cend() ? nullptr : iter->second; | return iter == attrs_.cend() ? nullptr : iter->second; | ||||
| } | } | ||||
| const std::unordered_map<std::string, ValuePtr>& attrs() const { return attrs_; } | |||||
| const std::unordered_map<std::string, ValuePtr> &attrs() const { return attrs_; } | |||||
| // if Primitive has any attribute, for Primitives like scalar_add, return, etc, don't have any attribute. | // if Primitive has any attribute, for Primitives like scalar_add, return, etc, don't have any attribute. | ||||
| bool HasAttr() const { return !attrs_.empty(); } | bool HasAttr() const { return !attrs_.empty(); } | ||||
| bool HasAttr(const std::string& attrName) const { | |||||
| bool HasAttr(const std::string &attrName) const { | |||||
| auto iter = attrs_.find(attrName); | auto iter = attrs_.find(attrName); | ||||
| return !(iter == attrs_.cend()); | return !(iter == attrs_.cend()); | ||||
| } | } | ||||
| @@ -103,8 +103,8 @@ class Primitive : public Named { | |||||
| PrimType prim_type() const { return prim_type_; } | PrimType prim_type() const { return prim_type_; } | ||||
| std::string instance_name() const { return instance_name_; } | std::string instance_name() const { return instance_name_; } | ||||
| std::string GetAttrsText() const; | std::string GetAttrsText() const; | ||||
| bool operator==(const Value& other) const override; | |||||
| bool operator==(const Primitive& other) const; | |||||
| bool operator==(const Value &other) const override; | |||||
| bool operator==(const Primitive &other) const; | |||||
| ~Primitive() override = default; | ~Primitive() override = default; | ||||
| protected: | protected: | ||||
| @@ -118,18 +118,18 @@ class Primitive : public Named { | |||||
| class PrimitivePy : public Primitive { | class PrimitivePy : public Primitive { | ||||
| public: | public: | ||||
| PrimitivePy(const py::str& name, const py::object& python_obj) : Primitive(name), python_obj_(python_obj) {} | |||||
| PrimitivePy(const py::str &name, const py::object &python_obj) : Primitive(name), python_obj_(python_obj) {} | |||||
| ~PrimitivePy() override = default; | ~PrimitivePy() override = default; | ||||
| MS_DECLARE_PARENT(PrimitivePy, Primitive); | MS_DECLARE_PARENT(PrimitivePy, Primitive); | ||||
| py::function GetBpropFunction() override; | py::function GetBpropFunction() override; | ||||
| py::function GetComputeFunction() override; | py::function GetComputeFunction() override; | ||||
| void AddPyAttr(const py::str& name, const py::object& obj); | |||||
| void AddPyAttr(const py::str &name, const py::object &obj); | |||||
| py::dict GetAttrDict(); | py::dict GetAttrDict(); | ||||
| const bool parse_info_ = true; | const bool parse_info_ = true; | ||||
| const py::object& GetPyObj() const { return python_obj_; } | |||||
| const py::object &GetPyObj() const { return python_obj_; } | |||||
| bool is_tuple_input_ = false; | bool is_tuple_input_ = false; | ||||
| private: | private: | ||||
| @@ -138,13 +138,13 @@ class PrimitivePy : public Primitive { | |||||
| using PrimitivePyPtr = std::shared_ptr<PrimitivePy>; | using PrimitivePyPtr = std::shared_ptr<PrimitivePy>; | ||||
| inline std::ostream& operator<<(std::ostream& os, const PrimitivePtr& p) { | |||||
| inline std::ostream &operator<<(std::ostream &os, const PrimitivePtr &p) { | |||||
| os << *p; | os << *p; | ||||
| return os; | return os; | ||||
| } | } | ||||
| struct PrimitiveEqual { | struct PrimitiveEqual { | ||||
| bool operator()(PrimitivePtr const& t1, PrimitivePtr const& t2) const { | |||||
| bool operator()(PrimitivePtr const &t1, PrimitivePtr const &t2) const { | |||||
| MS_EXCEPTION_IF_NULL(t1); | MS_EXCEPTION_IF_NULL(t1); | ||||
| MS_EXCEPTION_IF_NULL(t2); | MS_EXCEPTION_IF_NULL(t2); | ||||
| return t1->name() == t2->name(); | return t1->name() == t2->name(); | ||||
| @@ -152,7 +152,7 @@ struct PrimitiveEqual { | |||||
| }; | }; | ||||
| struct PrimitiveHasher { | struct PrimitiveHasher { | ||||
| std::size_t operator()(PrimitivePtr const& prim) const { | |||||
| std::size_t operator()(PrimitivePtr const &prim) const { | |||||
| std::size_t hash = std::hash<std::string>()(prim->name()); | std::size_t hash = std::hash<std::string>()(prim->name()); | ||||
| return hash; | return hash; | ||||
| } | } | ||||
| @@ -55,8 +55,8 @@ class BoolImm : public Scalar { | |||||
| bool value() const { return v_; } | bool value() const { return v_; } | ||||
| bool IsZero() override { return v_ == false; } | bool IsZero() override { return v_ == false; } | ||||
| bool IsOne() override { return v_ == true; } | bool IsOne() override { return v_ == true; } | ||||
| bool operator==(const Value& other) const override; | |||||
| bool operator==(const BoolImm& other) const; | |||||
| bool operator==(const Value &other) const override; | |||||
| bool operator==(const BoolImm &other) const; | |||||
| std::string ToString() const override { | std::string ToString() const override { | ||||
| if (v_) { | if (v_) { | ||||
| return "true"; | return "true"; | ||||
| @@ -80,7 +80,7 @@ IMM_TRAITS(BoolImmPtr, bool) | |||||
| class IntergerImm : public Scalar { | class IntergerImm : public Scalar { | ||||
| public: | public: | ||||
| IntergerImm() = default; | IntergerImm() = default; | ||||
| explicit IntergerImm(const TypePtr& t) : Scalar(t) {} | |||||
| explicit IntergerImm(const TypePtr &t) : Scalar(t) {} | |||||
| ~IntergerImm() override = default; | ~IntergerImm() override = default; | ||||
| MS_DECLARE_PARENT(IntergerImm, Scalar) | MS_DECLARE_PARENT(IntergerImm, Scalar) | ||||
| }; | }; | ||||
| @@ -95,8 +95,8 @@ class Int8Imm : public IntergerImm { | |||||
| bool IsZero() override { return v_ == 0; } | bool IsZero() override { return v_ == 0; } | ||||
| bool IsOne() override { return v_ == 1; } | bool IsOne() override { return v_ == 1; } | ||||
| int8_t value() const { return v_; } | int8_t value() const { return v_; } | ||||
| bool operator==(const Value& other) const override; | |||||
| bool operator==(const Int8Imm& other) const; | |||||
| bool operator==(const Value &other) const override; | |||||
| bool operator==(const Int8Imm &other) const; | |||||
| std::string ToString() const override { return std::to_string(v_); } | std::string ToString() const override { return std::to_string(v_); } | ||||
| std::string DumpText() const override { | std::string DumpText() const override { | ||||
| @@ -121,8 +121,8 @@ class Int16Imm : public IntergerImm { | |||||
| bool IsZero() override { return v_ == 0; } | bool IsZero() override { return v_ == 0; } | ||||
| bool IsOne() override { return v_ == 1; } | bool IsOne() override { return v_ == 1; } | ||||
| int16_t value() const { return v_; } | int16_t value() const { return v_; } | ||||
| bool operator==(const Value& other) const override; | |||||
| bool operator==(const Int16Imm& other) const; | |||||
| bool operator==(const Value &other) const override; | |||||
| bool operator==(const Int16Imm &other) const; | |||||
| std::string ToString() const override { return std::to_string(v_); } | std::string ToString() const override { return std::to_string(v_); } | ||||
| std::string DumpText() const override { | std::string DumpText() const override { | ||||
| @@ -147,8 +147,8 @@ class Int32Imm : public IntergerImm { | |||||
| bool IsZero() override { return v_ == 0; } | bool IsZero() override { return v_ == 0; } | ||||
| bool IsOne() override { return v_ == 1; } | bool IsOne() override { return v_ == 1; } | ||||
| int32_t value() const { return v_; } | int32_t value() const { return v_; } | ||||
| bool operator==(const Value& other) const override; | |||||
| bool operator==(const Int32Imm& other) const; | |||||
| bool operator==(const Value &other) const override; | |||||
| bool operator==(const Int32Imm &other) const; | |||||
| std::string ToString() const override { return std::to_string(v_); } | std::string ToString() const override { return std::to_string(v_); } | ||||
| std::string DumpText() const override { | std::string DumpText() const override { | ||||
| @@ -173,8 +173,8 @@ class Int64Imm : public IntergerImm { | |||||
| bool IsZero() override { return v_ == 0; } | bool IsZero() override { return v_ == 0; } | ||||
| bool IsOne() override { return v_ == 1; } | bool IsOne() override { return v_ == 1; } | ||||
| int64_t value() const { return v_; } | int64_t value() const { return v_; } | ||||
| bool operator==(const Value& other) const override; | |||||
| bool operator==(const Int64Imm& other) const; | |||||
| bool operator==(const Value &other) const override; | |||||
| bool operator==(const Int64Imm &other) const; | |||||
| std::string ToString() const override { return std::to_string(v_); } | std::string ToString() const override { return std::to_string(v_); } | ||||
| std::string DumpText() const override { | std::string DumpText() const override { | ||||
| @@ -199,8 +199,8 @@ class UInt8Imm : public IntergerImm { | |||||
| bool IsZero() override { return v_ == 0; } | bool IsZero() override { return v_ == 0; } | ||||
| bool IsOne() override { return v_ == 1; } | bool IsOne() override { return v_ == 1; } | ||||
| uint8_t value() const { return v_; } | uint8_t value() const { return v_; } | ||||
| bool operator==(const Value& other) const override; | |||||
| bool operator==(const UInt8Imm& other) const; | |||||
| bool operator==(const Value &other) const override; | |||||
| bool operator==(const UInt8Imm &other) const; | |||||
| std::string ToString() const override { return std::to_string(v_); } | std::string ToString() const override { return std::to_string(v_); } | ||||
| std::string DumpText() const override { | std::string DumpText() const override { | ||||
| @@ -225,8 +225,8 @@ class UInt16Imm : public IntergerImm { | |||||
| bool IsZero() override { return v_ == 0; } | bool IsZero() override { return v_ == 0; } | ||||
| bool IsOne() override { return v_ == 1; } | bool IsOne() override { return v_ == 1; } | ||||
| uint16_t value() const { return v_; } | uint16_t value() const { return v_; } | ||||
| bool operator==(const Value& other) const override; | |||||
| bool operator==(const UInt16Imm& other) const; | |||||
| bool operator==(const Value &other) const override; | |||||
| bool operator==(const UInt16Imm &other) const; | |||||
| std::string ToString() const override { return std::to_string(v_); } | std::string ToString() const override { return std::to_string(v_); } | ||||
| std::string DumpText() const override { | std::string DumpText() const override { | ||||
| @@ -251,8 +251,8 @@ class UInt32Imm : public IntergerImm { | |||||
| bool IsZero() override { return v_ == 0; } | bool IsZero() override { return v_ == 0; } | ||||
| bool IsOne() override { return v_ == 1; } | bool IsOne() override { return v_ == 1; } | ||||
| uint32_t value() const { return v_; } | uint32_t value() const { return v_; } | ||||
| bool operator==(const Value& other) const override; | |||||
| bool operator==(const UInt32Imm& other) const; | |||||
| bool operator==(const Value &other) const override; | |||||
| bool operator==(const UInt32Imm &other) const; | |||||
| std::string ToString() const override { return std::to_string(v_); } | std::string ToString() const override { return std::to_string(v_); } | ||||
| std::string DumpText() const override { | std::string DumpText() const override { | ||||
| @@ -277,8 +277,8 @@ class UInt64Imm : public IntergerImm { | |||||
| bool IsZero() override { return v_ == 0; } | bool IsZero() override { return v_ == 0; } | ||||
| bool IsOne() override { return v_ == 1; } | bool IsOne() override { return v_ == 1; } | ||||
| uint64_t value() const { return v_; } | uint64_t value() const { return v_; } | ||||
| bool operator==(const Value& other) const override; | |||||
| bool operator==(const UInt64Imm& other) const; | |||||
| bool operator==(const Value &other) const override; | |||||
| bool operator==(const UInt64Imm &other) const; | |||||
| std::string ToString() const override { return std::to_string(v_); } | std::string ToString() const override { return std::to_string(v_); } | ||||
| std::string DumpText() const override { | std::string DumpText() const override { | ||||
| @@ -296,7 +296,7 @@ IMM_TRAITS(UInt64ImmPtr, uint64_t); | |||||
| class FloatImm : public Scalar { | class FloatImm : public Scalar { | ||||
| public: | public: | ||||
| FloatImm() = default; | FloatImm() = default; | ||||
| explicit FloatImm(const TypePtr& t) : Scalar(t) {} | |||||
| explicit FloatImm(const TypePtr &t) : Scalar(t) {} | |||||
| ~FloatImm() override = default; | ~FloatImm() override = default; | ||||
| MS_DECLARE_PARENT(FloatImm, Scalar) | MS_DECLARE_PARENT(FloatImm, Scalar) | ||||
| }; | }; | ||||
| @@ -312,8 +312,8 @@ class FP32Imm : public FloatImm { | |||||
| bool IsZero() override { return fabs(v_) <= FLT_EPSILON; } | bool IsZero() override { return fabs(v_) <= FLT_EPSILON; } | ||||
| bool IsOne() override { return fabs(v_ - 1.0) <= FLT_EPSILON; } | bool IsOne() override { return fabs(v_ - 1.0) <= FLT_EPSILON; } | ||||
| float value() const { return v_; } | float value() const { return v_; } | ||||
| bool operator==(const Value& other) const override; | |||||
| bool operator==(const FP32Imm& other) const; | |||||
| bool operator==(const Value &other) const override; | |||||
| bool operator==(const FP32Imm &other) const; | |||||
| std::string ToString() const override { return std::to_string(v_); } | std::string ToString() const override { return std::to_string(v_); } | ||||
| std::string DumpText() const override { | std::string DumpText() const override { | ||||
| @@ -338,8 +338,8 @@ class FP64Imm : public FloatImm { | |||||
| bool IsZero() override { return fabs(v_) <= DBL_EPSILON; } | bool IsZero() override { return fabs(v_) <= DBL_EPSILON; } | ||||
| bool IsOne() override { return fabs(v_ - 1.0) <= DBL_EPSILON; } | bool IsOne() override { return fabs(v_ - 1.0) <= DBL_EPSILON; } | ||||
| double value() const { return v_; } | double value() const { return v_; } | ||||
| bool operator==(const Value& other) const override; | |||||
| bool operator==(const FP64Imm& other) const; | |||||
| bool operator==(const Value &other) const override; | |||||
| bool operator==(const FP64Imm &other) const; | |||||
| std::string ToString() const override { return std::to_string(v_); } | std::string ToString() const override { return std::to_string(v_); } | ||||
| std::string DumpText() const override { | std::string DumpText() const override { | ||||
| @@ -21,8 +21,8 @@ | |||||
| #include "pipeline/parse/data_converter.h" | #include "pipeline/parse/data_converter.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| Signature::Signature(const std::string& arg_name, const SignatureEnumRW& rw_tag, const SignatureEnumKind& arg_kind, | |||||
| const py::object& arg_default, const SignatureEnumDType& arg_dtype) | |||||
| Signature::Signature(const std::string &arg_name, const SignatureEnumRW &rw_tag, const SignatureEnumKind &arg_kind, | |||||
| const py::object &arg_default, const SignatureEnumDType &arg_dtype) | |||||
| : name(arg_name), rw(rw_tag), kind(arg_kind), dtype(arg_dtype) { | : name(arg_name), rw(rw_tag), kind(arg_kind), dtype(arg_dtype) { | ||||
| if (py::isinstance<SignatureEnumKind>(arg_default) && | if (py::isinstance<SignatureEnumKind>(arg_default) && | ||||
| py::cast<SignatureEnumKind>(arg_default) == SignatureEnumKind::kKindEmptyDefaultValue) { | py::cast<SignatureEnumKind>(arg_default) == SignatureEnumKind::kKindEmptyDefaultValue) { | ||||
| @@ -32,14 +32,14 @@ Signature::Signature(const std::string& arg_name, const SignatureEnumRW& rw_tag, | |||||
| } | } | ||||
| } | } | ||||
| Signature::Signature(const std::string& arg_name, const SignatureEnumRW& rw_tag, const SignatureEnumKind& arg_kind) | |||||
| Signature::Signature(const std::string &arg_name, const SignatureEnumRW &rw_tag, const SignatureEnumKind &arg_kind) | |||||
| : name(arg_name), | : name(arg_name), | ||||
| rw(rw_tag), | rw(rw_tag), | ||||
| kind(arg_kind), | kind(arg_kind), | ||||
| default_value(nullptr), | default_value(nullptr), | ||||
| dtype(SignatureEnumDType::kDTypeEmptyDefaultValue) {} | dtype(SignatureEnumDType::kDTypeEmptyDefaultValue) {} | ||||
| REGISTER_PYBIND_DEFINE(SignatureEnumRW, ([](const py::module* m) { | |||||
| REGISTER_PYBIND_DEFINE(SignatureEnumRW, ([](const py::module *m) { | |||||
| (void)py::enum_<SignatureEnumRW>(*m, "signature_rw", py::arithmetic()) | (void)py::enum_<SignatureEnumRW>(*m, "signature_rw", py::arithmetic()) | ||||
| .value("RW_READ", SignatureEnumRW::kRWRead) | .value("RW_READ", SignatureEnumRW::kRWRead) | ||||
| .value("RW_WRITE", SignatureEnumRW::kRWWrite) | .value("RW_WRITE", SignatureEnumRW::kRWWrite) | ||||
| @@ -61,9 +61,9 @@ struct Signature { | |||||
| SignatureEnumKind kind; | SignatureEnumKind kind; | ||||
| ValuePtr default_value; // nullptr for no default value | ValuePtr default_value; // nullptr for no default value | ||||
| SignatureEnumDType dtype; | SignatureEnumDType dtype; | ||||
| Signature(const std::string& arg_name, const SignatureEnumRW& rw_tag, const SignatureEnumKind& arg_kind, | |||||
| const py::object& arg_default, const SignatureEnumDType& arg_dtype); | |||||
| Signature(const std::string& arg_name, const SignatureEnumRW& rw_tag, const SignatureEnumKind& arg_kind); | |||||
| Signature(const std::string &arg_name, const SignatureEnumRW &rw_tag, const SignatureEnumKind &arg_kind, | |||||
| const py::object &arg_default, const SignatureEnumDType &arg_dtype); | |||||
| Signature(const std::string &arg_name, const SignatureEnumRW &rw_tag, const SignatureEnumKind &arg_kind); | |||||
| }; | }; | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -24,7 +24,7 @@ | |||||
| #include "pipeline/static_analysis/abstract_value.h" | #include "pipeline/static_analysis/abstract_value.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| const ValuePtr ValueSequeue::operator[](const std::size_t& dim) const { | |||||
| const ValuePtr ValueSequeue::operator[](const std::size_t &dim) const { | |||||
| if (dim >= size()) { | if (dim >= size()) { | ||||
| MS_LOG(EXCEPTION) << "List index [" << dim << "] is out of range [" << size() << "]."; | MS_LOG(EXCEPTION) << "List index [" << dim << "] is out of range [" << size() << "]."; | ||||
| } | } | ||||
| @@ -40,125 +40,125 @@ bool ValueSequeue::erase(size_t idx) { | |||||
| } | } | ||||
| } | } | ||||
| bool BoolImm::operator==(const Value& other) const { | |||||
| bool BoolImm::operator==(const Value &other) const { | |||||
| if (other.isa<BoolImm>()) { | if (other.isa<BoolImm>()) { | ||||
| auto other_ = static_cast<const BoolImm&>(other); | |||||
| auto other_ = static_cast<const BoolImm &>(other); | |||||
| return *this == other_; | return *this == other_; | ||||
| } else { | } else { | ||||
| return false; | return false; | ||||
| } | } | ||||
| } | } | ||||
| bool BoolImm::operator==(const BoolImm& other) const { return v_ == other.v_; } | |||||
| bool BoolImm::operator==(const BoolImm &other) const { return v_ == other.v_; } | |||||
| bool Int8Imm::operator==(const Value& other) const { | |||||
| bool Int8Imm::operator==(const Value &other) const { | |||||
| if (other.isa<Int8Imm>()) { | if (other.isa<Int8Imm>()) { | ||||
| auto other_ = static_cast<const Int8Imm&>(other); | |||||
| auto other_ = static_cast<const Int8Imm &>(other); | |||||
| return *this == other_; | return *this == other_; | ||||
| } else { | } else { | ||||
| return false; | return false; | ||||
| } | } | ||||
| } | } | ||||
| bool Int8Imm::operator==(const Int8Imm& other) const { return v_ == other.v_; } | |||||
| bool Int16Imm::operator==(const Value& other) const { | |||||
| bool Int8Imm::operator==(const Int8Imm &other) const { return v_ == other.v_; } | |||||
| bool Int16Imm::operator==(const Value &other) const { | |||||
| if (other.isa<Int16Imm>()) { | if (other.isa<Int16Imm>()) { | ||||
| auto other_ = static_cast<const Int16Imm&>(other); | |||||
| auto other_ = static_cast<const Int16Imm &>(other); | |||||
| return *this == other_; | return *this == other_; | ||||
| } else { | } else { | ||||
| return false; | return false; | ||||
| } | } | ||||
| } | } | ||||
| bool Int16Imm::operator==(const Int16Imm& other) const { return v_ == other.v_; } | |||||
| bool Int32Imm::operator==(const Value& other) const { | |||||
| bool Int16Imm::operator==(const Int16Imm &other) const { return v_ == other.v_; } | |||||
| bool Int32Imm::operator==(const Value &other) const { | |||||
| if (other.isa<Int32Imm>()) { | if (other.isa<Int32Imm>()) { | ||||
| auto other_ = static_cast<const Int32Imm&>(other); | |||||
| auto other_ = static_cast<const Int32Imm &>(other); | |||||
| return *this == other_; | return *this == other_; | ||||
| } else { | } else { | ||||
| return false; | return false; | ||||
| } | } | ||||
| } | } | ||||
| bool Int32Imm::operator==(const Int32Imm& other) const { return v_ == other.v_; } | |||||
| bool Int64Imm::operator==(const Value& other) const { | |||||
| bool Int32Imm::operator==(const Int32Imm &other) const { return v_ == other.v_; } | |||||
| bool Int64Imm::operator==(const Value &other) const { | |||||
| if (other.isa<Int64Imm>()) { | if (other.isa<Int64Imm>()) { | ||||
| auto other_ = static_cast<const Int64Imm&>(other); | |||||
| auto other_ = static_cast<const Int64Imm &>(other); | |||||
| return *this == other_; | return *this == other_; | ||||
| } else { | } else { | ||||
| return false; | return false; | ||||
| } | } | ||||
| } | } | ||||
| bool Int64Imm::operator==(const Int64Imm& other) const { return v_ == other.v_; } | |||||
| bool UInt8Imm::operator==(const Value& other) const { | |||||
| bool Int64Imm::operator==(const Int64Imm &other) const { return v_ == other.v_; } | |||||
| bool UInt8Imm::operator==(const Value &other) const { | |||||
| if (other.isa<UInt8Imm>()) { | if (other.isa<UInt8Imm>()) { | ||||
| auto other_ = static_cast<const UInt8Imm&>(other); | |||||
| auto other_ = static_cast<const UInt8Imm &>(other); | |||||
| return *this == other_; | return *this == other_; | ||||
| } else { | } else { | ||||
| return false; | return false; | ||||
| } | } | ||||
| } | } | ||||
| bool UInt8Imm::operator==(const UInt8Imm& other) const { return v_ == other.v_; } | |||||
| bool UInt16Imm::operator==(const Value& other) const { | |||||
| bool UInt8Imm::operator==(const UInt8Imm &other) const { return v_ == other.v_; } | |||||
| bool UInt16Imm::operator==(const Value &other) const { | |||||
| if (other.isa<UInt16Imm>()) { | if (other.isa<UInt16Imm>()) { | ||||
| auto other_ = static_cast<const UInt16Imm&>(other); | |||||
| auto other_ = static_cast<const UInt16Imm &>(other); | |||||
| return *this == other_; | return *this == other_; | ||||
| } else { | } else { | ||||
| return false; | return false; | ||||
| } | } | ||||
| } | } | ||||
| bool UInt16Imm::operator==(const UInt16Imm& other) const { return v_ == other.v_; } | |||||
| bool UInt32Imm::operator==(const Value& other) const { | |||||
| bool UInt16Imm::operator==(const UInt16Imm &other) const { return v_ == other.v_; } | |||||
| bool UInt32Imm::operator==(const Value &other) const { | |||||
| if (other.isa<UInt32Imm>()) { | if (other.isa<UInt32Imm>()) { | ||||
| auto other_ = static_cast<const UInt32Imm&>(other); | |||||
| auto other_ = static_cast<const UInt32Imm &>(other); | |||||
| return *this == other_; | return *this == other_; | ||||
| } else { | } else { | ||||
| return false; | return false; | ||||
| } | } | ||||
| } | } | ||||
| bool UInt32Imm::operator==(const UInt32Imm& other) const { return v_ == other.v_; } | |||||
| bool UInt64Imm::operator==(const Value& other) const { | |||||
| bool UInt32Imm::operator==(const UInt32Imm &other) const { return v_ == other.v_; } | |||||
| bool UInt64Imm::operator==(const Value &other) const { | |||||
| if (other.isa<UInt64Imm>()) { | if (other.isa<UInt64Imm>()) { | ||||
| auto other_ = static_cast<const UInt64Imm&>(other); | |||||
| auto other_ = static_cast<const UInt64Imm &>(other); | |||||
| return *this == other_; | return *this == other_; | ||||
| } else { | } else { | ||||
| return false; | return false; | ||||
| } | } | ||||
| } | } | ||||
| bool UInt64Imm::operator==(const UInt64Imm& other) const { return v_ == other.v_; } | |||||
| bool FP32Imm::operator==(const Value& other) const { | |||||
| bool UInt64Imm::operator==(const UInt64Imm &other) const { return v_ == other.v_; } | |||||
| bool FP32Imm::operator==(const Value &other) const { | |||||
| if (other.isa<FP32Imm>()) { | if (other.isa<FP32Imm>()) { | ||||
| auto other_ = static_cast<const FP32Imm&>(other); | |||||
| auto other_ = static_cast<const FP32Imm &>(other); | |||||
| return *this == other_; | return *this == other_; | ||||
| } else { | } else { | ||||
| return false; | return false; | ||||
| } | } | ||||
| } | } | ||||
| bool FP32Imm::operator==(const FP32Imm& other) const { return fabs(v_ - other.v_) < FLT_EPSILON; } | |||||
| bool FP64Imm::operator==(const Value& other) const { | |||||
| bool FP32Imm::operator==(const FP32Imm &other) const { return fabs(v_ - other.v_) < FLT_EPSILON; } | |||||
| bool FP64Imm::operator==(const Value &other) const { | |||||
| if (other.isa<FP64Imm>()) { | if (other.isa<FP64Imm>()) { | ||||
| auto other_ = static_cast<const FP64Imm&>(other); | |||||
| auto other_ = static_cast<const FP64Imm &>(other); | |||||
| return *this == other_; | return *this == other_; | ||||
| } else { | } else { | ||||
| return false; | return false; | ||||
| } | } | ||||
| } | } | ||||
| bool ValueSequeue::operator==(const Value& other) const { | |||||
| bool ValueSequeue::operator==(const Value &other) const { | |||||
| if (other.isa<ValueSequeue>()) { | if (other.isa<ValueSequeue>()) { | ||||
| auto other_ = static_cast<const ValueSequeue&>(other); | |||||
| auto other_ = static_cast<const ValueSequeue &>(other); | |||||
| return *this == other_; | return *this == other_; | ||||
| } else { | } else { | ||||
| return false; | return false; | ||||
| } | } | ||||
| } | } | ||||
| bool ValueSequeue::operator==(const ValueSequeue& other) const { | |||||
| bool ValueSequeue::operator==(const ValueSequeue &other) const { | |||||
| if (other.elements_.size() != elements_.size()) { | if (other.elements_.size() != elements_.size()) { | ||||
| return false; | return false; | ||||
| } | } | ||||
| return std::equal(elements_.begin(), elements_.end(), other.elements_.begin(), | return std::equal(elements_.begin(), elements_.end(), other.elements_.begin(), | ||||
| [](const ValuePtr& lhs, const ValuePtr& rhs) { return *lhs == *rhs; }); | |||||
| [](const ValuePtr &lhs, const ValuePtr &rhs) { return *lhs == *rhs; }); | |||||
| } | } | ||||
| std::string ValueSequeue::ToString() const { | std::string ValueSequeue::ToString() const { | ||||
| std::ostringstream buffer; | std::ostringstream buffer; | ||||
| bool begin = true; | bool begin = true; | ||||
| for (auto& attr : elements_) { | |||||
| for (auto &attr : elements_) { | |||||
| if (!begin) { | if (!begin) { | ||||
| buffer << ", "; | buffer << ", "; | ||||
| } else { | } else { | ||||
| @@ -179,28 +179,28 @@ std::string ValueSequeue::DumpText() const { | |||||
| return oss.str(); | return oss.str(); | ||||
| } | } | ||||
| bool FP64Imm::operator==(const FP64Imm& other) const { return fabs(v_ - other.v_) < DBL_EPSILON; } | |||||
| bool StringImm::operator==(const Value& other) const { | |||||
| bool FP64Imm::operator==(const FP64Imm &other) const { return fabs(v_ - other.v_) < DBL_EPSILON; } | |||||
| bool StringImm::operator==(const Value &other) const { | |||||
| if (other.isa<StringImm>()) { | if (other.isa<StringImm>()) { | ||||
| auto other_ = static_cast<const StringImm&>(other); | |||||
| auto other_ = static_cast<const StringImm &>(other); | |||||
| return *this == other_; | return *this == other_; | ||||
| } else { | } else { | ||||
| return false; | return false; | ||||
| } | } | ||||
| } | } | ||||
| bool StringImm::operator==(const StringImm& other) const { return str_ == other.str_; } | |||||
| bool StringImm::operator==(const StringImm &other) const { return str_ == other.str_; } | |||||
| bool RefKey::operator==(const Value& other) const { | |||||
| bool RefKey::operator==(const Value &other) const { | |||||
| if (other.isa<RefKey>()) { | if (other.isa<RefKey>()) { | ||||
| auto other_ = static_cast<const RefKey&>(other); | |||||
| auto other_ = static_cast<const RefKey &>(other); | |||||
| return *this == other_; | return *this == other_; | ||||
| } else { | } else { | ||||
| return false; | return false; | ||||
| } | } | ||||
| } | } | ||||
| bool RefKey::operator==(const RefKey& other) const { return tag_ == other.tag_; } | |||||
| bool RefKey::operator==(const RefKey &other) const { return tag_ == other.tag_; } | |||||
| bool AnyValue::operator==(const Value& other) const { | |||||
| bool AnyValue::operator==(const Value &other) const { | |||||
| if (other.isa<AnyValue>()) { | if (other.isa<AnyValue>()) { | ||||
| return true; | return true; | ||||
| } else { | } else { | ||||
| @@ -228,7 +228,7 @@ abstract::AbstractBasePtr AnyValue::ToAbstract() { return std::make_shared<abstr | |||||
| abstract::AbstractBasePtr ValueTuple::ToAbstract() { | abstract::AbstractBasePtr ValueTuple::ToAbstract() { | ||||
| abstract::AbstractBasePtrList a_list; | abstract::AbstractBasePtrList a_list; | ||||
| (void)std::transform(elements_.begin(), elements_.end(), std::back_inserter(a_list), [](const ValuePtr& ele) { | |||||
| (void)std::transform(elements_.begin(), elements_.end(), std::back_inserter(a_list), [](const ValuePtr &ele) { | |||||
| MS_EXCEPTION_IF_NULL(ele); | MS_EXCEPTION_IF_NULL(ele); | ||||
| return ele->ToAbstract(); | return ele->ToAbstract(); | ||||
| }); | }); | ||||
| @@ -237,7 +237,7 @@ abstract::AbstractBasePtr ValueTuple::ToAbstract() { | |||||
| abstract::AbstractBasePtr ValueList::ToAbstract() { | abstract::AbstractBasePtr ValueList::ToAbstract() { | ||||
| abstract::AbstractBasePtrList a_list; | abstract::AbstractBasePtrList a_list; | ||||
| (void)std::transform(elements_.begin(), elements_.end(), std::back_inserter(a_list), [](const ValuePtr& ele) { | |||||
| (void)std::transform(elements_.begin(), elements_.end(), std::back_inserter(a_list), [](const ValuePtr &ele) { | |||||
| MS_EXCEPTION_IF_NULL(ele); | MS_EXCEPTION_IF_NULL(ele); | ||||
| return ele->ToAbstract(); | return ele->ToAbstract(); | ||||
| }); | }); | ||||
| @@ -251,16 +251,16 @@ std::size_t ValueSlice::hash() const { | |||||
| return hash_combine({tid(), start_->hash(), stop_->hash(), step_->hash()}); | return hash_combine({tid(), start_->hash(), stop_->hash(), step_->hash()}); | ||||
| } | } | ||||
| bool ValueSlice::operator==(const Value& other) const { | |||||
| bool ValueSlice::operator==(const Value &other) const { | |||||
| if (other.isa<ValueSlice>()) { | if (other.isa<ValueSlice>()) { | ||||
| auto other_ = static_cast<const ValueSlice&>(other); | |||||
| auto other_ = static_cast<const ValueSlice &>(other); | |||||
| return *this == other_; | return *this == other_; | ||||
| } else { | } else { | ||||
| return false; | return false; | ||||
| } | } | ||||
| } | } | ||||
| bool ValueSlice::operator==(const ValueSlice& other) const { | |||||
| bool ValueSlice::operator==(const ValueSlice &other) const { | |||||
| MS_EXCEPTION_IF_NULL(start_); | MS_EXCEPTION_IF_NULL(start_); | ||||
| MS_EXCEPTION_IF_NULL(stop_); | MS_EXCEPTION_IF_NULL(stop_); | ||||
| MS_EXCEPTION_IF_NULL(step_); | MS_EXCEPTION_IF_NULL(step_); | ||||
| @@ -295,16 +295,16 @@ std::size_t KeywordArg::hash() const { | |||||
| return hash_combine({tid(), std::hash<std::string>{}(key_), value_->hash()}); | return hash_combine({tid(), std::hash<std::string>{}(key_), value_->hash()}); | ||||
| } | } | ||||
| bool KeywordArg::operator==(const Value& other) const { | |||||
| bool KeywordArg::operator==(const Value &other) const { | |||||
| if (other.isa<KeywordArg>()) { | if (other.isa<KeywordArg>()) { | ||||
| auto other_ = static_cast<const KeywordArg&>(other); | |||||
| auto other_ = static_cast<const KeywordArg &>(other); | |||||
| return *this == other_; | return *this == other_; | ||||
| } else { | } else { | ||||
| return false; | return false; | ||||
| } | } | ||||
| } | } | ||||
| bool KeywordArg::operator==(const KeywordArg& other) const { return (other.key_ == key_ && *other.value_ == *value_); } | |||||
| bool KeywordArg::operator==(const KeywordArg &other) const { return (other.key_ == key_ && *other.value_ == *value_); } | |||||
| std::string KeywordArg::ToString() const { | std::string KeywordArg::ToString() const { | ||||
| std::ostringstream buffer; | std::ostringstream buffer; | ||||
| @@ -322,25 +322,25 @@ abstract::AbstractBasePtr KeywordArg::ToAbstract() { | |||||
| return std::make_shared<abstract::AbstractKeywordArg>(key_, argument); | return std::make_shared<abstract::AbstractKeywordArg>(key_, argument); | ||||
| } | } | ||||
| const ValuePtr ValueDictionary::operator[](const std::string& key) const { | |||||
| const ValuePtr ValueDictionary::operator[](const std::string &key) const { | |||||
| auto it = std::find_if(key_values_.begin(), key_values_.end(), | auto it = std::find_if(key_values_.begin(), key_values_.end(), | ||||
| [key](const std::pair<std::string, ValuePtr>& item) { return item.first == key; }); | |||||
| [key](const std::pair<std::string, ValuePtr> &item) { return item.first == key; }); | |||||
| if (it == key_values_.end()) { | if (it == key_values_.end()) { | ||||
| MS_LOG(EXCEPTION) << "The key " << key << " is not in the map"; | MS_LOG(EXCEPTION) << "The key " << key << " is not in the map"; | ||||
| } | } | ||||
| return it->second; | return it->second; | ||||
| } | } | ||||
| bool ValueDictionary::operator==(const Value& other) const { | |||||
| bool ValueDictionary::operator==(const Value &other) const { | |||||
| if (other.isa<ValueDictionary>()) { | if (other.isa<ValueDictionary>()) { | ||||
| auto other_ = static_cast<const ValueDictionary&>(other); | |||||
| auto other_ = static_cast<const ValueDictionary &>(other); | |||||
| return *this == other_; | return *this == other_; | ||||
| } else { | } else { | ||||
| return false; | return false; | ||||
| } | } | ||||
| } | } | ||||
| bool ValueDictionary::operator==(const ValueDictionary& other) const { | |||||
| bool ValueDictionary::operator==(const ValueDictionary &other) const { | |||||
| if (key_values_.size() != other.key_values_.size()) { | if (key_values_.size() != other.key_values_.size()) { | ||||
| return false; | return false; | ||||
| } | } | ||||
| @@ -359,12 +359,12 @@ abstract::AbstractBasePtr ValueDictionary::ToAbstract() { | |||||
| std::vector<std::pair<std::string, abstract::AbstractBasePtr>> kv; | std::vector<std::pair<std::string, abstract::AbstractBasePtr>> kv; | ||||
| (void)std::transform( | (void)std::transform( | ||||
| key_values_.begin(), key_values_.end(), std::back_inserter(kv), | key_values_.begin(), key_values_.end(), std::back_inserter(kv), | ||||
| [](const std::pair<std::string, ValuePtr>& item) { return std::make_pair(item.first, item.second->ToAbstract()); }); | |||||
| [](const std::pair<std::string, ValuePtr> &item) { return std::make_pair(item.first, item.second->ToAbstract()); }); | |||||
| return std::make_shared<abstract::AbstractDictionary>(kv); | return std::make_shared<abstract::AbstractDictionary>(kv); | ||||
| } | } | ||||
| REGISTER_PYBIND_DEFINE( | REGISTER_PYBIND_DEFINE( | ||||
| RefKey, ([](const py::module* m) { | |||||
| RefKey, ([](const py::module *m) { | |||||
| (void)py::class_<RefKey, std::shared_ptr<RefKey>>(*m, "RefKey").def(py::init<std::string>(), py::arg("tag")); | (void)py::class_<RefKey, std::shared_ptr<RefKey>>(*m, "RefKey").def(py::init<std::string>(), py::arg("tag")); | ||||
| })); | })); | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -35,19 +35,19 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| class ValueSequeue : public Value { | class ValueSequeue : public Value { | ||||
| public: | public: | ||||
| explicit ValueSequeue(const ValuePtrList& elements) : elements_(elements) { | |||||
| explicit ValueSequeue(const ValuePtrList &elements) : elements_(elements) { | |||||
| TypePtrList t_list; | TypePtrList t_list; | ||||
| (void)std::transform(elements.begin(), elements.end(), std::back_inserter(t_list), [](const ValuePtr& ele) { | |||||
| (void)std::transform(elements.begin(), elements.end(), std::back_inserter(t_list), [](const ValuePtr &ele) { | |||||
| MS_EXCEPTION_IF_NULL(ele); | MS_EXCEPTION_IF_NULL(ele); | ||||
| return ele->type(); | return ele->type(); | ||||
| }); | }); | ||||
| TypePtr t = std::make_shared<Tuple>(t_list); | TypePtr t = std::make_shared<Tuple>(t_list); | ||||
| type_ = t; | type_ = t; | ||||
| } | } | ||||
| ValueSequeue(const std::initializer_list<ValuePtr>& elements) : elements_(elements.begin(), elements.end()) { | |||||
| ValueSequeue(const std::initializer_list<ValuePtr> &elements) : elements_(elements.begin(), elements.end()) { | |||||
| TypePtrList t_list; | TypePtrList t_list; | ||||
| (void)std::transform(elements_.begin(), elements_.end(), std::back_inserter(t_list), | (void)std::transform(elements_.begin(), elements_.end(), std::back_inserter(t_list), | ||||
| [](const ValuePtr& ele) { return ele->type(); }); | |||||
| [](const ValuePtr &ele) { return ele->type(); }); | |||||
| TypePtr t = std::make_shared<Tuple>(t_list); | TypePtr t = std::make_shared<Tuple>(t_list); | ||||
| type_ = t; | type_ = t; | ||||
| } | } | ||||
| @@ -56,10 +56,10 @@ class ValueSequeue : public Value { | |||||
| std::size_t hash() const override { return hash_combine(tid(), std::hash<std::size_t>{}(elements_.size())); } | std::size_t hash() const override { return hash_combine(tid(), std::hash<std::size_t>{}(elements_.size())); } | ||||
| std::size_t size() const { return elements_.size(); } | std::size_t size() const { return elements_.size(); } | ||||
| bool erase(size_t idx); | bool erase(size_t idx); | ||||
| const ValuePtr operator[](const std::size_t& dim) const; | |||||
| const ValuePtrList& value() const { return elements_; } | |||||
| bool operator==(const Value& other) const override; | |||||
| bool operator==(const ValueSequeue& other) const; | |||||
| const ValuePtr operator[](const std::size_t &dim) const; | |||||
| const ValuePtrList &value() const { return elements_; } | |||||
| bool operator==(const Value &other) const override; | |||||
| bool operator==(const ValueSequeue &other) const; | |||||
| std::string ToString() const override; | std::string ToString() const override; | ||||
| std::string DumpText() const override; | std::string DumpText() const override; | ||||
| @@ -70,8 +70,8 @@ using ValueSequeuePtr = std::shared_ptr<ValueSequeue>; | |||||
| class ValueTuple : public ValueSequeue { | class ValueTuple : public ValueSequeue { | ||||
| public: | public: | ||||
| explicit ValueTuple(const std::vector<ValuePtr>& elements) : ValueSequeue(elements) {} | |||||
| ValueTuple(const std::initializer_list<ValuePtr>& elements) : ValueSequeue(elements) {} | |||||
| explicit ValueTuple(const std::vector<ValuePtr> &elements) : ValueSequeue(elements) {} | |||||
| ValueTuple(const std::initializer_list<ValuePtr> &elements) : ValueSequeue(elements) {} | |||||
| ~ValueTuple() override = default; | ~ValueTuple() override = default; | ||||
| MS_DECLARE_PARENT(ValueTuple, ValueSequeue) | MS_DECLARE_PARENT(ValueTuple, ValueSequeue) | ||||
| abstract::AbstractBasePtr ToAbstract() override; | abstract::AbstractBasePtr ToAbstract() override; | ||||
| @@ -83,8 +83,8 @@ using ValueTuplePtr = std::shared_ptr<ValueTuple>; | |||||
| class ValueList : public ValueSequeue { | class ValueList : public ValueSequeue { | ||||
| public: | public: | ||||
| explicit ValueList(const std::vector<ValuePtr>& elements) : ValueSequeue(elements) {} | |||||
| ValueList(const std::initializer_list<ValuePtr>& elements) : ValueSequeue(elements) {} | |||||
| explicit ValueList(const std::vector<ValuePtr> &elements) : ValueSequeue(elements) {} | |||||
| ValueList(const std::initializer_list<ValuePtr> &elements) : ValueSequeue(elements) {} | |||||
| ~ValueList() override = default; | ~ValueList() override = default; | ||||
| MS_DECLARE_PARENT(ValueList, ValueSequeue) | MS_DECLARE_PARENT(ValueList, ValueSequeue) | ||||
| abstract::AbstractBasePtr ToAbstract() override; | abstract::AbstractBasePtr ToAbstract() override; | ||||
| @@ -94,7 +94,7 @@ class ValueList : public ValueSequeue { | |||||
| }; | }; | ||||
| using ValueListPtr = std::shared_ptr<ValueList>; | using ValueListPtr = std::shared_ptr<ValueList>; | ||||
| inline ValuePtr MakeValue(const std::vector<ValuePtr>& v) { return std::make_shared<ValueTuple>(v); } | |||||
| inline ValuePtr MakeValue(const std::vector<ValuePtr> &v) { return std::make_shared<ValueTuple>(v); } | |||||
| inline ValuePtr MakeValue(std::initializer_list<ValuePtr> v) { return std::make_shared<ValueTuple>(v); } | inline ValuePtr MakeValue(std::initializer_list<ValuePtr> v) { return std::make_shared<ValueTuple>(v); } | ||||
| template <typename T> | template <typename T> | ||||
| @@ -103,7 +103,7 @@ template <typename T, typename A> | |||||
| struct is_vector<std::vector<T, A>> : public std::true_type {}; | struct is_vector<std::vector<T, A>> : public std::true_type {}; | ||||
| template <typename T, typename U = typename std::enable_if<is_vector<T>::value, typename T::value_type>::type> | template <typename T, typename U = typename std::enable_if<is_vector<T>::value, typename T::value_type>::type> | ||||
| ValuePtr MakeValue(const T& vec) { | |||||
| ValuePtr MakeValue(const T &vec) { | |||||
| std::vector<ValuePtr> list; | std::vector<ValuePtr> list; | ||||
| (void)std::transform(vec.begin(), vec.end(), std::back_inserter(list), [](U ele) { return MakeValue(ele); }); | (void)std::transform(vec.begin(), vec.end(), std::back_inserter(list), [](U ele) { return MakeValue(ele); }); | ||||
| return std::make_shared<ValueTuple>(list); | return std::make_shared<ValueTuple>(list); | ||||
| @@ -111,13 +111,13 @@ ValuePtr MakeValue(const T& vec) { | |||||
| class ValueSlice : public Value { | class ValueSlice : public Value { | ||||
| public: | public: | ||||
| ValueSlice(const ValuePtr& start, const ValuePtr& stop, const ValuePtr& step) | |||||
| ValueSlice(const ValuePtr &start, const ValuePtr &stop, const ValuePtr &step) | |||||
| : start_(start), stop_(stop), step_(step) {} | : start_(start), stop_(stop), step_(step) {} | ||||
| ~ValueSlice() override = default; | ~ValueSlice() override = default; | ||||
| MS_DECLARE_PARENT(ValueSlice, Value) | MS_DECLARE_PARENT(ValueSlice, Value) | ||||
| std::size_t hash() const override; | std::size_t hash() const override; | ||||
| bool operator==(const Value& other) const override; | |||||
| bool operator==(const ValueSlice& other) const; | |||||
| bool operator==(const Value &other) const override; | |||||
| bool operator==(const ValueSlice &other) const; | |||||
| std::string ToString() const override; | std::string ToString() const override; | ||||
| @@ -133,13 +133,13 @@ using ValueSlicePtr = std::shared_ptr<ValueSlice>; | |||||
| class KeywordArg : public Value { | class KeywordArg : public Value { | ||||
| public: | public: | ||||
| KeywordArg(const std::string& key, const ValuePtr& value) : key_(key), value_(value) {} | |||||
| KeywordArg(const std::string &key, const ValuePtr &value) : key_(key), value_(value) {} | |||||
| ~KeywordArg() override = default; | ~KeywordArg() override = default; | ||||
| MS_DECLARE_PARENT(KeywordArg, Value) | MS_DECLARE_PARENT(KeywordArg, Value) | ||||
| std::size_t hash() const override; | std::size_t hash() const override; | ||||
| ValuePtr get_value() const { return value_; } | ValuePtr get_value() const { return value_; } | ||||
| bool operator==(const Value& other) const override; | |||||
| bool operator==(const KeywordArg& other) const; | |||||
| bool operator==(const Value &other) const override; | |||||
| bool operator==(const KeywordArg &other) const; | |||||
| std::string ToString() const override; | std::string ToString() const override; | ||||
| @@ -154,31 +154,31 @@ using KeywordArgPtr = std::shared_ptr<KeywordArg>; | |||||
| class ValueDictionary : public Value { | class ValueDictionary : public Value { | ||||
| public: | public: | ||||
| explicit ValueDictionary(const std::vector<std::pair<std::string, ValuePtr>>& key_values) : key_values_(key_values) {} | |||||
| explicit ValueDictionary(const std::vector<std::pair<std::string, ValuePtr>> &key_values) : key_values_(key_values) {} | |||||
| ~ValueDictionary() override = default; | ~ValueDictionary() override = default; | ||||
| MS_DECLARE_PARENT(ValueDictionary, Value) | MS_DECLARE_PARENT(ValueDictionary, Value) | ||||
| std::size_t hash() const override { return hash_combine(tid(), std::hash<std::size_t>{}(key_values_.size())); } | std::size_t hash() const override { return hash_combine(tid(), std::hash<std::size_t>{}(key_values_.size())); } | ||||
| std::size_t size() const { return key_values_.size(); } | std::size_t size() const { return key_values_.size(); } | ||||
| const ValuePtr operator[](const std::string& key) const; | |||||
| const std::vector<std::pair<std::string, ValuePtr>>& value() const { return key_values_; } | |||||
| bool operator==(const Value& other) const override; | |||||
| bool operator==(const ValueDictionary& other) const; | |||||
| const ValuePtr operator[](const std::string &key) const; | |||||
| const std::vector<std::pair<std::string, ValuePtr>> &value() const { return key_values_; } | |||||
| bool operator==(const Value &other) const override; | |||||
| bool operator==(const ValueDictionary &other) const; | |||||
| std::string ToString() const override { | std::string ToString() const override { | ||||
| std::ostringstream buffer; | std::ostringstream buffer; | ||||
| std::vector<std::string> keys; | std::vector<std::string> keys; | ||||
| std::vector<ValuePtr> values; | std::vector<ValuePtr> values; | ||||
| for (const auto& kv : key_values_) { | |||||
| for (const auto &kv : key_values_) { | |||||
| keys.push_back(kv.first); | keys.push_back(kv.first); | ||||
| values.push_back(kv.second); | values.push_back(kv.second); | ||||
| } | } | ||||
| buffer << "(Dict: " | buffer << "(Dict: " | ||||
| << " keys:("; | << " keys:("; | ||||
| for (const auto& key : keys) { | |||||
| for (const auto &key : keys) { | |||||
| buffer << key << ", "; | buffer << key << ", "; | ||||
| } | } | ||||
| buffer << ") values:("; | buffer << ") values:("; | ||||
| for (const auto& value : values) { | |||||
| for (const auto &value : values) { | |||||
| MS_EXCEPTION_IF_NULL(value); | MS_EXCEPTION_IF_NULL(value); | ||||
| buffer << value->DumpText() << ", "; | buffer << value->DumpText() << ", "; | ||||
| } | } | ||||
| @@ -195,14 +195,14 @@ using ValueDictionaryPtr = std::shared_ptr<ValueDictionary>; | |||||
| class StringImm : public Value { | class StringImm : public Value { | ||||
| public: | public: | ||||
| explicit StringImm(const std::string& str) : Value(kString), str_(str), hash_(std::hash<std::string>{}(str_)) {} | |||||
| explicit StringImm(const std::string &str) : Value(kString), str_(str), hash_(std::hash<std::string>{}(str_)) {} | |||||
| ~StringImm() override = default; | ~StringImm() override = default; | ||||
| MS_DECLARE_PARENT(StringImm, Value) | MS_DECLARE_PARENT(StringImm, Value) | ||||
| std::size_t hash() const override { return hash_; } | std::size_t hash() const override { return hash_; } | ||||
| const std::string& value() const { return str_; } | |||||
| bool operator==(const Value& other) const override; | |||||
| bool operator==(const StringImm& other) const; | |||||
| const std::string &value() const { return str_; } | |||||
| bool operator==(const Value &other) const override; | |||||
| bool operator==(const StringImm &other) const; | |||||
| abstract::AbstractBasePtr ToAbstract() override; | abstract::AbstractBasePtr ToAbstract() override; | ||||
| std::string ToString() const override { return str_; } | std::string ToString() const override { return str_; } | ||||
| @@ -218,18 +218,18 @@ class StringImm : public Value { | |||||
| }; | }; | ||||
| using StringImmPtr = std::shared_ptr<StringImm>; | using StringImmPtr = std::shared_ptr<StringImm>; | ||||
| IMM_TRAITS(StringImmPtr, std::string) | IMM_TRAITS(StringImmPtr, std::string) | ||||
| IMM_TRAITS(StringImmPtr, const char*) | |||||
| IMM_TRAITS(StringImmPtr, const char *) | |||||
| class RefKey : public Value { | class RefKey : public Value { | ||||
| public: | public: | ||||
| explicit RefKey(const std::string& tag) : Value(kRefKeyType), tag_(tag), hash_(std::hash<std::string>{}(tag)) {} | |||||
| explicit RefKey(const std::string &tag) : Value(kRefKeyType), tag_(tag), hash_(std::hash<std::string>{}(tag)) {} | |||||
| ~RefKey() override = default; | ~RefKey() override = default; | ||||
| MS_DECLARE_PARENT(RefKey, Value) | MS_DECLARE_PARENT(RefKey, Value) | ||||
| std::size_t hash() const override { return hash_; } | std::size_t hash() const override { return hash_; } | ||||
| const std::string& tag() const { return tag_; } | |||||
| bool operator==(const Value& other) const override; | |||||
| bool operator==(const RefKey& other) const; | |||||
| const std::string &tag() const { return tag_; } | |||||
| bool operator==(const Value &other) const override; | |||||
| bool operator==(const RefKey &other) const; | |||||
| abstract::AbstractBasePtr ToAbstract() override; | abstract::AbstractBasePtr ToAbstract() override; | ||||
| std::string ToString() const override { return "RefKey[" + tag_ + "]"; } | std::string ToString() const override { return "RefKey[" + tag_ + "]"; } | ||||
| @@ -251,13 +251,13 @@ class AnyValue : public Value { | |||||
| ~AnyValue() override = default; | ~AnyValue() override = default; | ||||
| MS_DECLARE_PARENT(AnyValue, Value) | MS_DECLARE_PARENT(AnyValue, Value) | ||||
| std::size_t hash() const override { return tid(); } | std::size_t hash() const override { return tid(); } | ||||
| bool operator==(const Value& other) const override; | |||||
| bool operator==(const Value &other) const override; | |||||
| abstract::AbstractBasePtr ToAbstract() override; | abstract::AbstractBasePtr ToAbstract() override; | ||||
| }; | }; | ||||
| extern const ValuePtr kAnyValue; | extern const ValuePtr kAnyValue; | ||||
| template <> | template <> | ||||
| inline const char* GetValue(const ValuePtr& value) { | |||||
| inline const char *GetValue(const ValuePtr &value) { | |||||
| if (value == nullptr) { | if (value == nullptr) { | ||||
| MS_LOG(EXCEPTION) << "Value is nullptr"; | MS_LOG(EXCEPTION) << "Value is nullptr"; | ||||
| } | } | ||||
| @@ -270,7 +270,7 @@ inline const char* GetValue(const ValuePtr& value) { | |||||
| template <typename T, typename S = typename std::decay<T>::type, | template <typename T, typename S = typename std::decay<T>::type, | ||||
| typename U = typename std::enable_if<is_vector<S>::value, typename S::value_type>::type> | typename U = typename std::enable_if<is_vector<S>::value, typename S::value_type>::type> | ||||
| std::vector<U> GetValue(const ValuePtr& value) { | |||||
| std::vector<U> GetValue(const ValuePtr &value) { | |||||
| if (value == nullptr) { | if (value == nullptr) { | ||||
| MS_LOG(EXCEPTION) << "Value is nullptr"; | MS_LOG(EXCEPTION) << "Value is nullptr"; | ||||
| } | } | ||||
| @@ -280,21 +280,21 @@ std::vector<U> GetValue(const ValuePtr& value) { | |||||
| << ">"; | << ">"; | ||||
| } | } | ||||
| std::vector<U> rets; | std::vector<U> rets; | ||||
| const std::vector<ValuePtr>& vals = value->cast<ValueSequeuePtr>()->value(); | |||||
| const std::vector<ValuePtr> &vals = value->cast<ValueSequeuePtr>()->value(); | |||||
| (void)std::transform(vals.begin(), vals.end(), std::back_inserter(rets), | (void)std::transform(vals.begin(), vals.end(), std::back_inserter(rets), | ||||
| [](const ValuePtr& v) { return GetValue<U>(v); }); | |||||
| [](const ValuePtr &v) { return GetValue<U>(v); }); | |||||
| return rets; | return rets; | ||||
| } | } | ||||
| inline ValueNodePtr NewValueNode(const ValuePtr& t) { return std::make_shared<ValueNode>(t); } | |||||
| inline ValueNodePtr NewValueNode(const ValuePtr &t) { return std::make_shared<ValueNode>(t); } | |||||
| template <typename T, typename _ = typename std::enable_if<!std::is_base_of<Value, T>::value>::type> | template <typename T, typename _ = typename std::enable_if<!std::is_base_of<Value, T>::value>::type> | ||||
| inline ValueNodePtr NewValueNode(const std::shared_ptr<T>& x) { | |||||
| inline ValueNodePtr NewValueNode(const std::shared_ptr<T> &x) { | |||||
| return NewValueNode(MakeValue(x)); | return NewValueNode(MakeValue(x)); | ||||
| } | } | ||||
| template <typename T, typename _ = typename std::enable_if<!is_shared_ptr<T>::value>::type> | template <typename T, typename _ = typename std::enable_if<!is_shared_ptr<T>::value>::type> | ||||
| inline ValueNodePtr NewValueNode(const T& x) { | |||||
| inline ValueNodePtr NewValueNode(const T &x) { | |||||
| return NewValueNode(MakeValue(x)); | return NewValueNode(MakeValue(x)); | ||||
| } | } | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -22,15 +22,15 @@ | |||||
| #include "optimizer/opt.h" | #include "optimizer/opt.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| using VisitFuncType = std::function<void(const AnfNodePtr&)>; | |||||
| using VisitFuncType = std::function<void(const AnfNodePtr &)>; | |||||
| class AnfVisitor { | class AnfVisitor { | ||||
| public: | public: | ||||
| virtual AnfNodePtr operator()(const opt::OptimizerPtr&, const AnfNodePtr&); | |||||
| virtual void Visit(const AnfNodePtr&); | |||||
| virtual void Visit(const CNodePtr&); | |||||
| virtual void Visit(const ValueNodePtr&); | |||||
| virtual void Visit(const ParameterPtr&); | |||||
| VisitFuncType Match(const PrimitivePtr&, const std::vector<opt::PredicateFuncType>& = {}); | |||||
| virtual AnfNodePtr operator()(const opt::OptimizerPtr &, const AnfNodePtr &); | |||||
| virtual void Visit(const AnfNodePtr &); | |||||
| virtual void Visit(const CNodePtr &); | |||||
| virtual void Visit(const ValueNodePtr &); | |||||
| virtual void Visit(const ParameterPtr &); | |||||
| VisitFuncType Match(const PrimitivePtr &, const std::vector<opt::PredicateFuncType> & = {}); | |||||
| virtual ~AnfVisitor() = default; | virtual ~AnfVisitor() = default; | ||||
| }; | }; | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -26,12 +26,12 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace kernel { | namespace kernel { | ||||
| namespace { | namespace { | ||||
| void FilterInvaildKernelInfo(const CNodePtr& kernel_node, | |||||
| std::vector<std::shared_ptr<kernel::KernelBuildInfo>>* kernel_info_list) { | |||||
| void FilterInvaildKernelInfo(const CNodePtr &kernel_node, | |||||
| std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list) { | |||||
| MS_EXCEPTION_IF_NULL(kernel_info_list); | MS_EXCEPTION_IF_NULL(kernel_info_list); | ||||
| std::vector<std::shared_ptr<kernel::KernelBuildInfo>> filtered_list; | std::vector<std::shared_ptr<kernel::KernelBuildInfo>> filtered_list; | ||||
| (void)std::copy_if(kernel_info_list->begin(), kernel_info_list->end(), std::back_inserter(filtered_list), | (void)std::copy_if(kernel_info_list->begin(), kernel_info_list->end(), std::back_inserter(filtered_list), | ||||
| [&](const std::shared_ptr<kernel::KernelBuildInfo>& kernel_build_info) { | |||||
| [&](const std::shared_ptr<kernel::KernelBuildInfo> &kernel_build_info) { | |||||
| return AnfAlgo::GetOutputTensorNum(kernel_node) == kernel_build_info->GetOutputNum() && | return AnfAlgo::GetOutputTensorNum(kernel_node) == kernel_build_info->GetOutputNum() && | ||||
| AnfAlgo::GetInputTensorNum(kernel_node) == kernel_build_info->GetInputNum(); | AnfAlgo::GetInputTensorNum(kernel_node) == kernel_build_info->GetInputNum(); | ||||
| }); | }); | ||||
| @@ -46,7 +46,7 @@ void FilterInvaildKernelInfo(const CNodePtr& kernel_node, | |||||
| } | } | ||||
| } | } | ||||
| } // namespace | } // namespace | ||||
| void KernelQuery(const CNodePtr& kernel_node, std::vector<std::shared_ptr<kernel::KernelBuildInfo>>* kernel_info_list) { | |||||
| void KernelQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list) { | |||||
| MS_EXCEPTION_IF_NULL(kernel_node); | MS_EXCEPTION_IF_NULL(kernel_node); | ||||
| MS_EXCEPTION_IF_NULL(kernel_info_list); | MS_EXCEPTION_IF_NULL(kernel_info_list); | ||||
| TbeMetadataInfo(kernel_node, kernel_info_list); | TbeMetadataInfo(kernel_node, kernel_info_list); | ||||
| @@ -38,11 +38,11 @@ class OpAttr { | |||||
| std::string value() const { return value_; } | std::string value() const { return value_; } | ||||
| std::string default_value() const { return default_value_; } | std::string default_value() const { return default_value_; } | ||||
| void set_name(const std::string& name) { name_ = name; } | |||||
| void set_param_type(const std::string& param_type) { param_type_ = param_type; } | |||||
| void set_type(const std::string& type) { type_ = type; } | |||||
| void set_value(const std::string& value) { value_ = value; } | |||||
| void set_default_value(const std::string& default_value) { default_value_ = default_value; } | |||||
| void set_name(const std::string &name) { name_ = name; } | |||||
| void set_param_type(const std::string ¶m_type) { param_type_ = param_type; } | |||||
| void set_type(const std::string &type) { type_ = type; } | |||||
| void set_value(const std::string &value) { value_ = value; } | |||||
| void set_default_value(const std::string &default_value) { default_value_ = default_value; } | |||||
| private: | private: | ||||
| std::string name_; | std::string name_; | ||||
| @@ -67,13 +67,13 @@ class OpIOInfo { | |||||
| std::vector<std::string> formats() const { return formats_; } | std::vector<std::string> formats() const { return formats_; } | ||||
| void set_index(const int index) { index_ = index; } | void set_index(const int index) { index_ = index; } | ||||
| void set_name(const std::string& name) { name_ = name; } | |||||
| void set_name(const std::string &name) { name_ = name; } | |||||
| void set_need_compile(const bool need_compile) { need_compile_ = need_compile; } | void set_need_compile(const bool need_compile) { need_compile_ = need_compile; } | ||||
| void set_param_type(const std::string& param_type) { param_type_ = param_type; } | |||||
| void set_reshape_type(const std::string& reshape_type) { reshape_type_ = reshape_type; } | |||||
| void set_shape(const std::string& shape) { shape_ = shape; } | |||||
| void set_dtypes(const std::vector<std::string>& dtype) { dtypes_ = dtype; } | |||||
| void set_formats(const std::vector<std::string>& formats) { formats_ = formats; } | |||||
| void set_param_type(const std::string ¶m_type) { param_type_ = param_type; } | |||||
| void set_reshape_type(const std::string &reshape_type) { reshape_type_ = reshape_type; } | |||||
| void set_shape(const std::string &shape) { shape_ = shape; } | |||||
| void set_dtypes(const std::vector<std::string> &dtype) { dtypes_ = dtype; } | |||||
| void set_formats(const std::vector<std::string> &formats) { formats_ = formats; } | |||||
| private: | private: | ||||
| int index_ = 0; | int index_ = 0; | ||||
| @@ -104,24 +104,24 @@ class OpInfo { | |||||
| std::vector<std::shared_ptr<OpAttr>> attrs_ptr() const { return attrs_ptr_; } | std::vector<std::shared_ptr<OpAttr>> attrs_ptr() const { return attrs_ptr_; } | ||||
| std::vector<std::shared_ptr<OpIOInfo>> inputs_ptr() const { return inputs_ptr_; } | std::vector<std::shared_ptr<OpIOInfo>> inputs_ptr() const { return inputs_ptr_; } | ||||
| std::vector<std::shared_ptr<OpIOInfo>> outputs_ptr() const { return outputs_ptr_; } | std::vector<std::shared_ptr<OpIOInfo>> outputs_ptr() const { return outputs_ptr_; } | ||||
| const std::unordered_map<size_t, size_t>& ref_infos() const { return ref_infos_; } | |||||
| const std::unordered_map<size_t, size_t> &ref_infos() const { return ref_infos_; } | |||||
| void set_op_name(const std::string& op_name) { op_name_ = op_name; } | |||||
| void set_op_name(const std::string &op_name) { op_name_ = op_name; } | |||||
| void set_imply_type(const OpImplyType imply_type) { imply_type_ = imply_type; } | void set_imply_type(const OpImplyType imply_type) { imply_type_ = imply_type; } | ||||
| void set_impl_path(const std::string& impl_path) { impl_path_ = impl_path; } | |||||
| void set_fusion_type(const std::string& fusion_type) { fusion_type_ = fusion_type; } | |||||
| void set_impl_path(const std::string &impl_path) { impl_path_ = impl_path; } | |||||
| void set_fusion_type(const std::string &fusion_type) { fusion_type_ = fusion_type; } | |||||
| void set_async_flag(const bool async_flag) { async_flag_ = async_flag; } | void set_async_flag(const bool async_flag) { async_flag_ = async_flag; } | ||||
| void set_binfile_name(const std::string& binfile_name) { binfile_name_ = binfile_name; } | |||||
| void set_binfile_name(const std::string &binfile_name) { binfile_name_ = binfile_name; } | |||||
| void set_compute_cost(const int compute_cost) { compute_cost_ = compute_cost; } | void set_compute_cost(const int compute_cost) { compute_cost_ = compute_cost; } | ||||
| void set_kernel_name(const std::string& kernel_name) { kernel_name_ = kernel_name; } | |||||
| void set_kernel_name(const std::string &kernel_name) { kernel_name_ = kernel_name; } | |||||
| void set_partial_flag(const bool partial_flag) { partial_flag_ = partial_flag; } | void set_partial_flag(const bool partial_flag) { partial_flag_ = partial_flag; } | ||||
| void set_dynamic_format(const bool dynamic_format) { dynamic_format_ = dynamic_format; } | void set_dynamic_format(const bool dynamic_format) { dynamic_format_ = dynamic_format; } | ||||
| void set_op_pattern(const std::string op_pattern) { op_pattern_ = op_pattern; } | void set_op_pattern(const std::string op_pattern) { op_pattern_ = op_pattern; } | ||||
| void add_attrs_ptr(const std::shared_ptr<OpAttr>& attr) { attrs_ptr_.push_back(attr); } | |||||
| void add_inputs_ptr(const std::shared_ptr<OpIOInfo>& input) { inputs_ptr_.push_back(input); } | |||||
| void add_outputs_ptr(const std::shared_ptr<OpIOInfo>& output) { outputs_ptr_.push_back(output); } | |||||
| void set_inputs_ptr(const std::vector<std::shared_ptr<OpIOInfo>>& inputs) { inputs_ptr_ = inputs; } | |||||
| void set_outputs_ptr(const std::vector<std::shared_ptr<OpIOInfo>>& outputs) { outputs_ptr_ = outputs; } | |||||
| void add_attrs_ptr(const std::shared_ptr<OpAttr> &attr) { attrs_ptr_.push_back(attr); } | |||||
| void add_inputs_ptr(const std::shared_ptr<OpIOInfo> &input) { inputs_ptr_.push_back(input); } | |||||
| void add_outputs_ptr(const std::shared_ptr<OpIOInfo> &output) { outputs_ptr_.push_back(output); } | |||||
| void set_inputs_ptr(const std::vector<std::shared_ptr<OpIOInfo>> &inputs) { inputs_ptr_ = inputs; } | |||||
| void set_outputs_ptr(const std::vector<std::shared_ptr<OpIOInfo>> &outputs) { outputs_ptr_ = outputs; } | |||||
| bool is_ref() const { return !ref_infos_.empty(); } | bool is_ref() const { return !ref_infos_.empty(); } | ||||
| bool has_ref_index(size_t out_index) const { return ref_infos_.find(out_index) != ref_infos_.end(); } | bool has_ref_index(size_t out_index) const { return ref_infos_.find(out_index) != ref_infos_.end(); } | ||||
| void add_ref_pair(size_t out_index, size_t in_index) { (void)ref_infos_.emplace(out_index, in_index); } | void add_ref_pair(size_t out_index, size_t in_index) { (void)ref_infos_.emplace(out_index, in_index); } | ||||
| @@ -67,7 +67,7 @@ std::string ImplTypeToStr(OpImplyType impl_type) { | |||||
| return "unknow"; | return "unknow"; | ||||
| } | } | ||||
| } | } | ||||
| bool OpLib::RegOp(const std::string& json_string, const std::string& impl_path) { | |||||
| bool OpLib::RegOp(const std::string &json_string, const std::string &impl_path) { | |||||
| bool ret = false; | bool ret = false; | ||||
| try { | try { | ||||
| auto op_json = nlohmann::json::parse(json_string); | auto op_json = nlohmann::json::parse(json_string); | ||||
| @@ -88,13 +88,13 @@ bool OpLib::RegOp(const std::string& json_string, const std::string& impl_path) | |||||
| if (!ret) { | if (!ret) { | ||||
| MS_LOG(DEBUG) << "RegOp failed: opname:" << op_name << "imply_type" << imply_type_string; | MS_LOG(DEBUG) << "RegOp failed: opname:" << op_name << "imply_type" << imply_type_string; | ||||
| } | } | ||||
| } catch (const std::exception& e) { | |||||
| } catch (const std::exception &e) { | |||||
| MS_LOG(DEBUG) << "get op_json elements failed:" << e.what(); | MS_LOG(DEBUG) << "get op_json elements failed:" << e.what(); | ||||
| } | } | ||||
| return ret; | return ret; | ||||
| } | } | ||||
| void OpLib::DecodeTBESpecificInfo(const nlohmann::json& obj, const std::shared_ptr<OpInfo>& op_info) { | |||||
| void OpLib::DecodeTBESpecificInfo(const nlohmann::json &obj, const std::shared_ptr<OpInfo> &op_info) { | |||||
| op_info->set_async_flag(obj.at(kAsyncFlag)); | op_info->set_async_flag(obj.at(kAsyncFlag)); | ||||
| op_info->set_binfile_name(obj.at(kBinfileName)); | op_info->set_binfile_name(obj.at(kBinfileName)); | ||||
| op_info->set_compute_cost(obj.at(kComputeCost)); | op_info->set_compute_cost(obj.at(kComputeCost)); | ||||
| @@ -108,8 +108,8 @@ void OpLib::DecodeTBESpecificInfo(const nlohmann::json& obj, const std::shared_p | |||||
| } | } | ||||
| } | } | ||||
| bool OpLib::DecodeOpInfo(const nlohmann::json& obj, const mindspore::kernel::OpImplyType imply_type, | |||||
| const std::string& impl_path) { | |||||
| bool OpLib::DecodeOpInfo(const nlohmann::json &obj, const mindspore::kernel::OpImplyType imply_type, | |||||
| const std::string &impl_path) { | |||||
| std::shared_ptr<OpInfo> op_info = std::make_shared<OpInfo>(); | std::shared_ptr<OpInfo> op_info = std::make_shared<OpInfo>(); | ||||
| MS_EXCEPTION_IF_NULL(op_info); | MS_EXCEPTION_IF_NULL(op_info); | ||||
| op_info->set_op_name(obj.at(kOpName)); | op_info->set_op_name(obj.at(kOpName)); | ||||
| @@ -120,7 +120,7 @@ bool OpLib::DecodeOpInfo(const nlohmann::json& obj, const mindspore::kernel::OpI | |||||
| DecodeTBESpecificInfo(obj, op_info); | DecodeTBESpecificInfo(obj, op_info); | ||||
| } | } | ||||
| auto attrs = obj.at(kAttr); | auto attrs = obj.at(kAttr); | ||||
| for (const auto& attr : attrs) { | |||||
| for (const auto &attr : attrs) { | |||||
| if (!DecodeAttr(attr, imply_type, op_info)) { | if (!DecodeAttr(attr, imply_type, op_info)) { | ||||
| MS_LOG(DEBUG) << "DecodeAttr Failed"; | MS_LOG(DEBUG) << "DecodeAttr Failed"; | ||||
| return false; | return false; | ||||
| @@ -131,14 +131,14 @@ bool OpLib::DecodeOpInfo(const nlohmann::json& obj, const mindspore::kernel::OpI | |||||
| dtype_format = obj.at(kDtypeFormat); | dtype_format = obj.at(kDtypeFormat); | ||||
| } | } | ||||
| auto inputs = obj.at(kIputs); | auto inputs = obj.at(kIputs); | ||||
| for (const auto& input : inputs) { | |||||
| for (const auto &input : inputs) { | |||||
| if (!DecodeInputOutput(input, imply_type, kInput, op_info, dtype_format)) { | if (!DecodeInputOutput(input, imply_type, kInput, op_info, dtype_format)) { | ||||
| MS_LOG(DEBUG) << "DecodeInputOutput Failed"; | MS_LOG(DEBUG) << "DecodeInputOutput Failed"; | ||||
| return false; | return false; | ||||
| } | } | ||||
| } | } | ||||
| auto outputs = obj.at(kOutputs); | auto outputs = obj.at(kOutputs); | ||||
| for (const auto& output : outputs) { | |||||
| for (const auto &output : outputs) { | |||||
| if (!DecodeInputOutput(output, imply_type, kOutput, op_info, dtype_format)) { | if (!DecodeInputOutput(output, imply_type, kOutput, op_info, dtype_format)) { | ||||
| MS_LOG(DEBUG) << "DecodeInputOutput Failed"; | MS_LOG(DEBUG) << "DecodeInputOutput Failed"; | ||||
| return false; | return false; | ||||
| @@ -156,8 +156,8 @@ bool OpLib::DecodeOpInfo(const nlohmann::json& obj, const mindspore::kernel::OpI | |||||
| return true; | return true; | ||||
| } | } | ||||
| bool OpLib::DecodeAttr(const nlohmann::json& obj, const OpImplyType imply_type, | |||||
| const std::shared_ptr<OpInfo>& op_info) { | |||||
| bool OpLib::DecodeAttr(const nlohmann::json &obj, const OpImplyType imply_type, | |||||
| const std::shared_ptr<OpInfo> &op_info) { | |||||
| MS_EXCEPTION_IF_NULL(op_info); | MS_EXCEPTION_IF_NULL(op_info); | ||||
| bool ret = true; | bool ret = true; | ||||
| try { | try { | ||||
| @@ -175,34 +175,34 @@ bool OpLib::DecodeAttr(const nlohmann::json& obj, const OpImplyType imply_type, | |||||
| op_attr->set_default_value(obj.at(kDefaultValue)); | op_attr->set_default_value(obj.at(kDefaultValue)); | ||||
| } | } | ||||
| op_info->add_attrs_ptr(op_attr); | op_info->add_attrs_ptr(op_attr); | ||||
| } catch (const std::exception& e) { | |||||
| } catch (const std::exception &e) { | |||||
| MS_LOG(DEBUG) << "DecodeAttr failed:" << e.what(); | MS_LOG(DEBUG) << "DecodeAttr failed:" << e.what(); | ||||
| ret = false; | ret = false; | ||||
| } | } | ||||
| return ret; | return ret; | ||||
| } | } | ||||
| bool OpLib::DecodeDtypeFormat(const nlohmann::json& dtype_format, const std::shared_ptr<OpIOInfo>& op_io, | |||||
| bool OpLib::DecodeDtypeFormat(const nlohmann::json &dtype_format, const std::shared_ptr<OpIOInfo> &op_io, | |||||
| size_t index) { | size_t index) { | ||||
| bool ret = true; | bool ret = true; | ||||
| try { | try { | ||||
| std::vector<std::string> dtype; | std::vector<std::string> dtype; | ||||
| std::vector<std::string> format; | std::vector<std::string> format; | ||||
| for (const auto& it : dtype_format) { | |||||
| for (const auto &it : dtype_format) { | |||||
| dtype.emplace_back(it[index][0]); | dtype.emplace_back(it[index][0]); | ||||
| format.emplace_back(it[index][1]); | format.emplace_back(it[index][1]); | ||||
| } | } | ||||
| op_io->set_dtypes(dtype); | op_io->set_dtypes(dtype); | ||||
| op_io->set_formats(format); | op_io->set_formats(format); | ||||
| } catch (const std::exception& e) { | |||||
| } catch (const std::exception &e) { | |||||
| MS_LOG(ERROR) << "DecodeDtypeFormat falied" << e.what(); | MS_LOG(ERROR) << "DecodeDtypeFormat falied" << e.what(); | ||||
| ret = false; | ret = false; | ||||
| } | } | ||||
| return ret; | return ret; | ||||
| } | } | ||||
| bool OpLib::DecodeInputOutput(const nlohmann::json& obj, const OpImplyType imply_type, const OpIOType io_type, | |||||
| const std::shared_ptr<OpInfo>& op_info, const nlohmann::json& dtype_format) { | |||||
| bool OpLib::DecodeInputOutput(const nlohmann::json &obj, const OpImplyType imply_type, const OpIOType io_type, | |||||
| const std::shared_ptr<OpInfo> &op_info, const nlohmann::json &dtype_format) { | |||||
| bool ret = true; | bool ret = true; | ||||
| try { | try { | ||||
| std::shared_ptr<OpIOInfo> op_io = std::make_shared<OpIOInfo>(); | std::shared_ptr<OpIOInfo> op_io = std::make_shared<OpIOInfo>(); | ||||
| @@ -243,14 +243,14 @@ bool OpLib::DecodeInputOutput(const nlohmann::json& obj, const OpImplyType imply | |||||
| } else if (io_type == kOutput) { | } else if (io_type == kOutput) { | ||||
| op_info->add_outputs_ptr(op_io); | op_info->add_outputs_ptr(op_io); | ||||
| } | } | ||||
| } catch (const std::exception& e) { | |||||
| } catch (const std::exception &e) { | |||||
| MS_LOG(DEBUG) << "DecodeInputOutput failed" << e.what(); | MS_LOG(DEBUG) << "DecodeInputOutput failed" << e.what(); | ||||
| ret = false; | ret = false; | ||||
| } | } | ||||
| return ret; | return ret; | ||||
| } | } | ||||
| std::shared_ptr<OpInfo> OpLib::FindOp(const std::string& op_name, OpImplyType imply_type) { | |||||
| std::shared_ptr<OpInfo> OpLib::FindOp(const std::string &op_name, OpImplyType imply_type) { | |||||
| auto context = MsContext::GetInstance(); | auto context = MsContext::GetInstance(); | ||||
| MS_EXCEPTION_IF_NULL(context); | MS_EXCEPTION_IF_NULL(context); | ||||
| bool is_gpu = (context->device_target() == kGPUDevice); | bool is_gpu = (context->device_target() == kGPUDevice); | ||||
| @@ -260,7 +260,7 @@ std::shared_ptr<OpInfo> OpLib::FindOp(const std::string& op_name, OpImplyType im | |||||
| << ", current op num:" << op_info_.size(); | << ", current op num:" << op_info_.size(); | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| for (const auto& op_info : op_info_) { | |||||
| for (const auto &op_info : op_info_) { | |||||
| MS_EXCEPTION_IF_NULL(op_info); | MS_EXCEPTION_IF_NULL(op_info); | ||||
| if (op_info->op_name() == op_name && op_info->imply_type() == imply_type) { | if (op_info->op_name() == op_name && op_info->imply_type() == imply_type) { | ||||
| return op_info; | return op_info; | ||||
| @@ -271,14 +271,14 @@ std::shared_ptr<OpInfo> OpLib::FindOp(const std::string& op_name, OpImplyType im | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| bool OpLib::GetRefInfo(const std::shared_ptr<OpInfo>& op_info) { | |||||
| bool OpLib::GetRefInfo(const std::shared_ptr<OpInfo> &op_info) { | |||||
| MS_EXCEPTION_IF_NULL(op_info); | MS_EXCEPTION_IF_NULL(op_info); | ||||
| const auto& output_infos = op_info->outputs_ptr(); | |||||
| const auto& input_infos = op_info->inputs_ptr(); | |||||
| const auto &output_infos = op_info->outputs_ptr(); | |||||
| const auto &input_infos = op_info->inputs_ptr(); | |||||
| for (size_t out_index = 0; out_index < output_infos.size(); out_index++) { | for (size_t out_index = 0; out_index < output_infos.size(); out_index++) { | ||||
| const auto& out_name = output_infos[out_index]->name(); | |||||
| const auto &out_name = output_infos[out_index]->name(); | |||||
| for (size_t in_index = 0; in_index < input_infos.size(); in_index++) { | for (size_t in_index = 0; in_index < input_infos.size(); in_index++) { | ||||
| const auto& in_name = input_infos[in_index]->name(); | |||||
| const auto &in_name = input_infos[in_index]->name(); | |||||
| if (out_name == in_name) { | if (out_name == in_name) { | ||||
| if (op_info->has_ref_index(out_index)) { | if (op_info->has_ref_index(out_index)) { | ||||
| MS_LOG(DEBUG) << "The out_index" << out_index << "is already in ref_info"; | MS_LOG(DEBUG) << "The out_index" << out_index << "is already in ref_info"; | ||||
| @@ -293,9 +293,9 @@ bool OpLib::GetRefInfo(const std::shared_ptr<OpInfo>& op_info) { | |||||
| return true; | return true; | ||||
| } | } | ||||
| bool OpLib::CheckRepetition(const std::shared_ptr<OpInfo>& op_info) { | |||||
| bool OpLib::CheckRepetition(const std::shared_ptr<OpInfo> &op_info) { | |||||
| MS_EXCEPTION_IF_NULL(op_info); | MS_EXCEPTION_IF_NULL(op_info); | ||||
| for (const auto& exist_op_info : op_info_) { | |||||
| for (const auto &exist_op_info : op_info_) { | |||||
| MS_EXCEPTION_IF_NULL(exist_op_info); | MS_EXCEPTION_IF_NULL(exist_op_info); | ||||
| if (exist_op_info->op_name() == op_info->op_name() && exist_op_info->imply_type() == op_info->imply_type() && | if (exist_op_info->op_name() == op_info->op_name() && exist_op_info->imply_type() == op_info->imply_type() && | ||||
| exist_op_info->impl_path() != op_info->impl_path()) { | exist_op_info->impl_path() != op_info->impl_path()) { | ||||
| @@ -28,23 +28,23 @@ class OpLib { | |||||
| public: | public: | ||||
| OpLib() = default; | OpLib() = default; | ||||
| virtual ~OpLib() = default; | virtual ~OpLib() = default; | ||||
| bool RegOp(const std::string& json_string, const std::string& impl_path); | |||||
| static std::shared_ptr<OpInfo> FindOp(const std::string& op_name, OpImplyType imply_type); | |||||
| bool RegOp(const std::string &json_string, const std::string &impl_path); | |||||
| static std::shared_ptr<OpInfo> FindOp(const std::string &op_name, OpImplyType imply_type); | |||||
| protected: | protected: | ||||
| static std::vector<std::shared_ptr<OpInfo>> op_info_; | static std::vector<std::shared_ptr<OpInfo>> op_info_; | ||||
| private: | private: | ||||
| static bool DecodeOpInfo(const nlohmann::json& obj, const OpImplyType imply_type, const std::string& impl_path); | |||||
| static bool DecodeAttr(const nlohmann::json& obj, const OpImplyType imply_type, | |||||
| const std::shared_ptr<OpInfo>& op_info); | |||||
| static bool DecodeDtypeFormat(const nlohmann::json& dtype_format, const std::shared_ptr<OpIOInfo>& op_io, | |||||
| static bool DecodeOpInfo(const nlohmann::json &obj, const OpImplyType imply_type, const std::string &impl_path); | |||||
| static bool DecodeAttr(const nlohmann::json &obj, const OpImplyType imply_type, | |||||
| const std::shared_ptr<OpInfo> &op_info); | |||||
| static bool DecodeDtypeFormat(const nlohmann::json &dtype_format, const std::shared_ptr<OpIOInfo> &op_io, | |||||
| size_t index); | size_t index); | ||||
| static void DecodeTBESpecificInfo(const nlohmann::json& obj, const std::shared_ptr<OpInfo>& op_info); | |||||
| static bool DecodeInputOutput(const nlohmann::json& obj, const OpImplyType imply_type, const OpIOType io_type, | |||||
| const std::shared_ptr<OpInfo>& op_info, const nlohmann::json& dtype_format); | |||||
| static bool GetRefInfo(const std::shared_ptr<OpInfo>& op_info); | |||||
| static bool CheckRepetition(const std::shared_ptr<OpInfo>& op_info); | |||||
| static void DecodeTBESpecificInfo(const nlohmann::json &obj, const std::shared_ptr<OpInfo> &op_info); | |||||
| static bool DecodeInputOutput(const nlohmann::json &obj, const OpImplyType imply_type, const OpIOType io_type, | |||||
| const std::shared_ptr<OpInfo> &op_info, const nlohmann::json &dtype_format); | |||||
| static bool GetRefInfo(const std::shared_ptr<OpInfo> &op_info); | |||||
| static bool CheckRepetition(const std::shared_ptr<OpInfo> &op_info); | |||||
| }; | }; | ||||
| } // namespace kernel | } // namespace kernel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -19,6 +19,6 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| // cppcheck-suppress unusedFunction | // cppcheck-suppress unusedFunction | ||||
| std::string set_version(const std::string& version) { return version; } | |||||
| std::string set_version(const std::string &version) { return version; } | |||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -42,11 +42,11 @@ struct OpMergedInfo { | |||||
| }; | }; | ||||
| using GenAttrFuncType = | using GenAttrFuncType = | ||||
| std::function<void(ValuePtr, onnx::AttributeProto_AttributeType, onnx::AttributeProto*, const PrimitivePtr&)>; | |||||
| std::function<void(ValuePtr, onnx::AttributeProto_AttributeType, onnx::AttributeProto *, const PrimitivePtr &)>; | |||||
| template <typename T, size_t rep_cnt = 0> | template <typename T, size_t rep_cnt = 0> | ||||
| void SetAttrValueToProto(const ValuePtr& value, onnx::AttributeProto_AttributeType attr_type, | |||||
| onnx::AttributeProto* const attr_proto, const PrimitivePtr&) { | |||||
| void SetAttrValueToProto(const ValuePtr &value, onnx::AttributeProto_AttributeType attr_type, | |||||
| onnx::AttributeProto *const attr_proto, const PrimitivePtr &) { | |||||
| auto casted_value = dyn_cast<T>(value); | auto casted_value = dyn_cast<T>(value); | ||||
| if (casted_value == nullptr) { | if (casted_value == nullptr) { | ||||
| MS_LOG(EXCEPTION) << "Cast value " << value->ToString() << " to type T failed."; | MS_LOG(EXCEPTION) << "Cast value " << value->ToString() << " to type T failed."; | ||||
| @@ -76,8 +76,8 @@ void SetAttrValueToProto(const ValuePtr& value, onnx::AttributeProto_AttributeTy | |||||
| } | } | ||||
| template <size_t beg_idx = 0> | template <size_t beg_idx = 0> | ||||
| void SetAttrTupleValueToProto(const ValuePtr& value, onnx::AttributeProto_AttributeType attr_type, | |||||
| onnx::AttributeProto* const attr_proto, const PrimitivePtr&) { | |||||
| void SetAttrTupleValueToProto(const ValuePtr &value, onnx::AttributeProto_AttributeType attr_type, | |||||
| onnx::AttributeProto *const attr_proto, const PrimitivePtr &) { | |||||
| auto tuple_ptr = dyn_cast<ValueTuple>(value); | auto tuple_ptr = dyn_cast<ValueTuple>(value); | ||||
| if (tuple_ptr == nullptr) { | if (tuple_ptr == nullptr) { | ||||
| MS_LOG(EXCEPTION) << "Cast value from type " << value->type_name() << " to ValueTuple failed."; | MS_LOG(EXCEPTION) << "Cast value from type " << value->type_name() << " to ValueTuple failed."; | ||||
| @@ -99,8 +99,8 @@ void SetAttrTupleValueToProto(const ValuePtr& value, onnx::AttributeProto_Attrib | |||||
| attr_proto->set_type(attr_type); | attr_proto->set_type(attr_type); | ||||
| } | } | ||||
| void SetPoolingPadMode(const ValuePtr& value, onnx::AttributeProto_AttributeType, | |||||
| onnx::AttributeProto* const attr_proto, const PrimitivePtr&) { | |||||
| void SetPoolingPadMode(const ValuePtr &value, onnx::AttributeProto_AttributeType, | |||||
| onnx::AttributeProto *const attr_proto, const PrimitivePtr &) { | |||||
| attr_proto->set_type(onnx::AttributeProto_AttributeType_STRING); | attr_proto->set_type(onnx::AttributeProto_AttributeType_STRING); | ||||
| auto attr_value = GetValue<std::string>(value); | auto attr_value = GetValue<std::string>(value); | ||||
| if (attr_value == "VALID") { | if (attr_value == "VALID") { | ||||
| @@ -112,16 +112,16 @@ void SetPoolingPadMode(const ValuePtr& value, onnx::AttributeProto_AttributeType | |||||
| class OpAttrInfo { | class OpAttrInfo { | ||||
| public: | public: | ||||
| OpAttrInfo(const std::string& attr_name, const string& onnx_attr_name, | |||||
| onnx::AttributeProto_AttributeType onnx_attr_type, const GenAttrFuncType& fn_gen_attr) | |||||
| OpAttrInfo(const std::string &attr_name, const string &onnx_attr_name, | |||||
| onnx::AttributeProto_AttributeType onnx_attr_type, const GenAttrFuncType &fn_gen_attr) | |||||
| : attr_name_(attr_name), | : attr_name_(attr_name), | ||||
| onnx_attr_name_(onnx_attr_name), | onnx_attr_name_(onnx_attr_name), | ||||
| onnx_attr_type_(onnx_attr_type), | onnx_attr_type_(onnx_attr_type), | ||||
| fn_gen_attr_(fn_gen_attr) {} | fn_gen_attr_(fn_gen_attr) {} | ||||
| ~OpAttrInfo() {} | ~OpAttrInfo() {} | ||||
| const std::string& attr_name() const { return attr_name_; } | |||||
| const std::string& onnx_attr_name() const { return onnx_attr_name_; } | |||||
| const std::string &attr_name() const { return attr_name_; } | |||||
| const std::string &onnx_attr_name() const { return onnx_attr_name_; } | |||||
| onnx::AttributeProto_AttributeType onnx_attr_type() const { return onnx_attr_type_; } | onnx::AttributeProto_AttributeType onnx_attr_type() const { return onnx_attr_type_; } | ||||
| GenAttrFuncType fn_gen_attr() const { return fn_gen_attr_; } | GenAttrFuncType fn_gen_attr() const { return fn_gen_attr_; } | ||||
| @@ -134,27 +134,27 @@ class OpAttrInfo { | |||||
| class OpNameInfo { | class OpNameInfo { | ||||
| public: | public: | ||||
| OpNameInfo& set_op_type(const std::string& op_type) { | |||||
| OpNameInfo &set_op_type(const std::string &op_type) { | |||||
| op_type_ = op_type; | op_type_ = op_type; | ||||
| return *this; | return *this; | ||||
| } | } | ||||
| const std::string& op_type() const { return op_type_; } | |||||
| const std::string &op_type() const { return op_type_; } | |||||
| OpNameInfo& set_onnx_type(const std::string& onnx_type) { | |||||
| OpNameInfo &set_onnx_type(const std::string &onnx_type) { | |||||
| onnx_type_ = onnx_type; | onnx_type_ = onnx_type; | ||||
| return *this; | return *this; | ||||
| } | } | ||||
| const std::string& onnx_type() const { return onnx_type_; } | |||||
| const std::string &onnx_type() const { return onnx_type_; } | |||||
| OpNameInfo& Attr(const std::string& attr_name, const std::string& onnx_attr_name, | |||||
| onnx::AttributeProto_AttributeType onnx_attr_type, const GenAttrFuncType& fn_gen_attr) { | |||||
| OpNameInfo &Attr(const std::string &attr_name, const std::string &onnx_attr_name, | |||||
| onnx::AttributeProto_AttributeType onnx_attr_type, const GenAttrFuncType &fn_gen_attr) { | |||||
| op_attrs_.emplace_back(OpAttrInfo(attr_name, onnx_attr_name, onnx_attr_type, fn_gen_attr)); | op_attrs_.emplace_back(OpAttrInfo(attr_name, onnx_attr_name, onnx_attr_type, fn_gen_attr)); | ||||
| return *this; | return *this; | ||||
| } | } | ||||
| const std::vector<OpAttrInfo>& op_attrs() const { return op_attrs_; } | |||||
| const std::vector<OpAttrInfo> &op_attrs() const { return op_attrs_; } | |||||
| private: | private: | ||||
| std::string op_type_; // operator type of MindSpore | std::string op_type_; // operator type of MindSpore | ||||
| @@ -183,8 +183,8 @@ OPERATOR_ONNX_CONVERT_DEFINE( | |||||
| .Attr("group", "group", onnx::AttributeProto_AttributeType_INT, SetAttrValueToProto<Int32Imm>) | .Attr("group", "group", onnx::AttributeProto_AttributeType_INT, SetAttrValueToProto<Int32Imm>) | ||||
| .Attr("kernel_size", "kernel_shape", onnx::AttributeProto_AttributeType_INTS, SetAttrTupleValueToProto<0>) | .Attr("kernel_size", "kernel_shape", onnx::AttributeProto_AttributeType_INTS, SetAttrTupleValueToProto<0>) | ||||
| .Attr("pad_mode", "auto_pad", onnx::AttributeProto_AttributeType_STRING, | .Attr("pad_mode", "auto_pad", onnx::AttributeProto_AttributeType_STRING, | ||||
| [](ValuePtr value, onnx::AttributeProto_AttributeType, onnx::AttributeProto* const attr_proto, | |||||
| const PrimitivePtr& prim) { | |||||
| [](ValuePtr value, onnx::AttributeProto_AttributeType, onnx::AttributeProto *const attr_proto, | |||||
| const PrimitivePtr &prim) { | |||||
| attr_proto->set_type(onnx::AttributeProto_AttributeType_STRING); | attr_proto->set_type(onnx::AttributeProto_AttributeType_STRING); | ||||
| auto attr_value = GetValue<std::string>(value); | auto attr_value = GetValue<std::string>(value); | ||||
| if (attr_value == "valid") { | if (attr_value == "valid") { | ||||
| @@ -220,7 +220,7 @@ OPERATOR_ONNX_CONVERT_DEFINE(Argmax, ArgMax, | |||||
| SetAttrValueToProto<Int32Imm>) | SetAttrValueToProto<Int32Imm>) | ||||
| .Attr("", "keepdims", onnx::AttributeProto_AttributeType_INT, | .Attr("", "keepdims", onnx::AttributeProto_AttributeType_INT, | ||||
| [](ValuePtr, onnx::AttributeProto_AttributeType, | [](ValuePtr, onnx::AttributeProto_AttributeType, | ||||
| onnx::AttributeProto* const attr_proto, const PrimitivePtr&) { | |||||
| onnx::AttributeProto *const attr_proto, const PrimitivePtr &) { | |||||
| attr_proto->set_type(onnx::AttributeProto_AttributeType_INT); | attr_proto->set_type(onnx::AttributeProto_AttributeType_INT); | ||||
| attr_proto->set_i(0); | attr_proto->set_i(0); | ||||
| })) | })) | ||||
| @@ -242,7 +242,7 @@ OPERATOR_ONNX_CONVERT_DEFINE( | |||||
| #define OP_CONVERT_FUNCTION_NAME(name) GetOpOnnxConvertInfo_##name | #define OP_CONVERT_FUNCTION_NAME(name) GetOpOnnxConvertInfo_##name | ||||
| void RegisterOpConverters(const std::function<void(OpNameInfo&&)>& fn) { | |||||
| void RegisterOpConverters(const std::function<void(OpNameInfo &&)> &fn) { | |||||
| fn(OP_CONVERT_FUNCTION_NAME(TensorAdd)()); | fn(OP_CONVERT_FUNCTION_NAME(TensorAdd)()); | ||||
| fn(OP_CONVERT_FUNCTION_NAME(Mul)()); | fn(OP_CONVERT_FUNCTION_NAME(Mul)()); | ||||
| @@ -265,16 +265,16 @@ class OpConvertRegistry { | |||||
| public: | public: | ||||
| ~OpConvertRegistry() { Clear(); } | ~OpConvertRegistry() { Clear(); } | ||||
| static void RegisterOneOpConverter(OpNameInfo&& op_info) { GetSingleton().op_map_[op_info.op_type()] = op_info; } | |||||
| static void RegisterOneOpConverter(OpNameInfo &&op_info) { GetSingleton().op_map_[op_info.op_type()] = op_info; } | |||||
| static void RegisterAllOpConverters() { RegisterOpConverters(RegisterOneOpConverter); } | static void RegisterAllOpConverters() { RegisterOpConverters(RegisterOneOpConverter); } | ||||
| static OpConvertRegistry& GetSingleton() { | |||||
| static OpConvertRegistry &GetSingleton() { | |||||
| static OpConvertRegistry registry = OpConvertRegistry(); | static OpConvertRegistry registry = OpConvertRegistry(); | ||||
| return registry; | return registry; | ||||
| } | } | ||||
| static const std::unordered_map<std::string, OpNameInfo>& GetOpConvertMap() { return GetSingleton().op_map_; } | |||||
| static const std::unordered_map<std::string, OpNameInfo> &GetOpConvertMap() { return GetSingleton().op_map_; } | |||||
| void Clear() noexcept { op_map_.clear(); } | void Clear() noexcept { op_map_.clear(); } | ||||
| @@ -289,59 +289,59 @@ class OnnxExporter { | |||||
| OnnxExporter() {} | OnnxExporter() {} | ||||
| ~OnnxExporter() {} | ~OnnxExporter() {} | ||||
| std::string GetOnnxProtoString(const FuncGraphPtr& func_graph); | |||||
| std::string GetOnnxProtoString(const FuncGraphPtr &func_graph); | |||||
| private: | private: | ||||
| void InitModelInfo(); | void InitModelInfo(); | ||||
| void ExportFuncGraph(const FuncGraphPtr& func_graph, onnx::GraphProto* graph_proto); | |||||
| void ExportParameters(const FuncGraphPtr& func_graph, onnx::GraphProto* graph_proto); | |||||
| void ExportFuncGraph(const FuncGraphPtr &func_graph, onnx::GraphProto *graph_proto); | |||||
| void ExportParameters(const FuncGraphPtr &func_graph, onnx::GraphProto *graph_proto); | |||||
| size_t ExportPrimitive(const FuncGraphPtr& func_graph, std::map<AnfNodePtr, size_t>* node_map_ptr, | |||||
| const PrimitivePtr& prim, const std::vector<AnfNodePtr>& inputs, | |||||
| onnx::GraphProto* graph_proto); | |||||
| size_t ExportPrimitive(const FuncGraphPtr &func_graph, std::map<AnfNodePtr, size_t> *node_map_ptr, | |||||
| const PrimitivePtr &prim, const std::vector<AnfNodePtr> &inputs, | |||||
| onnx::GraphProto *graph_proto); | |||||
| static onnx::TensorProto_DataType GetOnnxDataType(TypeId type_id); | static onnx::TensorProto_DataType GetOnnxDataType(TypeId type_id); | ||||
| void SetValueInfoType(const AnfNodePtr& node, onnx::ValueInfoProto* value_proto, bool is_output = false); | |||||
| void SetTensorProtoInfo(const ParameterPtr& param, onnx::TensorProto* tensor_proto); | |||||
| void MatchAndMark(const FuncGraphPtr& func_graph, const std::vector<AnfNodePtr>& nodes, | |||||
| std::unordered_map<AnfNodePtr, OpMergedInfo>* op_merged_infos_ptr); | |||||
| void ExportNodes(const FuncGraphPtr& func_graph, std::map<AnfNodePtr, size_t>* node_map_ptr, | |||||
| onnx::GraphProto* graph_proto); | |||||
| void ExportCNode(const FuncGraphPtr& func_graph, const CNodePtr& node, std::map<AnfNodePtr, size_t>* node_map_ptr, | |||||
| onnx::GraphProto* graph_proto); | |||||
| void ExportPrimReshape(const FuncGraphPtr& func_graph, const CNodePtr& node, | |||||
| std::map<AnfNodePtr, size_t>* node_map_ptr, onnx::GraphProto* graph_proto); | |||||
| void ExportPrimReduceMean(const FuncGraphPtr& func_graph, const CNodePtr& node, | |||||
| std::map<AnfNodePtr, size_t>* node_map_ptr, onnx::GraphProto* graph_proto); | |||||
| void ExportPrimCast(const FuncGraphPtr& func_graph, const CNodePtr& node, std::map<AnfNodePtr, size_t>* node_map_ptr, | |||||
| onnx::GraphProto* graph_proto); | |||||
| void ExportPrimPReLU(const FuncGraphPtr& func_graph, const CNodePtr& node, std::map<AnfNodePtr, size_t>* node_map_ptr, | |||||
| onnx::GraphProto* graph_proto); | |||||
| void ExportMergeConv(const FuncGraphPtr& func_graph, const CNodePtr& node, std::map<AnfNodePtr, size_t>* node_map_ptr, | |||||
| onnx::GraphProto* graph_proto); | |||||
| void ExportMergeGemm(const FuncGraphPtr& func_graph, const CNodePtr& node, std::map<AnfNodePtr, size_t>* node_map_ptr, | |||||
| onnx::GraphProto* graph_proto); | |||||
| void ExportMergeBatchNorm(const FuncGraphPtr& func_graph, const CNodePtr& node, | |||||
| std::map<AnfNodePtr, size_t>* node_map_ptr, onnx::GraphProto* graph_proto); | |||||
| void ExportOutput(const FuncGraphPtr& func_graph, const CNodePtr& node, std::map<AnfNodePtr, size_t>* node_map_ptr, | |||||
| onnx::GraphProto* graph_proto); | |||||
| std::string GetNodeInputName(const AnfNodePtr& node, std::map<AnfNodePtr, size_t>* node_map_ptr, | |||||
| onnx::GraphProto* const graph_proto); | |||||
| void ConvertTupleToTensor(const ValuePtr& value, onnx::TensorProto* tensor_proto); | |||||
| void SetNodeAttribute(const ValuePtr& value, onnx::NodeProto* node_proto); | |||||
| void SetValueInfoType(const AnfNodePtr &node, onnx::ValueInfoProto *value_proto, bool is_output = false); | |||||
| void SetTensorProtoInfo(const ParameterPtr ¶m, onnx::TensorProto *tensor_proto); | |||||
| void MatchAndMark(const FuncGraphPtr &func_graph, const std::vector<AnfNodePtr> &nodes, | |||||
| std::unordered_map<AnfNodePtr, OpMergedInfo> *op_merged_infos_ptr); | |||||
| void ExportNodes(const FuncGraphPtr &func_graph, std::map<AnfNodePtr, size_t> *node_map_ptr, | |||||
| onnx::GraphProto *graph_proto); | |||||
| void ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map<AnfNodePtr, size_t> *node_map_ptr, | |||||
| onnx::GraphProto *graph_proto); | |||||
| void ExportPrimReshape(const FuncGraphPtr &func_graph, const CNodePtr &node, | |||||
| std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *graph_proto); | |||||
| void ExportPrimReduceMean(const FuncGraphPtr &func_graph, const CNodePtr &node, | |||||
| std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *graph_proto); | |||||
| void ExportPrimCast(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map<AnfNodePtr, size_t> *node_map_ptr, | |||||
| onnx::GraphProto *graph_proto); | |||||
| void ExportPrimPReLU(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map<AnfNodePtr, size_t> *node_map_ptr, | |||||
| onnx::GraphProto *graph_proto); | |||||
| void ExportMergeConv(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map<AnfNodePtr, size_t> *node_map_ptr, | |||||
| onnx::GraphProto *graph_proto); | |||||
| void ExportMergeGemm(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map<AnfNodePtr, size_t> *node_map_ptr, | |||||
| onnx::GraphProto *graph_proto); | |||||
| void ExportMergeBatchNorm(const FuncGraphPtr &func_graph, const CNodePtr &node, | |||||
| std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *graph_proto); | |||||
| void ExportOutput(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map<AnfNodePtr, size_t> *node_map_ptr, | |||||
| onnx::GraphProto *graph_proto); | |||||
| std::string GetNodeInputName(const AnfNodePtr &node, std::map<AnfNodePtr, size_t> *node_map_ptr, | |||||
| onnx::GraphProto *const graph_proto); | |||||
| void ConvertTupleToTensor(const ValuePtr &value, onnx::TensorProto *tensor_proto); | |||||
| void SetNodeAttribute(const ValuePtr &value, onnx::NodeProto *node_proto); | |||||
| size_t AllocateNodeIndex() { return ++onnx_node_index_; } | size_t AllocateNodeIndex() { return ++onnx_node_index_; } | ||||
| void ResetNodeIndex() { onnx_node_index_ = 0; } | void ResetNodeIndex() { onnx_node_index_ = 0; } | ||||
| static int GetInt32Value(const AnfNodePtr& node) { | |||||
| static int GetInt32Value(const AnfNodePtr &node) { | |||||
| auto value_node_ptr = dyn_cast<ValueNode>(node); | auto value_node_ptr = dyn_cast<ValueNode>(node); | ||||
| MS_EXCEPTION_IF_NULL(value_node_ptr); | MS_EXCEPTION_IF_NULL(value_node_ptr); | ||||
| return GetValue<int>(value_node_ptr->value()); | return GetValue<int>(value_node_ptr->value()); | ||||
| @@ -352,7 +352,7 @@ class OnnxExporter { | |||||
| size_t onnx_node_index_ = 0; | size_t onnx_node_index_ = 0; | ||||
| }; | }; | ||||
| std::string OnnxExporter::GetOnnxProtoString(const FuncGraphPtr& func_graph) { | |||||
| std::string OnnxExporter::GetOnnxProtoString(const FuncGraphPtr &func_graph) { | |||||
| if (func_graph == nullptr) { | if (func_graph == nullptr) { | ||||
| return ""; | return ""; | ||||
| } | } | ||||
| @@ -360,7 +360,7 @@ std::string OnnxExporter::GetOnnxProtoString(const FuncGraphPtr& func_graph) { | |||||
| OpConvertRegistry::GetSingleton().Clear(); | OpConvertRegistry::GetSingleton().Clear(); | ||||
| OpConvertRegistry::RegisterAllOpConverters(); | OpConvertRegistry::RegisterAllOpConverters(); | ||||
| InitModelInfo(); | InitModelInfo(); | ||||
| onnx::GraphProto* graph_proto = model_.mutable_graph(); | |||||
| onnx::GraphProto *graph_proto = model_.mutable_graph(); | |||||
| ExportFuncGraph(func_graph, graph_proto); | ExportFuncGraph(func_graph, graph_proto); | ||||
| return model_.SerializeAsString(); | return model_.SerializeAsString(); | ||||
| } | } | ||||
| @@ -369,11 +369,11 @@ void OnnxExporter::InitModelInfo() { | |||||
| model_.set_ir_version(onnx::IR_VERSION_2019_1_22); | model_.set_ir_version(onnx::IR_VERSION_2019_1_22); | ||||
| model_.set_producer_name("MindSpore"); | model_.set_producer_name("MindSpore"); | ||||
| model_.set_producer_version("1.0"); | model_.set_producer_version("1.0"); | ||||
| onnx::OperatorSetIdProto* opset_proto = model_.add_opset_import(); | |||||
| onnx::OperatorSetIdProto *opset_proto = model_.add_opset_import(); | |||||
| opset_proto->set_version(9); | opset_proto->set_version(9); | ||||
| } | } | ||||
| void OnnxExporter::ExportFuncGraph(const FuncGraphPtr& func_graph, onnx::GraphProto* const graph_proto) { | |||||
| void OnnxExporter::ExportFuncGraph(const FuncGraphPtr &func_graph, onnx::GraphProto *const graph_proto) { | |||||
| std::map<AnfNodePtr, size_t> node_map; | std::map<AnfNodePtr, size_t> node_map; | ||||
| onnx_node_index_ = func_graph->parameters().size(); | onnx_node_index_ = func_graph->parameters().size(); | ||||
| @@ -390,14 +390,14 @@ void OnnxExporter::ExportFuncGraph(const FuncGraphPtr& func_graph, onnx::GraphPr | |||||
| ExportNodes(func_graph, &node_map, graph_proto); | ExportNodes(func_graph, &node_map, graph_proto); | ||||
| } | } | ||||
| void OnnxExporter::ExportParameters(const FuncGraphPtr& func_graph, onnx::GraphProto* const graph_proto) { | |||||
| for (auto& param : func_graph->parameters()) { | |||||
| void OnnxExporter::ExportParameters(const FuncGraphPtr &func_graph, onnx::GraphProto *const graph_proto) { | |||||
| for (auto ¶m : func_graph->parameters()) { | |||||
| const ParameterPtr param_ptr = dyn_cast<Parameter>(param); | const ParameterPtr param_ptr = dyn_cast<Parameter>(param); | ||||
| if (param_ptr == nullptr) { | if (param_ptr == nullptr) { | ||||
| MS_LOG(EXCEPTION) << "Parameter '" << param->ToString() << "' could not cast to parameter."; | MS_LOG(EXCEPTION) << "Parameter '" << param->ToString() << "' could not cast to parameter."; | ||||
| } | } | ||||
| onnx::ValueInfoProto* input_proto = graph_proto->add_input(); | |||||
| onnx::ValueInfoProto *input_proto = graph_proto->add_input(); | |||||
| input_proto->set_name(param_ptr->ToString()); | input_proto->set_name(param_ptr->ToString()); | ||||
| SetValueInfoType(param_ptr, input_proto); | SetValueInfoType(param_ptr, input_proto); | ||||
| @@ -405,7 +405,7 @@ void OnnxExporter::ExportParameters(const FuncGraphPtr& func_graph, onnx::GraphP | |||||
| continue; | continue; | ||||
| } | } | ||||
| // parameter with default value is an ONNX initializer | // parameter with default value is an ONNX initializer | ||||
| onnx::TensorProto* initializer_proto = graph_proto->add_initializer(); | |||||
| onnx::TensorProto *initializer_proto = graph_proto->add_initializer(); | |||||
| initializer_proto->set_name(param_ptr->ToString()); | initializer_proto->set_name(param_ptr->ToString()); | ||||
| SetTensorProtoInfo(param_ptr, initializer_proto); | SetTensorProtoInfo(param_ptr, initializer_proto); | ||||
| // set value for initializer | // set value for initializer | ||||
| @@ -445,25 +445,25 @@ onnx::TensorProto_DataType OnnxExporter::GetOnnxDataType(TypeId type_id) { | |||||
| return iter->second; | return iter->second; | ||||
| } | } | ||||
| void OnnxExporter::SetValueInfoType(const AnfNodePtr& node, onnx::ValueInfoProto* const value_proto, bool is_output) { | |||||
| void OnnxExporter::SetValueInfoType(const AnfNodePtr &node, onnx::ValueInfoProto *const value_proto, bool is_output) { | |||||
| auto dtype = node->Type(); | auto dtype = node->Type(); | ||||
| auto shape = node->Shape(); | auto shape = node->Shape(); | ||||
| onnx::TypeProto* type_proto = value_proto->mutable_type(); | |||||
| onnx::TypeProto *type_proto = value_proto->mutable_type(); | |||||
| if (dtype->isa<TensorType>() && shape->isa<abstract::Shape>()) { | if (dtype->isa<TensorType>() && shape->isa<abstract::Shape>()) { | ||||
| auto tensor = dyn_cast<TensorType>(dtype); | auto tensor = dyn_cast<TensorType>(dtype); | ||||
| auto elem_type = tensor->element(); | auto elem_type = tensor->element(); | ||||
| const auto& dims = dyn_cast<abstract::Shape>(shape)->shape(); | |||||
| const auto &dims = dyn_cast<abstract::Shape>(shape)->shape(); | |||||
| // output type of 'Argmax' of MindSpore is int32, output type of 'ArgMax' of ONNX is int64 | // output type of 'Argmax' of MindSpore is int32, output type of 'ArgMax' of ONNX is int64 | ||||
| auto type = is_output ? onnx::TensorProto_DataType_INT64 : GetOnnxDataType(elem_type->type_id()); | auto type = is_output ? onnx::TensorProto_DataType_INT64 : GetOnnxDataType(elem_type->type_id()); | ||||
| type_proto->mutable_tensor_type()->set_elem_type(type); | type_proto->mutable_tensor_type()->set_elem_type(type); | ||||
| for (const auto& dim : dims) { | |||||
| for (const auto &dim : dims) { | |||||
| type_proto->mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(dim); | type_proto->mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(dim); | ||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| void OnnxExporter::SetTensorProtoInfo(const ParameterPtr& param, onnx::TensorProto* const tensor_proto) { | |||||
| void OnnxExporter::SetTensorProtoInfo(const ParameterPtr ¶m, onnx::TensorProto *const tensor_proto) { | |||||
| auto dtype = param->Type(); | auto dtype = param->Type(); | ||||
| auto shape = param->Shape(); | auto shape = param->Shape(); | ||||
| if (!dtype->isa<TensorType>() || !shape->isa<abstract::Shape>()) { | if (!dtype->isa<TensorType>() || !shape->isa<abstract::Shape>()) { | ||||
| @@ -472,18 +472,18 @@ void OnnxExporter::SetTensorProtoInfo(const ParameterPtr& param, onnx::TensorPro | |||||
| auto tensor = dyn_cast<TensorType>(dtype); | auto tensor = dyn_cast<TensorType>(dtype); | ||||
| auto elem_type = tensor->element(); | auto elem_type = tensor->element(); | ||||
| const auto& dims = dyn_cast<abstract::Shape>(shape)->shape(); | |||||
| const auto &dims = dyn_cast<abstract::Shape>(shape)->shape(); | |||||
| tensor_proto->set_data_type(GetOnnxDataType(elem_type->type_id())); | tensor_proto->set_data_type(GetOnnxDataType(elem_type->type_id())); | ||||
| for (const auto& dim : dims) { | |||||
| for (const auto &dim : dims) { | |||||
| tensor_proto->add_dims(dim); | tensor_proto->add_dims(dim); | ||||
| } | } | ||||
| } | } | ||||
| void OnnxExporter::MatchAndMark(const FuncGraphPtr& func_graph, const std::vector<AnfNodePtr>& nodes, | |||||
| std::unordered_map<AnfNodePtr, OpMergedInfo>* op_merged_infos_ptr) { | |||||
| std::unordered_map<AnfNodePtr, OpMergedInfo>& op_merged_infos = *op_merged_infos_ptr; | |||||
| void OnnxExporter::MatchAndMark(const FuncGraphPtr &func_graph, const std::vector<AnfNodePtr> &nodes, | |||||
| std::unordered_map<AnfNodePtr, OpMergedInfo> *op_merged_infos_ptr) { | |||||
| std::unordered_map<AnfNodePtr, OpMergedInfo> &op_merged_infos = *op_merged_infos_ptr; | |||||
| for (auto& node : nodes) { | |||||
| for (auto &node : nodes) { | |||||
| if (!node->isa<CNode>()) { | if (!node->isa<CNode>()) { | ||||
| continue; | continue; | ||||
| } | } | ||||
| @@ -492,7 +492,7 @@ void OnnxExporter::MatchAndMark(const FuncGraphPtr& func_graph, const std::vecto | |||||
| // if the key `input` does not exist, just create a new one | // if the key `input` does not exist, just create a new one | ||||
| op_merged_infos[cnode].referred_count += 1; | op_merged_infos[cnode].referred_count += 1; | ||||
| } | } | ||||
| for (auto& input : cnode->inputs()) { | |||||
| for (auto &input : cnode->inputs()) { | |||||
| if (!input->isa<CNode>()) { | if (!input->isa<CNode>()) { | ||||
| continue; | continue; | ||||
| } | } | ||||
| @@ -527,14 +527,14 @@ void OnnxExporter::MatchAndMark(const FuncGraphPtr& func_graph, const std::vecto | |||||
| * | +-- Parameter | * | +-- Parameter | ||||
| * | `-- ValueNode | * | `-- ValueNode | ||||
| */ | */ | ||||
| void OnnxExporter::ExportNodes(const FuncGraphPtr& func_graph, std::map<AnfNodePtr, size_t>* node_map_ptr, | |||||
| onnx::GraphProto* const graph_proto) { | |||||
| void OnnxExporter::ExportNodes(const FuncGraphPtr &func_graph, std::map<AnfNodePtr, size_t> *node_map_ptr, | |||||
| onnx::GraphProto *const graph_proto) { | |||||
| std::vector<AnfNodePtr> nodes = TopoSort(func_graph->get_return(), SuccIncoming, AlwaysInclude); | std::vector<AnfNodePtr> nodes = TopoSort(func_graph->get_return(), SuccIncoming, AlwaysInclude); | ||||
| std::unordered_map<AnfNodePtr, OpMergedInfo> op_merged_infos; | std::unordered_map<AnfNodePtr, OpMergedInfo> op_merged_infos; | ||||
| MatchAndMark(func_graph, nodes, &op_merged_infos); | MatchAndMark(func_graph, nodes, &op_merged_infos); | ||||
| for (const AnfNodePtr& node : nodes) { | |||||
| for (const AnfNodePtr &node : nodes) { | |||||
| if (!node->isa<CNode>()) { | if (!node->isa<CNode>()) { | ||||
| continue; | continue; | ||||
| } | } | ||||
| @@ -570,20 +570,20 @@ void OnnxExporter::ExportNodes(const FuncGraphPtr& func_graph, std::map<AnfNodeP | |||||
| } | } | ||||
| } | } | ||||
| void OnnxExporter::ExportPrimReshape(const FuncGraphPtr& /*func_graph*/, const CNodePtr& node, | |||||
| std::map<AnfNodePtr, size_t>* node_map_ptr, onnx::GraphProto* const graph_proto) { | |||||
| void OnnxExporter::ExportPrimReshape(const FuncGraphPtr & /*func_graph*/, const CNodePtr &node, | |||||
| std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *const graph_proto) { | |||||
| auto name_x = GetNodeInputName(node->input(1), node_map_ptr, graph_proto); | auto name_x = GetNodeInputName(node->input(1), node_map_ptr, graph_proto); | ||||
| auto input_shape = node->input(2); | auto input_shape = node->input(2); | ||||
| std::string name_shape; | std::string name_shape; | ||||
| if (input_shape->isa<ValueNode>()) { | if (input_shape->isa<ValueNode>()) { | ||||
| auto const_node_idx = AllocateNodeIndex(); | auto const_node_idx = AllocateNodeIndex(); | ||||
| (*node_map_ptr)[input_shape] = const_node_idx; | (*node_map_ptr)[input_shape] = const_node_idx; | ||||
| onnx::NodeProto* node_proto = graph_proto->add_node(); | |||||
| onnx::NodeProto *node_proto = graph_proto->add_node(); | |||||
| name_shape = std::to_string(const_node_idx); | name_shape = std::to_string(const_node_idx); | ||||
| node_proto->add_output(name_shape); | node_proto->add_output(name_shape); | ||||
| node_proto->set_op_type("Constant"); | node_proto->set_op_type("Constant"); | ||||
| onnx::AttributeProto* attr_proto = node_proto->add_attribute(); | |||||
| onnx::AttributeProto *attr_proto = node_proto->add_attribute(); | |||||
| attr_proto->set_name("value"); | attr_proto->set_name("value"); | ||||
| attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR); | attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR); | ||||
| @@ -595,28 +595,28 @@ void OnnxExporter::ExportPrimReshape(const FuncGraphPtr& /*func_graph*/, const C | |||||
| auto node_idx = AllocateNodeIndex(); | auto node_idx = AllocateNodeIndex(); | ||||
| (*node_map_ptr)[node] = node_idx; | (*node_map_ptr)[node] = node_idx; | ||||
| onnx::NodeProto* node_proto = graph_proto->add_node(); | |||||
| onnx::NodeProto *node_proto = graph_proto->add_node(); | |||||
| node_proto->set_op_type(prim::kPrimReshape->name()); | node_proto->set_op_type(prim::kPrimReshape->name()); | ||||
| node_proto->add_output(std::to_string(node_idx)); | node_proto->add_output(std::to_string(node_idx)); | ||||
| node_proto->add_input(name_x); | node_proto->add_input(name_x); | ||||
| node_proto->add_input(name_shape); | node_proto->add_input(name_shape); | ||||
| } | } | ||||
| void OnnxExporter::ExportPrimReduceMean(const FuncGraphPtr& /*func_graph*/, const CNodePtr& node, | |||||
| std::map<AnfNodePtr, size_t>* node_map_ptr, | |||||
| onnx::GraphProto* const graph_proto) { | |||||
| void OnnxExporter::ExportPrimReduceMean(const FuncGraphPtr & /*func_graph*/, const CNodePtr &node, | |||||
| std::map<AnfNodePtr, size_t> *node_map_ptr, | |||||
| onnx::GraphProto *const graph_proto) { | |||||
| auto input_data = GetNodeInputName(node->input(1), node_map_ptr, graph_proto); | auto input_data = GetNodeInputName(node->input(1), node_map_ptr, graph_proto); | ||||
| auto input_axis = node->input(2); | auto input_axis = node->input(2); | ||||
| auto node_idx = AllocateNodeIndex(); | auto node_idx = AllocateNodeIndex(); | ||||
| (*node_map_ptr)[node] = node_idx; | (*node_map_ptr)[node] = node_idx; | ||||
| onnx::NodeProto* node_proto = graph_proto->add_node(); | |||||
| onnx::NodeProto *node_proto = graph_proto->add_node(); | |||||
| node_proto->set_op_type(prim::kPrimReduceMean->name()); | node_proto->set_op_type(prim::kPrimReduceMean->name()); | ||||
| node_proto->add_output(std::to_string(node_idx)); | node_proto->add_output(std::to_string(node_idx)); | ||||
| node_proto->add_input(input_data); | node_proto->add_input(input_data); | ||||
| if (input_axis->isa<ValueNode>()) { | if (input_axis->isa<ValueNode>()) { | ||||
| onnx::AttributeProto* attr_proto = node_proto->add_attribute(); | |||||
| onnx::AttributeProto *attr_proto = node_proto->add_attribute(); | |||||
| attr_proto->set_name("axes"); | attr_proto->set_name("axes"); | ||||
| attr_proto->set_type(onnx::AttributeProto_AttributeType_INTS); | attr_proto->set_type(onnx::AttributeProto_AttributeType_INTS); | ||||
| auto axis_value = dyn_cast<ValueNode>(input_axis)->value(); | auto axis_value = dyn_cast<ValueNode>(input_axis)->value(); | ||||
| @@ -630,20 +630,20 @@ void OnnxExporter::ExportPrimReduceMean(const FuncGraphPtr& /*func_graph*/, cons | |||||
| } | } | ||||
| } | } | ||||
| void OnnxExporter::ExportPrimCast(const FuncGraphPtr& /*func_graph*/, const CNodePtr& node, | |||||
| std::map<AnfNodePtr, size_t>* node_map_ptr, onnx::GraphProto* const graph_proto) { | |||||
| void OnnxExporter::ExportPrimCast(const FuncGraphPtr & /*func_graph*/, const CNodePtr &node, | |||||
| std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *const graph_proto) { | |||||
| auto input_data = GetNodeInputName(node->input(1), node_map_ptr, graph_proto); | auto input_data = GetNodeInputName(node->input(1), node_map_ptr, graph_proto); | ||||
| auto input_type = node->input(2); | auto input_type = node->input(2); | ||||
| auto node_idx = AllocateNodeIndex(); | auto node_idx = AllocateNodeIndex(); | ||||
| (*node_map_ptr)[node] = node_idx; | (*node_map_ptr)[node] = node_idx; | ||||
| onnx::NodeProto* node_proto = graph_proto->add_node(); | |||||
| onnx::NodeProto *node_proto = graph_proto->add_node(); | |||||
| node_proto->set_op_type(prim::kPrimCast->name()); | node_proto->set_op_type(prim::kPrimCast->name()); | ||||
| node_proto->add_output(std::to_string(node_idx)); | node_proto->add_output(std::to_string(node_idx)); | ||||
| node_proto->add_input(input_data); | node_proto->add_input(input_data); | ||||
| if (input_type->isa<ValueNode>()) { | if (input_type->isa<ValueNode>()) { | ||||
| onnx::AttributeProto* attr_proto = node_proto->add_attribute(); | |||||
| onnx::AttributeProto *attr_proto = node_proto->add_attribute(); | |||||
| attr_proto->set_name("to"); | attr_proto->set_name("to"); | ||||
| attr_proto->set_type(onnx::AttributeProto_AttributeType_INT); | attr_proto->set_type(onnx::AttributeProto_AttributeType_INT); | ||||
| auto type_value = dyn_cast<ValueNode>(input_type)->value(); | auto type_value = dyn_cast<ValueNode>(input_type)->value(); | ||||
| @@ -655,8 +655,8 @@ void OnnxExporter::ExportPrimCast(const FuncGraphPtr& /*func_graph*/, const CNod | |||||
| } | } | ||||
| } | } | ||||
| void OnnxExporter::ExportPrimPReLU(const FuncGraphPtr& /*func_graph*/, const CNodePtr& node, | |||||
| std::map<AnfNodePtr, size_t>* node_map_ptr, onnx::GraphProto* const graph_proto) { | |||||
| void OnnxExporter::ExportPrimPReLU(const FuncGraphPtr & /*func_graph*/, const CNodePtr &node, | |||||
| std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *const graph_proto) { | |||||
| auto input_x = GetNodeInputName(node->input(1), node_map_ptr, graph_proto); | auto input_x = GetNodeInputName(node->input(1), node_map_ptr, graph_proto); | ||||
| auto input_slope = GetNodeInputName(node->input(2), node_map_ptr, graph_proto); | auto input_slope = GetNodeInputName(node->input(2), node_map_ptr, graph_proto); | ||||
| @@ -668,11 +668,11 @@ void OnnxExporter::ExportPrimPReLU(const FuncGraphPtr& /*func_graph*/, const CNo | |||||
| // format of x is NCHW, input format is NCHW, if length of input_slope is 1, insert Unsqueeze [1,2] | // format of x is NCHW, input format is NCHW, if length of input_slope is 1, insert Unsqueeze [1,2] | ||||
| if (x_shape->shape().size() == 4 && slope_shape->shape().size() == 1) { | if (x_shape->shape().size() == 4 && slope_shape->shape().size() == 1) { | ||||
| auto node_idx = AllocateNodeIndex(); | auto node_idx = AllocateNodeIndex(); | ||||
| onnx::NodeProto* node_proto = graph_proto->add_node(); | |||||
| onnx::NodeProto *node_proto = graph_proto->add_node(); | |||||
| node_proto->set_op_type("Unsqueeze"); | node_proto->set_op_type("Unsqueeze"); | ||||
| node_proto->add_output(std::to_string(node_idx)); | node_proto->add_output(std::to_string(node_idx)); | ||||
| onnx::AttributeProto* attr_proto = node_proto->add_attribute(); | |||||
| onnx::AttributeProto *attr_proto = node_proto->add_attribute(); | |||||
| attr_proto->set_type(onnx::AttributeProto_AttributeType_INTS); | attr_proto->set_type(onnx::AttributeProto_AttributeType_INTS); | ||||
| attr_proto->set_name("axes"); | attr_proto->set_name("axes"); | ||||
| attr_proto->add_ints(1); | attr_proto->add_ints(1); | ||||
| @@ -684,15 +684,15 @@ void OnnxExporter::ExportPrimPReLU(const FuncGraphPtr& /*func_graph*/, const CNo | |||||
| auto node_idx = AllocateNodeIndex(); | auto node_idx = AllocateNodeIndex(); | ||||
| (*node_map_ptr)[node] = node_idx; | (*node_map_ptr)[node] = node_idx; | ||||
| onnx::NodeProto* node_proto = graph_proto->add_node(); | |||||
| onnx::NodeProto *node_proto = graph_proto->add_node(); | |||||
| node_proto->set_op_type("PRelu"); | node_proto->set_op_type("PRelu"); | ||||
| node_proto->add_output(std::to_string(node_idx)); | node_proto->add_output(std::to_string(node_idx)); | ||||
| node_proto->add_input(input_x); | node_proto->add_input(input_x); | ||||
| node_proto->add_input(input_slope); | node_proto->add_input(input_slope); | ||||
| } | } | ||||
| void OnnxExporter::ExportCNode(const FuncGraphPtr& func_graph, const CNodePtr& node, | |||||
| std::map<AnfNodePtr, size_t>* node_map_ptr, onnx::GraphProto* const graph_proto) { | |||||
| void OnnxExporter::ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &node, | |||||
| std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *const graph_proto) { | |||||
| // Type of the 2nd input of 'Reshape' of MindSpore is tuple, but ONNX's is tensor, need to do some convert | // Type of the 2nd input of 'Reshape' of MindSpore is tuple, but ONNX's is tensor, need to do some convert | ||||
| if (node->IsApply(prim::kPrimReshape)) { | if (node->IsApply(prim::kPrimReshape)) { | ||||
| return ExportPrimReshape(func_graph, node, node_map_ptr, graph_proto); | return ExportPrimReshape(func_graph, node, node_map_ptr, graph_proto); | ||||
| @@ -735,31 +735,31 @@ void OnnxExporter::ExportCNode(const FuncGraphPtr& func_graph, const CNodePtr& n | |||||
| (*node_map_ptr)[node] = ExportPrimitive(func_graph, node_map_ptr, prim, op_inputs, graph_proto); | (*node_map_ptr)[node] = ExportPrimitive(func_graph, node_map_ptr, prim, op_inputs, graph_proto); | ||||
| } | } | ||||
| size_t OnnxExporter::ExportPrimitive(const FuncGraphPtr& /*func_graph*/, std::map<AnfNodePtr, size_t>* node_map_ptr, | |||||
| const PrimitivePtr& prim, const std::vector<AnfNodePtr>& inputs, | |||||
| onnx::GraphProto* const graph_proto) { | |||||
| size_t OnnxExporter::ExportPrimitive(const FuncGraphPtr & /*func_graph*/, std::map<AnfNodePtr, size_t> *node_map_ptr, | |||||
| const PrimitivePtr &prim, const std::vector<AnfNodePtr> &inputs, | |||||
| onnx::GraphProto *const graph_proto) { | |||||
| auto op_map = OpConvertRegistry::GetOpConvertMap(); | auto op_map = OpConvertRegistry::GetOpConvertMap(); | ||||
| auto op_iter = op_map.find(prim->name()); | auto op_iter = op_map.find(prim->name()); | ||||
| if (op_iter == op_map.end()) { | if (op_iter == op_map.end()) { | ||||
| MS_LOG(EXCEPTION) << "Can not find key " << prim->name() << " in convert map"; | MS_LOG(EXCEPTION) << "Can not find key " << prim->name() << " in convert map"; | ||||
| } | } | ||||
| const OpNameInfo& op_convert_info = op_iter->second; | |||||
| const OpNameInfo &op_convert_info = op_iter->second; | |||||
| auto node_idx = AllocateNodeIndex(); | auto node_idx = AllocateNodeIndex(); | ||||
| onnx::NodeProto* node_proto = graph_proto->add_node(); | |||||
| onnx::NodeProto *node_proto = graph_proto->add_node(); | |||||
| node_proto->add_output(std::to_string(node_idx)); | node_proto->add_output(std::to_string(node_idx)); | ||||
| node_proto->set_op_type(op_convert_info.onnx_type()); | node_proto->set_op_type(op_convert_info.onnx_type()); | ||||
| // Set inputs | // Set inputs | ||||
| for (const auto& input : inputs) { | |||||
| for (const auto &input : inputs) { | |||||
| auto input_name = GetNodeInputName(input, node_map_ptr, graph_proto); | auto input_name = GetNodeInputName(input, node_map_ptr, graph_proto); | ||||
| node_proto->add_input(input_name); | node_proto->add_input(input_name); | ||||
| } | } | ||||
| // Set node attribute | // Set node attribute | ||||
| for (const OpAttrInfo& attr : op_convert_info.op_attrs()) { | |||||
| const std::string& attr_name = attr.attr_name(); | |||||
| for (const OpAttrInfo &attr : op_convert_info.op_attrs()) { | |||||
| const std::string &attr_name = attr.attr_name(); | |||||
| ValuePtr attr_value = nullptr; | ValuePtr attr_value = nullptr; | ||||
| if (!attr_name.empty()) { | if (!attr_name.empty()) { | ||||
| attr_value = prim->GetAttr(attr_name); | attr_value = prim->GetAttr(attr_name); | ||||
| @@ -767,15 +767,15 @@ size_t OnnxExporter::ExportPrimitive(const FuncGraphPtr& /*func_graph*/, std::ma | |||||
| MS_LOG(EXCEPTION) << "Primitive " << prim->name() << " does not have attribute " << attr_name; | MS_LOG(EXCEPTION) << "Primitive " << prim->name() << " does not have attribute " << attr_name; | ||||
| } | } | ||||
| } | } | ||||
| onnx::AttributeProto* onnx_attr_proto = node_proto->add_attribute(); | |||||
| onnx::AttributeProto *onnx_attr_proto = node_proto->add_attribute(); | |||||
| onnx_attr_proto->set_name(attr.onnx_attr_name()); | onnx_attr_proto->set_name(attr.onnx_attr_name()); | ||||
| attr.fn_gen_attr()(attr_value, attr.onnx_attr_type(), onnx_attr_proto, prim); | attr.fn_gen_attr()(attr_value, attr.onnx_attr_type(), onnx_attr_proto, prim); | ||||
| } | } | ||||
| return node_idx; | return node_idx; | ||||
| } | } | ||||
| void OnnxExporter::ExportMergeConv(const FuncGraphPtr& func_graph, const CNodePtr& node, | |||||
| std::map<AnfNodePtr, size_t>* node_map_ptr, onnx::GraphProto* const graph_proto) { | |||||
| void OnnxExporter::ExportMergeConv(const FuncGraphPtr &func_graph, const CNodePtr &node, | |||||
| std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *const graph_proto) { | |||||
| auto conv_node = dyn_cast<CNode>(node->input(1)); | auto conv_node = dyn_cast<CNode>(node->input(1)); | ||||
| auto input_x = conv_node->input(1); // conv input x | auto input_x = conv_node->input(1); // conv input x | ||||
| auto input_w = conv_node->input(2); // conv weight(filter) | auto input_w = conv_node->input(2); // conv weight(filter) | ||||
| @@ -786,8 +786,8 @@ void OnnxExporter::ExportMergeConv(const FuncGraphPtr& func_graph, const CNodePt | |||||
| (*node_map_ptr)[node] = ExportPrimitive(func_graph, node_map_ptr, prim_conv, inputs, graph_proto); | (*node_map_ptr)[node] = ExportPrimitive(func_graph, node_map_ptr, prim_conv, inputs, graph_proto); | ||||
| } | } | ||||
| void OnnxExporter::ExportMergeGemm(const FuncGraphPtr& func_graph, const CNodePtr& node, | |||||
| std::map<AnfNodePtr, size_t>* node_map_ptr, onnx::GraphProto* const graph_proto) { | |||||
| void OnnxExporter::ExportMergeGemm(const FuncGraphPtr &func_graph, const CNodePtr &node, | |||||
| std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *const graph_proto) { | |||||
| auto matmul_node = dyn_cast<CNode>(node->input(1)); | auto matmul_node = dyn_cast<CNode>(node->input(1)); | ||||
| auto input_x = matmul_node->input(1); // matmul input x | auto input_x = matmul_node->input(1); // matmul input x | ||||
| auto input_y = matmul_node->input(2); // matmul input y | auto input_y = matmul_node->input(2); // matmul input y | ||||
| @@ -798,9 +798,9 @@ void OnnxExporter::ExportMergeGemm(const FuncGraphPtr& func_graph, const CNodePt | |||||
| (*node_map_ptr)[node] = ExportPrimitive(func_graph, node_map_ptr, prim_matmul, inputs, graph_proto); | (*node_map_ptr)[node] = ExportPrimitive(func_graph, node_map_ptr, prim_matmul, inputs, graph_proto); | ||||
| } | } | ||||
| void OnnxExporter::ExportMergeBatchNorm(const FuncGraphPtr& func_graph, const CNodePtr& node, | |||||
| std::map<AnfNodePtr, size_t>* node_map_ptr, | |||||
| onnx::GraphProto* const graph_proto) { | |||||
| void OnnxExporter::ExportMergeBatchNorm(const FuncGraphPtr &func_graph, const CNodePtr &node, | |||||
| std::map<AnfNodePtr, size_t> *node_map_ptr, | |||||
| onnx::GraphProto *const graph_proto) { | |||||
| auto batch_norm_node = dyn_cast<CNode>(node->input(1)); | auto batch_norm_node = dyn_cast<CNode>(node->input(1)); | ||||
| PrimitivePtr prim_batch_norm = dyn_cast<Primitive>((dyn_cast<ValueNode>(batch_norm_node->input(0)))->value()); | PrimitivePtr prim_batch_norm = dyn_cast<Primitive>((dyn_cast<ValueNode>(batch_norm_node->input(0)))->value()); | ||||
| @@ -811,20 +811,20 @@ void OnnxExporter::ExportMergeBatchNorm(const FuncGraphPtr& func_graph, const CN | |||||
| (*node_map_ptr)[node] = ExportPrimitive(func_graph, node_map_ptr, prim_batch_norm, inputs, graph_proto); | (*node_map_ptr)[node] = ExportPrimitive(func_graph, node_map_ptr, prim_batch_norm, inputs, graph_proto); | ||||
| } | } | ||||
| void OnnxExporter::ExportOutput(const FuncGraphPtr& /*func_graph*/, const CNodePtr& node, | |||||
| std::map<AnfNodePtr, size_t>* node_map_ptr, onnx::GraphProto* const graph_proto) { | |||||
| void OnnxExporter::ExportOutput(const FuncGraphPtr & /*func_graph*/, const CNodePtr &node, | |||||
| std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *const graph_proto) { | |||||
| if (node->inputs().size() != 2) { | if (node->inputs().size() != 2) { | ||||
| MS_LOG(EXCEPTION) << "Number of inputs of return node is not equal to 2."; | MS_LOG(EXCEPTION) << "Number of inputs of return node is not equal to 2."; | ||||
| } | } | ||||
| AnfNodePtr arg = node->input(1); | AnfNodePtr arg = node->input(1); | ||||
| std::string name = GetNodeInputName(arg, node_map_ptr, graph_proto); | std::string name = GetNodeInputName(arg, node_map_ptr, graph_proto); | ||||
| onnx::ValueInfoProto* output_proto = graph_proto->add_output(); | |||||
| onnx::ValueInfoProto *output_proto = graph_proto->add_output(); | |||||
| output_proto->set_name(name); | output_proto->set_name(name); | ||||
| SetValueInfoType(arg, output_proto, false); | SetValueInfoType(arg, output_proto, false); | ||||
| } | } | ||||
| std::string OnnxExporter::GetNodeInputName(const AnfNodePtr& node, std::map<AnfNodePtr, size_t>* node_map_ptr, | |||||
| onnx::GraphProto* const graph_proto) { | |||||
| std::string OnnxExporter::GetNodeInputName(const AnfNodePtr &node, std::map<AnfNodePtr, size_t> *node_map_ptr, | |||||
| onnx::GraphProto *const graph_proto) { | |||||
| if (node->isa<CNode>()) { | if (node->isa<CNode>()) { | ||||
| auto iter = node_map_ptr->find(node); | auto iter = node_map_ptr->find(node); | ||||
| if (iter == node_map_ptr->end()) { | if (iter == node_map_ptr->end()) { | ||||
| @@ -848,7 +848,7 @@ std::string OnnxExporter::GetNodeInputName(const AnfNodePtr& node, std::map<AnfN | |||||
| (*node_map_ptr)[node] = node_idx; | (*node_map_ptr)[node] = node_idx; | ||||
| std::string node_name = std::to_string(node_idx); | std::string node_name = std::to_string(node_idx); | ||||
| onnx::NodeProto* node_proto = graph_proto->add_node(); | |||||
| onnx::NodeProto *node_proto = graph_proto->add_node(); | |||||
| node_proto->add_output(node_name); | node_proto->add_output(node_name); | ||||
| SetNodeAttribute(node->cast<ValueNodePtr>()->value(), node_proto); | SetNodeAttribute(node->cast<ValueNodePtr>()->value(), node_proto); | ||||
| @@ -859,7 +859,7 @@ std::string OnnxExporter::GetNodeInputName(const AnfNodePtr& node, std::map<AnfN | |||||
| MS_LOG(EXCEPTION) << "Unexpected node type " << node->type_name(); | MS_LOG(EXCEPTION) << "Unexpected node type " << node->type_name(); | ||||
| } | } | ||||
| void OnnxExporter::ConvertTupleToTensor(const ValuePtr& value, onnx::TensorProto* const tensor_proto) { | |||||
| void OnnxExporter::ConvertTupleToTensor(const ValuePtr &value, onnx::TensorProto *const tensor_proto) { | |||||
| auto tuple_ptr = dyn_cast<ValueTuple>(value); | auto tuple_ptr = dyn_cast<ValueTuple>(value); | ||||
| MS_EXCEPTION_IF_NULL(tuple_ptr); | MS_EXCEPTION_IF_NULL(tuple_ptr); | ||||
| if (tuple_ptr->size() == 0) { | if (tuple_ptr->size() == 0) { | ||||
| @@ -891,14 +891,14 @@ void OnnxExporter::ConvertTupleToTensor(const ValuePtr& value, onnx::TensorProto | |||||
| } | } | ||||
| } | } | ||||
| void OnnxExporter::SetNodeAttribute(const ValuePtr& value, onnx::NodeProto* const node_proto) { | |||||
| void OnnxExporter::SetNodeAttribute(const ValuePtr &value, onnx::NodeProto *const node_proto) { | |||||
| node_proto->set_op_type("Constant"); | node_proto->set_op_type("Constant"); | ||||
| onnx::AttributeProto* attr_proto = node_proto->add_attribute(); | |||||
| onnx::AttributeProto *attr_proto = node_proto->add_attribute(); | |||||
| attr_proto->set_name("value"); | attr_proto->set_name("value"); | ||||
| MS_LOG(EXCEPTION) << "Need to set value " << value->ToString() << " attribute for Constant node"; | MS_LOG(EXCEPTION) << "Need to set value " << value->ToString() << " attribute for Constant node"; | ||||
| } | } | ||||
| std::string GetOnnxProtoString(const FuncGraphPtr& func_graph) { | |||||
| std::string GetOnnxProtoString(const FuncGraphPtr &func_graph) { | |||||
| OnnxExporter exporter; | OnnxExporter exporter; | ||||
| return exporter.GetOnnxProtoString(func_graph); | return exporter.GetOnnxProtoString(func_graph); | ||||
| } | } | ||||
| @@ -32,12 +32,12 @@ enum class DataType { kInt, kFloat, kDouble, kUnknown }; | |||||
| // Whether has a T type data in AnyPtrList. | // Whether has a T type data in AnyPtrList. | ||||
| template <class T> | template <class T> | ||||
| bool HasType(const AnyPtrList& list) { | |||||
| bool ret = std::any_of(list.begin(), list.end(), [](const AnyPtr& ptr) { return ptr->is<T>(); }); | |||||
| bool HasType(const AnyPtrList &list) { | |||||
| bool ret = std::any_of(list.begin(), list.end(), [](const AnyPtr &ptr) { return ptr->is<T>(); }); | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| DataType InferType(const AnyPtrList& list) { | |||||
| DataType InferType(const AnyPtrList &list) { | |||||
| if (HasType<double>(list)) { | if (HasType<double>(list)) { | ||||
| return DataType::kDouble; | return DataType::kDouble; | ||||
| } else if (HasType<float>(list)) { | } else if (HasType<float>(list)) { | ||||
| @@ -180,7 +180,7 @@ bool InnerScalarGe(T x, U y) { | |||||
| } | } | ||||
| #define SCALAR_OP(op_t) \ | #define SCALAR_OP(op_t) \ | ||||
| ValuePtr Scalar##op_t(const ValuePtrList& list) { \ | |||||
| ValuePtr Scalar##op_t(const ValuePtrList &list) { \ | |||||
| do { \ | do { \ | ||||
| if (list.size() < 2) { \ | if (list.size() < 2) { \ | ||||
| MS_LOG(EXCEPTION) << "length of input list for Scalar" << #op_t << " is less than 2."; \ | MS_LOG(EXCEPTION) << "length of input list for Scalar" << #op_t << " is less than 2."; \ | ||||
| @@ -223,7 +223,7 @@ SCALAR_OP(Pow) | |||||
| SCALAR_OP(Floordiv) | SCALAR_OP(Floordiv) | ||||
| #define LOGIC_OP(op_t) \ | #define LOGIC_OP(op_t) \ | ||||
| ValuePtr Scalar##op_t(const ValuePtrList& list) { \ | |||||
| ValuePtr Scalar##op_t(const ValuePtrList &list) { \ | |||||
| if (list.size() < 2) { \ | if (list.size() < 2) { \ | ||||
| MS_LOG(EXCEPTION) << "length of input list for Scalar" << #op_t << " is less than 2."; \ | MS_LOG(EXCEPTION) << "length of input list for Scalar" << #op_t << " is less than 2."; \ | ||||
| } \ | } \ | ||||
| @@ -274,7 +274,7 @@ LOGIC_OP(Ne) | |||||
| LOGIC_OP(Le) | LOGIC_OP(Le) | ||||
| LOGIC_OP(Ge) | LOGIC_OP(Ge) | ||||
| ValuePtr ScalarUAdd(const ValuePtrList& list) { | |||||
| ValuePtr ScalarUAdd(const ValuePtrList &list) { | |||||
| if (list.size() != 1) { | if (list.size() != 1) { | ||||
| MS_LOG(EXCEPTION) << "Input number of ScalarUAdd should be 1, but got " << list.size(); | MS_LOG(EXCEPTION) << "Input number of ScalarUAdd should be 1, but got " << list.size(); | ||||
| } | } | ||||
| @@ -283,7 +283,7 @@ ValuePtr ScalarUAdd(const ValuePtrList& list) { | |||||
| return x; | return x; | ||||
| } | } | ||||
| ValuePtr ScalarUSub(const ValuePtrList& list) { | |||||
| ValuePtr ScalarUSub(const ValuePtrList &list) { | |||||
| if (list.size() != 1) { | if (list.size() != 1) { | ||||
| MS_LOG(EXCEPTION) << "Input number of ScalarUSub should be 1, but got " << list.size(); | MS_LOG(EXCEPTION) << "Input number of ScalarUSub should be 1, but got " << list.size(); | ||||
| } | } | ||||
| @@ -302,7 +302,7 @@ ValuePtr ScalarUSub(const ValuePtrList& list) { | |||||
| MS_LOG(EXCEPTION) << "Unsported Value for ScalarUSub, x: " << x->ToString() << "."; | MS_LOG(EXCEPTION) << "Unsported Value for ScalarUSub, x: " << x->ToString() << "."; | ||||
| } | } | ||||
| ValuePtr ScalarLog(const ValuePtrList& list) { | |||||
| ValuePtr ScalarLog(const ValuePtrList &list) { | |||||
| if (list.empty()) { | if (list.empty()) { | ||||
| MS_LOG(EXCEPTION) << "Input list of ScalarLog is empty."; | MS_LOG(EXCEPTION) << "Input list of ScalarLog is empty."; | ||||
| } | } | ||||
| @@ -321,7 +321,7 @@ ValuePtr ScalarLog(const ValuePtrList& list) { | |||||
| MS_LOG(EXCEPTION) << "Unsported Value for ScalarLog, x: " << x->ToString(); | MS_LOG(EXCEPTION) << "Unsported Value for ScalarLog, x: " << x->ToString(); | ||||
| } | } | ||||
| ValuePtr BoolNot(const ValuePtrList& list) { | |||||
| ValuePtr BoolNot(const ValuePtrList &list) { | |||||
| if (list.empty()) { | if (list.empty()) { | ||||
| MS_LOG(EXCEPTION) << "value list of BoolNot is empty"; | MS_LOG(EXCEPTION) << "value list of BoolNot is empty"; | ||||
| } | } | ||||
| @@ -337,7 +337,7 @@ ValuePtr BoolNot(const ValuePtrList& list) { | |||||
| MS_LOG(EXCEPTION) << "Unsported Value for BoolNot, x: " << x->ToString(); | MS_LOG(EXCEPTION) << "Unsported Value for BoolNot, x: " << x->ToString(); | ||||
| } | } | ||||
| ValuePtr BoolAnd(const ValuePtrList& list) { | |||||
| ValuePtr BoolAnd(const ValuePtrList &list) { | |||||
| if (list.size() < 2) { | if (list.size() < 2) { | ||||
| MS_LOG(EXCEPTION) << "Input number " << list.size() << " of BoolAnd is less then 2."; | MS_LOG(EXCEPTION) << "Input number " << list.size() << " of BoolAnd is less then 2."; | ||||
| } | } | ||||
| @@ -356,7 +356,7 @@ ValuePtr BoolAnd(const ValuePtrList& list) { | |||||
| MS_LOG(EXCEPTION) << "Unsported Value for BoolAnd, x: " << x->ToString() << "."; | MS_LOG(EXCEPTION) << "Unsported Value for BoolAnd, x: " << x->ToString() << "."; | ||||
| } | } | ||||
| ValuePtr BoolOr(const ValuePtrList& list) { | |||||
| ValuePtr BoolOr(const ValuePtrList &list) { | |||||
| if (list.size() < 2) { | if (list.size() < 2) { | ||||
| MS_LOG(EXCEPTION) << "Input number " << list.size() << " of BoolOr is less then 2."; | MS_LOG(EXCEPTION) << "Input number " << list.size() << " of BoolOr is less then 2."; | ||||
| } | } | ||||
| @@ -375,7 +375,7 @@ ValuePtr BoolOr(const ValuePtrList& list) { | |||||
| MS_LOG(EXCEPTION) << "Unsported Value for BoolOr, x: " << x->ToString() << "."; | MS_LOG(EXCEPTION) << "Unsported Value for BoolOr, x: " << x->ToString() << "."; | ||||
| } | } | ||||
| ValuePtr BoolEq(const ValuePtrList& list) { | |||||
| ValuePtr BoolEq(const ValuePtrList &list) { | |||||
| if (list.size() < 2) { | if (list.size() < 2) { | ||||
| MS_LOG(EXCEPTION) << "Input number " << list.size() << " of BoolEq is less than 2."; | MS_LOG(EXCEPTION) << "Input number " << list.size() << " of BoolEq is less than 2."; | ||||
| } | } | ||||
| @@ -29,29 +29,29 @@ namespace prim { | |||||
| using Any = mindspore::Any; | using Any = mindspore::Any; | ||||
| using AnyPtrList = std::vector<std::shared_ptr<Any>>; | using AnyPtrList = std::vector<std::shared_ptr<Any>>; | ||||
| using ValuePtrList = std::vector<ValuePtr>; | using ValuePtrList = std::vector<ValuePtr>; | ||||
| using OpsFunction = std::function<Any(const AnyPtrList&)>; | |||||
| using AnfNodeOpsFunction = std::function<AnfNodePtr(const std::vector<AnfNodePtr>&)>; | |||||
| using OpsFunction = std::function<Any(const AnyPtrList &)>; | |||||
| using AnfNodeOpsFunction = std::function<AnfNodePtr(const std::vector<AnfNodePtr> &)>; | |||||
| ValuePtr ScalarAdd(const ValuePtrList& list); | |||||
| ValuePtr ScalarSub(const ValuePtrList& list); | |||||
| ValuePtr ScalarMul(const ValuePtrList& list); | |||||
| ValuePtr ScalarDiv(const ValuePtrList& list); | |||||
| ValuePtr ScalarMod(const ValuePtrList& list); | |||||
| ValuePtr ScalarPow(const ValuePtrList& list); | |||||
| ValuePtr ScalarFloordiv(const ValuePtrList& list); | |||||
| ValuePtr ScalarUAdd(const ValuePtrList& list); | |||||
| ValuePtr ScalarUSub(const ValuePtrList& list); | |||||
| ValuePtr ScalarLog(const ValuePtrList& list); | |||||
| ValuePtr ScalarEq(const ValuePtrList& list); | |||||
| ValuePtr ScalarLt(const ValuePtrList& list); | |||||
| ValuePtr ScalarGt(const ValuePtrList& list); | |||||
| ValuePtr ScalarNe(const ValuePtrList& list); | |||||
| ValuePtr ScalarLe(const ValuePtrList& list); | |||||
| ValuePtr ScalarGe(const ValuePtrList& list); | |||||
| ValuePtr BoolNot(const ValuePtrList& list); | |||||
| ValuePtr BoolAnd(const ValuePtrList& list); | |||||
| ValuePtr BoolOr(const ValuePtrList& list); | |||||
| ValuePtr BoolEq(const ValuePtrList& list); | |||||
| ValuePtr ScalarAdd(const ValuePtrList &list); | |||||
| ValuePtr ScalarSub(const ValuePtrList &list); | |||||
| ValuePtr ScalarMul(const ValuePtrList &list); | |||||
| ValuePtr ScalarDiv(const ValuePtrList &list); | |||||
| ValuePtr ScalarMod(const ValuePtrList &list); | |||||
| ValuePtr ScalarPow(const ValuePtrList &list); | |||||
| ValuePtr ScalarFloordiv(const ValuePtrList &list); | |||||
| ValuePtr ScalarUAdd(const ValuePtrList &list); | |||||
| ValuePtr ScalarUSub(const ValuePtrList &list); | |||||
| ValuePtr ScalarLog(const ValuePtrList &list); | |||||
| ValuePtr ScalarEq(const ValuePtrList &list); | |||||
| ValuePtr ScalarLt(const ValuePtrList &list); | |||||
| ValuePtr ScalarGt(const ValuePtrList &list); | |||||
| ValuePtr ScalarNe(const ValuePtrList &list); | |||||
| ValuePtr ScalarLe(const ValuePtrList &list); | |||||
| ValuePtr ScalarGe(const ValuePtrList &list); | |||||
| ValuePtr BoolNot(const ValuePtrList &list); | |||||
| ValuePtr BoolAnd(const ValuePtrList &list); | |||||
| ValuePtr BoolOr(const ValuePtrList &list); | |||||
| ValuePtr BoolEq(const ValuePtrList &list); | |||||
| std::vector<int> BroadcastShape_(std::vector<int> s1, std::vector<int> s2); | std::vector<int> BroadcastShape_(std::vector<int> s1, std::vector<int> s2); | ||||
| } // namespace prim | } // namespace prim | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -66,7 +66,7 @@ const MetaFuncGraphPtr kTail = std::make_shared<Tail>("tail"); | |||||
| // Apply a function of two arguments cumulatively to the items of a sequence, | // Apply a function of two arguments cumulatively to the items of a sequence, | ||||
| // from left to right, so as to reduce the sequence to a single value.For example, | // from left to right, so as to reduce the sequence to a single value.For example, | ||||
| // reduce(lambda x, y: x + y, [ 1, 2, 3, 4, 5 ]) calculates ((((1 + 2) + 3) + 4) + 5). | // reduce(lambda x, y: x + y, [ 1, 2, 3, 4, 5 ]) calculates ((((1 + 2) + 3) + 4) + 5). | ||||
| AnyPtr Reduce(const OpsFunction& func, const AnyPtrList& list) { | |||||
| AnyPtr Reduce(const OpsFunction &func, const AnyPtrList &list) { | |||||
| std::shared_ptr<Any> ret; | std::shared_ptr<Any> ret; | ||||
| size_t size = list.size(); | size_t size = list.size(); | ||||
| if (size < 2) { | if (size < 2) { | ||||
| @@ -88,7 +88,7 @@ AnyPtr Reduce(const OpsFunction& func, const AnyPtrList& list) { | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| AnfNodePtr Reduce(const AnfNodeOpsFunction& func, const std::vector<AnfNodePtr>& list) { | |||||
| AnfNodePtr Reduce(const AnfNodeOpsFunction &func, const std::vector<AnfNodePtr> &list) { | |||||
| size_t size = list.size(); | size_t size = list.size(); | ||||
| if (size < 2) { | if (size < 2) { | ||||
| MS_LOG(EXCEPTION) << "length of inputs of Reduce is less than 2"; | MS_LOG(EXCEPTION) << "length of inputs of Reduce is less than 2"; | ||||
| @@ -121,7 +121,7 @@ void HyperMap::Init() { | |||||
| {"args", SignatureEnumRW::kRWRef, SignatureEnumKind::kKindVarPositional}}); | {"args", SignatureEnumRW::kRWRef, SignatureEnumKind::kKindVarPositional}}); | ||||
| } | } | ||||
| HyperMap::HyperMap(const std::shared_ptr<MultitypeFuncGraph>& fn_leaf) | |||||
| HyperMap::HyperMap(const std::shared_ptr<MultitypeFuncGraph> &fn_leaf) | |||||
| : MetaFuncGraph("hyper_map"), | : MetaFuncGraph("hyper_map"), | ||||
| fn_leaf_(fn_leaf), | fn_leaf_(fn_leaf), | ||||
| broadcast_(false), | broadcast_(false), | ||||
| @@ -129,13 +129,13 @@ HyperMap::HyperMap(const std::shared_ptr<MultitypeFuncGraph>& fn_leaf) | |||||
| Init(); | Init(); | ||||
| } | } | ||||
| HyperMap::HyperMap(const HyperMap& h) | |||||
| HyperMap::HyperMap(const HyperMap &h) | |||||
| : MetaFuncGraph("hyper_map"), fn_leaf_(h.fn_leaf_), broadcast_(h.broadcast_), nonleaf_(h.nonleaf_) { | : MetaFuncGraph("hyper_map"), fn_leaf_(h.fn_leaf_), broadcast_(h.broadcast_), nonleaf_(h.nonleaf_) { | ||||
| Init(); | Init(); | ||||
| } | } | ||||
| AnfNodePtr HyperMap::FullMake(TypePtr, const FuncGraphPtr& func_graph, const AnfNodePtr& fn_arg, | |||||
| const ArgsPairList& arg_map) { | |||||
| AnfNodePtr HyperMap::FullMake(TypePtr, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, | |||||
| const ArgsPairList &arg_map) { | |||||
| MS_EXCEPTION_IF_NULL(func_graph); | MS_EXCEPTION_IF_NULL(func_graph); | ||||
| std::vector<AnfNodePtr> inputs; | std::vector<AnfNodePtr> inputs; | ||||
| if (fn_arg != nullptr) { | if (fn_arg != nullptr) { | ||||
| @@ -145,17 +145,17 @@ AnfNodePtr HyperMap::FullMake(TypePtr, const FuncGraphPtr& func_graph, const Anf | |||||
| } | } | ||||
| (void)std::transform(arg_map.begin(), arg_map.end(), std::back_inserter(inputs), | (void)std::transform(arg_map.begin(), arg_map.end(), std::back_inserter(inputs), | ||||
| [](const std::pair<AnfNodePtr, Any>& item) { return item.first; }); | |||||
| [](const std::pair<AnfNodePtr, Any> &item) { return item.first; }); | |||||
| return func_graph->NewCNode(inputs); | return func_graph->NewCNode(inputs); | ||||
| } | } | ||||
| AnfNodePtr HyperMap::FullMake(const std::shared_ptr<List>& type, const FuncGraphPtr& func_graph, | |||||
| const AnfNodePtr& fn_arg, const ArgsPairList& arg_map) { | |||||
| AnfNodePtr HyperMap::FullMake(const std::shared_ptr<List> &type, const FuncGraphPtr &func_graph, | |||||
| const AnfNodePtr &fn_arg, const ArgsPairList &arg_map) { | |||||
| MS_EXCEPTION_IF_NULL(func_graph); | MS_EXCEPTION_IF_NULL(func_graph); | ||||
| MS_EXCEPTION_IF_NULL(type); | MS_EXCEPTION_IF_NULL(type); | ||||
| std::size_t size = type->elements().size(); | std::size_t size = type->elements().size(); | ||||
| bool is_not_same = std::any_of(arg_map.begin(), arg_map.end(), [size](const std::pair<AnfNodePtr, TypePtr>& item) { | |||||
| bool is_not_same = std::any_of(arg_map.begin(), arg_map.end(), [size](const std::pair<AnfNodePtr, TypePtr> &item) { | |||||
| auto lhs = std::static_pointer_cast<List>(item.second); | auto lhs = std::static_pointer_cast<List>(item.second); | ||||
| MS_EXCEPTION_IF_NULL(lhs); | MS_EXCEPTION_IF_NULL(lhs); | ||||
| return lhs->elements().size() != size; | return lhs->elements().size() != size; | ||||
| @@ -179,7 +179,7 @@ AnfNodePtr HyperMap::FullMake(const std::shared_ptr<List>& type, const FuncGraph | |||||
| (void)std::transform( | (void)std::transform( | ||||
| arg_map.begin(), arg_map.end(), std::back_inserter(inputs2), | arg_map.begin(), arg_map.end(), std::back_inserter(inputs2), | ||||
| [&func_graph, i](const std::pair<AnfNodePtr, Any>& item) { | |||||
| [&func_graph, i](const std::pair<AnfNodePtr, Any> &item) { | |||||
| return func_graph->NewCNode({NewValueNode(prim::kPrimListGetItem), item.first, NewValueNode(i)}); | return func_graph->NewCNode({NewValueNode(prim::kPrimListGetItem), item.first, NewValueNode(i)}); | ||||
| }); | }); | ||||
| @@ -188,13 +188,13 @@ AnfNodePtr HyperMap::FullMake(const std::shared_ptr<List>& type, const FuncGraph | |||||
| return func_graph->NewCNode(inputs); | return func_graph->NewCNode(inputs); | ||||
| } | } | ||||
| AnfNodePtr HyperMap::FullMake(const std::shared_ptr<Tuple>& type, const FuncGraphPtr& func_graph, | |||||
| const AnfNodePtr& fn_arg, const ArgsPairList& arg_map) { | |||||
| AnfNodePtr HyperMap::FullMake(const std::shared_ptr<Tuple> &type, const FuncGraphPtr &func_graph, | |||||
| const AnfNodePtr &fn_arg, const ArgsPairList &arg_map) { | |||||
| MS_EXCEPTION_IF_NULL(func_graph); | MS_EXCEPTION_IF_NULL(func_graph); | ||||
| MS_EXCEPTION_IF_NULL(type); | MS_EXCEPTION_IF_NULL(type); | ||||
| std::size_t size = type->elements().size(); | std::size_t size = type->elements().size(); | ||||
| bool is_not_same = std::any_of(arg_map.begin(), arg_map.end(), [size](const std::pair<AnfNodePtr, TypePtr>& item) { | |||||
| bool is_not_same = std::any_of(arg_map.begin(), arg_map.end(), [size](const std::pair<AnfNodePtr, TypePtr> &item) { | |||||
| auto lhs = std::static_pointer_cast<Tuple>(item.second); | auto lhs = std::static_pointer_cast<Tuple>(item.second); | ||||
| MS_EXCEPTION_IF_NULL(lhs); | MS_EXCEPTION_IF_NULL(lhs); | ||||
| return lhs->elements().size() != size; | return lhs->elements().size() != size; | ||||
| @@ -226,8 +226,8 @@ AnfNodePtr HyperMap::FullMake(const std::shared_ptr<Tuple>& type, const FuncGrap | |||||
| return func_graph->NewCNode(inputs); | return func_graph->NewCNode(inputs); | ||||
| } | } | ||||
| AnfNodePtr HyperMap::FullMake(const std::shared_ptr<Class>& type, const FuncGraphPtr& func_graph, | |||||
| const AnfNodePtr& fn_arg, const ArgsPairList& arg_map) { | |||||
| AnfNodePtr HyperMap::FullMake(const std::shared_ptr<Class> &type, const FuncGraphPtr &func_graph, | |||||
| const AnfNodePtr &fn_arg, const ArgsPairList &arg_map) { | |||||
| MS_EXCEPTION_IF_NULL(type); | MS_EXCEPTION_IF_NULL(type); | ||||
| MS_EXCEPTION_IF_NULL(func_graph); | MS_EXCEPTION_IF_NULL(func_graph); | ||||
| @@ -257,11 +257,11 @@ AnfNodePtr HyperMap::FullMake(const std::shared_ptr<Class>& type, const FuncGrap | |||||
| return func_graph->NewCNode(inputs); | return func_graph->NewCNode(inputs); | ||||
| } | } | ||||
| AnfNodePtr HyperMap::Make(const FuncGraphPtr& func_graph, const AnfNodePtr& fn_arg, const ArgsPairList& arg_map) { | |||||
| AnfNodePtr HyperMap::Make(const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, const ArgsPairList &arg_map) { | |||||
| bool found = false; | bool found = false; | ||||
| TypeId id = kObjectTypeEnd; | TypeId id = kObjectTypeEnd; | ||||
| std::pair<AnfNodePtr, TypePtr> pair; | std::pair<AnfNodePtr, TypePtr> pair; | ||||
| for (auto& item : arg_map) { | |||||
| for (auto &item : arg_map) { | |||||
| pair = item; | pair = item; | ||||
| id = item.second->type_id(); | id = item.second->type_id(); | ||||
| if (nonleaf_.count(id)) { | if (nonleaf_.count(id)) { | ||||
| @@ -272,7 +272,7 @@ AnfNodePtr HyperMap::Make(const FuncGraphPtr& func_graph, const AnfNodePtr& fn_a | |||||
| if (found) { | if (found) { | ||||
| // In a nonleaf situation, all arguments must have the same generic. | // In a nonleaf situation, all arguments must have the same generic. | ||||
| bool is_not_same = std::any_of(arg_map.begin(), arg_map.end(), [pair](const std::pair<AnfNodePtr, TypePtr>& item) { | |||||
| bool is_not_same = std::any_of(arg_map.begin(), arg_map.end(), [pair](const std::pair<AnfNodePtr, TypePtr> &item) { | |||||
| if (item.first != pair.first) { | if (item.first != pair.first) { | ||||
| return item.second->type_id() != pair.second->type_id(); | return item.second->type_id() != pair.second->type_id(); | ||||
| } | } | ||||
| @@ -283,7 +283,7 @@ AnfNodePtr HyperMap::Make(const FuncGraphPtr& func_graph, const AnfNodePtr& fn_a | |||||
| oss << "There are " << arg_map.size() << " inputs of `" << name_ << "`, corresponding type info:\n" | oss << "There are " << arg_map.size() << " inputs of `" << name_ << "`, corresponding type info:\n" | ||||
| << trace::GetDebugInfo(func_graph->debug_info()) << "\n"; | << trace::GetDebugInfo(func_graph->debug_info()) << "\n"; | ||||
| int idx = 0; | int idx = 0; | ||||
| for (auto& item : arg_map) { | |||||
| for (auto &item : arg_map) { | |||||
| oss << ++idx << ": " << item.second->ToString() << "\n"; | oss << ++idx << ": " << item.second->ToString() << "\n"; | ||||
| } | } | ||||
| MS_LOG(EXCEPTION) << "HyperMap cannot match up all input types of arguments.\n" << oss.str(); | MS_LOG(EXCEPTION) << "HyperMap cannot match up all input types of arguments.\n" << oss.str(); | ||||
| @@ -308,14 +308,14 @@ AnfNodePtr HyperMap::Make(const FuncGraphPtr& func_graph, const AnfNodePtr& fn_a | |||||
| } | } | ||||
| } | } | ||||
| ArgsPairList HyperMap::Harmonize(const FuncGraphPtr& func_graph, const ArgsPairList& args_spec_list) { | |||||
| ArgsPairList HyperMap::Harmonize(const FuncGraphPtr &func_graph, const ArgsPairList &args_spec_list) { | |||||
| TypePtr type_tensor = std::make_shared<TensorType>(); | TypePtr type_tensor = std::make_shared<TensorType>(); | ||||
| bool flag = std::any_of( | bool flag = std::any_of( | ||||
| args_spec_list.begin(), args_spec_list.end(), | args_spec_list.begin(), args_spec_list.end(), | ||||
| [type_tensor](const std::pair<AnfNodePtr, TypePtr>& item) { return IsSubType(item.second, type_tensor); }); | |||||
| [type_tensor](const std::pair<AnfNodePtr, TypePtr> &item) { return IsSubType(item.second, type_tensor); }); | |||||
| if (flag && broadcast_) { | if (flag && broadcast_) { | ||||
| ArgsPairList ret; | ArgsPairList ret; | ||||
| for (auto& item : args_spec_list) { | |||||
| for (auto &item : args_spec_list) { | |||||
| if (!IsSubType(item.second, type_tensor)) { | if (!IsSubType(item.second, type_tensor)) { | ||||
| TypePtr type_tensor_ele = std::make_shared<TensorType>(item.second); | TypePtr type_tensor_ele = std::make_shared<TensorType>(item.second); | ||||
| ret.push_back( | ret.push_back( | ||||
| @@ -329,7 +329,7 @@ ArgsPairList HyperMap::Harmonize(const FuncGraphPtr& func_graph, const ArgsPairL | |||||
| return args_spec_list; | return args_spec_list; | ||||
| } | } | ||||
| FuncGraphPtr HyperMap::GenerateFromTypes(const TypePtrList& args_spec_list) { | |||||
| FuncGraphPtr HyperMap::GenerateFromTypes(const TypePtrList &args_spec_list) { | |||||
| FuncGraphPtr ptrGraph = std::make_shared<FuncGraph>(); | FuncGraphPtr ptrGraph = std::make_shared<FuncGraph>(); | ||||
| ptrGraph->set_flags(FUNC_GRAPH_FLAG_CORE, true); | ptrGraph->set_flags(FUNC_GRAPH_FLAG_CORE, true); | ||||
| ptrGraph->debug_info()->set_name("hyper_map"); | ptrGraph->debug_info()->set_name("hyper_map"); | ||||
| @@ -353,7 +353,7 @@ FuncGraphPtr HyperMap::GenerateFromTypes(const TypePtrList& args_spec_list) { | |||||
| return ptrGraph; | return ptrGraph; | ||||
| } | } | ||||
| abstract::AbstractBasePtrList HyperMap::NormalizeArgs(const AbstractBasePtrList& args_spec_list) const { | |||||
| abstract::AbstractBasePtrList HyperMap::NormalizeArgs(const AbstractBasePtrList &args_spec_list) const { | |||||
| if (fn_leaf_ == nullptr) { | if (fn_leaf_ == nullptr) { | ||||
| MS_EXCEPTION_IF_NULL(args_spec_list[0]); | MS_EXCEPTION_IF_NULL(args_spec_list[0]); | ||||
| // Assert that hypermap's function param does not contain free variables | // Assert that hypermap's function param does not contain free variables | ||||
| @@ -368,20 +368,20 @@ abstract::AbstractBasePtrList HyperMap::NormalizeArgs(const AbstractBasePtrList& | |||||
| AbstractBasePtrList broadened; | AbstractBasePtrList broadened; | ||||
| (void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(broadened), | (void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(broadened), | ||||
| [](const AbstractBasePtr& arg) -> AbstractBasePtr { | |||||
| [](const AbstractBasePtr &arg) -> AbstractBasePtr { | |||||
| MS_EXCEPTION_IF_NULL(arg); | MS_EXCEPTION_IF_NULL(arg); | ||||
| return arg->Broaden(); | return arg->Broaden(); | ||||
| }); | }); | ||||
| return broadened; | return broadened; | ||||
| } | } | ||||
| REGISTER_PYBIND_DEFINE(HyperMap_, ([](const py::module* m) { | |||||
| REGISTER_PYBIND_DEFINE(HyperMap_, ([](const py::module *m) { | |||||
| (void)py::class_<HyperMapPy, MetaFuncGraph, std::shared_ptr<HyperMapPy>>(*m, "HyperMap_") | (void)py::class_<HyperMapPy, MetaFuncGraph, std::shared_ptr<HyperMapPy>>(*m, "HyperMap_") | ||||
| .def(py::init<std::shared_ptr<MultitypeFuncGraph>>(), py::arg("leaf")) | .def(py::init<std::shared_ptr<MultitypeFuncGraph>>(), py::arg("leaf")) | ||||
| .def(py::init<>()); | .def(py::init<>()); | ||||
| })); | })); | ||||
| FuncGraphPtr Tail::GenerateTupleFuncGraph(const abstract::AbstractTuplePtr& a_tuple) { | |||||
| FuncGraphPtr Tail::GenerateTupleFuncGraph(const abstract::AbstractTuplePtr &a_tuple) { | |||||
| MS_EXCEPTION_IF_NULL(a_tuple); | MS_EXCEPTION_IF_NULL(a_tuple); | ||||
| FuncGraphPtr ret = std::make_shared<FuncGraph>(); | FuncGraphPtr ret = std::make_shared<FuncGraph>(); | ||||
| @@ -401,7 +401,7 @@ FuncGraphPtr Tail::GenerateTupleFuncGraph(const abstract::AbstractTuplePtr& a_tu | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| FuncGraphPtr Tail::GenerateListFuncGraph(const abstract::AbstractListPtr& a_list) { | |||||
| FuncGraphPtr Tail::GenerateListFuncGraph(const abstract::AbstractListPtr &a_list) { | |||||
| MS_EXCEPTION_IF_NULL(a_list); | MS_EXCEPTION_IF_NULL(a_list); | ||||
| FuncGraphPtr ret = std::make_shared<FuncGraph>(); | FuncGraphPtr ret = std::make_shared<FuncGraph>(); | ||||
| @@ -421,7 +421,7 @@ FuncGraphPtr Tail::GenerateListFuncGraph(const abstract::AbstractListPtr& a_list | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| FuncGraphPtr Tail::GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) { | |||||
| FuncGraphPtr Tail::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) { | |||||
| if (args_spec_list.size() != 1) { | if (args_spec_list.size() != 1) { | ||||
| MS_LOG(EXCEPTION) << "tail requires a non-empty tuple."; | MS_LOG(EXCEPTION) << "tail requires a non-empty tuple."; | ||||
| } | } | ||||
| @@ -441,11 +441,11 @@ FuncGraphPtr Tail::GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) | |||||
| } | } | ||||
| REGISTER_PYBIND_DEFINE( | REGISTER_PYBIND_DEFINE( | ||||
| Tail_, ([](const py::module* m) { | |||||
| (void)py::class_<Tail, MetaFuncGraph, std::shared_ptr<Tail>>(*m, "Tail_").def(py::init<std::string&>()); | |||||
| Tail_, ([](const py::module *m) { | |||||
| (void)py::class_<Tail, MetaFuncGraph, std::shared_ptr<Tail>>(*m, "Tail_").def(py::init<std::string &>()); | |||||
| })); | })); | ||||
| FuncGraphPtr MakeTupleGradient::GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) { | |||||
| FuncGraphPtr MakeTupleGradient::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) { | |||||
| int tuple_size = SizeToInt(args_spec_list.size()); | int tuple_size = SizeToInt(args_spec_list.size()); | ||||
| std::ostringstream ss; | std::ostringstream ss; | ||||
| @@ -486,7 +486,7 @@ FuncGraphPtr MakeTupleGradient::GenerateFuncGraph(const AbstractBasePtrList& arg | |||||
| return fg; | return fg; | ||||
| } | } | ||||
| GradOperation::GradOperation(const std::string& name, bool get_all, bool get_by_list, bool sens_param) | |||||
| GradOperation::GradOperation(const std::string &name, bool get_all, bool get_by_list, bool sens_param) | |||||
| : MetaFuncGraph(name), get_all_(get_all), get_by_list_(get_by_list), sens_param_(sens_param) { | : MetaFuncGraph(name), get_all_(get_all), get_by_list_(get_by_list), sens_param_(sens_param) { | ||||
| if (get_by_list) { | if (get_by_list) { | ||||
| signatures_ = | signatures_ = | ||||
| @@ -496,8 +496,8 @@ GradOperation::GradOperation(const std::string& name, bool get_all, bool get_by_ | |||||
| } | } | ||||
| } | } | ||||
| FuncGraphPtr GradOperation::GetGrad(AnfNodePtr node, const AnfNodePtr& weights, | |||||
| const std::vector<AnfNodePtr>& params_list, bool applyJ) { | |||||
| FuncGraphPtr GradOperation::GetGrad(AnfNodePtr node, const AnfNodePtr &weights, | |||||
| const std::vector<AnfNodePtr> ¶ms_list, bool applyJ) { | |||||
| FuncGraphPtr ret = std::make_shared<FuncGraph>(); | FuncGraphPtr ret = std::make_shared<FuncGraph>(); | ||||
| ret->set_flags(FUNC_GRAPH_FLAG_CORE, true); | ret->set_flags(FUNC_GRAPH_FLAG_CORE, true); | ||||
| @@ -537,7 +537,7 @@ FuncGraphPtr GradOperation::GetGrad(AnfNodePtr node, const AnfNodePtr& weights, | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| void GradOperation::doGetGrad(const FuncGraphPtr& func_graph, AnfNodePtr out, AnfNodePtr ptrBprop, AnfNodePtr weights, | |||||
| void GradOperation::doGetGrad(const FuncGraphPtr &func_graph, AnfNodePtr out, AnfNodePtr ptrBprop, AnfNodePtr weights, | |||||
| ValueNodePtr opsTupleItem) { | ValueNodePtr opsTupleItem) { | ||||
| MS_EXCEPTION_IF_NULL(func_graph); | MS_EXCEPTION_IF_NULL(func_graph); | ||||
| @@ -590,7 +590,7 @@ void GradOperation::doGetGrad(const FuncGraphPtr& func_graph, AnfNodePtr out, An | |||||
| } | } | ||||
| // Generate the graph. | // Generate the graph. | ||||
| FuncGraphPtr GradOperation::GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) { | |||||
| FuncGraphPtr GradOperation::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) { | |||||
| if (args_spec_list.size() < 1) { | if (args_spec_list.size() < 1) { | ||||
| MS_LOG(EXCEPTION) << "GenerateGraph requires at least 1 parameters, while the input size is " | MS_LOG(EXCEPTION) << "GenerateGraph requires at least 1 parameters, while the input size is " | ||||
| << args_spec_list.size() << "."; | << args_spec_list.size() << "."; | ||||
| @@ -637,21 +637,21 @@ FuncGraphPtr GradOperation::GenerateFuncGraph(const AbstractBasePtrList& args_sp | |||||
| return dfBuilder; | return dfBuilder; | ||||
| } | } | ||||
| REGISTER_PYBIND_DEFINE(GradOperation_, ([](const py::module* m) { | |||||
| REGISTER_PYBIND_DEFINE(GradOperation_, ([](const py::module *m) { | |||||
| (void)py::class_<GradOperation, MetaFuncGraph, std::shared_ptr<GradOperation>>( | (void)py::class_<GradOperation, MetaFuncGraph, std::shared_ptr<GradOperation>>( | ||||
| *m, "GradOperation_") | *m, "GradOperation_") | ||||
| .def(py::init<std::string&>(), py::arg("fn")) | |||||
| .def(py::init<std::string&, bool, bool, bool>(), py::arg("fn"), py::arg("get_all"), | |||||
| .def(py::init<std::string &>(), py::arg("fn")) | |||||
| .def(py::init<std::string &, bool, bool, bool>(), py::arg("fn"), py::arg("get_all"), | |||||
| py::arg("get_by_list"), py::arg("sens_param")); | py::arg("get_by_list"), py::arg("sens_param")); | ||||
| })); | })); | ||||
| MultitypeFuncGraph::MultitypeFuncGraph(const std::string& name) : MetaFuncGraph(name) { | |||||
| MultitypeFuncGraph::MultitypeFuncGraph(const std::string &name) : MetaFuncGraph(name) { | |||||
| fn_cache_.clear(); | fn_cache_.clear(); | ||||
| signatures_ = std::vector<Signature>({// def multitype(*args:ref): | signatures_ = std::vector<Signature>({// def multitype(*args:ref): | ||||
| {"args", SignatureEnumRW::kRWRef, SignatureEnumKind::kKindVarPositional}}); | {"args", SignatureEnumRW::kRWRef, SignatureEnumKind::kKindVarPositional}}); | ||||
| } | } | ||||
| void MultitypeFuncGraph::Register(const TypePtrList& types, specialize_fn s_fn) { | |||||
| void MultitypeFuncGraph::Register(const TypePtrList &types, specialize_fn s_fn) { | |||||
| MS_LOG(DEBUG) << "Register type (" << ::mindspore::ToString(types) << "."; | MS_LOG(DEBUG) << "Register type (" << ::mindspore::ToString(types) << "."; | ||||
| auto fn = fn_cache_.find(types); | auto fn = fn_cache_.find(types); | ||||
| if (fn != fn_cache_.end()) { | if (fn != fn_cache_.end()) { | ||||
| @@ -660,7 +660,7 @@ void MultitypeFuncGraph::Register(const TypePtrList& types, specialize_fn s_fn) | |||||
| fn_cache_[types] = s_fn; | fn_cache_[types] = s_fn; | ||||
| } | } | ||||
| void MultitypeFuncGraph::Register(const TypePtrList& types, const py::function& py_fn) { | |||||
| void MultitypeFuncGraph::Register(const TypePtrList &types, const py::function &py_fn) { | |||||
| MS_LOG(DEBUG) << "Register type (" << ::mindspore::ToString(types) << ", " << std::string(py_fn.str()) << ")."; | MS_LOG(DEBUG) << "Register type (" << ::mindspore::ToString(types) << ", " << std::string(py_fn.str()) << ")."; | ||||
| auto fn = fn_cache_.find(types); | auto fn = fn_cache_.find(types); | ||||
| if (fn != fn_cache_.end()) { | if (fn != fn_cache_.end()) { | ||||
| @@ -669,9 +669,9 @@ void MultitypeFuncGraph::Register(const TypePtrList& types, const py::function& | |||||
| fn_cache_py_[types] = py_fn; | fn_cache_py_[types] = py_fn; | ||||
| } | } | ||||
| void MultitypeFuncGraph::Register(const std::vector<std::string>& types_name, const py::function& py_fn) { | |||||
| void MultitypeFuncGraph::Register(const std::vector<std::string> &types_name, const py::function &py_fn) { | |||||
| TypePtrList types; | TypePtrList types; | ||||
| for (auto& type_name : types_name) { | |||||
| for (auto &type_name : types_name) { | |||||
| auto type_ptr = StringToType(type_name); | auto type_ptr = StringToType(type_name); | ||||
| if (type_ptr == nullptr) { | if (type_ptr == nullptr) { | ||||
| MS_LOG(EXCEPTION) << "" << type_name << " convert from string error "; | MS_LOG(EXCEPTION) << "" << type_name << " convert from string error "; | ||||
| @@ -681,7 +681,7 @@ void MultitypeFuncGraph::Register(const std::vector<std::string>& types_name, co | |||||
| Register(types, py_fn); | Register(types, py_fn); | ||||
| } | } | ||||
| void MultitypeFuncGraph::PyRegister(const py::tuple& tuple, const py::function& py_fn) { | |||||
| void MultitypeFuncGraph::PyRegister(const py::tuple &tuple, const py::function &py_fn) { | |||||
| std::vector<std::string> types_name; | std::vector<std::string> types_name; | ||||
| for (size_t it = 0; it < tuple.size(); ++it) { | for (size_t it = 0; it < tuple.size(); ++it) { | ||||
| py::object name_py = tuple[it]; | py::object name_py = tuple[it]; | ||||
| @@ -693,16 +693,16 @@ void MultitypeFuncGraph::PyRegister(const py::tuple& tuple, const py::function& | |||||
| } | } | ||||
| Register(types_name, py_fn); | Register(types_name, py_fn); | ||||
| } | } | ||||
| static TypePtr UnwrapRef(const TypePtr& type) { | |||||
| static TypePtr UnwrapRef(const TypePtr &type) { | |||||
| if (type->isa<RefType>()) { | if (type->isa<RefType>()) { | ||||
| return type->cast<RefTypePtr>()->subtype(); | return type->cast<RefTypePtr>()->subtype(); | ||||
| } | } | ||||
| return type; | return type; | ||||
| } | } | ||||
| FuncGraphPtr MultitypeFuncGraph::GenerateFromTypes(const TypePtrList& types) { | |||||
| FuncGraphPtr MultitypeFuncGraph::GenerateFromTypes(const TypePtrList &types) { | |||||
| bool find_fn = false; | bool find_fn = false; | ||||
| py::function py_fn; | py::function py_fn; | ||||
| for (auto& item : fn_cache_py_) { | |||||
| for (auto &item : fn_cache_py_) { | |||||
| TypePtrList sign = item.first; | TypePtrList sign = item.first; | ||||
| if (sign.size() != types.size()) { | if (sign.size() != types.size()) { | ||||
| continue; | continue; | ||||
| @@ -735,7 +735,7 @@ FuncGraphPtr MultitypeFuncGraph::GenerateFromTypes(const TypePtrList& types) { | |||||
| oss << "There are " << fn_cache_py_.size() << " prototypes for overload function `" << name_ | oss << "There are " << fn_cache_py_.size() << " prototypes for overload function `" << name_ | ||||
| << "`, corresponding location info:\n"; | << "`, corresponding location info:\n"; | ||||
| int idx = 0; | int idx = 0; | ||||
| for (auto& item : fn_cache_py_) { | |||||
| for (auto &item : fn_cache_py_) { | |||||
| FuncGraphPtr func_graph = parse::ParsePythonCode(item.second); | FuncGraphPtr func_graph = parse::ParsePythonCode(item.second); | ||||
| if (func_graph == nullptr) { | if (func_graph == nullptr) { | ||||
| MS_LOG(WARNING) << "Fail to parse Python code for function `" << name_ << "`."; | MS_LOG(WARNING) << "Fail to parse Python code for function `" << name_ << "`."; | ||||
| @@ -747,15 +747,15 @@ FuncGraphPtr MultitypeFuncGraph::GenerateFromTypes(const TypePtrList& types) { | |||||
| << oss.str(); | << oss.str(); | ||||
| } | } | ||||
| REGISTER_PYBIND_DEFINE(MultitypeFuncGraph_, ([](const py::module* m) { | |||||
| REGISTER_PYBIND_DEFINE(MultitypeFuncGraph_, ([](const py::module *m) { | |||||
| (void)py::class_<MultitypeFuncGraph, MetaFuncGraph, std::shared_ptr<MultitypeFuncGraph>>( | (void)py::class_<MultitypeFuncGraph, MetaFuncGraph, std::shared_ptr<MultitypeFuncGraph>>( | ||||
| *m, "MultitypeFuncGraph_") | *m, "MultitypeFuncGraph_") | ||||
| .def(py::init<std::string&>()) | |||||
| .def(py::init<std::string &>()) | |||||
| .def("register_fn", &MultitypeFuncGraph::PyRegister); | .def("register_fn", &MultitypeFuncGraph::PyRegister); | ||||
| })); | })); | ||||
| // Generate the ListMap func graph. | // Generate the ListMap func graph. | ||||
| FuncGraphPtr ListMap::GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) { | |||||
| FuncGraphPtr ListMap::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) { | |||||
| size_t args_num = args_spec_list.size(); | size_t args_num = args_spec_list.size(); | ||||
| // args: fn, list1, list2, ... | // args: fn, list1, list2, ... | ||||
| if (args_num < 2) { | if (args_num < 2) { | ||||
| @@ -821,8 +821,8 @@ FuncGraphPtr ListMap::GenerateFuncGraph(const AbstractBasePtrList& args_spec_lis | |||||
| return fg_ptr; | return fg_ptr; | ||||
| } | } | ||||
| void ListMap::MakeCond(const std::vector<AnfNodePtr>& lists, const FuncGraphPtr& fgnext_ptr, | |||||
| const FuncGraphPtr& fg_ptr) { | |||||
| void ListMap::MakeCond(const std::vector<AnfNodePtr> &lists, const FuncGraphPtr &fgnext_ptr, | |||||
| const FuncGraphPtr &fg_ptr) { | |||||
| MS_EXCEPTION_IF_NULL(fg_ptr); | MS_EXCEPTION_IF_NULL(fg_ptr); | ||||
| AnfNodePtr fn = fg_ptr->add_parameter(); | AnfNodePtr fn = fg_ptr->add_parameter(); | ||||
| @@ -858,8 +858,8 @@ void ListMap::MakeCond(const std::vector<AnfNodePtr>& lists, const FuncGraphPtr& | |||||
| fgtrue_ptr->set_output(output_cnode); | fgtrue_ptr->set_output(output_cnode); | ||||
| } | } | ||||
| void ListMap::MakeNext(const std::vector<AnfNodePtr>& lists, const FuncGraphPtr& fgcond_ptr, | |||||
| const FuncGraphPtr& fg_ptr) { | |||||
| void ListMap::MakeNext(const std::vector<AnfNodePtr> &lists, const FuncGraphPtr &fgcond_ptr, | |||||
| const FuncGraphPtr &fg_ptr) { | |||||
| MS_EXCEPTION_IF_NULL(fg_ptr); | MS_EXCEPTION_IF_NULL(fg_ptr); | ||||
| AnfNodePtr fn = fg_ptr->add_parameter(); | AnfNodePtr fn = fg_ptr->add_parameter(); | ||||
| @@ -893,7 +893,7 @@ void ListMap::MakeNext(const std::vector<AnfNodePtr>& lists, const FuncGraphPtr& | |||||
| fg_ptr->set_output(output_cnode); | fg_ptr->set_output(output_cnode); | ||||
| } | } | ||||
| FuncGraphPtr TupleAdd::GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) { | |||||
| FuncGraphPtr TupleAdd::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) { | |||||
| // args: tuple1, tuple2 | // args: tuple1, tuple2 | ||||
| abstract::CheckArgsSize("TupleAdd", args_spec_list, 2); | abstract::CheckArgsSize("TupleAdd", args_spec_list, 2); | ||||
| AbstractBasePtr abs_a = args_spec_list[0]; | AbstractBasePtr abs_a = args_spec_list[0]; | ||||
| @@ -928,7 +928,7 @@ FuncGraphPtr TupleAdd::GenerateFuncGraph(const AbstractBasePtrList& args_spec_li | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| int GetArgScalarValue(const abstract::AbstractScalarPtr& scalar, const std::string&) { | |||||
| int GetArgScalarValue(const abstract::AbstractScalarPtr &scalar, const std::string &) { | |||||
| MS_EXCEPTION_IF_NULL(scalar); | MS_EXCEPTION_IF_NULL(scalar); | ||||
| return GetValue<int>(scalar->BuildValue()); | return GetValue<int>(scalar->BuildValue()); | ||||
| } | } | ||||
| @@ -942,7 +942,7 @@ int GetPositiveIndex(int index, int length) { | |||||
| return index; | return index; | ||||
| } | } | ||||
| int CheckSliceMember(const AbstractBasePtr& member, int default_value, const std::string& member_name) { | |||||
| int CheckSliceMember(const AbstractBasePtr &member, int default_value, const std::string &member_name) { | |||||
| MS_EXCEPTION_IF_NULL(member); | MS_EXCEPTION_IF_NULL(member); | ||||
| if (member->isa<AbstractScalar>()) { | if (member->isa<AbstractScalar>()) { | ||||
| @@ -957,8 +957,8 @@ int CheckSliceMember(const AbstractBasePtr& member, int default_value, const std | |||||
| << member->ToString(); | << member->ToString(); | ||||
| } | } | ||||
| void GenerateTupleSliceParameter(const AbstractTuplePtr& tuple, const AbstractSlicePtr& slice, int* start_index, | |||||
| int* stop_index, int* step_value) { | |||||
| void GenerateTupleSliceParameter(const AbstractTuplePtr &tuple, const AbstractSlicePtr &slice, int *start_index, | |||||
| int *stop_index, int *step_value) { | |||||
| MS_EXCEPTION_IF_NULL(tuple); | MS_EXCEPTION_IF_NULL(tuple); | ||||
| MS_EXCEPTION_IF_NULL(slice); | MS_EXCEPTION_IF_NULL(slice); | ||||
| MS_EXCEPTION_IF_NULL(start_index); | MS_EXCEPTION_IF_NULL(start_index); | ||||
| @@ -998,7 +998,7 @@ void GenerateTupleSliceParameter(const AbstractTuplePtr& tuple, const AbstractSl | |||||
| } | } | ||||
| } | } | ||||
| FuncGraphPtr TupleSlice::GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) { | |||||
| FuncGraphPtr TupleSlice::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) { | |||||
| // slice a tuple | // slice a tuple | ||||
| // args: tuple, start index, end index, step | // args: tuple, start index, end index, step | ||||
| const std::string op_name("TupleSlice"); | const std::string op_name("TupleSlice"); | ||||
| @@ -1032,7 +1032,7 @@ FuncGraphPtr TupleSlice::GenerateFuncGraph(const AbstractBasePtrList& args_spec_ | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| int ConvertBinaryToDecimal(const std::vector<unsigned int>& number_bin) { | |||||
| int ConvertBinaryToDecimal(const std::vector<unsigned int> &number_bin) { | |||||
| unsigned int number_dec = 0; | unsigned int number_dec = 0; | ||||
| for (size_t index = 0; index < number_bin.size(); index++) { | for (size_t index = 0; index < number_bin.size(); index++) { | ||||
| number_dec |= number_bin[index] << index; | number_dec |= number_bin[index] << index; | ||||
| @@ -1040,8 +1040,8 @@ int ConvertBinaryToDecimal(const std::vector<unsigned int>& number_bin) { | |||||
| return static_cast<int>(number_dec); | return static_cast<int>(number_dec); | ||||
| } | } | ||||
| void ParseSlice(const AbstractSlicePtr& slice, std::vector<int>* begin, std::vector<int>* end, | |||||
| std::vector<int>* strides, int length) { | |||||
| void ParseSlice(const AbstractSlicePtr &slice, std::vector<int> *begin, std::vector<int> *end, | |||||
| std::vector<int> *strides, int length) { | |||||
| MS_EXCEPTION_IF_NULL(slice); | MS_EXCEPTION_IF_NULL(slice); | ||||
| MS_EXCEPTION_IF_NULL(begin); | MS_EXCEPTION_IF_NULL(begin); | ||||
| MS_EXCEPTION_IF_NULL(end); | MS_EXCEPTION_IF_NULL(end); | ||||
| @@ -1064,8 +1064,8 @@ void ParseSlice(const AbstractSlicePtr& slice, std::vector<int>* begin, std::vec | |||||
| strides->push_back(step_value); | strides->push_back(step_value); | ||||
| } | } | ||||
| int GenerateStridedSliceParametersFromTuple(const AbstractTuplePtr& slice_tuple, const std::vector<int>& shape, | |||||
| std::vector<int>* begin, std::vector<int>* end, std::vector<int>* strides) { | |||||
| int GenerateStridedSliceParametersFromTuple(const AbstractTuplePtr &slice_tuple, const std::vector<int> &shape, | |||||
| std::vector<int> *begin, std::vector<int> *end, std::vector<int> *strides) { | |||||
| MS_EXCEPTION_IF_NULL(slice_tuple); | MS_EXCEPTION_IF_NULL(slice_tuple); | ||||
| MS_EXCEPTION_IF_NULL(begin); | MS_EXCEPTION_IF_NULL(begin); | ||||
| MS_EXCEPTION_IF_NULL(end); | MS_EXCEPTION_IF_NULL(end); | ||||
| @@ -1111,8 +1111,8 @@ int GenerateStridedSliceParametersFromTuple(const AbstractTuplePtr& slice_tuple, | |||||
| return ConvertBinaryToDecimal(shrink); | return ConvertBinaryToDecimal(shrink); | ||||
| } | } | ||||
| int GenerateStridedSliceParametersFromSlice(const AbstractSlicePtr& slice, const std::vector<int>& shape, | |||||
| std::vector<int>* begin, std::vector<int>* end, std::vector<int>* strides) { | |||||
| int GenerateStridedSliceParametersFromSlice(const AbstractSlicePtr &slice, const std::vector<int> &shape, | |||||
| std::vector<int> *begin, std::vector<int> *end, std::vector<int> *strides) { | |||||
| MS_EXCEPTION_IF_NULL(begin); | MS_EXCEPTION_IF_NULL(begin); | ||||
| MS_EXCEPTION_IF_NULL(end); | MS_EXCEPTION_IF_NULL(end); | ||||
| MS_EXCEPTION_IF_NULL(strides); | MS_EXCEPTION_IF_NULL(strides); | ||||
| @@ -1132,9 +1132,9 @@ int GenerateStridedSliceParametersFromSlice(const AbstractSlicePtr& slice, const | |||||
| return 0; | return 0; | ||||
| } | } | ||||
| int GenerateStridedSliceParametersFromNumber(const AbstractScalarPtr& scalar, const std::vector<int>& shape, | |||||
| std::vector<int>* begin, std::vector<int>* end, | |||||
| std::vector<int>* strides) { | |||||
| int GenerateStridedSliceParametersFromNumber(const AbstractScalarPtr &scalar, const std::vector<int> &shape, | |||||
| std::vector<int> *begin, std::vector<int> *end, | |||||
| std::vector<int> *strides) { | |||||
| MS_EXCEPTION_IF_NULL(begin); | MS_EXCEPTION_IF_NULL(begin); | ||||
| MS_EXCEPTION_IF_NULL(end); | MS_EXCEPTION_IF_NULL(end); | ||||
| MS_EXCEPTION_IF_NULL(strides); | MS_EXCEPTION_IF_NULL(strides); | ||||
| @@ -1153,7 +1153,7 @@ int GenerateStridedSliceParametersFromNumber(const AbstractScalarPtr& scalar, co | |||||
| return 1; | return 1; | ||||
| } | } | ||||
| FuncGraphPtr TensorSlice::GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) { | |||||
| FuncGraphPtr TensorSlice::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) { | |||||
| // slice a tensor | // slice a tensor | ||||
| // args: tensor, slice or slice tuple | // args: tensor, slice or slice tuple | ||||
| const std::string op_name = std::string("TensorSlice"); | const std::string op_name = std::string("TensorSlice"); | ||||
| @@ -1177,7 +1177,7 @@ FuncGraphPtr TensorSlice::GenerateFuncGraph(const AbstractBasePtrList& args_spec | |||||
| shrink_axis_mask = GenerateStridedSliceParametersFromNumber(scalar_ptr, shape, &begin, &end, &strides); | shrink_axis_mask = GenerateStridedSliceParametersFromNumber(scalar_ptr, shape, &begin, &end, &strides); | ||||
| } else { | } else { | ||||
| std::ostringstream args_info; | std::ostringstream args_info; | ||||
| for (const auto& arg : args_spec_list) { | |||||
| for (const auto &arg : args_spec_list) { | |||||
| MS_EXCEPTION_IF_NULL(arg); | MS_EXCEPTION_IF_NULL(arg); | ||||
| args_info << arg->ToString() << "\n"; | args_info << arg->ToString() << "\n"; | ||||
| } | } | ||||
| @@ -1199,19 +1199,19 @@ FuncGraphPtr TensorSlice::GenerateFuncGraph(const AbstractBasePtrList& args_spec | |||||
| return ret_graph; | return ret_graph; | ||||
| } | } | ||||
| REGISTER_PYBIND_DEFINE( | |||||
| TupleAdd_, ([](const py::module* m) { | |||||
| (void)py::class_<TupleAdd, MetaFuncGraph, std::shared_ptr<TupleAdd>>(*m, "TupleAdd_").def(py::init<std::string&>()); | |||||
| })); | |||||
| REGISTER_PYBIND_DEFINE(TupleAdd_, ([](const py::module *m) { | |||||
| (void)py::class_<TupleAdd, MetaFuncGraph, std::shared_ptr<TupleAdd>>(*m, "TupleAdd_") | |||||
| .def(py::init<std::string &>()); | |||||
| })); | |||||
| REGISTER_PYBIND_DEFINE(TupleSlice_, ([](const py::module* m) { | |||||
| REGISTER_PYBIND_DEFINE(TupleSlice_, ([](const py::module *m) { | |||||
| (void)py::class_<TupleSlice, MetaFuncGraph, std::shared_ptr<TupleSlice>>(*m, "TupleSlice_") | (void)py::class_<TupleSlice, MetaFuncGraph, std::shared_ptr<TupleSlice>>(*m, "TupleSlice_") | ||||
| .def(py::init<std::string&>()); | |||||
| .def(py::init<std::string &>()); | |||||
| })); | })); | ||||
| REGISTER_PYBIND_DEFINE(TensorSlice_, ([](const py::module* m) { | |||||
| REGISTER_PYBIND_DEFINE(TensorSlice_, ([](const py::module *m) { | |||||
| (void)py::class_<TensorSlice, MetaFuncGraph, std::shared_ptr<TensorSlice>>(*m, "TensorSlice_") | (void)py::class_<TensorSlice, MetaFuncGraph, std::shared_ptr<TensorSlice>>(*m, "TensorSlice_") | ||||
| .def(py::init<std::string&>()); | |||||
| .def(py::init<std::string &>()); | |||||
| })); | })); | ||||
| } // namespace prim | } // namespace prim | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -47,20 +47,20 @@ using ArgsPairList = std::vector<std::pair<AnfNodePtr, TypePtr>>; | |||||
| class MultitypeFuncGraph : public MetaFuncGraph { | class MultitypeFuncGraph : public MetaFuncGraph { | ||||
| public: | public: | ||||
| explicit MultitypeFuncGraph(const std::string& name); | |||||
| explicit MultitypeFuncGraph(const std::string &name); | |||||
| ~MultitypeFuncGraph() override = default; | ~MultitypeFuncGraph() override = default; | ||||
| MS_DECLARE_PARENT(MultitypeFuncGraph, MetaFuncGraph) | MS_DECLARE_PARENT(MultitypeFuncGraph, MetaFuncGraph) | ||||
| using specialize_fn = FuncGraph* (*)(TypePtrList); | |||||
| using specialize_fn = FuncGraph *(*)(TypePtrList); | |||||
| // Register a method which specialize based on types vectors; | // Register a method which specialize based on types vectors; | ||||
| virtual void Register(const TypePtrList& types, specialize_fn s_fn); | |||||
| virtual void Register(const TypePtrList& types, const py::function& py_fn); | |||||
| virtual void Register(const std::vector<std::string>& types_name, const py::function& py_fn); | |||||
| virtual void PyRegister(const py::tuple& tuple, const py::function& py_fn); | |||||
| virtual void Register(const TypePtrList &types, specialize_fn s_fn); | |||||
| virtual void Register(const TypePtrList &types, const py::function &py_fn); | |||||
| virtual void Register(const std::vector<std::string> &types_name, const py::function &py_fn); | |||||
| virtual void PyRegister(const py::tuple &tuple, const py::function &py_fn); | |||||
| FuncGraphPtr GenerateFromTypes(const TypePtrList& types) override; | |||||
| FuncGraphPtr GenerateFromTypes(const TypePtrList &types) override; | |||||
| size_t GetPyFnCacheSize() const { return fn_cache_py_.size(); } | size_t GetPyFnCacheSize() const { return fn_cache_py_.size(); } | ||||
| const std::unordered_map<TypePtrList, py::function, TypeListHasher, TypeListEqual>& GetPyFunctions() const { | |||||
| const std::unordered_map<TypePtrList, py::function, TypeListHasher, TypeListEqual> &GetPyFunctions() const { | |||||
| return fn_cache_py_; | return fn_cache_py_; | ||||
| } | } | ||||
| @@ -72,10 +72,10 @@ using MultitypeFuncGraphPtr = std::shared_ptr<MultitypeFuncGraph>; | |||||
| class HyperMap : public MetaFuncGraph { | class HyperMap : public MetaFuncGraph { | ||||
| public: | public: | ||||
| explicit HyperMap(const std::shared_ptr<MultitypeFuncGraph>& fn_leaf = nullptr); | |||||
| HyperMap(const HyperMap& h); | |||||
| explicit HyperMap(const std::shared_ptr<MultitypeFuncGraph> &fn_leaf = nullptr); | |||||
| HyperMap(const HyperMap &h); | |||||
| void Init(); | void Init(); | ||||
| HyperMap& operator=(const HyperMap& h) { | |||||
| HyperMap &operator=(const HyperMap &h) { | |||||
| if (this != &h) { | if (this != &h) { | ||||
| fn_leaf_ = h.fn_leaf_; | fn_leaf_ = h.fn_leaf_; | ||||
| broadcast_ = h.broadcast_; | broadcast_ = h.broadcast_; | ||||
| @@ -89,21 +89,21 @@ class HyperMap : public MetaFuncGraph { | |||||
| ~HyperMap() override = default; | ~HyperMap() override = default; | ||||
| MS_DECLARE_PARENT(HyperMap, MetaFuncGraph) | MS_DECLARE_PARENT(HyperMap, MetaFuncGraph) | ||||
| abstract::AbstractBasePtrList NormalizeArgs(const abstract::AbstractBasePtrList& args_spec_list) const override; | |||||
| FuncGraphPtr GenerateFromTypes(const TypePtrList& args_spec_list) override; | |||||
| abstract::AbstractBasePtrList NormalizeArgs(const abstract::AbstractBasePtrList &args_spec_list) const override; | |||||
| FuncGraphPtr GenerateFromTypes(const TypePtrList &args_spec_list) override; | |||||
| MetaFuncGraphPtr GetFnLeaf() { return fn_leaf_; } | MetaFuncGraphPtr GetFnLeaf() { return fn_leaf_; } | ||||
| private: | private: | ||||
| AnfNodePtr FullMake(TypePtr type, const FuncGraphPtr& func_graph, const AnfNodePtr& fn_arg, | |||||
| const ArgsPairList& arg_map); | |||||
| AnfNodePtr FullMake(const std::shared_ptr<List>& type, const FuncGraphPtr& func_graph, const AnfNodePtr& fn_arg, | |||||
| const ArgsPairList& arg_map); | |||||
| AnfNodePtr FullMake(const std::shared_ptr<Tuple>& type, const FuncGraphPtr& func_graph, const AnfNodePtr& fn_arg, | |||||
| const ArgsPairList& arg_map); | |||||
| AnfNodePtr FullMake(const std::shared_ptr<Class>& type, const FuncGraphPtr& func_graph, const AnfNodePtr& fn_arg, | |||||
| const ArgsPairList& arg_map); | |||||
| AnfNodePtr Make(const FuncGraphPtr& graph, const AnfNodePtr& fn_arg, const ArgsPairList& arg_map); | |||||
| ArgsPairList Harmonize(const FuncGraphPtr& graph, const ArgsPairList& args_spec_list); | |||||
| AnfNodePtr FullMake(TypePtr type, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, | |||||
| const ArgsPairList &arg_map); | |||||
| AnfNodePtr FullMake(const std::shared_ptr<List> &type, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, | |||||
| const ArgsPairList &arg_map); | |||||
| AnfNodePtr FullMake(const std::shared_ptr<Tuple> &type, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, | |||||
| const ArgsPairList &arg_map); | |||||
| AnfNodePtr FullMake(const std::shared_ptr<Class> &type, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, | |||||
| const ArgsPairList &arg_map); | |||||
| AnfNodePtr Make(const FuncGraphPtr &graph, const AnfNodePtr &fn_arg, const ArgsPairList &arg_map); | |||||
| ArgsPairList Harmonize(const FuncGraphPtr &graph, const ArgsPairList &args_spec_list); | |||||
| MultitypeFuncGraphPtr fn_leaf_; | MultitypeFuncGraphPtr fn_leaf_; | ||||
| bool broadcast_; | bool broadcast_; | ||||
| @@ -113,7 +113,7 @@ using HyperMapPtr = std::shared_ptr<HyperMap>; | |||||
| class HyperMapPy : public HyperMap { | class HyperMapPy : public HyperMap { | ||||
| public: | public: | ||||
| explicit HyperMapPy(const std::shared_ptr<MultitypeFuncGraph>& fn_leaf = nullptr) : HyperMap(fn_leaf) {} | |||||
| explicit HyperMapPy(const std::shared_ptr<MultitypeFuncGraph> &fn_leaf = nullptr) : HyperMap(fn_leaf) {} | |||||
| ~HyperMapPy() override = default; | ~HyperMapPy() override = default; | ||||
| MS_DECLARE_PARENT(HyperMapPy, HyperMap) | MS_DECLARE_PARENT(HyperMapPy, HyperMap) | ||||
| }; | }; | ||||
| @@ -123,56 +123,56 @@ extern ValuePtr kCompositeHyperMap; | |||||
| class Tail : public MetaFuncGraph { | class Tail : public MetaFuncGraph { | ||||
| public: | public: | ||||
| explicit Tail(const std::string& name) : MetaFuncGraph(name) {} | |||||
| explicit Tail(const std::string &name) : MetaFuncGraph(name) {} | |||||
| ~Tail() override = default; | ~Tail() override = default; | ||||
| MS_DECLARE_PARENT(Tail, MetaFuncGraph) | MS_DECLARE_PARENT(Tail, MetaFuncGraph) | ||||
| FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) override; | |||||
| FuncGraphPtr GenerateTupleFuncGraph(const abstract::AbstractTuplePtr& a_tuple); | |||||
| FuncGraphPtr GenerateListFuncGraph(const abstract::AbstractListPtr& a_list); | |||||
| FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override; | |||||
| FuncGraphPtr GenerateTupleFuncGraph(const abstract::AbstractTuplePtr &a_tuple); | |||||
| FuncGraphPtr GenerateListFuncGraph(const abstract::AbstractListPtr &a_list); | |||||
| friend bool operator==(const Tail& lhs, const Tail& rhs) { return lhs.name_ == rhs.name_; } | |||||
| friend bool operator==(const Tail &lhs, const Tail &rhs) { return lhs.name_ == rhs.name_; } | |||||
| }; | }; | ||||
| using TailPtr = std::shared_ptr<Tail>; | using TailPtr = std::shared_ptr<Tail>; | ||||
| class MakeTupleGradient : public MetaFuncGraph { | class MakeTupleGradient : public MetaFuncGraph { | ||||
| public: | public: | ||||
| explicit MakeTupleGradient(const std::string& name) : MetaFuncGraph(name) {} | |||||
| explicit MakeTupleGradient(const std::string &name) : MetaFuncGraph(name) {} | |||||
| ~MakeTupleGradient() override = default; | ~MakeTupleGradient() override = default; | ||||
| MS_DECLARE_PARENT(MakeTupleGradient, MetaFuncGraph) | MS_DECLARE_PARENT(MakeTupleGradient, MetaFuncGraph) | ||||
| FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) override; | |||||
| friend bool operator==(const MakeTupleGradient& lhs, const MakeTupleGradient& rhs) { return lhs.name_ == rhs.name_; } | |||||
| FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override; | |||||
| friend bool operator==(const MakeTupleGradient &lhs, const MakeTupleGradient &rhs) { return lhs.name_ == rhs.name_; } | |||||
| }; | }; | ||||
| using MakeTupleGradientPtr = std::shared_ptr<MakeTupleGradient>; | using MakeTupleGradientPtr = std::shared_ptr<MakeTupleGradient>; | ||||
| class GradOperation : public MetaFuncGraph { | class GradOperation : public MetaFuncGraph { | ||||
| public: | public: | ||||
| explicit GradOperation(const std::string& name, bool get_all = false, bool get_by_list = false, | |||||
| explicit GradOperation(const std::string &name, bool get_all = false, bool get_by_list = false, | |||||
| bool sens_param = false); | bool sens_param = false); | ||||
| ~GradOperation() override = default; | ~GradOperation() override = default; | ||||
| MS_DECLARE_PARENT(GradOperation, MetaFuncGraph) | MS_DECLARE_PARENT(GradOperation, MetaFuncGraph) | ||||
| FuncGraphPtr GetGrad(AnfNodePtr ptrNode, const AnfNodePtr& weights, const std::vector<AnfNodePtr>& ptrParams, | |||||
| FuncGraphPtr GetGrad(AnfNodePtr ptrNode, const AnfNodePtr &weights, const std::vector<AnfNodePtr> &ptrParams, | |||||
| bool applyJ = false); | bool applyJ = false); | ||||
| FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) override; | |||||
| FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override; | |||||
| bool sens_param() const { return sens_param_; } | bool sens_param() const { return sens_param_; } | ||||
| bool get_all_; | bool get_all_; | ||||
| bool get_by_list_; | bool get_by_list_; | ||||
| bool sens_param_; | bool sens_param_; | ||||
| private: | private: | ||||
| void doGetGrad(const FuncGraphPtr& func_graph, AnfNodePtr ptrOut, AnfNodePtr ptrBprop, AnfNodePtr weights, | |||||
| void doGetGrad(const FuncGraphPtr &func_graph, AnfNodePtr ptrOut, AnfNodePtr ptrBprop, AnfNodePtr weights, | |||||
| ValueNodePtr opsTupleItem); | ValueNodePtr opsTupleItem); | ||||
| }; | }; | ||||
| using GradOperationPtr = std::shared_ptr<GradOperation>; | using GradOperationPtr = std::shared_ptr<GradOperation>; | ||||
| class ListMap { | class ListMap { | ||||
| public: | public: | ||||
| explicit ListMap(const std::string& name) : name_(name) { cache_.clear(); } | |||||
| explicit ListMap(const std::string &name) : name_(name) { cache_.clear(); } | |||||
| ~ListMap() = default; | ~ListMap() = default; | ||||
| void MakeCond(const std::vector<AnfNodePtr>& lists, const FuncGraphPtr& gnext_ptr, const FuncGraphPtr& graph_ptr); | |||||
| void MakeNext(const std::vector<AnfNodePtr>& lists, const FuncGraphPtr& gcond_ptr, const FuncGraphPtr& graph_ptr); | |||||
| FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList& args_spec_list); | |||||
| void MakeCond(const std::vector<AnfNodePtr> &lists, const FuncGraphPtr &gnext_ptr, const FuncGraphPtr &graph_ptr); | |||||
| void MakeNext(const std::vector<AnfNodePtr> &lists, const FuncGraphPtr &gcond_ptr, const FuncGraphPtr &graph_ptr); | |||||
| FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list); | |||||
| private: | private: | ||||
| std::string name_; | std::string name_; | ||||
| @@ -181,31 +181,31 @@ class ListMap { | |||||
| class TupleAdd : public MetaFuncGraph { | class TupleAdd : public MetaFuncGraph { | ||||
| public: | public: | ||||
| explicit TupleAdd(const std::string& name) : MetaFuncGraph(name) {} | |||||
| explicit TupleAdd(const std::string &name) : MetaFuncGraph(name) {} | |||||
| ~TupleAdd() override = default; | ~TupleAdd() override = default; | ||||
| MS_DECLARE_PARENT(TupleAdd, MetaFuncGraph) | MS_DECLARE_PARENT(TupleAdd, MetaFuncGraph) | ||||
| FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) override; | |||||
| friend bool operator==(const TupleAdd& lhs, const TupleAdd& rhs) { return lhs.name_ == rhs.name_; } | |||||
| FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override; | |||||
| friend bool operator==(const TupleAdd &lhs, const TupleAdd &rhs) { return lhs.name_ == rhs.name_; } | |||||
| }; | }; | ||||
| using TupleAddPtr = std::shared_ptr<TupleAdd>; | using TupleAddPtr = std::shared_ptr<TupleAdd>; | ||||
| class TupleSlice : public MetaFuncGraph { | class TupleSlice : public MetaFuncGraph { | ||||
| public: | public: | ||||
| explicit TupleSlice(const std::string& name) : MetaFuncGraph(name) {} | |||||
| explicit TupleSlice(const std::string &name) : MetaFuncGraph(name) {} | |||||
| ~TupleSlice() override = default; | ~TupleSlice() override = default; | ||||
| MS_DECLARE_PARENT(TupleSlice, MetaFuncGraph) | MS_DECLARE_PARENT(TupleSlice, MetaFuncGraph) | ||||
| FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) override; | |||||
| friend bool operator==(const TupleSlice& lhs, const TupleSlice& rhs) { return lhs.name_ == rhs.name_; } | |||||
| FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override; | |||||
| friend bool operator==(const TupleSlice &lhs, const TupleSlice &rhs) { return lhs.name_ == rhs.name_; } | |||||
| }; | }; | ||||
| using TupleSlicePtr = std::shared_ptr<TupleSlice>; | using TupleSlicePtr = std::shared_ptr<TupleSlice>; | ||||
| class TensorSlice : public MetaFuncGraph { | class TensorSlice : public MetaFuncGraph { | ||||
| public: | public: | ||||
| explicit TensorSlice(const std::string& name) : MetaFuncGraph(name) {} | |||||
| explicit TensorSlice(const std::string &name) : MetaFuncGraph(name) {} | |||||
| ~TensorSlice() override = default; | ~TensorSlice() override = default; | ||||
| MS_DECLARE_PARENT(TensorSlice, MetaFuncGraph) | MS_DECLARE_PARENT(TensorSlice, MetaFuncGraph) | ||||
| FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) override; | |||||
| friend bool operator==(const TensorSlice& lhs, const TensorSlice& rhs) { return lhs.name_ == rhs.name_; } | |||||
| FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override; | |||||
| friend bool operator==(const TensorSlice &lhs, const TensorSlice &rhs) { return lhs.name_ == rhs.name_; } | |||||
| }; | }; | ||||
| using TensorSlicePtr = std::shared_ptr<TensorSlice>; | using TensorSlicePtr = std::shared_ptr<TensorSlice>; | ||||
| @@ -34,7 +34,7 @@ namespace prim { | |||||
| namespace { | namespace { | ||||
| using PatternListType = std::initializer_list<BaseRef>; | using PatternListType = std::initializer_list<BaseRef>; | ||||
| const std::vector<Signature>& GetSignature(const ValuePtr& function) { | |||||
| const std::vector<Signature> &GetSignature(const ValuePtr &function) { | |||||
| static const auto empty = std::vector<Signature>(); | static const auto empty = std::vector<Signature>(); | ||||
| if (function->isa<Primitive>()) { | if (function->isa<Primitive>()) { | ||||
| return function->cast<PrimitivePtr>()->signatures(); | return function->cast<PrimitivePtr>()->signatures(); | ||||
| @@ -44,8 +44,8 @@ const std::vector<Signature>& GetSignature(const ValuePtr& function) { | |||||
| return empty; | return empty; | ||||
| } | } | ||||
| void ProcessDefault(const std::string& func_name, const AbstractBasePtrList& args_spec_list, | |||||
| const std::vector<Signature>& signature, bool has_var, std::vector<AnfNodePtr>* op_inputs) { | |||||
| void ProcessDefault(const std::string &func_name, const AbstractBasePtrList &args_spec_list, | |||||
| const std::vector<Signature> &signature, bool has_var, std::vector<AnfNodePtr> *op_inputs) { | |||||
| std::size_t sig_size = signature.size(); | std::size_t sig_size = signature.size(); | ||||
| auto positional_size = sig_size; | auto positional_size = sig_size; | ||||
| if (has_var) { | if (has_var) { | ||||
| @@ -64,8 +64,8 @@ void ProcessDefault(const std::string& func_name, const AbstractBasePtrList& arg | |||||
| } | } | ||||
| // Get the largest type of index in the same SignatureEnumDType of arguments. | // Get the largest type of index in the same SignatureEnumDType of arguments. | ||||
| std::map<SignatureEnumDType, size_t> GetMaxDtypeIndex(const std::vector<SignatureEnumDType>& dtypes, | |||||
| const abstract::AbstractBasePtrList& args_spec_list) { | |||||
| std::map<SignatureEnumDType, size_t> GetMaxDtypeIndex(const std::vector<SignatureEnumDType> &dtypes, | |||||
| const abstract::AbstractBasePtrList &args_spec_list) { | |||||
| // record index for signature.dtypes of the same type | // record index for signature.dtypes of the same type | ||||
| // eg. [T, T1, T, T2, T, T1, T3] -> {{T:(0,2,4)}, {T1:(1,5)}, {T2:(3)}, {T3:(6)}} | // eg. [T, T1, T, T2, T, T1, T3] -> {{T:(0,2,4)}, {T1:(1,5)}, {T2:(3)}, {T3:(6)}} | ||||
| std::map<SignatureEnumDType, std::vector<size_t>> type_indexs; | std::map<SignatureEnumDType, std::vector<size_t>> type_indexs; | ||||
| @@ -89,7 +89,7 @@ std::map<SignatureEnumDType, size_t> GetMaxDtypeIndex(const std::vector<Signatur | |||||
| continue; | continue; | ||||
| } | } | ||||
| for (const auto& index : indexs) { | |||||
| for (const auto &index : indexs) { | |||||
| AbstractBasePtr arg_value = args_spec_list[index]; | AbstractBasePtr arg_value = args_spec_list[index]; | ||||
| if (arg_value->isa<abstract::AbstractRef>()) { | if (arg_value->isa<abstract::AbstractRef>()) { | ||||
| arg_value = arg_value->cast<abstract::AbstractRefPtr>()->ref(); | arg_value = arg_value->cast<abstract::AbstractRefPtr>()->ref(); | ||||
| @@ -104,7 +104,7 @@ std::map<SignatureEnumDType, size_t> GetMaxDtypeIndex(const std::vector<Signatur | |||||
| return dst_type; | return dst_type; | ||||
| } | } | ||||
| AnfNodePtr DoCast(const AnfNodePtr& param, const AnfNodePtr& source_param, const FuncGraphPtr& graph) { | |||||
| AnfNodePtr DoCast(const AnfNodePtr ¶m, const AnfNodePtr &source_param, const FuncGraphPtr &graph) { | |||||
| // op and module import path | // op and module import path | ||||
| auto prim_dtype = prim::GetPythonOps("dtype", "mindspore.ops.functional"); | auto prim_dtype = prim::GetPythonOps("dtype", "mindspore.ops.functional"); | ||||
| MS_EXCEPTION_IF_NULL(prim_dtype); | MS_EXCEPTION_IF_NULL(prim_dtype); | ||||
| @@ -116,11 +116,11 @@ AnfNodePtr DoCast(const AnfNodePtr& param, const AnfNodePtr& source_param, const | |||||
| return NewCNode({cast_node, param, dtype_node}, graph); | return NewCNode({cast_node, param, dtype_node}, graph); | ||||
| } | } | ||||
| void DoAutoCast(const std::vector<Signature>& signature, const abstract::AbstractBasePtrList& args_spec_list, | |||||
| const FuncGraphPtr& graph, std::vector<AnfNodePtr>* op_inputs) { | |||||
| void DoAutoCast(const std::vector<Signature> &signature, const abstract::AbstractBasePtrList &args_spec_list, | |||||
| const FuncGraphPtr &graph, std::vector<AnfNodePtr> *op_inputs) { | |||||
| std::vector<SignatureEnumDType> dtypes; | std::vector<SignatureEnumDType> dtypes; | ||||
| (void)std::transform(signature.begin(), signature.end(), std::back_inserter(dtypes), | (void)std::transform(signature.begin(), signature.end(), std::back_inserter(dtypes), | ||||
| [](const Signature& sig) { return sig.dtype; }); | |||||
| [](const Signature &sig) { return sig.dtype; }); | |||||
| int empty_dtype_count = std::count(dtypes.begin(), dtypes.end(), SignatureEnumDType::kDTypeEmptyDefaultValue); | int empty_dtype_count = std::count(dtypes.begin(), dtypes.end(), SignatureEnumDType::kDTypeEmptyDefaultValue); | ||||
| if (dtypes.empty() || static_cast<int>(dtypes.size()) == empty_dtype_count) { | if (dtypes.empty() || static_cast<int>(dtypes.size()) == empty_dtype_count) { | ||||
| return; | return; | ||||
| @@ -143,10 +143,10 @@ void DoAutoCast(const std::vector<Signature>& signature, const abstract::Abstrac | |||||
| } | } | ||||
| } | } | ||||
| AnfNodePtr BuildNewCNode(const FuncGraphPtr& func_graph, const std::string& func_name, const ValuePtr& function, | |||||
| const AbstractBasePtrList& args_spec_list, const std::vector<AnfNodePtr>& params_list) { | |||||
| AnfNodePtr BuildNewCNode(const FuncGraphPtr &func_graph, const std::string &func_name, const ValuePtr &function, | |||||
| const AbstractBasePtrList &args_spec_list, const std::vector<AnfNodePtr> ¶ms_list) { | |||||
| // args: original inputs | // args: original inputs | ||||
| auto& signature = GetSignature(function); | |||||
| auto &signature = GetSignature(function); | |||||
| std::size_t sig_size = signature.size(); | std::size_t sig_size = signature.size(); | ||||
| auto has_var = (sig_size > 0 && signature[sig_size - 1].kind == SignatureEnumKind::kKindVarPositional); | auto has_var = (sig_size > 0 && signature[sig_size - 1].kind == SignatureEnumKind::kKindVarPositional); | ||||
| if (sig_size > 0) { | if (sig_size > 0) { | ||||
| @@ -196,13 +196,13 @@ AnfNodePtr BuildNewCNode(const FuncGraphPtr& func_graph, const std::string& func | |||||
| } | } | ||||
| } // namespace | } // namespace | ||||
| AnfNodePtr GenerateCNode(const FuncGraphPtr& func_graph, const std::string& func_name, const ValuePtr& function, | |||||
| const AbstractBasePtrList& args_spec_list, const AnfNodePtrList& old_node_inputs) { | |||||
| AnfNodePtr GenerateCNode(const FuncGraphPtr &func_graph, const std::string &func_name, const ValuePtr &function, | |||||
| const AbstractBasePtrList &args_spec_list, const AnfNodePtrList &old_node_inputs) { | |||||
| auto new_cnode = BuildNewCNode(func_graph, func_name, function, args_spec_list, old_node_inputs); | auto new_cnode = BuildNewCNode(func_graph, func_name, function, args_spec_list, old_node_inputs); | ||||
| return new_cnode; | return new_cnode; | ||||
| } | } | ||||
| FuncGraphPtr DoSignatureMetaFuncGraph::GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) { | |||||
| FuncGraphPtr DoSignatureMetaFuncGraph::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) { | |||||
| FuncGraphPtr func_graph = std::make_shared<FuncGraph>(); | FuncGraphPtr func_graph = std::make_shared<FuncGraph>(); | ||||
| for (size_t i = 0; i < args_spec_list.size(); ++i) { | for (size_t i = 0; i < args_spec_list.size(); ++i) { | ||||
| @@ -37,17 +37,17 @@ namespace mindspore { | |||||
| namespace prim { | namespace prim { | ||||
| class DoSignatureMetaFuncGraph : public MetaFuncGraph { | class DoSignatureMetaFuncGraph : public MetaFuncGraph { | ||||
| public: | public: | ||||
| explicit DoSignatureMetaFuncGraph(const std::string& name, const ValuePtr& function) | |||||
| explicit DoSignatureMetaFuncGraph(const std::string &name, const ValuePtr &function) | |||||
| : MetaFuncGraph("S-" + name), function_(function) {} | : MetaFuncGraph("S-" + name), function_(function) {} | ||||
| ~DoSignatureMetaFuncGraph() override = default; | ~DoSignatureMetaFuncGraph() override = default; | ||||
| MS_DECLARE_PARENT(DoSignatureMetaFuncGraph, MetaFuncGraph) | MS_DECLARE_PARENT(DoSignatureMetaFuncGraph, MetaFuncGraph) | ||||
| FuncGraphPtr GenerateFuncGraph(const abstract::AbstractBasePtrList& args_spec_list) override; | |||||
| FuncGraphPtr GenerateFuncGraph(const abstract::AbstractBasePtrList &args_spec_list) override; | |||||
| const ValuePtr function() const { return function_; } | const ValuePtr function() const { return function_; } | ||||
| friend bool operator==(const DoSignatureMetaFuncGraph& lhs, const DoSignatureMetaFuncGraph& rhs) { | |||||
| friend bool operator==(const DoSignatureMetaFuncGraph &lhs, const DoSignatureMetaFuncGraph &rhs) { | |||||
| return &lhs == &rhs; | return &lhs == &rhs; | ||||
| } | } | ||||
| @@ -56,8 +56,8 @@ class DoSignatureMetaFuncGraph : public MetaFuncGraph { | |||||
| }; | }; | ||||
| using RWSignaturePtr = std::shared_ptr<DoSignatureMetaFuncGraph>; | using RWSignaturePtr = std::shared_ptr<DoSignatureMetaFuncGraph>; | ||||
| AnfNodePtr GenerateCNode(const FuncGraphPtr& func_graph, const std::string& func_name, const ValuePtr& function, | |||||
| const AbstractBasePtrList& args_spec_list, const AnfNodePtrList& old_node_inputs); | |||||
| AnfNodePtr GenerateCNode(const FuncGraphPtr &func_graph, const std::string &func_name, const ValuePtr &function, | |||||
| const AbstractBasePtrList &args_spec_list, const AnfNodePtrList &old_node_inputs); | |||||
| } // namespace prim | } // namespace prim | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -27,7 +27,7 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| // namespace to support composite operators definition | // namespace to support composite operators definition | ||||
| namespace prim { | namespace prim { | ||||
| FuncGraphPtr ListAppend::GenerateFuncGraph(const abstract::AbstractBasePtrList& args_list) { | |||||
| FuncGraphPtr ListAppend::GenerateFuncGraph(const abstract::AbstractBasePtrList &args_list) { | |||||
| abstract::CheckArgsSize("ListAppend", args_list, 2); | abstract::CheckArgsSize("ListAppend", args_list, 2); | ||||
| AbstractBasePtr arg0 = args_list[0]; | AbstractBasePtr arg0 = args_list[0]; | ||||
| @@ -52,9 +52,9 @@ FuncGraphPtr ListAppend::GenerateFuncGraph(const abstract::AbstractBasePtrList& | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| REGISTER_PYBIND_DEFINE(ListAppend_, ([](const py::module* m) { | |||||
| REGISTER_PYBIND_DEFINE(ListAppend_, ([](const py::module *m) { | |||||
| (void)py::class_<ListAppend, MetaFuncGraph, std::shared_ptr<ListAppend>>(*m, "ListAppend_") | (void)py::class_<ListAppend, MetaFuncGraph, std::shared_ptr<ListAppend>>(*m, "ListAppend_") | ||||
| .def(py::init<std::string&>()); | |||||
| .def(py::init<std::string &>()); | |||||
| })); | })); | ||||
| } // namespace prim | } // namespace prim | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -28,15 +28,15 @@ namespace mindspore { | |||||
| namespace prim { | namespace prim { | ||||
| class ListAppend : public MetaFuncGraph { | class ListAppend : public MetaFuncGraph { | ||||
| public: | public: | ||||
| explicit ListAppend(const std::string& name) : MetaFuncGraph(name) {} | |||||
| explicit ListAppend(const std::string &name) : MetaFuncGraph(name) {} | |||||
| ~ListAppend() override = default; | ~ListAppend() override = default; | ||||
| MS_DECLARE_PARENT(ListAppend, MetaFuncGraph) | MS_DECLARE_PARENT(ListAppend, MetaFuncGraph) | ||||
| FuncGraphPtr GenerateFuncGraph(const abstract::AbstractBasePtrList& a_list) override; | |||||
| friend std::ostream& operator<<(std::ostream& os, const ListAppend& list_append) { | |||||
| FuncGraphPtr GenerateFuncGraph(const abstract::AbstractBasePtrList &a_list) override; | |||||
| friend std::ostream &operator<<(std::ostream &os, const ListAppend &list_append) { | |||||
| os << list_append.name_; | os << list_append.name_; | ||||
| return os; | return os; | ||||
| } | } | ||||
| friend bool operator==(const ListAppend& lhs, const ListAppend& rhs) { return lhs.name_ == rhs.name_; } | |||||
| friend bool operator==(const ListAppend &lhs, const ListAppend &rhs) { return lhs.name_ == rhs.name_; } | |||||
| }; | }; | ||||
| using ListAppendPtr = std::shared_ptr<ListAppend>; | using ListAppendPtr = std::shared_ptr<ListAppend>; | ||||
| } // namespace prim | } // namespace prim | ||||
| @@ -40,7 +40,7 @@ using mindspore::abstract::AbstractKeywordArg; | |||||
| using mindspore::abstract::AbstractTuple; | using mindspore::abstract::AbstractTuple; | ||||
| using mindspore::abstract::AbstractTuplePtr; | using mindspore::abstract::AbstractTuplePtr; | ||||
| FuncGraphPtr UnpackCall::GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) { | |||||
| FuncGraphPtr UnpackCall::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) { | |||||
| // slice a tensor | // slice a tensor | ||||
| // args: tensor, slice or slice tuple | // args: tensor, slice or slice tuple | ||||
| const std::string op_name = std::string("UnpackCall"); | const std::string op_name = std::string("UnpackCall"); | ||||
| @@ -70,7 +70,7 @@ FuncGraphPtr UnpackCall::GenerateFuncGraph(const AbstractBasePtrList& args_spec_ | |||||
| AnfNodePtr para_dict = ret_graph->add_parameter(); | AnfNodePtr para_dict = ret_graph->add_parameter(); | ||||
| auto dict_elems = arg_dict->elements(); | auto dict_elems = arg_dict->elements(); | ||||
| (void)std::transform(dict_elems.begin(), dict_elems.end(), std::back_inserter(elems), | (void)std::transform(dict_elems.begin(), dict_elems.end(), std::back_inserter(elems), | ||||
| [ret_graph, para_dict](const AbstractAttribute& item) { | |||||
| [ret_graph, para_dict](const AbstractAttribute &item) { | |||||
| auto dict_get_item = ret_graph->NewCNode( | auto dict_get_item = ret_graph->NewCNode( | ||||
| {NewValueNode(prim::kPrimDictGetItem), para_dict, NewValueNode(item.first)}); | {NewValueNode(prim::kPrimDictGetItem), para_dict, NewValueNode(item.first)}); | ||||
| return ret_graph->NewCNode( | return ret_graph->NewCNode( | ||||
| @@ -85,9 +85,9 @@ FuncGraphPtr UnpackCall::GenerateFuncGraph(const AbstractBasePtrList& args_spec_ | |||||
| return ret_graph; | return ret_graph; | ||||
| } | } | ||||
| REGISTER_PYBIND_DEFINE(UnpackCall_, ([](const py::module* m) { | |||||
| REGISTER_PYBIND_DEFINE(UnpackCall_, ([](const py::module *m) { | |||||
| (void)py::class_<UnpackCall, MetaFuncGraph, std::shared_ptr<UnpackCall>>(*m, "UnpackCall_") | (void)py::class_<UnpackCall, MetaFuncGraph, std::shared_ptr<UnpackCall>>(*m, "UnpackCall_") | ||||
| .def(py::init<std::string&>()); | |||||
| .def(py::init<std::string &>()); | |||||
| })); | })); | ||||
| } // namespace prim | } // namespace prim | ||||
| @@ -40,11 +40,11 @@ namespace prim { | |||||
| // and generate positional parameters and key-value pairs for function. | // and generate positional parameters and key-value pairs for function. | ||||
| class UnpackCall : public MetaFuncGraph { | class UnpackCall : public MetaFuncGraph { | ||||
| public: | public: | ||||
| explicit UnpackCall(const std::string& name) : MetaFuncGraph(name) {} | |||||
| explicit UnpackCall(const std::string &name) : MetaFuncGraph(name) {} | |||||
| ~UnpackCall() override = default; | ~UnpackCall() override = default; | ||||
| MS_DECLARE_PARENT(UnpackCall, MetaFuncGraph) | MS_DECLARE_PARENT(UnpackCall, MetaFuncGraph) | ||||
| FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) override; | |||||
| friend bool operator==(const UnpackCall& lhs, const UnpackCall& rhs) { return lhs.name_ == rhs.name_; } | |||||
| FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override; | |||||
| friend bool operator==(const UnpackCall &lhs, const UnpackCall &rhs) { return lhs.name_ == rhs.name_; } | |||||
| }; | }; | ||||
| using UnpackCallPtr = std::shared_ptr<UnpackCall>; | using UnpackCallPtr = std::shared_ptr<UnpackCall>; | ||||
| @@ -36,7 +36,7 @@ namespace prim { | |||||
| using mindspore::abstract::AbstractBase; | using mindspore::abstract::AbstractBase; | ||||
| using mindspore::abstract::AbstractTuple; | using mindspore::abstract::AbstractTuple; | ||||
| FuncGraphPtr ZipOperation::GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) { | |||||
| FuncGraphPtr ZipOperation::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) { | |||||
| // zip operation: | // zip operation: | ||||
| // input: tuple arguments | // input: tuple arguments | ||||
| // output: tuple of items of input iterated on every input | // output: tuple of items of input iterated on every input | ||||
| @@ -44,7 +44,7 @@ FuncGraphPtr ZipOperation::GenerateFuncGraph(const AbstractBasePtrList& args_spe | |||||
| MS_LOG(EXCEPTION) << "zip arguments input should not be empty"; | MS_LOG(EXCEPTION) << "zip arguments input should not be empty"; | ||||
| } | } | ||||
| auto is_all_tuple = std::all_of(args_spec_list.begin(), args_spec_list.end(), [](const AbstractBasePtr& abs) -> bool { | |||||
| auto is_all_tuple = std::all_of(args_spec_list.begin(), args_spec_list.end(), [](const AbstractBasePtr &abs) -> bool { | |||||
| MS_EXCEPTION_IF_NULL(abs); | MS_EXCEPTION_IF_NULL(abs); | ||||
| return abs->isa<AbstractTuple>(); | return abs->isa<AbstractTuple>(); | ||||
| }); | }); | ||||
| @@ -53,7 +53,7 @@ FuncGraphPtr ZipOperation::GenerateFuncGraph(const AbstractBasePtrList& args_spe | |||||
| } | } | ||||
| auto min_abs = std::min_element(args_spec_list.begin(), args_spec_list.end(), | auto min_abs = std::min_element(args_spec_list.begin(), args_spec_list.end(), | ||||
| [](const AbstractBasePtr& x, const AbstractBasePtr& y) { | |||||
| [](const AbstractBasePtr &x, const AbstractBasePtr &y) { | |||||
| return (x->cast<AbstractTuplePtr>()->size() < y->cast<AbstractTuplePtr>()->size()); | return (x->cast<AbstractTuplePtr>()->size() < y->cast<AbstractTuplePtr>()->size()); | ||||
| }); | }); | ||||
| FuncGraphPtr ret_graph = std::make_shared<FuncGraph>(); | FuncGraphPtr ret_graph = std::make_shared<FuncGraph>(); | ||||
| @@ -81,10 +81,10 @@ FuncGraphPtr ZipOperation::GenerateFuncGraph(const AbstractBasePtrList& args_spe | |||||
| return ret_graph; | return ret_graph; | ||||
| } | } | ||||
| REGISTER_PYBIND_DEFINE(ZipOperation_, ([](const py::module* m) { | |||||
| REGISTER_PYBIND_DEFINE(ZipOperation_, ([](const py::module *m) { | |||||
| (void)py::class_<ZipOperation, MetaFuncGraph, std::shared_ptr<ZipOperation>>(*m, | (void)py::class_<ZipOperation, MetaFuncGraph, std::shared_ptr<ZipOperation>>(*m, | ||||
| "ZipOperation_") | "ZipOperation_") | ||||
| .def(py::init<std::string&>()); | |||||
| .def(py::init<std::string &>()); | |||||
| })); | })); | ||||
| } // namespace prim | } // namespace prim | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -42,15 +42,15 @@ using AbstractTuplePtr = abstract::AbstractTuplePtr; | |||||
| class ZipOperation : public MetaFuncGraph { | class ZipOperation : public MetaFuncGraph { | ||||
| public: | public: | ||||
| explicit ZipOperation(const std::string& name) : MetaFuncGraph(name) {} | |||||
| explicit ZipOperation(const std::string &name) : MetaFuncGraph(name) {} | |||||
| ~ZipOperation() override = default; | ~ZipOperation() override = default; | ||||
| MS_DECLARE_PARENT(ZipOperation, MetaFuncGraph) | MS_DECLARE_PARENT(ZipOperation, MetaFuncGraph) | ||||
| FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) override; | |||||
| friend std::ostream& operator<<(std::ostream& os, const ZipOperation& op) { | |||||
| FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override; | |||||
| friend std::ostream &operator<<(std::ostream &os, const ZipOperation &op) { | |||||
| os << op.name_; | os << op.name_; | ||||
| return os; | return os; | ||||
| } | } | ||||
| friend bool operator==(const ZipOperation& lhs, const ZipOperation& rhs) { return lhs.name_ == rhs.name_; } | |||||
| friend bool operator==(const ZipOperation &lhs, const ZipOperation &rhs) { return lhs.name_ == rhs.name_; } | |||||
| }; | }; | ||||
| using ZipOperationPtr = std::shared_ptr<ZipOperation>; | using ZipOperationPtr = std::shared_ptr<ZipOperation>; | ||||
| } // namespace prim | } // namespace prim | ||||
| @@ -238,7 +238,7 @@ const PrimitivePtr kPrimImageSummary = std::make_shared<Primitive>("ImageSummary | |||||
| const PrimitivePtr kPrimTensorSummary = std::make_shared<Primitive>("TensorSummary"); | const PrimitivePtr kPrimTensorSummary = std::make_shared<Primitive>("TensorSummary"); | ||||
| const PrimitivePtr kPrimHistogramSummary = std::make_shared<Primitive>("HistogramSummary"); | const PrimitivePtr kPrimHistogramSummary = std::make_shared<Primitive>("HistogramSummary"); | ||||
| ValuePtr GetPythonOps(const std::string& op_name, const std::string& module_name) { | |||||
| ValuePtr GetPythonOps(const std::string &op_name, const std::string &module_name) { | |||||
| py::object obj = parse::python_adapter::GetPyFn(module_name, op_name); | py::object obj = parse::python_adapter::GetPyFn(module_name, op_name); | ||||
| ValuePtr node = nullptr; | ValuePtr node = nullptr; | ||||
| bool succ = parse::ConvertData(obj, &node); | bool succ = parse::ConvertData(obj, &node); | ||||
| @@ -26,8 +26,8 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| // namespace to support primitive operators | // namespace to support primitive operators | ||||
| namespace prim { | namespace prim { | ||||
| ValuePtr GetPythonOps(const std::string& op_name, | |||||
| const std::string& module_name = "mindspore._extends.parse.standard_method"); | |||||
| ValuePtr GetPythonOps(const std::string &op_name, | |||||
| const std::string &module_name = "mindspore._extends.parse.standard_method"); | |||||
| // Arithmetic | // Arithmetic | ||||
| extern const PrimitivePtr kPrimScalarAdd; | extern const PrimitivePtr kPrimScalarAdd; | ||||
| @@ -241,7 +241,7 @@ extern const PrimitivePtr kPrimVirtualDataset; | |||||
| class DoSignaturePrimitive : public Primitive { | class DoSignaturePrimitive : public Primitive { | ||||
| public: | public: | ||||
| explicit DoSignaturePrimitive(const std::string& name, const ValuePtr& function) | |||||
| explicit DoSignaturePrimitive(const std::string &name, const ValuePtr &function) | |||||
| : Primitive("S-Prim-" + name), function_(function) {} | : Primitive("S-Prim-" + name), function_(function) {} | ||||
| ~DoSignaturePrimitive() override = default; | ~DoSignaturePrimitive() override = default; | ||||
| @@ -257,7 +257,7 @@ using DoSignaturePrimitivePtr = std::shared_ptr<DoSignaturePrimitive>; | |||||
| class UnpackGraphPrimitive : public Primitive { | class UnpackGraphPrimitive : public Primitive { | ||||
| public: | public: | ||||
| explicit UnpackGraphPrimitive(const std::string& name, const bool& with_sens, const bool& need_unpack_args) | |||||
| explicit UnpackGraphPrimitive(const std::string &name, const bool &with_sens, const bool &need_unpack_args) | |||||
| : Primitive("UnpackGraph"), with_sens_in_args_(with_sens), need_unpack_args_(need_unpack_args) {} | : Primitive("UnpackGraph"), with_sens_in_args_(with_sens), need_unpack_args_(need_unpack_args) {} | ||||
| ~UnpackGraphPrimitive() override = default; | ~UnpackGraphPrimitive() override = default; | ||||
| MS_DECLARE_PARENT(UnpackGraphPrimitive, Primitive) | MS_DECLARE_PARENT(UnpackGraphPrimitive, Primitive) | ||||
| @@ -54,7 +54,7 @@ PrimToFunction::PrimToFunction() | |||||
| {"scalar_sub", kPrimTypeTwoArgs}, | {"scalar_sub", kPrimTypeTwoArgs}, | ||||
| {"scalar_floordiv", kPrimTypeTwoArgs}}) {} | {"scalar_floordiv", kPrimTypeTwoArgs}}) {} | ||||
| bool PrimToFunction::GetFunction(const PrimitivePtr& prim, FunctionPtr* const func) const { | |||||
| bool PrimToFunction::GetFunction(const PrimitivePtr &prim, FunctionPtr *const func) const { | |||||
| bool result = false; | bool result = false; | ||||
| if (func != nullptr) { | if (func != nullptr) { | ||||
| @@ -79,7 +79,7 @@ bool PrimToFunction::GetFunction(const PrimitivePtr& prim, FunctionPtr* const fu | |||||
| return result; | return result; | ||||
| } | } | ||||
| int PrimToFunction::GetPrimType(const PrimitivePtr& prim) const { | |||||
| int PrimToFunction::GetPrimType(const PrimitivePtr &prim) const { | |||||
| MS_EXCEPTION_IF_NULL(prim); | MS_EXCEPTION_IF_NULL(prim); | ||||
| int prim_type = static_cast<int>(kPrimTypeUnknown); | int prim_type = static_cast<int>(kPrimTypeUnknown); | ||||
| @@ -41,21 +41,21 @@ class PrimToFunction; | |||||
| class PrimToFunction { | class PrimToFunction { | ||||
| public: | public: | ||||
| // Return a thread-safe singleton instance | // Return a thread-safe singleton instance | ||||
| static PrimToFunction& GetInstance() { | |||||
| static PrimToFunction &GetInstance() { | |||||
| static PrimToFunction instance; | static PrimToFunction instance; | ||||
| return instance; | return instance; | ||||
| } | } | ||||
| PrimToFunction(const PrimToFunction&) = delete; | |||||
| PrimToFunction& operator=(const PrimToFunction&) = delete; | |||||
| PrimToFunction(const PrimToFunction &) = delete; | |||||
| PrimToFunction &operator=(const PrimToFunction &) = delete; | |||||
| ~PrimToFunction() = default; | ~PrimToFunction() = default; | ||||
| // Get the args and return value for a primitive instance. | // Get the args and return value for a primitive instance. | ||||
| bool GetFunction(const PrimitivePtr& prim, FunctionPtr* func) const; | |||||
| bool GetFunction(const PrimitivePtr &prim, FunctionPtr *func) const; | |||||
| private: | private: | ||||
| PrimToFunction(); | PrimToFunction(); | ||||
| // Get the number of primitive arguments | // Get the number of primitive arguments | ||||
| int GetPrimType(const PrimitivePtr& prim) const; | |||||
| int GetPrimType(const PrimitivePtr &prim) const; | |||||
| const std::unordered_map<std::string, int> prim_func_type_map_; | const std::unordered_map<std::string, int> prim_func_type_map_; | ||||
| }; | }; | ||||
| } // namespace prim | } // namespace prim | ||||
| @@ -24,7 +24,7 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace ad { | namespace ad { | ||||
| Adjoint::Adjoint(const AnfNodePtr& primal, const AnfNodePtr& k, const FuncGraphPtr& caller) | |||||
| Adjoint::Adjoint(const AnfNodePtr &primal, const AnfNodePtr &k, const FuncGraphPtr &caller) | |||||
| : primal_(primal), caller_(caller), dout_(nullptr) { | : primal_(primal), caller_(caller), dout_(nullptr) { | ||||
| if (k != nullptr) { | if (k != nullptr) { | ||||
| k_ = k; | k_ = k; | ||||
| @@ -43,13 +43,13 @@ Adjoint::Adjoint(const AnfNodePtr& primal, const AnfNodePtr& k, const FuncGraphP | |||||
| AnfNodePtr Adjoint::k() { return k_; } | AnfNodePtr Adjoint::k() { return k_; } | ||||
| void Adjoint::RegisterKUser(const CNodePtr& user, size_t index) { k_user_.emplace_back(std::make_pair(user, index)); } | |||||
| void Adjoint::RegisterKUser(const CNodePtr &user, size_t index) { k_user_.emplace_back(std::make_pair(user, index)); } | |||||
| void Adjoint::UpdateK(const AnfNodePtr& new_k) { | |||||
| void Adjoint::UpdateK(const AnfNodePtr &new_k) { | |||||
| MS_EXCEPTION_IF_NULL(new_k); | MS_EXCEPTION_IF_NULL(new_k); | ||||
| MS_LOG(DEBUG) << "Replace k " << k_->ToString() << " with " << new_k->ToString(); | MS_LOG(DEBUG) << "Replace k " << k_->ToString() << " with " << new_k->ToString(); | ||||
| // In recursive case, it needs update. | // In recursive case, it needs update. | ||||
| for (auto& user : k_user_) { | |||||
| for (auto &user : k_user_) { | |||||
| MS_LOG(DEBUG) << "Update k user " << user.first->ToString() << " " << user.second << " input with new_k" | MS_LOG(DEBUG) << "Update k user " << user.first->ToString() << " " << user.second << " input with new_k" | ||||
| << new_k->ToString(); | << new_k->ToString(); | ||||
| if (user.first->input(user.second) != k_) { | if (user.first->input(user.second) != k_) { | ||||
| @@ -65,11 +65,11 @@ AnfNodePtr Adjoint::primal() { return primal_; } | |||||
| AnfNodePtr Adjoint::dout() { return dout_hole_; } | AnfNodePtr Adjoint::dout() { return dout_hole_; } | ||||
| void Adjoint::RegisterDoutUser(const CNodePtr& user, size_t index) { | |||||
| void Adjoint::RegisterDoutUser(const CNodePtr &user, size_t index) { | |||||
| dout_user_.emplace_back(std::make_pair(user, index)); | dout_user_.emplace_back(std::make_pair(user, index)); | ||||
| } | } | ||||
| void Adjoint::AccumulateDout(const AnfNodePtr& dout_factor) { | |||||
| void Adjoint::AccumulateDout(const AnfNodePtr &dout_factor) { | |||||
| if (dout_ != nullptr) { | if (dout_ != nullptr) { | ||||
| MS_LOG(DEBUG) << "Update dout " << dout_->ToString() << " with dout_factor " << dout_factor->ToString(); | MS_LOG(DEBUG) << "Update dout " << dout_->ToString() << " with dout_factor " << dout_factor->ToString(); | ||||
| auto add = prim::GetPythonOps("hyper_add"); | auto add = prim::GetPythonOps("hyper_add"); | ||||
| @@ -81,7 +81,7 @@ void Adjoint::AccumulateDout(const AnfNodePtr& dout_factor) { | |||||
| void Adjoint::CallDoutHole() { | void Adjoint::CallDoutHole() { | ||||
| if (dout_ != nullptr) { | if (dout_ != nullptr) { | ||||
| for (auto& user : dout_user_) { | |||||
| for (auto &user : dout_user_) { | |||||
| MS_LOG(DEBUG) << "Update dout user " << user.first->ToString() << " " << user.second << " input with dout " | MS_LOG(DEBUG) << "Update dout user " << user.first->ToString() << " " << user.second << " input with dout " | ||||
| << dout_->ToString(); | << dout_->ToString(); | ||||
| if (user.first->input(user.second) != dout_hole_) { | if (user.first->input(user.second) != dout_hole_) { | ||||
| @@ -28,15 +28,15 @@ namespace mindspore { | |||||
| namespace ad { | namespace ad { | ||||
| class Adjoint { | class Adjoint { | ||||
| public: | public: | ||||
| Adjoint(const AnfNodePtr& primal, const AnfNodePtr& k, const FuncGraphPtr& caller); | |||||
| Adjoint(const AnfNodePtr &primal, const AnfNodePtr &k, const FuncGraphPtr &caller); | |||||
| ~Adjoint() = default; | ~Adjoint() = default; | ||||
| AnfNodePtr primal(); | AnfNodePtr primal(); | ||||
| AnfNodePtr k(); | AnfNodePtr k(); | ||||
| void UpdateK(const AnfNodePtr& k); | |||||
| void RegisterKUser(const CNodePtr& user, size_t index); | |||||
| void UpdateK(const AnfNodePtr &k); | |||||
| void RegisterKUser(const CNodePtr &user, size_t index); | |||||
| AnfNodePtr dout(); | AnfNodePtr dout(); | ||||
| void AccumulateDout(const AnfNodePtr& dout_factor); | |||||
| void RegisterDoutUser(const CNodePtr& user, size_t index); | |||||
| void AccumulateDout(const AnfNodePtr &dout_factor); | |||||
| void RegisterDoutUser(const CNodePtr &user, size_t index); | |||||
| void CallDoutHole(); | void CallDoutHole(); | ||||
| private: | private: | ||||
| @@ -36,7 +36,7 @@ using mindspore::abstract::AbstractList; | |||||
| using mindspore::abstract::AbstractScalar; | using mindspore::abstract::AbstractScalar; | ||||
| using mindspore::abstract::AbstractTuple; | using mindspore::abstract::AbstractTuple; | ||||
| static AbstractBasePtr Reabs(const AbstractBasePtr& t) { | |||||
| static AbstractBasePtr Reabs(const AbstractBasePtr &t) { | |||||
| if (t == nullptr) { | if (t == nullptr) { | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| @@ -47,14 +47,14 @@ static AbstractBasePtr Reabs(const AbstractBasePtr& t) { | |||||
| AbstractBasePtrList baselist; | AbstractBasePtrList baselist; | ||||
| auto attributes = abs_class->attributes(); | auto attributes = abs_class->attributes(); | ||||
| (void)std::transform(attributes.begin(), attributes.end(), std::back_inserter(baselist), | (void)std::transform(attributes.begin(), attributes.end(), std::back_inserter(baselist), | ||||
| [](const AbstractAttribute& item) { return item.second; }); | |||||
| [](const AbstractAttribute &item) { return item.second; }); | |||||
| res = std::make_shared<AbstractTuple>(baselist); | res = std::make_shared<AbstractTuple>(baselist); | ||||
| } else if (t->isa<AbstractDictionary>()) { | } else if (t->isa<AbstractDictionary>()) { | ||||
| auto abs_dict = dyn_cast<AbstractDictionary>(t); | auto abs_dict = dyn_cast<AbstractDictionary>(t); | ||||
| AbstractBasePtrList baselist; | AbstractBasePtrList baselist; | ||||
| auto elements = abs_dict->elements(); | auto elements = abs_dict->elements(); | ||||
| (void)std::transform(elements.begin(), elements.end(), std::back_inserter(baselist), | (void)std::transform(elements.begin(), elements.end(), std::back_inserter(baselist), | ||||
| [](const AbstractAttribute& item) { return item.second; }); | |||||
| [](const AbstractAttribute &item) { return item.second; }); | |||||
| res = std::make_shared<AbstractTuple>(baselist); | res = std::make_shared<AbstractTuple>(baselist); | ||||
| } else if (t->isa<AbstractList>()) { | } else if (t->isa<AbstractList>()) { | ||||
| auto abs_dict = dyn_cast<AbstractList>(t); | auto abs_dict = dyn_cast<AbstractList>(t); | ||||
| @@ -63,11 +63,11 @@ static AbstractBasePtr Reabs(const AbstractBasePtr& t) { | |||||
| return res; | return res; | ||||
| } | } | ||||
| AnfNodePtr ConvertGetAttrToTupleGetItem(const CNodePtr& node) { | |||||
| AnfNodePtr ConvertGetAttrToTupleGetItem(const CNodePtr &node) { | |||||
| MS_EXCEPTION_IF_NULL(node); | MS_EXCEPTION_IF_NULL(node); | ||||
| MS_EXCEPTION_IF_NULL(node->func_graph()); | MS_EXCEPTION_IF_NULL(node->func_graph()); | ||||
| const auto& inputs = node->inputs(); | |||||
| const auto &inputs = node->inputs(); | |||||
| // Inputs should be [getattr, data, attribute] | // Inputs should be [getattr, data, attribute] | ||||
| MS_ASSERT(inputs.size() == 3 && "GetAttr should have three inputs."); | MS_ASSERT(inputs.size() == 3 && "GetAttr should have three inputs."); | ||||
| @@ -86,9 +86,9 @@ AnfNodePtr ConvertGetAttrToTupleGetItem(const CNodePtr& node) { | |||||
| auto cons_str = cons_is_str ? GetValue<std::string>(GetValueNode(cons)) : ""; | auto cons_str = cons_is_str ? GetValue<std::string>(GetValueNode(cons)) : ""; | ||||
| auto ct = dyn_cast<AbstractClass>(dt); | auto ct = dyn_cast<AbstractClass>(dt); | ||||
| const auto& cmap = ct->attributes(); | |||||
| const auto &cmap = ct->attributes(); | |||||
| int count = 0; | int count = 0; | ||||
| for (auto& item : cmap) { | |||||
| for (auto &item : cmap) { | |||||
| if (cons_is_str && item.first == cons_str) { | if (cons_is_str && item.first == cons_str) { | ||||
| break; | break; | ||||
| } | } | ||||
| @@ -102,12 +102,12 @@ AnfNodePtr ConvertGetAttrToTupleGetItem(const CNodePtr& node) { | |||||
| return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleGetItem), data, idx_c}); | return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleGetItem), data, idx_c}); | ||||
| } | } | ||||
| AnfNodePtr ConvertDictGetItemToTupleGetItem(const CNodePtr& node) { | |||||
| AnfNodePtr ConvertDictGetItemToTupleGetItem(const CNodePtr &node) { | |||||
| MS_EXCEPTION_IF_NULL(node); | MS_EXCEPTION_IF_NULL(node); | ||||
| MS_EXCEPTION_IF_NULL(node->func_graph()); | MS_EXCEPTION_IF_NULL(node->func_graph()); | ||||
| // Inputs should be [dict_getitem, dict, item] | // Inputs should be [dict_getitem, dict, item] | ||||
| const auto& inputs = node->inputs(); | |||||
| const auto &inputs = node->inputs(); | |||||
| MS_ASSERT(inputs.size() == 3 && "DictGetItem should have three inputs."); | MS_ASSERT(inputs.size() == 3 && "DictGetItem should have three inputs."); | ||||
| AnfNodePtr data = inputs[1]; | AnfNodePtr data = inputs[1]; | ||||
| @@ -124,9 +124,9 @@ AnfNodePtr ConvertDictGetItemToTupleGetItem(const CNodePtr& node) { | |||||
| auto cons_str = cons_is_str ? GetValue<std::string>(GetValueNode(cons)) : ""; | auto cons_str = cons_is_str ? GetValue<std::string>(GetValueNode(cons)) : ""; | ||||
| auto ct = dyn_cast<abstract::AbstractDictionary>(dt); | auto ct = dyn_cast<abstract::AbstractDictionary>(dt); | ||||
| const auto& cmap = ct->elements(); | |||||
| const auto &cmap = ct->elements(); | |||||
| int count = 0; | int count = 0; | ||||
| for (auto& item : cmap) { | |||||
| for (auto &item : cmap) { | |||||
| if (cons_is_str && item.first == cons_str) { | if (cons_is_str && item.first == cons_str) { | ||||
| break; | break; | ||||
| } | } | ||||
| @@ -139,7 +139,7 @@ AnfNodePtr ConvertDictGetItemToTupleGetItem(const CNodePtr& node) { | |||||
| return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleGetItem), data, idx_c}); | return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleGetItem), data, idx_c}); | ||||
| } | } | ||||
| AnfNodePtr ConvertMakeRecordToMakeTuple(const CNodePtr& node) { | |||||
| AnfNodePtr ConvertMakeRecordToMakeTuple(const CNodePtr &node) { | |||||
| MS_EXCEPTION_IF_NULL(node); | MS_EXCEPTION_IF_NULL(node); | ||||
| MS_EXCEPTION_IF_NULL(node->func_graph()); | MS_EXCEPTION_IF_NULL(node->func_graph()); | ||||
| @@ -150,11 +150,11 @@ AnfNodePtr ConvertMakeRecordToMakeTuple(const CNodePtr& node) { | |||||
| return node->func_graph()->NewCNode(inputs); | return node->func_graph()->NewCNode(inputs); | ||||
| } | } | ||||
| AnfNodePtr ErasePartialNode(const CNodePtr& node) { | |||||
| AnfNodePtr ErasePartialNode(const CNodePtr &node) { | |||||
| MS_EXCEPTION_IF_NULL(node); | MS_EXCEPTION_IF_NULL(node); | ||||
| MS_EXCEPTION_IF_NULL(node->func_graph()); | MS_EXCEPTION_IF_NULL(node->func_graph()); | ||||
| const auto& inputs = node->inputs(); | |||||
| const auto &inputs = node->inputs(); | |||||
| // Inputs should be [partial, fn, arg1, ...], so offset by 2 to get arg; | // Inputs should be [partial, fn, arg1, ...], so offset by 2 to get arg; | ||||
| MS_ASSERT(inputs.size() >= 2 && "Partial should have more than two inputs."); | MS_ASSERT(inputs.size() >= 2 && "Partial should have more than two inputs."); | ||||
| @@ -178,7 +178,7 @@ AnfNodePtr ErasePartialNode(const CNodePtr& node) { | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| AnfNodePtr ConvertMakeListToMakeTuple(const CNodePtr& node) { | |||||
| AnfNodePtr ConvertMakeListToMakeTuple(const CNodePtr &node) { | |||||
| MS_EXCEPTION_IF_NULL(node); | MS_EXCEPTION_IF_NULL(node); | ||||
| MS_EXCEPTION_IF_NULL(node->func_graph()); | MS_EXCEPTION_IF_NULL(node->func_graph()); | ||||
| @@ -189,11 +189,11 @@ AnfNodePtr ConvertMakeListToMakeTuple(const CNodePtr& node) { | |||||
| return node->func_graph()->NewCNode(inputs); | return node->func_graph()->NewCNode(inputs); | ||||
| } | } | ||||
| AnfNodePtr ConvertListGetItemToTupleGetItem(const CNodePtr& node) { | |||||
| AnfNodePtr ConvertListGetItemToTupleGetItem(const CNodePtr &node) { | |||||
| MS_EXCEPTION_IF_NULL(node); | MS_EXCEPTION_IF_NULL(node); | ||||
| MS_EXCEPTION_IF_NULL(node->func_graph()); | MS_EXCEPTION_IF_NULL(node->func_graph()); | ||||
| const auto& inputs = node->inputs(); | |||||
| const auto &inputs = node->inputs(); | |||||
| // Inputs should be [list_getitem, list, item] | // Inputs should be [list_getitem, list, item] | ||||
| if (inputs.size() < 3) { | if (inputs.size() < 3) { | ||||
| MS_LOG(EXCEPTION) << "Node's input number < 3."; | MS_LOG(EXCEPTION) << "Node's input number < 3."; | ||||
| @@ -208,11 +208,11 @@ AnfNodePtr ConvertListGetItemToTupleGetItem(const CNodePtr& node) { | |||||
| return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleGetItem), data, cons_node}); | return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleGetItem), data, cons_node}); | ||||
| } | } | ||||
| AnfNodePtr ConvertListSetItemToTupleSetItem(const CNodePtr& node) { | |||||
| AnfNodePtr ConvertListSetItemToTupleSetItem(const CNodePtr &node) { | |||||
| MS_EXCEPTION_IF_NULL(node); | MS_EXCEPTION_IF_NULL(node); | ||||
| MS_EXCEPTION_IF_NULL(node->func_graph()); | MS_EXCEPTION_IF_NULL(node->func_graph()); | ||||
| const auto& inputs = node->inputs(); | |||||
| const auto &inputs = node->inputs(); | |||||
| // Inputs should be [list_setitem, list, index, item] | // Inputs should be [list_setitem, list, index, item] | ||||
| if (inputs.size() < 4) { | if (inputs.size() < 4) { | ||||
| MS_LOG(EXCEPTION) << "Node's input number < 4."; | MS_LOG(EXCEPTION) << "Node's input number < 4."; | ||||
| @@ -225,36 +225,36 @@ AnfNodePtr ConvertListSetItemToTupleSetItem(const CNodePtr& node) { | |||||
| return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleSetItem), data, cons, value}); | return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleSetItem), data, cons, value}); | ||||
| } | } | ||||
| AnfNodePtr EraseMakeDictNode(const CNodePtr& node) { | |||||
| AnfNodePtr EraseMakeDictNode(const CNodePtr &node) { | |||||
| MS_EXCEPTION_IF_NULL(node); | MS_EXCEPTION_IF_NULL(node); | ||||
| const auto& inputs = node->inputs(); | |||||
| const auto &inputs = node->inputs(); | |||||
| MS_ASSERT(inputs.size() >= 3 && "MakeDict should have three inputs"); | MS_ASSERT(inputs.size() >= 3 && "MakeDict should have three inputs"); | ||||
| return inputs[2]; | return inputs[2]; | ||||
| } | } | ||||
| AnfNodePtr EraseMakeKeywordArgNode(const CNodePtr& node) { | |||||
| AnfNodePtr EraseMakeKeywordArgNode(const CNodePtr &node) { | |||||
| MS_EXCEPTION_IF_NULL(node); | MS_EXCEPTION_IF_NULL(node); | ||||
| const auto& inputs = node->inputs(); | |||||
| const auto &inputs = node->inputs(); | |||||
| // Inputs should be [make_keyword_arg, key, value] | // Inputs should be [make_keyword_arg, key, value] | ||||
| MS_ASSERT(inputs.size() == 3 && "MakeKeyword should have three inputs"); | MS_ASSERT(inputs.size() == 3 && "MakeKeyword should have three inputs"); | ||||
| return inputs[2]; | return inputs[2]; | ||||
| } | } | ||||
| AnfNodePtr EraseExtractKeywordArg(const CNodePtr& node) { | |||||
| AnfNodePtr EraseExtractKeywordArg(const CNodePtr &node) { | |||||
| MS_EXCEPTION_IF_NULL(node); | MS_EXCEPTION_IF_NULL(node); | ||||
| const auto& inputs = node->inputs(); | |||||
| const auto &inputs = node->inputs(); | |||||
| // Inputs should be [extract_keyword_arg, arg, key] | // Inputs should be [extract_keyword_arg, arg, key] | ||||
| MS_ASSERT(inputs.size() == 3 && "ExtractKeyword should have three inputs"); | MS_ASSERT(inputs.size() == 3 && "ExtractKeyword should have three inputs"); | ||||
| return inputs[2]; | return inputs[2]; | ||||
| } | } | ||||
| ValueTuplePtr ConvertValueListToValueTuple(const ValueListPtr& value_list, int depth) { | |||||
| ValueTuplePtr ConvertValueListToValueTuple(const ValueListPtr &value_list, int depth) { | |||||
| const int DEPTH_MAX = 5; | const int DEPTH_MAX = 5; | ||||
| if (depth > DEPTH_MAX) { | if (depth > DEPTH_MAX) { | ||||
| MS_LOG(EXCEPTION) << "List nesting is not allowed more than 5 levels."; | MS_LOG(EXCEPTION) << "List nesting is not allowed more than 5 levels."; | ||||
| } | } | ||||
| std::vector<ValuePtr> elements; | std::vector<ValuePtr> elements; | ||||
| for (const auto& it : value_list->value()) { | |||||
| for (const auto &it : value_list->value()) { | |||||
| ValuePtr value = nullptr; | ValuePtr value = nullptr; | ||||
| if (it->isa<ValueList>()) { | if (it->isa<ValueList>()) { | ||||
| value = ConvertValueListToValueTuple(it->cast<ValueListPtr>(), depth + 1); | value = ConvertValueListToValueTuple(it->cast<ValueListPtr>(), depth + 1); | ||||
| @@ -266,7 +266,7 @@ ValueTuplePtr ConvertValueListToValueTuple(const ValueListPtr& value_list, int d | |||||
| return std::make_shared<ValueTuple>(elements); | return std::make_shared<ValueTuple>(elements); | ||||
| } | } | ||||
| AnfNodePtr ConvertValueListNodeToValueTupleNode(const ValueNodePtr& node) { | |||||
| AnfNodePtr ConvertValueListNodeToValueTupleNode(const ValueNodePtr &node) { | |||||
| MS_EXCEPTION_IF_NULL(node); | MS_EXCEPTION_IF_NULL(node); | ||||
| ValuePtr value = node->value(); | ValuePtr value = node->value(); | ||||
| auto value_list = value->cast<ValueListPtr>(); | auto value_list = value->cast<ValueListPtr>(); | ||||
| @@ -278,13 +278,13 @@ AnfNodePtr ConvertValueListNodeToValueTupleNode(const ValueNodePtr& node) { | |||||
| // Convert class to Tuple | // Convert class to Tuple | ||||
| // Convert getattr to getitem | // Convert getattr to getitem | ||||
| // Convert make_record to make_tuple | // Convert make_record to make_tuple | ||||
| void SimplifyDataStructures(const FuncGraphPtr& root, const FuncGraphManagerPtr& manager) { | |||||
| void SimplifyDataStructures(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager) { | |||||
| MS_EXCEPTION_IF_NULL(manager); | MS_EXCEPTION_IF_NULL(manager); | ||||
| manager->AddFuncGraph(root); | manager->AddFuncGraph(root); | ||||
| // Since `manager->Replace(...);` will modify member `all_nodes_`, so `all_node` can't be a ref var | // Since `manager->Replace(...);` will modify member `all_nodes_`, so `all_node` can't be a ref var | ||||
| AnfNodeSet all_node = manager->all_nodes(); | AnfNodeSet all_node = manager->all_nodes(); | ||||
| for (auto& node : all_node) { | |||||
| for (auto &node : all_node) { | |||||
| MS_EXCEPTION_IF_NULL(node); | MS_EXCEPTION_IF_NULL(node); | ||||
| auto cnode = node->cast<CNodePtr>(); | auto cnode = node->cast<CNodePtr>(); | ||||
| AnfNodePtr new_node = nullptr; | AnfNodePtr new_node = nullptr; | ||||
| @@ -320,20 +320,20 @@ void SimplifyDataStructures(const FuncGraphPtr& root, const FuncGraphManagerPtr& | |||||
| } | } | ||||
| } | } | ||||
| for (auto& node : manager->all_nodes()) { | |||||
| for (auto &node : manager->all_nodes()) { | |||||
| auto ret = Reabs(node->abstract()); | auto ret = Reabs(node->abstract()); | ||||
| node->set_abstract(ret); | node->set_abstract(ret); | ||||
| } | } | ||||
| } | } | ||||
| // expand tuples in graph parameters | // expand tuples in graph parameters | ||||
| static std::vector<AnfNodePtr> ExpandTuplesP(const FuncGraphManagerPtr& mng, const FuncGraphPtr& func_graph, | |||||
| const std::vector<AnfNodePtr>& params) { | |||||
| static std::vector<AnfNodePtr> ExpandTuplesP(const FuncGraphManagerPtr &mng, const FuncGraphPtr &func_graph, | |||||
| const std::vector<AnfNodePtr> ¶ms) { | |||||
| MS_EXCEPTION_IF_NULL(mng); | MS_EXCEPTION_IF_NULL(mng); | ||||
| MS_EXCEPTION_IF_NULL(func_graph); | MS_EXCEPTION_IF_NULL(func_graph); | ||||
| std::vector<AnfNodePtr> new_params; | std::vector<AnfNodePtr> new_params; | ||||
| for (const auto& param : params) { | |||||
| for (const auto ¶m : params) { | |||||
| MS_EXCEPTION_IF_NULL(param); | MS_EXCEPTION_IF_NULL(param); | ||||
| auto param_abs = param->abstract(); | auto param_abs = param->abstract(); | ||||
| MS_EXCEPTION_IF_NULL(param_abs); | MS_EXCEPTION_IF_NULL(param_abs); | ||||
| @@ -350,7 +350,7 @@ static std::vector<AnfNodePtr> ExpandTuplesP(const FuncGraphManagerPtr& mng, con | |||||
| std::vector<AnfNodePtr> new_param; | std::vector<AnfNodePtr> new_param; | ||||
| std::vector<AnfNodePtr> inputs{NewValueNode(prim::kPrimMakeTuple)}; | std::vector<AnfNodePtr> inputs{NewValueNode(prim::kPrimMakeTuple)}; | ||||
| auto abs_tuple = dyn_cast<AbstractTuple>(param_abs); | auto abs_tuple = dyn_cast<AbstractTuple>(param_abs); | ||||
| for (auto& elem : abs_tuple->elements()) { | |||||
| for (auto &elem : abs_tuple->elements()) { | |||||
| auto np = std::make_shared<Parameter>(func_graph); | auto np = std::make_shared<Parameter>(func_graph); | ||||
| np->set_abstract(elem); | np->set_abstract(elem); | ||||
| new_param.emplace_back(np); | new_param.emplace_back(np); | ||||
| @@ -366,11 +366,11 @@ static std::vector<AnfNodePtr> ExpandTuplesP(const FuncGraphManagerPtr& mng, con | |||||
| } | } | ||||
| // expand tuples in graph applies | // expand tuples in graph applies | ||||
| static std::vector<AnfNodePtr> ExpandTuplesC(const FuncGraphPtr& graph, const std::vector<AnfNodePtr>& inputs) { | |||||
| static std::vector<AnfNodePtr> ExpandTuplesC(const FuncGraphPtr &graph, const std::vector<AnfNodePtr> &inputs) { | |||||
| MS_EXCEPTION_IF_NULL(graph); | MS_EXCEPTION_IF_NULL(graph); | ||||
| std::vector<AnfNodePtr> new_inputs; | std::vector<AnfNodePtr> new_inputs; | ||||
| for (const auto& input : inputs) { | |||||
| for (const auto &input : inputs) { | |||||
| MS_EXCEPTION_IF_NULL(input); | MS_EXCEPTION_IF_NULL(input); | ||||
| auto input_abs = input->abstract(); | auto input_abs = input->abstract(); | ||||
| @@ -391,7 +391,7 @@ static std::vector<AnfNodePtr> ExpandTuplesC(const FuncGraphPtr& graph, const st | |||||
| int idx = 0; | int idx = 0; | ||||
| std::vector<AnfNodePtr> new_input; | std::vector<AnfNodePtr> new_input; | ||||
| auto abs_tuple = dyn_cast<AbstractTuple>(input_abs); | auto abs_tuple = dyn_cast<AbstractTuple>(input_abs); | ||||
| for (auto& elem : abs_tuple->elements()) { | |||||
| for (auto &elem : abs_tuple->elements()) { | |||||
| auto c_node = graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), input, NewValueNode(idx)}); | auto c_node = graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), input, NewValueNode(idx)}); | ||||
| AbstractBasePtr aptr = std::make_shared<AbstractScalar>(std::make_shared<Int32Imm>(idx)); | AbstractBasePtr aptr = std::make_shared<AbstractScalar>(std::make_shared<Int32Imm>(idx)); | ||||
| c_node->input(2)->set_abstract(aptr); | c_node->input(2)->set_abstract(aptr); | ||||
| @@ -416,19 +416,19 @@ static std::vector<AnfNodePtr> ExpandTuplesC(const FuncGraphPtr& graph, const st | |||||
| // tuples in Graph's parameters: AbstractTuple (a, b, c) --> | // tuples in Graph's parameters: AbstractTuple (a, b, c) --> | ||||
| // CNode("make_tuple", Parameter(a), Parameter(b), Parameter(c)) | // CNode("make_tuple", Parameter(a), Parameter(b), Parameter(c)) | ||||
| // cppcheck-suppress unusedFunction | // cppcheck-suppress unusedFunction | ||||
| void EraseTuple(const FuncGraphPtr& root, const FuncGraphManagerPtr& manager) { | |||||
| void EraseTuple(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager) { | |||||
| MS_EXCEPTION_IF_NULL(manager); | MS_EXCEPTION_IF_NULL(manager); | ||||
| manager->AddFuncGraph(root); | manager->AddFuncGraph(root); | ||||
| // NOTICE: since `manager->Replace(...);` will modify member `all_nodes_`, so `all_node` can't be a ref var | // NOTICE: since `manager->Replace(...);` will modify member `all_nodes_`, so `all_node` can't be a ref var | ||||
| AnfNodeSet all_node = manager->all_nodes(); | AnfNodeSet all_node = manager->all_nodes(); | ||||
| for (auto& node : all_node) { | |||||
| for (auto &node : all_node) { | |||||
| auto cnode = node->cast<CNodePtr>(); | auto cnode = node->cast<CNodePtr>(); | ||||
| if (cnode == nullptr) { | if (cnode == nullptr) { | ||||
| continue; | continue; | ||||
| } | } | ||||
| const auto& inputs = cnode->inputs(); | |||||
| const auto &inputs = cnode->inputs(); | |||||
| // Bypass the first input in inputs as it's fn. | // Bypass the first input in inputs as it's fn. | ||||
| if (!IsValueNode<Primitive>(inputs[0])) { | if (!IsValueNode<Primitive>(inputs[0])) { | ||||
| @@ -466,7 +466,7 @@ void EraseTuple(const FuncGraphPtr& root, const FuncGraphManagerPtr& manager) { | |||||
| } | } | ||||
| FuncGraphSet all_graph = manager->func_graphs(); | FuncGraphSet all_graph = manager->func_graphs(); | ||||
| for (auto& func_graph : all_graph) { | |||||
| for (auto &func_graph : all_graph) { | |||||
| MS_EXCEPTION_IF_NULL(func_graph); | MS_EXCEPTION_IF_NULL(func_graph); | ||||
| auto expand_p = ExpandTuplesP(manager, func_graph, func_graph->parameters()); | auto expand_p = ExpandTuplesP(manager, func_graph, func_graph->parameters()); | ||||
| manager->SetParameters(func_graph, expand_p); | manager->SetParameters(func_graph, expand_p); | ||||
| @@ -22,7 +22,7 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace opt { | namespace opt { | ||||
| // Automatically adding control depend based on effect order and side effect analysis. | // Automatically adding control depend based on effect order and side effect analysis. | ||||
| void AddControlDepend(const FuncGraphPtr& graph); | |||||
| void AddControlDepend(const FuncGraphPtr &graph); | |||||
| } // namespace opt | } // namespace opt | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CCSRC_OPTIMIZER_CONTROL_DEPEND_H_ | #endif // MINDSPORE_CCSRC_OPTIMIZER_CONTROL_DEPEND_H_ | ||||
| @@ -44,7 +44,7 @@ static AnfNodePtr GenerateUnpackGraphNode(std::vector<AnfNodePtr> inputs_y, Func | |||||
| nodes.push_back(func_node); | nodes.push_back(func_node); | ||||
| // {unpackcall, {GradOperation, ...}, args...} | // {unpackcall, {GradOperation, ...}, args...} | ||||
| std::transform(inputs_y.begin() + 2, inputs_y.end(), std::back_inserter(nodes), | std::transform(inputs_y.begin() + 2, inputs_y.end(), std::back_inserter(nodes), | ||||
| [](const AnfNodePtr& node) { return node; }); | |||||
| [](const AnfNodePtr &node) { return node; }); | |||||
| unpack_graph_node = func_graph->NewCNode(nodes); | unpack_graph_node = func_graph->NewCNode(nodes); | ||||
| } else { | } else { | ||||
| auto unpack_graph = std::make_shared<prim::UnpackGraphPrimitive>("unpack_graph", sens_param, false); | auto unpack_graph = std::make_shared<prim::UnpackGraphPrimitive>("unpack_graph", sens_param, false); | ||||
| @@ -52,14 +52,14 @@ static AnfNodePtr GenerateUnpackGraphNode(std::vector<AnfNodePtr> inputs_y, Func | |||||
| nodes.push_back(func_node); | nodes.push_back(func_node); | ||||
| // {{GradOperation, ...}, args...} | // {{GradOperation, ...}, args...} | ||||
| std::transform(inputs_y.begin() + 1, inputs_y.end(), std::back_inserter(nodes), | std::transform(inputs_y.begin() + 1, inputs_y.end(), std::back_inserter(nodes), | ||||
| [](const AnfNodePtr& node) { return node; }); | |||||
| [](const AnfNodePtr &node) { return node; }); | |||||
| unpack_graph_node = func_graph->NewCNode(nodes); | unpack_graph_node = func_graph->NewCNode(nodes); | ||||
| } | } | ||||
| return unpack_graph_node; | return unpack_graph_node; | ||||
| } | } | ||||
| // get metagraph of value node | // get metagraph of value node | ||||
| MetaFuncGraphPtr GetMetaFuncGraphOfValueNode(const AnfNodePtr& node) { | |||||
| MetaFuncGraphPtr GetMetaFuncGraphOfValueNode(const AnfNodePtr &node) { | |||||
| ValuePtr value; | ValuePtr value; | ||||
| if (IsValueNode<prim::DoSignaturePrimitive>(node)) { | if (IsValueNode<prim::DoSignaturePrimitive>(node)) { | ||||
| value = GetValueNode(node)->cast<prim::DoSignaturePrimitivePtr>()->function(); | value = GetValueNode(node)->cast<prim::DoSignaturePrimitivePtr>()->function(); | ||||
| @@ -73,7 +73,7 @@ MetaFuncGraphPtr GetMetaFuncGraphOfValueNode(const AnfNodePtr& node) { | |||||
| } | } | ||||
| // check if node is a specific metafuncgraph op | // check if node is a specific metafuncgraph op | ||||
| bool IsMetaFuncGraph(const AnfNodePtr& node, const MetaFuncGraphPtr meta_func_graph) { | |||||
| bool IsMetaFuncGraph(const AnfNodePtr &node, const MetaFuncGraphPtr meta_func_graph) { | |||||
| if (node != nullptr) { | if (node != nullptr) { | ||||
| auto meta_func_graph_ptr = GetMetaFuncGraphOfValueNode(node); | auto meta_func_graph_ptr = GetMetaFuncGraphOfValueNode(node); | ||||
| if (meta_func_graph_ptr == nullptr) { | if (meta_func_graph_ptr == nullptr) { | ||||
| @@ -89,7 +89,7 @@ bool IsMetaFuncGraph(const AnfNodePtr& node, const MetaFuncGraphPtr meta_func_gr | |||||
| // {{GradOperation, g, w}, Ys} | // {{GradOperation, g, w}, Ys} | ||||
| // {UnPackCall, {GradOperation, g, w}, Ys} | // {UnPackCall, {GradOperation, g, w}, Ys} | ||||
| AnfNodePtr GradVarPrepare::operator()(const OptimizerPtr&, const AnfNodePtr& node) { | |||||
| AnfNodePtr GradVarPrepare::operator()(const OptimizerPtr &, const AnfNodePtr &node) { | |||||
| if (!node->isa<CNode>() || node->func_graph() == nullptr) { | if (!node->isa<CNode>() || node->func_graph() == nullptr) { | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| @@ -31,20 +31,20 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| /* namespace to support opt */ | /* namespace to support opt */ | ||||
| namespace opt { | namespace opt { | ||||
| SubstitutionPtr MakeSubstitution(const TransformFuncType& transform, const std::string& name, const PrimitivePtr& prim, | |||||
| const RenormAction& renorm_action) { | |||||
| auto fn = [prim](const AnfNodePtr& node) -> bool { return IsPrimitiveCNode(node, prim); }; | |||||
| SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std::string &name, const PrimitivePtr &prim, | |||||
| const RenormAction &renorm_action) { | |||||
| auto fn = [prim](const AnfNodePtr &node) -> bool { return IsPrimitiveCNode(node, prim); }; | |||||
| return std::make_shared<Substitution>(transform, name, fn, renorm_action); | return std::make_shared<Substitution>(transform, name, fn, renorm_action); | ||||
| } | } | ||||
| SubstitutionPtr MakeSubstitution(const TransformFuncType& transform, const std::string& name, | |||||
| const std::vector<PrimitivePtr>& prims, const RenormAction& renorm_action) { | |||||
| auto fn = [prims](const AnfNodePtr& node) -> bool { | |||||
| SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std::string &name, | |||||
| const std::vector<PrimitivePtr> &prims, const RenormAction &renorm_action) { | |||||
| auto fn = [prims](const AnfNodePtr &node) -> bool { | |||||
| if (!node->isa<CNode>()) { | if (!node->isa<CNode>()) { | ||||
| return false; | return false; | ||||
| } | } | ||||
| for (auto& prim : prims) { | |||||
| for (auto &prim : prims) { | |||||
| if (IsPrimitiveCNode(node, prim)) { | if (IsPrimitiveCNode(node, prim)) { | ||||
| return true; | return true; | ||||
| } | } | ||||
| @@ -55,12 +55,12 @@ SubstitutionPtr MakeSubstitution(const TransformFuncType& transform, const std:: | |||||
| return std::make_shared<Substitution>(transform, name, fn, renorm_action); | return std::make_shared<Substitution>(transform, name, fn, renorm_action); | ||||
| } | } | ||||
| SubstitutionPtr MakeSubstitution(const TransformFuncType& transform, const std::string& name, | |||||
| const PredicateFuncType& predicate, const RenormAction& renorm_action) { | |||||
| SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std::string &name, | |||||
| const PredicateFuncType &predicate, const RenormAction &renorm_action) { | |||||
| return std::make_shared<Substitution>(transform, name, predicate, renorm_action); | return std::make_shared<Substitution>(transform, name, predicate, renorm_action); | ||||
| } | } | ||||
| AnfNodePtr Substitution::operator()(const OptimizerPtr& optimizer, const AnfNodePtr& node) const { | |||||
| AnfNodePtr Substitution::operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) const { | |||||
| #ifdef ENABLE_PROFILE | #ifdef ENABLE_PROFILE | ||||
| double t = GetTime(); | double t = GetTime(); | ||||
| #endif | #endif | ||||
| @@ -88,8 +88,8 @@ AnfNodePtr Substitution::operator()(const OptimizerPtr& optimizer, const AnfNode | |||||
| return result; | return result; | ||||
| } | } | ||||
| bool SubstitutionList::ApplyTransform(const OptimizerPtr& optimizer, const AnfNodePtr& root_node, | |||||
| const SubstitutionPtr& transform) const { | |||||
| bool SubstitutionList::ApplyTransform(const OptimizerPtr &optimizer, const AnfNodePtr &root_node, | |||||
| const SubstitutionPtr &transform) const { | |||||
| FuncGraphManagerPtr manager = optimizer->manager(); | FuncGraphManagerPtr manager = optimizer->manager(); | ||||
| std::unordered_set<AnfNodePtr> seen_node; | std::unordered_set<AnfNodePtr> seen_node; | ||||
| std::deque<AnfNodePtr> todo{root_node}; | std::deque<AnfNodePtr> todo{root_node}; | ||||
| @@ -131,13 +131,13 @@ bool SubstitutionList::ApplyTransform(const OptimizerPtr& optimizer, const AnfNo | |||||
| } | } | ||||
| if (node->isa<CNode>()) { | if (node->isa<CNode>()) { | ||||
| auto& inputs = node->cast<CNodePtr>()->inputs(); | |||||
| auto &inputs = node->cast<CNodePtr>()->inputs(); | |||||
| (void)std::copy(inputs.begin(), inputs.end(), std::back_inserter(todo)); | (void)std::copy(inputs.begin(), inputs.end(), std::back_inserter(todo)); | ||||
| } | } | ||||
| auto& node_users = manager->node_users(); | |||||
| auto &node_users = manager->node_users(); | |||||
| if (change && node_users.find(node) != node_users.end()) { | if (change && node_users.find(node) != node_users.end()) { | ||||
| for (auto& use : node_users[node]) { | |||||
| for (auto &use : node_users[node]) { | |||||
| auto use_node = use.first; | auto use_node = use.first; | ||||
| todo.push_back(use_node); | todo.push_back(use_node); | ||||
| if (seen_node.find(use_node) != seen_node.end()) { | if (seen_node.find(use_node) != seen_node.end()) { | ||||
| @@ -152,7 +152,7 @@ bool SubstitutionList::ApplyTransform(const OptimizerPtr& optimizer, const AnfNo | |||||
| return changes; | return changes; | ||||
| } | } | ||||
| bool SubstitutionList::operator()(const FuncGraphPtr& func_graph, const OptimizerPtr& optimizer) const { | |||||
| bool SubstitutionList::operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer) const { | |||||
| MS_EXCEPTION_IF_NULL(optimizer); | MS_EXCEPTION_IF_NULL(optimizer); | ||||
| MS_EXCEPTION_IF_NULL(func_graph); | MS_EXCEPTION_IF_NULL(func_graph); | ||||
| FuncGraphManagerPtr manager = optimizer->manager(); | FuncGraphManagerPtr manager = optimizer->manager(); | ||||
| @@ -163,7 +163,7 @@ bool SubstitutionList::operator()(const FuncGraphPtr& func_graph, const Optimize | |||||
| do { | do { | ||||
| loop = false; | loop = false; | ||||
| for (auto const& transform : list_) { | |||||
| for (auto const &transform : list_) { | |||||
| auto change = ApplyTransform(optimizer, func_graph->output(), transform); | auto change = ApplyTransform(optimizer, func_graph->output(), transform); | ||||
| changes = changes || change; | changes = changes || change; | ||||
| loop = loop || change; | loop = loop || change; | ||||
| @@ -28,7 +28,7 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace parallel { | namespace parallel { | ||||
| std::unordered_set<CNodePtr> FindCNodesWithPara(const AnfNodePtr& para, uint32_t recursive_times = 0) { | |||||
| std::unordered_set<CNodePtr> FindCNodesWithPara(const AnfNodePtr ¶, uint32_t recursive_times = 0) { | |||||
| if (recursive_times > MAX_RECURSIVE_CALL_TIMES) { | if (recursive_times > MAX_RECURSIVE_CALL_TIMES) { | ||||
| MS_LOG(EXCEPTION) << "FindCNodesWithPara exceeds max recursive call times! Max recursive call times is " | MS_LOG(EXCEPTION) << "FindCNodesWithPara exceeds max recursive call times! Max recursive call times is " | ||||
| << MAX_RECURSIVE_CALL_TIMES; | << MAX_RECURSIVE_CALL_TIMES; | ||||
| @@ -39,7 +39,7 @@ std::unordered_set<CNodePtr> FindCNodesWithPara(const AnfNodePtr& para, uint32_t | |||||
| MS_EXCEPTION_IF_NULL(manager); | MS_EXCEPTION_IF_NULL(manager); | ||||
| auto node_set = manager->node_users()[para]; | auto node_set = manager->node_users()[para]; | ||||
| std::unordered_set<CNodePtr> cnode_set; | std::unordered_set<CNodePtr> cnode_set; | ||||
| for (auto& node_pair : node_set) { | |||||
| for (auto &node_pair : node_set) { | |||||
| auto cnode = node_pair.first->cast<CNodePtr>(); | auto cnode = node_pair.first->cast<CNodePtr>(); | ||||
| MS_EXCEPTION_IF_NULL(cnode); | MS_EXCEPTION_IF_NULL(cnode); | ||||
| if (!IsValueNode<Primitive>(cnode->input(0))) { | if (!IsValueNode<Primitive>(cnode->input(0))) { | ||||
| @@ -54,7 +54,7 @@ std::unordered_set<CNodePtr> FindCNodesWithPara(const AnfNodePtr& para, uint32_t | |||||
| (void)cnode_set.emplace(cnode); | (void)cnode_set.emplace(cnode); | ||||
| } else { | } else { | ||||
| auto cnode_set_sub = FindCNodesWithPara(node_pair.first, recursive_times + 1); | auto cnode_set_sub = FindCNodesWithPara(node_pair.first, recursive_times + 1); | ||||
| for (auto& cnode_sub : cnode_set_sub) { | |||||
| for (auto &cnode_sub : cnode_set_sub) { | |||||
| (void)cnode_set.emplace(cnode_sub); | (void)cnode_set.emplace(cnode_sub); | ||||
| } | } | ||||
| } | } | ||||
| @@ -63,8 +63,8 @@ std::unordered_set<CNodePtr> FindCNodesWithPara(const AnfNodePtr& para, uint32_t | |||||
| } | } | ||||
| Status AllreduceFusion::AddNodeToGraph() { | Status AllreduceFusion::AddNodeToGraph() { | ||||
| const auto& parameters = root_graph_->parameters(); | |||||
| for (auto& parameter : parameters) { | |||||
| const auto ¶meters = root_graph_->parameters(); | |||||
| for (auto ¶meter : parameters) { | |||||
| if (!ParameterRequireGrad(parameter)) { | if (!ParameterRequireGrad(parameter)) { | ||||
| continue; | continue; | ||||
| } | } | ||||
| @@ -72,7 +72,7 @@ Status AllreduceFusion::AddNodeToGraph() { | |||||
| if (cnode_set.empty()) { | if (cnode_set.empty()) { | ||||
| continue; | continue; | ||||
| } | } | ||||
| for (auto& cnode : cnode_set) { | |||||
| for (auto &cnode : cnode_set) { | |||||
| MS_LOG(DEBUG) << "AddNode " << cnode->DebugString(); | MS_LOG(DEBUG) << "AddNode " << cnode->DebugString(); | ||||
| if (allreduce_graph_.AddNode(cnode, parameter) != SUCCESS) { | if (allreduce_graph_.AddNode(cnode, parameter) != SUCCESS) { | ||||
| MS_LOG(ERROR) << "AddNode failed! cnode: " << cnode->DebugString(); | MS_LOG(ERROR) << "AddNode failed! cnode: " << cnode->DebugString(); | ||||
| @@ -83,7 +83,7 @@ Status AllreduceFusion::AddNodeToGraph() { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| CNodeCostMap AllreduceFusion::FindCNode(const AnfNodePtr& from, uint32_t recursive_times) const { | |||||
| CNodeCostMap AllreduceFusion::FindCNode(const AnfNodePtr &from, uint32_t recursive_times) const { | |||||
| if (recursive_times > MAX_RECURSIVE_CALL_TIMES) { | if (recursive_times > MAX_RECURSIVE_CALL_TIMES) { | ||||
| MS_LOG(EXCEPTION) << "FindCNode exceeds max recursive call times! Max recursive call times is " | MS_LOG(EXCEPTION) << "FindCNode exceeds max recursive call times! Max recursive call times is " | ||||
| << MAX_RECURSIVE_CALL_TIMES; | << MAX_RECURSIVE_CALL_TIMES; | ||||
| @@ -110,30 +110,30 @@ CNodeCostMap AllreduceFusion::FindCNode(const AnfNodePtr& from, uint32_t recursi | |||||
| return cnode_dist; | return cnode_dist; | ||||
| } else { | } else { | ||||
| auto cnode_dist_next = FindNextCNodes(cnode, recursive_times + 1); | auto cnode_dist_next = FindNextCNodes(cnode, recursive_times + 1); | ||||
| for (auto& ele : cnode_dist_next) { | |||||
| for (auto &ele : cnode_dist_next) { | |||||
| cnode_dist[ele.first] = cost + ele.second; | cnode_dist[ele.first] = cost + ele.second; | ||||
| } | } | ||||
| } | } | ||||
| } else { | } else { | ||||
| auto cnode_dist_next = FindNextCNodes(cnode); | auto cnode_dist_next = FindNextCNodes(cnode); | ||||
| for (auto& ele : cnode_dist_next) { | |||||
| for (auto &ele : cnode_dist_next) { | |||||
| cnode_dist[ele.first] = ele.second; | cnode_dist[ele.first] = ele.second; | ||||
| } | } | ||||
| } | } | ||||
| return cnode_dist; | return cnode_dist; | ||||
| } | } | ||||
| CNodeCostMap AllreduceFusion::FindNextCNodes(const CNodePtr& from, uint32_t recursive_times) const { | |||||
| CNodeCostMap AllreduceFusion::FindNextCNodes(const CNodePtr &from, uint32_t recursive_times) const { | |||||
| if (recursive_times > MAX_RECURSIVE_CALL_TIMES) { | if (recursive_times > MAX_RECURSIVE_CALL_TIMES) { | ||||
| MS_LOG(EXCEPTION) << "FindNextCNodes exceeds max recursive call times! Max recursive call times is " | MS_LOG(EXCEPTION) << "FindNextCNodes exceeds max recursive call times! Max recursive call times is " | ||||
| << MAX_RECURSIVE_CALL_TIMES; | << MAX_RECURSIVE_CALL_TIMES; | ||||
| } | } | ||||
| const auto& from_inputs = from->inputs(); | |||||
| const auto &from_inputs = from->inputs(); | |||||
| std::unordered_map<CNodePtr, double> dist_map; | std::unordered_map<CNodePtr, double> dist_map; | ||||
| MS_LOG(DEBUG) << "from cnode " << from->DebugString() << " has " << from_inputs.size() << " inputs"; | MS_LOG(DEBUG) << "from cnode " << from->DebugString() << " has " << from_inputs.size() << " inputs"; | ||||
| for (auto& input_node : from_inputs) { | |||||
| for (auto &input_node : from_inputs) { | |||||
| auto cnode_dist = FindCNode(input_node, recursive_times + 1); | auto cnode_dist = FindCNode(input_node, recursive_times + 1); | ||||
| for (auto& ele : cnode_dist) { | |||||
| for (auto &ele : cnode_dist) { | |||||
| (void)dist_map.emplace(ele); | (void)dist_map.emplace(ele); | ||||
| } | } | ||||
| } | } | ||||
| @@ -142,11 +142,11 @@ CNodeCostMap AllreduceFusion::FindNextCNodes(const CNodePtr& from, uint32_t recu | |||||
| Status AllreduceFusion::AddEdgeToGraph() { | Status AllreduceFusion::AddEdgeToGraph() { | ||||
| std::unordered_map<CNodePtr, int32_t> cnode_state_map; | std::unordered_map<CNodePtr, int32_t> cnode_state_map; | ||||
| const auto& cnodes = allreduce_graph_.cnode_set(); | |||||
| for (auto& cnode : cnodes) { | |||||
| const auto &cnodes = allreduce_graph_.cnode_set(); | |||||
| for (auto &cnode : cnodes) { | |||||
| cnode_state_map[cnode] = 0; | cnode_state_map[cnode] = 0; | ||||
| } | } | ||||
| const auto& head_cnode = allreduce_graph_.head_cnode(); | |||||
| const auto &head_cnode = allreduce_graph_.head_cnode(); | |||||
| std::queue<CNodePtr> cnode_queue; | std::queue<CNodePtr> cnode_queue; | ||||
| cnode_queue.emplace(head_cnode); | cnode_queue.emplace(head_cnode); | ||||
| cnode_state_map[head_cnode] = 1; | cnode_state_map[head_cnode] = 1; | ||||
| @@ -156,9 +156,9 @@ Status AllreduceFusion::AddEdgeToGraph() { | |||||
| cnode_queue.pop(); | cnode_queue.pop(); | ||||
| cnode_state_map[cur_cnode] = 2; | cnode_state_map[cur_cnode] = 2; | ||||
| auto next = FindNextCNodes(cur_cnode); | auto next = FindNextCNodes(cur_cnode); | ||||
| for (auto& ele : next) { | |||||
| auto& cnode = ele.first; | |||||
| auto& dist = ele.second; | |||||
| for (auto &ele : next) { | |||||
| auto &cnode = ele.first; | |||||
| auto &dist = ele.second; | |||||
| if (cnode_state_map[cnode] == 0) { | if (cnode_state_map[cnode] == 0) { | ||||
| cnode_queue.emplace(cnode); | cnode_queue.emplace(cnode); | ||||
| cnode_state_map[cnode] = 1; | cnode_state_map[cnode] = 1; | ||||
| @@ -173,7 +173,7 @@ Status AllreduceFusion::AddEdgeToGraph() { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| std::vector<CNodePtr> FindMirror(const AnfNodePtr& para, uint32_t recursive_times = 0) { | |||||
| std::vector<CNodePtr> FindMirror(const AnfNodePtr ¶, uint32_t recursive_times = 0) { | |||||
| if (recursive_times > MAX_RECURSIVE_CALL_TIMES) { | if (recursive_times > MAX_RECURSIVE_CALL_TIMES) { | ||||
| MS_LOG(EXCEPTION) << "FindMirror exceeds max recursive call times! Max recursive call times is " | MS_LOG(EXCEPTION) << "FindMirror exceeds max recursive call times! Max recursive call times is " | ||||
| << MAX_RECURSIVE_CALL_TIMES; | << MAX_RECURSIVE_CALL_TIMES; | ||||
| @@ -184,7 +184,7 @@ std::vector<CNodePtr> FindMirror(const AnfNodePtr& para, uint32_t recursive_time | |||||
| MS_EXCEPTION_IF_NULL(manager); | MS_EXCEPTION_IF_NULL(manager); | ||||
| AnfNodeIndexSet node_set = manager->node_users()[para]; | AnfNodeIndexSet node_set = manager->node_users()[para]; | ||||
| std::vector<CNodePtr> cnode_list; | std::vector<CNodePtr> cnode_list; | ||||
| for (auto& node_pair : node_set) { | |||||
| for (auto &node_pair : node_set) { | |||||
| auto cnode = node_pair.first->cast<CNodePtr>(); | auto cnode = node_pair.first->cast<CNodePtr>(); | ||||
| MS_EXCEPTION_IF_NULL(cnode); | MS_EXCEPTION_IF_NULL(cnode); | ||||
| if (!IsValueNode<Primitive>(cnode->input(0))) { | if (!IsValueNode<Primitive>(cnode->input(0))) { | ||||
| @@ -210,7 +210,7 @@ std::vector<CNodePtr> FindMirror(const AnfNodePtr& para, uint32_t recursive_time | |||||
| return cnode_list; | return cnode_list; | ||||
| } | } | ||||
| void SetMirrorFusion(const CNodePtr& mirror_cnode, int32_t fusion, const std::string& parameter_name) { | |||||
| void SetMirrorFusion(const CNodePtr &mirror_cnode, int32_t fusion, const std::string ¶meter_name) { | |||||
| MS_EXCEPTION_IF_NULL(mirror_cnode); | MS_EXCEPTION_IF_NULL(mirror_cnode); | ||||
| MS_LOG(DEBUG) << "Set Mirror " << mirror_cnode->DebugString() << " fusion " << fusion; | MS_LOG(DEBUG) << "Set Mirror " << mirror_cnode->DebugString() << " fusion " << fusion; | ||||
| auto node_prim = GetValueNode<PrimitivePtr>(mirror_cnode->input(0)); | auto node_prim = GetValueNode<PrimitivePtr>(mirror_cnode->input(0)); | ||||
| @@ -227,14 +227,14 @@ void SetMirrorFusion(const CNodePtr& mirror_cnode, int32_t fusion, const std::st | |||||
| (void)node_prim->AddAttr(PARAMETER, MakeValue(std::make_shared<StringImm>(parameter_name))); | (void)node_prim->AddAttr(PARAMETER, MakeValue(std::make_shared<StringImm>(parameter_name))); | ||||
| } | } | ||||
| Status FindMirrorAndSetFusion(const AnfNodePtr& para, int32_t fusion) { | |||||
| Status FindMirrorAndSetFusion(const AnfNodePtr ¶, int32_t fusion) { | |||||
| auto mirror_cnodes = FindMirror(para); | auto mirror_cnodes = FindMirror(para); | ||||
| if (mirror_cnodes.empty()) { | if (mirror_cnodes.empty()) { | ||||
| MS_LOG(WARNING) << para->ToString() << " 0 Mirror CNode found."; | MS_LOG(WARNING) << para->ToString() << " 0 Mirror CNode found."; | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| if (mirror_cnodes.size() > 2) { | if (mirror_cnodes.size() > 2) { | ||||
| for (auto& mirror_cnode : mirror_cnodes) { | |||||
| for (auto &mirror_cnode : mirror_cnodes) { | |||||
| MS_EXCEPTION_IF_NULL(mirror_cnode); | MS_EXCEPTION_IF_NULL(mirror_cnode); | ||||
| MS_LOG(INFO) << mirror_cnode->DebugString(); | MS_LOG(INFO) << mirror_cnode->DebugString(); | ||||
| } | } | ||||
| @@ -243,15 +243,15 @@ Status FindMirrorAndSetFusion(const AnfNodePtr& para, int32_t fusion) { | |||||
| << "Mirror CNode found."; | << "Mirror CNode found."; | ||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| for (auto& mirror_cnode : mirror_cnodes) { | |||||
| for (auto &mirror_cnode : mirror_cnodes) { | |||||
| auto parameter_name = ParameterName(para); | auto parameter_name = ParameterName(para); | ||||
| SetMirrorFusion(mirror_cnode, fusion, parameter_name); | SetMirrorFusion(mirror_cnode, fusion, parameter_name); | ||||
| } | } | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status FindMirrorAndSetFusion(const std::vector<AnfNodePtr>& paras, int32_t fusion) { | |||||
| for (auto& param_node : paras) { | |||||
| Status FindMirrorAndSetFusion(const std::vector<AnfNodePtr> ¶s, int32_t fusion) { | |||||
| for (auto ¶m_node : paras) { | |||||
| if (FindMirrorAndSetFusion(param_node, fusion) != SUCCESS) { | if (FindMirrorAndSetFusion(param_node, fusion) != SUCCESS) { | ||||
| MS_LOG(ERROR) << "FindMirrorAndSetFusion failed"; | MS_LOG(ERROR) << "FindMirrorAndSetFusion failed"; | ||||
| return FAILED; | return FAILED; | ||||
| @@ -260,7 +260,7 @@ Status FindMirrorAndSetFusion(const std::vector<AnfNodePtr>& paras, int32_t fusi | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status AllreduceFusion::SetFusion(const std::vector<double>& cost_map) { | |||||
| Status AllreduceFusion::SetFusion(const std::vector<double> &cost_map) { | |||||
| if (cost_map.size() < 2) { | if (cost_map.size() < 2) { | ||||
| MS_LOG(ERROR) << "cost_map must has at least 2 items, cost_map size is " << cost_map.size(); | MS_LOG(ERROR) << "cost_map must has at least 2 items, cost_map size is " << cost_map.size(); | ||||
| return FAILED; | return FAILED; | ||||
| @@ -386,7 +386,7 @@ Status AllreduceFusion::SetFusionByAlgorithm(int32_t algorithm) { | |||||
| return SetFusionByBackwardCompAndAllreduceTime(); | return SetFusionByBackwardCompAndAllreduceTime(); | ||||
| } | } | ||||
| Status AllreduceFusion::ProcessAllreduceFusion(const CNodePtr& ret) { | |||||
| Status AllreduceFusion::ProcessAllreduceFusion(const CNodePtr &ret) { | |||||
| if (ret == nullptr) { | if (ret == nullptr) { | ||||
| MS_LOG(ERROR) << "ret is nullptr."; | MS_LOG(ERROR) << "ret is nullptr."; | ||||
| return FAILED; | return FAILED; | ||||
| @@ -50,15 +50,15 @@ class AllreduceFusion { | |||||
| allreduce_bandwidth_(0), | allreduce_bandwidth_(0), | ||||
| computation_time_parameter_(0) {} | computation_time_parameter_(0) {} | ||||
| virtual ~AllreduceFusion() = default; | virtual ~AllreduceFusion() = default; | ||||
| Status ProcessAllreduceFusion(const CNodePtr& ret); | |||||
| Status ProcessAllreduceFusion(const CNodePtr &ret); | |||||
| private: | private: | ||||
| Status AddNodeToGraph(); | Status AddNodeToGraph(); | ||||
| CNodeCostMap FindCNode(const AnfNodePtr& from, uint32_t recursive_times = 0) const; | |||||
| CNodeCostMap FindNextCNodes(const CNodePtr& from, uint32_t recursive_times = 0) const; | |||||
| CNodeCostMap FindCNode(const AnfNodePtr &from, uint32_t recursive_times = 0) const; | |||||
| CNodeCostMap FindNextCNodes(const CNodePtr &from, uint32_t recursive_times = 0) const; | |||||
| Status AddEdgeToGraph(); | Status AddEdgeToGraph(); | ||||
| std::vector<double> GenerateCostMap(int32_t fusion_times, double tail_percent) const; | std::vector<double> GenerateCostMap(int32_t fusion_times, double tail_percent) const; | ||||
| Status SetFusion(const std::vector<double>& cost_map); | |||||
| Status SetFusion(const std::vector<double> &cost_map); | |||||
| Status SetFusionByAlgorithm(int32_t algorithm); | Status SetFusionByAlgorithm(int32_t algorithm); | ||||
| Status SetFusionByBackwardCompTime(); | Status SetFusionByBackwardCompTime(); | ||||
| Status SetFusionByBackwardCompAndAllreduceTime(); | Status SetFusionByBackwardCompAndAllreduceTime(); | ||||
| @@ -23,7 +23,7 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace parallel { | namespace parallel { | ||||
| Status AllreduceGraph::AddNode(const CNodePtr& node, const AnfNodePtr& para) { | |||||
| Status AllreduceGraph::AddNode(const CNodePtr &node, const AnfNodePtr ¶) { | |||||
| AllreduceNodePtr arnode; | AllreduceNodePtr arnode; | ||||
| auto cnode_emplace_return = cnode_set_.emplace(node); | auto cnode_emplace_return = cnode_set_.emplace(node); | ||||
| if (!cnode_emplace_return.second) { | if (!cnode_emplace_return.second) { | ||||
| @@ -64,7 +64,7 @@ Status AllreduceGraph::AddNode(const CNodePtr& node, const AnfNodePtr& para) { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status AllreduceGraph::AddEdge(const CNodePtr& from, const CNodePtr& to, double dist) { | |||||
| Status AllreduceGraph::AddEdge(const CNodePtr &from, const CNodePtr &to, double dist) { | |||||
| auto from_arnode_iter = cnode_arnode_map_.find(from); | auto from_arnode_iter = cnode_arnode_map_.find(from); | ||||
| if (from_arnode_iter == cnode_arnode_map_.end()) { | if (from_arnode_iter == cnode_arnode_map_.end()) { | ||||
| MS_LOG(ERROR) << "cnode from: " << from->DebugString() << "has not been added"; | MS_LOG(ERROR) << "cnode from: " << from->DebugString() << "has not been added"; | ||||
| @@ -94,14 +94,14 @@ Status AllreduceGraph::AddEdge(const CNodePtr& from, const CNodePtr& to, double | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| bool AllreduceGraph::NodeInGraph(const CNodePtr& node) const { | |||||
| bool AllreduceGraph::NodeInGraph(const CNodePtr &node) const { | |||||
| auto cnode_iter = cnode_set_.find(node); | auto cnode_iter = cnode_set_.find(node); | ||||
| return !(cnode_iter == cnode_set_.end()); | return !(cnode_iter == cnode_set_.end()); | ||||
| } | } | ||||
| std::vector<AnfNodePtr> AllreduceGraph::GetParaByCost(double from, double to) { | std::vector<AnfNodePtr> AllreduceGraph::GetParaByCost(double from, double to) { | ||||
| std::vector<AnfNodePtr> nodes; | std::vector<AnfNodePtr> nodes; | ||||
| for (auto& cnode_arnode : cnode_arnode_map_) { | |||||
| for (auto &cnode_arnode : cnode_arnode_map_) { | |||||
| MS_LOG(DEBUG) << "cnode: " << cnode_arnode.first->DebugString() | MS_LOG(DEBUG) << "cnode: " << cnode_arnode.first->DebugString() | ||||
| << ", depend_feat_size: " << cnode_arnode.second->depend_feat_size() | << ", depend_feat_size: " << cnode_arnode.second->depend_feat_size() | ||||
| << " curr_para_size: " << cnode_arnode.second->curr_para_size(); | << " curr_para_size: " << cnode_arnode.second->curr_para_size(); | ||||
| @@ -117,7 +117,7 @@ std::pair<std::vector<AnfNodePtr>, double> AllreduceGraph::GetParaByParaSize(dou | |||||
| std::vector<AnfNodePtr> nodes; | std::vector<AnfNodePtr> nodes; | ||||
| double cur_para_size = 0; | double cur_para_size = 0; | ||||
| double from = to; | double from = to; | ||||
| for (auto& arnode : arnode_vec_) { | |||||
| for (auto &arnode : arnode_vec_) { | |||||
| if (arnode.depend_feat_size() != max_ && arnode.depend_feat_size() >= to) { | if (arnode.depend_feat_size() != max_ && arnode.depend_feat_size() >= to) { | ||||
| continue; | continue; | ||||
| } | } | ||||
| @@ -135,14 +135,14 @@ std::pair<std::vector<AnfNodePtr>, double> AllreduceGraph::GetParaByParaSize(dou | |||||
| void AllreduceGraph::PrintCNodeSet() const { | void AllreduceGraph::PrintCNodeSet() const { | ||||
| MS_LOG(INFO) << "CNodeSet:"; | MS_LOG(INFO) << "CNodeSet:"; | ||||
| for (auto& cnode : cnode_set_) { | |||||
| for (auto &cnode : cnode_set_) { | |||||
| MS_LOG(INFO) << cnode->DebugString(); | MS_LOG(INFO) << cnode->DebugString(); | ||||
| } | } | ||||
| } | } | ||||
| void AllreduceGraph::PrintAllredueGraphInfo() const { | void AllreduceGraph::PrintAllredueGraphInfo() const { | ||||
| MS_LOG(INFO) << "max: " << max_; | MS_LOG(INFO) << "max: " << max_; | ||||
| for (auto& cnode_arnode : cnode_arnode_map_) { | |||||
| for (auto &cnode_arnode : cnode_arnode_map_) { | |||||
| MS_LOG(INFO) << "cnode: " << cnode_arnode.first->DebugString(); | MS_LOG(INFO) << "cnode: " << cnode_arnode.first->DebugString(); | ||||
| MS_LOG(INFO) << "arnode info: "; | MS_LOG(INFO) << "arnode info: "; | ||||
| cnode_arnode.second->ToString(); | cnode_arnode.second->ToString(); | ||||
| @@ -151,21 +151,21 @@ void AllreduceGraph::PrintAllredueGraphInfo() const { | |||||
| void AllreduceGraph::PrintArnodeVec() const { | void AllreduceGraph::PrintArnodeVec() const { | ||||
| MS_LOG(INFO) << "ArnodeVec:"; | MS_LOG(INFO) << "ArnodeVec:"; | ||||
| for (auto& arnode : arnode_vec_) { | |||||
| for (auto &arnode : arnode_vec_) { | |||||
| arnode.ToString(); | arnode.ToString(); | ||||
| } | } | ||||
| } | } | ||||
| void AllreduceGraph::PrintArnodeSet() const { | void AllreduceGraph::PrintArnodeSet() const { | ||||
| MS_LOG(INFO) << "ArnodeSet:"; | MS_LOG(INFO) << "ArnodeSet:"; | ||||
| for (auto& arnode : arnode_set_) { | |||||
| for (auto &arnode : arnode_set_) { | |||||
| arnode->ToString(); | arnode->ToString(); | ||||
| } | } | ||||
| } | } | ||||
| void AllreduceGraph::SortArnode() { | void AllreduceGraph::SortArnode() { | ||||
| arnode_vec_.clear(); | arnode_vec_.clear(); | ||||
| for (auto& node : arnode_set_) { | |||||
| for (auto &node : arnode_set_) { | |||||
| arnode_vec_.emplace_back(*node); | arnode_vec_.emplace_back(*node); | ||||
| } | } | ||||
| std::sort(arnode_vec_.begin(), arnode_vec_.end(), std::greater<>()); | std::sort(arnode_vec_.begin(), arnode_vec_.end(), std::greater<>()); | ||||
| @@ -173,8 +173,8 @@ void AllreduceGraph::SortArnode() { | |||||
| Status AllreduceGraph::RemoveExtraParas() { | Status AllreduceGraph::RemoveExtraParas() { | ||||
| std::unordered_set<AnfNodePtr> para_map; | std::unordered_set<AnfNodePtr> para_map; | ||||
| for (auto& node : arnode_vec_) { | |||||
| for (auto& para : node.paras()) { | |||||
| for (auto &node : arnode_vec_) { | |||||
| for (auto ¶ : node.paras()) { | |||||
| auto emplac_result = para_map.emplace(para); | auto emplac_result = para_map.emplace(para); | ||||
| if (!emplac_result.second) { | if (!emplac_result.second) { | ||||
| MS_LOG(DEBUG) << "parameter: " << para->fullname_with_scope() << "in arnode"; | MS_LOG(DEBUG) << "parameter: " << para->fullname_with_scope() << "in arnode"; | ||||
| @@ -188,7 +188,7 @@ Status AllreduceGraph::RemoveExtraParas() { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status AllreduceGraph::set_head_cnode(const CNodePtr& node) { | |||||
| Status AllreduceGraph::set_head_cnode(const CNodePtr &node) { | |||||
| auto arnode = std::make_shared<AllreduceNode>(AllreduceNode()); | auto arnode = std::make_shared<AllreduceNode>(AllreduceNode()); | ||||
| if (arnode->Init(node) != SUCCESS) { | if (arnode->Init(node) != SUCCESS) { | ||||
| MS_LOG(ERROR) << "AllreduceNode Init failed"; | MS_LOG(ERROR) << "AllreduceNode Init failed"; | ||||
| @@ -42,9 +42,9 @@ class AllreduceGraph { | |||||
| cnode_arnode_map_(), | cnode_arnode_map_(), | ||||
| max_(0) {} | max_(0) {} | ||||
| virtual ~AllreduceGraph() = default; | virtual ~AllreduceGraph() = default; | ||||
| Status AddNode(const CNodePtr& node, const AnfNodePtr& para); | |||||
| Status AddEdge(const CNodePtr& from, const CNodePtr& to, double dist); | |||||
| bool NodeInGraph(const CNodePtr& node) const; | |||||
| Status AddNode(const CNodePtr &node, const AnfNodePtr ¶); | |||||
| Status AddEdge(const CNodePtr &from, const CNodePtr &to, double dist); | |||||
| bool NodeInGraph(const CNodePtr &node) const; | |||||
| std::vector<AnfNodePtr> GetParaByCost(double from, double to); | std::vector<AnfNodePtr> GetParaByCost(double from, double to); | ||||
| // Find the first several AllreduceNode whose depend_feat_size is less than to, the sum of whose parameter size is | // Find the first several AllreduceNode whose depend_feat_size is less than to, the sum of whose parameter size is | ||||
| // over para_size. | // over para_size. | ||||
| @@ -60,9 +60,9 @@ class AllreduceGraph { | |||||
| void PrintAllredueGraphInfo() const; | void PrintAllredueGraphInfo() const; | ||||
| void PrintArnodeVec() const; | void PrintArnodeVec() const; | ||||
| void PrintArnodeSet() const; | void PrintArnodeSet() const; | ||||
| const std::unordered_set<CNodePtr>& cnode_set() const { return cnode_set_; } | |||||
| const std::unordered_set<CNodePtr> &cnode_set() const { return cnode_set_; } | |||||
| CNodePtr head_cnode() const { return head_cnode_; } | CNodePtr head_cnode() const { return head_cnode_; } | ||||
| Status set_head_cnode(const CNodePtr& node); | |||||
| Status set_head_cnode(const CNodePtr &node); | |||||
| double max() const { return max_; } | double max() const { return max_; } | ||||
| private: | private: | ||||