From 54c96fe13b613d7eed2ab23b1639de5a1f0c8881 Mon Sep 17 00:00:00 2001 From: liuxiao93 Date: Mon, 19 Oct 2020 17:19:49 +0800 Subject: [PATCH] Add DynamicRNN for old backend. --- .../ccsrc/transform/graph_ir/op_adapter_map.h | 2 + .../graph_ir/op_declare/rnn_declare.cc | 44 +++++++++++++++++++ .../graph_ir/op_declare/rnn_declare.h | 8 +++- mindspore/ops/operations/_grad_ops.py | 2 +- 4 files changed, 54 insertions(+), 2 deletions(-) diff --git a/mindspore/ccsrc/transform/graph_ir/op_adapter_map.h b/mindspore/ccsrc/transform/graph_ir/op_adapter_map.h index a98a7a0005..2cf6de4dfd 100644 --- a/mindspore/ccsrc/transform/graph_ir/op_adapter_map.h +++ b/mindspore/ccsrc/transform/graph_ir/op_adapter_map.h @@ -183,6 +183,8 @@ constexpr const char kNameBasicLSTMCell[] = "BasicLSTMCell"; constexpr const char kNameBasicLSTMCellInputGrad[] = "BasicLSTMCellInputGrad"; constexpr const char kNameBasicLSTMCellWeightGrad[] = "BasicLSTMCellWeightGrad"; constexpr const char kNameBasicLSTMCellCStateGrad[] = "BasicLSTMCellCStateGrad"; +constexpr const char kNameDynamicRNN[] = "DynamicRNN"; +constexpr const char kNameDynamicRNNGrad[] = "DynamicRNNGrad"; constexpr const char kNameL2Loss[] = "L2Loss"; constexpr const char kNameCTCLoss[] = "CTCLoss"; constexpr const char kNameRange[] = "Range"; diff --git a/mindspore/ccsrc/transform/graph_ir/op_declare/rnn_declare.cc b/mindspore/ccsrc/transform/graph_ir/op_declare/rnn_declare.cc index 1b6c433664..484a76a54a 100644 --- a/mindspore/ccsrc/transform/graph_ir/op_declare/rnn_declare.cc +++ b/mindspore/ccsrc/transform/graph_ir/op_declare/rnn_declare.cc @@ -48,4 +48,48 @@ ATTR_MAP(BasicLSTMCellCStateGrad) = {{"forget_bias", ATTR_DESC(forget_bias, AnyT {"activation", ATTR_DESC(activation, AnyTraits())}}; OUTPUT_MAP(BasicLSTMCellCStateGrad) = {{0, OUTPUT_DESC(dgate)}, {1, OUTPUT_DESC(dct_1)}}; REG_ADPT_DESC(BasicLSTMCellCStateGrad, kNameBasicLSTMCellCStateGrad, ADPT_DESC(BasicLSTMCellCStateGrad)) + +// DynamicRNN +INPUT_MAP(DynamicRNN) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(w)}, {3, INPUT_DESC(b)}, + {4, INPUT_DESC(seq_length)}, {5, INPUT_DESC(init_h)}, {6, INPUT_DESC(init_c)}, + {7, INPUT_DESC(wci)}, {8, INPUT_DESC(wcf)}, {9, INPUT_DESC(wco)}, + {10, INPUT_DESC(mask)}}; +ATTR_MAP(DynamicRNN) = {{"cell_type", ATTR_DESC(cell_type, AnyTraits())}, + {"direction", ATTR_DESC(direction, AnyTraits())}, + {"cell_depth", ATTR_DESC(cell_depth, AnyTraits())}, + {"use_peephole", ATTR_DESC(use_peephole, AnyTraits())}, + {"keep_prob", ATTR_DESC(keep_prob, AnyTraits())}, + {"cell_clip", ATTR_DESC(cell_clip, AnyTraits())}, + {"num_proj", ATTR_DESC(num_proj, AnyTraits())}, + {"time_major", ATTR_DESC(time_major, AnyTraits())}, + {"ivation", ATTR_DESC(activation, AnyTraits())}, + {"forget_bias", ATTR_DESC(forget_bias, AnyTraits())}, + {"is_training", ATTR_DESC(is_training, AnyTraits())}}; +OUTPUT_MAP(DynamicRNN) = {{0, OUTPUT_DESC(y)}, {1, OUTPUT_DESC(output_h)}, {2, OUTPUT_DESC(output_c)}, + {3, OUTPUT_DESC(i)}, {4, OUTPUT_DESC(j)}, {5, OUTPUT_DESC(f)}, + {6, OUTPUT_DESC(o)}, {7, OUTPUT_DESC(tanhc)}}; +REG_ADPT_DESC(DynamicRNN, kNameDynamicRNN, ADPT_DESC(DynamicRNN)) + +// DynamicRNNGrad +INPUT_MAP(DynamicRNNGrad) = { + {1, INPUT_DESC(x)}, {2, INPUT_DESC(w)}, {3, INPUT_DESC(b)}, {4, INPUT_DESC(y)}, + {5, INPUT_DESC(init_h)}, {6, INPUT_DESC(init_c)}, {7, INPUT_DESC(h)}, {8, INPUT_DESC(c)}, + {9, INPUT_DESC(dy)}, {10, INPUT_DESC(dh)}, {11, INPUT_DESC(dc)}, {12, INPUT_DESC(i)}, + {13, INPUT_DESC(j)}, {14, INPUT_DESC(f)}, {15, INPUT_DESC(o)}, {16, INPUT_DESC(tanhct)}}; + +ATTR_MAP(DynamicRNNGrad) = {{"cell_type", ATTR_DESC(cell_type, AnyTraits())}, + {"direction", ATTR_DESC(direction, AnyTraits())}, + {"cell_depth", ATTR_DESC(cell_depth, AnyTraits())}, + {"use_peephole", ATTR_DESC(use_peephole, AnyTraits())}, + {"keep_prob", ATTR_DESC(keep_prob, AnyTraits())}, + {"cell_clip", ATTR_DESC(cell_clip, AnyTraits())}, + {"num_proj", ATTR_DESC(num_proj, AnyTraits())}, + {"time_major", ATTR_DESC(time_major, AnyTraits())}, + {"forget_bias", ATTR_DESC(forget_bias, AnyTraits())}}; +OUTPUT_MAP(DynamicRNNGrad) = {{0, OUTPUT_DESC(dw)}, + {1, OUTPUT_DESC(db)}, + {2, OUTPUT_DESC(dx)}, + {3, OUTPUT_DESC(dh_prev)}, + {4, OUTPUT_DESC(dc_prev)}}; +REG_ADPT_DESC(DynamicRNNGrad, kNameDynamicRNNGrad, ADPT_DESC(DynamicRNNGrad)) } // namespace mindspore::transform diff --git a/mindspore/ccsrc/transform/graph_ir/op_declare/rnn_declare.h b/mindspore/ccsrc/transform/graph_ir/op_declare/rnn_declare.h index 247c6a1a14..0939fdb131 100644 --- a/mindspore/ccsrc/transform/graph_ir/op_declare/rnn_declare.h +++ b/mindspore/ccsrc/transform/graph_ir/op_declare/rnn_declare.h @@ -19,8 +19,8 @@ #include #include -#include "transform/graph_ir/op_declare/op_declare_macro.h" #include "ops/rnn.h" +#include "transform/graph_ir/op_declare/op_declare_macro.h" namespace mindspore::transform { DECLARE_OP_ADAPTER(BasicLSTMCell) @@ -34,5 +34,11 @@ DECLARE_OP_USE_OUTPUT(BasicLSTMCellWeightGrad) DECLARE_OP_ADAPTER(BasicLSTMCellCStateGrad) DECLARE_OP_USE_OUTPUT(BasicLSTMCellCStateGrad) + +DECLARE_OP_ADAPTER(DynamicRNN) +DECLARE_OP_USE_OUTPUT(DynamicRNN) + +DECLARE_OP_ADAPTER(DynamicRNNGrad) +DECLARE_OP_USE_OUTPUT(DynamicRNNGrad) } // namespace mindspore::transform #endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_DECLARE_RNN_DECLARE_H_ diff --git a/mindspore/ops/operations/_grad_ops.py b/mindspore/ops/operations/_grad_ops.py index 48f405678d..39ca6581ba 100644 --- a/mindspore/ops/operations/_grad_ops.py +++ b/mindspore/ops/operations/_grad_ops.py @@ -1062,7 +1062,7 @@ class DynamicRNNGrad(PrimitiveWithInfer): keep_prob=-1.0, cell_clip=-1.0, num_proj=0, - time_major=False, + time_major=True, forget_bias=0.0): self.forget_bias = validator.check_value_type("forget_bias", forget_bias, [float], self.name) self.add_prim_attr("io_format", "ND")