| @@ -506,7 +506,7 @@ KernelSelectStatus SelectKernelInfo(const CNodePtr &kernel_node) { | |||||
| if (select_status == kNoMatched) { | if (select_status == kNoMatched) { | ||||
| MS_LOG(WARNING) << "The node [" << kernel_node->DebugString() | MS_LOG(WARNING) << "The node [" << kernel_node->DebugString() | ||||
| << "] cannot find valid TBE kernel info, try to get aicpu kernel info"; | << "] cannot find valid TBE kernel info, try to get aicpu kernel info"; | ||||
| kernel::AICpuQuery(kernel_node, &kernel_info_list); | |||||
| kernel::AICPUQuery(kernel_node, &kernel_info_list); | |||||
| select_status = SetMatchedKernelInfo(kernel_node, kernel_info_list); | select_status = SetMatchedKernelInfo(kernel_node, kernel_info_list); | ||||
| AnfAlgo::SetNodeAttr(kAttrIsAICPUKernel, MakeValue(true), kernel_node); | AnfAlgo::SetNodeAttr(kAttrIsAICPUKernel, MakeValue(true), kernel_node); | ||||
| } | } | ||||
| @@ -71,21 +71,20 @@ void KernelQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel | |||||
| FilterInvalidKernelInfo(kernel_node, kernel_info_list); | FilterInvalidKernelInfo(kernel_node, kernel_info_list); | ||||
| } | } | ||||
| void AICpuQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list) { | |||||
| void AICPUQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list) { | |||||
| MS_EXCEPTION_IF_NULL(kernel_node); | MS_EXCEPTION_IF_NULL(kernel_node); | ||||
| MS_EXCEPTION_IF_NULL(kernel_info_list); | MS_EXCEPTION_IF_NULL(kernel_info_list); | ||||
| kernel_info_list->clear(); | kernel_info_list->clear(); | ||||
| AicpuMetadataInfo(kernel_node, kernel_info_list); | AicpuMetadataInfo(kernel_node, kernel_info_list); | ||||
| FilterInvalidKernelInfo(kernel_node, kernel_info_list); | FilterInvalidKernelInfo(kernel_node, kernel_info_list); | ||||
| } | } | ||||
| bool IsSupportedByAiCpu(const AnfNodePtr &kernel_node, const KernelBuildInfoPtr &select_kernel_build_info) { | |||||
| bool IsSupportedByAICPU(const AnfNodePtr &kernel_node, const KernelBuildInfoPtr &select_kernel_build_info) { | |||||
| MS_EXCEPTION_IF_NULL(kernel_node); | MS_EXCEPTION_IF_NULL(kernel_node); | ||||
| MS_EXCEPTION_IF_NULL(select_kernel_build_info); | MS_EXCEPTION_IF_NULL(select_kernel_build_info); | ||||
| std::vector<std::shared_ptr<kernel::KernelBuildInfo>> kernel_info_list; | std::vector<std::shared_ptr<kernel::KernelBuildInfo>> kernel_info_list; | ||||
| auto cnode = kernel_node->cast<CNodePtr>(); | auto cnode = kernel_node->cast<CNodePtr>(); | ||||
| MS_EXCEPTION_IF_NULL(cnode); | MS_EXCEPTION_IF_NULL(cnode); | ||||
| AicpuMetadataInfo(cnode, &kernel_info_list); | |||||
| FilterInvalidKernelInfo(cnode, &kernel_info_list); | |||||
| AICPUQuery(cnode, &kernel_info_list); | |||||
| return std::any_of(kernel_info_list.begin(), kernel_info_list.end(), | return std::any_of(kernel_info_list.begin(), kernel_info_list.end(), | ||||
| [&select_kernel_build_info](const kernel::KernelBuildInfoPtr item) { | [&select_kernel_build_info](const kernel::KernelBuildInfoPtr item) { | ||||
| MS_EXCEPTION_IF_NULL(item); | MS_EXCEPTION_IF_NULL(item); | ||||
| @@ -93,7 +92,7 @@ bool IsSupportedByAiCpu(const AnfNodePtr &kernel_node, const KernelBuildInfoPtr | |||||
| }); | }); | ||||
| } | } | ||||
| bool IsSupportedByAiCore(const AnfNodePtr &kernel_node, const KernelBuildInfoPtr &select_kernel_build_info) { | |||||
| bool IsSupportedByAICore(const AnfNodePtr &kernel_node, const KernelBuildInfoPtr &select_kernel_build_info) { | |||||
| MS_EXCEPTION_IF_NULL(kernel_node); | MS_EXCEPTION_IF_NULL(kernel_node); | ||||
| MS_EXCEPTION_IF_NULL(select_kernel_build_info); | MS_EXCEPTION_IF_NULL(select_kernel_build_info); | ||||
| std::vector<std::shared_ptr<kernel::KernelBuildInfo>> kernel_info_list; | std::vector<std::shared_ptr<kernel::KernelBuildInfo>> kernel_info_list; | ||||
| @@ -26,9 +26,9 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace kernel { | namespace kernel { | ||||
| void KernelQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list); | void KernelQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list); | ||||
| void AICpuQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list); | |||||
| bool IsSupportedByAiCpu(const AnfNodePtr &kernel_node, const KernelBuildInfoPtr &select_kernel_build_info); | |||||
| bool IsSupportedByAiCore(const AnfNodePtr &kernel_node, const KernelBuildInfoPtr &select_kernel_build_info); | |||||
| void AICPUQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list); | |||||
| bool IsSupportedByAICPU(const AnfNodePtr &kernel_node, const KernelBuildInfoPtr &select_kernel_build_info); | |||||
| bool IsSupportedByAICore(const AnfNodePtr &kernel_node, const KernelBuildInfoPtr &select_kernel_build_info); | |||||
| } // namespace kernel | } // namespace kernel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CCSRC_KERNEL_KERNEL_QUERY_H_ | #endif // MINDSPORE_CCSRC_KERNEL_KERNEL_QUERY_H_ | ||||
| @@ -559,6 +559,9 @@ bool IsShapeMatchFormat(const std::vector<size_t> &shape, const std::string &for | |||||
| if (format == kOpFormat_DEFAULT) { | if (format == kOpFormat_DEFAULT) { | ||||
| return true; | return true; | ||||
| } | } | ||||
| if (format == kOpFormat_NDHWC && shape.size() != kShape5dDims) { | |||||
| return false; | |||||
| } | |||||
| // if shape size is 0, the shape will be a scalar | // if shape size is 0, the shape will be a scalar | ||||
| if (shape.empty()) { | if (shape.empty()) { | ||||
| return true; | return true; | ||||
| @@ -574,21 +577,28 @@ bool IsShapeMatchFormat(const std::vector<size_t> &shape, const std::string &for | |||||
| bool IsValidKernelInfo(const std::shared_ptr<CNode> &kernel_node, const kernel::KernelBuildInfo &kernel_build_info) { | bool IsValidKernelInfo(const std::shared_ptr<CNode> &kernel_node, const kernel::KernelBuildInfo &kernel_build_info) { | ||||
| MS_EXCEPTION_IF_NULL(kernel_node); | MS_EXCEPTION_IF_NULL(kernel_node); | ||||
| auto check_function = [](const std::vector<size_t> &shape, const std::string &format) -> bool { | |||||
| if (!IsShapeMatchFormat(shape, format)) { | |||||
| return false; | |||||
| } | |||||
| return true; | |||||
| }; | |||||
| const size_t kCAxis = 1; | |||||
| for (size_t index = 0; index < kernel_build_info.GetOutputNum(); ++index) { | for (size_t index = 0; index < kernel_build_info.GetOutputNum(); ++index) { | ||||
| auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, index); | auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, index); | ||||
| if (!check_function(output_shape, kernel_build_info.GetOutputFormat(index))) { | |||||
| if (kernel_build_info.GetOutputFormat(index) == kOpFormat_FRACTAL_Z_C04) { | |||||
| if (output_shape.size() != kShape4dDims || output_shape[kCAxis] > 4) { | |||||
| return false; | |||||
| } | |||||
| return false; | |||||
| } | |||||
| if (!IsShapeMatchFormat(output_shape, kernel_build_info.GetOutputFormat(index))) { | |||||
| return false; | return false; | ||||
| } | } | ||||
| } | } | ||||
| for (size_t index = 0; index < kernel_build_info.GetInputNum(); ++index) { | for (size_t index = 0; index < kernel_build_info.GetInputNum(); ++index) { | ||||
| auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, index); | auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, index); | ||||
| if (!check_function(input_shape, kernel_build_info.GetInputFormat(index))) { | |||||
| if (!IsShapeMatchFormat(input_shape, kernel_build_info.GetInputFormat(index))) { | |||||
| return false; | |||||
| } | |||||
| if (kernel_build_info.GetInputFormat(index) == kOpFormat_FRACTAL_Z_C04) { | |||||
| if (input_shape.size() != kShape4dDims || input_shape[kCAxis] > 4) { | |||||
| return false; | |||||
| } | |||||
| return false; | return false; | ||||
| } | } | ||||
| } | } | ||||
| @@ -20,12 +20,12 @@ | |||||
| #include <string> | #include <string> | ||||
| #include <vector> | #include <vector> | ||||
| #include <memory> | #include <memory> | ||||
| #include "kernel/oplib/opinfo.h" | |||||
| #include "kernel/kernel_build_info.h" | #include "kernel/kernel_build_info.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace kernel { | namespace kernel { | ||||
| void TbeMetadataInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr<KernelBuildInfo>> *kernel_info_list); | void TbeMetadataInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr<KernelBuildInfo>> *kernel_info_list); | ||||
| bool CheckSupported(const AnfNodePtr &anf_node, const KernelBuildInfoPtr &select_kernel_build_info); | |||||
| } // namespace kernel | } // namespace kernel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -32,13 +32,13 @@ namespace opt { | |||||
| using KernelBuildInfoBuilder = kernel::KernelBuildInfo::KernelBuildInfoBuilder; | using KernelBuildInfoBuilder = kernel::KernelBuildInfo::KernelBuildInfoBuilder; | ||||
| namespace { | namespace { | ||||
| kernel::KernelBuildInfoPtr RefreshKernelBuildInfo(const std::string &input_format, const std::string &output_format, | kernel::KernelBuildInfoPtr RefreshKernelBuildInfo(const std::string &input_format, const std::string &output_format, | ||||
| const AnfNodePtr &node, | |||||
| const kernel::KernelBuildInfo ori_build_info) { | |||||
| const AnfNodePtr &node, const TypeId device_type, | |||||
| const kernel::KernelBuildInfo &ori_build_info) { | |||||
| KernelBuildInfoBuilder builder; | KernelBuildInfoBuilder builder; | ||||
| builder.SetInputsFormat({input_format}); | builder.SetInputsFormat({input_format}); | ||||
| builder.SetOutputsFormat({output_format}); | builder.SetOutputsFormat({output_format}); | ||||
| builder.SetInputsDeviceType({ori_build_info.GetInputDeviceType(0)}); | |||||
| builder.SetOutputsDeviceType({ori_build_info.GetOutputDeviceType(0)}); | |||||
| builder.SetInputsDeviceType({device_type}); | |||||
| builder.SetOutputsDeviceType({device_type}); | |||||
| builder.SetKernelType(ori_build_info.kernel_type()); | builder.SetKernelType(ori_build_info.kernel_type()); | ||||
| builder.SetFusionType(ori_build_info.fusion_type()); | builder.SetFusionType(ori_build_info.fusion_type()); | ||||
| builder.SetProcessor(ori_build_info.processor()); | builder.SetProcessor(ori_build_info.processor()); | ||||
| @@ -56,11 +56,7 @@ CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, | |||||
| CNodePtr trans_node = func_graph->NewCNode(trans_inputs); | CNodePtr trans_node = func_graph->NewCNode(trans_inputs); | ||||
| MS_EXCEPTION_IF_NULL(trans_node); | MS_EXCEPTION_IF_NULL(trans_node); | ||||
| std::vector<kernel::Axis> padding_axis; | std::vector<kernel::Axis> padding_axis; | ||||
| if (AnfAlgo::IsRealKernel(input)) { | |||||
| padding_axis = AnfAlgo::GetOutputReshapeType(input, 0); | |||||
| } else { | |||||
| padding_axis = AnfAlgo::GetPrevNodeOutputReshapeType(input, 0); | |||||
| } | |||||
| padding_axis = AnfAlgo::GetOutputReshapeType(input, 0); | |||||
| if (need_padding) { | if (need_padding) { | ||||
| // if need padding we should set the transdata node's shape to the padding shape | // if need padding we should set the transdata node's shape to the padding shape | ||||
| AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(input, 0)}, | AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(input, 0)}, | ||||
| @@ -129,15 +125,8 @@ AnfNodePtr GetTransInputNodePtr(const FuncGraphPtr &func_graph, const CNodePtr & | |||||
| AnfNodePtr InsertTransOpForSingleOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | AnfNodePtr InsertTransOpForSingleOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | ||||
| const KernelSelectPtr &kernel_select) { | const KernelSelectPtr &kernel_select) { | ||||
| MS_EXCEPTION_IF_NULL(node); | MS_EXCEPTION_IF_NULL(node); | ||||
| std::string output_format; | |||||
| std::vector<size_t> origin_shape; | |||||
| if (!AnfAlgo::IsRealKernel(node)) { | |||||
| output_format = AnfAlgo::GetPrevNodeOutputFormat(node, 0); | |||||
| origin_shape = AnfAlgo::GetPrevNodeOutputInferShape(node, 0); | |||||
| } else { | |||||
| output_format = AnfAlgo::GetOutputFormat(node, 0); | |||||
| origin_shape = AnfAlgo::GetOutputInferShape(node, 0); | |||||
| } | |||||
| std::string output_format = AnfAlgo::GetOutputFormat(node, 0); | |||||
| std::vector<size_t> origin_shape = AnfAlgo::GetOutputInferShape(node, 0); | |||||
| if (output_format == kOpFormat_NC1KHKWHWC0) { | if (output_format == kOpFormat_NC1KHKWHWC0) { | ||||
| MS_LOG(EXCEPTION) << "got the hw format " << output_format << "when insert the transdata node " | MS_LOG(EXCEPTION) << "got the hw format " << output_format << "when insert the transdata node " | ||||
| << node->DebugString(); | << node->DebugString(); | ||||
| @@ -186,6 +175,7 @@ AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePt | |||||
| AnfNodePtr trans_node = nullptr; | AnfNodePtr trans_node = nullptr; | ||||
| AnfNodePtr input_node = node; | AnfNodePtr input_node = node; | ||||
| AnfNodePtr trans_data = nullptr; | AnfNodePtr trans_data = nullptr; | ||||
| TypeId dtype = AnfAlgo::GetOutputDeviceDataType(node, 0); | |||||
| MS_EXCEPTION_IF_NULL(node); | MS_EXCEPTION_IF_NULL(node); | ||||
| if (origin_format.empty() || dest_format.empty()) { | if (origin_format.empty() || dest_format.empty()) { | ||||
| MS_LOG(EXCEPTION) << "trans op format is error, origin = " << origin_format << ", dest " << origin_format; | MS_LOG(EXCEPTION) << "trans op format is error, origin = " << origin_format << ", dest " << origin_format; | ||||
| @@ -196,6 +186,7 @@ AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePt | |||||
| MS_LOG(EXCEPTION) << "cannot insert a transdata node to a node's input which the node is not a cnode"; | MS_LOG(EXCEPTION) << "cannot insert a transdata node to a node's input which the node is not a cnode"; | ||||
| } | } | ||||
| auto cnode = node->cast<CNodePtr>(); | auto cnode = node->cast<CNodePtr>(); | ||||
| dtype = AnfAlgo::GetInputDeviceDataType(cnode, insert_index); | |||||
| MS_EXCEPTION_IF_NULL(cnode); | MS_EXCEPTION_IF_NULL(cnode); | ||||
| input_node = AnfAlgo::GetInputNode(cnode, insert_index); | input_node = AnfAlgo::GetInputNode(cnode, insert_index); | ||||
| } | } | ||||
| @@ -231,7 +222,7 @@ AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePt | |||||
| MS_EXCEPTION_IF_NULL(trans_data); | MS_EXCEPTION_IF_NULL(trans_data); | ||||
| MS_EXCEPTION_IF_NULL(trans_data->kernel_info()); | MS_EXCEPTION_IF_NULL(trans_data->kernel_info()); | ||||
| auto trans_ori_build_info = trans_data->kernel_info()->select_kernel_build_info(); | auto trans_ori_build_info = trans_data->kernel_info()->select_kernel_build_info(); | ||||
| auto kernel_build_info = RefreshKernelBuildInfo(origin_format, dest_format, input_node, *trans_ori_build_info); | |||||
| auto kernel_build_info = RefreshKernelBuildInfo(origin_format, dest_format, input_node, dtype, *trans_ori_build_info); | |||||
| AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info, trans_data.get()); | AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info, trans_data.get()); | ||||
| return trans_node; | return trans_node; | ||||
| } | } | ||||
| @@ -39,11 +39,11 @@ class SupportedChecker { | |||||
| virtual ~SupportedChecker() = default; | virtual ~SupportedChecker() = default; | ||||
| virtual bool CheckAiCoreSupported(const AnfNodePtr &anf_node, | virtual bool CheckAiCoreSupported(const AnfNodePtr &anf_node, | ||||
| const kernel::KernelBuildInfoPtr &select_kernel_build_info) { | const kernel::KernelBuildInfoPtr &select_kernel_build_info) { | ||||
| return kernel::IsSupportedByAiCore(anf_node, select_kernel_build_info); | |||||
| return kernel::IsSupportedByAICore(anf_node, select_kernel_build_info); | |||||
| } | } | ||||
| virtual bool CheckAiCpuSupported(const AnfNodePtr &anf_node, | virtual bool CheckAiCpuSupported(const AnfNodePtr &anf_node, | ||||
| const kernel::KernelBuildInfoPtr &select_kernel_build_info) { | const kernel::KernelBuildInfoPtr &select_kernel_build_info) { | ||||
| return kernel::IsSupportedByAiCpu(anf_node, select_kernel_build_info); | |||||
| return kernel::IsSupportedByAICPU(anf_node, select_kernel_build_info); | |||||
| } | } | ||||
| }; | }; | ||||
| using SupportedCheckerPtr = std::shared_ptr<SupportedChecker>; | using SupportedCheckerPtr = std::shared_ptr<SupportedChecker>; | ||||
| @@ -114,8 +114,8 @@ bool ParameterTransOpFusion::Run(const FuncGraphPtr &func_graph) { | |||||
| auto param_dtype = AnfAlgo::GetOutputDeviceDataType(final_node, 0); | auto param_dtype = AnfAlgo::GetOutputDeviceDataType(final_node, 0); | ||||
| auto cast = trans_road[1]; | auto cast = trans_road[1]; | ||||
| AnfAlgo::SetSelectKernelBuildInfo(GetKernelBuildInfo(cast, format, param_dtype, dtype), cast.get()); | |||||
| if (param_format == format && param_dtype != dtype) { | if (param_format == format && param_dtype != dtype) { | ||||
| AnfAlgo::SetSelectKernelBuildInfo(GetKernelBuildInfo(cast, format, param_dtype, dtype), cast.get()); | |||||
| manager->Replace(trans_road[2], final_node); | manager->Replace(trans_road[2], final_node); | ||||
| manager->Replace(cur_transop, cast); | manager->Replace(cur_transop, cast); | ||||
| } | } | ||||
| @@ -292,6 +292,9 @@ std::string AnfRuntimeAlgorithm::GetOutputFormat(const AnfNodePtr &node, size_t | |||||
| << " is out of the node output range :" << GetOutputTensorNum(node) << " #node [" | << " is out of the node output range :" << GetOutputTensorNum(node) << " #node [" | ||||
| << node->DebugString() << "]"; | << node->DebugString() << "]"; | ||||
| } | } | ||||
| if (!AnfAlgo::IsRealKernel(node)) { | |||||
| return AnfAlgo::GetPrevNodeOutputFormat(node, output_idx); | |||||
| } | |||||
| auto kernel_info = node->kernel_info(); | auto kernel_info = node->kernel_info(); | ||||
| MS_EXCEPTION_IF_NULL(kernel_info); | MS_EXCEPTION_IF_NULL(kernel_info); | ||||
| auto build_info = kernel_info->select_kernel_build_info(); | auto build_info = kernel_info->select_kernel_build_info(); | ||||
| @@ -311,6 +314,9 @@ std::string AnfRuntimeAlgorithm::GetInputFormat(const AnfNodePtr &node, size_t i | |||||
| << " is out of the number node Input range :" << GetInputTensorNum(node) << "#node [" | << " is out of the number node Input range :" << GetInputTensorNum(node) << "#node [" | ||||
| << node->DebugString() << "]"; | << node->DebugString() << "]"; | ||||
| } | } | ||||
| if (!IsRealKernel(node)) { | |||||
| GetPrevNodeOutputFormat(node, input_idx); | |||||
| } | |||||
| auto kernel_info = node->kernel_info(); | auto kernel_info = node->kernel_info(); | ||||
| MS_EXCEPTION_IF_NULL(kernel_info); | MS_EXCEPTION_IF_NULL(kernel_info); | ||||
| auto build_info = kernel_info->select_kernel_build_info(); | auto build_info = kernel_info->select_kernel_build_info(); | ||||
| @@ -367,8 +373,8 @@ std::vector<size_t> AnfRuntimeAlgorithm::GetOutputInferShape(const AnfNodePtr &n | |||||
| } else if (b_shp->isa<abstract::NoShape>()) { | } else if (b_shp->isa<abstract::NoShape>()) { | ||||
| return std::vector<size_t>(); | return std::vector<size_t>(); | ||||
| } else { | } else { | ||||
| MS_LOG(EXCEPTION) << "The output type of ApplyKernel should be a NoShape , ArrayShape or a TupleShape, but it is " | |||||
| << base_shape->ToString(); | |||||
| MS_LOG(EXCEPTION) << "The output type of ApplyKernel index:" << output_idx | |||||
| << " should be a NoShape , ArrayShape or a TupleShape, but it is " << base_shape->ToString(); | |||||
| } | } | ||||
| } else if (base_shape->isa<abstract::NoShape>()) { | } else if (base_shape->isa<abstract::NoShape>()) { | ||||
| return std::vector<size_t>(); | return std::vector<size_t>(); | ||||
| @@ -415,6 +421,9 @@ std::vector<kernel::Axis> AnfRuntimeAlgorithm::GetInputReshapeType(const AnfNode | |||||
| << " is out of range of the node's input size : " << GetInputTensorNum(node) << "#node[" | << " is out of range of the node's input size : " << GetInputTensorNum(node) << "#node[" | ||||
| << node->DebugString() << "]"; | << node->DebugString() << "]"; | ||||
| } | } | ||||
| if (!IsRealKernel(node)) { | |||||
| return GetPrevNodeOutputReshapeType(node, input_idx); | |||||
| } | |||||
| auto kernel_info = node->kernel_info(); | auto kernel_info = node->kernel_info(); | ||||
| MS_EXCEPTION_IF_NULL(kernel_info); | MS_EXCEPTION_IF_NULL(kernel_info); | ||||
| auto build_info = kernel_info->select_kernel_build_info(); | auto build_info = kernel_info->select_kernel_build_info(); | ||||
| @@ -431,6 +440,9 @@ std::vector<kernel::Axis> AnfRuntimeAlgorithm::GetOutputReshapeType(const AnfNod | |||||
| MS_LOG(EXCEPTION) << "The index [" << output_idx << "] is out of range of the node's output size [ " | MS_LOG(EXCEPTION) << "The index [" << output_idx << "] is out of range of the node's output size [ " | ||||
| << GetOutputTensorNum(node) << "#node[ " << node->DebugString() << "]"; | << GetOutputTensorNum(node) << "#node[ " << node->DebugString() << "]"; | ||||
| } | } | ||||
| if (!IsRealKernel(node)) { | |||||
| return GetPrevNodeOutputReshapeType(node, output_idx); | |||||
| } | |||||
| auto kernel_info = node->kernel_info(); | auto kernel_info = node->kernel_info(); | ||||
| MS_EXCEPTION_IF_NULL(kernel_info); | MS_EXCEPTION_IF_NULL(kernel_info); | ||||
| auto build_info = kernel_info->select_kernel_build_info(); | auto build_info = kernel_info->select_kernel_build_info(); | ||||
| @@ -488,6 +500,9 @@ TypeId AnfRuntimeAlgorithm::GetOutputDeviceDataType(const AnfNodePtr &node, size | |||||
| MS_LOG(EXCEPTION) << "The index [" << output_idx << "] is out of range of the node's output size [ " | MS_LOG(EXCEPTION) << "The index [" << output_idx << "] is out of range of the node's output size [ " | ||||
| << GetOutputTensorNum(node) << "#node [ " << node->DebugString() << "]"; | << GetOutputTensorNum(node) << "#node [ " << node->DebugString() << "]"; | ||||
| } | } | ||||
| if (!IsRealKernel(node)) { | |||||
| return GetPrevNodeOutputDeviceDataType(node, output_idx); | |||||
| } | |||||
| auto kernel_info = node->kernel_info(); | auto kernel_info = node->kernel_info(); | ||||
| MS_EXCEPTION_IF_NULL(kernel_info); | MS_EXCEPTION_IF_NULL(kernel_info); | ||||
| auto build_info = kernel_info->select_kernel_build_info(); | auto build_info = kernel_info->select_kernel_build_info(); | ||||
| @@ -506,6 +521,9 @@ TypeId AnfRuntimeAlgorithm::GetInputDeviceDataType(const AnfNodePtr &node, size_ | |||||
| MS_LOG(EXCEPTION) << "The index [" << input_idx << "] is out of range of the node's input size [ " | MS_LOG(EXCEPTION) << "The index [" << input_idx << "] is out of range of the node's input size [ " | ||||
| << GetInputTensorNum(node) << "#node [ " << node->DebugString() << "]"; | << GetInputTensorNum(node) << "#node [ " << node->DebugString() << "]"; | ||||
| } | } | ||||
| if (!IsRealKernel(node)) { | |||||
| return GetPrevNodeOutputDeviceDataType(node, 0); | |||||
| } | |||||
| auto kernel_info = node->kernel_info(); | auto kernel_info = node->kernel_info(); | ||||
| MS_EXCEPTION_IF_NULL(kernel_info); | MS_EXCEPTION_IF_NULL(kernel_info); | ||||
| auto build_info = kernel_info->select_kernel_build_info(); | auto build_info = kernel_info->select_kernel_build_info(); | ||||