/** * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_GRAPH_COMPILER_H_ #define MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_GRAPH_COMPILER_H_ #include #include #include #include #include #include "runtime/hardware/device_context.h" #include "backend/session/session_basic.h" #include "ir/tensor.h" namespace mindspore { using device::DeviceContext; using mindspore::tensor::TensorPtr; using session::InputTensorInfo; using session::KernelWithIndex; using session::OpRunInfo; namespace runtime { class GraphCompiler { public: static GraphCompiler &GetInstance() { static GraphCompiler instance; return instance; } // Set device context which is initialized, the function must be called // before using GraphCompiler and after changing device type or device id. void set_device_context(DeviceContext *device_context); // Construct kernel graph from anf nodes list and compile kernel graph in Graph mode, // the detailed implementation of compiling graph is in 'CompileGraphImpl'. GraphId CompileGraph(const AnfNodePtrList &nodes, const AnfNodePtrList &outputs); // Construct single op kernel graph and compile the kernel graph in PyNative mode. GraphId CompileGraph(const session::OpRunInfo &op_run_info, const GraphInfo &graph_info, const std::vector *tensors_mask, std::vector *input_tensors, bool *single_op_cache_hit); // Get graph by graph id, if not exist return nullptr, used in Graph mode. KernelGraphPtr Fetch(GraphId graph_id) const; // Get graph by graph info, if not exist return nullptr, used in PyNative mode. KernelGraphPtr Fetch(const GraphInfo &graph_info) const; // The following four methods used in PyNative back propagation to split complete kernel graph to single // op graph, and these methods will be removed to class MindRTBackend after deleting session module. // Cache index for all parameter and output nodes of kernel graph, used to get parameter of single op and // recover output of original complete back propagation kernel graph. void GetParamAndOutputIndex(const KernelGraphPtr &graph, const std::vector &inputs, VectorRef *outputs, std::map *parameter_index, std::map>> *output_indexes); // Get input tensors for single op compile and run, input tensors may convert from value node and parameter in graph // and prev kernel node's output. void GetSingleOpInputTensors(const CNodePtr &kernel, const std::map &op_output, const std::map ¶meter_index, const std::vector &graph_inputs, InputTensorInfo *input_tensor_info); // Get OpRunInfo and GraphInfo for single op compile and run. void GetSingleOpRunInfoAndGraphInfo(const CNodePtr &kernel, const std::vector &input_tensors, OpRunInfo *run_info, GraphInfo *graph_info); // Handle single op output tensor and recover output of original complete kernel graph. void RecoverGraphOutput(const AnfNodePtr &kernel, const VectorRef &op_outputs, const std::map>> &output_indexes, std::map *op_output_map, VectorRef *outputs, std::vector *runop_output_tensors); // Collect output tensors of back propagation graph for allreduce operators to average gradient, // used in PyNative distributed training mode. void AddGradAddrToBucket(const GraphId &graph_id, const std::vector &grad_tensor); // Clear resource in bucket, such as useless tensors and device memory of all communication operators, // Bucket is used in PyNative distributed training mode, one bucket handles all resource to launch and sync allreduce // operator. void ClearAllBucket(const GraphId &graph_id); private: GraphCompiler() = default; ~GraphCompiler() = default; DISABLE_COPY_AND_ASSIGN(GraphCompiler); // The implementation of compiling graph in Graph Mode, including optimizing graph, // setting operator info, creating kernel and transforming kernel graph to ActorSet. GraphId CompileGraphImpl(const KernelGraphPtr &graph) const; // Create device address for all anf nodes of graph. void CreateDeviceAddress(const KernelGraphPtr &graph) const; DeviceContext *device_context_{nullptr}; // Single op kernel graph cache for PyNative mode. std::unordered_map run_op_graphs_; // The member variable 'session_' will be removed after removing session module. session::SessionPtr session_{nullptr}; }; } // namespace runtime } // namespace mindspore #endif // MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_GRAPH_COMPILER_H_