| @@ -23,6 +23,8 @@ | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| void HcclMetadataInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr<KernelBuildInfo>> *kernel_info_list) { | |||
| const std::vector<TypeId> kHcclSupportTypes = {kNumberTypeInt8, kNumberTypeInt32, kNumberTypeFloat16, | |||
| kNumberTypeFloat32, kNumberTypeInt16}; | |||
| MS_EXCEPTION_IF_NULL(kernel_info_list); | |||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||
| std::string op_name = AnfAlgo::GetCNodeName(kernel_node); | |||
| @@ -30,27 +32,27 @@ void HcclMetadataInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr<K | |||
| MS_LOG(DEBUG) << "Hccl does not have op [" << op_name << "]"; | |||
| return; | |||
| } | |||
| std::vector<std::string> inputs_format{}; | |||
| std::vector<TypeId> inputs_type{}; | |||
| for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) { | |||
| inputs_format.emplace_back(AnfAlgo::GetPrevNodeOutputFormat(kernel_node, input_index)); | |||
| inputs_type.push_back(AnfAlgo::GetPrevNodeOutputDeviceDataType(kernel_node, input_index)); | |||
| } | |||
| std::vector<std::string> outputs_format; | |||
| std::vector<TypeId> outputs_type; | |||
| for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(kernel_node); ++output_index) { | |||
| outputs_format.emplace_back(AnfAlgo::GetPrevNodeOutputFormat(kernel_node, output_index)); | |||
| outputs_type.push_back(AnfAlgo::GetOutputInferDataType(kernel_node, output_index)); | |||
| for (const auto &type : kHcclSupportTypes) { | |||
| std::vector<std::string> inputs_format{}; | |||
| std::vector<TypeId> inputs_type{}; | |||
| for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) { | |||
| inputs_format.emplace_back(AnfAlgo::GetPrevNodeOutputFormat(kernel_node, input_index)); | |||
| inputs_type.push_back(type); | |||
| } | |||
| std::vector<std::string> outputs_format; | |||
| std::vector<TypeId> outputs_type; | |||
| for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(kernel_node); ++output_index) { | |||
| outputs_format.emplace_back(AnfAlgo::GetPrevNodeOutputFormat(kernel_node, output_index)); | |||
| outputs_type.push_back(type); | |||
| } | |||
| auto builder = KernelBuildInfo::KernelBuildInfoBuilder(); | |||
| builder.SetInputsFormat(inputs_format); | |||
| builder.SetInputsDeviceType(inputs_type); | |||
| builder.SetOutputsFormat(outputs_format); | |||
| builder.SetOutputsDeviceType(outputs_type); | |||
| builder.SetKernelType(HCCL_KERNEL); | |||
| kernel_info_list->push_back(builder.Build()); | |||
| } | |||
| auto builder = KernelBuildInfo::KernelBuildInfoBuilder(); | |||
| builder.SetInputsFormat(inputs_format); | |||
| builder.SetInputsDeviceType(inputs_type); | |||
| builder.SetOutputsFormat(outputs_format); | |||
| builder.SetOutputsDeviceType(outputs_type); | |||
| builder.SetKernelType(HCCL_KERNEL); | |||
| kernel_info_list->push_back(builder.Build()); | |||
| } | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -120,6 +120,24 @@ bool CheckIndexOutput(const CNodePtr &node, const std::shared_ptr<kernel::Kernel | |||
| return AnfAlgo::GetOutputFormat(node, 0) == kernel_info->GetOutputFormat(index); | |||
| } | |||
| void ChangeNodeInferInfo(const CNodePtr &cnode, const CNodePtr &cast, const size_t cast_index) { | |||
| using Shape = std::vector<size_t>; | |||
| auto cast_dtype = AnfAlgo::GetOutputInferDataType(cast, 0); | |||
| auto cast_shape = AnfAlgo::GetOutputInferShape(cast, 0); | |||
| std::vector<Shape> shapes; | |||
| std::vector<TypeId> types; | |||
| for (size_t index = 0; index < AnfAlgo::GetOutputTensorNum(cnode); ++index) { | |||
| if (cast_index == index) { | |||
| shapes.emplace_back(cast_shape); | |||
| types.emplace_back(cast_dtype); | |||
| continue; | |||
| } | |||
| shapes.emplace_back(AnfAlgo::GetOutputInferShape(cnode, index)); | |||
| types.emplace_back(AnfAlgo::GetOutputInferDataType(cnode, index)); | |||
| } | |||
| AnfAlgo::SetOutputInferTypeAndShape(types, shapes, cnode.get()); | |||
| } | |||
| AnfNodePtr MergeCastToNextOp(const FuncGraphPtr &graph, const CNodePtr &node, const KernelQueryPtr kernel_query) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| MS_EXCEPTION_IF_NULL(kernel_query); | |||
| @@ -151,9 +169,9 @@ AnfNodePtr MergeCastToNextOp(const FuncGraphPtr &graph, const CNodePtr &node, co | |||
| << "ori kernel info" << ori_kernel_info->ToString() << "alternative kernel info" | |||
| << (*alternative_kernel_info)->ToString(); | |||
| AnfAlgo::SetSelectKernelBuildInfo(*alternative_kernel_info, next_cnode.get()); | |||
| ChangeNodeInferInfo(next_cnode, node, cast_index); | |||
| if (node->inputs().size() < kCastInputNum) { | |||
| auto op_name = AnfAlgo::GetCNodeName(node); | |||
| MS_LOG(EXCEPTION) << "op[" << op_name << "] has wrong input num:"; | |||
| MS_LOG(EXCEPTION) << "Op[" << node->DebugString() << "] has wrong input num:"; | |||
| } | |||
| return node->input(1); | |||
| } | |||
| @@ -223,7 +241,11 @@ AnfNodePtr MergeCastToPriorOp(const FuncGraphPtr &graph, const CNodePtr &cur_nod | |||
| << "ori kernel info" << ori_kernel_info->ToString() << "alternative kernel info" | |||
| << (*kernel_info_it)->ToString(); | |||
| AnfAlgo::SetSelectKernelBuildInfo(*kernel_info_it, prior_op.get()); | |||
| ChangeNodeInferInfo(prior_op, cur_node, output_idx); | |||
| if (!single_output) { | |||
| MS_EXCEPTION_IF_NULL(x_node); | |||
| ChangeNodeInferInfo(x_node->cast<CNodePtr>(), cur_node, 0); | |||
| } | |||
| auto prior_name = AnfAlgo::GetCNodeName(prior_op); | |||
| if (prior_name == kFive2FourOpName) { | |||
| AnfAlgo::CopyNodeAttr("dst_type", "dstType", cur_node, prior_op); | |||