|
|
|
@@ -31,6 +31,7 @@ |
|
|
|
namespace mindspore { |
|
|
|
namespace opt { |
|
|
|
namespace { |
|
|
|
constexpr int kOffsetTwo = 2; |
|
|
|
constexpr size_t kCondNodesNum = 12; |
|
|
|
constexpr size_t kCondCNodesNum = 4; |
|
|
|
constexpr size_t kBodyNodesNum = 69; |
|
|
|
@@ -162,7 +163,7 @@ const VectorRef TfBidirectionGruFusion::DefineFowardPattern() const { |
|
|
|
MS_CHECK_TRUE_RET(is_param6 != nullptr, {}); |
|
|
|
auto fw_while = VectorRef({is_while, fw_vars_[0], fw_vars_[1], is_param5, fw_stride, is_param6, fw_reserve, |
|
|
|
fw_init_state_, fw_min, fw_from_tensor, input_length_}); |
|
|
|
fw_while.insert(fw_while.end(), fw_vars_.begin() + 2, fw_vars_.end()); |
|
|
|
fw_while.insert(fw_while.end(), fw_vars_.begin() + kOffsetTwo, fw_vars_.end()); |
|
|
|
auto is_var1 = std::make_shared<Var>(); |
|
|
|
MS_CHECK_TRUE_RET(is_var1 != nullptr, {}); |
|
|
|
fw_while.emplace_back(is_var1); |
|
|
|
@@ -232,7 +233,7 @@ const VectorRef TfBidirectionGruFusion::DefinebackwardPattern() const { |
|
|
|
MS_CHECK_TRUE_RET(is_param6 != nullptr, {}); |
|
|
|
auto bw_while = VectorRef({is_while, bw_vars_[0], bw_vars_[1], is_param5, bw_stride, is_param6, bw_reserve, |
|
|
|
bw_init_state_, bw_min, bw_from_tensor, input_length_}); |
|
|
|
bw_while.insert(bw_while.end(), bw_vars_.begin() + 2, bw_vars_.end()); |
|
|
|
bw_while.insert(bw_while.end(), bw_vars_.begin() + kOffsetTwo, bw_vars_.end()); |
|
|
|
auto is_var2 = std::make_shared<Var>(); |
|
|
|
MS_CHECK_TRUE_RET(is_var2 != nullptr, {}); |
|
|
|
bw_while.emplace_back(is_var2); |
|
|
|
@@ -400,7 +401,7 @@ STATUS TfBidirectionGruFusion::GetInputAndHiddenSize(const AnfNodePtr &fw_cand_k |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
auto fw_cand_kernel_shape = fw_cand_kernel_value->shape(); |
|
|
|
if (fw_cand_kernel_shape.size() != 2) { |
|
|
|
if (fw_cand_kernel_shape.size() != kInputSizeTwo) { |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
auto bw_cand_kernel_value = GetDefaultTensorInfo(bw_cand_kernel_anf); |
|
|
|
@@ -408,7 +409,7 @@ STATUS TfBidirectionGruFusion::GetInputAndHiddenSize(const AnfNodePtr &fw_cand_k |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
auto bw_cand_kernel_shape = bw_cand_kernel_value->shape(); |
|
|
|
if (bw_cand_kernel_shape.size() != 2) { |
|
|
|
if (bw_cand_kernel_shape.size() != kInputSizeTwo) { |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
if (fw_cand_kernel_shape != bw_cand_kernel_shape) { |
|
|
|
|