Browse Source

split unsupported transdata

tags/v0.6.0-beta
WilliamLian 5 years ago
parent
commit
edba641ddb
12 changed files with 211 additions and 13 deletions
  1. +34
    -2
      mindspore/ccsrc/backend/kernel_compiler/kernel_build_info.cc
  2. +31
    -4
      mindspore/ccsrc/backend/kernel_compiler/kernel_build_info.h
  3. +2
    -2
      mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_select.cc
  4. +2
    -0
      mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc
  5. +2
    -2
      mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.cc
  6. +65
    -0
      mindspore/ccsrc/backend/optimizer/ascend/format_type/split_unsupported_transdata.cc
  7. +37
    -0
      mindspore/ccsrc/backend/optimizer/ascend/format_type/split_unsupported_transdata.h
  8. +6
    -0
      tests/ut/cpp/pre_activate/ascend/format_type/insert_trans_op_test.cc
  9. +6
    -0
      tests/ut/cpp/pre_activate/ascend/format_type/remove_internal_output_test.cc
  10. +12
    -0
      tests/ut/cpp/pre_activate/ascend/ir_fission/transdata_split_test.cc
  11. +6
    -0
      tests/ut/cpp/pre_activate/ascend/ir_fusion/transpose_transdata_fusion_test.cc
  12. +8
    -3
      tests/ut/cpp/pre_activate/pass/eliminate_redundant_op_test.cc

+ 34
- 2
mindspore/ccsrc/backend/kernel_compiler/kernel_build_info.cc View File

@@ -158,13 +158,13 @@ void KernelBuildInfo::KernelBuildInfoBuilder::SetProcessor(Processor processor)

std::shared_ptr<KernelBuildInfo> KernelBuildInfo::KernelBuildInfoBuilder::Build() { return kernel_build_info_; }

void KernelBuildInfo::KernelBuildInfoBuilder::SetInputReshapeType(
void KernelBuildInfo::KernelBuildInfoBuilder::SetInputsReshapeType(
const std::vector<std::vector<Axis>> &input_reshape_type) {
MS_EXCEPTION_IF_NULL(kernel_build_info_);
kernel_build_info_->input_reshape_type_ = input_reshape_type;
}

void KernelBuildInfo::KernelBuildInfoBuilder::SetOutputReshapeType(
void KernelBuildInfo::KernelBuildInfoBuilder::SetOutputsReshapeType(
const std::vector<std::vector<Axis>> &output_reshape_type) {
MS_EXCEPTION_IF_NULL(kernel_build_info_);
kernel_build_info_->output_reshape_type_ = output_reshape_type;
@@ -189,5 +189,37 @@ void KernelBuildInfo::KernelBuildInfoBuilder::SetOutputFormat(const std::string
}
kernel_build_info_->outputs_format_[index] = format;
}

void KernelBuildInfo::KernelBuildInfoBuilder::SetInputReshapeType(const std::vector<Axis> &input_reshape_type,
size_t index) {
if (index >= kernel_build_info_->input_reshape_type_.size()) {
MS_LOG(EXCEPTION) << "index outof range!";
}
std::copy(input_reshape_type.begin(), input_reshape_type.end(),
std::back_inserter(kernel_build_info_->input_reshape_type_[index]));
}

void KernelBuildInfo::KernelBuildInfoBuilder::SetOutputReshapeType(const std::vector<Axis> &output_reshape_type,
size_t index) {
if (index >= kernel_build_info_->output_reshape_type_.size()) {
MS_LOG(EXCEPTION) << "index outof range!";
}
std::copy(output_reshape_type.begin(), output_reshape_type.end(),
std::back_inserter(kernel_build_info_->output_reshape_type_[index]));
}

void KernelBuildInfo::KernelBuildInfoBuilder::SetOutputDeviceType(const TypeId &output_device_type, size_t index) {
if (index >= kernel_build_info_->outputs_device_type_.size()) {
MS_LOG(EXCEPTION) << "index outof range!";
}
kernel_build_info_->outputs_device_type_[index] = output_device_type;
}

void KernelBuildInfo::KernelBuildInfoBuilder::SetInputDeviceType(const TypeId &input_device_type, size_t index) {
if (index >= kernel_build_info_->inputs_device_type_.size()) {
MS_LOG(EXCEPTION) << "index outof range!";
}
kernel_build_info_->inputs_device_type_[index] = input_device_type;
}
} // namespace kernel
} // namespace mindspore

