| @@ -0,0 +1,71 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_GRAPH_COMPILER_H_ | |||
| #define MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_GRAPH_COMPILER_H_ | |||
| #include <vector> | |||
| #include <memory> | |||
| #include <string> | |||
| #include <unordered_map> | |||
| #include "runtime/hardware/device_context.h" | |||
| #include "backend/session/session_basic.h" | |||
| namespace mindspore { | |||
| namespace runtime { | |||
| class GraphCompiler { | |||
| public: | |||
| static GraphCompiler &GetInstance() { | |||
| static GraphCompiler instance; | |||
| return instance; | |||
| } | |||
| // Set device context which is initialized, the function must be called | |||
| // before using GraphCompiler and after changing device type or device id. | |||
| void set_device_context(device::DeviceContext *device_context); | |||
| // Construct kernel graph from anf nodes list and compile kernel graph in Graph mode, | |||
| // the detailed implementation of compiling graph is in 'CompileGraphImpl'. | |||
| GraphId CompileGraph(const AnfNodePtrList &nodes, const AnfNodePtrList &outputs); | |||
| // Run a graph and get the output in Graph mode. | |||
| void RunGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs); | |||
| // Construct single op kernel graph, compile and run the kernel graph in PyNative mode. | |||
| void CompileAndRunGraph(OpRunInfo *op_run_info, const GraphInfo &graph_info, | |||
| std::vector<tensor::TensorPtr> *input_tensors, const std::vector<int64_t> &tensors_mask, | |||
| VectorRef *outputs); | |||
| private: | |||
| GraphCompiler() = default; | |||
| ~GraphCompiler() = default; | |||
| DISABLE_COPY_AND_ASSIGN(GraphCompiler); | |||
| // The implementation of compiling graph in Graph Mode, including optimizing graph, | |||
| // setting operator info, creating kernel and transforming kernel graph to ActorSet. | |||
| GraphId CompileGraphImpl(const KernelGraphPtr &graph); | |||
| device::DeviceContext *device_context_{nullptr}; | |||
| // Single op kernel graph cache for PyNative mode. | |||
| std::unordered_map<GraphInfo, std::shared_ptr<KernelGraph>> run_op_graphs_; | |||
| // The member variable 'session_' will be removed after removing session module. | |||
| session::SessionPtr session_{nullptr}; | |||
| }; | |||
| } // namespace runtime | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_GRAPH_COMPILER_H_ | |||