| @@ -95,6 +95,16 @@ constexpr auto kJSocVersion = "socVersion"; | |||
| constexpr auto kSOC_VERSION = "SOC_VERSION"; | |||
| constexpr auto kJIsDynamicShape = "is_dynamic_shape"; | |||
| bool IsNeedChangeDefaultFormat(const CNodePtr &cnode) { | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| MS_LOG(INFO) << "Check if need change default format"; | |||
| if (AnfAlgo::HasNodeAttr("io_format", cnode->cast<CNodePtr>())) { | |||
| auto attr = AnfAlgo::GetNodeAttr<std::string>(cnode, "io_format"); | |||
| return attr == kOpFormat_NCDHW; | |||
| } | |||
| return false; | |||
| } | |||
| bool TbeKernelJsonCreator::GenTbeSingleKernelJson(const std::shared_ptr<mindspore::AnfNode> &anf_node, | |||
| nlohmann::json *kernel_json) { | |||
| MS_EXCEPTION_IF_NULL(anf_node); | |||
| @@ -161,10 +171,14 @@ void TbeKernelJsonCreator::GenValidInputDescJson(const std::shared_ptr<AnfNode> | |||
| bool value, const std::shared_ptr<OpIOInfo> &input_ptr, | |||
| const string &op_input_name, size_t input_i, | |||
| std::vector<nlohmann::json> *input_list) { | |||
| auto def_format = kOpFormat_NCHW; | |||
| auto dtype = GetDeviceInputType(anf_node, real_input_index); | |||
| auto format = GetDeviceInputFormat(anf_node, real_input_index); | |||
| auto shape = GetDeviceInputShape(anf_node, real_input_index); | |||
| auto ori_shape = AnfAlgo::GetPrevNodeOutputInferShape(anf_node, real_input_index); | |||
| if (anf_node->isa<CNode>() && IsNeedChangeDefaultFormat(anf_node->cast<CNodePtr>())) { | |||
| def_format = kOpFormat_NCDHW; | |||
| } | |||
| if (ori_shape.empty()) { | |||
| ori_shape.emplace_back(1); | |||
| } | |||
| @@ -172,7 +186,7 @@ void TbeKernelJsonCreator::GenValidInputDescJson(const std::shared_ptr<AnfNode> | |||
| input_desc_json[kJDtype] = dtype; | |||
| input_desc_json[kJName] = op_input_name + std::to_string(input_i); | |||
| input_desc_json[kJOriShape] = ori_shape; | |||
| input_desc_json[kJOriFormat] = kOpFormat_NCHW; | |||
| input_desc_json[kJOriFormat] = def_format; | |||
| input_desc_json[kJShape] = shape; | |||
| input_desc_json[kJFormat] = format; | |||
| input_desc_json[kJValid] = value; | |||
| @@ -379,6 +393,10 @@ void TbeKernelJsonCreator::GenOutputList(const std::shared_ptr<AnfNode> &anf_nod | |||
| std::vector<nlohmann::json> *output_list) { | |||
| MS_EXCEPTION_IF_NULL(output_idx); | |||
| MS_EXCEPTION_IF_NULL(output_list); | |||
| auto def_format = kOpFormat_NCHW; | |||
| if (anf_node->isa<CNode>() && IsNeedChangeDefaultFormat(anf_node->cast<CNodePtr>())) { | |||
| def_format = kOpFormat_NCDHW; | |||
| } | |||
| for (size_t i = 0; i < output_obj_num; i++) { | |||
| auto dtype = GetDeviceOutputType(anf_node, *output_idx); | |||
| auto format = GetDeviceOutputFormat(anf_node, *output_idx); | |||
| @@ -397,7 +415,7 @@ void TbeKernelJsonCreator::GenOutputList(const std::shared_ptr<AnfNode> &anf_nod | |||
| output_obj[kJShape] = shape; | |||
| output_obj[kJFormat] = format; | |||
| output_obj[kJOriShape] = ori_shape; | |||
| output_obj[kJOriFormat] = kOpFormat_NCHW; | |||
| output_obj[kJOriFormat] = def_format; | |||
| output_obj[kJName] = output_ptr->name(); | |||
| output_obj[kJValid] = true; | |||
| output_obj[kJParamType] = output_ptr->param_type(); | |||
| @@ -580,6 +598,9 @@ std::string TbeKernelJsonCreator::GetDeviceInputFormat(const AnfNodePtr &anf_nod | |||
| format = kOpFormat_NCHW; | |||
| } | |||
| } | |||
| if (anf_node->isa<CNode>() && IsNeedChangeDefaultFormat(anf_node->cast<CNodePtr>())) { | |||
| format = kOpFormat_NCDHW; | |||
| } | |||
| return format; | |||
| } | |||
| @@ -619,6 +640,9 @@ std::string TbeKernelJsonCreator::GetDeviceOutputFormat(const AnfNodePtr &anf_no | |||
| format = kOpFormat_NCHW; | |||
| } | |||
| } | |||
| if (anf_node->isa<CNode>() && IsNeedChangeDefaultFormat(anf_node->cast<CNodePtr>())) { | |||
| format = kOpFormat_NCDHW; | |||
| } | |||
| return format; | |||
| } | |||
| @@ -818,6 +842,10 @@ void TbeKernelBuild::GenSuffixDescJson(nlohmann::json *output_desc) { | |||
| void TbeKernelBuild::GenDescJson(const std::shared_ptr<mindspore::AnfNode> &anf_node, size_t node_out_idx, | |||
| size_t desc_output_idx, nlohmann::json *output_desc, FusionDataType fusion_data_type) { | |||
| GenPreDescJson(output_desc); | |||
| auto def_format = kOpFormat_NCHW; | |||
| if (anf_node->isa<CNode>() && IsNeedChangeDefaultFormat(anf_node->cast<CNodePtr>())) { | |||
| def_format = kOpFormat_NCDHW; | |||
| } | |||
| // data_type | |||
| auto type_id = AnfAlgo::GetOutputDeviceDataType(anf_node, node_out_idx); | |||
| (*output_desc)[kJDataType] = tbe::TypeIdToString(type_id); | |||
| @@ -828,7 +856,7 @@ void TbeKernelBuild::GenDescJson(const std::shared_ptr<mindspore::AnfNode> &anf_ | |||
| } | |||
| (*output_desc)[kJName] = output_desc_name; | |||
| // ori_format | |||
| (*output_desc)[kJOriFormat] = kOpFormat_NCHW; | |||
| (*output_desc)[kJOriFormat] = def_format; | |||
| // ori_shape | |||
| auto ori_shape = AnfAlgo::GetOutputInferShape(anf_node, node_out_idx); | |||
| if (ori_shape.empty()) { | |||
| @@ -248,13 +248,57 @@ bool TbeKernelBroadCastSelecter::IsBroadCastSupportFracNZ(SupportFormat *support | |||
| bool TbeKernelBroadCastSelecter::IsBroadCastSupportNDC1HWC0(SupportFormat *support_format) const { | |||
| MS_EXCEPTION_IF_NULL(support_format); | |||
| return false; | |||
| if (IsSameShape()) { | |||
| if (!HasScalarInput()) { | |||
| AssignSupportFormat(kOpFormat_NDC1HWC0, support_format); | |||
| return true; | |||
| } | |||
| return false; | |||
| } | |||
| SupportFormatItem input_support_format; | |||
| SupportFormatItem output_support_format; | |||
| if (HasScalarInput()) { | |||
| for (const auto &shape : input_shapes_) { | |||
| if (IsScalarShape(shape)) { | |||
| input_support_format.emplace_back(kOpFormat_NCDHW); | |||
| } else if (!Is5DShape(shape)) { | |||
| return false; | |||
| } else if (shape[kChannelC] % kAlignmented16 != 0) { | |||
| return false; | |||
| } else { | |||
| input_support_format.emplace_back(kOpFormat_NDC1HWC0); | |||
| } | |||
| } | |||
| } else { | |||
| for (const auto &shape : input_shapes_) { | |||
| if (!Is5DShape(shape)) { | |||
| return false; | |||
| } | |||
| } | |||
| auto shape_tmp = input_shapes_[0]; | |||
| auto broadcast_c_axis = std::any_of( | |||
| input_shapes_.begin(), input_shapes_.end(), | |||
| [&shape_tmp](const std::vector<size_t> &elem) { return shape_tmp.at(kChannelC) != elem.at(kChannelC); }); | |||
| if (broadcast_c_axis) { | |||
| MS_LOG(INFO) << "This node broadcast c channel."; | |||
| return false; | |||
| } | |||
| input_support_format.assign(input_num_, kOpFormat_NDC1HWC0); | |||
| } | |||
| GenOutputSupportFormat(kOpFormat_NDC1HWC0, &output_support_format); | |||
| support_format->input_format.emplace_back(input_support_format); | |||
| support_format->output_format.emplace_back(output_support_format); | |||
| return true; | |||
| } | |||
| bool TbeKernelBroadCastSelecter::Is4DShape(const std::vector<size_t> &shape) const { | |||
| return shape.size() == kShape4dDims; | |||
| } | |||
| bool TbeKernelBroadCastSelecter::Is5DShape(const std::vector<size_t> &shape) const { | |||
| return shape.size() == kShape5dDims; | |||
| } | |||
| bool TbeKernelBroadCastSelecter::IsSameShape() const { | |||
| auto shape = input_shapes_.begin(); | |||
| for (const auto &item : input_shapes_) { | |||
| @@ -40,6 +40,7 @@ class TbeKernelBroadCastSelecter { | |||
| bool IsSameShape() const; | |||
| void PadScalarShape(std::vector<size_t> *shape) const; | |||
| bool Is4DShape(const std::vector<size_t> &shape) const; | |||
| bool Is5DShape(const std::vector<size_t> &shape) const; | |||
| bool IsScalarShape(const std::vector<size_t> &shape) const; | |||
| bool HasScalarInput() const; | |||
| void GenOutputSupportFormat(const std::string &support_format, SupportFormatItem *output_support_item) const; | |||
| @@ -72,8 +72,18 @@ bool TbeKernelReduceSelecter::IsReduceSupport5HD(SupportFormat *support_format) | |||
| bool TbeKernelReduceSelecter::IsReduceSupportNDC1HWC0(SupportFormat *support_format) const { | |||
| MS_EXCEPTION_IF_NULL(support_format); | |||
| // like to 5HD | |||
| return false; | |||
| if (!Is5DShape(input_shape_)) { | |||
| return false; | |||
| } | |||
| if (!keep_dims_ || axis_.empty()) { | |||
| return false; | |||
| } | |||
| auto reduce_c_axis = std::any_of(axis_.begin(), axis_.end(), [](const size_t &elem) { return (elem == kChannelC); }); | |||
| if (reduce_c_axis) { | |||
| return false; | |||
| } | |||
| AssignSupportFormat(kOpFormat_NDC1HWC0, support_format); | |||
| return true; | |||
| } | |||
| bool TbeKernelReduceSelecter::IsReduceSupportFracZ(SupportFormat *support_format) const { | |||
| @@ -142,6 +152,8 @@ void TbeKernelReduceSelecter::AssignSupportFormat(const std::string &support_for | |||
| bool TbeKernelReduceSelecter::Is4DShape(const std::vector<size_t> &shape) const { return shape.size() == kShape4dDims; } | |||
| bool TbeKernelReduceSelecter::Is5DShape(const std::vector<size_t> &shape) const { return shape.size() == kShape5dDims; } | |||
| void TbeKernelReduceSelecter::PadScalarShape(std::vector<size_t> *shape) const { | |||
| MS_EXCEPTION_IF_NULL(shape); | |||
| if (shape->empty()) { | |||
| @@ -39,6 +39,7 @@ class TbeKernelReduceSelecter { | |||
| void GetReduceAttrKeepDim(); | |||
| void AssignSupportFormat(const std::string &support_format_str, SupportFormat *support_format) const; | |||
| bool Is4DShape(const std::vector<size_t> &shape) const; | |||
| bool Is5DShape(const std::vector<size_t> &shape) const; | |||
| void PadScalarShape(std::vector<size_t> *shape) const; | |||
| CNodePtr cnode_ptr_; | |||
| std::vector<size_t> input_shape_{}; | |||
| @@ -187,6 +187,9 @@ void TbeKernelSelect::GetBroadcastPatternKernelInfo(const OpInfo &op_info) { | |||
| if (!broadcast_selecter.IsBroadCastSupportFracNZ(&support_format)) { | |||
| MS_LOG(INFO) << "Node(" << node_name_ << ") does not support FracNZ."; | |||
| } | |||
| if (!broadcast_selecter.IsBroadCastSupportNDC1HWC0(&support_format)) { | |||
| MS_LOG(INFO) << "Node(" << node_name_ << ") does not support NDC1HWC0."; | |||
| } | |||
| PrintSupportedFormat(support_format); | |||
| OpInfo op_info_new; | |||
| CreateNewOpInfo(op_info, support_format, &op_info_new); | |||
| @@ -281,10 +284,8 @@ bool TbeKernelSelect::IsShapeMatchFormat(const std::vector<size_t> &shape, const | |||
| return true; | |||
| } | |||
| // not support format: | |||
| // 1 NDHWC with shape size != 5 | |||
| // 3 !NDHWC with shape size > 4 | |||
| if ((format == kOpFormat_NDHWC && shape.size() != kShape5dDims) || | |||
| (format != kOpFormat_NDHWC && shape.size() > kShape4dDims)) { | |||
| // 1 NCDHW with shape size != 5 | |||
| if (format == kOpFormat_NCDHW && shape.size() != kShape5dDims) { | |||
| MS_LOG(INFO) << "Warning: Shape format check failed, format: " << format << ", size: " << shape.size(); | |||
| return false; | |||
| } | |||
| @@ -32,7 +32,7 @@ namespace mindspore { | |||
| namespace opt { | |||
| using KernelBuildInfoBuilder = kernel::KernelBuildInfo::KernelBuildInfoBuilder; | |||
| namespace { | |||
| const std::set<std::string> kCommonFormatSet = {kOpFormat_DEFAULT, kOpFormat_ND, kOpFormat_NCHW}; | |||
| const std::set<std::string> kCommonFormatSet = {kOpFormat_DEFAULT, kOpFormat_ND, kOpFormat_NCHW, kOpFormat_NCDHW}; | |||
| AnfNodePtr CreateReshapeNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node, | |||
| const KernelSelectPtr &kernel_select, const std::vector<size_t> &dst_shape) { | |||
| std::vector<AnfNodePtr> trans_inputs; | |||
| @@ -70,9 +70,17 @@ AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePt | |||
| CNodePtr trans_data = nullptr; | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| // Init | |||
| std::string default_format = kOpFormat_DEFAULT; | |||
| if (node->isa<CNode>() && AnfAlgo::HasNodeAttr("io_format", node->cast<CNodePtr>())) { | |||
| auto attr = AnfAlgo::GetNodeAttr<std::string>(node, "io_format"); | |||
| if (attr == kOpFormat_NCDHW) { | |||
| default_format = kOpFormat_NCDHW; | |||
| } | |||
| } | |||
| AnfNodePtr input_node = is_insert_input ? AnfAlgo::GetInputNode(node->cast<CNodePtr>(), insert_index) : node; | |||
| std::string input_format = is_insert_input ? kOpFormat_DEFAULT : AnfAlgo::GetOutputFormat(node, insert_index); | |||
| std::string dst_format = is_insert_input ? AnfAlgo::GetInputFormat(node, insert_index) : kOpFormat_DEFAULT; | |||
| std::string input_format = is_insert_input ? default_format : AnfAlgo::GetOutputFormat(node, insert_index); | |||
| std::string dst_format = is_insert_input ? AnfAlgo::GetInputFormat(node, insert_index) : default_format; | |||
| std::vector<Axis> padding_axis = is_insert_input ? AnfAlgo::GetInputReshapeType(node, insert_index) | |||
| : AnfAlgo::GetOutputReshapeType(node, insert_index); | |||
| auto input_node_out_shape = is_insert_input ? AnfAlgo::GetPrevNodeOutputInferShape(node, insert_index) | |||
| @@ -369,6 +369,26 @@ void KernelGraph::CheckLoop() { | |||
| } | |||
| } | |||
| void ReSetParameterValueNodeFormatAndType(const AnfNodePtr &node, const std::string &format) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(); | |||
| MS_EXCEPTION_IF_NULL(kernel_build_info_builder); | |||
| kernel_build_info_builder->SetOutputsFormat({format}); | |||
| kernel_build_info_builder->SetOutputsDeviceType({AnfAlgo::GetOutputInferDataType(node, 0)}); | |||
| AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), node.get()); | |||
| } | |||
| void KernelGraph::ResetInFormat(const AnfNodePtr &node, const std::string &format) const { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(node); i++) { | |||
| auto in_node = AnfAlgo::GetInputNode(node->cast<CNodePtr>(), i); | |||
| MS_EXCEPTION_IF_NULL(in_node); | |||
| if (in_node->isa<Parameter>() || in_node->isa<ValueNode>()) { | |||
| ReSetParameterValueNodeFormatAndType(in_node, format); | |||
| } | |||
| } | |||
| } | |||
| CNodePtr KernelGraph::NewCNode(const std::vector<AnfNodePtr> &inputs) { | |||
| auto cnode = FuncGraph::NewCNode(inputs); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| @@ -378,6 +398,12 @@ CNodePtr KernelGraph::NewCNode(const std::vector<AnfNodePtr> &inputs) { | |||
| AnfAlgo::SetNodeAttr(kIsBackendCast, MakeValue(false), cnode); | |||
| } | |||
| SetKernelInfoForNode(cnode); | |||
| if (AnfAlgo::HasNodeAttr("io_format", cnode)) { | |||
| auto attr = AnfAlgo::GetNodeAttr<std::string>(cnode, "io_format"); | |||
| if (attr == kOpFormat_NCDHW) { | |||
| ResetInFormat(cnode, kOpFormat_NCDHW); | |||
| } | |||
| } | |||
| AnfAlgo::SetGraphId(graph_id_, cnode.get()); | |||
| return cnode; | |||
| } | |||
| @@ -273,6 +273,7 @@ class KernelGraph : public FuncGraph { | |||
| // remove value node form graph | |||
| bool RemoveValueNodeFromGraph(const ValueNodePtr &value_node); | |||
| void SetKernelInfoForNode(const AnfNodePtr &node) const; | |||
| void ResetInFormat(const AnfNodePtr &node, const std::string &format) const; | |||
| AnfNodePtr MakeValueNode(const AnfNodePtr &node); | |||
| void VisitNodeDescendants(const AnfNodePtr &node, std::queue<AnfNodePtr> *visit_queue, | |||
| std::unordered_set<AnfNodePtr> *visited_nodes); | |||
| @@ -266,6 +266,41 @@ std::vector<size_t> Nc1hwc0DeviceShape(const std::vector<size_t> &shape) { | |||
| return device_shape; | |||
| } | |||
| std::vector<size_t> Ndc1hwc0DeviceShape(const std::vector<size_t> &shape) { | |||
| // NCDHW | |||
| if (shape.size() != 5) { | |||
| MS_LOG(EXCEPTION) << "Check dims failed, expect shape dim 5, but got shape dim : " << shape.size(); | |||
| } | |||
| std::vector<size_t> device_shape; | |||
| const size_t C1 = (shape[1] + kCubeSize - 1) / kCubeSize; | |||
| const size_t C0 = kCubeSize; | |||
| device_shape.push_back(shape[0]); | |||
| device_shape.push_back(shape[2]); | |||
| device_shape.push_back(C1); | |||
| device_shape.push_back(shape[3]); | |||
| device_shape.push_back(shape[4]); | |||
| device_shape.push_back(C0); | |||
| return device_shape; | |||
| } | |||
| std::vector<size_t> Fracz3DDeviceShape(const std::vector<size_t> &shape) { | |||
| // NCDHW -> Frac_Z_3D | |||
| if (shape.size() != 5) { | |||
| MS_LOG(EXCEPTION) << "Check dims failed, expect shape dim 5, but got shape dim : " << shape.size(); | |||
| } | |||
| std::vector<size_t> device_shape; | |||
| const size_t C1 = (shape[1] + kCubeSize - 1) / kCubeSize; | |||
| const size_t N1 = (shape[0] + kCubeSize - 1) / kCubeSize; | |||
| device_shape.push_back(shape[2]); | |||
| device_shape.push_back(C1); | |||
| device_shape.push_back(shape[3]); | |||
| device_shape.push_back(shape[4]); | |||
| device_shape.push_back(N1); | |||
| device_shape.push_back(kCubeSize); | |||
| device_shape.push_back(kCubeSize); | |||
| return device_shape; | |||
| } | |||
| std::vector<size_t> C1hwncoc0DeviceShape(const std::vector<size_t> &shape) { | |||
| if (!CheckDims(shape)) { | |||
| MS_LOG(EXCEPTION) << "Check dims failed."; | |||
| @@ -310,7 +345,7 @@ std::vector<size_t> Nc1hwc04DeviceShape(const std::vector<size_t> &shape) { | |||
| return device_shape; | |||
| } | |||
| std::vector<size_t> NdhwcDeviceShape(const std::vector<size_t> &shape) { | |||
| std::vector<size_t> NcdhwDeviceShape(const std::vector<size_t> &shape) { | |||
| if (shape.size() < kNdhwc) { | |||
| MS_LOG(EXCEPTION) << "Shape dims must be 5 when format is ndhwc."; | |||
| } | |||
| @@ -405,7 +440,9 @@ std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const s | |||
| {kOpFormat_C1HWNCoC0, C1hwncoc0DeviceShape}, | |||
| {kOpFormat_FRACTAL_Z_C04, FracZc04DeviceShape}, | |||
| {kOpFormat_NC1HWC0_C04, Nc1hwc04DeviceShape}, | |||
| {kOpFormat_NDHWC, NdhwcDeviceShape}}; | |||
| {kOpFormat_NCDHW, NcdhwDeviceShape}, | |||
| {kOpFormat_NDC1HWC0, Ndc1hwc0DeviceShape}, | |||
| {kOpFormat_FRACTAL_Z_3D, Fracz3DDeviceShape}}; | |||
| if (format == kOpFormat_ND || format == kOpFormat_DEFAULT) { | |||
| return shape; | |||
| @@ -441,7 +478,7 @@ std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const s | |||
| device_shape.push_back(kCubeSize); | |||
| return device_shape; | |||
| } | |||
| if (shape.size() != kNchwDims) { | |||
| if (shape.size() != kNchwDims && shape.size() != 5) { | |||
| MS_LOG(WARNING) << "Get Device Shape using a shape size is less than 4 ,should be Padding shape by Default firstly"; | |||
| temp_shape = PaddingShapeTo4dByDefault(shape); | |||
| } | |||
| @@ -496,7 +533,9 @@ bool TransFormat(const FormatArgs &args, void *result) { | |||
| const std::map<std::string, FormatTransfer> format_trans_map{ | |||
| {kOpFormat_FRAC_Z, NchwToFracZ}, {kOpFormat_FRAC_NZ, NchwToFracNz}, | |||
| {kOpFormat_NC1HWC0, NchwToNc1hwc0}, {kOpFormat_C1HWNCoC0, NchwToC1hwncoc0}, | |||
| {kOpFormat_FRACTAL_Z_C04, NchwToFracZc04}, {kOpFormat_NC1HWC0_C04, NchwToNc1hwc04}}; | |||
| {kOpFormat_FRACTAL_Z_C04, NchwToFracZc04}, {kOpFormat_NC1HWC0_C04, NchwToNc1hwc04}, | |||
| {kOpFormat_NDC1HWC0, NcdhwToNdc1hwc0}}; | |||
| MS_LOG(DEBUG) << "Start trans format."; | |||
| if (abstract::TypeIdSize(args.src_data_type) < 1) { | |||
| MS_LOG(ERROR) << "Invalid datatype.."; | |||
| @@ -514,11 +553,11 @@ bool TransFormat(const FormatArgs &args, void *result) { | |||
| bool TransFormatFromDeviceToHost(const FormatArgs &args, void *result) { | |||
| using FormatTransfer = std::function<bool(const FormatArgs &, void *)>; | |||
| const std::map<std::string, FormatTransfer> format_trans_map{{kOpFormat_FRAC_Z, FracZToNchw}, | |||
| {kOpFormat_FRAC_NZ, FracNzToNchw}, | |||
| {kOpFormat_NC1HWC0, Nc1hwc0ToNchw}, | |||
| {kOpFormat_C1HWNCoC0, C1hwncoc0ToNchw}, | |||
| {kOpFormat_NC1HWC0_C04, Nc1hwc04ToNchw}}; | |||
| const std::map<std::string, FormatTransfer> format_trans_map{ | |||
| {kOpFormat_FRAC_Z, FracZToNchw}, {kOpFormat_FRAC_NZ, FracNzToNchw}, | |||
| {kOpFormat_NC1HWC0, Nc1hwc0ToNchw}, {kOpFormat_C1HWNCoC0, C1hwncoc0ToNchw}, | |||
| {kOpFormat_NC1HWC0_C04, Nc1hwc04ToNchw}, {kOpFormat_NDC1HWC0, Ndc1hwc0ToNcdhw}}; | |||
| MS_LOG(DEBUG) << "Start trans format."; | |||
| if (abstract::TypeIdSize(args.src_data_type) < 1) { | |||
| MS_LOG(ERROR) << "Invalid datatype.."; | |||
| @@ -1106,5 +1145,119 @@ bool C1hwncoc0ToNchw(const FormatArgs &args, void *result) { | |||
| } | |||
| return true; | |||
| } | |||
| bool Ndc1hwc0ToNcdhw(const FormatArgs &args, void *result) { | |||
| MS_LOG(DEBUG) << "Trans from ndc1hwc0 to ncdhw"; | |||
| MS_EXCEPTION_IF_NULL(result); | |||
| if (args.host_shape.size() != 5) { | |||
| MS_LOG(ERROR) << "Illegal host shape dim, expect dim: 5, but got " << args.host_shape.size(); | |||
| return false; | |||
| } | |||
| auto size = abstract::TypeIdSize(args.src_data_type); | |||
| if (size < 1) { | |||
| MS_LOG(ERROR) << "Illegal dtype."; | |||
| return false; | |||
| } | |||
| auto total_size = abstract::ShapeSize(args.device_shape) * size; | |||
| if (total_size != args.device_size) { | |||
| MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size; | |||
| return false; | |||
| } | |||
| auto n = args.host_shape[0]; | |||
| auto c = args.host_shape[1]; | |||
| auto d = args.host_shape[2]; | |||
| auto h = args.host_shape[3]; | |||
| auto w = args.host_shape[4]; | |||
| auto c1 = args.device_shape[2]; | |||
| auto c0 = args.device_shape[5]; | |||
| const size_t cdhw = c * d * h * w; | |||
| const size_t dhw = d * h * w; | |||
| const size_t hw = h * w; | |||
| const size_t dc1hwc0 = d * c1 * h * w * c0; | |||
| const size_t c1hwc0 = c1 * h * w * c0; | |||
| const size_t hwc0 = h * w * c0; | |||
| const size_t wc0 = w * c0; | |||
| for (size_t n_i = 0; n_i < n; n_i++) { | |||
| size_t n_head = n_i * cdhw; | |||
| for (size_t c_i = 0; c_i < c; c_i++) { | |||
| size_t c_head = n_head + c_i * dhw; | |||
| for (size_t d_i = 0; d_i < d; d_i++) { | |||
| size_t d_head = c_head + d_i * hw; | |||
| for (size_t h_i = 0; h_i < h; h_i++) { | |||
| size_t h_head = d_head + h_i * w; | |||
| for (size_t w_i = 0; w_i < w; w_i++) { | |||
| size_t dst_i = h_head + w_i; | |||
| size_t c1_i = c_i / c0; | |||
| size_t c0_i = c_i % c0; | |||
| auto src_idx = n_i * dc1hwc0 + d_i * c1hwc0 + c1_i * hwc0 + h_i * wc0 + w_i * c0 + c0_i; | |||
| SetData(size, false, src_idx, dst_i, args, result); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| return true; | |||
| } | |||
| bool NcdhwToNdc1hwc0(const FormatArgs &args, void *result) { | |||
| MS_LOG(DEBUG) << "Trans from ncdhw to ndc1hwc0"; | |||
| MS_EXCEPTION_IF_NULL(result); | |||
| if (args.host_shape.size() != 5) { | |||
| MS_LOG(ERROR) << "Illegal host shape dim, expect dim: 5, but got " << args.host_shape.size(); | |||
| return false; | |||
| } | |||
| auto size = abstract::TypeIdSize(args.src_data_type); | |||
| if (size < 1) { | |||
| MS_LOG(ERROR) << "Illegal dtype."; | |||
| return false; | |||
| } | |||
| auto total_size = abstract::ShapeSize(args.device_shape) * size; | |||
| if (total_size != args.device_size) { | |||
| MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size; | |||
| return false; | |||
| } | |||
| auto n = args.host_shape[0]; | |||
| auto c = args.host_shape[1]; | |||
| auto d = args.host_shape[2]; | |||
| auto h = args.host_shape[3]; | |||
| auto w = args.host_shape[4]; | |||
| auto c0 = kCubeSize; | |||
| auto c1 = DivCeil(c, c0); | |||
| const size_t cdhw = c * d * h * w; | |||
| const size_t dhw = d * h * w; | |||
| const size_t hw = h * w; | |||
| const size_t dc1hwc0 = d * c1 * h * w * c0; | |||
| const size_t c1hwc0 = c1 * h * w * c0; | |||
| const size_t hwc0 = h * w * c0; | |||
| const size_t wc0 = w * c0; | |||
| for (size_t n_i = 0; n_i < n; n_i++) { | |||
| size_t n_head = n_i * dc1hwc0; | |||
| for (size_t d_i = 0; d_i < d; d_i++) { | |||
| size_t d_head = n_head + d_i * c1hwc0; | |||
| for (size_t c1_i = 0; c1_i < c1; c1_i++) { | |||
| size_t c1_head = d_head + c1_i * hwc0; | |||
| for (size_t h_i = 0; h_i < h; h_i++) { | |||
| size_t h_head = c1_head + h_i * wc0; | |||
| for (size_t w_i = 0; w_i < w; w_i++) { | |||
| size_t w_head = h_head + w_i * c0; | |||
| for (size_t c0_i = 0; c0_i < c0; c0_i++) { | |||
| size_t dst_i = c0_i + w_head; | |||
| size_t c_i = c0_i + c1_i * c0; | |||
| size_t src_i = n_i * cdhw + c_i * dhw + d_i * hw + h_i * w + w_i; | |||
| auto pad_zero = c_i >= c; | |||
| SetData(size, pad_zero, src_i, dst_i, args, result); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| return true; | |||
| } | |||
| } // namespace trans | |||
| } // namespace mindspore | |||
| @@ -66,6 +66,8 @@ bool NchwToNc1hwc0(const FormatArgs &args, void *result); | |||
| bool NchwToFracZc04(const FormatArgs &args, void *result); | |||
| bool NchwToNc1hwc04(const FormatArgs &args, void *result); | |||
| bool NchwToC1hwncoc0(const FormatArgs &args, void *result); | |||
| bool NcdhwToNdc1hwc0(const FormatArgs &args, void *result); | |||
| // device to host | |||
| bool ToNchw(const FormatArgs &args, void *result); | |||
| bool FracZToNchw(const FormatArgs &args, void *result); | |||
| @@ -73,6 +75,7 @@ bool FracNzToNchw(const FormatArgs &args, void *result); | |||
| bool Nc1hwc0ToNchw(const FormatArgs &args, void *result); | |||
| bool Nc1hwc04ToNchw(const FormatArgs &args, void *result); | |||
| bool C1hwncoc0ToNchw(const FormatArgs &args, void *result); | |||
| bool Ndc1hwc0ToNcdhw(const FormatArgs &args, void *result); | |||
| } // namespace trans | |||
| } // namespace mindspore | |||
| @@ -292,7 +292,7 @@ bool AscendDeviceAddress::SyncDeviceToHost(const ShapeVector &shape, size_t size | |||
| if (host_shape.empty()) { | |||
| host_shape.emplace_back(1); | |||
| } | |||
| if (format_ == kOpFormat_NCHW || format_ == kOpFormat_DEFAULT || format_ == kOpFormat_NDHWC) { | |||
| if (format_ == kOpFormat_NCHW || format_ == kOpFormat_DEFAULT || format_ == kOpFormat_NCDHW) { | |||
| if (type_id_ == type) { | |||
| SyncMemory(host_ptr, ptr_, size, RT_MEMCPY_DEVICE_TO_HOST); | |||
| sync_ok = true; | |||
| @@ -454,7 +454,7 @@ std::vector<size_t> AscendDeviceAddress::GetWorkspaceSizeList(const nlohmann::js | |||
| std::vector<size_t> AscendDeviceAddress::GetDeviceShape(std::vector<size_t> *host_shape) const { | |||
| std::vector<size_t> device_shape; | |||
| if (format_ == kOpFormat_FRAC_NZ || format_ == kOpFormat_NDHWC) { | |||
| if (format_ == kOpFormat_FRAC_NZ || format_ == kOpFormat_NCDHW) { | |||
| device_shape = trans::TransShapeToDevice(*host_shape, format_); | |||
| } else { | |||
| if (host_shape_.empty()) { | |||
| @@ -531,7 +531,7 @@ bool AscendDeviceAddress::SyncHostToDevice(const ShapeVector &shape, size_t size | |||
| if (host_shape.empty()) { | |||
| host_shape.emplace_back(1); | |||
| } | |||
| if (format_ == kOpFormat_NCHW || format_ == kOpFormat_DEFAULT || format_ == kOpFormat_NDHWC) { | |||
| if (format_ == kOpFormat_NCHW || format_ == kOpFormat_DEFAULT || format_ == kOpFormat_NCDHW) { | |||
| if (type_id_ == type) { | |||
| SyncMemory(ptr_, host_ptr, size, RT_MEMCPY_HOST_TO_DEVICE); | |||
| sync_ok = true; | |||
| @@ -575,7 +575,7 @@ bool AscendDeviceAddress::ConvertFormatAndSyncHostToDevice(const ShapeVector &sh | |||
| host_shape.emplace_back(1); | |||
| } | |||
| std::vector<size_t> device_shape; | |||
| if (format_ == kOpFormat_FRAC_NZ || format_ == kOpFormat_NDHWC) { | |||
| if (format_ == kOpFormat_FRAC_NZ || format_ == kOpFormat_NCDHW) { | |||
| device_shape = trans::TransShapeToDevice(host_shape, format_); | |||
| } else { | |||
| host_shape = trans::PaddingShapeTo4d(host_shape); | |||
| @@ -81,6 +81,7 @@ string GetPriorityMatchFormat(const CNodePtr &cnode) { | |||
| string priority_matched_format = kOpFormat_NC1HWC0; | |||
| bool is_init = false; | |||
| bool need_change_nd = false; | |||
| bool is_5d_input = false; | |||
| for (size_t index = 0; index < AnfAlgo::GetInputTensorNum(cnode); ++index) { | |||
| auto pre_output_format = AnfAlgo::GetPrevNodeOutputFormat(cnode, index); | |||
| if (AnfAlgo::IsFeatureMapInput(cnode, index) && | |||
| @@ -93,14 +94,21 @@ string GetPriorityMatchFormat(const CNodePtr &cnode) { | |||
| priority_matched_format = kOpFormat_DEFAULT; | |||
| } | |||
| auto input_shape_size = AnfAlgo::GetPrevNodeOutputInferShape(cnode, index).size(); | |||
| if (input_shape_size == 5) { | |||
| is_5d_input = true; | |||
| } | |||
| need_change_nd = (need_change_nd || (input_shape_size != 4 && input_shape_size > 1)); | |||
| } | |||
| if (need_change_nd && priority_matched_format != kOpFormat_FRAC_NZ) { | |||
| priority_matched_format = kOpFormat_DEFAULT; | |||
| } | |||
| if (is_5d_input && priority_matched_format != kOpFormat_FRAC_NZ) { | |||
| priority_matched_format = kOpFormat_NDC1HWC0; | |||
| } | |||
| AnfAlgo::SetNodeAttr(kPriChoosenFormat, MakeValue(priority_matched_format), cnode); | |||
| return priority_matched_format; | |||
| } | |||
| /** | |||
| * Compare two vector by priority, select a better vector, like compare two num, first compare highest num location, | |||
| * if equal then next num location | |||
| @@ -157,7 +165,8 @@ void UpdateCurMatchCounts(const kernel::KernelBuildInfo &kernel_build_info, cons | |||
| if (kernel_build_info.GetInputFormat(input_index) == pri_match_format) { | |||
| (*cur_kernelinfo_match_counts)[MATCH_SPECIAL_FORMAT_COUNT] += base_score; | |||
| } | |||
| if (kernel_build_info.GetInputFormat(input_index) == kOpFormat_DEFAULT) { | |||
| if (kernel_build_info.GetInputFormat(input_index) == kOpFormat_DEFAULT || | |||
| kernel_build_info.GetInputFormat(input_index) == kOpFormat_NCDHW) { | |||
| (*cur_kernelinfo_match_counts)[MATCH_DEFAULT_FORMAT_COUNT] += base_score; | |||
| } | |||
| } | |||
| @@ -376,7 +385,9 @@ void SetTensorDeviceInfo(const CNodePtr &kernel_node) { | |||
| std::vector<std::string> output_format = {AnfAlgo::GetOutputFormat(real_input_node, 0)}; | |||
| if (IsValueNode<tensor::Tensor>(input_kernel_node) && | |||
| AnfAlgo::GetOutputDeviceDataType(input_kernel_node, 0) == kTypeUnknown) { | |||
| if (selected_kernel_info->GetInputFormat(input_index) != kOpFormat_FRACTAL_ZN_LSTM) { | |||
| if (selected_kernel_info->GetInputFormat(input_index) != kOpFormat_FRACTAL_ZN_LSTM || | |||
| selected_kernel_info->GetInputFormat(input_index) != kOpFormat_FRACTAL_Z_3D || | |||
| selected_kernel_info->GetInputFormat(input_index) != kOpFormat_NDC1HWC0) { | |||
| output_format = {selected_kernel_info->GetInputFormat(input_index)}; | |||
| } | |||
| builder->SetOutputsFormat(output_format); | |||
| @@ -386,7 +397,9 @@ void SetTensorDeviceInfo(const CNodePtr &kernel_node) { | |||
| continue; | |||
| } | |||
| if (AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) == kTypeUnknown || is_ref) { | |||
| if (selected_kernel_info->GetInputFormat(input_index) != kOpFormat_FRACTAL_ZN_LSTM) { | |||
| if (selected_kernel_info->GetInputFormat(input_index) != kOpFormat_FRACTAL_ZN_LSTM || | |||
| selected_kernel_info->GetInputFormat(input_index) != kOpFormat_FRACTAL_Z_3D || | |||
| selected_kernel_info->GetInputFormat(input_index) != kOpFormat_NDC1HWC0) { | |||
| output_format = {selected_kernel_info->GetInputFormat(input_index)}; | |||
| } | |||
| builder->SetOutputsFormat(output_format); | |||
| @@ -386,11 +386,23 @@ constexpr auto kOpFormat_C1HWNCoC0 = "C1HWNCoC0"; | |||
| constexpr auto kOpFormat_NC1HWC0_C04 = "NC1HWC0_C04"; | |||
| constexpr auto kOpFormat_FRACTAL_Z_C04 = "FRACTAL_Z_C04"; | |||
| constexpr auto kOpFormat_NDHWC = "NDHWC"; | |||
| constexpr auto kOpFormat_NCDHW = "NCDHW"; | |||
| constexpr auto kOpFormat_DHWNC = "DHWNC"; | |||
| constexpr auto kOpFormat_DHWCN = "DHWCN"; | |||
| constexpr auto kOpFormat_NDC1HWC0 = "NDC1HWC0"; | |||
| constexpr auto kOpFormat_FRACTAL_Z_3D = "FRACTAL_Z_3D"; | |||
| constexpr auto kOpFormat_FRACTAL_ZN_LSTM = "FRACTAL_ZN_LSTM"; | |||
| const std::set<std::string> kOpFormatList = { | |||
| kOpFormat_DEFAULT, kOpFormat_NC1KHKWHWC0, kOpFormat_ND, kOpFormat_NCHW, kOpFormat_NHWC, | |||
| kOpFormat_HWCN, kOpFormat_NC1HWC0, kOpFormat_FRAC_Z, kOpFormat_C1HWNCoC0, kOpFormat_FRAC_NZ, | |||
| kOpFormat_NC1HWC0_C04, kOpFormat_FRACTAL_Z_C04, kOpFormat_NDHWC, kOpFormat_FRACTAL_ZN_LSTM}; | |||
| const std::set<std::string> kOpFormatList = {kOpFormat_DEFAULT, kOpFormat_NC1KHKWHWC0, | |||
| kOpFormat_ND, kOpFormat_NCHW, | |||
| kOpFormat_NHWC, kOpFormat_HWCN, | |||
| kOpFormat_NC1HWC0, kOpFormat_FRAC_Z, | |||
| kOpFormat_C1HWNCoC0, kOpFormat_FRAC_NZ, | |||
| kOpFormat_NC1HWC0_C04, kOpFormat_FRACTAL_Z_C04, | |||
| kOpFormat_NDHWC, kOpFormat_FRACTAL_ZN_LSTM, | |||
| kOpFormat_NDC1HWC0, kOpFormat_NCDHW, | |||
| kOpFormat_FRACTAL_Z_3D, kOpFormat_DHWNC, | |||
| kOpFormat_DHWCN}; | |||
| const std::set<std::string> kDefaultCompatibleFormat = {kOpFormat_ND, kOpFormat_NCHW, kOpFormat_NHWC, kOpFormat_HWCN}; | |||
| const std::set<std::string> kOptOperatorSet = {kMomentumOpName, | |||
| kApplyMomentumOpName, | |||
| @@ -427,8 +439,8 @@ const std::set<std::string> kOptOperatorSet = {kMomentumOpName, | |||
| kSparseApplyProximalAdagradOpName}; | |||
| const std::set<std::string> kHWSpecialFormatSet = { | |||
| kOpFormat_FRAC_Z, kOpFormat_NC1KHKWHWC0, kOpFormat_NC1HWC0, kOpFormat_FRAC_NZ, | |||
| kOpFormat_C1HWNCoC0, kOpFormat_NC1HWC0_C04, kOpFormat_FRACTAL_Z_C04, kOpFormat_FRACTAL_ZN_LSTM}; | |||
| kOpFormat_FRACTAL_Z_3D, kOpFormat_NC1KHKWHWC0, kOpFormat_NC1HWC0, kOpFormat_FRAC_NZ, kOpFormat_C1HWNCoC0, | |||
| kOpFormat_NC1HWC0_C04, kOpFormat_FRACTAL_Z_C04, kOpFormat_FRACTAL_ZN_LSTM, kOpFormat_NDC1HWC0, kOpFormat_FRAC_Z}; | |||
| const std::set<TypeId> kFloatDataTypeSet = {kNumberTypeFloat16, kNumberTypeFloat32}; | |||