+ 31
- 4
mindspore/ccsrc/backend/kernel_compiler/kernel_build_info.h View File

@@ -71,6 +71,10 @@ class KernelBuildInfo {

std::vector<TypeId> GetAllOutputDeviceTypes() const;

std::vector<std::vector<Axis>> GetAllOutputReshapeType() const;

std::vector<std::vector<Axis>> GetAllInputReshapeType() const;

OpPattern op_pattern() const { return op_pattern_; }

FusionType fusion_type() const { return fusion_type_; }
@@ -108,8 +112,23 @@ class KernelBuildInfo::KernelBuildInfoBuilder {
public:
KernelBuildInfoBuilder() { kernel_build_info_ = std::make_shared<KernelBuildInfo>(); }

explicit KernelBuildInfoBuilder(std::shared_ptr<KernelBuildInfo> kernel_build_info)
: kernel_build_info_(std::move(kernel_build_info)) {}
explicit KernelBuildInfoBuilder(const std::shared_ptr<KernelBuildInfo> &kernel_build_info)
: kernel_build_info_(std::make_shared<KernelBuildInfo>()) {
SetKernelType(kernel_build_info->kernel_type());
SetFusionType(kernel_build_info->fusion_type());
SetProcessor(kernel_build_info->processor());
OpPattern(kernel_build_info->op_pattern());
for (size_t index = 0; index < kernel_build_info->GetInputNum(); ++index) {
kernel_build_info_->inputs_device_type_.emplace_back(kernel_build_info->GetInputDeviceType(index));
kernel_build_info_->inputs_format_.emplace_back(kernel_build_info->GetInputFormat(index));
kernel_build_info_->input_reshape_type_.emplace_back(kernel_build_info->GetInputReshapeType(index));
}
for (size_t index = 0; index < kernel_build_info->GetOutputNum(); ++index) {
kernel_build_info_->outputs_device_type_.emplace_back(kernel_build_info->GetOutputDeviceType(index));
kernel_build_info_->outputs_format_.emplace_back(kernel_build_info->GetOutputFormat(index));
kernel_build_info_->output_reshape_type_.emplace_back(kernel_build_info->GetOutputReshapeType(index));
}
}

~KernelBuildInfoBuilder() = default;

@@ -123,9 +142,9 @@ class KernelBuildInfo::KernelBuildInfoBuilder {

void SetOutputsDeviceType(const std::vector<TypeId> &outputs_device_type);

void SetInputReshapeType(const std::vector<std::vector<Axis>> &input_reshape_type);
void SetInputsReshapeType(const std::vector<std::vector<Axis>> &input_reshape_type);

void SetOutputReshapeType(const std::vector<std::vector<Axis>> &output_reshape_type);
void SetOutputsReshapeType(const std::vector<std::vector<Axis>> &output_reshape_type);

void SetFusionType(FusionType fusion_type);

@@ -137,6 +156,14 @@ class KernelBuildInfo::KernelBuildInfoBuilder {

void SetOutputFormat(const std::string &format, size_t index);

void SetInputReshapeType(const std::vector<Axis> &input_reshape_type, size_t index);

void SetOutputReshapeType(const std::vector<Axis> &output_reshape_type, size_t index);

void SetInputDeviceType(const TypeId &input_device_type, size_t index);

void SetOutputDeviceType(const TypeId &output_device_type, size_t index);

std::shared_ptr<KernelBuildInfo> Build();

private:


+ 2
- 2
mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_select.cc View File

@@ -118,7 +118,7 @@ void TbeKernelSelect::GetCommonPatternKernelInfo(const OpInfo &op_info) {
}
builder.SetInputsDeviceType(inputs_device_type);
builder.SetInputsFormat(inputs_format);
builder.SetInputReshapeType(inputs_reshape_type);
builder.SetInputsReshapeType(inputs_reshape_type);
// output
std::vector<std::string> outputs_format;
std::vector<TypeId> outputs_device_type;
@@ -129,7 +129,7 @@ void TbeKernelSelect::GetCommonPatternKernelInfo(const OpInfo &op_info) {
}
builder.SetOutputsDeviceType(outputs_device_type);
builder.SetOutputsFormat(outputs_format);
builder.SetOutputReshapeType(outputs_reshape_type);
builder.SetOutputsReshapeType(outputs_reshape_type);
kernel_info_list_->emplace_back(builder.Build());
}
MS_LOG(INFO) << "end.";


+ 2
- 0
mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc View File

@@ -47,6 +47,7 @@
#include "backend/optimizer/ascend/ir_fission/transdata_split.h"
#include "backend/optimizer/ascend/ir_fission/topk_split.h"
#include "backend/optimizer/ascend/ir_fusion/momentum_lossscale_fusion.h"
#include "backend/optimizer/ascend/format_type/split_unsupported_transdata.h"
#include "backend/optimizer/ascend/ir_fusion/mul_add_fusion.h"
#include "backend/optimizer/ascend/ir_fusion/mul_addn_fusion.h"
#include "backend/optimizer/ascend/ir_fusion/matmul_biasadd_fusion.h"
@@ -228,6 +229,7 @@ void AscendMixPrecision(const std::shared_ptr<session::KernelGraph> &kernel_grap
mixed_precision_pm->AddPass(std::make_shared<MergeCastToOp>());
mixed_precision_pm->AddPass(std::make_shared<LayerNormBetaGammaBackpropFusion>());
mixed_precision_pm->AddPass(std::make_shared<EraseVisitAttr>());
mixed_precision_pm->AddPass(std::make_shared<SplitUnsupportedTransData>());
mixed_precision_pm->AddPass(std::make_shared<ConvertUnSupportNodeToAICPU>());
mixed_precision_pm->AddPass(std::make_shared<RemoveInternalOutputCast>());
optimizer->AddPassManager(mixed_precision_pm);


+ 2
- 2
mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.cc View File

@@ -174,8 +174,8 @@ void RefreshKernelBuildInfo(const std::string &input_format, const std::string &
MS_EXCEPTION_IF_NULL(ori_build_info);
auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(ori_build_info);
builder->SetInputsFormat({input_format});
builder->SetInputReshapeType({reshape_type});
builder->SetOutputReshapeType({reshape_type});
builder->SetInputsReshapeType({reshape_type});
builder->SetOutputsReshapeType({reshape_type});
builder->SetOutputsFormat({output_format});
if (type_id != kTypeUnknown) {
builder->SetOutputsDeviceType({type_id});


+ 65
- 0
mindspore/ccsrc/backend/optimizer/ascend/format_type/split_unsupported_transdata.cc View File

@@ -0,0 +1,65 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "backend/optimizer/ascend/format_type/split_unsupported_transdata.h"
#include <vector>
#include <memory>
#include "backend/session/anf_runtime_algorithm.h"

namespace mindspore {
namespace opt {
const BaseRef SplitUnsupportedTransData::DefinePattern() const {
VarPtr X = std::make_shared<Var>();
return VectorRef({prim::KPrimTransData, X});
}

const AnfNodePtr SplitUnsupportedTransData::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const EquivPtr &) const {
if (node == nullptr || !node->isa<CNode>() || !AnfAlgo::IsRealKernel(node)) {
return nullptr;
}
auto ori_trans_data = node->cast<CNodePtr>();
if (AnfAlgo::GetCNodeName(ori_trans_data) != prim::KPrimTransData->name()) {
return nullptr;
}
auto kernel_info = AnfAlgo::GetSelectKernelBuildInfo(ori_trans_data);
MS_EXCEPTION_IF_NULL(kernel_info);
if (kernel_info->GetInputNum() != 1 || kernel_info->GetOutputNum() != 1) {
MS_LOG(EXCEPTION) << "Transdata node's kernel info's input and output format size is not 1"
<< ori_trans_data->DebugString();
}
return SplitTransData(func_graph, ori_trans_data);
}
AnfNodePtr SplitUnsupportedTransData::SplitTransData(const FuncGraphPtr &func_graph, const CNodePtr &trans_node) const {
auto kernel_info = AnfAlgo::GetSelectKernelBuildInfo(trans_node);
if (kHWSpecialFormatSet.find(kernel_info->GetInputFormat(0)) == kHWSpecialFormatSet.end() ||
kHWSpecialFormatSet.find(kernel_info->GetOutputFormat(0)) == kHWSpecialFormatSet.end()) {
return trans_node;
}
auto builder_info_to_default = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(kernel_info);
auto builder_info_to_special_foramt = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(kernel_info);
builder_info_to_default->SetOutputsFormat({kOpFormat_DEFAULT});
builder_info_to_special_foramt->SetInputsFormat({kOpFormat_DEFAULT});
std::vector<AnfNodePtr> next_trans_node_inputs = {
NewValueNode(std::make_shared<Primitive>(prim::KPrimTransData->name())), trans_node};
auto next_trans_node = func_graph->NewCNode(next_trans_node_inputs);
next_trans_node->set_abstract(trans_node->abstract());
AnfAlgo::SetSelectKernelBuildInfo(builder_info_to_default->Build(), trans_node.get());
AnfAlgo::SetSelectKernelBuildInfo(builder_info_to_special_foramt->Build(), next_trans_node.get());
return next_trans_node;
}
} // namespace opt
} // namespace mindspore

+ 37
- 0
mindspore/ccsrc/backend/optimizer/ascend/format_type/split_unsupported_transdata.h View File

@@ -0,0 +1,37 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_REMOVE_TRANSDATA_SPILT_H
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_REMOVE_TRANSDATA_SPILT_H

#include "backend/optimizer/common/optimizer.h"

namespace mindspore {
namespace opt {
class SplitUnsupportedTransData : public PatternProcessPass {
public:
explicit SplitUnsupportedTransData(bool multigraph = true)
: PatternProcessPass("split_unsupported_transdata", multigraph) {}
~SplitUnsupportedTransData() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;

private:
AnfNodePtr SplitTransData(const FuncGraphPtr &func_graph, const CNodePtr &trans_node) const;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_REMOVE_TRANSDATA_SPILT_H

+ 6
- 0
tests/ut/cpp/pre_activate/ascend/format_type/insert_trans_op_test.cc View File

@@ -51,6 +51,8 @@ class TestHWInsertTransOp : public BackendCommon {
builder.SetInputsFormat({format, format});
builder.SetInputsDeviceType({kFloat16->type_id(), kFloat16->type_id()});
builder.SetOutputsFormat({format});
builder.SetInputsReshapeType({{},{}});
builder.SetOutputsReshapeType({});
builder.SetOutputsDeviceType({kFloat16->type_id()});
add->set_kernel_info(std::make_shared<device::KernelInfo>());
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), add.get());
@@ -70,6 +72,8 @@ class TestHWInsertTransOp : public BackendCommon {
EXPECT_NE(ret->input(1)->cast<CNodePtr>()->input(1)->cast<CNodePtr>()->input(1), nullptr);
auto max_pool = ret->input(1)->cast<CNodePtr>()->input(1)->cast<CNodePtr>()->input(1);
KernelBuildInfoBuilder builder;
builder.SetInputsReshapeType({{},{}});
builder.SetOutputsReshapeType({});
builder.SetInputsFormat({kOpFormat_DEFAULT});
builder.SetInputsDeviceType({kFloat16->type_id()});
builder.SetOutputsFormat({format, format});
@@ -88,6 +92,8 @@ class MockInsertTransOpKernelSelectTrans4Dto5D : public KernelSelect {
~MockInsertTransOpKernelSelectTrans4Dto5D() override = default;
void SelectKernel(const CNodePtr &cnode) override {
KernelBuildInfoBuilder builder;
builder.SetInputsReshapeType({{}});
builder.SetOutputsReshapeType({{}});
builder.SetInputsFormat({"NCHW"});
builder.SetInputsDeviceType({kFloat16->type_id()});
builder.SetOutputsFormat({"NC1HWC0"});


+ 6
- 0
tests/ut/cpp/pre_activate/ascend/format_type/remove_internal_output_test.cc View File

@@ -52,6 +52,8 @@ class TestHWRemoveInternalOutput : public BackendCommon {
kg->AddInternalOutput(add, add);
KernelBuildInfoBuilder builder;
builder.SetInputsFormat({kOpFormat_DEFAULT, kOpFormat_DEFAULT});
builder.SetInputsReshapeType({{}, {}});
builder.SetOutputsReshapeType({{}});
builder.SetInputsDeviceType({kFloat32->type_id(), kFloat32->type_id()});
builder.SetOutputsFormat({kOpFormat_NC1HWC0});
builder.SetOutputsDeviceType({kFloat16->type_id()});
@@ -78,6 +80,8 @@ class TestHWRemoveInternalOutput : public BackendCommon {
kg->AddInternalOutput(tuple_getitem1, max_pool);
kg->AddInternalOutput(tuple_getitem2, max_pool);
KernelBuildInfoBuilder builder;
builder.SetInputsReshapeType({{}});
builder.SetOutputsReshapeType({{}, {}});
builder.SetInputsFormat({kOpFormat_DEFAULT});
builder.SetInputsDeviceType({kFloat32->type_id()});
builder.SetOutputsFormat({kOpFormat_NC1HWC0, kOpFormat_NC1HWC0});
@@ -95,6 +99,8 @@ class MockRemoveInternalOutputTransOpKernelSelect : public KernelSelect {
~MockRemoveInternalOutputTransOpKernelSelect() override = default;
void SelectKernel(const CNodePtr &cnode) override {
KernelBuildInfoBuilder builder;
builder.SetInputsReshapeType({{}});
builder.SetOutputsReshapeType({{}});
builder.SetInputsFormat({kOpFormat_NC1HWC0});
builder.SetInputsDeviceType({kFloat16->type_id()});
builder.SetOutputsFormat({kOpFormat_DEFAULT});


+ 12
- 0
tests/ut/cpp/pre_activate/ascend/ir_fission/transdata_split_test.cc View File

@@ -51,6 +51,8 @@ class MockInsertTransOpKernelSelectTrans4Dto5D : public KernelSelect {
builder.SetInputsDeviceType({kFloat16->type_id()});
builder.SetOutputsFormat({"NC1HWC0"});
builder.SetOutputsDeviceType({kFloat16->type_id()});
builder.SetInputsReshapeType({{}});
builder.SetOutputsReshapeType({{}});
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cnode.get());
} else {
KernelBuildInfoBuilder builder;
@@ -58,6 +60,8 @@ class MockInsertTransOpKernelSelectTrans4Dto5D : public KernelSelect {
builder.SetInputsDeviceType({kFloat16->type_id()});
builder.SetOutputsFormat({"NC1HWC0"});
builder.SetOutputsDeviceType({kFloat16->type_id()});
builder.SetInputsReshapeType({{}});
builder.SetOutputsReshapeType({{}});
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cnode.get());
}
}
@@ -74,10 +78,14 @@ class MockTransdataSplitKernelSelect : public KernelSelect {
builder.SetInputsDeviceType({kFloat16->type_id()});
builder.SetOutputsFormat({"NCHW"});
builder.SetOutputsDeviceType({kFloat16->type_id()});
builder.SetInputsReshapeType({{}});
builder.SetOutputsReshapeType({{}});
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cnode.get());
} else {
KernelBuildInfoBuilder builder;
builder.SetInputsFormat({"NCHW"});
builder.SetInputsReshapeType({{}});
builder.SetOutputsReshapeType({{}});
builder.SetInputsDeviceType({kFloat16->type_id()});
builder.SetOutputsFormat({"NCHW"});
builder.SetOutputsDeviceType({kFloat16->type_id()});
@@ -116,6 +124,8 @@ TEST_F(TestHWTransdataSplit, test_transdata_split_fraz_nchw) {
builder.SetKernelType(KernelType::TBE_KERNEL);
builder.SetFusionType(kernel::FusionType::ELEMWISE);
builder.SetProcessor(kernel::Processor::AICORE);
builder.SetInputsReshapeType({{}});
builder.SetOutputsReshapeType({{}});
auto kernel_info = std::make_shared<device::KernelInfo>();
kernel_info->set_select_kernel_build_info(builder.Build());
transpose->set_kernel_info(kernel_info);
@@ -162,6 +172,8 @@ TEST_F(TestHWTransdataSplit, test_transdata_split_nchw_fraz) {
builder.SetKernelType(KernelType::TBE_KERNEL);
builder.SetFusionType(kernel::FusionType::ELEMWISE);
builder.SetProcessor(kernel::Processor::AICORE);
builder.SetInputsReshapeType({{}});
builder.SetOutputsReshapeType({{}});
auto kernel_info = std::make_shared<device::KernelInfo>();
kernel_info->set_select_kernel_build_info(builder.Build());
transpose->set_kernel_info(kernel_info);


+ 6
- 0
tests/ut/cpp/pre_activate/ascend/ir_fusion/transpose_transdata_fusion_test.cc View File

@@ -58,6 +58,8 @@ class MockInsertTransOpKernelSelectTrans4Dto5D : public KernelSelect {
builder.SetInputsDeviceType({kFloat16->type_id()});
builder.SetOutputsFormat({"NC1HWC0"});
builder.SetOutputsDeviceType({kFloat16->type_id()});
builder.SetInputsReshapeType({});
builder.SetOutputsReshapeType({});
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cnode.get());
} else {
KernelBuildInfoBuilder builder;
@@ -65,6 +67,8 @@ class MockInsertTransOpKernelSelectTrans4Dto5D : public KernelSelect {
builder.SetInputsDeviceType({kFloat16->type_id()});
builder.SetOutputsFormat({"NC1HWC0"});
builder.SetOutputsDeviceType({kFloat16->type_id()});
builder.SetInputsReshapeType({});
builder.SetOutputsReshapeType({});
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cnode.get());
}
}
@@ -97,6 +101,8 @@ TEST_F(TestHWTransposeTransdataFusion, test_transpose_transdata_fusion) {
builder.SetInputsDeviceType({kFloat16->type_id()});
builder.SetOutputsFormat({"NC1HWC0"});
builder.SetOutputsDeviceType({kFloat16->type_id()});
builder.SetInputsReshapeType({});
builder.SetOutputsReshapeType({});
builder.SetKernelType(KernelType::TBE_KERNEL);
builder.SetFusionType(kernel::FusionType::ELEMWISE);
builder.SetProcessor(kernel::Processor::AICORE);


+ 8
- 3
tests/ut/cpp/pre_activate/pass/eliminate_redundant_op_test.cc View File

@@ -56,6 +56,8 @@ class MockEliminate5To4And4To5KernelSelect : public KernelSelect {
~MockEliminate5To4And4To5KernelSelect() override = default;
void SelectKernel(const CNodePtr &cnode) override {
KernelBuildInfoBuilder builder;
builder.SetInputsReshapeType({{}});
builder.SetOutputsReshapeType({{}});
builder.SetInputsFormat({"NCHW"});
builder.SetInputsDeviceType({kFloat16->type_id()});
builder.SetOutputsFormat({"NC1HWC0"});
@@ -102,7 +104,8 @@ TEST_F(TestHWEliminateRedundantOp, test_eliminate_5to4_4to5) {
builder.SetOutputsFormat({"NC1HWC0"});
builder.SetInputsDeviceType({kFloat16->type_id(), kFloat16->type_id()});
builder.SetOutputsDeviceType({kFloat16->type_id()});

builder.SetInputsReshapeType({{}, {}});
builder.SetOutputsReshapeType({{}});
sub->set_kernel_info(std::make_shared<device::KernelInfo>());
add->set_kernel_info(std::make_shared<device::KernelInfo>());
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), sub.get());
@@ -168,7 +171,8 @@ TEST_F(TestHWEliminateRedundantOp, test_eliminate_cast) {
builder.SetOutputsFormat({"NC1HWC0"});
builder.SetInputsDeviceType({kFloat16->type_id(), kFloat16->type_id()});
builder.SetOutputsDeviceType({kFloat16->type_id()});

builder.SetInputsReshapeType({{}, {}});
builder.SetOutputsReshapeType({{}});
sub->set_kernel_info(std::make_shared<device::KernelInfo>());
add->set_kernel_info(std::make_shared<device::KernelInfo>());
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), sub.get());
@@ -244,7 +248,8 @@ TEST_F(TestHWEliminateRedundantOp, test_eliminate_cast_depend_cast) {
builder.SetOutputsFormat({"NC1HWC0"});
builder.SetInputsDeviceType({kFloat16->type_id(), kFloat16->type_id()});
builder.SetOutputsDeviceType({kFloat16->type_id()});

builder.SetInputsReshapeType({{}, {}});
builder.SetOutputsReshapeType({{}});
sub->set_kernel_info(std::make_shared<device::KernelInfo>());
add->set_kernel_info(std::make_shared<device::KernelInfo>());
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), sub.get());


Loading…
Cancel
Save