| @@ -158,13 +158,13 @@ void KernelBuildInfo::KernelBuildInfoBuilder::SetProcessor(Processor processor) | |||
| std::shared_ptr<KernelBuildInfo> KernelBuildInfo::KernelBuildInfoBuilder::Build() { return kernel_build_info_; } | |||
| void KernelBuildInfo::KernelBuildInfoBuilder::SetInputReshapeType( | |||
| void KernelBuildInfo::KernelBuildInfoBuilder::SetInputsReshapeType( | |||
| const std::vector<std::vector<Axis>> &input_reshape_type) { | |||
| MS_EXCEPTION_IF_NULL(kernel_build_info_); | |||
| kernel_build_info_->input_reshape_type_ = input_reshape_type; | |||
| } | |||
| void KernelBuildInfo::KernelBuildInfoBuilder::SetOutputReshapeType( | |||
| void KernelBuildInfo::KernelBuildInfoBuilder::SetOutputsReshapeType( | |||
| const std::vector<std::vector<Axis>> &output_reshape_type) { | |||
| MS_EXCEPTION_IF_NULL(kernel_build_info_); | |||
| kernel_build_info_->output_reshape_type_ = output_reshape_type; | |||
| @@ -189,5 +189,37 @@ void KernelBuildInfo::KernelBuildInfoBuilder::SetOutputFormat(const std::string | |||
| } | |||
| kernel_build_info_->outputs_format_[index] = format; | |||
| } | |||
| void KernelBuildInfo::KernelBuildInfoBuilder::SetInputReshapeType(const std::vector<Axis> &input_reshape_type, | |||
| size_t index) { | |||
| if (index >= kernel_build_info_->input_reshape_type_.size()) { | |||
| MS_LOG(EXCEPTION) << "index outof range!"; | |||
| } | |||
| std::copy(input_reshape_type.begin(), input_reshape_type.end(), | |||
| std::back_inserter(kernel_build_info_->input_reshape_type_[index])); | |||
| } | |||
| void KernelBuildInfo::KernelBuildInfoBuilder::SetOutputReshapeType(const std::vector<Axis> &output_reshape_type, | |||
| size_t index) { | |||
| if (index >= kernel_build_info_->output_reshape_type_.size()) { | |||
| MS_LOG(EXCEPTION) << "index outof range!"; | |||
| } | |||
| std::copy(output_reshape_type.begin(), output_reshape_type.end(), | |||
| std::back_inserter(kernel_build_info_->output_reshape_type_[index])); | |||
| } | |||
| void KernelBuildInfo::KernelBuildInfoBuilder::SetOutputDeviceType(const TypeId &output_device_type, size_t index) { | |||
| if (index >= kernel_build_info_->outputs_device_type_.size()) { | |||
| MS_LOG(EXCEPTION) << "index outof range!"; | |||
| } | |||
| kernel_build_info_->outputs_device_type_[index] = output_device_type; | |||
| } | |||
| void KernelBuildInfo::KernelBuildInfoBuilder::SetInputDeviceType(const TypeId &input_device_type, size_t index) { | |||
| if (index >= kernel_build_info_->inputs_device_type_.size()) { | |||
| MS_LOG(EXCEPTION) << "index outof range!"; | |||
| } | |||
| kernel_build_info_->inputs_device_type_[index] = input_device_type; | |||
| } | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -71,6 +71,10 @@ class KernelBuildInfo { | |||
| std::vector<TypeId> GetAllOutputDeviceTypes() const; | |||
| std::vector<std::vector<Axis>> GetAllOutputReshapeType() const; | |||
| std::vector<std::vector<Axis>> GetAllInputReshapeType() const; | |||
| OpPattern op_pattern() const { return op_pattern_; } | |||
| FusionType fusion_type() const { return fusion_type_; } | |||
| @@ -108,8 +112,23 @@ class KernelBuildInfo::KernelBuildInfoBuilder { | |||
| public: | |||
| KernelBuildInfoBuilder() { kernel_build_info_ = std::make_shared<KernelBuildInfo>(); } | |||
| explicit KernelBuildInfoBuilder(std::shared_ptr<KernelBuildInfo> kernel_build_info) | |||
| : kernel_build_info_(std::move(kernel_build_info)) {} | |||
| explicit KernelBuildInfoBuilder(const std::shared_ptr<KernelBuildInfo> &kernel_build_info) | |||
| : kernel_build_info_(std::make_shared<KernelBuildInfo>()) { | |||
| SetKernelType(kernel_build_info->kernel_type()); | |||
| SetFusionType(kernel_build_info->fusion_type()); | |||
| SetProcessor(kernel_build_info->processor()); | |||
| OpPattern(kernel_build_info->op_pattern()); | |||
| for (size_t index = 0; index < kernel_build_info->GetInputNum(); ++index) { | |||
| kernel_build_info_->inputs_device_type_.emplace_back(kernel_build_info->GetInputDeviceType(index)); | |||
| kernel_build_info_->inputs_format_.emplace_back(kernel_build_info->GetInputFormat(index)); | |||
| kernel_build_info_->input_reshape_type_.emplace_back(kernel_build_info->GetInputReshapeType(index)); | |||
| } | |||
| for (size_t index = 0; index < kernel_build_info->GetOutputNum(); ++index) { | |||
| kernel_build_info_->outputs_device_type_.emplace_back(kernel_build_info->GetOutputDeviceType(index)); | |||
| kernel_build_info_->outputs_format_.emplace_back(kernel_build_info->GetOutputFormat(index)); | |||
| kernel_build_info_->output_reshape_type_.emplace_back(kernel_build_info->GetOutputReshapeType(index)); | |||
| } | |||
| } | |||
| ~KernelBuildInfoBuilder() = default; | |||
| @@ -123,9 +142,9 @@ class KernelBuildInfo::KernelBuildInfoBuilder { | |||
| void SetOutputsDeviceType(const std::vector<TypeId> &outputs_device_type); | |||
| void SetInputReshapeType(const std::vector<std::vector<Axis>> &input_reshape_type); | |||
| void SetInputsReshapeType(const std::vector<std::vector<Axis>> &input_reshape_type); | |||
| void SetOutputReshapeType(const std::vector<std::vector<Axis>> &output_reshape_type); | |||
| void SetOutputsReshapeType(const std::vector<std::vector<Axis>> &output_reshape_type); | |||
| void SetFusionType(FusionType fusion_type); | |||
| @@ -137,6 +156,14 @@ class KernelBuildInfo::KernelBuildInfoBuilder { | |||
| void SetOutputFormat(const std::string &format, size_t index); | |||
| void SetInputReshapeType(const std::vector<Axis> &input_reshape_type, size_t index); | |||
| void SetOutputReshapeType(const std::vector<Axis> &output_reshape_type, size_t index); | |||
| void SetInputDeviceType(const TypeId &input_device_type, size_t index); | |||
| void SetOutputDeviceType(const TypeId &output_device_type, size_t index); | |||
| std::shared_ptr<KernelBuildInfo> Build(); | |||
| private: | |||
| @@ -118,7 +118,7 @@ void TbeKernelSelect::GetCommonPatternKernelInfo(const OpInfo &op_info) { | |||
| } | |||
| builder.SetInputsDeviceType(inputs_device_type); | |||
| builder.SetInputsFormat(inputs_format); | |||
| builder.SetInputReshapeType(inputs_reshape_type); | |||
| builder.SetInputsReshapeType(inputs_reshape_type); | |||
| // output | |||
| std::vector<std::string> outputs_format; | |||
| std::vector<TypeId> outputs_device_type; | |||
| @@ -129,7 +129,7 @@ void TbeKernelSelect::GetCommonPatternKernelInfo(const OpInfo &op_info) { | |||
| } | |||
| builder.SetOutputsDeviceType(outputs_device_type); | |||
| builder.SetOutputsFormat(outputs_format); | |||
| builder.SetOutputReshapeType(outputs_reshape_type); | |||
| builder.SetOutputsReshapeType(outputs_reshape_type); | |||
| kernel_info_list_->emplace_back(builder.Build()); | |||
| } | |||
| MS_LOG(INFO) << "end."; | |||
| @@ -47,6 +47,7 @@ | |||
| #include "backend/optimizer/ascend/ir_fission/transdata_split.h" | |||
| #include "backend/optimizer/ascend/ir_fission/topk_split.h" | |||
| #include "backend/optimizer/ascend/ir_fusion/momentum_lossscale_fusion.h" | |||
| #include "backend/optimizer/ascend/format_type/split_unsupported_transdata.h" | |||
| #include "backend/optimizer/ascend/ir_fusion/mul_add_fusion.h" | |||
| #include "backend/optimizer/ascend/ir_fusion/mul_addn_fusion.h" | |||
| #include "backend/optimizer/ascend/ir_fusion/matmul_biasadd_fusion.h" | |||
| @@ -228,6 +229,7 @@ void AscendMixPrecision(const std::shared_ptr<session::KernelGraph> &kernel_grap | |||
| mixed_precision_pm->AddPass(std::make_shared<MergeCastToOp>()); | |||
| mixed_precision_pm->AddPass(std::make_shared<LayerNormBetaGammaBackpropFusion>()); | |||
| mixed_precision_pm->AddPass(std::make_shared<EraseVisitAttr>()); | |||
| mixed_precision_pm->AddPass(std::make_shared<SplitUnsupportedTransData>()); | |||
| mixed_precision_pm->AddPass(std::make_shared<ConvertUnSupportNodeToAICPU>()); | |||
| mixed_precision_pm->AddPass(std::make_shared<RemoveInternalOutputCast>()); | |||
| optimizer->AddPassManager(mixed_precision_pm); | |||
| @@ -174,8 +174,8 @@ void RefreshKernelBuildInfo(const std::string &input_format, const std::string & | |||
| MS_EXCEPTION_IF_NULL(ori_build_info); | |||
| auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(ori_build_info); | |||
| builder->SetInputsFormat({input_format}); | |||
| builder->SetInputReshapeType({reshape_type}); | |||
| builder->SetOutputReshapeType({reshape_type}); | |||
| builder->SetInputsReshapeType({reshape_type}); | |||
| builder->SetOutputsReshapeType({reshape_type}); | |||
| builder->SetOutputsFormat({output_format}); | |||
| if (type_id != kTypeUnknown) { | |||
| builder->SetOutputsDeviceType({type_id}); | |||
| @@ -0,0 +1,65 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/optimizer/ascend/format_type/split_unsupported_transdata.h" | |||
| #include <vector> | |||
| #include <memory> | |||
| #include "backend/session/anf_runtime_algorithm.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| const BaseRef SplitUnsupportedTransData::DefinePattern() const { | |||
| VarPtr X = std::make_shared<Var>(); | |||
| return VectorRef({prim::KPrimTransData, X}); | |||
| } | |||
| const AnfNodePtr SplitUnsupportedTransData::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | |||
| const EquivPtr &) const { | |||
| if (node == nullptr || !node->isa<CNode>() || !AnfAlgo::IsRealKernel(node)) { | |||
| return nullptr; | |||
| } | |||
| auto ori_trans_data = node->cast<CNodePtr>(); | |||
| if (AnfAlgo::GetCNodeName(ori_trans_data) != prim::KPrimTransData->name()) { | |||
| return nullptr; | |||
| } | |||
| auto kernel_info = AnfAlgo::GetSelectKernelBuildInfo(ori_trans_data); | |||
| MS_EXCEPTION_IF_NULL(kernel_info); | |||
| if (kernel_info->GetInputNum() != 1 || kernel_info->GetOutputNum() != 1) { | |||
| MS_LOG(EXCEPTION) << "Transdata node's kernel info's input and output format size is not 1" | |||
| << ori_trans_data->DebugString(); | |||
| } | |||
| return SplitTransData(func_graph, ori_trans_data); | |||
| } | |||
| AnfNodePtr SplitUnsupportedTransData::SplitTransData(const FuncGraphPtr &func_graph, const CNodePtr &trans_node) const { | |||
| auto kernel_info = AnfAlgo::GetSelectKernelBuildInfo(trans_node); | |||
| if (kHWSpecialFormatSet.find(kernel_info->GetInputFormat(0)) == kHWSpecialFormatSet.end() || | |||
| kHWSpecialFormatSet.find(kernel_info->GetOutputFormat(0)) == kHWSpecialFormatSet.end()) { | |||
| return trans_node; | |||
| } | |||
| auto builder_info_to_default = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(kernel_info); | |||
| auto builder_info_to_special_foramt = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(kernel_info); | |||
| builder_info_to_default->SetOutputsFormat({kOpFormat_DEFAULT}); | |||
| builder_info_to_special_foramt->SetInputsFormat({kOpFormat_DEFAULT}); | |||
| std::vector<AnfNodePtr> next_trans_node_inputs = { | |||
| NewValueNode(std::make_shared<Primitive>(prim::KPrimTransData->name())), trans_node}; | |||
| auto next_trans_node = func_graph->NewCNode(next_trans_node_inputs); | |||
| next_trans_node->set_abstract(trans_node->abstract()); | |||
| AnfAlgo::SetSelectKernelBuildInfo(builder_info_to_default->Build(), trans_node.get()); | |||
| AnfAlgo::SetSelectKernelBuildInfo(builder_info_to_special_foramt->Build(), next_trans_node.get()); | |||
| return next_trans_node; | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,37 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_REMOVE_TRANSDATA_SPILT_H | |||
| #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_REMOVE_TRANSDATA_SPILT_H | |||
| #include "backend/optimizer/common/optimizer.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| class SplitUnsupportedTransData : public PatternProcessPass { | |||
| public: | |||
| explicit SplitUnsupportedTransData(bool multigraph = true) | |||
| : PatternProcessPass("split_unsupported_transdata", multigraph) {} | |||
| ~SplitUnsupportedTransData() override = default; | |||
| const BaseRef DefinePattern() const override; | |||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | |||
| private: | |||
| AnfNodePtr SplitTransData(const FuncGraphPtr &func_graph, const CNodePtr &trans_node) const; | |||
| }; | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_REMOVE_TRANSDATA_SPILT_H | |||
| @@ -51,6 +51,8 @@ class TestHWInsertTransOp : public BackendCommon { | |||
| builder.SetInputsFormat({format, format}); | |||
| builder.SetInputsDeviceType({kFloat16->type_id(), kFloat16->type_id()}); | |||
| builder.SetOutputsFormat({format}); | |||
| builder.SetInputsReshapeType({{},{}}); | |||
| builder.SetOutputsReshapeType({}); | |||
| builder.SetOutputsDeviceType({kFloat16->type_id()}); | |||
| add->set_kernel_info(std::make_shared<device::KernelInfo>()); | |||
| AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), add.get()); | |||
| @@ -70,6 +72,8 @@ class TestHWInsertTransOp : public BackendCommon { | |||
| EXPECT_NE(ret->input(1)->cast<CNodePtr>()->input(1)->cast<CNodePtr>()->input(1), nullptr); | |||
| auto max_pool = ret->input(1)->cast<CNodePtr>()->input(1)->cast<CNodePtr>()->input(1); | |||
| KernelBuildInfoBuilder builder; | |||
| builder.SetInputsReshapeType({{},{}}); | |||
| builder.SetOutputsReshapeType({}); | |||
| builder.SetInputsFormat({kOpFormat_DEFAULT}); | |||
| builder.SetInputsDeviceType({kFloat16->type_id()}); | |||
| builder.SetOutputsFormat({format, format}); | |||
| @@ -88,6 +92,8 @@ class MockInsertTransOpKernelSelectTrans4Dto5D : public KernelSelect { | |||
| ~MockInsertTransOpKernelSelectTrans4Dto5D() override = default; | |||
| void SelectKernel(const CNodePtr &cnode) override { | |||
| KernelBuildInfoBuilder builder; | |||
| builder.SetInputsReshapeType({{}}); | |||
| builder.SetOutputsReshapeType({{}}); | |||
| builder.SetInputsFormat({"NCHW"}); | |||
| builder.SetInputsDeviceType({kFloat16->type_id()}); | |||
| builder.SetOutputsFormat({"NC1HWC0"}); | |||
| @@ -52,6 +52,8 @@ class TestHWRemoveInternalOutput : public BackendCommon { | |||
| kg->AddInternalOutput(add, add); | |||
| KernelBuildInfoBuilder builder; | |||
| builder.SetInputsFormat({kOpFormat_DEFAULT, kOpFormat_DEFAULT}); | |||
| builder.SetInputsReshapeType({{}, {}}); | |||
| builder.SetOutputsReshapeType({{}}); | |||
| builder.SetInputsDeviceType({kFloat32->type_id(), kFloat32->type_id()}); | |||
| builder.SetOutputsFormat({kOpFormat_NC1HWC0}); | |||
| builder.SetOutputsDeviceType({kFloat16->type_id()}); | |||
| @@ -78,6 +80,8 @@ class TestHWRemoveInternalOutput : public BackendCommon { | |||
| kg->AddInternalOutput(tuple_getitem1, max_pool); | |||
| kg->AddInternalOutput(tuple_getitem2, max_pool); | |||
| KernelBuildInfoBuilder builder; | |||
| builder.SetInputsReshapeType({{}}); | |||
| builder.SetOutputsReshapeType({{}, {}}); | |||
| builder.SetInputsFormat({kOpFormat_DEFAULT}); | |||
| builder.SetInputsDeviceType({kFloat32->type_id()}); | |||
| builder.SetOutputsFormat({kOpFormat_NC1HWC0, kOpFormat_NC1HWC0}); | |||
| @@ -95,6 +99,8 @@ class MockRemoveInternalOutputTransOpKernelSelect : public KernelSelect { | |||
| ~MockRemoveInternalOutputTransOpKernelSelect() override = default; | |||
| void SelectKernel(const CNodePtr &cnode) override { | |||
| KernelBuildInfoBuilder builder; | |||
| builder.SetInputsReshapeType({{}}); | |||
| builder.SetOutputsReshapeType({{}}); | |||
| builder.SetInputsFormat({kOpFormat_NC1HWC0}); | |||
| builder.SetInputsDeviceType({kFloat16->type_id()}); | |||
| builder.SetOutputsFormat({kOpFormat_DEFAULT}); | |||
| @@ -51,6 +51,8 @@ class MockInsertTransOpKernelSelectTrans4Dto5D : public KernelSelect { | |||
| builder.SetInputsDeviceType({kFloat16->type_id()}); | |||
| builder.SetOutputsFormat({"NC1HWC0"}); | |||
| builder.SetOutputsDeviceType({kFloat16->type_id()}); | |||
| builder.SetInputsReshapeType({{}}); | |||
| builder.SetOutputsReshapeType({{}}); | |||
| AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cnode.get()); | |||
| } else { | |||
| KernelBuildInfoBuilder builder; | |||
| @@ -58,6 +60,8 @@ class MockInsertTransOpKernelSelectTrans4Dto5D : public KernelSelect { | |||
| builder.SetInputsDeviceType({kFloat16->type_id()}); | |||
| builder.SetOutputsFormat({"NC1HWC0"}); | |||
| builder.SetOutputsDeviceType({kFloat16->type_id()}); | |||
| builder.SetInputsReshapeType({{}}); | |||
| builder.SetOutputsReshapeType({{}}); | |||
| AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cnode.get()); | |||
| } | |||
| } | |||
| @@ -74,10 +78,14 @@ class MockTransdataSplitKernelSelect : public KernelSelect { | |||
| builder.SetInputsDeviceType({kFloat16->type_id()}); | |||
| builder.SetOutputsFormat({"NCHW"}); | |||
| builder.SetOutputsDeviceType({kFloat16->type_id()}); | |||
| builder.SetInputsReshapeType({{}}); | |||
| builder.SetOutputsReshapeType({{}}); | |||
| AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cnode.get()); | |||
| } else { | |||
| KernelBuildInfoBuilder builder; | |||
| builder.SetInputsFormat({"NCHW"}); | |||
| builder.SetInputsReshapeType({{}}); | |||
| builder.SetOutputsReshapeType({{}}); | |||
| builder.SetInputsDeviceType({kFloat16->type_id()}); | |||
| builder.SetOutputsFormat({"NCHW"}); | |||
| builder.SetOutputsDeviceType({kFloat16->type_id()}); | |||
| @@ -116,6 +124,8 @@ TEST_F(TestHWTransdataSplit, test_transdata_split_fraz_nchw) { | |||
| builder.SetKernelType(KernelType::TBE_KERNEL); | |||
| builder.SetFusionType(kernel::FusionType::ELEMWISE); | |||
| builder.SetProcessor(kernel::Processor::AICORE); | |||
| builder.SetInputsReshapeType({{}}); | |||
| builder.SetOutputsReshapeType({{}}); | |||
| auto kernel_info = std::make_shared<device::KernelInfo>(); | |||
| kernel_info->set_select_kernel_build_info(builder.Build()); | |||
| transpose->set_kernel_info(kernel_info); | |||
| @@ -162,6 +172,8 @@ TEST_F(TestHWTransdataSplit, test_transdata_split_nchw_fraz) { | |||
| builder.SetKernelType(KernelType::TBE_KERNEL); | |||
| builder.SetFusionType(kernel::FusionType::ELEMWISE); | |||
| builder.SetProcessor(kernel::Processor::AICORE); | |||
| builder.SetInputsReshapeType({{}}); | |||
| builder.SetOutputsReshapeType({{}}); | |||
| auto kernel_info = std::make_shared<device::KernelInfo>(); | |||
| kernel_info->set_select_kernel_build_info(builder.Build()); | |||
| transpose->set_kernel_info(kernel_info); | |||
| @@ -58,6 +58,8 @@ class MockInsertTransOpKernelSelectTrans4Dto5D : public KernelSelect { | |||
| builder.SetInputsDeviceType({kFloat16->type_id()}); | |||
| builder.SetOutputsFormat({"NC1HWC0"}); | |||
| builder.SetOutputsDeviceType({kFloat16->type_id()}); | |||
| builder.SetInputsReshapeType({}); | |||
| builder.SetOutputsReshapeType({}); | |||
| AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cnode.get()); | |||
| } else { | |||
| KernelBuildInfoBuilder builder; | |||
| @@ -65,6 +67,8 @@ class MockInsertTransOpKernelSelectTrans4Dto5D : public KernelSelect { | |||
| builder.SetInputsDeviceType({kFloat16->type_id()}); | |||
| builder.SetOutputsFormat({"NC1HWC0"}); | |||
| builder.SetOutputsDeviceType({kFloat16->type_id()}); | |||
| builder.SetInputsReshapeType({}); | |||
| builder.SetOutputsReshapeType({}); | |||
| AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cnode.get()); | |||
| } | |||
| } | |||
| @@ -97,6 +101,8 @@ TEST_F(TestHWTransposeTransdataFusion, test_transpose_transdata_fusion) { | |||
| builder.SetInputsDeviceType({kFloat16->type_id()}); | |||
| builder.SetOutputsFormat({"NC1HWC0"}); | |||
| builder.SetOutputsDeviceType({kFloat16->type_id()}); | |||
| builder.SetInputsReshapeType({}); | |||
| builder.SetOutputsReshapeType({}); | |||
| builder.SetKernelType(KernelType::TBE_KERNEL); | |||
| builder.SetFusionType(kernel::FusionType::ELEMWISE); | |||
| builder.SetProcessor(kernel::Processor::AICORE); | |||
| @@ -56,6 +56,8 @@ class MockEliminate5To4And4To5KernelSelect : public KernelSelect { | |||
| ~MockEliminate5To4And4To5KernelSelect() override = default; | |||
| void SelectKernel(const CNodePtr &cnode) override { | |||
| KernelBuildInfoBuilder builder; | |||
| builder.SetInputsReshapeType({{}}); | |||
| builder.SetOutputsReshapeType({{}}); | |||
| builder.SetInputsFormat({"NCHW"}); | |||
| builder.SetInputsDeviceType({kFloat16->type_id()}); | |||
| builder.SetOutputsFormat({"NC1HWC0"}); | |||
| @@ -102,7 +104,8 @@ TEST_F(TestHWEliminateRedundantOp, test_eliminate_5to4_4to5) { | |||
| builder.SetOutputsFormat({"NC1HWC0"}); | |||
| builder.SetInputsDeviceType({kFloat16->type_id(), kFloat16->type_id()}); | |||
| builder.SetOutputsDeviceType({kFloat16->type_id()}); | |||
| builder.SetInputsReshapeType({{}, {}}); | |||
| builder.SetOutputsReshapeType({{}}); | |||
| sub->set_kernel_info(std::make_shared<device::KernelInfo>()); | |||
| add->set_kernel_info(std::make_shared<device::KernelInfo>()); | |||
| AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), sub.get()); | |||
| @@ -168,7 +171,8 @@ TEST_F(TestHWEliminateRedundantOp, test_eliminate_cast) { | |||
| builder.SetOutputsFormat({"NC1HWC0"}); | |||
| builder.SetInputsDeviceType({kFloat16->type_id(), kFloat16->type_id()}); | |||
| builder.SetOutputsDeviceType({kFloat16->type_id()}); | |||
| builder.SetInputsReshapeType({{}, {}}); | |||
| builder.SetOutputsReshapeType({{}}); | |||
| sub->set_kernel_info(std::make_shared<device::KernelInfo>()); | |||
| add->set_kernel_info(std::make_shared<device::KernelInfo>()); | |||
| AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), sub.get()); | |||
| @@ -244,7 +248,8 @@ TEST_F(TestHWEliminateRedundantOp, test_eliminate_cast_depend_cast) { | |||
| builder.SetOutputsFormat({"NC1HWC0"}); | |||
| builder.SetInputsDeviceType({kFloat16->type_id(), kFloat16->type_id()}); | |||
| builder.SetOutputsDeviceType({kFloat16->type_id()}); | |||
| builder.SetInputsReshapeType({{}, {}}); | |||
| builder.SetOutputsReshapeType({{}}); | |||
| sub->set_kernel_info(std::make_shared<device::KernelInfo>()); | |||
| add->set_kernel_info(std::make_shared<device::KernelInfo>()); | |||
| AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), sub.get()); | |||