| @@ -108,7 +108,7 @@ | |||||
| #include "utils/ms_context.h" | #include "utils/ms_context.h" | ||||
| #include "utils/config_manager.h" | #include "utils/config_manager.h" | ||||
| #include "debug/anf_ir_dump.h" | #include "debug/anf_ir_dump.h" | ||||
| #include "debug/anf_ir_utils.h" | |||||
| #include "debug/dump_proto.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace opt { | namespace opt { | ||||
| @@ -34,7 +34,7 @@ | |||||
| #include "runtime/device/ascend/ascend_stream_assign.h" | #include "runtime/device/ascend/ascend_stream_assign.h" | ||||
| #include "backend/session/anf_runtime_algorithm.h" | #include "backend/session/anf_runtime_algorithm.h" | ||||
| #include "debug/anf_ir_dump.h" | #include "debug/anf_ir_dump.h" | ||||
| #include "debug/anf_ir_utils.h" | |||||
| #include "debug/dump_proto.h" | |||||
| #include "utils/ms_utils.h" | #include "utils/ms_utils.h" | ||||
| #include "backend/optimizer/common/helper.h" | #include "backend/optimizer/common/helper.h" | ||||
| #include "runtime/device/kernel_runtime_manager.h" | #include "runtime/device/kernel_runtime_manager.h" | ||||
| @@ -13,7 +13,7 @@ | |||||
| * See the License for the specific language governing permissions and | * See the License for the specific language governing permissions and | ||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "debug/anf_ir_utils.h" | |||||
| #include "debug/dump_proto.h" | |||||
| #include "backend/session/gpu_session.h" | #include "backend/session/gpu_session.h" | ||||
| #include "runtime/device/gpu/kernel_info_setter.h" | #include "runtime/device/gpu/kernel_info_setter.h" | ||||
| #include "runtime/device/gpu/gpu_kernel_build.h" | #include "runtime/device/gpu/gpu_kernel_build.h" | ||||
| @@ -2258,64 +2258,4 @@ std::vector<FuncGraphPtr> ImportIR(const std::string &filename) { | |||||
| parser.ParseFile(); | parser.ParseFile(); | ||||
| return parser.GetFuncGraphs(); | return parser.GetFuncGraphs(); | ||||
| } | } | ||||
| #ifdef ENABLE_DUMP_IR | |||||
| void DumpIRProto(const FuncGraphPtr &func_graph, const std::string &suffix) { | |||||
| if (func_graph == nullptr) { | |||||
| MS_LOG(ERROR) << "Func graph is nullptr"; | |||||
| return; | |||||
| } | |||||
| auto ms_context = MsContext::GetInstance(); | |||||
| if (ms_context == nullptr) { | |||||
| MS_LOG(ERROR) << "ms_context is nullptr"; | |||||
| return; | |||||
| } | |||||
| auto save_graphs_path = ms_context->save_graphs_path(); | |||||
| if (save_graphs_path.empty()) { | |||||
| save_graphs_path = "."; | |||||
| } | |||||
| std::string file_path = save_graphs_path + "/" + "ms_output_" + suffix + ".pb"; | |||||
| if (file_path.size() > PATH_MAX) { | |||||
| MS_LOG(ERROR) << "File path " << file_path << " is too long."; | |||||
| return; | |||||
| } | |||||
| char real_path[PATH_MAX] = {0}; | |||||
| char *real_path_ret = nullptr; | |||||
| #if defined(_WIN32) || defined(_WIN64) | |||||
| real_path_ret = _fullpath(real_path, file_path.c_str(), PATH_MAX); | |||||
| #else | |||||
| real_path_ret = realpath(file_path.c_str(), real_path); | |||||
| #endif | |||||
| if (nullptr == real_path_ret) { | |||||
| MS_LOG(DEBUG) << "dir " << file_path << " does not exit."; | |||||
| } else { | |||||
| std::string path_string = real_path; | |||||
| if (chmod(common::SafeCStr(path_string), S_IRUSR | S_IWUSR) == -1) { | |||||
| MS_LOG(ERROR) << "Modify file:" << real_path << " to rw fail."; | |||||
| return; | |||||
| } | |||||
| } | |||||
| // write to pb file | |||||
| std::ofstream ofs(real_path); | |||||
| if (!ofs.is_open()) { | |||||
| MS_LOG(ERROR) << "Open file '" << real_path << "' failed!"; | |||||
| return; | |||||
| } | |||||
| ofs << GetFuncGraphProtoString(func_graph); | |||||
| ofs.close(); | |||||
| // set file mode to read only by user | |||||
| ChangeFileMode(file_path, S_IRUSR); | |||||
| } | |||||
| #else | |||||
| void DumpIRProto(const FuncGraphPtr &, const std::string &) { | |||||
| static bool already_printed = false; | |||||
| if (already_printed) { | |||||
| return; | |||||
| } | |||||
| already_printed = true; | |||||
| MS_LOG(WARNING) << "The functionality of dumping function graph IR in protobuf format is disabled, " | |||||
| << "please recompile source to enable it. See help of building script."; | |||||
| } | |||||
| #endif | |||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -112,14 +112,6 @@ void ExportIR(const std::string &filename, const std::string &id, const FuncGrap | |||||
| void ExportIR(const std::string &filename, const std::vector<TaggedGraph> &graphs); | void ExportIR(const std::string &filename, const std::vector<TaggedGraph> &graphs); | ||||
| std::vector<FuncGraphPtr> ImportIR(const std::string &filename); | std::vector<FuncGraphPtr> ImportIR(const std::string &filename); | ||||
| std::string GetFuncGraphProtoString(const FuncGraphPtr &func_graph); | |||||
| void DumpIRProto(const FuncGraphPtr &func_graph, const std::string &suffix); | |||||
| std::string GetOnnxProtoString(const FuncGraphPtr &func_graph); | |||||
| std::string GetBinaryProtoString(const FuncGraphPtr &func_graph); | |||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CCSRC_DEBUG_ANF_IR_UTILS_H_ | #endif // MINDSPORE_CCSRC_DEBUG_ANF_IR_UTILS_H_ | ||||
| @@ -13,16 +13,20 @@ | |||||
| * See the License for the specific language governing permissions and | * See the License for the specific language governing permissions and | ||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "debug/dump_proto.h" | |||||
| #include <algorithm> | |||||
| #include <fstream> | |||||
| #include <map> | #include <map> | ||||
| #include <memory> | #include <memory> | ||||
| #include <utility> | #include <utility> | ||||
| #include <algorithm> | |||||
| #include <vector> | |||||
| #include "debug/anf_ir_utils.h" | |||||
| #include "proto/anf_ir.pb.h" | #include "proto/anf_ir.pb.h" | ||||
| #include "ir/graph_utils.h" | #include "ir/graph_utils.h" | ||||
| #include "utils/ms_context.h" | |||||
| #include "utils/symbolic.h" | #include "utils/symbolic.h" | ||||
| #include "utils/utils.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| class ProtoExporter { | class ProtoExporter { | ||||
| @@ -514,4 +518,64 @@ std::string GetFuncGraphProtoString(const FuncGraphPtr &func_graph) { | |||||
| ProtoExporter exporter; | ProtoExporter exporter; | ||||
| return exporter.GetFuncGraphProtoString(func_graph); | return exporter.GetFuncGraphProtoString(func_graph); | ||||
| } | } | ||||
| #ifdef ENABLE_DUMP_IR | |||||
| void DumpIRProto(const FuncGraphPtr &func_graph, const std::string &suffix) { | |||||
| if (func_graph == nullptr) { | |||||
| MS_LOG(ERROR) << "Func graph is nullptr"; | |||||
| return; | |||||
| } | |||||
| auto ms_context = MsContext::GetInstance(); | |||||
| if (ms_context == nullptr) { | |||||
| MS_LOG(ERROR) << "ms_context is nullptr"; | |||||
| return; | |||||
| } | |||||
| auto save_graphs_path = ms_context->save_graphs_path(); | |||||
| if (save_graphs_path.empty()) { | |||||
| save_graphs_path = "."; | |||||
| } | |||||
| std::string file_path = save_graphs_path + "/" + "ms_output_" + suffix + ".pb"; | |||||
| if (file_path.size() > PATH_MAX) { | |||||
| MS_LOG(ERROR) << "File path " << file_path << " is too long."; | |||||
| return; | |||||
| } | |||||
| char real_path[PATH_MAX] = {0}; | |||||
| char *real_path_ret = nullptr; | |||||
| #if defined(_WIN32) || defined(_WIN64) | |||||
| real_path_ret = _fullpath(real_path, file_path.c_str(), PATH_MAX); | |||||
| #else | |||||
| real_path_ret = realpath(file_path.c_str(), real_path); | |||||
| #endif | |||||
| if (nullptr == real_path_ret) { | |||||
| MS_LOG(DEBUG) << "dir " << file_path << " does not exit."; | |||||
| } else { | |||||
| std::string path_string = real_path; | |||||
| if (chmod(common::SafeCStr(path_string), S_IRUSR | S_IWUSR) == -1) { | |||||
| MS_LOG(ERROR) << "Modify file:" << real_path << " to rw fail."; | |||||
| return; | |||||
| } | |||||
| } | |||||
| // write to pb file | |||||
| std::ofstream ofs(real_path); | |||||
| if (!ofs.is_open()) { | |||||
| MS_LOG(ERROR) << "Open file '" << real_path << "' failed!"; | |||||
| return; | |||||
| } | |||||
| ofs << GetFuncGraphProtoString(func_graph); | |||||
| ofs.close(); | |||||
| // set file mode to read only by user | |||||
| ChangeFileMode(file_path, S_IRUSR); | |||||
| } | |||||
| #else | |||||
| void DumpIRProto(const FuncGraphPtr &, const std::string &) { | |||||
| static bool already_printed = false; | |||||
| if (already_printed) { | |||||
| return; | |||||
| } | |||||
| already_printed = true; | |||||
| MS_LOG(WARNING) << "The functionality of dumping function graph IR in protobuf format is disabled, " | |||||
| << "please recompile source to enable it. See help of building script."; | |||||
| } | |||||
| #endif | |||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -0,0 +1,33 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_DEBUG_DUMP_PROTO_H_ | |||||
| #define MINDSPORE_CCSRC_DEBUG_DUMP_PROTO_H_ | |||||
| #include <string> | |||||
| #include "ir/func_graph.h" | |||||
| namespace mindspore { | |||||
| std::string GetFuncGraphProtoString(const FuncGraphPtr &func_graph); | |||||
| std::string GetOnnxProtoString(const FuncGraphPtr &func_graph); | |||||
| std::string GetBinaryProtoString(const FuncGraphPtr &func_graph); | |||||
| void DumpIRProto(const FuncGraphPtr &func_graph, const std::string &suffix); | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_DEBUG_DUMP_PROTO_H_ | |||||
| @@ -27,6 +27,8 @@ | |||||
| #include "pybind_api/ir/tensor_py.h" | #include "pybind_api/ir/tensor_py.h" | ||||
| #include "frontend/operator/ops.h" | #include "frontend/operator/ops.h" | ||||
| #include "abstract/infer_functions.h" | #include "abstract/infer_functions.h" | ||||
| #include "utils/convert_utils_py.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace abstract { | namespace abstract { | ||||
| enum State { | enum State { | ||||
| @@ -17,10 +17,14 @@ | |||||
| */ | */ | ||||
| #include "frontend/optimizer/cse.h" | #include "frontend/optimizer/cse.h" | ||||
| #include <vector> | #include <vector> | ||||
| #include <set> | #include <set> | ||||
| #include <unordered_map> | #include <unordered_map> | ||||
| #include "abstract/abstract_function.h" | |||||
| #include "utils/flags.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| /* namespace to support opt */ | /* namespace to support opt */ | ||||
| namespace opt { | namespace opt { | ||||
| @@ -24,23 +24,16 @@ | |||||
| #include <memory> | #include <memory> | ||||
| #include "ir/anf.h" | #include "ir/anf.h" | ||||
| #include "ir/manager.h" | #include "ir/manager.h" | ||||
| #include "frontend/optimizer/optimizer.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| /* namespace to support opt */ | /* namespace to support opt */ | ||||
| namespace opt { | namespace opt { | ||||
| // Common subexpression elimination. | // Common subexpression elimination. | ||||
| class CSE { | class CSE { | ||||
| public: | public: | ||||
| explicit CSE(bool report_changes = true) : report_changes_(report_changes) {} | |||||
| CSE() = default; | |||||
| virtual ~CSE() = default; | virtual ~CSE() = default; | ||||
| bool operator()(const FuncGraphPtr &root, const OptimizerPtr &optimizer) { | |||||
| bool chg = Cse(root, optimizer->resource()->manager()); | |||||
| return chg && report_changes_; | |||||
| } | |||||
| virtual bool CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node, bool check_side_effect = true) const; | virtual bool CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node, bool check_side_effect = true) const; | ||||
| virtual bool CheckRandomEffect(const AnfNodePtr &main, const AnfNodePtr &node) const; | virtual bool CheckRandomEffect(const AnfNodePtr &main, const AnfNodePtr &node) const; | ||||
| @@ -51,7 +44,6 @@ class CSE { | |||||
| bool BuildOrderGroupAndDoReplace(const FuncGraphManagerPtr manager) const; | bool BuildOrderGroupAndDoReplace(const FuncGraphManagerPtr manager) const; | ||||
| bool DoReplace(const FuncGraphManagerPtr manager, const std::vector<std::size_t> &order_group, | bool DoReplace(const FuncGraphManagerPtr manager, const std::vector<std::size_t> &order_group, | ||||
| std::unordered_map<std::size_t, std::vector<AnfNodePtr>> *groups) const; | std::unordered_map<std::size_t, std::vector<AnfNodePtr>> *groups) const; | ||||
| bool report_changes_; | |||||
| }; | }; | ||||
| BasePtr AbsOf(const AnfNodePtr &node); | BasePtr AbsOf(const AnfNodePtr &node); | ||||
| @@ -0,0 +1,53 @@ | |||||
| /** | |||||
| * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). | |||||
| * | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_CSE_PASS_H_ | |||||
| #define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_CSE_PASS_H_ | |||||
| #include <vector> | |||||
| #include <unordered_map> | |||||
| #include <memory> | |||||
| #include "frontend/optimizer/cse.h" | |||||
| #include "frontend/optimizer/optimizer.h" | |||||
| namespace mindspore { | |||||
| /* namespace to support opt */ | |||||
| namespace opt { | |||||
| // Common subexpression elimination. | |||||
| class CSEPass : public CSE { | |||||
| public: | |||||
| explicit CSEPass(bool report_changes = true) : CSE(), report_changes_(report_changes) {} | |||||
| virtual ~CSEPass() = default; | |||||
| bool operator()(const FuncGraphPtr &root, const OptimizerPtr &optimizer) { | |||||
| bool chg = Cse(root, optimizer->resource()->manager()); | |||||
| return chg && report_changes_; | |||||
| } | |||||
| private: | |||||
| bool report_changes_; | |||||
| }; | |||||
| BasePtr AbsOf(const AnfNodePtr &node); | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_CSE_PASS_H_ | |||||
| @@ -21,6 +21,7 @@ | |||||
| #include <string> | #include <string> | ||||
| #include "pipeline/jit/parse/python_adapter.h" | #include "pipeline/jit/parse/python_adapter.h" | ||||
| #include "utils/convert_utils_py.h" | |||||
| using mindspore::tensor::Tensor; | using mindspore::tensor::Tensor; | ||||
| @@ -28,7 +28,7 @@ | |||||
| #include "pipeline/jit/validator.h" | #include "pipeline/jit/validator.h" | ||||
| #include "pipeline/jit/remove_value_node_dup.h" | #include "pipeline/jit/remove_value_node_dup.h" | ||||
| #include "frontend/optimizer/optimizer.h" | #include "frontend/optimizer/optimizer.h" | ||||
| #include "frontend/optimizer/cse.h" | |||||
| #include "frontend/optimizer/cse_pass.h" | |||||
| #include "frontend/optimizer/graph_kernel_reuse.h" | #include "frontend/optimizer/graph_kernel_reuse.h" | ||||
| #include "frontend/optimizer/clean.h" | #include "frontend/optimizer/clean.h" | ||||
| #include "frontend/optimizer/irpass.h" | #include "frontend/optimizer/irpass.h" | ||||
| @@ -158,7 +158,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) { | |||||
| {"resolve", resolve_pass}, | {"resolve", resolve_pass}, | ||||
| {"a_after_grad", a_after_grad}, | {"a_after_grad", a_after_grad}, | ||||
| {"renormalize", opt::OptPassConfig::Renormalize()}, | {"renormalize", opt::OptPassConfig::Renormalize()}, | ||||
| {"cse", opt::OptPassConfig(opt::CSE(false))}, | |||||
| {"cse", opt::OptPassConfig(opt::CSEPass(false))}, | |||||
| {"a_3", a_3}}); | {"a_3", a_3}}); | ||||
| return map_a; | return map_a; | ||||
| @@ -192,7 +192,7 @@ OptPassGroupMap GetOptPassesB(const opt::irpass::OptimizeIRPassLib &irpass) { | |||||
| {"b_1", b_1}, | {"b_1", b_1}, | ||||
| {"b_2", b_2}, | {"b_2", b_2}, | ||||
| {"renormalize", opt::OptPassConfig::Renormalize()}, | {"renormalize", opt::OptPassConfig::Renormalize()}, | ||||
| {"cse", opt::OptPassConfig(opt::CSE(false))}, | |||||
| {"cse", opt::OptPassConfig(opt::CSEPass(false))}, | |||||
| }); | }); | ||||
| return map; | return map; | ||||
| } | } | ||||
| @@ -205,7 +205,7 @@ OptPassGroupMap GetOptPassesGraphKernelA(const opt::irpass::OptimizeIRPassLib &i | |||||
| {"graph_kernel_reuse", opt::OptPassConfig(opt::GraphKernelReuse())}, | {"graph_kernel_reuse", opt::OptPassConfig(opt::GraphKernelReuse())}, | ||||
| {"interface_fusion", interface_fusion}, | {"interface_fusion", interface_fusion}, | ||||
| {"renormalize", opt::OptPassConfig::Renormalize()}, | {"renormalize", opt::OptPassConfig::Renormalize()}, | ||||
| {"cse", opt::OptPassConfig(opt::CSE(false))}, | |||||
| {"cse", opt::OptPassConfig(opt::CSEPass(false))}, | |||||
| }); | }); | ||||
| return map; | return map; | ||||
| } | } | ||||
| @@ -29,9 +29,11 @@ | |||||
| #include "pipeline/jit/parse/data_converter.h" | #include "pipeline/jit/parse/data_converter.h" | ||||
| #include "frontend/optimizer/ad/dfunctor.h" | #include "frontend/optimizer/ad/dfunctor.h" | ||||
| #include "debug/anf_ir_dump.h" | #include "debug/anf_ir_dump.h" | ||||
| #include "debug/dump_proto.h" | |||||
| #include "debug/anf_ir_utils.h" | #include "debug/anf_ir_utils.h" | ||||
| #include "utils/config_manager.h" | #include "utils/config_manager.h" | ||||
| #include "utils/convert_utils.h" | #include "utils/convert_utils.h" | ||||
| #include "utils/convert_utils_py.h" | |||||
| #include "utils/context/context_extends.h" | #include "utils/context/context_extends.h" | ||||
| #include "vm/segment_runner.h" | #include "vm/segment_runner.h" | ||||
| #include "frontend/parallel/context.h" | #include "frontend/parallel/context.h" | ||||
| @@ -30,6 +30,7 @@ | |||||
| #include "transform/graph_ir/graph_runner.h" | #include "transform/graph_ir/graph_runner.h" | ||||
| #include "debug/draw.h" | #include "debug/draw.h" | ||||
| #include "abstract/abstract_value.h" | #include "abstract/abstract_value.h" | ||||
| #include "utils/convert_utils_py.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace pipeline { | namespace pipeline { | ||||
| @@ -17,11 +17,13 @@ | |||||
| #include "pipeline/jit/remove_value_node_dup.h" | #include "pipeline/jit/remove_value_node_dup.h" | ||||
| #include "ir/anf.h" | #include "ir/anf.h" | ||||
| #include "ir/func_graph.h" | |||||
| #include "ir/tensor.h" | #include "ir/tensor.h" | ||||
| #include "ir/manager.h" | #include "ir/manager.h" | ||||
| #include "frontend/optimizer/cse.h" | #include "frontend/optimizer/cse.h" | ||||
| #include "utils/log_adapter.h" | #include "utils/log_adapter.h" | ||||
| #include "utils/hashing.h" | #include "utils/hashing.h" | ||||
| #include "utils/convert_utils.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace pipeline { | namespace pipeline { | ||||
| @@ -34,6 +34,7 @@ | |||||
| #include "pipeline/jit/resource.h" | #include "pipeline/jit/resource.h" | ||||
| #include "pipeline/jit/parse/resolve.h" | #include "pipeline/jit/parse/resolve.h" | ||||
| #include "utils/convert_utils.h" | #include "utils/convert_utils.h" | ||||
| #include "utils/convert_utils_py.h" | |||||
| #include "utils/ms_context.h" | #include "utils/ms_context.h" | ||||
| #include "pipeline/jit/parse/data_converter.h" | #include "pipeline/jit/parse/data_converter.h" | ||||
| #include "abstract/primitive_infer_map.h" | #include "abstract/primitive_infer_map.h" | ||||
| @@ -30,6 +30,7 @@ | |||||
| #include "utils/ms_context.h" | #include "utils/ms_context.h" | ||||
| #include "utils/context/context_extends.h" | #include "utils/context/context_extends.h" | ||||
| #include "utils/config_manager.h" | #include "utils/config_manager.h" | ||||
| #include "utils/convert_utils_py.h" | |||||
| #include "frontend/operator/ops.h" | #include "frontend/operator/ops.h" | ||||
| #include "frontend/operator/composite/composite.h" | #include "frontend/operator/composite/composite.h" | ||||
| #include "frontend/operator/composite/do_signature.h" | #include "frontend/operator/composite/do_signature.h" | ||||
| @@ -21,6 +21,7 @@ | |||||
| #include "pipeline/jit/parse/data_converter.h" | #include "pipeline/jit/parse/data_converter.h" | ||||
| #include "pybind11/pytypes.h" | #include "pybind11/pytypes.h" | ||||
| #include "utils/convert_utils_base.h" | #include "utils/convert_utils_base.h" | ||||
| #include "utils/convert_utils_py.h" | |||||
| #include "utils/primitive_utils.h" | #include "utils/primitive_utils.h" | ||||
| #include "utils/base_ref_extends.h" | #include "utils/base_ref_extends.h" | ||||
| #include "utils/ms_context.h" | #include "utils/ms_context.h" | ||||
| @@ -21,6 +21,13 @@ | |||||
| #include <thread> | #include <thread> | ||||
| #include <atomic> | #include <atomic> | ||||
| #include "pybind11/pybind11.h" | |||||
| #include "utils/ms_utils.h" | |||||
| #include "utils/convert_utils_base.h" | |||||
| namespace py = pybind11; | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace context { | namespace context { | ||||
| #ifdef ENABLE_GE | #ifdef ENABLE_GE | ||||
| @@ -22,7 +22,6 @@ | |||||
| #include <memory> | #include <memory> | ||||
| #include "utils/ms_context.h" | #include "utils/ms_context.h" | ||||
| #include "utils/tensorprint_utils.h" | #include "utils/tensorprint_utils.h" | ||||
| #include "utils/convert_utils.h" | |||||
| #ifndef NO_DLIB | #ifndef NO_DLIB | ||||
| #include "tdt/tsd_client.h" | #include "tdt/tsd_client.h" | ||||
| @@ -24,211 +24,14 @@ | |||||
| #include <utility> | #include <utility> | ||||
| #include <cfloat> | #include <cfloat> | ||||
| #include "pybind11/pybind11.h" | |||||
| #include "abstract/abstract_value.h" | #include "abstract/abstract_value.h" | ||||
| #include "pipeline/jit/parse/parse.h" | |||||
| #include "pipeline/jit/parse/parse_base.h" | |||||
| #include "ir/value.h" | #include "ir/value.h" | ||||
| #include "ir/tensor.h" | #include "ir/tensor.h" | ||||
| #include "ir/param_info.h" | #include "ir/param_info.h" | ||||
| #include "utils/base_ref_extends.h" | |||||
| #include "utils/ms_context.h" | #include "utils/ms_context.h" | ||||
| #include "utils/shape_utils.h" | #include "utils/shape_utils.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| py::object BuiltinsToPyData(const Any &value); | |||||
| py::object BuiltinsToPyData(const BaseRef &value); | |||||
| py::object VectorToPyData(const Any &value); | |||||
| py::object VectorRefToPyData(const VectorRef &value); | |||||
| py::object ValuePtrToPyData(const ValuePtr &value) { | |||||
| if (value == nullptr) { | |||||
| MS_LOG(EXCEPTION) << "value is null"; | |||||
| } | |||||
| py::object ret; | |||||
| if (value->isa<Int8Imm>()) { | |||||
| MS_LOG(DEBUG) << "int8"; | |||||
| py::int_ v = value->cast<Int8ImmPtr>()->value(); | |||||
| ret = v; | |||||
| } else if (value->isa<Int16Imm>()) { | |||||
| MS_LOG(DEBUG) << "int16"; | |||||
| py::int_ v = value->cast<Int16ImmPtr>()->value(); | |||||
| ret = v; | |||||
| } else if (value->isa<Int32Imm>()) { | |||||
| MS_LOG(DEBUG) << "int32"; | |||||
| py::int_ v = value->cast<Int32ImmPtr>()->value(); | |||||
| ret = v; | |||||
| } else if (value->isa<Int64Imm>()) { | |||||
| MS_LOG(DEBUG) << "int64"; | |||||
| py::int_ v = value->cast<Int64ImmPtr>()->value(); | |||||
| ret = v; | |||||
| } else if (value->isa<UInt8Imm>()) { | |||||
| MS_LOG(DEBUG) << "uint8"; | |||||
| py::int_ v = value->cast<UInt8ImmPtr>()->value(); | |||||
| ret = v; | |||||
| } else if (value->isa<UInt16Imm>()) { | |||||
| MS_LOG(DEBUG) << "uint16"; | |||||
| py::int_ v = value->cast<UInt16ImmPtr>()->value(); | |||||
| ret = v; | |||||
| } else if (value->isa<UInt32Imm>()) { | |||||
| MS_LOG(DEBUG) << "uint32"; | |||||
| py::int_ v = value->cast<UInt32ImmPtr>()->value(); | |||||
| ret = v; | |||||
| } else if (value->isa<UInt64Imm>()) { | |||||
| MS_LOG(DEBUG) << "uint64"; | |||||
| py::int_ v = value->cast<UInt64ImmPtr>()->value(); | |||||
| ret = v; | |||||
| } else if (value->isa<BoolImm>()) { | |||||
| MS_LOG(DEBUG) << "bool"; | |||||
| py::bool_ v = value->cast<BoolImmPtr>()->value(); | |||||
| ret = v; | |||||
| } else if (value->isa<FP64Imm>()) { | |||||
| MS_LOG(DEBUG) << "double"; | |||||
| py::float_ v = value->cast<FP64ImmPtr>()->value(); | |||||
| ret = v; | |||||
| } else if (value->isa<FP32Imm>()) { | |||||
| MS_LOG(DEBUG) << "float"; | |||||
| py::float_ v = value->cast<FP32ImmPtr>()->value(); | |||||
| ret = v; | |||||
| } else if (value->isa<StringImm>()) { | |||||
| MS_LOG(DEBUG) << "String"; | |||||
| py::str v = value->cast<StringImmPtr>()->value(); | |||||
| ret = v; | |||||
| } else if (value->isa<tensor::Tensor>()) { | |||||
| MS_LOG(DEBUG) << "tensor"; | |||||
| py::tuple v(1); | |||||
| v[0] = value->cast<tensor::TensorPtr>(); | |||||
| ret = v[0]; | |||||
| } else if (value->isa<tensor::MetaTensor>()) { | |||||
| MS_LOG(DEBUG) << "MetaTensor"; | |||||
| py::tuple v(1); | |||||
| v[0] = value->cast<tensor::MetaTensorPtr>(); | |||||
| ret = v[0]; | |||||
| } else if (value->isa<RefKey>()) { | |||||
| MS_LOG(DEBUG) << "RefKey"; | |||||
| py::tuple v(1); | |||||
| v[0] = value->cast<RefKeyPtr>(); | |||||
| ret = v[0]; | |||||
| } else if (value->isa<ValueTuple>()) { | |||||
| MS_LOG(DEBUG) << "tuple"; | |||||
| auto value_tuple = value->cast<ValueTuplePtr>()->value(); | |||||
| py::tuple rets(value_tuple.size()); | |||||
| size_t i = 0; | |||||
| for (auto &v : value_tuple) { | |||||
| rets[i] = ValuePtrToPyData(v); | |||||
| i++; | |||||
| } | |||||
| ret = rets; | |||||
| } else if (value->isa<ValueList>()) { | |||||
| MS_LOG(DEBUG) << "list"; | |||||
| auto value_list = value->cast<ValueListPtr>()->value(); | |||||
| py::list rets(value_list.size()); | |||||
| size_t i = 0; | |||||
| for (auto &v : value_list) { | |||||
| rets[i] = ValuePtrToPyData(v); | |||||
| i++; | |||||
| } | |||||
| ret = rets; | |||||
| } else if (value->isa<Ellipsis>()) { | |||||
| ret = py::ellipsis(); | |||||
| } else if (value->isa<ValueSlice>()) { | |||||
| auto slice = value->cast<ValueSlicePtr>(); | |||||
| auto start = ValuePtrToPyData(slice->start()); | |||||
| auto end = ValuePtrToPyData(slice->stop()); | |||||
| auto step = ValuePtrToPyData(slice->step()); | |||||
| ret = parse::python_adapter::CallPyFn(parse::PYTHON_MOD_PARSE_MODULE, parse::PYTHON_PARSE_CLASS_SLICE, start, end, | |||||
| step); | |||||
| } else if (value->isa<Type>()) { | |||||
| py::tuple v(1); | |||||
| v[0] = value->cast<TypePtr>(); | |||||
| ret = v[0]; | |||||
| } else if (value->isa<AnyValue>()) { | |||||
| ret = py::none(); | |||||
| } else if (value->isa<None>()) { | |||||
| ret = py::none(); | |||||
| } else { | |||||
| MS_LOG(INFO) << "Unsupported convert value: " << value->ToString() << " to a PyData."; | |||||
| } | |||||
| return ret; | |||||
| } | |||||
| py::object AnyToPyData(const Any &value) { | |||||
| py::object ret; | |||||
| MS_LOG(DEBUG) << "AnyToPyData " << value.GetString(); | |||||
| if (value.is<int>() || value.is<float>() || value.is<double>() || value.is<bool>()) { | |||||
| ret = BuiltinsToPyData(value); | |||||
| } else if (value.is<ValuePtr>()) { | |||||
| MS_LOG(DEBUG) << "ValuePtr"; | |||||
| ValuePtr v = value.cast<ValuePtr>(); | |||||
| ret = ValuePtrToPyData(v); | |||||
| } else if (value.is<tensor::TensorPtr>()) { | |||||
| MS_LOG(DEBUG) << "tensor"; | |||||
| py::tuple v(1); | |||||
| v[0] = value.cast<tensor::TensorPtr>(); | |||||
| ret = v[0]; | |||||
| } else if (value.is<py::object>()) { | |||||
| MS_LOG(DEBUG) << "py obj"; | |||||
| ret = value.cast<py::object>(); | |||||
| } else if (value.is<std::vector<tensor::TensorPtr>>() || value.is<std::vector<Any>>()) { | |||||
| ret = VectorToPyData(value); | |||||
| } else if (value.is<std::list<Any>>()) { | |||||
| MS_LOG(DEBUG) << "list_any"; | |||||
| auto value_list = value.cast<std::list<Any>>(); | |||||
| py::list rets = py::list(); | |||||
| for (auto &v : value_list) { | |||||
| rets.append(AnyToPyData(v)); | |||||
| } | |||||
| ret = rets; | |||||
| } else if (value.is<std::vector<Any>>()) { | |||||
| auto value_list = value.cast<std::vector<Any>>(); | |||||
| py::tuple rets(value_list.size()); | |||||
| for (size_t i = 0; i < value_list.size(); i++) { | |||||
| rets[i] = AnyToPyData(value_list[i]); | |||||
| } | |||||
| ret = rets; | |||||
| } else if (value.is<TypePtr>()) { | |||||
| py::tuple v(1); | |||||
| v[0] = value.cast<TypePtr>(); | |||||
| ret = v[0]; | |||||
| } else { | |||||
| MS_LOG(EXCEPTION) << "value is not support type"; | |||||
| } | |||||
| return ret; | |||||
| } | |||||
| py::object BaseRefToPyData(const BaseRef &value) { | |||||
| py::object ret; | |||||
| MS_LOG(DEBUG) << "BaseRefToPyData " << value.ToString(); | |||||
| if (utils::isa<int>(value) || utils::isa<float>(value) || utils::isa<double>(value) || utils::isa<bool>(value)) { | |||||
| ret = BuiltinsToPyData(value); | |||||
| } else if (utils::isa<ValuePtr>(value)) { | |||||
| MS_LOG(DEBUG) << "ValuePtr"; | |||||
| ValuePtr v = utils::cast<ValuePtr>(value); | |||||
| ret = ValuePtrToPyData(v); | |||||
| } else if (utils::isa<tensor::TensorPtr>(value)) { | |||||
| MS_LOG(DEBUG) << "tensor"; | |||||
| py::tuple v(1); | |||||
| v[0] = utils::cast<tensor::TensorPtr>(value); | |||||
| ret = v[0]; | |||||
| } else if (utils::isa<PyObjectRef>(value)) { | |||||
| MS_LOG(DEBUG) << "py obj"; | |||||
| PyObjectRef py_ref = utils::cast<PyObjectRef>(value); | |||||
| ret = py_ref.object_; | |||||
| } else if (utils::isa<VectorRef>(value)) { | |||||
| auto vec_ref = utils::cast<VectorRef>(value); | |||||
| ret = VectorRefToPyData(vec_ref); | |||||
| } else if (utils::isa<TypePtr>(value)) { | |||||
| py::tuple v(1); | |||||
| v[0] = utils::cast<TypePtr>(value); | |||||
| ret = v[0]; | |||||
| } else { | |||||
| MS_LOG(EXCEPTION) << "value is not support type"; | |||||
| } | |||||
| return ret; | |||||
| } | |||||
| bool ValueToBool(const ValuePtr &v, bool *value) { | bool ValueToBool(const ValuePtr &v, bool *value) { | ||||
| MS_EXCEPTION_IF_NULL(v); | MS_EXCEPTION_IF_NULL(v); | ||||
| if (v->isa<BoolImm>()) { | if (v->isa<BoolImm>()) { | ||||
| @@ -315,185 +118,6 @@ bool BaseRefToBool(const BaseRef &v, bool *value) { | |||||
| return true; | return true; | ||||
| } | } | ||||
| py::object BuiltinsToPyData(const Any &value) { | |||||
| if (value.is<int>()) { | |||||
| MS_LOG(DEBUG) << "int"; | |||||
| py::int_ ret = value.cast<int>(); | |||||
| return std::move(ret); | |||||
| } else if (value.is<float>()) { | |||||
| MS_LOG(DEBUG) << "float"; | |||||
| py::float_ ret = value.cast<float>(); | |||||
| return std::move(ret); | |||||
| } else if (value.is<double>()) { | |||||
| MS_LOG(DEBUG) << "double"; | |||||
| py::float_ ret = value.cast<double>(); | |||||
| return std::move(ret); | |||||
| } else { | |||||
| MS_LOG(DEBUG) << "bool"; | |||||
| py::bool_ ret = value.cast<bool>(); | |||||
| return std::move(ret); | |||||
| } | |||||
| } | |||||
| py::object BuiltinsToPyData(const BaseRef &value) { | |||||
| if (utils::isa<int>(value)) { | |||||
| MS_LOG(DEBUG) << "int"; | |||||
| py::int_ ret = utils::cast<int>(value); | |||||
| return std::move(ret); | |||||
| } else if (utils::isa<float>(value)) { | |||||
| MS_LOG(DEBUG) << "float"; | |||||
| py::float_ ret = utils::cast<float>(value); | |||||
| return std::move(ret); | |||||
| } else if (utils::isa<double>(value)) { | |||||
| MS_LOG(DEBUG) << "double"; | |||||
| py::float_ ret = utils::cast<double>(value); | |||||
| return std::move(ret); | |||||
| } else { | |||||
| MS_LOG(DEBUG) << "bool"; | |||||
| py::bool_ ret = utils::cast<bool>(value); | |||||
| return std::move(ret); | |||||
| } | |||||
| } | |||||
| py::object VectorToPyData(const Any &value) { | |||||
| py::object ret; | |||||
| if (value.is<std::vector<tensor::TensorPtr>>()) { | |||||
| MS_LOG(DEBUG) << "vector_tensor"; | |||||
| std::vector<tensor::TensorPtr> outputs; | |||||
| outputs = value.cast<std::vector<tensor::TensorPtr>>(); | |||||
| py::tuple tensor_tuple(outputs.size()); | |||||
| for (std::size_t i = 0; i < outputs.size(); ++i) { | |||||
| tensor_tuple[i] = *outputs[i]; | |||||
| } | |||||
| ret = tensor_tuple; | |||||
| } else { | |||||
| MS_LOG(DEBUG) << "vector_any"; | |||||
| auto value_list = value.cast<std::vector<Any>>(); | |||||
| py::tuple any_tuple = py::tuple(value_list.size()); | |||||
| size_t i = 0; | |||||
| for (auto &v : value_list) { | |||||
| any_tuple[i] = AnyToPyData(v); | |||||
| i++; | |||||
| } | |||||
| ret = any_tuple; | |||||
| } | |||||
| return ret; | |||||
| } | |||||
| py::object VectorRefToPyData(const VectorRef &value_list) { | |||||
| py::object ret; | |||||
| MS_LOG(DEBUG) << "vector_ref"; | |||||
| size_t value_size = value_list.size(); | |||||
| auto ref_tuple = py::tuple(value_size); | |||||
| for (size_t i = 0; i < value_size; i++) { | |||||
| ref_tuple[i] = BaseRefToPyData(value_list[i]); | |||||
| } | |||||
| ret = ref_tuple; | |||||
| return ret; | |||||
| } | |||||
| AbstractBasePtr PyListDtype2AbstractTensor(const py::object &shape_obj, const py::object &type_obj, | |||||
| const py::object &min_shape, const py::object &max_shape) { | |||||
| if ((py::isinstance<py::list>(shape_obj) || py::isinstance<py::tuple>(shape_obj)) && py::isinstance<Type>(type_obj)) { | |||||
| auto ret_vec = shape_obj.cast<ShapeVector>(); | |||||
| auto ret_dtype = type_obj.cast<TypePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(ret_dtype); | |||||
| // if the size of shape list is empty, return an scalar abstract | |||||
| if (ret_vec.empty() && (!ret_dtype->isa<TensorType>())) { | |||||
| abstract::AbstractScalarPtr abs_scalar = std::make_shared<abstract::AbstractScalar>(kAnyValue, ret_dtype); | |||||
| return abs_scalar; | |||||
| } | |||||
| AbstractBasePtr tensor = nullptr; | |||||
| ShapeVector min_shape_vec; | |||||
| ShapeVector max_shape_vec; | |||||
| if (!min_shape.is_none()) { | |||||
| min_shape_vec = min_shape.cast<ShapeVector>(); | |||||
| } | |||||
| if (!max_shape.is_none()) { | |||||
| max_shape_vec = max_shape.cast<ShapeVector>(); | |||||
| } | |||||
| auto ret_shape = std::make_shared<abstract::Shape>(ret_vec, min_shape_vec, max_shape_vec); | |||||
| if (ret_dtype->isa<TensorType>()) { | |||||
| auto tensor_type = type_obj.cast<TensorTypePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(tensor_type); | |||||
| auto element = std::make_shared<abstract::AbstractScalar>(kAnyValue, tensor_type->element()); | |||||
| tensor = std::make_shared<abstract::AbstractTensor>(element, ret_shape); | |||||
| } else { | |||||
| auto element = std::make_shared<abstract::AbstractScalar>(kAnyValue, ret_dtype); | |||||
| tensor = std::make_shared<abstract::AbstractTensor>(element, ret_shape); | |||||
| } | |||||
| return tensor; | |||||
| } else if (py::isinstance<py::tuple>(shape_obj) && py::isinstance<py::tuple>(type_obj)) { | |||||
| py::tuple shape_tuple = shape_obj.cast<py::tuple>(); | |||||
| py::tuple typeid_tuple = type_obj.cast<py::tuple>(); | |||||
| AbstractBasePtrList ptr_list; | |||||
| for (size_t it = 0; it < shape_tuple.size(); ++it) { | |||||
| auto tensor_it = PyListDtype2AbstractTensor(shape_tuple[it], typeid_tuple[it]); | |||||
| ptr_list.push_back(tensor_it); | |||||
| } | |||||
| auto tuple = std::make_shared<abstract::AbstractTuple>(ptr_list); | |||||
| return tuple; | |||||
| } else if (shape_obj.is_none() && type_obj.is_none()) { | |||||
| // AbstractNone indicates there is no output for this CNode node. | |||||
| auto abstract_none = std::make_shared<abstract::AbstractNone>(); | |||||
| return abstract_none; | |||||
| } else { | |||||
| // When sparse enabled, the undetermined might be raised and eliminated in opt passes | |||||
| auto context = MsContext::GetInstance(); | |||||
| MS_EXCEPTION_IF_NULL(context); | |||||
| bool enable_sparse = context->enable_sparse(); | |||||
| if (enable_sparse) { | |||||
| return std::make_shared<abstract::AbstractUndetermined>(); | |||||
| } | |||||
| MS_LOG(EXCEPTION) << "Python evaluator return invalid shape or type. " << (std::string)py::str(type_obj); | |||||
| } | |||||
| } | |||||
| bool IsGraphOutputValueNodeOrParameter(const AnfNodePtr &output, const py::tuple &args, | |||||
| const std::shared_ptr<py::object> &ret_val) { | |||||
| if (output->isa<ValueNode>()) { | |||||
| MS_LOG(INFO) << "Graph's output is a constant. No need to execute."; | |||||
| ValuePtr value = GetValueNode(output); | |||||
| *ret_val = ValuePtrToPyData(value); | |||||
| return true; | |||||
| } | |||||
| // Adapter will transform values in __init__() and construct() to parameters, this could cause | |||||
| // inputs (a.k.a args in current function) size less than parameters'. | |||||
| if (output->isa<Parameter>()) { | |||||
| MS_LOG(INFO) << "Graph's output is a parameter. If all params are inputs, no need to execute."; | |||||
| // Find the right parameter as ret_val. | |||||
| auto func_graph = output->func_graph(); | |||||
| MS_EXCEPTION_IF_NULL(func_graph); | |||||
| auto params = func_graph->parameters(); | |||||
| if ((args.size() + func_graph->hyper_param_count()) != params.size()) { | |||||
| MS_LOG(EXCEPTION) << "Input size " << args.size() << " add Parameter count " << func_graph->hyper_param_count() | |||||
| << " not equal to graph input size " << params.size() << ", let graph to be executed."; | |||||
| } | |||||
| auto it = std::find(params.begin(), params.end(), output); | |||||
| if (it == params.end()) { | |||||
| MS_EXCEPTION(UnknownError) << "When graph output is Parameter, it should be found in graph parameters"; | |||||
| } | |||||
| size_t index = it - params.cbegin(); | |||||
| if (index >= args.size() + func_graph->hyper_param_count()) { | |||||
| MS_EXCEPTION(UnknownError) << "Index " << index << " equal or larger than args size " << args.size() | |||||
| << " add Parameter count " << func_graph->hyper_param_count() << "."; | |||||
| } | |||||
| if (index < args.size()) { | |||||
| *ret_val = args[index]; | |||||
| } else { | |||||
| auto param = dyn_cast<Parameter>(params[index]); | |||||
| MS_EXCEPTION_IF_NULL(param); | |||||
| if (!param->has_default()) { | |||||
| MS_LOG(EXCEPTION) << "Can not determine value of Parameter " << index << " (" << param->name() << ")"; | |||||
| } | |||||
| auto tensor = param->default_param(); | |||||
| *ret_val = py::cast(tensor); | |||||
| } | |||||
| return true; | |||||
| } | |||||
| return false; | |||||
| } | |||||
| namespace { | namespace { | ||||
| // Isomorphism | // Isomorphism | ||||
| bool SameNode(const AnfNodePtr &node1, const AnfNodePtr &node2, FuncGraphPairMapEquiv *equiv_func_graph, | bool SameNode(const AnfNodePtr &node1, const AnfNodePtr &node2, FuncGraphPairMapEquiv *equiv_func_graph, | ||||
| @@ -25,14 +25,12 @@ | |||||
| #include <unordered_map> | #include <unordered_map> | ||||
| #include <unordered_set> | #include <unordered_set> | ||||
| #include "pybind11/pybind11.h" | |||||
| #include "utils/convert_utils_base.h" | #include "utils/convert_utils_base.h" | ||||
| #include "utils/any.h" | #include "utils/any.h" | ||||
| #include "base/base_ref.h" | #include "base/base_ref.h" | ||||
| #include "base/base.h" | #include "base/base.h" | ||||
| #include "ir/anf.h" | #include "ir/anf.h" | ||||
| namespace py = pybind11; | |||||
| #include "ir/func_graph.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace tensor { | namespace tensor { | ||||
| @@ -40,19 +38,9 @@ class Tensor; | |||||
| using TensorPtr = std::shared_ptr<Tensor>; | using TensorPtr = std::shared_ptr<Tensor>; | ||||
| } // namespace tensor | } // namespace tensor | ||||
| py::object AnyToPyData(const Any &value); | |||||
| py::object BaseRefToPyData(const BaseRef &value); | |||||
| bool BaseRefToBool(const BaseRef &in, bool *out); | bool BaseRefToBool(const BaseRef &in, bool *out); | ||||
| bool BaseRefToInt(const ValuePtr &v, int *value); | bool BaseRefToInt(const ValuePtr &v, int *value); | ||||
| bool ValueToBool(const ValuePtr &in, bool *out); | bool ValueToBool(const ValuePtr &in, bool *out); | ||||
| py::object ValuePtrToPyData(const ValuePtr &value); | |||||
| AbstractBasePtr PyListDtype2AbstractTensor(const py::object &shape_obj, const py::object &type_obj, | |||||
| const py::object &min_shape = py::none(), | |||||
| const py::object &max_shape = py::none()); | |||||
| bool IsGraphOutputValueNodeOrParameter(const AnfNodePtr &output, const py::tuple &args, | |||||
| const std::shared_ptr<py::object> &ret_val); | |||||
| // Isomorphism | // Isomorphism | ||||
| struct PairHasher { | struct PairHasher { | ||||
| @@ -0,0 +1,409 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "utils/convert_utils_py.h" | |||||
| #include <vector> | |||||
| #include <string> | |||||
| #include <memory> | |||||
| #include <algorithm> | |||||
| #include <list> | |||||
| #include <utility> | |||||
| #include <cfloat> | |||||
| #include "abstract/abstract_value.h" | |||||
| #include "pipeline/jit/parse/parse.h" | |||||
| #include "pipeline/jit/parse/parse_base.h" | |||||
| #include "ir/value.h" | |||||
| #include "ir/tensor.h" | |||||
| #include "ir/param_info.h" | |||||
| #include "pybind_api/ir/base_ref_py.h" | |||||
| #include "utils/ms_context.h" | |||||
| namespace mindspore { | |||||
| py::object BuiltinsToPyData(const Any &value); | |||||
| py::object BuiltinsToPyData(const BaseRef &value); | |||||
| py::object VectorToPyData(const Any &value); | |||||
| py::object VectorRefToPyData(const VectorRef &value); | |||||
| py::object ValuePtrToPyData(const ValuePtr &value) { | |||||
| if (value == nullptr) { | |||||
| MS_LOG(EXCEPTION) << "value is null"; | |||||
| } | |||||
| py::object ret; | |||||
| if (value->isa<Int8Imm>()) { | |||||
| MS_LOG(DEBUG) << "int8"; | |||||
| py::int_ v = value->cast<Int8ImmPtr>()->value(); | |||||
| ret = v; | |||||
| } else if (value->isa<Int16Imm>()) { | |||||
| MS_LOG(DEBUG) << "int16"; | |||||
| py::int_ v = value->cast<Int16ImmPtr>()->value(); | |||||
| ret = v; | |||||
| } else if (value->isa<Int32Imm>()) { | |||||
| MS_LOG(DEBUG) << "int32"; | |||||
| py::int_ v = value->cast<Int32ImmPtr>()->value(); | |||||
| ret = v; | |||||
| } else if (value->isa<Int64Imm>()) { | |||||
| MS_LOG(DEBUG) << "int64"; | |||||
| py::int_ v = value->cast<Int64ImmPtr>()->value(); | |||||
| ret = v; | |||||
| } else if (value->isa<UInt8Imm>()) { | |||||
| MS_LOG(DEBUG) << "uint8"; | |||||
| py::int_ v = value->cast<UInt8ImmPtr>()->value(); | |||||
| ret = v; | |||||
| } else if (value->isa<UInt16Imm>()) { | |||||
| MS_LOG(DEBUG) << "uint16"; | |||||
| py::int_ v = value->cast<UInt16ImmPtr>()->value(); | |||||
| ret = v; | |||||
| } else if (value->isa<UInt32Imm>()) { | |||||
| MS_LOG(DEBUG) << "uint32"; | |||||
| py::int_ v = value->cast<UInt32ImmPtr>()->value(); | |||||
| ret = v; | |||||
| } else if (value->isa<UInt64Imm>()) { | |||||
| MS_LOG(DEBUG) << "uint64"; | |||||
| py::int_ v = value->cast<UInt64ImmPtr>()->value(); | |||||
| ret = v; | |||||
| } else if (value->isa<BoolImm>()) { | |||||
| MS_LOG(DEBUG) << "bool"; | |||||
| py::bool_ v = value->cast<BoolImmPtr>()->value(); | |||||
| ret = v; | |||||
| } else if (value->isa<FP64Imm>()) { | |||||
| MS_LOG(DEBUG) << "double"; | |||||
| py::float_ v = value->cast<FP64ImmPtr>()->value(); | |||||
| ret = v; | |||||
| } else if (value->isa<FP32Imm>()) { | |||||
| MS_LOG(DEBUG) << "float"; | |||||
| py::float_ v = value->cast<FP32ImmPtr>()->value(); | |||||
| ret = v; | |||||
| } else if (value->isa<StringImm>()) { | |||||
| MS_LOG(DEBUG) << "String"; | |||||
| py::str v = value->cast<StringImmPtr>()->value(); | |||||
| ret = v; | |||||
| } else if (value->isa<tensor::Tensor>()) { | |||||
| MS_LOG(DEBUG) << "tensor"; | |||||
| py::tuple v(1); | |||||
| v[0] = value->cast<tensor::TensorPtr>(); | |||||
| ret = v[0]; | |||||
| } else if (value->isa<tensor::MetaTensor>()) { | |||||
| MS_LOG(DEBUG) << "MetaTensor"; | |||||
| py::tuple v(1); | |||||
| v[0] = value->cast<tensor::MetaTensorPtr>(); | |||||
| ret = v[0]; | |||||
| } else if (value->isa<RefKey>()) { | |||||
| MS_LOG(DEBUG) << "RefKey"; | |||||
| py::tuple v(1); | |||||
| v[0] = value->cast<RefKeyPtr>(); | |||||
| ret = v[0]; | |||||
| } else if (value->isa<ValueTuple>()) { | |||||
| MS_LOG(DEBUG) << "tuple"; | |||||
| auto value_tuple = value->cast<ValueTuplePtr>()->value(); | |||||
| py::tuple rets(value_tuple.size()); | |||||
| size_t i = 0; | |||||
| for (auto &v : value_tuple) { | |||||
| rets[i] = ValuePtrToPyData(v); | |||||
| i++; | |||||
| } | |||||
| ret = rets; | |||||
| } else if (value->isa<ValueList>()) { | |||||
| MS_LOG(DEBUG) << "list"; | |||||
| auto value_list = value->cast<ValueListPtr>()->value(); | |||||
| py::list rets(value_list.size()); | |||||
| size_t i = 0; | |||||
| for (auto &v : value_list) { | |||||
| rets[i] = ValuePtrToPyData(v); | |||||
| i++; | |||||
| } | |||||
| ret = rets; | |||||
| } else if (value->isa<Ellipsis>()) { | |||||
| ret = py::ellipsis(); | |||||
| } else if (value->isa<ValueSlice>()) { | |||||
| auto slice = value->cast<ValueSlicePtr>(); | |||||
| auto start = ValuePtrToPyData(slice->start()); | |||||
| auto end = ValuePtrToPyData(slice->stop()); | |||||
| auto step = ValuePtrToPyData(slice->step()); | |||||
| ret = parse::python_adapter::CallPyFn(parse::PYTHON_MOD_PARSE_MODULE, parse::PYTHON_PARSE_CLASS_SLICE, start, end, | |||||
| step); | |||||
| } else if (value->isa<Type>()) { | |||||
| py::tuple v(1); | |||||
| v[0] = value->cast<TypePtr>(); | |||||
| ret = v[0]; | |||||
| } else if (value->isa<AnyValue>()) { | |||||
| ret = py::none(); | |||||
| } else if (value->isa<None>()) { | |||||
| ret = py::none(); | |||||
| } else { | |||||
| MS_LOG(INFO) << "Unsupported convert value: " << value->ToString() << " to a PyData."; | |||||
| } | |||||
| return ret; | |||||
| } | |||||
| py::object AnyToPyData(const Any &value) { | |||||
| py::object ret; | |||||
| MS_LOG(DEBUG) << "AnyToPyData " << value.GetString(); | |||||
| if (value.is<int>() || value.is<float>() || value.is<double>() || value.is<bool>()) { | |||||
| ret = BuiltinsToPyData(value); | |||||
| } else if (value.is<ValuePtr>()) { | |||||
| MS_LOG(DEBUG) << "ValuePtr"; | |||||
| ValuePtr v = value.cast<ValuePtr>(); | |||||
| ret = ValuePtrToPyData(v); | |||||
| } else if (value.is<tensor::TensorPtr>()) { | |||||
| MS_LOG(DEBUG) << "tensor"; | |||||
| py::tuple v(1); | |||||
| v[0] = value.cast<tensor::TensorPtr>(); | |||||
| ret = v[0]; | |||||
| } else if (value.is<py::object>()) { | |||||
| MS_LOG(DEBUG) << "py obj"; | |||||
| ret = value.cast<py::object>(); | |||||
| } else if (value.is<std::vector<tensor::TensorPtr>>() || value.is<std::vector<Any>>()) { | |||||
| ret = VectorToPyData(value); | |||||
| } else if (value.is<std::list<Any>>()) { | |||||
| MS_LOG(DEBUG) << "list_any"; | |||||
| auto value_list = value.cast<std::list<Any>>(); | |||||
| py::list rets = py::list(); | |||||
| for (auto &v : value_list) { | |||||
| rets.append(AnyToPyData(v)); | |||||
| } | |||||
| ret = rets; | |||||
| } else if (value.is<std::vector<Any>>()) { | |||||
| auto value_list = value.cast<std::vector<Any>>(); | |||||
| py::tuple rets(value_list.size()); | |||||
| for (size_t i = 0; i < value_list.size(); i++) { | |||||
| rets[i] = AnyToPyData(value_list[i]); | |||||
| } | |||||
| ret = rets; | |||||
| } else if (value.is<TypePtr>()) { | |||||
| py::tuple v(1); | |||||
| v[0] = value.cast<TypePtr>(); | |||||
| ret = v[0]; | |||||
| } else { | |||||
| MS_LOG(EXCEPTION) << "value is not support type"; | |||||
| } | |||||
| return ret; | |||||
| } | |||||
| py::object BaseRefToPyData(const BaseRef &value) { | |||||
| py::object ret; | |||||
| MS_LOG(DEBUG) << "BaseRefToPyData " << value.ToString(); | |||||
| if (utils::isa<int>(value) || utils::isa<float>(value) || utils::isa<double>(value) || utils::isa<bool>(value)) { | |||||
| ret = BuiltinsToPyData(value); | |||||
| } else if (utils::isa<ValuePtr>(value)) { | |||||
| MS_LOG(DEBUG) << "ValuePtr"; | |||||
| ValuePtr v = utils::cast<ValuePtr>(value); | |||||
| ret = ValuePtrToPyData(v); | |||||
| } else if (utils::isa<tensor::TensorPtr>(value)) { | |||||
| MS_LOG(DEBUG) << "tensor"; | |||||
| py::tuple v(1); | |||||
| v[0] = utils::cast<tensor::TensorPtr>(value); | |||||
| ret = v[0]; | |||||
| } else if (utils::isa<PyObjectRef>(value)) { | |||||
| MS_LOG(DEBUG) << "py obj"; | |||||
| PyObjectRef py_ref = utils::cast<PyObjectRef>(value); | |||||
| ret = py_ref.object_; | |||||
| } else if (utils::isa<VectorRef>(value)) { | |||||
| auto vec_ref = utils::cast<VectorRef>(value); | |||||
| ret = VectorRefToPyData(vec_ref); | |||||
| } else if (utils::isa<TypePtr>(value)) { | |||||
| py::tuple v(1); | |||||
| v[0] = utils::cast<TypePtr>(value); | |||||
| ret = v[0]; | |||||
| } else { | |||||
| MS_LOG(EXCEPTION) << "value is not support type"; | |||||
| } | |||||
| return ret; | |||||
| } | |||||
| py::object BuiltinsToPyData(const Any &value) { | |||||
| if (value.is<int>()) { | |||||
| MS_LOG(DEBUG) << "int"; | |||||
| py::int_ ret = value.cast<int>(); | |||||
| return std::move(ret); | |||||
| } else if (value.is<float>()) { | |||||
| MS_LOG(DEBUG) << "float"; | |||||
| py::float_ ret = value.cast<float>(); | |||||
| return std::move(ret); | |||||
| } else if (value.is<double>()) { | |||||
| MS_LOG(DEBUG) << "double"; | |||||
| py::float_ ret = value.cast<double>(); | |||||
| return std::move(ret); | |||||
| } else { | |||||
| MS_LOG(DEBUG) << "bool"; | |||||
| py::bool_ ret = value.cast<bool>(); | |||||
| return std::move(ret); | |||||
| } | |||||
| } | |||||
| py::object BuiltinsToPyData(const BaseRef &value) { | |||||
| if (utils::isa<int>(value)) { | |||||
| MS_LOG(DEBUG) << "int"; | |||||
| py::int_ ret = utils::cast<int>(value); | |||||
| return std::move(ret); | |||||
| } else if (utils::isa<float>(value)) { | |||||
| MS_LOG(DEBUG) << "float"; | |||||
| py::float_ ret = utils::cast<float>(value); | |||||
| return std::move(ret); | |||||
| } else if (utils::isa<double>(value)) { | |||||
| MS_LOG(DEBUG) << "double"; | |||||
| py::float_ ret = utils::cast<double>(value); | |||||
| return std::move(ret); | |||||
| } else { | |||||
| MS_LOG(DEBUG) << "bool"; | |||||
| py::bool_ ret = utils::cast<bool>(value); | |||||
| return std::move(ret); | |||||
| } | |||||
| } | |||||
| py::object VectorToPyData(const Any &value) { | |||||
| py::object ret; | |||||
| if (value.is<std::vector<tensor::TensorPtr>>()) { | |||||
| MS_LOG(DEBUG) << "vector_tensor"; | |||||
| std::vector<tensor::TensorPtr> outputs; | |||||
| outputs = value.cast<std::vector<tensor::TensorPtr>>(); | |||||
| py::tuple tensor_tuple(outputs.size()); | |||||
| for (std::size_t i = 0; i < outputs.size(); ++i) { | |||||
| tensor_tuple[i] = *outputs[i]; | |||||
| } | |||||
| ret = tensor_tuple; | |||||
| } else { | |||||
| MS_LOG(DEBUG) << "vector_any"; | |||||
| auto value_list = value.cast<std::vector<Any>>(); | |||||
| py::tuple any_tuple = py::tuple(value_list.size()); | |||||
| size_t i = 0; | |||||
| for (auto &v : value_list) { | |||||
| any_tuple[i] = AnyToPyData(v); | |||||
| i++; | |||||
| } | |||||
| ret = any_tuple; | |||||
| } | |||||
| return ret; | |||||
| } | |||||
| py::object VectorRefToPyData(const VectorRef &value_list) { | |||||
| py::object ret; | |||||
| MS_LOG(DEBUG) << "vector_ref"; | |||||
| size_t value_size = value_list.size(); | |||||
| auto ref_tuple = py::tuple(value_size); | |||||
| for (size_t i = 0; i < value_size; i++) { | |||||
| ref_tuple[i] = BaseRefToPyData(value_list[i]); | |||||
| } | |||||
| ret = ref_tuple; | |||||
| return ret; | |||||
| } | |||||
| AbstractBasePtr PyListDtype2AbstractTensor(const py::object &shape_obj, const py::object &type_obj, | |||||
| const py::object &min_shape, const py::object &max_shape) { | |||||
| if ((py::isinstance<py::list>(shape_obj) || py::isinstance<py::tuple>(shape_obj)) && py::isinstance<Type>(type_obj)) { | |||||
| auto ret_vec = shape_obj.cast<ShapeVector>(); | |||||
| auto ret_dtype = type_obj.cast<TypePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(ret_dtype); | |||||
| // if the size of shape list is empty, return an scalar abstract | |||||
| if (ret_vec.empty() && (!ret_dtype->isa<TensorType>())) { | |||||
| abstract::AbstractScalarPtr abs_scalar = std::make_shared<abstract::AbstractScalar>(kAnyValue, ret_dtype); | |||||
| return abs_scalar; | |||||
| } | |||||
| AbstractBasePtr tensor = nullptr; | |||||
| ShapeVector min_shape_vec; | |||||
| ShapeVector max_shape_vec; | |||||
| if (!min_shape.is_none()) { | |||||
| min_shape_vec = min_shape.cast<ShapeVector>(); | |||||
| } | |||||
| if (!max_shape.is_none()) { | |||||
| max_shape_vec = max_shape.cast<ShapeVector>(); | |||||
| } | |||||
| auto ret_shape = std::make_shared<abstract::Shape>(ret_vec, min_shape_vec, max_shape_vec); | |||||
| if (ret_dtype->isa<TensorType>()) { | |||||
| auto tensor_type = type_obj.cast<TensorTypePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(tensor_type); | |||||
| auto element = std::make_shared<abstract::AbstractScalar>(kAnyValue, tensor_type->element()); | |||||
| tensor = std::make_shared<abstract::AbstractTensor>(element, ret_shape); | |||||
| } else { | |||||
| auto element = std::make_shared<abstract::AbstractScalar>(kAnyValue, ret_dtype); | |||||
| tensor = std::make_shared<abstract::AbstractTensor>(element, ret_shape); | |||||
| } | |||||
| return tensor; | |||||
| } else if (py::isinstance<py::tuple>(shape_obj) && py::isinstance<py::tuple>(type_obj)) { | |||||
| py::tuple shape_tuple = shape_obj.cast<py::tuple>(); | |||||
| py::tuple typeid_tuple = type_obj.cast<py::tuple>(); | |||||
| AbstractBasePtrList ptr_list; | |||||
| for (size_t it = 0; it < shape_tuple.size(); ++it) { | |||||
| auto tensor_it = PyListDtype2AbstractTensor(shape_tuple[it], typeid_tuple[it]); | |||||
| ptr_list.push_back(tensor_it); | |||||
| } | |||||
| auto tuple = std::make_shared<abstract::AbstractTuple>(ptr_list); | |||||
| return tuple; | |||||
| } else if (shape_obj.is_none() && type_obj.is_none()) { | |||||
| // AbstractNone indicates there is no output for this CNode node. | |||||
| auto abstract_none = std::make_shared<abstract::AbstractNone>(); | |||||
| return abstract_none; | |||||
| } else { | |||||
| // When sparse enabled, the undetermined might be raised and eliminated in opt passes | |||||
| auto context = MsContext::GetInstance(); | |||||
| MS_EXCEPTION_IF_NULL(context); | |||||
| bool enable_sparse = context->enable_sparse(); | |||||
| if (enable_sparse) { | |||||
| return std::make_shared<abstract::AbstractUndetermined>(); | |||||
| } | |||||
| MS_LOG(EXCEPTION) << "Python evaluator return invalid shape or type. " << (std::string)py::str(type_obj); | |||||
| } | |||||
| } | |||||
| bool IsGraphOutputValueNodeOrParameter(const AnfNodePtr &output, const py::tuple &args, | |||||
| const std::shared_ptr<py::object> &ret_val) { | |||||
| if (output->isa<ValueNode>()) { | |||||
| MS_LOG(INFO) << "Graph's output is a constant. No need to execute."; | |||||
| ValuePtr value = GetValueNode(output); | |||||
| *ret_val = ValuePtrToPyData(value); | |||||
| return true; | |||||
| } | |||||
| // Adapter will transform values in __init__() and construct() to parameters, this could cause | |||||
| // inputs (a.k.a args in current function) size less than parameters'. | |||||
| if (output->isa<Parameter>()) { | |||||
| MS_LOG(INFO) << "Graph's output is a parameter. If all params are inputs, no need to execute."; | |||||
| // Find the right parameter as ret_val. | |||||
| auto func_graph = output->func_graph(); | |||||
| MS_EXCEPTION_IF_NULL(func_graph); | |||||
| auto params = func_graph->parameters(); | |||||
| if ((args.size() + func_graph->hyper_param_count()) != params.size()) { | |||||
| MS_LOG(EXCEPTION) << "Input size " << args.size() << " add Parameter count " << func_graph->hyper_param_count() | |||||
| << " not equal to graph input size " << params.size() << ", let graph to be executed."; | |||||
| } | |||||
| auto it = std::find(params.begin(), params.end(), output); | |||||
| if (it == params.end()) { | |||||
| MS_EXCEPTION(UnknownError) << "When graph output is Parameter, it should be found in graph parameters"; | |||||
| } | |||||
| size_t index = it - params.cbegin(); | |||||
| if (index >= args.size() + func_graph->hyper_param_count()) { | |||||
| MS_EXCEPTION(UnknownError) << "Index " << index << " equal or larger than args size " << args.size() | |||||
| << " add Parameter count " << func_graph->hyper_param_count() << "."; | |||||
| } | |||||
| if (index < args.size()) { | |||||
| *ret_val = args[index]; | |||||
| } else { | |||||
| auto param = dyn_cast<Parameter>(params[index]); | |||||
| MS_EXCEPTION_IF_NULL(param); | |||||
| if (!param->has_default()) { | |||||
| MS_LOG(EXCEPTION) << "Can not determine value of Parameter " << index << " (" << param->name() << ")"; | |||||
| } | |||||
| auto tensor = param->default_param(); | |||||
| *ret_val = py::cast(tensor); | |||||
| } | |||||
| return true; | |||||
| } | |||||
| return false; | |||||
| } | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,43 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_UTILS_CONVERT_UTILS_PY_H_ | |||||
| #define MINDSPORE_CCSRC_UTILS_CONVERT_UTILS_PY_H_ | |||||
| #include <memory> | |||||
| #include "pybind11/pybind11.h" | |||||
| #include "utils/convert_utils_base.h" | |||||
| #include "utils/any.h" | |||||
| #include "utils/base_ref_extends.h" | |||||
| #include "ir/anf.h" | |||||
| namespace py = pybind11; | |||||
| namespace mindspore { | |||||
| py::object AnyToPyData(const Any &value); | |||||
| py::object BaseRefToPyData(const BaseRef &value); | |||||
| py::object ValuePtrToPyData(const ValuePtr &value); | |||||
| bool IsGraphOutputValueNodeOrParameter(const AnfNodePtr &output, const py::tuple &args, | |||||
| const std::shared_ptr<py::object> &ret_val); | |||||
| AbstractBasePtr PyListDtype2AbstractTensor(const py::object &shape_obj, const py::object &type_obj, | |||||
| const py::object &min_shape = py::none(), | |||||
| const py::object &max_shape = py::none()); | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_UTILS_CONVERT_UTILS_PY_H_ | |||||
| @@ -22,6 +22,7 @@ | |||||
| #include "utils/log_adapter.h" | #include "utils/log_adapter.h" | ||||
| #include "utils/ms_utils.h" | #include "utils/ms_utils.h" | ||||
| #include "utils/base_ref_extends.h" | #include "utils/base_ref_extends.h" | ||||
| #include "utils/convert_utils_py.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| py::function GetBpropFunctionByObj(py::object obj) { | py::function GetBpropFunctionByObj(py::object obj) { | ||||
| @@ -24,7 +24,6 @@ | |||||
| #include <memory> | #include <memory> | ||||
| #include <vector> | #include <vector> | ||||
| #include "utils/base_ref_extends.h" | |||||
| #include "ir/anf.h" | #include "ir/anf.h" | ||||
| #include "ir/manager.h" | #include "ir/manager.h" | ||||
| #include "ir/tensor.h" | #include "ir/tensor.h" | ||||
| @@ -21,7 +21,7 @@ | |||||
| #include "ir/anf.h" | #include "ir/anf.h" | ||||
| #include "frontend/operator/ops.h" | #include "frontend/operator/ops.h" | ||||
| #include "frontend/optimizer/cse.h" | |||||
| #include "frontend/optimizer/cse_pass.h" | |||||
| #include "frontend/optimizer/optimizer.h" | #include "frontend/optimizer/optimizer.h" | ||||
| #include "frontend/optimizer/irpass.h" | #include "frontend/optimizer/irpass.h" | ||||
| #include "debug/draw.h" | #include "debug/draw.h" | ||||
| @@ -53,7 +53,7 @@ TEST_F(TestOptOptimizer, test_step_opt) { | |||||
| irpass.inline_, | irpass.inline_, | ||||
| }}, | }}, | ||||
| {"grad", {irpass.expand_jprim_}}, | {"grad", {irpass.expand_jprim_}}, | ||||
| {"cse", OptPassConfig(CSE(false))}}, | |||||
| {"cse", OptPassConfig(CSEPass(false))}}, | |||||
| true); | true); | ||||
| EXPECT_TRUE(optimizer.get() != nullptr); | EXPECT_TRUE(optimizer.get() != nullptr); | ||||
| @@ -20,6 +20,7 @@ | |||||
| #include "debug/draw.h" | #include "debug/draw.h" | ||||
| #include "frontend/operator/ops.h" | #include "frontend/operator/ops.h" | ||||
| #include "pipeline/jit/static_analysis/static_analysis.h" | #include "pipeline/jit/static_analysis/static_analysis.h" | ||||
| #include "utils/convert_utils_py.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace parallel { | namespace parallel { | ||||
| @@ -13,10 +13,12 @@ | |||||
| * See the License for the specific language governing permissions and | * See the License for the specific language governing permissions and | ||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "debug/anf_ir_utils.h" | |||||
| #include "debug/dump_proto.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| void DumpIRProto(const FuncGraphPtr &func_graph, const std::string &suffix) { return; } | |||||
| std::string GetFuncGraphProtoString(const FuncGraphPtr &func_graph) { return ""; } | std::string GetFuncGraphProtoString(const FuncGraphPtr &func_graph) { return ""; } | ||||
| std::string GetOnnxProtoString(const FuncGraphPtr &func_graph) { return ""; } | std::string GetOnnxProtoString(const FuncGraphPtr &func_graph) { return ""; } | ||||
| @@ -29,6 +29,7 @@ | |||||
| #include "vm/transform.h" | #include "vm/transform.h" | ||||
| #include "ir/tensor.h" | #include "ir/tensor.h" | ||||
| #include "utils/convert_utils.h" | #include "utils/convert_utils.h" | ||||
| #include "utils/convert_utils_py.h" | |||||
| #include "utils/log_adapter.h" | #include "utils/log_adapter.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||