| @@ -66,7 +66,7 @@ size_t KernelBuildInfo::GetInputNum() const { return inputs_format_.size(); } | |||
| size_t KernelBuildInfo::GetOutputNum() const { return outputs_format_.size(); } | |||
| std::vector<Axis> KernelBuildInfo::GetInputReshapeType(size_t input_index) const { | |||
| std::string KernelBuildInfo::GetInputReshapeType(size_t input_index) const { | |||
| if (input_reshape_type_.empty()) { | |||
| return {}; | |||
| } | |||
| @@ -77,7 +77,7 @@ std::vector<Axis> KernelBuildInfo::GetInputReshapeType(size_t input_index) const | |||
| return input_reshape_type_[input_index]; | |||
| } | |||
| std::vector<Axis> KernelBuildInfo::GetOutputReshapeType(size_t output_index) const { | |||
| std::string KernelBuildInfo::GetOutputReshapeType(size_t output_index) const { | |||
| if (output_reshape_type_.empty()) { | |||
| return {}; | |||
| } | |||
| @@ -175,14 +175,13 @@ void KernelBuildInfo::KernelBuildInfoBuilder::SetProcessor(Processor processor) | |||
| std::shared_ptr<KernelBuildInfo> KernelBuildInfo::KernelBuildInfoBuilder::Build() { return kernel_build_info_; } | |||
| void KernelBuildInfo::KernelBuildInfoBuilder::SetInputsReshapeType( | |||
| const std::vector<std::vector<Axis>> &input_reshape_type) { | |||
| void KernelBuildInfo::KernelBuildInfoBuilder::SetInputsReshapeType(const std::vector<std::string> &input_reshape_type) { | |||
| MS_EXCEPTION_IF_NULL(kernel_build_info_); | |||
| kernel_build_info_->input_reshape_type_ = input_reshape_type; | |||
| } | |||
| void KernelBuildInfo::KernelBuildInfoBuilder::SetOutputsReshapeType( | |||
| const std::vector<std::vector<Axis>> &output_reshape_type) { | |||
| const std::vector<std::string> &output_reshape_type) { | |||
| MS_EXCEPTION_IF_NULL(kernel_build_info_); | |||
| kernel_build_info_->output_reshape_type_ = output_reshape_type; | |||
| } | |||
| @@ -206,8 +205,7 @@ void KernelBuildInfo::KernelBuildInfoBuilder::SetOutputFormat(const std::string | |||
| } | |||
| kernel_build_info_->outputs_format_[index] = format; | |||
| } | |||
| void KernelBuildInfo::KernelBuildInfoBuilder::SetInputReshapeType(const std::vector<Axis> &input_reshape_type, | |||
| size_t index) { | |||
| void KernelBuildInfo::KernelBuildInfoBuilder::SetInputReshapeType(const std::string &input_reshape_type, size_t index) { | |||
| if (index >= kernel_build_info_->input_reshape_type_.size()) { | |||
| MS_LOG(EXCEPTION) << "index outof range!"; | |||
| } | |||
| @@ -215,7 +213,7 @@ void KernelBuildInfo::KernelBuildInfoBuilder::SetInputReshapeType(const std::vec | |||
| std::back_inserter(kernel_build_info_->input_reshape_type_[index])); | |||
| } | |||
| void KernelBuildInfo::KernelBuildInfoBuilder::SetOutputReshapeType(const std::vector<Axis> &output_reshape_type, | |||
| void KernelBuildInfo::KernelBuildInfoBuilder::SetOutputReshapeType(const std::string &output_reshape_type, | |||
| size_t index) { | |||
| if (index >= kernel_build_info_->output_reshape_type_.size()) { | |||
| MS_LOG(EXCEPTION) << "index outof range!"; | |||
| @@ -57,13 +57,13 @@ class KernelBuildInfo { | |||
| TypeId GetOutputDeviceType(size_t output_index) const; | |||
| std::vector<Axis> GetInputReshapeType(size_t input_index) const; | |||
| std::string GetInputReshapeType(size_t input_index) const; | |||
| bool IsInputDefaultPadding() const; | |||
| bool IsOutputDefaultPadding() const; | |||
| std::vector<Axis> GetOutputReshapeType(size_t input_index) const; | |||
| std::string GetOutputReshapeType(size_t input_index) const; | |||
| const std::string &GetOriginDataFormat() const; | |||
| @@ -75,9 +75,9 @@ class KernelBuildInfo { | |||
| const std::vector<TypeId> &GetAllOutputDeviceTypes() const; | |||
| std::vector<std::vector<Axis>> GetAllOutputReshapeType() const; | |||
| std::vector<std::string> GetAllOutputReshapeType() const; | |||
| std::vector<std::vector<Axis>> GetAllInputReshapeType() const; | |||
| std::vector<std::string> GetAllInputReshapeType() const; | |||
| OpPattern op_pattern() const { return op_pattern_; } | |||
| @@ -106,8 +106,8 @@ class KernelBuildInfo { | |||
| std::vector<std::string> inputs_format_; | |||
| OpPattern op_pattern_; | |||
| std::vector<std::string> outputs_format_; | |||
| std::vector<std::vector<Axis>> input_reshape_type_; | |||
| std::vector<std::vector<Axis>> output_reshape_type_; | |||
| std::vector<std::string> input_reshape_type_; | |||
| std::vector<std::string> output_reshape_type_; | |||
| std::vector<TypeId> inputs_device_type_; | |||
| std::vector<TypeId> outputs_device_type_; | |||
| FusionType fusion_type_; | |||
| @@ -151,9 +151,9 @@ class KernelBuildInfo::KernelBuildInfoBuilder { | |||
| void SetOutputsDeviceType(const std::vector<TypeId> &outputs_device_type); | |||
| void SetInputsReshapeType(const std::vector<std::vector<Axis>> &input_reshape_type); | |||
| void SetInputsReshapeType(const std::vector<std::string> &input_reshape_type); | |||
| void SetOutputsReshapeType(const std::vector<std::vector<Axis>> &output_reshape_type); | |||
| void SetOutputsReshapeType(const std::vector<std::string> &output_reshape_type); | |||
| void SetFusionType(FusionType fusion_type); | |||
| @@ -165,9 +165,9 @@ class KernelBuildInfo::KernelBuildInfoBuilder { | |||
| void SetOutputFormat(const std::string &format, size_t index); | |||
| void SetInputReshapeType(const std::vector<Axis> &input_reshape_type, size_t index); | |||
| void SetInputReshapeType(const std::string &input_reshape_type, size_t index); | |||
| void SetOutputReshapeType(const std::vector<Axis> &output_reshape_type, size_t index); | |||
| void SetOutputReshapeType(const std::string &output_reshape_type, size_t index); | |||
| void SetInputDeviceType(const TypeId &input_device_type, size_t index); | |||
| @@ -99,7 +99,7 @@ void TbeKernelSelect::GetCommonPatternKernelInfo(const OpInfo &op_info) { | |||
| SetTbeBuildCommonInfo(op_info, &builder); | |||
| std::vector<std::string> inputs_format; | |||
| std::vector<TypeId> inputs_device_type; | |||
| std::vector<std::vector<Axis>> inputs_reshape_type; | |||
| std::vector<std::string> inputs_reshape_type; | |||
| // input | |||
| if (!GenBuilderItem(true, kernel_build_info_index, real_input_tensor_num, inputs_info, dyn_input_sizes, | |||
| &inputs_format, &inputs_device_type, &inputs_reshape_type)) { | |||
| @@ -111,7 +111,7 @@ void TbeKernelSelect::GetCommonPatternKernelInfo(const OpInfo &op_info) { | |||
| // output | |||
| std::vector<std::string> outputs_format; | |||
| std::vector<TypeId> outputs_device_type; | |||
| std::vector<std::vector<Axis>> outputs_reshape_type; | |||
| std::vector<std::string> outputs_reshape_type; | |||
| if (!GenBuilderItem(false, kernel_build_info_index, real_output_tensor_num, outputs_info, dyn_input_sizes, | |||
| &outputs_format, &outputs_device_type, &outputs_reshape_type)) { | |||
| break; | |||
| @@ -290,7 +290,7 @@ std::vector<int64_t> TbeKernelSelect::GetNodeDynamicInputs() { | |||
| bool TbeKernelSelect::GenBuilderItem(bool is_input, size_t kernel_build_info_index, size_t real_io_tensor_num, | |||
| const std::vector<std::shared_ptr<OpIOInfo>> &ios_info, | |||
| const std::vector<int64_t> &dyn_input_sizes, std::vector<std::string> *formats, | |||
| std::vector<TypeId> *device_types, std::vector<std::vector<Axis>> *reshape_types) { | |||
| std::vector<TypeId> *device_types, std::vector<std::string> *reshape_types) { | |||
| MS_EXCEPTION_IF_NULL(formats); | |||
| MS_EXCEPTION_IF_NULL(device_types); | |||
| MS_EXCEPTION_IF_NULL(reshape_types); | |||
| @@ -306,8 +306,7 @@ bool TbeKernelSelect::GenBuilderItem(bool is_input, size_t kernel_build_info_ind | |||
| kernel_build_info_format = io_info_item->formats()[kernel_build_info_index]; | |||
| } | |||
| const std::string &io_param_type = io_info_item->param_type(); | |||
| std::vector<Axis> reshape_type; | |||
| StringToAxisVector(io_info_item->reshape_type(), &reshape_type); | |||
| auto reshape_type = io_info_item->reshape_type(); | |||
| if (io_param_type == kParamTypeDynamic) { | |||
| // dynamic io | |||
| if (is_input) { | |||
| @@ -355,28 +354,6 @@ bool TbeKernelSelect::GenBuilderItem(bool is_input, size_t kernel_build_info_ind | |||
| return true; | |||
| } | |||
| void TbeKernelSelect::StringToAxisVector(const std::string &reshape_type_str, std::vector<Axis> *reshape_type_vec) { | |||
| MS_EXCEPTION_IF_NULL(reshape_type_vec); | |||
| for (const auto &c : reshape_type_str) { | |||
| switch (c) { | |||
| case 'N': | |||
| reshape_type_vec->push_back(N); | |||
| break; | |||
| case 'C': | |||
| reshape_type_vec->push_back(C); | |||
| break; | |||
| case 'H': | |||
| reshape_type_vec->push_back(H); | |||
| break; | |||
| case 'W': | |||
| reshape_type_vec->push_back(W); | |||
| break; | |||
| default: | |||
| MS_LOG(EXCEPTION) << "Unknown axis " << c << "in reshape type."; | |||
| } | |||
| } | |||
| } | |||
| void TbeKernelSelect::CreateNewOpIOInfo(const mindspore::kernel::OpIOInfo &op_io_info, | |||
| const std::vector<std::vector<std::string>> &support_format_item, size_t index, | |||
| mindspore::kernel::OpIOInfo *op_io_info_new) { | |||
| @@ -52,8 +52,7 @@ class TbeKernelSelect { | |||
| bool GenBuilderItem(bool is_input, size_t kernel_build_info_index, size_t real_io_tensor_num, | |||
| const std::vector<std::shared_ptr<OpIOInfo>> &ios_info, | |||
| const std::vector<int64_t> &dyn_input_sizes, std::vector<std::string> *formats, | |||
| std::vector<TypeId> *device_types, std::vector<std::vector<Axis>> *reshape_types); | |||
| static void StringToAxisVector(const std::string &reshape_type_str, std::vector<Axis> *reshape_type_vec); | |||
| std::vector<TypeId> *device_types, std::vector<std::string> *reshape_types); | |||
| static void CreateNewOpInfo(const OpInfo &op_info, const SupportFormat &support_format, OpInfo *op_info_new); | |||
| static void CreateNewOpIOInfo(const OpIOInfo &op_io_info, | |||
| const std::vector<std::vector<std::string>> &support_format_item, size_t index, | |||
| @@ -187,8 +187,8 @@ AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePt | |||
| AnfNodePtr input_node = is_insert_input ? AnfAlgo::GetInputNode(node->cast<CNodePtr>(), insert_index) : node; | |||
| std::string input_format = is_insert_input ? default_format : AnfAlgo::GetOutputFormat(node, insert_index); | |||
| std::string dst_format = is_insert_input ? AnfAlgo::GetInputFormat(node, insert_index) : default_format; | |||
| std::vector<Axis> padding_axis = is_insert_input ? AnfAlgo::GetInputReshapeType(node, insert_index) | |||
| : AnfAlgo::GetOutputReshapeType(node, insert_index); | |||
| std::string padding_axis = is_insert_input ? AnfAlgo::GetInputReshapeType(node, insert_index) | |||
| : AnfAlgo::GetOutputReshapeType(node, insert_index); | |||
| auto input_node_out_shape = is_insert_input ? AnfAlgo::GetPrevNodeOutputInferShape(node, insert_index) | |||
| : AnfAlgo::GetOutputInferShape(input_node, insert_index); | |||
| bool need_padding = is_insert_input ? trans::IsNeedPadding(dst_format, input_node_out_shape.size()) | |||
| @@ -200,8 +200,8 @@ AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePt | |||
| } else if (is_insert_input) { | |||
| // if need padding & is input need insert a transdata | |||
| // reshape[padding shape] -> transdata[padding shape] -> node | |||
| auto padding_shape = | |||
| trans::PaddingShapeTo4d(input_node_out_shape, AnfAlgo::GetInputReshapeType(node, insert_index)); | |||
| auto padding_shape = trans::PaddingShape(input_node_out_shape, AnfAlgo::GetInputFormat(node, insert_index), | |||
| AnfAlgo::GetInputReshapeType(node, insert_index)); | |||
| auto reshape_node = CreateReshapeNode(func_graph, input_node, kernel_select, padding_shape); | |||
| trans_data = NewTransOpNode(func_graph, reshape_node, kernel_select, need_padding, prim::KPrimTransData->name()); | |||
| trans_node = trans_data; | |||
| @@ -222,8 +222,7 @@ AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePt | |||
| } | |||
| void RefreshKernelBuildInfo(const std::string &input_format, const std::string &output_format, | |||
| const AnfNodePtr &trans_data, const std::vector<Axis> &reshape_type, | |||
| const TypeId &type_id) { | |||
| const AnfNodePtr &trans_data, const std::string &reshape_type, const TypeId &type_id) { | |||
| MS_EXCEPTION_IF_NULL(trans_data); | |||
| auto ori_build_info = AnfAlgo::GetSelectKernelBuildInfo(trans_data); | |||
| MS_EXCEPTION_IF_NULL(ori_build_info); | |||
| @@ -249,9 +248,10 @@ CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, | |||
| if (need_padding) { | |||
| // if need padding we should set the transdata node's shape to the padding shape | |||
| auto padding_axis = AnfAlgo::GetOutputReshapeType(input, 0); | |||
| AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(input, 0)}, | |||
| {trans::PaddingShapeTo4d(AnfAlgo::GetOutputInferShape(input, 0), padding_axis)}, | |||
| trans_node.get()); | |||
| AnfAlgo::SetOutputInferTypeAndShape( | |||
| {AnfAlgo::GetOutputInferDataType(input, 0)}, | |||
| {trans::PaddingShape(AnfAlgo::GetOutputInferShape(input, 0), AnfAlgo::GetOutputFormat(input, 0), padding_axis)}, | |||
| trans_node.get()); | |||
| } else { | |||
| AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(input, 0)}, | |||
| {AnfAlgo::GetOutputInferShape(input, 0)}, trans_node.get()); | |||
| @@ -273,7 +273,7 @@ CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, | |||
| CNodePtr AddCastOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const std::string &format, | |||
| const TypeId &input_type, const TypeId &output_type, | |||
| const std::vector<size_t> &origin_shape, const TypeId &origin_type, | |||
| const std::vector<Axis> &reshape_type) { | |||
| const std::string &reshape_type) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| std::string input_format = format; | |||
| std::string output_format = format; | |||
| @@ -88,7 +88,7 @@ class OpFinder { | |||
| using OpFinderPtr = std::shared_ptr<OpFinder>; | |||
| void RefreshKernelBuildInfo(const std::string &input_format, const std::string &output_format, | |||
| const AnfNodePtr &trans_data, const std::vector<Axis> &reshape_type = {}, | |||
| const AnfNodePtr &trans_data, const std::string &reshape_type = {""}, | |||
| const TypeId &type_id = kTypeUnknown); | |||
| CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const KernelSelectPtr &kernel_select, | |||
| @@ -97,7 +97,7 @@ CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, | |||
| CNodePtr AddCastOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const std::string &format, | |||
| const TypeId &input_type, const TypeId &output_type, | |||
| const std::vector<size_t> &origin_shape, const TypeId &origin_type, | |||
| const std::vector<Axis> &reshape_type = std::vector<Axis>{}); | |||
| const std::string &reshape_type = std::string{}); | |||
| AnfNodePtr InsertTransOpForInput(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | |||
| const KernelSelectPtr &kernel_select); | |||
| @@ -586,7 +586,7 @@ std::string AnfRuntimeAlgorithm::GetPrevNodeOutputFormat(const AnfNodePtr &anf_n | |||
| return AnfRuntimeAlgorithm::GetOutputFormat(kernel_with_index.first, kernel_with_index.second); | |||
| } | |||
| std::vector<Axis> AnfRuntimeAlgorithm::GetPrevNodeOutputReshapeType(const AnfNodePtr &node, size_t input_idx) { | |||
| std::string AnfRuntimeAlgorithm::GetPrevNodeOutputReshapeType(const AnfNodePtr &node, size_t input_idx) { | |||
| KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(node, input_idx); | |||
| return GetOutputReshapeType(kernel_with_index.first, kernel_with_index.second); | |||
| } | |||
| @@ -642,7 +642,7 @@ std::vector<size_t> AnfRuntimeAlgorithm::GetOutputDeviceShape(const AnfNodePtr & | |||
| } | |||
| // if format is default_format or NC1KHKWHWC0,device shape = original shape | |||
| if (trans::IsNeedPadding(format, infer_shape.size())) { | |||
| infer_shape = trans::PaddingShapeTo4d(infer_shape, GetOutputReshapeType(node, output_idx)); | |||
| infer_shape = trans::PaddingShape(infer_shape, format, GetOutputReshapeType(node, output_idx)); | |||
| } | |||
| return trans::TransShapeToDevice(infer_shape, format); | |||
| } | |||
| @@ -655,12 +655,12 @@ std::vector<size_t> AnfRuntimeAlgorithm::GetInputDeviceShape(const AnfNodePtr &n | |||
| } | |||
| // if format is default_format or NC1KHKWHWC0,device shape = original shape | |||
| if (trans::IsNeedPadding(format, infer_shape.size())) { | |||
| infer_shape = trans::PaddingShapeTo4d(infer_shape, GetInputReshapeType(node, input_idx)); | |||
| infer_shape = trans::PaddingShape(infer_shape, format, GetInputReshapeType(node, input_idx)); | |||
| } | |||
| return trans::TransShapeToDevice(infer_shape, format); | |||
| } | |||
| std::vector<Axis> AnfRuntimeAlgorithm::GetInputReshapeType(const AnfNodePtr &node, size_t input_idx) { | |||
| std::string AnfRuntimeAlgorithm::GetInputReshapeType(const AnfNodePtr &node, size_t input_idx) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| if (input_idx > GetInputTensorNum(node)) { | |||
| MS_LOG(EXCEPTION) << "The index:" << input_idx | |||
| @@ -681,7 +681,7 @@ std::vector<Axis> AnfRuntimeAlgorithm::GetInputReshapeType(const AnfNodePtr &nod | |||
| return build_info->GetInputReshapeType(input_idx); | |||
| } | |||
| std::vector<Axis> AnfRuntimeAlgorithm::GetOutputReshapeType(const AnfNodePtr &node, size_t output_idx) { | |||
| std::string AnfRuntimeAlgorithm::GetOutputReshapeType(const AnfNodePtr &node, size_t output_idx) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| if (output_idx > GetOutputTensorNum(node)) { | |||
| MS_LOG(EXCEPTION) << "The index [" << output_idx << "] is out of range of the node's output size [ " | |||
| @@ -122,7 +122,7 @@ class AnfRuntimeAlgorithm { | |||
| // get output format from prev node,input_index is the input index of current node related to prev node | |||
| static std::string GetPrevNodeOutputFormat(const AnfNodePtr &node, size_t input_idx); | |||
| // get reshape_type of from the output of input node. | |||
| static std::vector<Axis> GetPrevNodeOutputReshapeType(const AnfNodePtr &node, size_t input_idx); | |||
| static std::string GetPrevNodeOutputReshapeType(const AnfNodePtr &node, size_t input_idx); | |||
| // get output shapes inferred by ME from input nodes. | |||
| static std::vector<size_t> GetOutputInferShape(const AnfNodePtr &node, size_t output_idx); | |||
| // get input shapes inferred by ME from input nodes. | |||
| @@ -132,9 +132,9 @@ class AnfRuntimeAlgorithm { | |||
| // get input shapes which will built and run in device | |||
| static std::vector<size_t> GetInputDeviceShape(const AnfNodePtr &node, size_t input_idx); | |||
| // Get Input Padding Axis | |||
| static std::vector<Axis> GetInputReshapeType(const AnfNodePtr &node, size_t output_idx); | |||
| static std::string GetInputReshapeType(const AnfNodePtr &node, size_t output_idx); | |||
| // Get Output Padding Axis | |||
| static std::vector<Axis> GetOutputReshapeType(const AnfNodePtr &node, size_t output_idx); | |||
| static std::string GetOutputReshapeType(const AnfNodePtr &node, size_t output_idx); | |||
| // get output data type inferred by ME of anf node | |||
| static TypeId GetOutputInferDataType(const AnfNodePtr &node, size_t output_idx); | |||
| // get output original data type from prev node,input_index is the input index of current node related to prev node | |||
| @@ -21,6 +21,7 @@ | |||
| #include "abstract/utils.h" | |||
| #include "backend/session/anf_runtime_algorithm.h" | |||
| #include "backend/kernel_compiler/kernel.h" | |||
| #include "backend/kernel_compiler/tbe/tbe_dynaminc_shape_util.h" | |||
| #include "runtime/device/convert_tensor_utils.h" | |||
| #include "utils/convert_utils.h" | |||
| #include "utils/log_adapter.h" | |||
| @@ -28,7 +29,7 @@ | |||
| namespace mindspore { | |||
| namespace trans { | |||
| enum kAxis : int { kN = 0, kC, kH, kW, kNchwDims, kNdhwc }; | |||
| enum kAxis : int { kN = 0, kC, kH, kW, kNchwDims, kNcdhw }; | |||
| inline void SetData(size_t size, bool pad_zero, size_t src_idx, size_t dst_idx, const FormatArgs &args, void *result) { | |||
| switch (size) { | |||
| case 1: | |||
| @@ -343,7 +344,7 @@ std::vector<size_t> Nc1hwc04DeviceShape(const std::vector<size_t> &shape) { | |||
| } | |||
| std::vector<size_t> NcdhwDeviceShape(const std::vector<size_t> &shape) { | |||
| if (shape.size() < kNdhwc) { | |||
| if (shape.size() < kNcdhw) { | |||
| MS_LOG(EXCEPTION) << "Shape dims must be 5 when format is ndhwc."; | |||
| } | |||
| return shape; | |||
| @@ -388,6 +389,20 @@ bool IsNeedPadding(const std::string &format, const size_t shape_size) { | |||
| return false; | |||
| } | |||
| std::vector<size_t> PaddingShape(const std::vector<size_t> &shape, const std::string &format, | |||
| const std::string &pad_index) { | |||
| std::vector<size_t> host_shape; | |||
| if (k3DFormatSet.find(format) != k3DFormatSet.end()) { | |||
| if (shape.size() >= kNcdhw) { | |||
| return shape; | |||
| } | |||
| host_shape = trans::PaddingShapeTo5d(shape, pad_index); | |||
| } else { | |||
| host_shape = trans::PaddingShapeTo4d(shape, pad_index); | |||
| } | |||
| return host_shape; | |||
| } | |||
| ShapeVector GetRuntimePaddingShape(const AnfNodePtr &node, size_t index) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| ShapeVector shape; | |||
| @@ -409,14 +424,84 @@ ShapeVector GetRuntimePaddingShape(const AnfNodePtr &node, size_t index) { | |||
| } else { | |||
| host_shape = AnfAlgo::GetOutputInferShape(node, index); | |||
| } | |||
| if (trans::IsNeedPadding(AnfAlgo::GetOutputFormat(node, index), host_shape.size())) { | |||
| host_shape = trans::PaddingShapeTo4d(host_shape, AnfAlgo::GetOutputReshapeType(node, index)); | |||
| auto format = AnfAlgo::GetOutputFormat(node, index); | |||
| if (trans::IsNeedPadding(format, host_shape.size())) { | |||
| host_shape = trans::PaddingShape(host_shape, format, AnfAlgo::GetOutputReshapeType(node, index)); | |||
| } | |||
| std::transform(host_shape.begin(), host_shape.end(), std::back_inserter(shape), SizeToLong); | |||
| return shape; | |||
| } | |||
| std::vector<size_t> PaddingShapeTo4d(const std::vector<size_t> &shape, const std::vector<Axis> &padding_axis) { | |||
| void StringToAxisVector4D(const std::string &reshape_type_str, std::vector<Axis> *reshape_type_vec) { | |||
| MS_EXCEPTION_IF_NULL(reshape_type_vec); | |||
| if (reshape_type_str.empty()) { | |||
| MS_LOG(DEBUG) << "Reshape type str is empty, no need padding."; | |||
| return; | |||
| } | |||
| for (const auto &c : reshape_type_str) { | |||
| switch (c) { | |||
| case 'N': | |||
| reshape_type_vec->push_back(N); | |||
| break; | |||
| case 'C': | |||
| reshape_type_vec->push_back(C); | |||
| break; | |||
| case 'H': | |||
| reshape_type_vec->push_back(H); | |||
| break; | |||
| case 'W': | |||
| reshape_type_vec->push_back(W); | |||
| break; | |||
| default: | |||
| MS_LOG(EXCEPTION) << "Unknown axis " << c << "in reshape type."; | |||
| } | |||
| } | |||
| } | |||
| void StringToAxisVector5D(const std::string &reshape_type_str, std::vector<Axis5D> *reshape_type_vec) { | |||
| MS_EXCEPTION_IF_NULL(reshape_type_vec); | |||
| if (reshape_type_str.empty()) { | |||
| MS_LOG(DEBUG) << "Reshape type str is empty, no need padding."; | |||
| return; | |||
| } | |||
| for (const auto &c : reshape_type_str) { | |||
| switch (c) { | |||
| case 'N': | |||
| reshape_type_vec->push_back(N_ncdhw); | |||
| break; | |||
| case 'C': | |||
| reshape_type_vec->push_back(C_ncdhw); | |||
| break; | |||
| case 'D': | |||
| reshape_type_vec->push_back(D_ncdhw); | |||
| break; | |||
| case 'H': | |||
| reshape_type_vec->push_back(H_ncdhw); | |||
| break; | |||
| case 'W': | |||
| reshape_type_vec->push_back(W_ncdhw); | |||
| break; | |||
| default: | |||
| MS_LOG(EXCEPTION) << "Unknown axis " << c << "in reshape type."; | |||
| } | |||
| } | |||
| } | |||
| std::vector<size_t> PaddingShapeTo5d(const std::vector<size_t> &shape, const std::string &padding_str) { | |||
| std::vector<Axis5D> padding_axis; | |||
| StringToAxisVector5D(padding_str, &padding_axis); | |||
| if (padding_axis.empty() || shape.size() != padding_axis.size()) { | |||
| return PaddingShapeTo5dDefault(shape); | |||
| } | |||
| std::vector<size_t> shape_5d(kNcdhw, 1); | |||
| for (size_t index = 0; index < padding_axis.size(); index++) { | |||
| shape_5d[padding_axis[index]] = shape[index]; | |||
| } | |||
| return shape_5d; | |||
| } | |||
| std::vector<size_t> PaddingShapeTo4d(const std::vector<size_t> &shape, const std::string &padding_str) { | |||
| std::vector<Axis> padding_axis; | |||
| StringToAxisVector4D(padding_str, &padding_axis); | |||
| if (padding_axis.empty() || shape.size() != padding_axis.size()) { | |||
| return PaddingShapeTo4dByDefault(shape); | |||
| } | |||
| @@ -427,6 +512,38 @@ std::vector<size_t> PaddingShapeTo4d(const std::vector<size_t> &shape, const std | |||
| return shape_4d; | |||
| } | |||
| std::vector<size_t> PaddingShapeTo5dDefault(const std::vector<size_t> &shape) { | |||
| if (shape.size() >= kNcdhw) { | |||
| return shape; | |||
| } | |||
| std::vector<size_t> shape_5d(kNcdhw, 1); | |||
| switch (shape.size()) { | |||
| case 0: | |||
| return shape_5d; | |||
| case 1: | |||
| shape_5d[1] = shape[0]; | |||
| break; | |||
| case 2: | |||
| shape_5d[1] = shape[0]; | |||
| shape_5d[2] = shape[1]; | |||
| break; | |||
| case 3: | |||
| shape_5d[1] = shape[0]; | |||
| shape_5d[2] = shape[1]; | |||
| shape_5d[3] = shape[2]; | |||
| break; | |||
| case 4: | |||
| shape_5d[1] = shape[0]; | |||
| shape_5d[2] = shape[1]; | |||
| shape_5d[3] = shape[2]; | |||
| shape_5d[4] = shape[3]; | |||
| break; | |||
| default: | |||
| MS_LOG(EXCEPTION) << "Unexpected shape size = " << shape.size(); | |||
| } | |||
| return shape_5d; | |||
| } | |||
| std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const std::string &format) { | |||
| using DeviceShapeTransfer = std::function<std::vector<size_t>(const std::vector<size_t> &)>; | |||
| const std::map<std::string, DeviceShapeTransfer> device_shape_map{{kOpFormat_NCHW, NchwDeviceShape}, | |||
| @@ -475,10 +592,13 @@ std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const s | |||
| device_shape.push_back(kCubeSize); | |||
| return device_shape; | |||
| } | |||
| if (shape.size() != kNchwDims && shape.size() != 5) { | |||
| if (shape.size() != kNchwDims && k3DFormatSet.find(format) == k3DFormatSet.end()) { | |||
| MS_LOG(WARNING) << "Get Device Shape using a shape size is less than 4 ,should be Padding shape by Default firstly"; | |||
| temp_shape = PaddingShapeTo4dByDefault(shape); | |||
| } | |||
| if (shape.size() != kNcdhw && k3DFormatSet.find(format) != k3DFormatSet.end()) { | |||
| temp_shape = PaddingShapeTo5dDefault(shape); | |||
| } | |||
| auto iter = device_shape_map.find(format); | |||
| if (iter == device_shape_map.end()) { | |||
| MS_LOG(EXCEPTION) << "Unexpected format[" << format << "]"; | |||
| @@ -30,6 +30,13 @@ | |||
| namespace mindspore { | |||
| namespace trans { | |||
| enum Axis5D : int { | |||
| N_ncdhw = 0, | |||
| C_ncdhw, | |||
| D_ncdhw, | |||
| H_ncdhw, | |||
| W_ncdhw, | |||
| }; | |||
| struct TypeIdArgs { | |||
| const void *data; | |||
| size_t host_shape_size; // Multiply each dimension elements. [a, b, c, d] => a*b*c*d | |||
| @@ -50,7 +57,13 @@ struct FormatArgs { | |||
| size_t CubeSizeByType(const TypeId data_type); | |||
| std::vector<size_t> PaddingShapeTo4d(const std::vector<size_t> &shape, const std::vector<Axis> &padding_axis = {}); | |||
| std::vector<size_t> PaddingShape(const std::vector<size_t> &shape, const std::string &format, | |||
| const std::string &pad_index = {""}); | |||
| std::vector<size_t> PaddingShapeTo4d(const std::vector<size_t> &shape, const std::string &padding_axis = {""}); | |||
| std::vector<size_t> PaddingShapeTo5d(const std::vector<size_t> &shape, const std::string &padding_axis = {""}); | |||
| std::vector<size_t> PaddingShapeTo5dDefault(const std::vector<size_t> &shape); | |||
| void StringToAxisVector4D(const std::string &reshape_type_str, std::vector<Axis> *reshape_type_vec); | |||
| void StringToAxisVector5D(const std::string &reshape_type_str, std::vector<Axis5D> *reshape_type_vec); | |||
| ShapeVector GetRuntimePaddingShape(const AnfNodePtr &node, size_t index); | |||
| bool IsNeedPadding(const std::string &format, const size_t shape_size); | |||
| std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const std::string &format); | |||
| @@ -475,7 +475,7 @@ std::vector<size_t> AscendDeviceAddress::GetDeviceShape(std::vector<size_t> *hos | |||
| device_shape = trans::TransShapeToDevice(*host_shape, format_); | |||
| } else { | |||
| if (host_shape_.empty()) { | |||
| *host_shape = trans::PaddingShapeTo4d(*host_shape); | |||
| *host_shape = trans::PaddingShape(*host_shape, format_); | |||
| } else { | |||
| host_shape->clear(); | |||
| (void)std::transform(host_shape_.begin(), host_shape_.end(), std::back_inserter(*host_shape), LongToSize); | |||
| @@ -595,11 +595,10 @@ bool AscendDeviceAddress::ConvertFormatAndSyncHostToDevice(const ShapeVector &sh | |||
| host_shape.emplace_back(1); | |||
| } | |||
| std::vector<size_t> device_shape; | |||
| if (format_ == kOpFormat_FRAC_NZ || format_ == kOpFormat_NCDHW || format_ == kOpFormat_NDC1HWC0 || | |||
| format_ == kOpFormat_FRACTAL_Z_3D) { | |||
| if (format_ == kOpFormat_FRAC_NZ) { | |||
| device_shape = trans::TransShapeToDevice(host_shape, format_); | |||
| } else { | |||
| host_shape = trans::PaddingShapeTo4d(host_shape); | |||
| host_shape = trans::PaddingShape(host_shape, format_); | |||
| device_shape = trans::TransShapeToDevice(host_shape, format_); | |||
| } | |||
| if (type_id_ != type) { | |||
| @@ -68,7 +68,7 @@ size_t KernelRuntime::CountNodeDeviceMemorySize(const mindspore::AnfNodePtr &nod | |||
| std::vector<size_t> shape = AnfAlgo::GetOutputDeviceShape(node, output_index); | |||
| auto format = AnfAlgo::GetOutputFormat(node, output_index); | |||
| if (shape.empty() && format != kOpFormat_DEFAULT) { | |||
| shape = trans::PaddingShapeTo4d(shape, AnfAlgo::GetOutputReshapeType(node, output_index)); | |||
| shape = trans::PaddingShape(shape, format, AnfAlgo::GetOutputReshapeType(node, output_index)); | |||
| shape = trans::TransShapeToDevice(shape, format); | |||
| } | |||
| // scalar's output shape is a empty vector | |||
| @@ -303,6 +303,7 @@ constexpr auto kAttrFactor = "factor"; | |||
| constexpr auto kAttrIsRef = "isRef"; | |||
| constexpr auto kAttrDataShape = "data_shape"; | |||
| constexpr auto kAttrFormat = "format"; | |||
| constexpr auto kAttrReshapeType = "reshape_type"; | |||
| constexpr auto kAttrAxis = "axis"; | |||
| constexpr auto kAttrKeepDims = "keep_dims"; | |||
| constexpr auto kAttrShapeGamma = "shape_gamma"; | |||
| @@ -285,8 +285,8 @@ class Tensor : public MetaTensor { | |||
| DeviceSyncPtr device_address() const { return device_sync_; } | |||
| void set_device_address(const DeviceSyncPtr &device_sync) { device_sync_ = device_sync; } | |||
| void set_padding_type(std::vector<Axis> padding_type) { padding_type_ = padding_type; } | |||
| std::vector<Axis> padding_type() const { return padding_type_; } | |||
| void set_padding_type(const std::string padding_type) { padding_type_ = padding_type; } | |||
| std::string padding_type() const { return padding_type_; } | |||
| std::string id() const { return id_; } | |||
| TypePtr cast_dtype() { return cast_dtype_; } | |||
| @@ -366,7 +366,7 @@ class Tensor : public MetaTensor { | |||
| bool cache_enable_{false}; | |||
| std::shared_ptr<Tensor> cache_tensor_ptr_{nullptr}; | |||
| std::shared_ptr<Tensor> hashmap_tensor_ptr_{nullptr}; | |||
| std::vector<Axis> padding_type_; | |||
| std::string padding_type_{""}; | |||
| TypePtr cast_dtype_{nullptr}; | |||
| std::shared_ptr<DeviceEvent> device_event_{nullptr}; | |||
| }; | |||
| @@ -50,8 +50,8 @@ class TestHWInsertTransOp : public BackendCommon { | |||
| KernelBuildInfoBuilder builder; | |||
| builder.SetInputsFormat({format, format}); | |||
| builder.SetInputsDeviceType({kFloat16->type_id(), kFloat16->type_id()}); | |||
| builder.SetInputsReshapeType({{},{}}); | |||
| builder.SetOutputsReshapeType({}); | |||
| builder.SetInputsReshapeType({"", ""}); | |||
| builder.SetOutputsReshapeType({""}); | |||
| builder.SetOutputsFormat({format}); | |||
| builder.SetOutputsDeviceType({kFloat16->type_id()}); | |||
| add->set_kernel_info(std::make_shared<device::KernelInfo>()); | |||
| @@ -72,8 +72,8 @@ class TestHWInsertTransOp : public BackendCommon { | |||
| EXPECT_NE(ret->input(1)->cast<CNodePtr>()->input(1)->cast<CNodePtr>()->input(1), nullptr); | |||
| auto max_pool = ret->input(1)->cast<CNodePtr>()->input(1)->cast<CNodePtr>()->input(1); | |||
| KernelBuildInfoBuilder builder; | |||
| builder.SetInputsReshapeType({{}}); | |||
| builder.SetOutputsReshapeType({{},{}}); | |||
| builder.SetInputsReshapeType({""}); | |||
| builder.SetOutputsReshapeType({"", ""}); | |||
| builder.SetInputsFormat({kOpFormat_DEFAULT}); | |||
| builder.SetInputsDeviceType({kFloat16->type_id()}); | |||
| builder.SetOutputsFormat({format, format}); | |||
| @@ -92,8 +92,8 @@ class MockInsertTransOpKernelSelectTrans4Dto5D : public KernelSelect { | |||
| ~MockInsertTransOpKernelSelectTrans4Dto5D() override = default; | |||
| void SelectKernel(const CNodePtr &cnode) override { | |||
| KernelBuildInfoBuilder builder; | |||
| builder.SetInputsReshapeType({{}}); | |||
| builder.SetOutputsReshapeType({{}}); | |||
| builder.SetInputsReshapeType({""}); | |||
| builder.SetOutputsReshapeType({""}); | |||
| builder.SetInputsFormat({"NCHW"}); | |||
| builder.SetInputsDeviceType({kFloat16->type_id()}); | |||
| builder.SetOutputsFormat({"NC1HWC0"}); | |||
| @@ -53,8 +53,8 @@ class TestHWRemoveInternalOutput : public BackendCommon { | |||
| KernelBuildInfoBuilder builder; | |||
| builder.SetInputsFormat({kOpFormat_DEFAULT, kOpFormat_DEFAULT}); | |||
| builder.SetInputsDeviceType({kFloat32->type_id(), kFloat32->type_id()}); | |||
| builder.SetInputsReshapeType({{}, {}}); | |||
| builder.SetOutputsReshapeType({{}}); | |||
| builder.SetInputsReshapeType({"", ""}); | |||
| builder.SetOutputsReshapeType({""}); | |||
| builder.SetOutputsFormat({kOpFormat_NC1HWC0}); | |||
| builder.SetOutputsDeviceType({kFloat16->type_id()}); | |||
| add->set_kernel_info(std::make_shared<device::KernelInfo>()); | |||
| @@ -80,8 +80,8 @@ class TestHWRemoveInternalOutput : public BackendCommon { | |||
| kg->AddInternalOutput(tuple_getitem1, max_pool, 0, true); | |||
| kg->AddInternalOutput(tuple_getitem2, max_pool, 1, true); | |||
| KernelBuildInfoBuilder builder; | |||
| builder.SetInputsReshapeType({{}}); | |||
| builder.SetOutputsReshapeType({{}, {}}); | |||
| builder.SetInputsReshapeType({""}); | |||
| builder.SetOutputsReshapeType({"", ""}); | |||
| builder.SetInputsFormat({kOpFormat_DEFAULT}); | |||
| builder.SetInputsDeviceType({kFloat32->type_id()}); | |||
| builder.SetOutputsFormat({kOpFormat_NC1HWC0, kOpFormat_NC1HWC0}); | |||
| @@ -103,8 +103,8 @@ class MockRemoveInternalOutputTransOpKernelSelect : public KernelSelect { | |||
| builder.SetInputsDeviceType({kFloat16->type_id()}); | |||
| builder.SetOutputsFormat({kOpFormat_DEFAULT}); | |||
| builder.SetOutputsDeviceType({kFloat32->type_id()}); | |||
| builder.SetInputsReshapeType({{}}); | |||
| builder.SetOutputsReshapeType({{}}); | |||
| builder.SetInputsReshapeType({""}); | |||
| builder.SetOutputsReshapeType({""}); | |||
| AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cnode.get()); | |||
| } | |||
| }; | |||
| @@ -51,8 +51,8 @@ class MockInsertTransOpKernelSelectTrans4Dto5D : public KernelSelect { | |||
| builder.SetInputsDeviceType({kFloat16->type_id()}); | |||
| builder.SetOutputsFormat({"NC1HWC0"}); | |||
| builder.SetOutputsDeviceType({kFloat16->type_id()}); | |||
| builder.SetInputsReshapeType({}); | |||
| builder.SetOutputsReshapeType({}); | |||
| builder.SetInputsReshapeType({""}); | |||
| builder.SetOutputsReshapeType({""}); | |||
| AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cnode.get()); | |||
| } else { | |||
| KernelBuildInfoBuilder builder; | |||
| @@ -60,8 +60,8 @@ class MockInsertTransOpKernelSelectTrans4Dto5D : public KernelSelect { | |||
| builder.SetInputsDeviceType({kFloat16->type_id()}); | |||
| builder.SetOutputsFormat({"NC1HWC0"}); | |||
| builder.SetOutputsDeviceType({kFloat16->type_id()}); | |||
| builder.SetInputsReshapeType({}); | |||
| builder.SetOutputsReshapeType({}); | |||
| builder.SetInputsReshapeType({""}); | |||
| builder.SetOutputsReshapeType({""}); | |||
| AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cnode.get()); | |||
| } | |||
| @@ -79,8 +79,8 @@ class MockTransdataSplitKernelSelect : public KernelSelect { | |||
| builder.SetInputsDeviceType({kFloat16->type_id()}); | |||
| builder.SetOutputsFormat({"NCHW"}); | |||
| builder.SetOutputsDeviceType({kFloat16->type_id()}); | |||
| builder.SetInputsReshapeType({{}}); | |||
| builder.SetOutputsReshapeType({{}}); | |||
| builder.SetInputsReshapeType({""}); | |||
| builder.SetOutputsReshapeType({""}); | |||
| AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cnode.get()); | |||
| } else { | |||
| KernelBuildInfoBuilder builder; | |||
| @@ -88,8 +88,8 @@ class MockTransdataSplitKernelSelect : public KernelSelect { | |||
| builder.SetInputsDeviceType({kFloat16->type_id()}); | |||
| builder.SetOutputsFormat({"NCHW"}); | |||
| builder.SetOutputsDeviceType({kFloat16->type_id()}); | |||
| builder.SetInputsReshapeType({{}}); | |||
| builder.SetOutputsReshapeType({{}}); | |||
| builder.SetInputsReshapeType({""}); | |||
| builder.SetOutputsReshapeType({""}); | |||
| AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cnode.get()); | |||
| } | |||
| } | |||
| @@ -125,8 +125,8 @@ TEST_F(TestHWTransdataSplit, test_transdata_split_fraz_nchw) { | |||
| builder.SetKernelType(KernelType::TBE_KERNEL); | |||
| builder.SetFusionType(kernel::FusionType::ELEMWISE); | |||
| builder.SetProcessor(kernel::Processor::AICORE); | |||
| builder.SetInputsReshapeType({{}}); | |||
| builder.SetOutputsReshapeType({{}}); | |||
| builder.SetInputsReshapeType({""}); | |||
| builder.SetOutputsReshapeType({""}); | |||
| auto kernel_info = std::make_shared<device::KernelInfo>(); | |||
| kernel_info->set_select_kernel_build_info(builder.Build()); | |||
| transpose->set_kernel_info(kernel_info); | |||
| @@ -173,8 +173,8 @@ TEST_F(TestHWTransdataSplit, test_transdata_split_nchw_fraz) { | |||
| builder.SetKernelType(KernelType::TBE_KERNEL); | |||
| builder.SetFusionType(kernel::FusionType::ELEMWISE); | |||
| builder.SetProcessor(kernel::Processor::AICORE); | |||
| builder.SetInputsReshapeType({{}}); | |||
| builder.SetOutputsReshapeType({{}}); | |||
| builder.SetInputsReshapeType({""}); | |||
| builder.SetOutputsReshapeType({""}); | |||
| auto kernel_info = std::make_shared<device::KernelInfo>(); | |||
| kernel_info->set_select_kernel_build_info(builder.Build()); | |||
| transpose->set_kernel_info(kernel_info); | |||
| @@ -58,8 +58,8 @@ class MockInsertTransOpKernelSelectTrans4Dto5D : public KernelSelect { | |||
| builder.SetInputsDeviceType({kFloat16->type_id()}); | |||
| builder.SetOutputsFormat({"NC1HWC0"}); | |||
| builder.SetOutputsDeviceType({kFloat16->type_id()}); | |||
| builder.SetInputsReshapeType({}); | |||
| builder.SetOutputsReshapeType({}); | |||
| builder.SetInputsReshapeType({""}); | |||
| builder.SetOutputsReshapeType({""}); | |||
| AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cnode.get()); | |||
| } else { | |||
| KernelBuildInfoBuilder builder; | |||
| @@ -67,8 +67,8 @@ class MockInsertTransOpKernelSelectTrans4Dto5D : public KernelSelect { | |||
| builder.SetInputsDeviceType({kFloat16->type_id()}); | |||
| builder.SetOutputsFormat({"NC1HWC0"}); | |||
| builder.SetOutputsDeviceType({kFloat16->type_id()}); | |||
| builder.SetInputsReshapeType({}); | |||
| builder.SetOutputsReshapeType({}); | |||
| builder.SetInputsReshapeType({""}); | |||
| builder.SetOutputsReshapeType({""}); | |||
| AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cnode.get()); | |||
| } | |||
| } | |||
| @@ -97,8 +97,8 @@ TEST_F(TestHWTransposeTransdataFusion, test_transpose_transdata_fusion) { | |||
| EXPECT_NE(transpose, nullptr); | |||
| KernelBuildInfoBuilder builder; | |||
| builder.SetInputsReshapeType({}); | |||
| builder.SetOutputsReshapeType({}); | |||
| builder.SetInputsReshapeType({""}); | |||
| builder.SetOutputsReshapeType({""}); | |||
| builder.SetInputsFormat({"NCHW"}); | |||
| builder.SetInputsDeviceType({kFloat16->type_id()}); | |||
| builder.SetOutputsFormat({"NC1HWC0"}); | |||
| @@ -56,8 +56,8 @@ class MockEliminate5To4And4To5KernelSelect : public KernelSelect { | |||
| ~MockEliminate5To4And4To5KernelSelect() override = default; | |||
| void SelectKernel(const CNodePtr &cnode) override { | |||
| KernelBuildInfoBuilder builder; | |||
| builder.SetInputsReshapeType({{}}); | |||
| builder.SetOutputsReshapeType({{}}); | |||
| builder.SetInputsReshapeType({""}); | |||
| builder.SetOutputsReshapeType({""}); | |||
| builder.SetInputsFormat({"NCHW"}); | |||
| builder.SetInputsDeviceType({kFloat16->type_id()}); | |||
| builder.SetOutputsFormat({"NC1HWC0"}); | |||
| @@ -104,8 +104,8 @@ TEST_F(TestHWEliminateRedundantOp, test_eliminate_5to4_4to5) { | |||
| builder.SetOutputsFormat({"NC1HWC0"}); | |||
| builder.SetInputsDeviceType({kFloat16->type_id(), kFloat16->type_id()}); | |||
| builder.SetOutputsDeviceType({kFloat16->type_id()}); | |||
| builder.SetInputsReshapeType({{}, {}}); | |||
| builder.SetOutputsReshapeType({{}}); | |||
| builder.SetInputsReshapeType({"", ""}); | |||
| builder.SetOutputsReshapeType({""}); | |||
| sub->set_kernel_info(std::make_shared<device::KernelInfo>()); | |||
| add->set_kernel_info(std::make_shared<device::KernelInfo>()); | |||
| AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), sub.get()); | |||
| @@ -171,8 +171,8 @@ TEST_F(TestHWEliminateRedundantOp, test_eliminate_cast) { | |||
| builder.SetOutputsFormat({"NC1HWC0"}); | |||
| builder.SetInputsDeviceType({kFloat16->type_id(), kFloat16->type_id()}); | |||
| builder.SetOutputsDeviceType({kFloat16->type_id()}); | |||
| builder.SetInputsReshapeType({{}, {}}); | |||
| builder.SetOutputsReshapeType({{}, {}}); | |||
| builder.SetInputsReshapeType({"", ""}); | |||
| builder.SetOutputsReshapeType({"", ""}); | |||
| sub->set_kernel_info(std::make_shared<device::KernelInfo>()); | |||
| add->set_kernel_info(std::make_shared<device::KernelInfo>()); | |||
| AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), sub.get()); | |||
| @@ -248,8 +248,8 @@ TEST_F(TestHWEliminateRedundantOp, test_eliminate_cast_depend_cast) { | |||
| builder.SetOutputsFormat({"NC1HWC0"}); | |||
| builder.SetInputsDeviceType({kFloat16->type_id(), kFloat16->type_id()}); | |||
| builder.SetOutputsDeviceType({kFloat16->type_id()}); | |||
| builder.SetInputsReshapeType({{}, {}}); | |||
| builder.SetOutputsReshapeType({{}}); | |||
| builder.SetInputsReshapeType({"", ""}); | |||
| builder.SetOutputsReshapeType({""}); | |||
| sub->set_kernel_info(std::make_shared<device::KernelInfo>()); | |||
| add->set_kernel_info(std::make_shared<device::KernelInfo>()); | |||
| AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), sub.get()); | |||