/** * Copyright 2019 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_CCSRC_SESSION_ANF_RUNTIME_ALGORITHM_H #define MINDSPORE_CCSRC_SESSION_ANF_RUNTIME_ALGORITHM_H #include #include #include #include #include #include #include #include "ir/anf.h" #include "ir/dtype.h" #include "ir/base.h" #include "ir/primitive.h" #include "device/device_address.h" #include "kernel/kernel.h" #include "kernel/kernel_build_info.h" #include "operator/ops.h" #include "utils/contract.h" #include "session/kernel_graph.h" namespace mindspore { namespace session { using AnfVisitFuncion = std::function; using KernelWithIndex = std::pair; class AnfRuntimeAlgorithm { public: // get input_anf_node's real kernel by recurse static KernelWithIndex VisitKernel(const AnfNodePtr &input_anf_node, size_t output_index); static KernelWithIndex VisitKernelWithReturnType(const AnfNodePtr &input_anf_node, size_t output_index, bool visit_nop_node = false, const std::vector &return_types = { prim::kPrimMakeTuple}); static std::vector GetAllOutput(const AnfNodePtr &node, const std::vector &return_types = {}); // get cnode primitive static AnfNodePtr GetCNodePrimitiveNode(const CNodePtr &node); static void SetNodeInput(const CNodePtr &node, const AnfNodePtr &input_node, size_t index); static PrimitivePtr GetCNodePrimitive(const AnfNodePtr &node); // check whether anf node is a node of 'primitive_type',such as make_tuple is a cnode of kPrimMakeTuple static bool CheckPrimitiveType(const AnfNodePtr &node, const PrimitivePtr &primitive_type); // get cnode primitive static FuncGraphPtr GetCNodeFuncGraphPtr(const AnfNodePtr &node); // get kernel_name of anf node static std::string GetCNodeName(const AnfNodePtr &node); // get detail info of anf node static std::string GetNodeDebugString(const AnfNodePtr &node); // get attr of anf node template static T GetNodeAttr(const AnfNodePtr &node, const std::string &key) { MS_EXCEPTION_IF_NULL(node); if (!node->isa()) { std::string node_debug_log = node->DebugString(); MS_LOG(EXCEPTION) << "Only cnode has attr, but this anf is " << node_debug_log.c_str(); } auto primitive = GetCNodePrimitive(node); MS_EXCEPTION_IF_NULL(primitive); return GetValue(primitive->GetAttr(key)); } static bool IsTupleOutput(const AnfNodePtr &anf); // set attr of anf node static void SetNodeAttr(const std::string &key, const ValuePtr &value, const AnfNodePtr &node); // set attr of key from 'from' node to 'to' node static void CopyNodeAttr(const std::string &key, const AnfNodePtr &from, const AnfNodePtr &to); // set a new key for attr from 'from' node to 'to' node static void CopyNodeAttr(const std::string &old_key, const std::string &new_key, const AnfNodePtr &from, const AnfNodePtr &to); // set all attrs from 'from' node to 'to' node static void CopyNodeAttrs(const AnfNodePtr &from, const AnfNodePtr &to); // check whether a cnode has the specified attr. static bool HasNodeAttr(const std::string &key, const CNodePtr &node); // delete attr of anf node static void EraseNodeAttr(const std::string &key, AnfNodePtr node); // get the num of input real_kernel(which can be build and run in device) static size_t GetInputTensorNum(const AnfNodePtr &node); // get the num of output real_kernel(which can be build and run in device) static size_t GetOutputTensorNum(const AnfNodePtr &node); // get output format select of anf node static std::string GetOutputFormat(const AnfNodePtr &node, size_t output_idx); // get input format select of anf node static std::string GetInputFormat(const AnfNodePtr &node, size_t input_idx); // get prev node output width output index static KernelWithIndex GetPrevNodeOutput(const AnfNodePtr &anf_node, size_t input_idx); // get output format from prev node,input_index is the input index of current node related to prev node static std::string GetPrevNodeOutputFormat(const AnfNodePtr &node, size_t input_idx); // get reshape_type of from the output of input node. static std::vector GetPrevNodeOutputReshapeType(const AnfNodePtr &node, size_t input_idx); // get output shapes inferred by ME from input nodes. static std::vector GetOutputInferShape(const AnfNodePtr &node, size_t output_idx); // get input shapes inferred by ME from input nodes. static std::vector GetPrevNodeOutputInferShape(const AnfNodePtr &node, size_t input_idx); // get output shapes which will built and run in device static std::vector GetOutputDeviceShape(const AnfNodePtr &node, size_t output_idx); // get input shapes which will built and run in device static std::vector GetInputDeviceShape(const AnfNodePtr &node, size_t input_idx); // Get Input Padding Axis static std::vector GetInputReshapeType(const AnfNodePtr &node, size_t output_idx); // Get Output Padding Axis static std::vector GetOutputReshapeType(const AnfNodePtr &node, size_t output_idx); // get output data type inferred by ME of anf node static TypeId GetOutputInferDataType(const AnfNodePtr &node, size_t output_idx); // get output original data type from prev node,input_index is the input index of current node related to prev node static TypeId GetPrevNodeOutputInferDataType(const AnfNodePtr &node, size_t input_idx); // get output select data type of anf node static TypeId GetOutputDeviceDataType(const AnfNodePtr &node, size_t output_idx); // get input select data type of anf node static TypeId GetInputDeviceDataType(const AnfNodePtr &node, size_t input_idx); // get output select data type from prev node,input_index is the input index of current node related to prev node static TypeId GetPrevNodeOutputDeviceDataType(const AnfNodePtr &node, size_t input_idx); // get output device addr of anf_node static const DeviceAddress *GetOutputAddr(const AnfNodePtr &node, size_t output_idx, bool visit_nop_node = true); // get mutable output device addr of anf_node static DeviceAddressPtr GetMutableOutputAddr(const AnfNodePtr &node, size_t output_idx, bool visit_nop_node = true); // check whether output addr is exist or not static bool OutputAddrExist(const AnfNodePtr &node, size_t output_idx); // get address from prev node,input_index is the input index of current node related to prev node static const DeviceAddress *GetPrevNodeOutputAddr(const AnfNodePtr &node, size_t input_idx, bool visit_nop_node = true); static DeviceAddressPtr GetPrevNodeMutableOutputAddr(const AnfNodePtr &anf_node, size_t input_idx, bool visit_nop_node = true); // set output device addr of anf_node static void SetOutputAddr(const DeviceAddressPtr &addr, size_t output_idx, AnfNode *node); // set workspace device addr of anf_node static void SetWorkspaceAddr(const DeviceAddressPtr &addr, size_t output_idx, AnfNode *node); // get workspace device addr of anf_node static DeviceAddress *GetWorkspaceAddr(const AnfNodePtr &node, size_t output_idx); // set infer shapes and types of anf node static void SetOutputInferTypeAndShape(const std::vector &types, const std::vector> &shapes, AnfNode *node); static void CopyAbstract(const AnfNodePtr &from_node, AnfNode *to_node); // get op pattern of the node static kernel::OpPattern GetOpPattern(const AnfNodePtr &node); // get KernelBuildType of node ,such as ATT,RT,FWK and so on static KernelType GetKernelType(const AnfNodePtr &node); // get processor type:AICORE,AICPU... static kernel::Processor GetProcessor(const AnfNodePtr &node); // get fusion type:AICORE,AICPU... static kernel::FusionType GetFusionType(const AnfNodePtr &node); // set select kernel_build_info static void SetSelectKernelBuildInfo(const kernel::KernelBuildInfoPtr &select_kernel_build_info, AnfNode *node); // get select kernel_build_info static kernel::KernelBuildInfoPtr GetSelectKernelBuildInfo(const AnfNodePtr &node); // get kernelMode static kernel::KernelMod *GetKernelMod(const AnfNodePtr &node); // set kernel mod static void SetKernelMod(const kernel::KernelModPtr &kernel_mod, AnfNode *node); // checkout whether the anf node is a real kernel that can run on device,parameter and constant is real kernel too static bool IsRealKernel(const AnfNodePtr &node); // checkout whether the anf node is a real kernel that is a cnode and can run on device static bool IsRealCNodeKernel(const AnfNodePtr &node); // checkout whether the anf node is a graph kernel. static bool IsGraphKernel(const AnfNodePtr &node); // check parameter is weight or data static bool IsParameterWeight(const ParameterPtr &node); // set stream id of kernel,which will be set in stream assign and be used in stream generate static void SetStreamId(uint32_t stream_id, AnfNode *node); // get stream id static uint32_t GetStreamId(const AnfNodePtr &node); // set stream distinction label to distinguish different ops in different streams static void SetStreamDistinctionLabel(uint32_t stream_label, AnfNode *node); // get stream distinction label static uint32_t GetStreamDistinctionLabel(const AnfNode *node); // set graph id static void SetGraphId(uint32_t graph_id, AnfNode *node); // get graph id static uint32_t GetGraphId(const AnfNode *node); static AnfNodePtr GetInputNode(const CNodePtr &node, size_t index); // charge if the node's output is a feature map output static bool IsFeatureMapOutput(const AnfNodePtr &node); // charge if the node's input is from a feature map output static bool IsFeatureMapInput(const AnfNodePtr &node, size_t input_index); // get real input index for some tbe ops which input order is different between me and tbe impl static size_t GetRealInputIndex(const AnfNodePtr &anf_node, const size_t cur_index); static bool IsCommunicationOp(const AnfNodePtr &node); static bool IsGetNext(const NotNull &node); static FuncGraphPtr GetValueNodeFuncGraph(const AnfNodePtr &node); static std::vector GetCallNodeKernelGraph(const CNodePtr &call_node); static bool IsSwitchCall(const CNodePtr &call_node); static bool IsScalarInput(const CNodePtr &cnode, size_t index); static bool IsScalarOutput(const CNodePtr &cnode, size_t index); static void ReorderExecList(NotNull *> node_list); static bool IsWhileTrueGraph(const KernelGraphPtr &child_graph); // get fix output precision of cnode. static TypeId GetCNodeOutputPrecision(const AnfNodePtr &node); // get fix output precision from prev node, input_idx is the input index of current node related to prev node. static TypeId GetPrevNodeOutputPrecision(const AnfNodePtr &node, size_t input_idx); }; } // namespace session using AnfAlgo = session::AnfRuntimeAlgorithm; } // namespace mindspore #endif // MINDSPORE_CCSRC_SESSION_ANF_RUNTIME_ALGORITHM_H