| @@ -95,6 +95,16 @@ constexpr auto kJSocVersion = "socVersion"; | |||||
| constexpr auto kSOC_VERSION = "SOC_VERSION"; | constexpr auto kSOC_VERSION = "SOC_VERSION"; | ||||
| constexpr auto kJIsDynamicShape = "is_dynamic_shape"; | constexpr auto kJIsDynamicShape = "is_dynamic_shape"; | ||||
| bool IsNeedChangeDefaultFormat(const CNodePtr &cnode) { | |||||
| MS_EXCEPTION_IF_NULL(cnode); | |||||
| MS_LOG(INFO) << "Check if need change default format"; | |||||
| if (AnfAlgo::HasNodeAttr("io_format", cnode->cast<CNodePtr>())) { | |||||
| auto attr = AnfAlgo::GetNodeAttr<std::string>(cnode, "io_format"); | |||||
| return attr == kOpFormat_NCDHW; | |||||
| } | |||||
| return false; | |||||
| } | |||||
| bool TbeKernelJsonCreator::GenTbeSingleKernelJson(const std::shared_ptr<mindspore::AnfNode> &anf_node, | bool TbeKernelJsonCreator::GenTbeSingleKernelJson(const std::shared_ptr<mindspore::AnfNode> &anf_node, | ||||
| nlohmann::json *kernel_json) { | nlohmann::json *kernel_json) { | ||||
| MS_EXCEPTION_IF_NULL(anf_node); | MS_EXCEPTION_IF_NULL(anf_node); | ||||
| @@ -161,10 +171,14 @@ void TbeKernelJsonCreator::GenValidInputDescJson(const std::shared_ptr<AnfNode> | |||||
| bool value, const std::shared_ptr<OpIOInfo> &input_ptr, | bool value, const std::shared_ptr<OpIOInfo> &input_ptr, | ||||
| const string &op_input_name, size_t input_i, | const string &op_input_name, size_t input_i, | ||||
| std::vector<nlohmann::json> *input_list) { | std::vector<nlohmann::json> *input_list) { | ||||
| auto def_format = kOpFormat_NCHW; | |||||
| auto dtype = GetDeviceInputType(anf_node, real_input_index); | auto dtype = GetDeviceInputType(anf_node, real_input_index); | ||||
| auto format = GetDeviceInputFormat(anf_node, real_input_index); | auto format = GetDeviceInputFormat(anf_node, real_input_index); | ||||
| auto shape = GetDeviceInputShape(anf_node, real_input_index); | auto shape = GetDeviceInputShape(anf_node, real_input_index); | ||||
| auto ori_shape = AnfAlgo::GetPrevNodeOutputInferShape(anf_node, real_input_index); | auto ori_shape = AnfAlgo::GetPrevNodeOutputInferShape(anf_node, real_input_index); | ||||
| if (anf_node->isa<CNode>() && IsNeedChangeDefaultFormat(anf_node->cast<CNodePtr>())) { | |||||
| def_format = kOpFormat_NCDHW; | |||||
| } | |||||
| if (ori_shape.empty()) { | if (ori_shape.empty()) { | ||||
| ori_shape.emplace_back(1); | ori_shape.emplace_back(1); | ||||
| } | } | ||||
| @@ -172,7 +186,7 @@ void TbeKernelJsonCreator::GenValidInputDescJson(const std::shared_ptr<AnfNode> | |||||
| input_desc_json[kJDtype] = dtype; | input_desc_json[kJDtype] = dtype; | ||||
| input_desc_json[kJName] = op_input_name + std::to_string(input_i); | input_desc_json[kJName] = op_input_name + std::to_string(input_i); | ||||
| input_desc_json[kJOriShape] = ori_shape; | input_desc_json[kJOriShape] = ori_shape; | ||||
| input_desc_json[kJOriFormat] = kOpFormat_NCHW; | |||||
| input_desc_json[kJOriFormat] = def_format; | |||||
| input_desc_json[kJShape] = shape; | input_desc_json[kJShape] = shape; | ||||
| input_desc_json[kJFormat] = format; | input_desc_json[kJFormat] = format; | ||||
| input_desc_json[kJValid] = value; | input_desc_json[kJValid] = value; | ||||
| @@ -379,6 +393,10 @@ void TbeKernelJsonCreator::GenOutputList(const std::shared_ptr<AnfNode> &anf_nod | |||||
| std::vector<nlohmann::json> *output_list) { | std::vector<nlohmann::json> *output_list) { | ||||
| MS_EXCEPTION_IF_NULL(output_idx); | MS_EXCEPTION_IF_NULL(output_idx); | ||||
| MS_EXCEPTION_IF_NULL(output_list); | MS_EXCEPTION_IF_NULL(output_list); | ||||
| auto def_format = kOpFormat_NCHW; | |||||
| if (anf_node->isa<CNode>() && IsNeedChangeDefaultFormat(anf_node->cast<CNodePtr>())) { | |||||
| def_format = kOpFormat_NCDHW; | |||||
| } | |||||
| for (size_t i = 0; i < output_obj_num; i++) { | for (size_t i = 0; i < output_obj_num; i++) { | ||||
| auto dtype = GetDeviceOutputType(anf_node, *output_idx); | auto dtype = GetDeviceOutputType(anf_node, *output_idx); | ||||
| auto format = GetDeviceOutputFormat(anf_node, *output_idx); | auto format = GetDeviceOutputFormat(anf_node, *output_idx); | ||||
| @@ -397,7 +415,7 @@ void TbeKernelJsonCreator::GenOutputList(const std::shared_ptr<AnfNode> &anf_nod | |||||
| output_obj[kJShape] = shape; | output_obj[kJShape] = shape; | ||||
| output_obj[kJFormat] = format; | output_obj[kJFormat] = format; | ||||
| output_obj[kJOriShape] = ori_shape; | output_obj[kJOriShape] = ori_shape; | ||||
| output_obj[kJOriFormat] = kOpFormat_NCHW; | |||||
| output_obj[kJOriFormat] = def_format; | |||||
| output_obj[kJName] = output_ptr->name(); | output_obj[kJName] = output_ptr->name(); | ||||
| output_obj[kJValid] = true; | output_obj[kJValid] = true; | ||||
| output_obj[kJParamType] = output_ptr->param_type(); | output_obj[kJParamType] = output_ptr->param_type(); | ||||
| @@ -580,6 +598,9 @@ std::string TbeKernelJsonCreator::GetDeviceInputFormat(const AnfNodePtr &anf_nod | |||||
| format = kOpFormat_NCHW; | format = kOpFormat_NCHW; | ||||
| } | } | ||||
| } | } | ||||
| if (anf_node->isa<CNode>() && IsNeedChangeDefaultFormat(anf_node->cast<CNodePtr>())) { | |||||
| format = kOpFormat_NCDHW; | |||||
| } | |||||
| return format; | return format; | ||||
| } | } | ||||
| @@ -619,6 +640,9 @@ std::string TbeKernelJsonCreator::GetDeviceOutputFormat(const AnfNodePtr &anf_no | |||||
| format = kOpFormat_NCHW; | format = kOpFormat_NCHW; | ||||
| } | } | ||||
| } | } | ||||
| if (anf_node->isa<CNode>() && IsNeedChangeDefaultFormat(anf_node->cast<CNodePtr>())) { | |||||
| format = kOpFormat_NCDHW; | |||||
| } | |||||
| return format; | return format; | ||||
| } | } | ||||
| @@ -818,6 +842,10 @@ void TbeKernelBuild::GenSuffixDescJson(nlohmann::json *output_desc) { | |||||
| void TbeKernelBuild::GenDescJson(const std::shared_ptr<mindspore::AnfNode> &anf_node, size_t node_out_idx, | void TbeKernelBuild::GenDescJson(const std::shared_ptr<mindspore::AnfNode> &anf_node, size_t node_out_idx, | ||||
| size_t desc_output_idx, nlohmann::json *output_desc, FusionDataType fusion_data_type) { | size_t desc_output_idx, nlohmann::json *output_desc, FusionDataType fusion_data_type) { | ||||
| GenPreDescJson(output_desc); | GenPreDescJson(output_desc); | ||||
| auto def_format = kOpFormat_NCHW; | |||||
| if (anf_node->isa<CNode>() && IsNeedChangeDefaultFormat(anf_node->cast<CNodePtr>())) { | |||||
| def_format = kOpFormat_NCDHW; | |||||
| } | |||||
| // data_type | // data_type | ||||
| auto type_id = AnfAlgo::GetOutputDeviceDataType(anf_node, node_out_idx); | auto type_id = AnfAlgo::GetOutputDeviceDataType(anf_node, node_out_idx); | ||||
| (*output_desc)[kJDataType] = tbe::TypeIdToString(type_id); | (*output_desc)[kJDataType] = tbe::TypeIdToString(type_id); | ||||
| @@ -828,7 +856,7 @@ void TbeKernelBuild::GenDescJson(const std::shared_ptr<mindspore::AnfNode> &anf_ | |||||
| } | } | ||||
| (*output_desc)[kJName] = output_desc_name; | (*output_desc)[kJName] = output_desc_name; | ||||
| // ori_format | // ori_format | ||||
| (*output_desc)[kJOriFormat] = kOpFormat_NCHW; | |||||
| (*output_desc)[kJOriFormat] = def_format; | |||||
| // ori_shape | // ori_shape | ||||
| auto ori_shape = AnfAlgo::GetOutputInferShape(anf_node, node_out_idx); | auto ori_shape = AnfAlgo::GetOutputInferShape(anf_node, node_out_idx); | ||||
| if (ori_shape.empty()) { | if (ori_shape.empty()) { | ||||
| @@ -248,13 +248,57 @@ bool TbeKernelBroadCastSelecter::IsBroadCastSupportFracNZ(SupportFormat *support | |||||
| bool TbeKernelBroadCastSelecter::IsBroadCastSupportNDC1HWC0(SupportFormat *support_format) const { | bool TbeKernelBroadCastSelecter::IsBroadCastSupportNDC1HWC0(SupportFormat *support_format) const { | ||||
| MS_EXCEPTION_IF_NULL(support_format); | MS_EXCEPTION_IF_NULL(support_format); | ||||
| return false; | |||||
| if (IsSameShape()) { | |||||
| if (!HasScalarInput()) { | |||||
| AssignSupportFormat(kOpFormat_NDC1HWC0, support_format); | |||||
| return true; | |||||
| } | |||||
| return false; | |||||
| } | |||||
| SupportFormatItem input_support_format; | |||||
| SupportFormatItem output_support_format; | |||||
| if (HasScalarInput()) { | |||||
| for (const auto &shape : input_shapes_) { | |||||
| if (IsScalarShape(shape)) { | |||||
| input_support_format.emplace_back(kOpFormat_NCDHW); | |||||
| } else if (!Is5DShape(shape)) { | |||||
| return false; | |||||
| } else if (shape[kChannelC] % kAlignmented16 != 0) { | |||||
| return false; | |||||
| } else { | |||||
| input_support_format.emplace_back(kOpFormat_NDC1HWC0); | |||||
| } | |||||
| } | |||||
| } else { | |||||
| for (const auto &shape : input_shapes_) { | |||||
| if (!Is5DShape(shape)) { | |||||
| return false; | |||||
| } | |||||
| } | |||||
| auto shape_tmp = input_shapes_[0]; | |||||
| auto broadcast_c_axis = std::any_of( | |||||
| input_shapes_.begin(), input_shapes_.end(), | |||||
| [&shape_tmp](const std::vector<size_t> &elem) { return shape_tmp.at(kChannelC) != elem.at(kChannelC); }); | |||||
| if (broadcast_c_axis) { | |||||
| MS_LOG(INFO) << "This node broadcast c channel."; | |||||
| return false; | |||||
| } | |||||
| input_support_format.assign(input_num_, kOpFormat_NDC1HWC0); | |||||
| } | |||||
| GenOutputSupportFormat(kOpFormat_NDC1HWC0, &output_support_format); | |||||
| support_format->input_format.emplace_back(input_support_format); | |||||
| support_format->output_format.emplace_back(output_support_format); | |||||
| return true; | |||||
| } | } | ||||
| bool TbeKernelBroadCastSelecter::Is4DShape(const std::vector<size_t> &shape) const { | bool TbeKernelBroadCastSelecter::Is4DShape(const std::vector<size_t> &shape) const { | ||||
| return shape.size() == kShape4dDims; | return shape.size() == kShape4dDims; | ||||
| } | } | ||||
| bool TbeKernelBroadCastSelecter::Is5DShape(const std::vector<size_t> &shape) const { | |||||
| return shape.size() == kShape5dDims; | |||||
| } | |||||
| bool TbeKernelBroadCastSelecter::IsSameShape() const { | bool TbeKernelBroadCastSelecter::IsSameShape() const { | ||||
| auto shape = input_shapes_.begin(); | auto shape = input_shapes_.begin(); | ||||
| for (const auto &item : input_shapes_) { | for (const auto &item : input_shapes_) { | ||||
| @@ -40,6 +40,7 @@ class TbeKernelBroadCastSelecter { | |||||
| bool IsSameShape() const; | bool IsSameShape() const; | ||||
| void PadScalarShape(std::vector<size_t> *shape) const; | void PadScalarShape(std::vector<size_t> *shape) const; | ||||
| bool Is4DShape(const std::vector<size_t> &shape) const; | bool Is4DShape(const std::vector<size_t> &shape) const; | ||||
| bool Is5DShape(const std::vector<size_t> &shape) const; | |||||
| bool IsScalarShape(const std::vector<size_t> &shape) const; | bool IsScalarShape(const std::vector<size_t> &shape) const; | ||||
| bool HasScalarInput() const; | bool HasScalarInput() const; | ||||
| void GenOutputSupportFormat(const std::string &support_format, SupportFormatItem *output_support_item) const; | void GenOutputSupportFormat(const std::string &support_format, SupportFormatItem *output_support_item) const; | ||||
| @@ -72,8 +72,18 @@ bool TbeKernelReduceSelecter::IsReduceSupport5HD(SupportFormat *support_format) | |||||
| bool TbeKernelReduceSelecter::IsReduceSupportNDC1HWC0(SupportFormat *support_format) const { | bool TbeKernelReduceSelecter::IsReduceSupportNDC1HWC0(SupportFormat *support_format) const { | ||||
| MS_EXCEPTION_IF_NULL(support_format); | MS_EXCEPTION_IF_NULL(support_format); | ||||
| // like to 5HD | |||||
| return false; | |||||
| if (!Is5DShape(input_shape_)) { | |||||
| return false; | |||||
| } | |||||
| if (!keep_dims_ || axis_.empty()) { | |||||
| return false; | |||||
| } | |||||
| auto reduce_c_axis = std::any_of(axis_.begin(), axis_.end(), [](const size_t &elem) { return (elem == kChannelC); }); | |||||
| if (reduce_c_axis) { | |||||
| return false; | |||||
| } | |||||
| AssignSupportFormat(kOpFormat_NDC1HWC0, support_format); | |||||
| return true; | |||||
| } | } | ||||
| bool TbeKernelReduceSelecter::IsReduceSupportFracZ(SupportFormat *support_format) const { | bool TbeKernelReduceSelecter::IsReduceSupportFracZ(SupportFormat *support_format) const { | ||||
| @@ -142,6 +152,8 @@ void TbeKernelReduceSelecter::AssignSupportFormat(const std::string &support_for | |||||
| bool TbeKernelReduceSelecter::Is4DShape(const std::vector<size_t> &shape) const { return shape.size() == kShape4dDims; } | bool TbeKernelReduceSelecter::Is4DShape(const std::vector<size_t> &shape) const { return shape.size() == kShape4dDims; } | ||||
| bool TbeKernelReduceSelecter::Is5DShape(const std::vector<size_t> &shape) const { return shape.size() == kShape5dDims; } | |||||
| void TbeKernelReduceSelecter::PadScalarShape(std::vector<size_t> *shape) const { | void TbeKernelReduceSelecter::PadScalarShape(std::vector<size_t> *shape) const { | ||||
| MS_EXCEPTION_IF_NULL(shape); | MS_EXCEPTION_IF_NULL(shape); | ||||
| if (shape->empty()) { | if (shape->empty()) { | ||||
| @@ -39,6 +39,7 @@ class TbeKernelReduceSelecter { | |||||
| void GetReduceAttrKeepDim(); | void GetReduceAttrKeepDim(); | ||||
| void AssignSupportFormat(const std::string &support_format_str, SupportFormat *support_format) const; | void AssignSupportFormat(const std::string &support_format_str, SupportFormat *support_format) const; | ||||
| bool Is4DShape(const std::vector<size_t> &shape) const; | bool Is4DShape(const std::vector<size_t> &shape) const; | ||||
| bool Is5DShape(const std::vector<size_t> &shape) const; | |||||
| void PadScalarShape(std::vector<size_t> *shape) const; | void PadScalarShape(std::vector<size_t> *shape) const; | ||||
| CNodePtr cnode_ptr_; | CNodePtr cnode_ptr_; | ||||
| std::vector<size_t> input_shape_{}; | std::vector<size_t> input_shape_{}; | ||||
| @@ -187,6 +187,9 @@ void TbeKernelSelect::GetBroadcastPatternKernelInfo(const OpInfo &op_info) { | |||||
| if (!broadcast_selecter.IsBroadCastSupportFracNZ(&support_format)) { | if (!broadcast_selecter.IsBroadCastSupportFracNZ(&support_format)) { | ||||
| MS_LOG(INFO) << "Node(" << node_name_ << ") does not support FracNZ."; | MS_LOG(INFO) << "Node(" << node_name_ << ") does not support FracNZ."; | ||||
| } | } | ||||
| if (!broadcast_selecter.IsBroadCastSupportNDC1HWC0(&support_format)) { | |||||
| MS_LOG(INFO) << "Node(" << node_name_ << ") does not support NDC1HWC0."; | |||||
| } | |||||
| PrintSupportedFormat(support_format); | PrintSupportedFormat(support_format); | ||||
| OpInfo op_info_new; | OpInfo op_info_new; | ||||
| CreateNewOpInfo(op_info, support_format, &op_info_new); | CreateNewOpInfo(op_info, support_format, &op_info_new); | ||||
| @@ -281,10 +284,8 @@ bool TbeKernelSelect::IsShapeMatchFormat(const std::vector<size_t> &shape, const | |||||
| return true; | return true; | ||||
| } | } | ||||
| // not support format: | // not support format: | ||||
| // 1 NDHWC with shape size != 5 | |||||
| // 3 !NDHWC with shape size > 4 | |||||
| if ((format == kOpFormat_NDHWC && shape.size() != kShape5dDims) || | |||||
| (format != kOpFormat_NDHWC && shape.size() > kShape4dDims)) { | |||||
| // 1 NCDHW with shape size != 5 | |||||
| if (format == kOpFormat_NCDHW && shape.size() != kShape5dDims) { | |||||
| MS_LOG(INFO) << "Warning: Shape format check failed, format: " << format << ", size: " << shape.size(); | MS_LOG(INFO) << "Warning: Shape format check failed, format: " << format << ", size: " << shape.size(); | ||||
| return false; | return false; | ||||
| } | } | ||||
| @@ -32,7 +32,7 @@ namespace mindspore { | |||||
| namespace opt { | namespace opt { | ||||
| using KernelBuildInfoBuilder = kernel::KernelBuildInfo::KernelBuildInfoBuilder; | using KernelBuildInfoBuilder = kernel::KernelBuildInfo::KernelBuildInfoBuilder; | ||||
| namespace { | namespace { | ||||
| const std::set<std::string> kCommonFormatSet = {kOpFormat_DEFAULT, kOpFormat_ND, kOpFormat_NCHW}; | |||||
| const std::set<std::string> kCommonFormatSet = {kOpFormat_DEFAULT, kOpFormat_ND, kOpFormat_NCHW, kOpFormat_NCDHW}; | |||||
| AnfNodePtr CreateReshapeNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node, | AnfNodePtr CreateReshapeNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node, | ||||
| const KernelSelectPtr &kernel_select, const std::vector<size_t> &dst_shape) { | const KernelSelectPtr &kernel_select, const std::vector<size_t> &dst_shape) { | ||||
| std::vector<AnfNodePtr> trans_inputs; | std::vector<AnfNodePtr> trans_inputs; | ||||
| @@ -70,9 +70,17 @@ AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePt | |||||
| CNodePtr trans_data = nullptr; | CNodePtr trans_data = nullptr; | ||||
| MS_EXCEPTION_IF_NULL(node); | MS_EXCEPTION_IF_NULL(node); | ||||
| // Init | // Init | ||||
| std::string default_format = kOpFormat_DEFAULT; | |||||
| if (node->isa<CNode>() && AnfAlgo::HasNodeAttr("io_format", node->cast<CNodePtr>())) { | |||||
| auto attr = AnfAlgo::GetNodeAttr<std::string>(node, "io_format"); | |||||
| if (attr == kOpFormat_NCDHW) { | |||||
| default_format = kOpFormat_NCDHW; | |||||
| } | |||||
| } | |||||
| AnfNodePtr input_node = is_insert_input ? AnfAlgo::GetInputNode(node->cast<CNodePtr>(), insert_index) : node; | AnfNodePtr input_node = is_insert_input ? AnfAlgo::GetInputNode(node->cast<CNodePtr>(), insert_index) : node; | ||||
| std::string input_format = is_insert_input ? kOpFormat_DEFAULT : AnfAlgo::GetOutputFormat(node, insert_index); | |||||
| std::string dst_format = is_insert_input ? AnfAlgo::GetInputFormat(node, insert_index) : kOpFormat_DEFAULT; | |||||
| std::string input_format = is_insert_input ? default_format : AnfAlgo::GetOutputFormat(node, insert_index); | |||||
| std::string dst_format = is_insert_input ? AnfAlgo::GetInputFormat(node, insert_index) : default_format; | |||||
| std::vector<Axis> padding_axis = is_insert_input ? AnfAlgo::GetInputReshapeType(node, insert_index) | std::vector<Axis> padding_axis = is_insert_input ? AnfAlgo::GetInputReshapeType(node, insert_index) | ||||
| : AnfAlgo::GetOutputReshapeType(node, insert_index); | : AnfAlgo::GetOutputReshapeType(node, insert_index); | ||||
| auto input_node_out_shape = is_insert_input ? AnfAlgo::GetPrevNodeOutputInferShape(node, insert_index) | auto input_node_out_shape = is_insert_input ? AnfAlgo::GetPrevNodeOutputInferShape(node, insert_index) | ||||
| @@ -369,6 +369,26 @@ void KernelGraph::CheckLoop() { | |||||
| } | } | ||||
| } | } | ||||
| void ReSetParameterValueNodeFormatAndType(const AnfNodePtr &node, const std::string &format) { | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(); | |||||
| MS_EXCEPTION_IF_NULL(kernel_build_info_builder); | |||||
| kernel_build_info_builder->SetOutputsFormat({format}); | |||||
| kernel_build_info_builder->SetOutputsDeviceType({AnfAlgo::GetOutputInferDataType(node, 0)}); | |||||
| AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), node.get()); | |||||
| } | |||||
| void KernelGraph::ResetInFormat(const AnfNodePtr &node, const std::string &format) const { | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(node); i++) { | |||||
| auto in_node = AnfAlgo::GetInputNode(node->cast<CNodePtr>(), i); | |||||
| MS_EXCEPTION_IF_NULL(in_node); | |||||
| if (in_node->isa<Parameter>() || in_node->isa<ValueNode>()) { | |||||
| ReSetParameterValueNodeFormatAndType(in_node, format); | |||||
| } | |||||
| } | |||||
| } | |||||
| CNodePtr KernelGraph::NewCNode(const std::vector<AnfNodePtr> &inputs) { | CNodePtr KernelGraph::NewCNode(const std::vector<AnfNodePtr> &inputs) { | ||||
| auto cnode = FuncGraph::NewCNode(inputs); | auto cnode = FuncGraph::NewCNode(inputs); | ||||
| MS_EXCEPTION_IF_NULL(cnode); | MS_EXCEPTION_IF_NULL(cnode); | ||||
| @@ -378,6 +398,12 @@ CNodePtr KernelGraph::NewCNode(const std::vector<AnfNodePtr> &inputs) { | |||||
| AnfAlgo::SetNodeAttr(kIsBackendCast, MakeValue(false), cnode); | AnfAlgo::SetNodeAttr(kIsBackendCast, MakeValue(false), cnode); | ||||
| } | } | ||||
| SetKernelInfoForNode(cnode); | SetKernelInfoForNode(cnode); | ||||
| if (AnfAlgo::HasNodeAttr("io_format", cnode)) { | |||||
| auto attr = AnfAlgo::GetNodeAttr<std::string>(cnode, "io_format"); | |||||
| if (attr == kOpFormat_NCDHW) { | |||||
| ResetInFormat(cnode, kOpFormat_NCDHW); | |||||
| } | |||||
| } | |||||
| AnfAlgo::SetGraphId(graph_id_, cnode.get()); | AnfAlgo::SetGraphId(graph_id_, cnode.get()); | ||||
| return cnode; | return cnode; | ||||
| } | } | ||||
| @@ -273,6 +273,7 @@ class KernelGraph : public FuncGraph { | |||||
| // remove value node form graph | // remove value node form graph | ||||
| bool RemoveValueNodeFromGraph(const ValueNodePtr &value_node); | bool RemoveValueNodeFromGraph(const ValueNodePtr &value_node); | ||||
| void SetKernelInfoForNode(const AnfNodePtr &node) const; | void SetKernelInfoForNode(const AnfNodePtr &node) const; | ||||
| void ResetInFormat(const AnfNodePtr &node, const std::string &format) const; | |||||
| AnfNodePtr MakeValueNode(const AnfNodePtr &node); | AnfNodePtr MakeValueNode(const AnfNodePtr &node); | ||||
| void VisitNodeDescendants(const AnfNodePtr &node, std::queue<AnfNodePtr> *visit_queue, | void VisitNodeDescendants(const AnfNodePtr &node, std::queue<AnfNodePtr> *visit_queue, | ||||
| std::unordered_set<AnfNodePtr> *visited_nodes); | std::unordered_set<AnfNodePtr> *visited_nodes); | ||||
| @@ -266,6 +266,41 @@ std::vector<size_t> Nc1hwc0DeviceShape(const std::vector<size_t> &shape) { | |||||
| return device_shape; | return device_shape; | ||||
| } | } | ||||
| std::vector<size_t> Ndc1hwc0DeviceShape(const std::vector<size_t> &shape) { | |||||
| // NCDHW | |||||
| if (shape.size() != 5) { | |||||
| MS_LOG(EXCEPTION) << "Check dims failed, expect shape dim 5, but got shape dim : " << shape.size(); | |||||
| } | |||||
| std::vector<size_t> device_shape; | |||||
| const size_t C1 = (shape[1] + kCubeSize - 1) / kCubeSize; | |||||
| const size_t C0 = kCubeSize; | |||||
| device_shape.push_back(shape[0]); | |||||
| device_shape.push_back(shape[2]); | |||||
| device_shape.push_back(C1); | |||||
| device_shape.push_back(shape[3]); | |||||
| device_shape.push_back(shape[4]); | |||||
| device_shape.push_back(C0); | |||||
| return device_shape; | |||||
| } | |||||
| std::vector<size_t> Fracz3DDeviceShape(const std::vector<size_t> &shape) { | |||||
| // NCDHW -> Frac_Z_3D | |||||
| if (shape.size() != 5) { | |||||
| MS_LOG(EXCEPTION) << "Check dims failed, expect shape dim 5, but got shape dim : " << shape.size(); | |||||
| } | |||||
| std::vector<size_t> device_shape; | |||||
| const size_t C1 = (shape[1] + kCubeSize - 1) / kCubeSize; | |||||
| const size_t N1 = (shape[0] + kCubeSize - 1) / kCubeSize; | |||||
| device_shape.push_back(shape[2]); | |||||
| device_shape.push_back(C1); | |||||
| device_shape.push_back(shape[3]); | |||||
| device_shape.push_back(shape[4]); | |||||
| device_shape.push_back(N1); | |||||
| device_shape.push_back(kCubeSize); | |||||
| device_shape.push_back(kCubeSize); | |||||
| return device_shape; | |||||
| } | |||||
| std::vector<size_t> C1hwncoc0DeviceShape(const std::vector<size_t> &shape) { | std::vector<size_t> C1hwncoc0DeviceShape(const std::vector<size_t> &shape) { | ||||
| if (!CheckDims(shape)) { | if (!CheckDims(shape)) { | ||||
| MS_LOG(EXCEPTION) << "Check dims failed."; | MS_LOG(EXCEPTION) << "Check dims failed."; | ||||
| @@ -310,7 +345,7 @@ std::vector<size_t> Nc1hwc04DeviceShape(const std::vector<size_t> &shape) { | |||||
| return device_shape; | return device_shape; | ||||
| } | } | ||||
| std::vector<size_t> NdhwcDeviceShape(const std::vector<size_t> &shape) { | |||||
| std::vector<size_t> NcdhwDeviceShape(const std::vector<size_t> &shape) { | |||||
| if (shape.size() < kNdhwc) { | if (shape.size() < kNdhwc) { | ||||
| MS_LOG(EXCEPTION) << "Shape dims must be 5 when format is ndhwc."; | MS_LOG(EXCEPTION) << "Shape dims must be 5 when format is ndhwc."; | ||||
| } | } | ||||
| @@ -405,7 +440,9 @@ std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const s | |||||
| {kOpFormat_C1HWNCoC0, C1hwncoc0DeviceShape}, | {kOpFormat_C1HWNCoC0, C1hwncoc0DeviceShape}, | ||||
| {kOpFormat_FRACTAL_Z_C04, FracZc04DeviceShape}, | {kOpFormat_FRACTAL_Z_C04, FracZc04DeviceShape}, | ||||
| {kOpFormat_NC1HWC0_C04, Nc1hwc04DeviceShape}, | {kOpFormat_NC1HWC0_C04, Nc1hwc04DeviceShape}, | ||||
| {kOpFormat_NDHWC, NdhwcDeviceShape}}; | |||||
| {kOpFormat_NCDHW, NcdhwDeviceShape}, | |||||
| {kOpFormat_NDC1HWC0, Ndc1hwc0DeviceShape}, | |||||
| {kOpFormat_FRACTAL_Z_3D, Fracz3DDeviceShape}}; | |||||
| if (format == kOpFormat_ND || format == kOpFormat_DEFAULT) { | if (format == kOpFormat_ND || format == kOpFormat_DEFAULT) { | ||||
| return shape; | return shape; | ||||
| @@ -441,7 +478,7 @@ std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const s | |||||
| device_shape.push_back(kCubeSize); | device_shape.push_back(kCubeSize); | ||||
| return device_shape; | return device_shape; | ||||
| } | } | ||||
| if (shape.size() != kNchwDims) { | |||||
| if (shape.size() != kNchwDims && shape.size() != 5) { | |||||
| MS_LOG(WARNING) << "Get Device Shape using a shape size is less than 4 ,should be Padding shape by Default firstly"; | MS_LOG(WARNING) << "Get Device Shape using a shape size is less than 4 ,should be Padding shape by Default firstly"; | ||||
| temp_shape = PaddingShapeTo4dByDefault(shape); | temp_shape = PaddingShapeTo4dByDefault(shape); | ||||
| } | } | ||||
| @@ -496,7 +533,9 @@ bool TransFormat(const FormatArgs &args, void *result) { | |||||
| const std::map<std::string, FormatTransfer> format_trans_map{ | const std::map<std::string, FormatTransfer> format_trans_map{ | ||||
| {kOpFormat_FRAC_Z, NchwToFracZ}, {kOpFormat_FRAC_NZ, NchwToFracNz}, | {kOpFormat_FRAC_Z, NchwToFracZ}, {kOpFormat_FRAC_NZ, NchwToFracNz}, | ||||
| {kOpFormat_NC1HWC0, NchwToNc1hwc0}, {kOpFormat_C1HWNCoC0, NchwToC1hwncoc0}, | {kOpFormat_NC1HWC0, NchwToNc1hwc0}, {kOpFormat_C1HWNCoC0, NchwToC1hwncoc0}, | ||||
| {kOpFormat_FRACTAL_Z_C04, NchwToFracZc04}, {kOpFormat_NC1HWC0_C04, NchwToNc1hwc04}}; | |||||
| {kOpFormat_FRACTAL_Z_C04, NchwToFracZc04}, {kOpFormat_NC1HWC0_C04, NchwToNc1hwc04}, | |||||
| {kOpFormat_NDC1HWC0, NcdhwToNdc1hwc0}}; | |||||
| MS_LOG(DEBUG) << "Start trans format."; | MS_LOG(DEBUG) << "Start trans format."; | ||||
| if (abstract::TypeIdSize(args.src_data_type) < 1) { | if (abstract::TypeIdSize(args.src_data_type) < 1) { | ||||
| MS_LOG(ERROR) << "Invalid datatype.."; | MS_LOG(ERROR) << "Invalid datatype.."; | ||||
| @@ -514,11 +553,11 @@ bool TransFormat(const FormatArgs &args, void *result) { | |||||
| bool TransFormatFromDeviceToHost(const FormatArgs &args, void *result) { | bool TransFormatFromDeviceToHost(const FormatArgs &args, void *result) { | ||||
| using FormatTransfer = std::function<bool(const FormatArgs &, void *)>; | using FormatTransfer = std::function<bool(const FormatArgs &, void *)>; | ||||
| const std::map<std::string, FormatTransfer> format_trans_map{{kOpFormat_FRAC_Z, FracZToNchw}, | |||||
| {kOpFormat_FRAC_NZ, FracNzToNchw}, | |||||
| {kOpFormat_NC1HWC0, Nc1hwc0ToNchw}, | |||||
| {kOpFormat_C1HWNCoC0, C1hwncoc0ToNchw}, | |||||
| {kOpFormat_NC1HWC0_C04, Nc1hwc04ToNchw}}; | |||||
| const std::map<std::string, FormatTransfer> format_trans_map{ | |||||
| {kOpFormat_FRAC_Z, FracZToNchw}, {kOpFormat_FRAC_NZ, FracNzToNchw}, | |||||
| {kOpFormat_NC1HWC0, Nc1hwc0ToNchw}, {kOpFormat_C1HWNCoC0, C1hwncoc0ToNchw}, | |||||
| {kOpFormat_NC1HWC0_C04, Nc1hwc04ToNchw}, {kOpFormat_NDC1HWC0, Ndc1hwc0ToNcdhw}}; | |||||
| MS_LOG(DEBUG) << "Start trans format."; | MS_LOG(DEBUG) << "Start trans format."; | ||||
| if (abstract::TypeIdSize(args.src_data_type) < 1) { | if (abstract::TypeIdSize(args.src_data_type) < 1) { | ||||
| MS_LOG(ERROR) << "Invalid datatype.."; | MS_LOG(ERROR) << "Invalid datatype.."; | ||||
| @@ -1106,5 +1145,119 @@ bool C1hwncoc0ToNchw(const FormatArgs &args, void *result) { | |||||
| } | } | ||||
| return true; | return true; | ||||
| } | } | ||||
| bool Ndc1hwc0ToNcdhw(const FormatArgs &args, void *result) { | |||||
| MS_LOG(DEBUG) << "Trans from ndc1hwc0 to ncdhw"; | |||||
| MS_EXCEPTION_IF_NULL(result); | |||||
| if (args.host_shape.size() != 5) { | |||||
| MS_LOG(ERROR) << "Illegal host shape dim, expect dim: 5, but got " << args.host_shape.size(); | |||||
| return false; | |||||
| } | |||||
| auto size = abstract::TypeIdSize(args.src_data_type); | |||||
| if (size < 1) { | |||||
| MS_LOG(ERROR) << "Illegal dtype."; | |||||
| return false; | |||||
| } | |||||
| auto total_size = abstract::ShapeSize(args.device_shape) * size; | |||||
| if (total_size != args.device_size) { | |||||
| MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size; | |||||
| return false; | |||||
| } | |||||
| auto n = args.host_shape[0]; | |||||
| auto c = args.host_shape[1]; | |||||
| auto d = args.host_shape[2]; | |||||
| auto h = args.host_shape[3]; | |||||
| auto w = args.host_shape[4]; | |||||
| auto c1 = args.device_shape[2]; | |||||
| auto c0 = args.device_shape[5]; | |||||
| const size_t cdhw = c * d * h * w; | |||||
| const size_t dhw = d * h * w; | |||||
| const size_t hw = h * w; | |||||
| const size_t dc1hwc0 = d * c1 * h * w * c0; | |||||
| const size_t c1hwc0 = c1 * h * w * c0; | |||||
| const size_t hwc0 = h * w * c0; | |||||
| const size_t wc0 = w * c0; | |||||
| for (size_t n_i = 0; n_i < n; n_i++) { | |||||
| size_t n_head = n_i * cdhw; | |||||
| for (size_t c_i = 0; c_i < c; c_i++) { | |||||
| size_t c_head = n_head + c_i * dhw; | |||||
| for (size_t d_i = 0; d_i < d; d_i++) { | |||||
| size_t d_head = c_head + d_i * hw; | |||||
| for (size_t h_i = 0; h_i < h; h_i++) { | |||||
| size_t h_head = d_head + h_i * w; | |||||
| for (size_t w_i = 0; w_i < w; w_i++) { | |||||
| size_t dst_i = h_head + w_i; | |||||
| size_t c1_i = c_i / c0; | |||||
| size_t c0_i = c_i % c0; | |||||
| auto src_idx = n_i * dc1hwc0 + d_i * c1hwc0 + c1_i * hwc0 + h_i * wc0 + w_i * c0 + c0_i; | |||||
| SetData(size, false, src_idx, dst_i, args, result); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| return true; | |||||
| } | |||||
| bool NcdhwToNdc1hwc0(const FormatArgs &args, void *result) { | |||||
| MS_LOG(DEBUG) << "Trans from ncdhw to ndc1hwc0"; | |||||
| MS_EXCEPTION_IF_NULL(result); | |||||
| if (args.host_shape.size() != 5) { | |||||
| MS_LOG(ERROR) << "Illegal host shape dim, expect dim: 5, but got " << args.host_shape.size(); | |||||
| return false; | |||||
| } | |||||
| auto size = abstract::TypeIdSize(args.src_data_type); | |||||
| if (size < 1) { | |||||
| MS_LOG(ERROR) << "Illegal dtype."; | |||||
| return false; | |||||
| } | |||||
| auto total_size = abstract::ShapeSize(args.device_shape) * size; | |||||
| if (total_size != args.device_size) { | |||||
| MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size; | |||||
| return false; | |||||
| } | |||||
| auto n = args.host_shape[0]; | |||||
| auto c = args.host_shape[1]; | |||||
| auto d = args.host_shape[2]; | |||||
| auto h = args.host_shape[3]; | |||||
| auto w = args.host_shape[4]; | |||||
| auto c0 = kCubeSize; | |||||
| auto c1 = DivCeil(c, c0); | |||||
| const size_t cdhw = c * d * h * w; | |||||
| const size_t dhw = d * h * w; | |||||
| const size_t hw = h * w; | |||||
| const size_t dc1hwc0 = d * c1 * h * w * c0; | |||||
| const size_t c1hwc0 = c1 * h * w * c0; | |||||
| const size_t hwc0 = h * w * c0; | |||||
| const size_t wc0 = w * c0; | |||||
| for (size_t n_i = 0; n_i < n; n_i++) { | |||||
| size_t n_head = n_i * dc1hwc0; | |||||
| for (size_t d_i = 0; d_i < d; d_i++) { | |||||
| size_t d_head = n_head + d_i * c1hwc0; | |||||
| for (size_t c1_i = 0; c1_i < c1; c1_i++) { | |||||
| size_t c1_head = d_head + c1_i * hwc0; | |||||
| for (size_t h_i = 0; h_i < h; h_i++) { | |||||
| size_t h_head = c1_head + h_i * wc0; | |||||
| for (size_t w_i = 0; w_i < w; w_i++) { | |||||
| size_t w_head = h_head + w_i * c0; | |||||
| for (size_t c0_i = 0; c0_i < c0; c0_i++) { | |||||
| size_t dst_i = c0_i + w_head; | |||||
| size_t c_i = c0_i + c1_i * c0; | |||||
| size_t src_i = n_i * cdhw + c_i * dhw + d_i * hw + h_i * w + w_i; | |||||
| auto pad_zero = c_i >= c; | |||||
| SetData(size, pad_zero, src_i, dst_i, args, result); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| return true; | |||||
| } | |||||
| } // namespace trans | } // namespace trans | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -66,6 +66,8 @@ bool NchwToNc1hwc0(const FormatArgs &args, void *result); | |||||
| bool NchwToFracZc04(const FormatArgs &args, void *result); | bool NchwToFracZc04(const FormatArgs &args, void *result); | ||||
| bool NchwToNc1hwc04(const FormatArgs &args, void *result); | bool NchwToNc1hwc04(const FormatArgs &args, void *result); | ||||
| bool NchwToC1hwncoc0(const FormatArgs &args, void *result); | bool NchwToC1hwncoc0(const FormatArgs &args, void *result); | ||||
| bool NcdhwToNdc1hwc0(const FormatArgs &args, void *result); | |||||
| // device to host | // device to host | ||||
| bool ToNchw(const FormatArgs &args, void *result); | bool ToNchw(const FormatArgs &args, void *result); | ||||
| bool FracZToNchw(const FormatArgs &args, void *result); | bool FracZToNchw(const FormatArgs &args, void *result); | ||||
| @@ -73,6 +75,7 @@ bool FracNzToNchw(const FormatArgs &args, void *result); | |||||
| bool Nc1hwc0ToNchw(const FormatArgs &args, void *result); | bool Nc1hwc0ToNchw(const FormatArgs &args, void *result); | ||||
| bool Nc1hwc04ToNchw(const FormatArgs &args, void *result); | bool Nc1hwc04ToNchw(const FormatArgs &args, void *result); | ||||
| bool C1hwncoc0ToNchw(const FormatArgs &args, void *result); | bool C1hwncoc0ToNchw(const FormatArgs &args, void *result); | ||||
| bool Ndc1hwc0ToNcdhw(const FormatArgs &args, void *result); | |||||
| } // namespace trans | } // namespace trans | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -292,7 +292,7 @@ bool AscendDeviceAddress::SyncDeviceToHost(const ShapeVector &shape, size_t size | |||||
| if (host_shape.empty()) { | if (host_shape.empty()) { | ||||
| host_shape.emplace_back(1); | host_shape.emplace_back(1); | ||||
| } | } | ||||
| if (format_ == kOpFormat_NCHW || format_ == kOpFormat_DEFAULT || format_ == kOpFormat_NDHWC) { | |||||
| if (format_ == kOpFormat_NCHW || format_ == kOpFormat_DEFAULT || format_ == kOpFormat_NCDHW) { | |||||
| if (type_id_ == type) { | if (type_id_ == type) { | ||||
| SyncMemory(host_ptr, ptr_, size, RT_MEMCPY_DEVICE_TO_HOST); | SyncMemory(host_ptr, ptr_, size, RT_MEMCPY_DEVICE_TO_HOST); | ||||
| sync_ok = true; | sync_ok = true; | ||||
| @@ -454,7 +454,7 @@ std::vector<size_t> AscendDeviceAddress::GetWorkspaceSizeList(const nlohmann::js | |||||
| std::vector<size_t> AscendDeviceAddress::GetDeviceShape(std::vector<size_t> *host_shape) const { | std::vector<size_t> AscendDeviceAddress::GetDeviceShape(std::vector<size_t> *host_shape) const { | ||||
| std::vector<size_t> device_shape; | std::vector<size_t> device_shape; | ||||
| if (format_ == kOpFormat_FRAC_NZ || format_ == kOpFormat_NDHWC) { | |||||
| if (format_ == kOpFormat_FRAC_NZ || format_ == kOpFormat_NCDHW) { | |||||
| device_shape = trans::TransShapeToDevice(*host_shape, format_); | device_shape = trans::TransShapeToDevice(*host_shape, format_); | ||||
| } else { | } else { | ||||
| if (host_shape_.empty()) { | if (host_shape_.empty()) { | ||||
| @@ -531,7 +531,7 @@ bool AscendDeviceAddress::SyncHostToDevice(const ShapeVector &shape, size_t size | |||||
| if (host_shape.empty()) { | if (host_shape.empty()) { | ||||
| host_shape.emplace_back(1); | host_shape.emplace_back(1); | ||||
| } | } | ||||
| if (format_ == kOpFormat_NCHW || format_ == kOpFormat_DEFAULT || format_ == kOpFormat_NDHWC) { | |||||
| if (format_ == kOpFormat_NCHW || format_ == kOpFormat_DEFAULT || format_ == kOpFormat_NCDHW) { | |||||
| if (type_id_ == type) { | if (type_id_ == type) { | ||||
| SyncMemory(ptr_, host_ptr, size, RT_MEMCPY_HOST_TO_DEVICE); | SyncMemory(ptr_, host_ptr, size, RT_MEMCPY_HOST_TO_DEVICE); | ||||
| sync_ok = true; | sync_ok = true; | ||||
| @@ -575,7 +575,7 @@ bool AscendDeviceAddress::ConvertFormatAndSyncHostToDevice(const ShapeVector &sh | |||||
| host_shape.emplace_back(1); | host_shape.emplace_back(1); | ||||
| } | } | ||||
| std::vector<size_t> device_shape; | std::vector<size_t> device_shape; | ||||
| if (format_ == kOpFormat_FRAC_NZ || format_ == kOpFormat_NDHWC) { | |||||
| if (format_ == kOpFormat_FRAC_NZ || format_ == kOpFormat_NCDHW) { | |||||
| device_shape = trans::TransShapeToDevice(host_shape, format_); | device_shape = trans::TransShapeToDevice(host_shape, format_); | ||||
| } else { | } else { | ||||
| host_shape = trans::PaddingShapeTo4d(host_shape); | host_shape = trans::PaddingShapeTo4d(host_shape); | ||||
| @@ -81,6 +81,7 @@ string GetPriorityMatchFormat(const CNodePtr &cnode) { | |||||
| string priority_matched_format = kOpFormat_NC1HWC0; | string priority_matched_format = kOpFormat_NC1HWC0; | ||||
| bool is_init = false; | bool is_init = false; | ||||
| bool need_change_nd = false; | bool need_change_nd = false; | ||||
| bool is_5d_input = false; | |||||
| for (size_t index = 0; index < AnfAlgo::GetInputTensorNum(cnode); ++index) { | for (size_t index = 0; index < AnfAlgo::GetInputTensorNum(cnode); ++index) { | ||||
| auto pre_output_format = AnfAlgo::GetPrevNodeOutputFormat(cnode, index); | auto pre_output_format = AnfAlgo::GetPrevNodeOutputFormat(cnode, index); | ||||
| if (AnfAlgo::IsFeatureMapInput(cnode, index) && | if (AnfAlgo::IsFeatureMapInput(cnode, index) && | ||||
| @@ -93,14 +94,21 @@ string GetPriorityMatchFormat(const CNodePtr &cnode) { | |||||
| priority_matched_format = kOpFormat_DEFAULT; | priority_matched_format = kOpFormat_DEFAULT; | ||||
| } | } | ||||
| auto input_shape_size = AnfAlgo::GetPrevNodeOutputInferShape(cnode, index).size(); | auto input_shape_size = AnfAlgo::GetPrevNodeOutputInferShape(cnode, index).size(); | ||||
| if (input_shape_size == 5) { | |||||
| is_5d_input = true; | |||||
| } | |||||
| need_change_nd = (need_change_nd || (input_shape_size != 4 && input_shape_size > 1)); | need_change_nd = (need_change_nd || (input_shape_size != 4 && input_shape_size > 1)); | ||||
| } | } | ||||
| if (need_change_nd && priority_matched_format != kOpFormat_FRAC_NZ) { | if (need_change_nd && priority_matched_format != kOpFormat_FRAC_NZ) { | ||||
| priority_matched_format = kOpFormat_DEFAULT; | priority_matched_format = kOpFormat_DEFAULT; | ||||
| } | } | ||||
| if (is_5d_input && priority_matched_format != kOpFormat_FRAC_NZ) { | |||||
| priority_matched_format = kOpFormat_NDC1HWC0; | |||||
| } | |||||
| AnfAlgo::SetNodeAttr(kPriChoosenFormat, MakeValue(priority_matched_format), cnode); | AnfAlgo::SetNodeAttr(kPriChoosenFormat, MakeValue(priority_matched_format), cnode); | ||||
| return priority_matched_format; | return priority_matched_format; | ||||
| } | } | ||||
| /** | /** | ||||
| * Compare two vector by priority, select a better vector, like compare two num, first compare highest num location, | * Compare two vector by priority, select a better vector, like compare two num, first compare highest num location, | ||||
| * if equal then next num location | * if equal then next num location | ||||
| @@ -157,7 +165,8 @@ void UpdateCurMatchCounts(const kernel::KernelBuildInfo &kernel_build_info, cons | |||||
| if (kernel_build_info.GetInputFormat(input_index) == pri_match_format) { | if (kernel_build_info.GetInputFormat(input_index) == pri_match_format) { | ||||
| (*cur_kernelinfo_match_counts)[MATCH_SPECIAL_FORMAT_COUNT] += base_score; | (*cur_kernelinfo_match_counts)[MATCH_SPECIAL_FORMAT_COUNT] += base_score; | ||||
| } | } | ||||
| if (kernel_build_info.GetInputFormat(input_index) == kOpFormat_DEFAULT) { | |||||
| if (kernel_build_info.GetInputFormat(input_index) == kOpFormat_DEFAULT || | |||||
| kernel_build_info.GetInputFormat(input_index) == kOpFormat_NCDHW) { | |||||
| (*cur_kernelinfo_match_counts)[MATCH_DEFAULT_FORMAT_COUNT] += base_score; | (*cur_kernelinfo_match_counts)[MATCH_DEFAULT_FORMAT_COUNT] += base_score; | ||||
| } | } | ||||
| } | } | ||||
| @@ -376,7 +385,9 @@ void SetTensorDeviceInfo(const CNodePtr &kernel_node) { | |||||
| std::vector<std::string> output_format = {AnfAlgo::GetOutputFormat(real_input_node, 0)}; | std::vector<std::string> output_format = {AnfAlgo::GetOutputFormat(real_input_node, 0)}; | ||||
| if (IsValueNode<tensor::Tensor>(input_kernel_node) && | if (IsValueNode<tensor::Tensor>(input_kernel_node) && | ||||
| AnfAlgo::GetOutputDeviceDataType(input_kernel_node, 0) == kTypeUnknown) { | AnfAlgo::GetOutputDeviceDataType(input_kernel_node, 0) == kTypeUnknown) { | ||||
| if (selected_kernel_info->GetInputFormat(input_index) != kOpFormat_FRACTAL_ZN_LSTM) { | |||||
| if (selected_kernel_info->GetInputFormat(input_index) != kOpFormat_FRACTAL_ZN_LSTM || | |||||
| selected_kernel_info->GetInputFormat(input_index) != kOpFormat_FRACTAL_Z_3D || | |||||
| selected_kernel_info->GetInputFormat(input_index) != kOpFormat_NDC1HWC0) { | |||||
| output_format = {selected_kernel_info->GetInputFormat(input_index)}; | output_format = {selected_kernel_info->GetInputFormat(input_index)}; | ||||
| } | } | ||||
| builder->SetOutputsFormat(output_format); | builder->SetOutputsFormat(output_format); | ||||
| @@ -386,7 +397,9 @@ void SetTensorDeviceInfo(const CNodePtr &kernel_node) { | |||||
| continue; | continue; | ||||
| } | } | ||||
| if (AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) == kTypeUnknown || is_ref) { | if (AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) == kTypeUnknown || is_ref) { | ||||
| if (selected_kernel_info->GetInputFormat(input_index) != kOpFormat_FRACTAL_ZN_LSTM) { | |||||
| if (selected_kernel_info->GetInputFormat(input_index) != kOpFormat_FRACTAL_ZN_LSTM || | |||||
| selected_kernel_info->GetInputFormat(input_index) != kOpFormat_FRACTAL_Z_3D || | |||||
| selected_kernel_info->GetInputFormat(input_index) != kOpFormat_NDC1HWC0) { | |||||
| output_format = {selected_kernel_info->GetInputFormat(input_index)}; | output_format = {selected_kernel_info->GetInputFormat(input_index)}; | ||||
| } | } | ||||
| builder->SetOutputsFormat(output_format); | builder->SetOutputsFormat(output_format); | ||||
| @@ -386,11 +386,23 @@ constexpr auto kOpFormat_C1HWNCoC0 = "C1HWNCoC0"; | |||||
| constexpr auto kOpFormat_NC1HWC0_C04 = "NC1HWC0_C04"; | constexpr auto kOpFormat_NC1HWC0_C04 = "NC1HWC0_C04"; | ||||
| constexpr auto kOpFormat_FRACTAL_Z_C04 = "FRACTAL_Z_C04"; | constexpr auto kOpFormat_FRACTAL_Z_C04 = "FRACTAL_Z_C04"; | ||||
| constexpr auto kOpFormat_NDHWC = "NDHWC"; | constexpr auto kOpFormat_NDHWC = "NDHWC"; | ||||
| constexpr auto kOpFormat_NCDHW = "NCDHW"; | |||||
| constexpr auto kOpFormat_DHWNC = "DHWNC"; | |||||
| constexpr auto kOpFormat_DHWCN = "DHWCN"; | |||||
| constexpr auto kOpFormat_NDC1HWC0 = "NDC1HWC0"; | |||||
| constexpr auto kOpFormat_FRACTAL_Z_3D = "FRACTAL_Z_3D"; | |||||
| constexpr auto kOpFormat_FRACTAL_ZN_LSTM = "FRACTAL_ZN_LSTM"; | constexpr auto kOpFormat_FRACTAL_ZN_LSTM = "FRACTAL_ZN_LSTM"; | ||||
| const std::set<std::string> kOpFormatList = { | |||||
| kOpFormat_DEFAULT, kOpFormat_NC1KHKWHWC0, kOpFormat_ND, kOpFormat_NCHW, kOpFormat_NHWC, | |||||
| kOpFormat_HWCN, kOpFormat_NC1HWC0, kOpFormat_FRAC_Z, kOpFormat_C1HWNCoC0, kOpFormat_FRAC_NZ, | |||||
| kOpFormat_NC1HWC0_C04, kOpFormat_FRACTAL_Z_C04, kOpFormat_NDHWC, kOpFormat_FRACTAL_ZN_LSTM}; | |||||
| const std::set<std::string> kOpFormatList = {kOpFormat_DEFAULT, kOpFormat_NC1KHKWHWC0, | |||||
| kOpFormat_ND, kOpFormat_NCHW, | |||||
| kOpFormat_NHWC, kOpFormat_HWCN, | |||||
| kOpFormat_NC1HWC0, kOpFormat_FRAC_Z, | |||||
| kOpFormat_C1HWNCoC0, kOpFormat_FRAC_NZ, | |||||
| kOpFormat_NC1HWC0_C04, kOpFormat_FRACTAL_Z_C04, | |||||
| kOpFormat_NDHWC, kOpFormat_FRACTAL_ZN_LSTM, | |||||
| kOpFormat_NDC1HWC0, kOpFormat_NCDHW, | |||||
| kOpFormat_FRACTAL_Z_3D, kOpFormat_DHWNC, | |||||
| kOpFormat_DHWCN}; | |||||
| const std::set<std::string> kDefaultCompatibleFormat = {kOpFormat_ND, kOpFormat_NCHW, kOpFormat_NHWC, kOpFormat_HWCN}; | const std::set<std::string> kDefaultCompatibleFormat = {kOpFormat_ND, kOpFormat_NCHW, kOpFormat_NHWC, kOpFormat_HWCN}; | ||||
| const std::set<std::string> kOptOperatorSet = {kMomentumOpName, | const std::set<std::string> kOptOperatorSet = {kMomentumOpName, | ||||
| kApplyMomentumOpName, | kApplyMomentumOpName, | ||||
| @@ -427,8 +439,8 @@ const std::set<std::string> kOptOperatorSet = {kMomentumOpName, | |||||
| kSparseApplyProximalAdagradOpName}; | kSparseApplyProximalAdagradOpName}; | ||||
| const std::set<std::string> kHWSpecialFormatSet = { | const std::set<std::string> kHWSpecialFormatSet = { | ||||
| kOpFormat_FRAC_Z, kOpFormat_NC1KHKWHWC0, kOpFormat_NC1HWC0, kOpFormat_FRAC_NZ, | |||||
| kOpFormat_C1HWNCoC0, kOpFormat_NC1HWC0_C04, kOpFormat_FRACTAL_Z_C04, kOpFormat_FRACTAL_ZN_LSTM}; | |||||
| kOpFormat_FRACTAL_Z_3D, kOpFormat_NC1KHKWHWC0, kOpFormat_NC1HWC0, kOpFormat_FRAC_NZ, kOpFormat_C1HWNCoC0, | |||||
| kOpFormat_NC1HWC0_C04, kOpFormat_FRACTAL_Z_C04, kOpFormat_FRACTAL_ZN_LSTM, kOpFormat_NDC1HWC0, kOpFormat_FRAC_Z}; | |||||
| const std::set<TypeId> kFloatDataTypeSet = {kNumberTypeFloat16, kNumberTypeFloat32}; | const std::set<TypeId> kFloatDataTypeSet = {kNumberTypeFloat16, kNumberTypeFloat32}; | ||||