| @@ -21,7 +21,7 @@ namespace mindspore { | |||||
| namespace opt { | namespace opt { | ||||
| void BackendCommonOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph); | void BackendCommonOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph); | ||||
| void CommonFinalOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph); | void CommonFinalOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph); | ||||
| void CommonUnifyMindIR(const std::shared_ptr<session::KernelGraph> &kernel_graph); | |||||
| void CommonUnifyMindIR(const std::shared_ptr<session::KernelGraph> &kernel_graph); // for debug | |||||
| void AddDynamicShapeAttrPass(const std::shared_ptr<session::KernelGraph> &kernel_graph); | void AddDynamicShapeAttrPass(const std::shared_ptr<session::KernelGraph> &kernel_graph); | ||||
| void EliminateIllegalDataTypePass(const std::shared_ptr<session::KernelGraph> &kernel_graph); | void EliminateIllegalDataTypePass(const std::shared_ptr<session::KernelGraph> &kernel_graph); | ||||
| void DynamicShapeConvertPass(const std::shared_ptr<session::KernelGraph> &kernel_graph); | void DynamicShapeConvertPass(const std::shared_ptr<session::KernelGraph> &kernel_graph); | ||||
| @@ -47,7 +47,7 @@ class AscendSession : public SessionBasic { | |||||
| static void BatchBuildKernel(const std::vector<std::shared_ptr<SessionTask>> &build_tasks); | static void BatchBuildKernel(const std::vector<std::shared_ptr<SessionTask>> &build_tasks); | ||||
| protected: | |||||
| protected: // load graph to device related | |||||
| void UnifyMindIR(const KernelGraphPtr &graph) override; | void UnifyMindIR(const KernelGraphPtr &graph) override; | ||||
| GraphId CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) override; | GraphId CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) override; | ||||
| GraphId CompileGraphImpl(NotNull<FuncGraphPtr> func_graph) override; | GraphId CompileGraphImpl(NotNull<FuncGraphPtr> func_graph) override; | ||||
| @@ -31,7 +31,7 @@ class CPUSession : public SessionBasic { | |||||
| ~CPUSession() override = default; | ~CPUSession() override = default; | ||||
| void Init(uint32_t device_id) override; | void Init(uint32_t device_id) override; | ||||
| protected: | |||||
| protected: // load graph | |||||
| void UnifyMindIR(const KernelGraphPtr &graph) override { SessionBasic::UnifyMindIR(graph); } | void UnifyMindIR(const KernelGraphPtr &graph) override { SessionBasic::UnifyMindIR(graph); } | ||||
| void CreateOutputTensors(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &input_tensors, VectorRef *, | void CreateOutputTensors(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &input_tensors, VectorRef *, | ||||
| std::map<tensor::TensorPtr, session::KernelWithIndex> *tensor_to_node, | std::map<tensor::TensorPtr, session::KernelWithIndex> *tensor_to_node, | ||||
| @@ -228,7 +228,7 @@ class BACKEND_EXPORT SessionBasic : public std::enable_shared_from_this<SessionB | |||||
| // When the device address of the node is used as the output of the graph, the device address will be passed | // When the device address of the node is used as the output of the graph, the device address will be passed | ||||
| // to the output tensor, and the output node will recreate a new device address. This third parameter records | // to the output tensor, and the output node will recreate a new device address. This third parameter records | ||||
| // the relationship between the new and old device address. | // the relationship between the new and old device address. | ||||
| virtual void UpdateOutputTensors(const VectorRef *outputs, | |||||
| virtual void UpdateOutputTensors(const VectorRef *outputs, // the output of graph | |||||
| const std::map<tensor::TensorPtr, session::KernelWithIndex> &tensor_to_node, | const std::map<tensor::TensorPtr, session::KernelWithIndex> &tensor_to_node, | ||||
| std::map<DeviceAddressPtr, DeviceAddressPtr> *); | std::map<DeviceAddressPtr, DeviceAddressPtr> *); | ||||
| virtual void UnifyMindIR(const KernelGraphPtr &graph); | virtual void UnifyMindIR(const KernelGraphPtr &graph); | ||||
| @@ -42,7 +42,7 @@ constexpr size_t kDependInputNum = 3; | |||||
| constexpr size_t kDependFirstInputIdx = 1; | constexpr size_t kDependFirstInputIdx = 1; | ||||
| constexpr size_t kTupleGetItemFirstInputIdx = 1; | constexpr size_t kTupleGetItemFirstInputIdx = 1; | ||||
| } // namespace | } // namespace | ||||
| STATUS MindsporeImporter::Mindir2AnfAdjust(const FuncGraphPtr &func_graph, const converter::Flags &flag) { | |||||
| STATUS MindsporeImporter::Mindir2AnfAdjust(const FuncGraphPtr &func_graph, const converter::Flags &flag) { // fmk: Ms | |||||
| MS_ASSERT(func_graph != nullptr); | MS_ASSERT(func_graph != nullptr); | ||||
| auto primitive_adjust_pass = std::make_shared<PrimitiveAdjust>(); | auto primitive_adjust_pass = std::make_shared<PrimitiveAdjust>(); | ||||
| MS_CHECK_TRUE_MSG(primitive_adjust_pass != nullptr, RET_NULL_PTR, "primitive_adjust_pass is nullptr."); | MS_CHECK_TRUE_MSG(primitive_adjust_pass != nullptr, RET_NULL_PTR, "primitive_adjust_pass is nullptr."); | ||||
| @@ -77,7 +77,7 @@ void CodeMSModelBuild(std::ofstream &ofs, const Configurator *config) { | |||||
| " return kMSStatusLiteNotSupport;\n" | " return kMSStatusLiteNotSupport;\n" | ||||
| " }\n"; | " }\n"; | ||||
| ofs << " int ret = RET_OK;\n"; | ofs << " int ret = RET_OK;\n"; | ||||
| if (config->target() != kARM32M) { | |||||
| if (config->target() != kARM32M) { // only support ARM32M | |||||
| ofs << " ret = Init((void*)model_data, data_size);\n"; | ofs << " ret = Init((void*)model_data, data_size);\n"; | ||||
| } | } | ||||
| if (config->support_parallel()) { | if (config->support_parallel()) { | ||||
| @@ -23,7 +23,7 @@ from mindspore.common.dtype import dtype_to_nptype, pytype_to_dtype | |||||
| from mindspore.common import dtype as mstype | from mindspore.common import dtype as mstype | ||||
| from mindspore import log as logger | from mindspore import log as logger | ||||
| from mindspore.common.api import _cell_graph_executor | from mindspore.common.api import _cell_graph_executor | ||||
| from mindspore.train.mind_ir_pb2 import ModelProto as mindir_model | |||||
| from mindspore.train.mind_ir_pb2 import ModelProto as mindir_model | |||||
| from mindspore.train.checkpoint_pb2 import Checkpoint | from mindspore.train.checkpoint_pb2 import Checkpoint | ||||
| from mindspore.train.node_strategy_pb2 import ParallelStrategyMap as ckpt_strategy | from mindspore.train.node_strategy_pb2 import ParallelStrategyMap as ckpt_strategy | ||||
| @@ -42,7 +42,7 @@ def _convert_type(types): | |||||
| Returns: | Returns: | ||||
| list, list of element in dataset. | list, list of element in dataset. | ||||
| """ | """ | ||||
| ms_types = [] | |||||
| ms_types = [] | |||||
| for np_type in types: | for np_type in types: | ||||
| ms_type = pytype_to_dtype(np_type) | ms_type = pytype_to_dtype(np_type) | ||||
| ms_types.append(ms_type) | ms_types.append(ms_type) | ||||