|
|
|
@@ -70,15 +70,15 @@ BiDirectionTfGruCellFusion::BiDirectionTfGruCellFusion(const std::string &name, |
|
|
|
|
|
|
|
const BaseRef BiDirectionTfGruCellFusion::DefinePattern() const { |
|
|
|
// forward |
|
|
|
auto fw_max1 = |
|
|
|
VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReduceFusion)), input_length_}); |
|
|
|
auto fw_max2 = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMaximum)), |
|
|
|
std::make_shared<CondVar>(IsParameterNode), fw_max1}); |
|
|
|
auto fw_reduce = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReduceFusion)), |
|
|
|
input_length_, std::make_shared<CondVar>(IsParameterNode)}); |
|
|
|
auto fw_max = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMaximum)), |
|
|
|
std::make_shared<CondVar>(IsParameterNode), fw_reduce}); |
|
|
|
|
|
|
|
auto fw_shape = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimShape)), transpose_input_}); |
|
|
|
auto fw_stride = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimStridedSlice)), fw_shape, |
|
|
|
std::make_shared<SeqVar>()}); |
|
|
|
auto fw_min = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMinimum)), fw_stride, fw_max2}); |
|
|
|
auto fw_min = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMinimum)), fw_stride, fw_max}); |
|
|
|
|
|
|
|
auto fw_reserve = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTensorListReserve)), |
|
|
|
std::make_shared<CondVar>(IsParameterNode), fw_stride}); |
|
|
|
@@ -100,8 +100,8 @@ const BaseRef BiDirectionTfGruCellFusion::DefinePattern() const { |
|
|
|
// backward |
|
|
|
auto bw_reverse_seq = |
|
|
|
VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReverseSequence)), input_, input_length_}); |
|
|
|
auto bw_max1 = |
|
|
|
VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReduceFusion)), input_length_}); |
|
|
|
auto bw_max1 = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReduceFusion)), input_length_, |
|
|
|
std::make_shared<CondVar>(IsParameterNode)}); |
|
|
|
auto bw_max2 = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMaximum)), |
|
|
|
std::make_shared<CondVar>(IsParameterNode), bw_max1}); |
|
|
|
auto bw_trans = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTranspose)), bw_reverse_seq, |
|
|
|
|