Browse Source

fix GRU fusion for encorder_0111.pb

tags/v1.2.0-rc1
mengyuanli 4 years ago
parent
commit
0d141c1d02
2 changed files with 12 additions and 7 deletions
  1. +5
    -0
      mindspore/lite/src/ops/ops_utils.cc
  2. +7
    -7
      mindspore/lite/tools/optimizer/fusion/bidirection_tf_gru_cell_fusion.cc

+ 5
- 0
mindspore/lite/src/ops/ops_utils.cc View File

@@ -236,6 +236,10 @@ schema::PrimitiveT *DropoutGradPrimitiveCreator(const AnfNodePtr &node) {
auto ms_primc = GetValueNode<std::shared_ptr<mindspore::ops::DropoutGrad>>(node);
return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr;
}
schema::PrimitiveT *GRUPrimitiveCreator(const AnfNodePtr &node) {
auto ms_primc = GetValueNode<std::shared_ptr<mindspore::ops::GRU>>(node);
return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr;
}
schema::PrimitiveT *EltwisePrimitiveCreator(const AnfNodePtr &node) {
auto ms_primc = GetValueNode<std::shared_ptr<mindspore::ops::Eltwise>>(node);
return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr;
@@ -790,6 +794,7 @@ RegistryMSOps g_gatherPrimitiveCreatorRegistry("Gather", GatherPrimitiveCreator)
RegistryMSOps g_gatherNdPrimitiveCreatorRegistry("GatherNd", GatherNdPrimitiveCreator);
RegistryMSOps g_greaterPrimitiveCreatorRegistry("Greater", GreaterPrimitiveCreator);
RegistryMSOps g_greaterEqualPrimitiveCreatorRegistry("GreaterEqual", GreaterEqualPrimitiveCreator);
RegistryMSOps g_gRUPrimitiveCreatorRegistry("GRU", GRUPrimitiveCreator);
RegistryMSOps g_hashtableLookupPrimitiveCreatorRegistry("HashtableLookup", HashtableLookupPrimitiveCreator);
RegistryMSOps g_instanceNormPrimitiveCreatorRegistry("InstanceNorm", InstanceNormPrimitiveCreator);
RegistryMSOps g_invertPermutationPrimitiveCreatorRegistry("InvertPermutation", InvertPermutationPrimitiveCreator);


+ 7
- 7
mindspore/lite/tools/optimizer/fusion/bidirection_tf_gru_cell_fusion.cc View File

@@ -70,15 +70,15 @@ BiDirectionTfGruCellFusion::BiDirectionTfGruCellFusion(const std::string &name,

const BaseRef BiDirectionTfGruCellFusion::DefinePattern() const {
// forward
auto fw_max1 =
VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReduceFusion)), input_length_});
auto fw_max2 = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMaximum)),
std::make_shared<CondVar>(IsParameterNode), fw_max1});
auto fw_reduce = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReduceFusion)),
input_length_, std::make_shared<CondVar>(IsParameterNode)});
auto fw_max = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMaximum)),
std::make_shared<CondVar>(IsParameterNode), fw_reduce});

auto fw_shape = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimShape)), transpose_input_});
auto fw_stride = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimStridedSlice)), fw_shape,
std::make_shared<SeqVar>()});
auto fw_min = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMinimum)), fw_stride, fw_max2});
auto fw_min = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMinimum)), fw_stride, fw_max});

auto fw_reserve = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTensorListReserve)),
std::make_shared<CondVar>(IsParameterNode), fw_stride});
@@ -100,8 +100,8 @@ const BaseRef BiDirectionTfGruCellFusion::DefinePattern() const {
// backward
auto bw_reverse_seq =
VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReverseSequence)), input_, input_length_});
auto bw_max1 =
VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReduceFusion)), input_length_});
auto bw_max1 = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReduceFusion)), input_length_,
std::make_shared<CondVar>(IsParameterNode)});
auto bw_max2 = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMaximum)),
std::make_shared<CondVar>(IsParameterNode), bw_max1});
auto bw_trans = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTranspose)), bw_reverse_seq,


Loading…
Cancel
Save