From 5f9d2759eefd134dc230437081237ed7501bfb79 Mon Sep 17 00:00:00 2001 From: WilliamLian Date: Sat, 20 Jun 2020 10:39:23 +0800 Subject: [PATCH] fix bug of hccl kernel info and change cast's kernel info --- .../ccsrc/kernel/hccl/hccl_kernel_metadata.cc | 42 ++++++++++--------- .../ascend/format_type/merge_cast_to_op.cc | 28 +++++++++++-- 2 files changed, 47 insertions(+), 23 deletions(-) diff --git a/mindspore/ccsrc/kernel/hccl/hccl_kernel_metadata.cc b/mindspore/ccsrc/kernel/hccl/hccl_kernel_metadata.cc index f0a0dda258..601d5cf1ea 100755 --- a/mindspore/ccsrc/kernel/hccl/hccl_kernel_metadata.cc +++ b/mindspore/ccsrc/kernel/hccl/hccl_kernel_metadata.cc @@ -23,6 +23,8 @@ namespace mindspore { namespace kernel { void HcclMetadataInfo(const CNodePtr &kernel_node, std::vector> *kernel_info_list) { + const std::vector kHcclSupportTypes = {kNumberTypeInt8, kNumberTypeInt32, kNumberTypeFloat16, + kNumberTypeFloat32, kNumberTypeInt16}; MS_EXCEPTION_IF_NULL(kernel_info_list); MS_EXCEPTION_IF_NULL(kernel_node); std::string op_name = AnfAlgo::GetCNodeName(kernel_node); @@ -30,27 +32,27 @@ void HcclMetadataInfo(const CNodePtr &kernel_node, std::vector inputs_format{}; - std::vector inputs_type{}; - for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) { - inputs_format.emplace_back(AnfAlgo::GetPrevNodeOutputFormat(kernel_node, input_index)); - inputs_type.push_back(AnfAlgo::GetPrevNodeOutputDeviceDataType(kernel_node, input_index)); - } - - std::vector outputs_format; - std::vector outputs_type; - for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(kernel_node); ++output_index) { - outputs_format.emplace_back(AnfAlgo::GetPrevNodeOutputFormat(kernel_node, output_index)); - outputs_type.push_back(AnfAlgo::GetOutputInferDataType(kernel_node, output_index)); + for (const auto &type : kHcclSupportTypes) { + std::vector inputs_format{}; + std::vector inputs_type{}; + for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) { + inputs_format.emplace_back(AnfAlgo::GetPrevNodeOutputFormat(kernel_node, input_index)); + inputs_type.push_back(type); + } + std::vector outputs_format; + std::vector outputs_type; + for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(kernel_node); ++output_index) { + outputs_format.emplace_back(AnfAlgo::GetPrevNodeOutputFormat(kernel_node, output_index)); + outputs_type.push_back(type); + } + auto builder = KernelBuildInfo::KernelBuildInfoBuilder(); + builder.SetInputsFormat(inputs_format); + builder.SetInputsDeviceType(inputs_type); + builder.SetOutputsFormat(outputs_format); + builder.SetOutputsDeviceType(outputs_type); + builder.SetKernelType(HCCL_KERNEL); + kernel_info_list->push_back(builder.Build()); } - auto builder = KernelBuildInfo::KernelBuildInfoBuilder(); - builder.SetInputsFormat(inputs_format); - builder.SetInputsDeviceType(inputs_type); - builder.SetOutputsFormat(outputs_format); - builder.SetOutputsDeviceType(outputs_type); - builder.SetKernelType(HCCL_KERNEL); - kernel_info_list->push_back(builder.Build()); } } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/format_type/merge_cast_to_op.cc b/mindspore/ccsrc/pre_activate/ascend/format_type/merge_cast_to_op.cc index 8bb58c18a5..4377eddf32 100644 --- a/mindspore/ccsrc/pre_activate/ascend/format_type/merge_cast_to_op.cc +++ b/mindspore/ccsrc/pre_activate/ascend/format_type/merge_cast_to_op.cc @@ -120,6 +120,24 @@ bool CheckIndexOutput(const CNodePtr &node, const std::shared_ptrGetOutputFormat(index); } +void ChangeNodeInferInfo(const CNodePtr &cnode, const CNodePtr &cast, const size_t cast_index) { + using Shape = std::vector; + auto cast_dtype = AnfAlgo::GetOutputInferDataType(cast, 0); + auto cast_shape = AnfAlgo::GetOutputInferShape(cast, 0); + std::vector shapes; + std::vector types; + for (size_t index = 0; index < AnfAlgo::GetOutputTensorNum(cnode); ++index) { + if (cast_index == index) { + shapes.emplace_back(cast_shape); + types.emplace_back(cast_dtype); + continue; + } + shapes.emplace_back(AnfAlgo::GetOutputInferShape(cnode, index)); + types.emplace_back(AnfAlgo::GetOutputInferDataType(cnode, index)); + } + AnfAlgo::SetOutputInferTypeAndShape(types, shapes, cnode.get()); +} + AnfNodePtr MergeCastToNextOp(const FuncGraphPtr &graph, const CNodePtr &node, const KernelQueryPtr kernel_query) { MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(kernel_query); @@ -151,9 +169,9 @@ AnfNodePtr MergeCastToNextOp(const FuncGraphPtr &graph, const CNodePtr &node, co << "ori kernel info" << ori_kernel_info->ToString() << "alternative kernel info" << (*alternative_kernel_info)->ToString(); AnfAlgo::SetSelectKernelBuildInfo(*alternative_kernel_info, next_cnode.get()); + ChangeNodeInferInfo(next_cnode, node, cast_index); if (node->inputs().size() < kCastInputNum) { - auto op_name = AnfAlgo::GetCNodeName(node); - MS_LOG(EXCEPTION) << "op[" << op_name << "] has wrong input num:"; + MS_LOG(EXCEPTION) << "Op[" << node->DebugString() << "] has wrong input num:"; } return node->input(1); } @@ -223,7 +241,11 @@ AnfNodePtr MergeCastToPriorOp(const FuncGraphPtr &graph, const CNodePtr &cur_nod << "ori kernel info" << ori_kernel_info->ToString() << "alternative kernel info" << (*kernel_info_it)->ToString(); AnfAlgo::SetSelectKernelBuildInfo(*kernel_info_it, prior_op.get()); - + ChangeNodeInferInfo(prior_op, cur_node, output_idx); + if (!single_output) { + MS_EXCEPTION_IF_NULL(x_node); + ChangeNodeInferInfo(x_node->cast(), cur_node, 0); + } auto prior_name = AnfAlgo::GetCNodeName(prior_op); if (prior_name == kFive2FourOpName) { AnfAlgo::CopyNodeAttr("dst_type", "dstType", cur_node, prior_op);