| @@ -0,0 +1,71 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_GRAPH_COMPILER_H_ | |||||
| #define MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_GRAPH_COMPILER_H_ | |||||
| #include <vector> | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <unordered_map> | |||||
| #include "runtime/hardware/device_context.h" | |||||
| #include "backend/session/session_basic.h" | |||||
| namespace mindspore { | |||||
| namespace runtime { | |||||
| class GraphCompiler { | |||||
| public: | |||||
| static GraphCompiler &GetInstance() { | |||||
| static GraphCompiler instance; | |||||
| return instance; | |||||
| } | |||||
| // Set device context which is initialized, the function must be called | |||||
| // before using GraphCompiler and after changing device type or device id. | |||||
| void set_device_context(device::DeviceContext *device_context); | |||||
| // Construct kernel graph from anf nodes list and compile kernel graph in Graph mode, | |||||
| // the detailed implementation of compiling graph is in 'CompileGraphImpl'. | |||||
| GraphId CompileGraph(const AnfNodePtrList &nodes, const AnfNodePtrList &outputs); | |||||
| // Run a graph and get the output in Graph mode. | |||||
| void RunGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs); | |||||
| // Construct single op kernel graph, compile and run the kernel graph in PyNative mode. | |||||
| void CompileAndRunGraph(OpRunInfo *op_run_info, const GraphInfo &graph_info, | |||||
| std::vector<tensor::TensorPtr> *input_tensors, const std::vector<int64_t> &tensors_mask, | |||||
| VectorRef *outputs); | |||||
| private: | |||||
| GraphCompiler() = default; | |||||
| ~GraphCompiler() = default; | |||||
| DISABLE_COPY_AND_ASSIGN(GraphCompiler); | |||||
| // The implementation of compiling graph in Graph Mode, including optimizing graph, | |||||
| // setting operator info, creating kernel and transforming kernel graph to ActorSet. | |||||
| GraphId CompileGraphImpl(const KernelGraphPtr &graph); | |||||
| device::DeviceContext *device_context_{nullptr}; | |||||
| // Single op kernel graph cache for PyNative mode. | |||||
| std::unordered_map<GraphInfo, std::shared_ptr<KernelGraph>> run_op_graphs_; | |||||
| // The member variable 'session_' will be removed after removing session module. | |||||
| session::SessionPtr session_{nullptr}; | |||||
| }; | |||||
| } // namespace runtime | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_GRAPH_COMPILER_H_ | |||||