Merge pull request !19716 from wangnan39/tbe_build_adapt_dynamic_shapetags/v1.4.0
| @@ -44,7 +44,7 @@ def _initialize(impl_path): | |||
| def _replace_range(args): | |||
| for arg in args: | |||
| if not arg.__contains__('range'): | |||
| if not arg or not arg.__contains__('range'): | |||
| continue | |||
| shape_range = arg["range"] | |||
| for range_item in shape_range: | |||
| @@ -129,6 +129,61 @@ void SetLicInfo(nlohmann::json *op_info_json) { | |||
| (*op_info_json)[kJOpTuneList] = LicManager::GetInstance().GetOpTuneList(); | |||
| (*op_info_json)[kJPassList] = LicManager::GetInstance().GetPassSwitch(); | |||
| } | |||
| std::vector<int64_t> GetOutputShapeForTbeBuild(const AnfNodePtr &anf_node, size_t real_index) { | |||
| MS_EXCEPTION_IF_NULL(anf_node); | |||
| std::vector<int64_t> shape; | |||
| auto output_shape = AnfAlgo::GetOutputDetailShape(anf_node, real_index); | |||
| MS_EXCEPTION_IF_NULL(output_shape); | |||
| if (output_shape->isa<abstract::Shape>()) { | |||
| auto shape_ptr = output_shape->cast<abstract::ShapePtr>(); | |||
| MS_EXCEPTION_IF_NULL(shape_ptr); | |||
| shape = shape_ptr->shape(); | |||
| } | |||
| if (shape.empty()) { | |||
| shape.emplace_back(1); | |||
| } | |||
| return shape; | |||
| } | |||
| std::vector<int64_t> GetOutputDeviceShapeForTbeBuild(const kCreaterType creater_type, const AnfNodePtr &anf_node, | |||
| const size_t real_index) { | |||
| MS_EXCEPTION_IF_NULL(anf_node); | |||
| std::vector<int64_t> shape; | |||
| if (creater_type == OP_SELECT_FORMAT || creater_type == CHECK_SUPPORTED) { | |||
| shape = GetOutputShapeForTbeBuild(anf_node, real_index); | |||
| } else { | |||
| auto format = AnfAlgo::GetOutputFormat(anf_node, real_index); | |||
| shape = AnfAlgo::GetOutputDeviceShapeForTbeBuild(anf_node, real_index, format); | |||
| } | |||
| if (shape.empty()) { | |||
| shape.emplace_back(1); | |||
| } | |||
| return shape; | |||
| } | |||
| std::vector<int64_t> GetInputShapeForTbeBuild(const AnfNodePtr &anf_node, size_t real_index) { | |||
| MS_EXCEPTION_IF_NULL(anf_node); | |||
| session::KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(anf_node, real_index); | |||
| return GetOutputShapeForTbeBuild(kernel_with_index.first, kernel_with_index.second); | |||
| } | |||
| std::vector<int64_t> GetInputDeviceShapeForTbeBuild(const kCreaterType creater_type, const AnfNodePtr &anf_node, | |||
| const size_t real_index) { | |||
| MS_EXCEPTION_IF_NULL(anf_node); | |||
| std::vector<int64_t> shape; | |||
| session::KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(anf_node, real_index); | |||
| if (creater_type == OP_SELECT_FORMAT || creater_type == CHECK_SUPPORTED) { | |||
| shape = GetOutputShapeForTbeBuild(kernel_with_index.first, kernel_with_index.second); | |||
| } else { | |||
| auto format = AnfAlgo::GetInputFormat(anf_node, real_index); | |||
| shape = AnfAlgo::GetOutputDeviceShapeForTbeBuild(kernel_with_index.first, kernel_with_index.second, format); | |||
| } | |||
| if (shape.empty()) { | |||
| shape.emplace_back(1); | |||
| } | |||
| return shape; | |||
| } | |||
| } // namespace | |||
| bool TbeKernelJsonCreator::GenTbeSingleKernelJson(const std::shared_ptr<mindspore::AnfNode> &anf_node, | |||
| nlohmann::json *kernel_json) { | |||
| @@ -232,17 +287,14 @@ void TbeKernelJsonCreator::GenValidInputDescJson(const std::shared_ptr<AnfNode> | |||
| auto def_format = kOpFormat_NCHW; | |||
| auto dtype = GetDeviceInputType(anf_node, real_input_index); | |||
| auto format = GetDeviceInputFormat(anf_node, real_input_index); | |||
| auto shape = GetDeviceInputShape(anf_node, real_input_index); | |||
| auto ori_shape = AnfAlgo::GetPrevNodeOutputInferShape(anf_node, real_input_index); | |||
| auto shape = GetInputDeviceShapeForTbeBuild(creater_type_, anf_node, real_input_index); | |||
| auto ori_shape = GetInputShapeForTbeBuild(anf_node, real_input_index); | |||
| if (anf_node->isa<CNode>() && IsNeedChangeDefaultFormat(anf_node->cast<CNodePtr>())) { | |||
| def_format = kOpFormat_NCDHW; | |||
| } | |||
| if (def_format == kOpFormat_NCDHW && k3DFormatSet.find(format) == k3DFormatSet.end()) { | |||
| format = kOpFormat_NCDHW; | |||
| } | |||
| if (ori_shape.empty()) { | |||
| ori_shape.emplace_back(1); | |||
| } | |||
| nlohmann::json input_desc_json; | |||
| input_desc_json[kJDtype] = dtype; | |||
| input_desc_json[kJName] = op_input_name + std::to_string(input_i); | |||
| @@ -463,17 +515,12 @@ void TbeKernelJsonCreator::GenOutputList(const std::shared_ptr<AnfNode> &anf_nod | |||
| auto dtype = GetDeviceOutputType(anf_node, *output_idx); | |||
| auto format = GetDeviceOutputFormat(anf_node, *output_idx); | |||
| std::vector<int64_t> shape; | |||
| AnfAlgo::GetRealDynamicShape(GetDeviceOutputShape(anf_node, *output_idx), NOT_NULL(&shape)); | |||
| std::vector<int64_t> shape = GetOutputDeviceShapeForTbeBuild(creater_type_, anf_node, *output_idx); | |||
| std::vector<int64_t> ori_shape = GetOutputShapeForTbeBuild(anf_node, *output_idx); | |||
| std::vector<int64_t> ori_shape; | |||
| AnfAlgo::GetRealDynamicShape(AnfAlgo::GetOutputInferShape(anf_node, *output_idx), NOT_NULL(&ori_shape)); | |||
| if (def_format == kOpFormat_NCDHW && k3DFormatSet.find(format) == k3DFormatSet.end()) { | |||
| format = kOpFormat_NCDHW; | |||
| } | |||
| if (ori_shape.empty()) { | |||
| ori_shape.emplace_back(1); | |||
| } | |||
| nlohmann::json output_obj; | |||
| output_obj[kJDtype] = dtype; | |||
| output_obj[kJShape] = shape; | |||
| @@ -248,16 +248,41 @@ CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, | |||
| MS_EXCEPTION_IF_NULL(kernel_select); | |||
| CNodePtr trans_node = func_graph->NewCNode({NewValueNode(std::make_shared<Primitive>(op_name)), input}); | |||
| MS_EXCEPTION_IF_NULL(trans_node); | |||
| auto infer_type = AnfAlgo::GetOutputInferDataType(input, 0); | |||
| auto out_shape_base = AnfAlgo::GetOutputDetailShape(input, 0); | |||
| MS_EXCEPTION_IF_NULL(out_shape_base); | |||
| ShapeVector out_shape; | |||
| ShapeVector out_shape_min; | |||
| ShapeVector out_shape_max; | |||
| bool is_dynamic_shape = false; | |||
| if (out_shape_base->isa<abstract::Shape>()) { | |||
| auto out_shape_ptr = out_shape_base->cast<abstract::ShapePtr>(); | |||
| MS_EXCEPTION_IF_NULL(out_shape_ptr); | |||
| out_shape = out_shape_ptr->shape(); | |||
| if (out_shape_ptr->IsDynamic()) { | |||
| out_shape_min = out_shape_ptr->min_shape(); | |||
| out_shape_max = out_shape_ptr->max_shape(); | |||
| is_dynamic_shape = true; | |||
| } | |||
| } | |||
| if (need_padding) { | |||
| // if need padding we should set the transdata node's shape to the padding shape | |||
| auto padding_axis = AnfAlgo::GetOutputReshapeType(input, 0); | |||
| AnfAlgo::SetOutputInferTypeAndShape( | |||
| {AnfAlgo::GetOutputInferDataType(input, 0)}, | |||
| {trans::PaddingShape(AnfAlgo::GetOutputInferShape(input, 0), AnfAlgo::GetOutputFormat(input, 0), padding_axis)}, | |||
| trans_node.get()); | |||
| abstract::ShapePtr pad_shape_ptr; | |||
| ShapeVector pad_shape = trans::PaddingShape(out_shape, AnfAlgo::GetOutputFormat(input, 0), padding_axis); | |||
| if (is_dynamic_shape) { | |||
| ShapeVector pad_shape_min = trans::PaddingShape(out_shape_min, AnfAlgo::GetOutputFormat(input, 0), padding_axis); | |||
| ShapeVector pad_shape_max = trans::PaddingShape(out_shape_max, AnfAlgo::GetOutputFormat(input, 0), padding_axis); | |||
| pad_shape_ptr = std::make_shared<abstract::Shape>(pad_shape, pad_shape_min, pad_shape_max); | |||
| } else { | |||
| pad_shape_ptr = std::make_shared<abstract::Shape>(pad_shape); | |||
| } | |||
| AnfAlgo::SetOutputTypeAndDetailShape({infer_type}, {pad_shape_ptr}, trans_node.get()); | |||
| } else { | |||
| AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(input, 0)}, | |||
| {AnfAlgo::GetOutputInferShape(input, 0)}, trans_node.get()); | |||
| AnfAlgo::SetOutputTypeAndDetailShape({infer_type}, {out_shape_base}, trans_node.get()); | |||
| } | |||
| // special handle for ut | |||
| if (trans_node->kernel_info() == nullptr) { | |||
| @@ -267,6 +292,11 @@ CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, | |||
| if (op_name == prim::kPrimTranspose->name()) { | |||
| AnfAlgo::SetNodeAttr(kAttrPerm, MakeValue(perm), trans_node); | |||
| } | |||
| if (is_dynamic_shape) { | |||
| AnfAlgo::SetNodeAttr(kAttrIsDynamicShape, MakeValue(true), trans_node); | |||
| AnfAlgo::SetNodeAttr(kAttrInputIsDynamicShape, MakeValue(true), trans_node); | |||
| AnfAlgo::SetNodeAttr(kAttrOutputIsDynamicShape, MakeValue(true), trans_node); | |||
| } | |||
| AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), trans_node); | |||
| AnfAlgo::SetNodeAttr(kAttrDatadumpOriginalNames, MakeValue<std::vector<std::string>>({}), trans_node); | |||
| trans_node->set_scope(input->scope()); | |||
| @@ -308,6 +338,7 @@ CNodePtr AddCastOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr & | |||
| AnfAlgo::SetNodeAttr(kAttrInputIsDynamicShape, MakeValue(true), cast); | |||
| AnfAlgo::SetNodeAttr(kAttrOutputIsDynamicShape, MakeValue(true), cast); | |||
| } | |||
| AnfAlgo::SetNodeAttr("dst_type", TypeIdToType(origin_type), cast); | |||
| AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cast.get()); | |||
| AnfAlgo::SetOutputTypeAndDetailShape({origin_type}, {origin_shape}, cast.get()); | |||
| AnfAlgo::SetNodeAttr(kIsBackendCast, MakeValue(true), cast); | |||
| @@ -810,6 +810,27 @@ std::vector<size_t> AnfRuntimeAlgorithm::GetPrevNodeOutputInferShape(const AnfNo | |||
| return AnfRuntimeAlgorithm::GetOutputInferShape(kernel_with_index.first, kernel_with_index.second); | |||
| } | |||
| std::vector<int64_t> AnfRuntimeAlgorithm::GetOutputDeviceShapeForTbeBuild(const AnfNodePtr &node, | |||
| const size_t output_idx, | |||
| const std::string &format) { | |||
| auto output_shape = GetOutputDetailShape(node, output_idx); | |||
| std::vector<int64_t> infer_shape; | |||
| if (output_shape->isa<abstract::Shape>()) { | |||
| auto shape_ptr = output_shape->cast<abstract::ShapePtr>(); | |||
| MS_EXCEPTION_IF_NULL(shape_ptr); | |||
| infer_shape = shape_ptr->shape(); | |||
| } | |||
| if (infer_shape.empty()) { | |||
| return infer_shape; | |||
| } | |||
| // if format is default_format or NC1KHKWHWC0,device shape = original shape | |||
| if (trans::IsNeedPadding(format, infer_shape.size())) { | |||
| infer_shape = trans::PaddingShape(infer_shape, format, GetOutputReshapeType(node, output_idx)); | |||
| } | |||
| return trans::TransShapeToDevice(infer_shape, format, node, output_idx); | |||
| } | |||
| std::vector<size_t> AnfRuntimeAlgorithm::GetOutputDeviceShape(const AnfNodePtr &node, size_t output_idx) { | |||
| auto format = GetOutputFormat(node, output_idx); | |||
| auto infer_shape = GetOutputInferShape(node, output_idx); | |||
| @@ -154,6 +154,9 @@ class AnfRuntimeAlgorithm { | |||
| static std::vector<size_t> GetOutputDeviceShape(const AnfNodePtr &node, size_t output_idx); | |||
| // get input shapes which will built and run in device | |||
| static std::vector<size_t> GetInputDeviceShape(const AnfNodePtr &node, size_t input_idx); | |||
| // get output shapes for tbe build | |||
| static std::vector<int64_t> GetOutputDeviceShapeForTbeBuild(const AnfNodePtr &node, const size_t output_idx, | |||
| const std::string &format); | |||
| // Get Input Padding Axis | |||
| static std::string GetInputReshapeType(const AnfNodePtr &node, size_t output_idx); | |||
| // Get Output Padding Axis | |||
| @@ -17,9 +17,9 @@ | |||
| #include <functional> | |||
| #include <numeric> | |||
| #include <utility> | |||
| #include <algorithm> | |||
| #include "utils/ms_utils.h" | |||
| #include "abstract/utils.h" | |||
| #include "backend/session/anf_runtime_algorithm.h" | |||
| #include "backend/kernel_compiler/kernel.h" | |||
| #include "backend/kernel_compiler/tbe/tbe_dynaminc_shape_util.h" | |||
| #include "runtime/device/convert_tensor_utils.h" | |||
| @@ -27,9 +27,9 @@ | |||
| #include "utils/log_adapter.h" | |||
| #include "utils/utils.h" | |||
| using mindspore::abstract::Shape; | |||
| namespace mindspore { | |||
| namespace trans { | |||
| enum kAxis : int { kN = 0, kC, kH, kW, kNchwDims, kNcdhw }; | |||
| inline void SetData(size_t size, bool pad_zero, size_t src_idx, size_t dst_idx, const FormatArgs &args, void *result) { | |||
| switch (size) { | |||
| case 1: | |||
| @@ -214,7 +214,12 @@ bool CastKernel(const TypeIdArgs &args, void *dst, const size_t data_size, const | |||
| } | |||
| namespace { | |||
| bool CheckDims(const std::vector<size_t> &shape) { | |||
| bool HasShapeDynamic(const std::vector<int64_t> &shape_list) { | |||
| return std::any_of(shape_list.begin(), shape_list.end(), [](int64_t shape) { return shape == Shape::SHP_ANY; }); | |||
| } | |||
| template <typename T> | |||
| bool CheckDims(const std::vector<T> &shape) { | |||
| if (shape.size() != kNchwDims) { | |||
| MS_LOG(ERROR) << "Host shape dims should be 4"; | |||
| return false; | |||
| @@ -229,6 +234,13 @@ std::vector<size_t> NchwDeviceShape(const std::vector<size_t> &shape) { | |||
| return shape; | |||
| } | |||
| std::vector<int64_t> NchwDeviceDynamicShape(const std::vector<int64_t> &shape) { | |||
| if (!CheckDims(shape)) { | |||
| MS_LOG(EXCEPTION) << "Check dims failed."; | |||
| } | |||
| return shape; | |||
| } | |||
| std::vector<size_t> NhwcDeviceShape(const std::vector<size_t> &shape) { | |||
| if (!CheckDims(shape)) { | |||
| MS_LOG(EXCEPTION) << "Ccheck dims failed."; | |||
| @@ -241,6 +253,18 @@ std::vector<size_t> NhwcDeviceShape(const std::vector<size_t> &shape) { | |||
| return device_shape; | |||
| } | |||
| std::vector<int64_t> NhwcDeviceDynamicShape(const std::vector<int64_t> &shape) { | |||
| if (!CheckDims(shape)) { | |||
| MS_LOG(EXCEPTION) << "Ccheck dims failed."; | |||
| } | |||
| std::vector<int64_t> device_shape; | |||
| device_shape.push_back(shape[kN]); | |||
| device_shape.push_back(shape[kH]); | |||
| device_shape.push_back(shape[kW]); | |||
| device_shape.push_back(shape[kC]); | |||
| return device_shape; | |||
| } | |||
| std::vector<size_t> HwchDeviceShape(const std::vector<size_t> &shape) { | |||
| if (!CheckDims(shape)) { | |||
| MS_LOG(EXCEPTION) << "Check dims failed."; | |||
| @@ -253,6 +277,18 @@ std::vector<size_t> HwchDeviceShape(const std::vector<size_t> &shape) { | |||
| return device_shape; | |||
| } | |||
| std::vector<int64_t> HwchDeviceDynamicShape(const std::vector<int64_t> &shape) { | |||
| if (!CheckDims(shape)) { | |||
| MS_LOG(EXCEPTION) << "Check dims failed."; | |||
| } | |||
| std::vector<int64_t> device_shape; | |||
| device_shape.push_back(shape[kH]); | |||
| device_shape.push_back(shape[kW]); | |||
| device_shape.push_back(shape[kC]); | |||
| device_shape.push_back(shape[kN]); | |||
| return device_shape; | |||
| } | |||
| std::vector<size_t> FracZDeviceShape(const std::vector<size_t> &shape) { | |||
| if (!CheckDims(shape)) { | |||
| MS_LOG(EXCEPTION) << "Check dims failed."; | |||
| @@ -267,6 +303,28 @@ std::vector<size_t> FracZDeviceShape(const std::vector<size_t> &shape) { | |||
| return device_shape; | |||
| } | |||
| std::vector<int64_t> FracZDeviceDynamicShape(const std::vector<int64_t> &shape) { | |||
| if (!CheckDims(shape)) { | |||
| MS_LOG(EXCEPTION) << "Check dims failed."; | |||
| } | |||
| std::vector<int64_t> device_shape; | |||
| if (HasShapeDynamic({shape[kC], shape[kH], shape[kW]})) { | |||
| device_shape.push_back(Shape::SHP_ANY); | |||
| } else { | |||
| const int64_t cin16 = ((shape[kC] + kCubeSize - 1) / kCubeSize) * kCubeSize; | |||
| device_shape.push_back(shape[kH] * shape[kW] * cin16 / kCubeSize); | |||
| } | |||
| if (shape[kN] == Shape::SHP_ANY) { | |||
| device_shape.push_back(Shape::SHP_ANY); | |||
| } else { | |||
| const int64_t cout16 = ((shape[kN] + kCubeSize - 1) / kCubeSize) * kCubeSize; | |||
| device_shape.push_back(cout16 / kCubeSize); | |||
| } | |||
| device_shape.push_back(kCubeSize); | |||
| device_shape.push_back(kCubeSize); | |||
| return device_shape; | |||
| } | |||
| std::vector<size_t> Nc1hwc0DeviceShape(const std::vector<size_t> &shape) { | |||
| if (!CheckDims(shape)) { | |||
| MS_LOG(EXCEPTION) << "Check dims failed."; | |||
| @@ -282,6 +340,21 @@ std::vector<size_t> Nc1hwc0DeviceShape(const std::vector<size_t> &shape) { | |||
| return device_shape; | |||
| } | |||
| std::vector<int64_t> Nc1hwc0DeviceDynamicShape(const std::vector<int64_t> &shape) { | |||
| if (!CheckDims(shape)) { | |||
| MS_LOG(EXCEPTION) << "Check dims failed."; | |||
| } | |||
| std::vector<int64_t> device_shape; | |||
| const int64_t C1 = (shape[kC] == Shape::SHP_ANY) ? Shape::SHP_ANY : (shape[kC] + kCubeSize - 1) / kCubeSize; | |||
| const int64_t C0 = kCubeSize; | |||
| device_shape.push_back(shape[kN]); | |||
| device_shape.push_back(C1); | |||
| device_shape.push_back(shape[kH]); | |||
| device_shape.push_back(shape[kW]); | |||
| device_shape.push_back(C0); | |||
| return device_shape; | |||
| } | |||
| std::vector<size_t> Ndc1hwc0DeviceShape(const std::vector<size_t> &shape) { | |||
| // NCDHW | |||
| if (shape.size() != 5) { | |||
| @@ -299,6 +372,23 @@ std::vector<size_t> Ndc1hwc0DeviceShape(const std::vector<size_t> &shape) { | |||
| return device_shape; | |||
| } | |||
| std::vector<int64_t> Ndc1hwc0DeviceDynamicShape(const std::vector<int64_t> &shape) { | |||
| // NCDHW | |||
| if (shape.size() != 5) { | |||
| MS_LOG(EXCEPTION) << "Check dims failed, expect shape dim 5, but got shape dim : " << shape.size(); | |||
| } | |||
| std::vector<int64_t> device_shape; | |||
| const int64_t C1 = (shape[1] == Shape::SHP_ANY) ? Shape::SHP_ANY : (shape[1] + kCubeSize - 1) / kCubeSize; | |||
| const int64_t C0 = kCubeSize; | |||
| device_shape.push_back(shape[0]); | |||
| device_shape.push_back(shape[2]); | |||
| device_shape.push_back(C1); | |||
| device_shape.push_back(shape[3]); | |||
| device_shape.push_back(shape[4]); | |||
| device_shape.push_back(C0); | |||
| return device_shape; | |||
| } | |||
| std::vector<size_t> Fracz3DDeviceShape(const std::vector<size_t> &shape) { | |||
| // NCDHW -> Frac_Z_3D | |||
| if (shape.size() != 5) { | |||
| @@ -314,6 +404,26 @@ std::vector<size_t> Fracz3DDeviceShape(const std::vector<size_t> &shape) { | |||
| return device_shape; | |||
| } | |||
| std::vector<int64_t> Fracz3DDeviceDynamicShape(const std::vector<int64_t> &shape) { | |||
| // NCDHW -> Frac_Z_3D | |||
| if (shape.size() != 5) { | |||
| MS_LOG(EXCEPTION) << "Check dims failed, expect shape dim 5, but got shape dim : " << shape.size(); | |||
| } | |||
| std::vector<int64_t> device_shape; | |||
| if (HasShapeDynamic({shape[1], shape[2], shape[3], shape[4]})) { | |||
| device_shape.push_back(Shape::SHP_ANY); | |||
| } else { | |||
| const int64_t C1 = (shape[1] + kCubeSize - 1) / kCubeSize; | |||
| device_shape.push_back(shape[2] * C1 * shape[3] * shape[4]); | |||
| } | |||
| const int64_t N1 = (shape[0] == Shape::SHP_ANY) ? Shape::SHP_ANY : (shape[0] + kCubeSize - 1) / kCubeSize; | |||
| device_shape.push_back(N1); | |||
| device_shape.push_back(kCubeSize); | |||
| device_shape.push_back(kCubeSize); | |||
| return device_shape; | |||
| } | |||
| std::vector<size_t> C1hwncoc0DeviceShape(const std::vector<size_t> &shape) { | |||
| if (!CheckDims(shape)) { | |||
| MS_LOG(EXCEPTION) << "Check dims failed."; | |||
| @@ -328,6 +438,21 @@ std::vector<size_t> C1hwncoc0DeviceShape(const std::vector<size_t> &shape) { | |||
| return device_shape; | |||
| } | |||
| std::vector<int64_t> C1hwncoc0DeviceDynamicShape(const std::vector<int64_t> &shape) { | |||
| if (!CheckDims(shape)) { | |||
| MS_LOG(EXCEPTION) << "Check dims failed."; | |||
| } | |||
| std::vector<int64_t> device_shape; | |||
| shape[kC] == Shape::SHP_ANY ? device_shape.push_back(Shape::SHP_ANY) | |||
| : device_shape.push_back((shape[kC] - 1) / kCubeSize + 1); | |||
| device_shape.push_back(shape[kH]); | |||
| device_shape.push_back(shape[kW]); | |||
| device_shape.push_back(shape[kN]); | |||
| device_shape.push_back(kCubeSize); | |||
| device_shape.push_back(kCubeSize); | |||
| return device_shape; | |||
| } | |||
| std::vector<size_t> FracZc04DeviceShape(const std::vector<size_t> &shape) { | |||
| if (!CheckDims(shape)) { | |||
| MS_LOG(EXCEPTION) << "Check dims failed."; | |||
| @@ -343,6 +468,28 @@ std::vector<size_t> FracZc04DeviceShape(const std::vector<size_t> &shape) { | |||
| return device_shape; | |||
| } | |||
| std::vector<int64_t> FracZc04DeviceDynamicShape(const std::vector<int64_t> &shape) { | |||
| if (!CheckDims(shape)) { | |||
| MS_LOG(EXCEPTION) << "Check dims failed."; | |||
| } | |||
| std::vector<int64_t> device_shape; | |||
| const int64_t c0 = 4; | |||
| int64_t first_dim; | |||
| if (HasShapeDynamic({shape[kH], shape[kW]})) { | |||
| first_dim = Shape::SHP_ANY; | |||
| } else { | |||
| first_dim = DivCeil(c0 * shape[kH] * shape[kW], SizeToLong(kCubeSize)); | |||
| } | |||
| auto shape_kN = shape.at(kN); | |||
| int64_t no = (shape_kN == Shape::SHP_ANY) ? Shape::SHP_ANY : DivCeil(shape.at(kN), SizeToLong(kCubeSize)); | |||
| device_shape.push_back(first_dim); | |||
| device_shape.push_back(no); | |||
| device_shape.push_back(kCubeSize); | |||
| device_shape.push_back(kCubeSize); | |||
| return device_shape; | |||
| } | |||
| std::vector<size_t> Nc1hwc04DeviceShape(const std::vector<size_t> &shape) { | |||
| if (!CheckDims(shape)) { | |||
| MS_LOG(EXCEPTION) << "Check dims failed."; | |||
| @@ -358,6 +505,21 @@ std::vector<size_t> Nc1hwc04DeviceShape(const std::vector<size_t> &shape) { | |||
| return device_shape; | |||
| } | |||
| std::vector<int64_t> Nc1hwc04DeviceDynamicShape(const std::vector<int64_t> &shape) { | |||
| if (!CheckDims(shape)) { | |||
| MS_LOG(EXCEPTION) << "Check dims failed."; | |||
| } | |||
| std::vector<int64_t> device_shape; | |||
| const int64_t C1 = 1; | |||
| const int64_t C0 = 4; | |||
| device_shape.push_back(shape[kN]); | |||
| device_shape.push_back(C1); | |||
| device_shape.push_back(shape[kH]); | |||
| device_shape.push_back(shape[kW]); | |||
| device_shape.push_back(C0); | |||
| return device_shape; | |||
| } | |||
| std::vector<size_t> NcdhwDeviceShape(const std::vector<size_t> &shape) { | |||
| if (shape.size() < kNcdhw) { | |||
| MS_LOG(EXCEPTION) << "Shape dims must be 5 when format is ndhwc."; | |||
| @@ -365,6 +527,13 @@ std::vector<size_t> NcdhwDeviceShape(const std::vector<size_t> &shape) { | |||
| return shape; | |||
| } | |||
| std::vector<int64_t> NcdhwDeviceDynamicShape(const std::vector<int64_t> &shape) { | |||
| if (shape.size() < kNcdhw) { | |||
| MS_LOG(EXCEPTION) << "Shape dims must be 5 when format is ndhwc."; | |||
| } | |||
| return shape; | |||
| } | |||
| // change channel-first shape to channel-last shape. | |||
| // eg. [2,3,4] => [2,4,3]; [2,3,4,5] => [2,4,5,3] | |||
| std::vector<size_t> ChannelLastDeviceShape(const std::vector<size_t> &shape) { | |||
| @@ -380,6 +549,21 @@ std::vector<size_t> ChannelLastDeviceShape(const std::vector<size_t> &shape) { | |||
| return device_shape; | |||
| } | |||
| // change channel-first shape to channel-last shape. | |||
| // eg. [2,3,4] => [2,4,3]; [2,3,4,5] => [2,4,5,3] | |||
| std::vector<int64_t> ChannelLastDeviceDynamicShape(const std::vector<int64_t> &shape) { | |||
| auto dim = shape.size(); | |||
| std::vector<int64_t> axis; | |||
| axis.resize(dim); | |||
| std::iota(axis.begin() + 1, axis.end(), 2); | |||
| axis[dim - 1] = 1; | |||
| std::vector<int64_t> device_shape; | |||
| std::transform(axis.begin(), axis.end(), std::back_inserter(device_shape), [&shape](int n) { return shape[n]; }); | |||
| return device_shape; | |||
| } | |||
| std::vector<size_t> FracZDeviceShapeWithGroups(const std::vector<size_t> &shape, const int64_t groups = 1) { | |||
| if (!CheckDims(shape)) { | |||
| MS_LOG(EXCEPTION) << "Check dims failed."; | |||
| @@ -399,6 +583,80 @@ std::vector<size_t> FracZDeviceShapeWithGroups(const std::vector<size_t> &shape, | |||
| device_shape.push_back(kCubeSize); | |||
| return device_shape; | |||
| } | |||
| std::vector<int64_t> FracZDeviceShapeWithGroups(const std::vector<int64_t> &shape, const int64_t groups = 1) { | |||
| if (!CheckDims(shape)) { | |||
| MS_LOG(EXCEPTION) << "Check dims failed."; | |||
| } | |||
| int64_t c1_dim = Shape::SHP_ANY; | |||
| int64_t g_dim = Shape::SHP_ANY; | |||
| int64_t n1 = Shape::SHP_ANY; | |||
| if (HasShapeDynamic({shape[kC], shape[kN]})) { | |||
| size_t group_size = LongToSize(groups); | |||
| size_t cin_ori_tmp = LongToSize(shape[kC]); | |||
| size_t cout_ori_tmp = LongToSize(shape[kN]) / group_size; | |||
| size_t e_mult = | |||
| std::min(Lcm(Lcm(cin_ori_tmp, kCubeSize) / cin_ori_tmp, Lcm(cout_ori_tmp, kCubeSize) / cout_ori_tmp), group_size); | |||
| int64_t cin_opt = DivCeil(e_mult * cin_ori_tmp, kCubeSize) * kCubeSize; | |||
| c1_dim = cin_opt / kCubeSize; | |||
| g_dim = DivCeil(group_size, e_mult); | |||
| n1 = DivCeil(cout_ori_tmp * e_mult, kCubeSize); | |||
| } | |||
| std::vector<int64_t> device_shape; | |||
| if (HasShapeDynamic({shape[kC], shape[kN], shape[kH], shape[kW]})) { | |||
| device_shape.push_back(g_dim * c1_dim * shape[kH] * shape[kW]); | |||
| } else { | |||
| device_shape.push_back(Shape::SHP_ANY); | |||
| } | |||
| device_shape.push_back(n1); | |||
| device_shape.push_back(kNiSize); | |||
| device_shape.push_back(kCubeSize); | |||
| return device_shape; | |||
| } | |||
| std::vector<int64_t> TransShapeToFracNZ(const std::vector<int64_t> &shape) { | |||
| std::vector<int64_t> device_shape; | |||
| if (shape.size() == 1 && (shape[0] == 1 || shape[0] % kCubeSize == 0)) { | |||
| // For [1] and [1024] shape we can trait it as NZ shape | |||
| return shape; | |||
| } | |||
| if (shape.size() < 2) { | |||
| MS_LOG(EXCEPTION) << "Format FRACTAL_NZ is not support shape " << shape.size(); | |||
| } else { | |||
| (void)std::copy(shape.begin(), shape.end() - 2, std::back_inserter(device_shape)); | |||
| } | |||
| int64_t h_shape = shape[shape.size() - 2]; | |||
| int64_t w_shape = shape[shape.size() - 1]; | |||
| int64_t h1 = (h_shape == Shape::SHP_ANY) ? Shape::SHP_ANY : (h_shape - 1) / kCubeSize + 1; | |||
| int64_t w1 = (w_shape == Shape::SHP_ANY) ? Shape::SHP_ANY : (w_shape - 1) / kCubeSize + 1; | |||
| device_shape.push_back(w1); | |||
| device_shape.push_back(h1); | |||
| device_shape.push_back(kCubeSize); | |||
| device_shape.push_back(kCubeSize); | |||
| return device_shape; | |||
| } | |||
| std::vector<int64_t> TransShapeToFracNZLSTM(const std::vector<int64_t> &shape) { | |||
| std::vector<int64_t> device_shape; | |||
| const int64_t c0 = 4; | |||
| const int64_t h_shape = shape.at(kN); | |||
| const int64_t i_shape = shape.at(kC); | |||
| const int64_t h = (h_shape == Shape::SHP_ANY) ? Shape::SHP_ANY : h_shape / c0; | |||
| int64_t first = Shape::SHP_ANY; | |||
| if (h_shape != Shape::SHP_ANY && i_shape != Shape::SHP_ANY) { | |||
| int64_t i = i_shape - h; | |||
| first = DivCeil(i, SizeToLong(kCubeSize)) + DivCeil(h, SizeToLong(kCubeSize)); | |||
| } | |||
| const int64_t second = (h == Shape::SHP_ANY) ? Shape::SHP_ANY : c0 * DivCeil(h, SizeToLong(kCubeSize)); | |||
| device_shape.push_back(first); | |||
| device_shape.push_back(second); | |||
| device_shape.push_back(kCubeSize); | |||
| device_shape.push_back(kCubeSize); | |||
| return device_shape; | |||
| } | |||
| } // namespace | |||
| int64_t GetAttrGroups(const AnfNodePtr &node, const size_t index) { | |||
| @@ -439,20 +697,6 @@ bool IsNeedPadding(const std::string &format, const size_t shape_size) { | |||
| return false; | |||
| } | |||
| std::vector<size_t> PaddingShape(const std::vector<size_t> &shape, const std::string &format, | |||
| const std::string &pad_index) { | |||
| std::vector<size_t> host_shape; | |||
| if (k3DFormatSet.find(format) != k3DFormatSet.end()) { | |||
| if (shape.size() >= kNcdhw) { | |||
| return shape; | |||
| } | |||
| host_shape = trans::PaddingShapeTo5d(shape, pad_index); | |||
| } else { | |||
| host_shape = trans::PaddingShapeTo4d(shape, pad_index); | |||
| } | |||
| return host_shape; | |||
| } | |||
| ShapeVector GetRuntimePaddingShape(const AnfNodePtr &node, size_t index) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| ShapeVector shape; | |||
| @@ -536,90 +780,6 @@ void StringToAxisVector5D(const std::string &reshape_type_str, std::vector<Axis5 | |||
| } | |||
| } | |||
| std::vector<size_t> PaddingShapeTo5d(const std::vector<size_t> &shape, const std::string &padding_str) { | |||
| std::vector<Axis5D> padding_axis; | |||
| StringToAxisVector5D(padding_str, &padding_axis); | |||
| if (padding_axis.empty() || shape.size() != padding_axis.size()) { | |||
| return PaddingShapeTo5dDefault(shape); | |||
| } | |||
| std::vector<size_t> shape_5d(kNcdhw, 1); | |||
| for (size_t index = 0; index < padding_axis.size(); index++) { | |||
| shape_5d[padding_axis[index]] = shape[index]; | |||
| } | |||
| return shape_5d; | |||
| } | |||
| std::vector<size_t> PaddingShapeTo4d(const std::vector<size_t> &shape, const std::string &padding_str) { | |||
| std::vector<Axis> padding_axis; | |||
| StringToAxisVector4D(padding_str, &padding_axis); | |||
| if (padding_axis.empty() || shape.size() != padding_axis.size()) { | |||
| return PaddingShapeTo4dDefault(shape); | |||
| } | |||
| std::vector<size_t> shape_4d(kNchwDims, 1); | |||
| for (size_t index = 0; index < padding_axis.size(); index++) { | |||
| shape_4d[padding_axis[index]] = shape[index]; | |||
| } | |||
| return shape_4d; | |||
| } | |||
| std::vector<size_t> PaddingShapeTo5dDefault(const std::vector<size_t> &shape) { | |||
| if (shape.size() >= kNcdhw) { | |||
| return shape; | |||
| } | |||
| std::vector<size_t> shape_5d(kNcdhw, 1); | |||
| switch (shape.size()) { | |||
| case 0: | |||
| return shape_5d; | |||
| case 1: | |||
| shape_5d[1] = shape[0]; | |||
| break; | |||
| case 2: | |||
| shape_5d[1] = shape[0]; | |||
| shape_5d[2] = shape[1]; | |||
| break; | |||
| case 3: | |||
| shape_5d[1] = shape[0]; | |||
| shape_5d[2] = shape[1]; | |||
| shape_5d[3] = shape[2]; | |||
| break; | |||
| case 4: | |||
| shape_5d[1] = shape[0]; | |||
| shape_5d[2] = shape[1]; | |||
| shape_5d[3] = shape[2]; | |||
| shape_5d[4] = shape[3]; | |||
| break; | |||
| default: | |||
| MS_LOG(EXCEPTION) << "Unexpected shape size = " << shape.size(); | |||
| } | |||
| return shape_5d; | |||
| } | |||
| std::vector<size_t> PaddingShapeTo4dDefault(const std::vector<size_t> &shape) { | |||
| std::vector<size_t> shape_4d(kNchwDims, 1); | |||
| switch (shape.size()) { | |||
| case 0: | |||
| return shape_4d; | |||
| case 1: | |||
| shape_4d[kC] = shape[kN]; | |||
| break; | |||
| case 2: | |||
| shape_4d[kC] = shape[kN]; | |||
| shape_4d[kH] = shape[kC]; | |||
| break; | |||
| case 3: | |||
| shape_4d[kC] = shape[kN]; | |||
| shape_4d[kH] = shape[kC]; | |||
| shape_4d[kW] = shape[kH]; | |||
| break; | |||
| case 4: | |||
| std::copy(shape.begin(), shape.end(), shape_4d.begin()); | |||
| break; | |||
| default: | |||
| MS_LOG(EXCEPTION) << "Unexpected shape size = " << shape.size(); | |||
| } | |||
| return shape_4d; | |||
| } | |||
| std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const std::string &format, | |||
| const int64_t groups) { | |||
| using DeviceShapeTransfer = std::function<std::vector<size_t>(const std::vector<size_t> &)>; | |||
| @@ -687,13 +847,47 @@ std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const s | |||
| return iter->second(temp_shape); | |||
| } | |||
| std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const std::string &format, | |||
| const AnfNodePtr &node, const size_t index) { | |||
| int64_t groups = 1; | |||
| if (format == kOpFormat_FRAC_Z) { | |||
| groups = GetAttrGroups(node, index); | |||
| std::vector<int64_t> TransShapeToDevice(const std::vector<int64_t> &shape, const std::string &format, | |||
| const int64_t groups) { | |||
| using DeviceShapeTransfer = std::function<std::vector<int64_t>(const std::vector<int64_t> &)>; | |||
| const std::map<std::string, DeviceShapeTransfer> device_shape_map{ | |||
| {kOpFormat_NCHW, NchwDeviceDynamicShape}, | |||
| {kOpFormat_NHWC, NhwcDeviceDynamicShape}, | |||
| {kOpFormat_HWCN, HwchDeviceDynamicShape}, | |||
| {kOpFormat_FRAC_Z, FracZDeviceDynamicShape}, | |||
| {kOpFormat_NC1HWC0, Nc1hwc0DeviceDynamicShape}, | |||
| {kOpFormat_C1HWNCoC0, C1hwncoc0DeviceDynamicShape}, | |||
| {kOpFormat_FRACTAL_Z_C04, FracZc04DeviceDynamicShape}, | |||
| {kOpFormat_NC1HWC0_C04, Nc1hwc04DeviceDynamicShape}, | |||
| {kOpFormat_NCDHW, NcdhwDeviceDynamicShape}, | |||
| {kOpFormat_ChannelLast, ChannelLastDeviceDynamicShape}, | |||
| {kOpFormat_NDC1HWC0, Ndc1hwc0DeviceDynamicShape}, | |||
| {kOpFormat_FRACTAL_Z_3D, Fracz3DDeviceDynamicShape}}; | |||
| if (format == kOpFormat_ND || format == kOpFormat_DEFAULT) { | |||
| return shape; | |||
| } | |||
| if (groups > 1 && format == kOpFormat_FRAC_Z) { | |||
| return FracZDeviceShapeWithGroups(shape, groups); | |||
| } | |||
| return TransShapeToDevice(shape, format, groups); | |||
| auto temp_shape = shape; | |||
| if (format == kOpFormat_FRAC_NZ) { | |||
| return TransShapeToFracNZ(shape); | |||
| } else if (format == kOpFormat_FRACTAL_ZN_LSTM) { | |||
| return TransShapeToFracNZLSTM(shape); | |||
| } | |||
| if (format != kOpFormat_ChannelLast && shape.size() != kNchwDims && k3DFormatSet.find(format) == k3DFormatSet.end()) { | |||
| MS_LOG(WARNING) << "Get Device Shape using a shape size is less than 4 ,should be Padding shape by Default firstly"; | |||
| temp_shape = PaddingShapeTo4dDefault(shape); | |||
| } | |||
| if (shape.size() != kNcdhw && k3DFormatSet.find(format) != k3DFormatSet.end()) { | |||
| temp_shape = PaddingShapeTo5dDefault(shape); | |||
| } | |||
| auto iter = device_shape_map.find(format); | |||
| if (iter == device_shape_map.end()) { | |||
| MS_LOG(EXCEPTION) << "Unexpected format[" << format << "]"; | |||
| } | |||
| return iter->second(temp_shape); | |||
| } | |||
| bool CheckArgs(const FormatArgs &args, size_t *size, size_t *total_size) { | |||
| @@ -27,9 +27,11 @@ | |||
| #include "backend/kernel_compiler/kernel.h" | |||
| #include "ir/dtype/type.h" | |||
| #include "utils/shape_utils.h" | |||
| #include "backend/session/anf_runtime_algorithm.h" | |||
| namespace mindspore { | |||
| namespace trans { | |||
| enum kAxis : int { kN = 0, kC, kH, kW, kNchwDims, kNcdhw }; | |||
| enum Axis5D : int { | |||
| N_ncdhw = 0, | |||
| C_ncdhw, | |||
| @@ -55,12 +57,7 @@ struct FormatArgs { | |||
| TypeId src_data_type; | |||
| }; | |||
| std::vector<size_t> PaddingShape(const std::vector<size_t> &shape, const std::string &format, | |||
| const std::string &pad_index = {""}); | |||
| std::vector<size_t> PaddingShapeTo4d(const std::vector<size_t> &shape, const std::string &padding_axis = {""}); | |||
| std::vector<size_t> PaddingShapeTo5d(const std::vector<size_t> &shape, const std::string &padding_axis = {""}); | |||
| std::vector<size_t> PaddingShapeTo5dDefault(const std::vector<size_t> &shape); | |||
| std::vector<size_t> PaddingShapeTo4dDefault(const std::vector<size_t> &shape); | |||
| int64_t GetAttrGroups(const AnfNodePtr &node, const size_t index); | |||
| void StringToAxisVector4D(const std::string &reshape_type_str, std::vector<Axis> *reshape_type_vec); | |||
| void StringToAxisVector5D(const std::string &reshape_type_str, std::vector<Axis5D> *reshape_type_vec); | |||
| ShapeVector GetRuntimePaddingShape(const AnfNodePtr &node, size_t index); | |||
| @@ -68,8 +65,17 @@ bool IsNeedPadding(const std::string &format, const size_t shape_size); | |||
| int64_t GetNodeGroups(const AnfNodePtr &node); | |||
| std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const std::string &format, | |||
| const int64_t groups = 1); | |||
| std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const std::string &format, | |||
| const AnfNodePtr &node, const size_t index); | |||
| std::vector<int64_t> TransShapeToDevice(const std::vector<int64_t> &shape, const std::string &format, | |||
| const int64_t groups = 1); | |||
| template <typename T> | |||
| std::vector<T> TransShapeToDevice(const std::vector<T> &shape, const std::string &format, const AnfNodePtr &node, | |||
| const size_t index) { | |||
| int64_t groups = 1; | |||
| if (format == kOpFormat_FRAC_Z) { | |||
| groups = GetAttrGroups(node, index); | |||
| } | |||
| return TransShapeToDevice(shape, format, groups); | |||
| } | |||
| bool TransDataType(const TypeIdArgs &args, void *result); | |||
| bool TransFormat(const FormatArgs &args, void *result, int64_t groups = 1); | |||
| bool TransFormat(const FormatArgs &args, void *result, const AnfNodePtr &node, const size_t index); | |||
| @@ -104,6 +110,109 @@ const std::map<std::string, FormatTransfer> kTransFormatMapOfHostToDevice{ | |||
| {kOpFormat_NC1HWC0, NchwToNc1hwc0}, {kOpFormat_C1HWNCoC0, NchwToC1hwncoc0}, | |||
| {kOpFormat_FRACTAL_Z_C04, NchwToFracZc04}, {kOpFormat_NC1HWC0_C04, NchwToNc1hwc04}, | |||
| {kOpFormat_NDC1HWC0, NcdhwToNdc1hwc0}, {kOpFormat_FRACTAL_Z_3D, NcdhwToFracZ3D}}; | |||
| template <typename T> | |||
| std::vector<T> PaddingShapeTo5dDefault(const std::vector<T> &shape) { | |||
| if (shape.size() >= kNcdhw) { | |||
| return shape; | |||
| } | |||
| std::vector<T> shape_5d(kNcdhw, 1); | |||
| switch (shape.size()) { | |||
| case 0: | |||
| return shape_5d; | |||
| case 1: | |||
| shape_5d[1] = shape[0]; | |||
| break; | |||
| case 2: | |||
| shape_5d[1] = shape[0]; | |||
| shape_5d[2] = shape[1]; | |||
| break; | |||
| case 3: | |||
| shape_5d[1] = shape[0]; | |||
| shape_5d[2] = shape[1]; | |||
| shape_5d[3] = shape[2]; | |||
| break; | |||
| case 4: | |||
| shape_5d[1] = shape[0]; | |||
| shape_5d[2] = shape[1]; | |||
| shape_5d[3] = shape[2]; | |||
| shape_5d[4] = shape[3]; | |||
| break; | |||
| default: | |||
| MS_LOG(EXCEPTION) << "Unexpected shape size = " << shape.size(); | |||
| } | |||
| return shape_5d; | |||
| } | |||
| template <typename T> | |||
| std::vector<T> PaddingShapeTo4dDefault(const std::vector<T> &shape) { | |||
| std::vector<T> shape_4d(kNchwDims, 1); | |||
| switch (shape.size()) { | |||
| case 0: | |||
| return shape_4d; | |||
| case 1: | |||
| shape_4d[kC] = shape[kN]; | |||
| break; | |||
| case 2: | |||
| shape_4d[kC] = shape[kN]; | |||
| shape_4d[kH] = shape[kC]; | |||
| break; | |||
| case 3: | |||
| shape_4d[kC] = shape[kN]; | |||
| shape_4d[kH] = shape[kC]; | |||
| shape_4d[kW] = shape[kH]; | |||
| break; | |||
| case 4: | |||
| std::copy(shape.begin(), shape.end(), shape_4d.begin()); | |||
| break; | |||
| default: | |||
| MS_LOG(EXCEPTION) << "Unexpected shape size = " << shape.size(); | |||
| } | |||
| return shape_4d; | |||
| } | |||
| template <typename T> | |||
| std::vector<T> PaddingShapeTo5d(const std::vector<T> &shape, const std::string &padding_str = {""}) { | |||
| std::vector<Axis5D> padding_axis; | |||
| StringToAxisVector5D(padding_str, &padding_axis); | |||
| if (padding_axis.empty() || shape.size() != padding_axis.size()) { | |||
| return PaddingShapeTo5dDefault(shape); | |||
| } | |||
| std::vector<T> shape_5d(kNcdhw, 1); | |||
| for (size_t index = 0; index < padding_axis.size(); index++) { | |||
| shape_5d[padding_axis[index]] = shape[index]; | |||
| } | |||
| return shape_5d; | |||
| } | |||
| template <typename T> | |||
| std::vector<T> PaddingShapeTo4d(const std::vector<T> &shape, const std::string &padding_str = {""}) { | |||
| std::vector<Axis> padding_axis; | |||
| StringToAxisVector4D(padding_str, &padding_axis); | |||
| if (padding_axis.empty() || shape.size() != padding_axis.size()) { | |||
| return PaddingShapeTo4dDefault(shape); | |||
| } | |||
| std::vector<T> shape_4d(kNchwDims, 1); | |||
| for (size_t index = 0; index < padding_axis.size(); index++) { | |||
| shape_4d[padding_axis[index]] = shape[index]; | |||
| } | |||
| return shape_4d; | |||
| } | |||
| template <typename T> | |||
| std::vector<T> PaddingShape(const std::vector<T> &shape, const std::string &format, | |||
| const std::string &pad_index = {""}) { | |||
| std::vector<T> host_shape; | |||
| if (k3DFormatSet.find(format) != k3DFormatSet.end()) { | |||
| if (shape.size() >= kNcdhw) { | |||
| return shape; | |||
| } | |||
| host_shape = trans::PaddingShapeTo5d(shape, pad_index); | |||
| } else { | |||
| host_shape = trans::PaddingShapeTo4d(shape, pad_index); | |||
| } | |||
| return host_shape; | |||
| } | |||
| } // namespace trans | |||
| } // namespace mindspore | |||
| @@ -52,6 +52,7 @@ void FeedTeOpTensorInputArg(const NotNull<CNodePtr> &cnode, | |||
| auto input_node = input_node_with_index.first; | |||
| auto input_index = input_node_with_index.second; | |||
| auto output_shape = AnfAlgo::GetOutputDeviceShape(input_node, input_index); | |||
| auto output_ori_shape = AnfAlgo::GetOutputInferShape(input_node, input_index); | |||
| auto output_format = AnfAlgo::GetOutputFormat(input_node, input_index); | |||
| auto output_dtype = AnfAlgo::GetOutputDeviceDataType(input_node, input_index); | |||
| auto iter = type_name_map.find(output_dtype); | |||
| @@ -65,6 +66,7 @@ void FeedTeOpTensorInputArg(const NotNull<CNodePtr> &cnode, | |||
| tensor_arg.arg_type = optiling::TA_SINGLE; | |||
| tensor.dtype = ge_output_dtype; | |||
| tensor.shape.insert(tensor.shape.end(), output_shape.begin(), output_shape.end()); | |||
| tensor.ori_shape.insert(tensor.ori_shape.end(), output_ori_shape.begin(), output_ori_shape.end()); | |||
| tensor.format = GeTypesConvert::GetGeTilingFormat(GeTypesConvert::GetGeFormat(output_format, output_shape.size())); | |||
| MS_LOG(INFO) << "Tiling Format:" << tensor.format; | |||
| @@ -79,6 +81,7 @@ void FeedTeOpTensorOutputArg(const NotNull<CNodePtr> &cnode, | |||
| auto output_size = AnfAlgo::GetOutputTensorNum(cnode.get()); | |||
| for (size_t i = 0; i < output_size; ++i) { | |||
| auto output_shape = AnfAlgo::GetOutputDeviceShape(cnode.get(), i); | |||
| auto output_ori_shape = AnfAlgo::GetOutputInferShape(cnode.get(), i); | |||
| auto output_format = AnfAlgo::GetOutputFormat(cnode.get(), i); | |||
| auto data_type = AnfAlgo::GetOutputDeviceDataType(cnode.get(), i); | |||
| auto iter = type_name_map.find(data_type); | |||
| @@ -91,6 +94,7 @@ void FeedTeOpTensorOutputArg(const NotNull<CNodePtr> &cnode, | |||
| tensor_arg.arg_type = optiling::TA_SINGLE; | |||
| tensor.dtype = iter->second; | |||
| tensor.shape.insert(tensor.shape.end(), output_shape.begin(), output_shape.end()); | |||
| tensor.ori_shape.insert(tensor.ori_shape.end(), output_ori_shape.begin(), output_ori_shape.end()); | |||
| tensor.format = GeTypesConvert::GetGeTilingFormat(GeTypesConvert::GetGeFormat(output_format, output_shape.size())); | |||
| MS_LOG(INFO) << "Tiling Format:" << tensor.format; | |||
| tensor_arg.tensor.emplace_back(tensor); | |||
| @@ -502,7 +502,9 @@ AbstractBasePtr InferImplCast(const AnalysisEnginePtr &, const PrimitivePtr &pri | |||
| MS_EXCEPTION_IF_NULL(input_x); | |||
| auto attr = primitive->GetAttr("dst_type"); | |||
| if (attr == nullptr) { | |||
| attr = args_spec_list[1]->BuildValue(); | |||
| auto input_dtype = args_spec_list[1]; | |||
| MS_EXCEPTION_IF_NULL(input_dtype); | |||
| attr = input_dtype->BuildValue(); | |||
| MS_EXCEPTION_IF_NULL(attr); | |||
| primitive->set_attr("dst_type", attr); | |||
| } | |||
| @@ -24,13 +24,14 @@ | |||
| #include "utils/check_convert_utils.h" | |||
| #include "abstract/primitive_infer_map.h" | |||
| using mindspore::abstract::Shape; | |||
| namespace mindspore { | |||
| namespace ops { | |||
| namespace { | |||
| // check functions | |||
| void CheckShapeAnyAndPositive(const std::string &op, const ShapeVector &shape) { | |||
| for (size_t i = 0; i < shape.size(); ++i) { | |||
| if ((shape[i] < 0) && (shape[i] != abstract::Shape::SHP_ANY)) { | |||
| if ((shape[i] < 0) && (shape[i] != Shape::SHP_ANY)) { | |||
| MS_EXCEPTION(ValueError) << op << " shape element [" << i << "] must be positive integer or SHP_ANY, but got " | |||
| << shape[i]; | |||
| } | |||
| @@ -74,28 +75,61 @@ void Conv2DPadFunction(std::vector<int64_t> *output_hw, std::vector<int64_t> *pa | |||
| const std::vector<int64_t> &dilation, const int64_t &pad_mode, | |||
| const std::vector<int64_t> &padding) { | |||
| if (pad_mode == PadMode::VALID) { | |||
| output_hw->push_back(static_cast<int64_t>(std::ceil(((x_h * 1.0) - dilation[0] * (kernel[0] - 1)) / stride[0]))); | |||
| output_hw->push_back(static_cast<int64_t>(std::ceil(((x_w * 1.0) - dilation[1] * (kernel[1] - 1)) / stride[1]))); | |||
| int64_t out_h = -1; | |||
| int64_t out_w = -1; | |||
| if (x_h != Shape::SHP_ANY) { | |||
| auto h_shape = static_cast<int64_t>(std::ceil(((x_h * 1.0) - dilation[0] * (kernel[0] - 1)) / stride[0])); | |||
| out_h = h_shape >= 1 ? h_shape : 1L; | |||
| } | |||
| if (x_w != Shape::SHP_ANY) { | |||
| auto w_shape = static_cast<int64_t>(std::ceil(((x_w * 1.0) - dilation[1] * (kernel[1] - 1)) / stride[1])); | |||
| out_w = w_shape >= 1 ? w_shape : 1L; | |||
| } | |||
| output_hw->push_back(out_h); | |||
| output_hw->push_back(out_w); | |||
| (void)pad_list->insert(pad_list->begin(), 4, 0); | |||
| } else if (pad_mode == PadMode::SAME) { | |||
| output_hw->push_back(static_cast<int64_t>(std::ceil((x_h * 1.0) / stride[0]))); | |||
| output_hw->push_back(static_cast<int64_t>(std::ceil((x_w * 1.0) / stride[1]))); | |||
| int64_t pad_needed_h = (output_hw->at(0) - 1) * stride[0] + dilation[0] * (kernel[0] - 1) + 1 - x_h; | |||
| pad_needed_h = std::max((int64_t)0, pad_needed_h); | |||
| pad_list->push_back(static_cast<int64_t>(std::floor(pad_needed_h / 2))); | |||
| pad_list->push_back(pad_needed_h - pad_list->at(0)); | |||
| int64_t pad_needed_w = (output_hw->at(1) - 1) * stride[1] + dilation[1] * (kernel[1] - 1) + 1 - x_w; | |||
| pad_needed_w = std::max((int64_t)0, pad_needed_w); | |||
| pad_list->push_back(static_cast<int64_t>(std::floor(pad_needed_w / 2))); | |||
| pad_list->push_back(pad_needed_w - pad_list->at(2)); | |||
| if (x_h == Shape::SHP_ANY) { | |||
| output_hw->push_back(Shape::SHP_ANY); | |||
| pad_list->push_back(Shape::SHP_ANY); | |||
| pad_list->push_back(Shape::SHP_ANY); | |||
| } else { | |||
| output_hw->push_back(static_cast<int64_t>(std::ceil((x_h * 1.0) / stride[0]))); | |||
| int64_t pad_needed_h = (output_hw->at(0) - 1) * stride[0] + dilation[0] * (kernel[0] - 1) + 1 - x_h; | |||
| pad_needed_h = std::max((int64_t)0, pad_needed_h); | |||
| pad_list->push_back(static_cast<int64_t>(std::floor(pad_needed_h / 2))); | |||
| pad_list->push_back(pad_needed_h - pad_list->at(0)); | |||
| } | |||
| if (x_w == Shape::SHP_ANY) { | |||
| output_hw->push_back(Shape::SHP_ANY); | |||
| pad_list->push_back(Shape::SHP_ANY); | |||
| pad_list->push_back(Shape::SHP_ANY); | |||
| } else { | |||
| output_hw->push_back(static_cast<int64_t>(std::ceil((x_w * 1.0) / stride[1]))); | |||
| int64_t pad_needed_w = (output_hw->at(1) - 1) * stride[1] + dilation[1] * (kernel[1] - 1) + 1 - x_w; | |||
| pad_needed_w = std::max((int64_t)0, pad_needed_w); | |||
| pad_list->push_back(static_cast<int64_t>(std::floor(pad_needed_w / 2))); | |||
| pad_list->push_back(pad_needed_w - pad_list->at(2)); | |||
| } | |||
| } else if (pad_mode == PadMode::PAD) { | |||
| (void)pad_list->insert(pad_list->begin(), padding.begin(), padding.end()); | |||
| output_hw->push_back(static_cast<int64_t>(std::floor( | |||
| 1 + ((x_h * 1.0) + pad_list->at(0) + pad_list->at(1) - kernel[0] - (kernel[0] - 1) * (dilation[0] - 1)) / | |||
| stride[0]))); | |||
| output_hw->push_back(static_cast<int64_t>(std::floor( | |||
| 1 + ((x_w * 1.0) + pad_list->at(2) + pad_list->at(3) - kernel[1] - (kernel[1] - 1) * (dilation[1] - 1)) / | |||
| stride[1]))); | |||
| int64_t out_h = -1; | |||
| int64_t out_w = -1; | |||
| if (x_h != Shape::SHP_ANY) { | |||
| auto h_shape = static_cast<int64_t>(std::floor( | |||
| 1 + ((x_h * 1.0) + pad_list->at(0) + pad_list->at(1) - kernel[0] - (kernel[0] - 1) * (dilation[0] - 1)) / | |||
| stride[0])); | |||
| out_h = h_shape >= 1 ? h_shape : 1L; | |||
| } | |||
| if (x_w != Shape::SHP_ANY) { | |||
| auto w_shape = static_cast<int64_t>(std::floor( | |||
| 1 + ((x_w * 1.0) + pad_list->at(2) + pad_list->at(3) - kernel[1] - (kernel[1] - 1) * (dilation[1] - 1)) / | |||
| stride[1])); | |||
| out_w = w_shape >= 1 ? w_shape : 1L; | |||
| } | |||
| output_hw->push_back(out_h); | |||
| output_hw->push_back(out_w); | |||
| } | |||
| } | |||
| @@ -131,20 +165,20 @@ abstract::ShapePtr Conv2dInferShape(const PrimitivePtr &primitive, const std::ve | |||
| w_axis = 2; | |||
| } | |||
| int64_t group = CheckAttrPositiveInt64(prim_name, primitive->GetAttr("group"), "group"); | |||
| if ((x_shape[c_axis] != abstract::Shape::SHP_ANY) && (w_shape[c_axis] != abstract::Shape::SHP_ANY) && | |||
| if ((x_shape[c_axis] != Shape::SHP_ANY) && (w_shape[c_axis] != Shape::SHP_ANY) && | |||
| ((x_shape[c_axis] / group) != w_shape[c_axis])) { | |||
| MS_LOG(EXCEPTION) << "x_shape[C_in] / group must equal to w_shape[C_in] = " << w_shape[c_axis] << ", but got " | |||
| << (x_shape[c_axis] / group); | |||
| } | |||
| int64_t out_channel = CheckAttrPositiveInt64(prim_name, primitive->GetAttr("out_channel"), "out_channel"); | |||
| if ((w_shape[n_axis] != abstract::Shape::SHP_ANY) && (w_shape[n_axis] != out_channel)) { | |||
| if ((w_shape[n_axis] != Shape::SHP_ANY) && (w_shape[n_axis] != out_channel)) { | |||
| MS_LOG(EXCEPTION) << "w_shape[" << n_axis << "] = " << w_shape[n_axis] << " must equal to = " << out_channel; | |||
| } | |||
| std::vector<int64_t> kernel_size = CheckAttrIntOrTuple(prim_name, primitive->GetAttr("kernel_size"), 0, 2); | |||
| if ((w_shape[h_axis] != abstract::Shape::SHP_ANY) && (w_shape[h_axis] != kernel_size[0])) { | |||
| if ((w_shape[h_axis] != Shape::SHP_ANY) && (w_shape[h_axis] != kernel_size[0])) { | |||
| MS_LOG(EXCEPTION) << "weight height = " << w_shape[h_axis] << ", must equal to = " << kernel_size[0]; | |||
| } | |||
| if ((w_shape[w_axis] != abstract::Shape::SHP_ANY) && (w_shape[w_axis] != kernel_size[1])) { | |||
| if ((w_shape[w_axis] != Shape::SHP_ANY) && (w_shape[w_axis] != kernel_size[1])) { | |||
| MS_LOG(EXCEPTION) << "weight width = " << w_shape[w_axis] << ", must equal to = " << kernel_size[1]; | |||
| } | |||
| std::vector<int64_t> stride = CheckAttrIntOrTuple(prim_name, primitive->GetAttr("stride"), 2, 2); | |||
| @@ -160,16 +194,6 @@ abstract::ShapePtr Conv2dInferShape(const PrimitivePtr &primitive, const std::ve | |||
| std::vector<int64_t> pad_list_max; | |||
| Conv2DPadFunction(&output_hw, &pad_list, x_shape[h_axis], x_shape[w_axis], kernel_size, stride, dilation, pad_mode, | |||
| padding); | |||
| if (x_shape[h_axis] == abstract::Shape::SHP_ANY) { | |||
| output_hw[0] = abstract::Shape::SHP_ANY; | |||
| pad_list[0] = abstract::Shape::SHP_ANY; | |||
| pad_list[1] = abstract::Shape::SHP_ANY; | |||
| } | |||
| if (x_shape[w_axis] == abstract::Shape::SHP_ANY) { | |||
| output_hw[1] = abstract::Shape::SHP_ANY; | |||
| pad_list[2] = abstract::Shape::SHP_ANY; | |||
| pad_list[3] = abstract::Shape::SHP_ANY; | |||
| } | |||
| Conv2DPadFunction(&output_hw_min, &pad_list_min, x_min_shape[h_axis], x_min_shape[w_axis], kernel_size, stride, | |||
| dilation, pad_mode, padding); | |||
| Conv2DPadFunction(&output_hw_max, &pad_list_max, x_max_shape[h_axis], x_max_shape[w_axis], kernel_size, stride, | |||
| @@ -46,6 +46,7 @@ from .assign import _assign_tbe | |||
| from .assign_add import _assign_add_tbe | |||
| from .assign_sub import _assign_sub_tbe | |||
| from .batch_matmul import _batch_matmul_tbe | |||
| from .batch_matmul_ds import _batch_matmul_ds_tbe | |||
| from .batchnorm import _batch_norm_tbe | |||
| from .batchnorm_grad import _batch_norm_grad_tbe | |||
| from .bias_add import _bias_add_tbe | |||
| @@ -55,6 +56,7 @@ from .cast_ds import _cast_ds_tbe | |||
| from .conv2d import _conv2d_tbe | |||
| from .conv2d_backprop_filter import _conv2d_backprop_filter_tbe | |||
| from .conv2d_backprop_input import _conv2d_backprop_input_tbe | |||
| from .conv2d_ds import _conv2d_ds_tbe | |||
| from .confusion_mul_grad import _confusion_mul_grad_tbe | |||
| from .dropout_do_mask import _dropout_do_mask_tbe | |||
| from .dropout_do_mask_ds import _dropout_do_mask_ds_tbe | |||
| @@ -92,6 +94,7 @@ from .trans_data import _trans_data_tbe | |||
| from .trans_data_ds import _trans_data_ds_tbe | |||
| from .top_k import _top_k_tbe | |||
| from .matmul import _matmul_tbe | |||
| from .matmul_ds import _matmul_ds_tbe | |||
| from .sub import _sub_tbe | |||
| from .sub_ds import _sub_ds_tbe | |||
| from .scatter_nd import _scatter_nd_tbe | |||
| @@ -0,0 +1,47 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Conv2D op""" | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| conv2d_op_info = TBERegOp("Conv2D") \ | |||
| .fusion_type("CONVLUTION") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("conv2d.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("conv2d") \ | |||
| .partial_flag(True) \ | |||
| .dynamic_shape(True) \ | |||
| .attr("stride", "required", "listInt", "all") \ | |||
| .attr("pad_list", "required", "listInt", "all") \ | |||
| .attr("dilation", "required", "listInt", "all") \ | |||
| .attr("groups", "optional", "int", "all") \ | |||
| .attr("format", "optional", "str", "all") \ | |||
| .attr("offset_x", "optional", "int", "all", "0") \ | |||
| .input(0, "x", False, "required", "all") \ | |||
| .input(1, "filter", False, "required", "all") \ | |||
| .input(2, "bias", False, "optional", "all") \ | |||
| .input(3, "offset_w", False, "optional", "all") \ | |||
| .output(0, "y", True, "required", "all") \ | |||
| .is_dynamic_format(True) \ | |||
| .dtype_format(DataType.F16_None, DataType.F16_None, DataType.F16_None, DataType.I8_None, DataType.F16_None) \ | |||
| .get_op_info() | |||
| # .dtype_format(DataType.F16_5HD, DataType.F16_FracZ, DataType.F16_Default, DataType.I8_Default, DataType.F16_Default) ? | |||
| @op_info_register(conv2d_op_info) | |||
| def _conv2d_ds_tbe(): | |||
| """Conv2D TBE register""" | |||
| return | |||