| @@ -187,6 +187,8 @@ constexpr const char kNameBasicLSTMCellWeightGrad[] = "BasicLSTMCellWeightGrad"; | |||
| constexpr const char kNameBasicLSTMCellCStateGrad[] = "BasicLSTMCellCStateGrad"; | |||
| constexpr const char kNameDynamicRNN[] = "DynamicRNN"; | |||
| constexpr const char kNameDynamicRNNGrad[] = "DynamicRNNGrad"; | |||
| constexpr const char kNameDynamicGRUV2[] = "DynamicGRUV2"; | |||
| constexpr const char kNameDynamicGRUV2Grad[] = "DynamicGRUV2Grad"; | |||
| constexpr const char kNameL2Loss[] = "L2Loss"; | |||
| constexpr const char kNameCTCLoss[] = "CTCLoss"; | |||
| constexpr const char kNameRange[] = "Range"; | |||
| @@ -92,4 +92,42 @@ OUTPUT_MAP(DynamicRNNGrad) = {{0, OUTPUT_DESC(dw)}, | |||
| {3, OUTPUT_DESC(dh_prev)}, | |||
| {4, OUTPUT_DESC(dc_prev)}}; | |||
| REG_ADPT_DESC(DynamicRNNGrad, kNameDynamicRNNGrad, ADPT_DESC(DynamicRNNGrad)) | |||
| // DynamicGRUV2 | |||
| INPUT_MAP(DynamicGRUV2) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(weight_input)}, {3, INPUT_DESC(weight_hidden)}, | |||
| {4, INPUT_DESC(bias_input)}, {5, INPUT_DESC(bias_hidden)}, {6, INPUT_DESC(seq_length)}, | |||
| {7, INPUT_DESC(init_h)}}; | |||
| ATTR_MAP(DynamicGRUV2) = {{"direction", ATTR_DESC(direction, AnyTraits<std::string>())}, | |||
| {"cell_depth", ATTR_DESC(cell_depth, AnyTraits<int64_t>())}, | |||
| {"keep_prob", ATTR_DESC(keep_prob, AnyTraits<float>())}, | |||
| {"cell_clip", ATTR_DESC(cell_clip, AnyTraits<float>())}, | |||
| {"num_proj", ATTR_DESC(num_proj, AnyTraits<int64_t>())}, | |||
| {"time_major", ATTR_DESC(time_major, AnyTraits<bool>())}, | |||
| {"activation", ATTR_DESC(direction, AnyTraits<std::string>())}, | |||
| {"gate_order", ATTR_DESC(gate_order, AnyTraits<std::string>())}, | |||
| {"reset_after", ATTR_DESC(reset_after, AnyTraits<bool>())}, | |||
| {"is_training", ATTR_DESC(is_training, AnyTraits<bool>())}}; | |||
| OUTPUT_MAP(DynamicGRUV2) = {{0, OUTPUT_DESC(y)}, {1, OUTPUT_DESC(output_h)}, {2, OUTPUT_DESC(update)}, | |||
| {3, OUTPUT_DESC(reset)}, {4, OUTPUT_DESC(new)}, {5, OUTPUT_DESC(hidden_new)}}; | |||
| REG_ADPT_DESC(DynamicGRUV2, kNameDynamicGRUV2, ADPT_DESC(DynamicGRUV2)) | |||
| // DynamicGRUV2Grad | |||
| INPUT_MAP(DynamicGRUV2Grad) = { | |||
| {1, INPUT_DESC(x)}, {2, INPUT_DESC(weight_input)}, {3, INPUT_DESC(weight_hidden)}, | |||
| {4, INPUT_DESC(y)}, {5, INPUT_DESC(init_h)}, {6, INPUT_DESC(h)}, | |||
| {7, INPUT_DESC(dy)}, {8, INPUT_DESC(dh)}, {9, INPUT_DESC(update)}, | |||
| {10, INPUT_DESC(reset)}, {11, INPUT_DESC(new)}, {12, INPUT_DESC(hidden_new)}, | |||
| {13, INPUT_DESC(seq_length)}, {14, INPUT_DESC(mask)}}; | |||
| ATTR_MAP(DynamicGRUV2Grad) = {{"direction", ATTR_DESC(direction, AnyTraits<std::string>())}, | |||
| {"cell_depth", ATTR_DESC(cell_depth, AnyTraits<int64_t>())}, | |||
| {"keep_prob", ATTR_DESC(keep_prob, AnyTraits<float>())}, | |||
| {"cell_clip", ATTR_DESC(cell_clip, AnyTraits<float>())}, | |||
| {"num_proj", ATTR_DESC(num_proj, AnyTraits<int64_t>())}, | |||
| {"time_major", ATTR_DESC(time_major, AnyTraits<bool>())}, | |||
| {"bias_type", ATTR_DESC(bias_type, AnyTraits<std::string>())}, | |||
| {"gate_order", ATTR_DESC(gate_order, AnyTraits<std::string>())}, | |||
| {"reset_after", ATTR_DESC(reset_after, AnyTraits<bool>())}}; | |||
| OUTPUT_MAP(DynamicGRUV2Grad) = {{0, OUTPUT_DESC(dw_input)}, {1, OUTPUT_DESC(dw_hidden)}, {2, OUTPUT_DESC(db_input)}, | |||
| {3, OUTPUT_DESC(db_hidden)}, {4, OUTPUT_DESC(dx)}, {5, OUTPUT_DESC(dh_prev)}}; | |||
| REG_ADPT_DESC(DynamicGRUV2Grad, kNameDynamicGRUV2Grad, ADPT_DESC(DynamicGRUV2Grad)) | |||
| } // namespace mindspore::transform | |||
| @@ -40,5 +40,11 @@ DECLARE_OP_USE_OUTPUT(DynamicRNN) | |||
| DECLARE_OP_ADAPTER(DynamicRNNGrad) | |||
| DECLARE_OP_USE_OUTPUT(DynamicRNNGrad) | |||
| DECLARE_OP_ADAPTER(DynamicGRUV2) | |||
| DECLARE_OP_USE_OUTPUT(DynamicGRUV2) | |||
| DECLARE_OP_ADAPTER(DynamicGRUV2Grad) | |||
| DECLARE_OP_USE_OUTPUT(DynamicGRUV2Grad) | |||
| } // namespace mindspore::transform | |||
| #endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_DECLARE_RNN_DECLARE_H_ | |||
| @@ -225,6 +225,8 @@ constexpr auto kBasicLSTMCellInputGradOpName = "BasicLSTMCellInputGrad"; | |||
| constexpr auto kBasicLSTMCellOpName = "BasicLSTMCell"; | |||
| constexpr auto kDynamicRNNOpName = "DynamicRNN"; | |||
| constexpr auto kLSTMInputGradOpName = "LSTMInputGrad"; | |||
| constexpr auto kDynamicGRUOpName = "DynamicGRU"; | |||
| constexpr auto kGRUV2HiddenGrad = "GRUV2HiddenGrad"; | |||
| constexpr auto kFusedSparseFtrlName = "FusedSparseFtrl"; | |||
| constexpr auto kFusedSparseProximalAdagradName = "FusedSparseProximalAdagrad"; | |||
| constexpr auto kFusedSparseLazyAdamName = "FusedSparseLazyAdam"; | |||
| @@ -110,6 +110,8 @@ inline const PrimitivePtr kPrimUniqueGrad = std::make_shared<Primitive>("UniqueG | |||
| inline const PrimitivePtr kPrimExtractImagePatches = std::make_shared<Primitive>("ExtractImagePatches"); | |||
| inline const PrimitivePtr kPrimDynamicRNN = std::make_shared<Primitive>("DynamicRNN"); | |||
| inline const PrimitivePtr kPrimDynamicRNNGrad = std::make_shared<Primitive>("DynamicRNNGrad"); | |||
| inline const PrimitivePtr kPrimDynamicGRUV2 = std::make_shared<Primitive>("DynamicGRUV2"); | |||
| inline const PrimitivePtr kPrimDynamicGRUV2Grad = std::make_shared<Primitive>("DynamicGRUV2Grad"); | |||
| inline const PrimitivePtr kPrimScatterAdd = std::make_shared<Primitive>("ScatterAdd"); | |||
| inline const PrimitivePtr kPrimScatterUpdate = std::make_shared<Primitive>("ScatterUpdate"); | |||
| inline const PrimitivePtr kPrimDiv = std::make_shared<Primitive>("Div"); | |||
| @@ -849,7 +849,16 @@ def get_bprop_lstm(self): | |||
| @bprop_getters.register(P.DynamicRNN) | |||
| def get_bprop_dynamic_rnn(self): | |||
| """Grad definition for `DynamicRNN` operation.""" | |||
| dynamic_rnn_grad = G.DynamicRNNGrad(forget_bias=self.forget_bias) | |||
| dynamic_rnn_grad = G.DynamicRNNGrad(cell_type=self.cell_type, | |||
| direction=self.direction, | |||
| cell_depth=self.cell_depth, | |||
| use_peephole=self.use_peephole, | |||
| keep_prob=self.keep_prob, | |||
| cell_clip=self.cell_clip, | |||
| num_proj=self.num_proj, | |||
| time_major=self.time_major, | |||
| forget_bias=self.forget_bias) | |||
| expand_dims = P.ExpandDims() | |||
| def bprop(x, w, b, seq_length, init_h, init_c, out, dout): | |||
| dy, dh, dc, _, _, _, _, _, = dout | |||
| @@ -858,10 +867,30 @@ def get_bprop_dynamic_rnn(self): | |||
| y, h, c, i, j, f, o, tanhct = out | |||
| dw, db, dx, dh_prev, dc_prev = dynamic_rnn_grad(x, w, b, y, init_h[0], init_c[0], h, | |||
| c, dy, dh, dc, i, j, f, o, tanhct) | |||
| dh_prev = expand_dims(dh_prev, 0) | |||
| dc_prev = expand_dims(dc_prev, 0) | |||
| return dx, dw, db, (0), dh_prev, dc_prev | |||
| return bprop | |||
| @bprop_getters.register(inner.DynamicGRUV2) | |||
| def get_bprop_dynamic_gru_v2(self): | |||
| """Grad definition for `DynamicGRUV2` operation.""" | |||
| dynamic_gru_v2_grad = G.DynamicGRUV2Grad(self.direction, self.cell_depth, self.keep_prob, self.cell_clip, | |||
| self.num_proj, self.time_major, 'double_bias', self.gate_order, | |||
| self.reset_after) | |||
| def bprop(x, winput, whidden, binput, bhidden, seq, init_h, out, dout): | |||
| y, out_h, update, reset, new, hidden_new = out | |||
| dy, dout_h, _, _, _, _ = dout | |||
| dw_input, dw_hidden, db_input, db_hidden, dx, dh_prev = dynamic_gru_v2_grad(x, winput, whidden, y, init_h, | |||
| out_h, dy, dout_h[-1], update, | |||
| reset, new, hidden_new, None, None) | |||
| return dx, dw_input, dw_hidden, db_input, db_hidden, (0), dh_prev | |||
| return bprop | |||
| @bprop_getters.register(P.SigmoidCrossEntropyWithLogits) | |||
| def get_bprop_sigmoid_crossentropy_with_logits(self): | |||
| """Grad definition for `SigmoidCrossEntropyWithLogits` operation.""" | |||
| @@ -286,6 +286,8 @@ from .basic_lstm_cell_c_state_grad import _basic_lstm_cell_c_state_grad_tbe | |||
| from .basic_lstm_cell_weight_grad import _basic_lstm_cell_weight_grad_tbe | |||
| from .basic_lstm_cell_input_grad import _basic_lstm_cell_input_grad_tbe | |||
| from .dynamic_rnn import _dynamic_rnn_tbe | |||
| from .dynamic_gru_v2 import _dynamic_gru_v2_tbe | |||
| from .gru_v2_hidden_grad import _gru_v2_hidden_grad_tbe | |||
| from .lstm_input_grad import _lstm_input_grad_tbe | |||
| from .confusion_matrix import _confusion_matrix_tbe | |||
| from .broadcast_to import _broadcast_to_tbe | |||
| @@ -0,0 +1,63 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """DynamicGRUV2 op""" | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| dynamic_gru_v2_op_info = TBERegOp("DynamicGRUV2") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("dynamic_gru_v2.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("dynamic_gru_v2") \ | |||
| .attr("direction", "optional", "str", "all", "UNIDIRECTIONAL") \ | |||
| .attr("cell_depth", "optional", "int", "all", "1") \ | |||
| .attr("keep_prob", "optional", "float", "all", "1") \ | |||
| .attr("cell_clip", "optional", "float", "all", "-1") \ | |||
| .attr("num_proj", "optional", "int", "all", "0") \ | |||
| .attr("time_major", "optional", "bool", "all", "true") \ | |||
| .attr("activation", "optional", "str", "all", "tanh") \ | |||
| .attr("gate_order", "optional", "str", "all", "rzh") \ | |||
| .attr("reset_after", "optional", "bool", "all", "true") \ | |||
| .attr("is_training", "optional", "bool", "all", "true") \ | |||
| .partial_flag(True) \ | |||
| .input(0, "x", False, "required", "all") \ | |||
| .input(1, "weight_input", False, "required", "all") \ | |||
| .input(2, "weight_hidden", False, "required", "all") \ | |||
| .input(3, "bias_input", False, "optional", "all") \ | |||
| .input(4, "bias_hidden", False, "optional", "all") \ | |||
| .input(5, "seq_length", False, "optional", "all") \ | |||
| .input(6, "init_h", False, "optional", "all") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .output(1, "output_h", False, "required", "all") \ | |||
| .output(2, "update", False, "optional", "all") \ | |||
| .output(3, "reset", False, "optional", "all") \ | |||
| .output(4, "new", False, "optional", "all") \ | |||
| .output(5, "hidden_new", False, "optional", "all") \ | |||
| .dtype_format(DataType.F16_FracNZ, DataType.F16_FracZ, DataType.F16_FracZ, DataType.F32_Default, | |||
| DataType.F32_Default, DataType.I32_Default, DataType.F32_FracNZ, DataType.F32_FracNZ, | |||
| DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, | |||
| DataType.F32_FracNZ) \ | |||
| .dtype_format(DataType.F16_FracNZ, DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_Default, | |||
| DataType.F16_Default, DataType.I32_Default, DataType.F16_FracNZ, DataType.F16_FracNZ, | |||
| DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, | |||
| DataType.F16_FracNZ) \ | |||
| .get_op_info() | |||
| @op_info_register(dynamic_gru_v2_op_info) | |||
| def _dynamic_gru_v2_tbe(): | |||
| """DynamicGRUV2 TBE register""" | |||
| return | |||
| @@ -0,0 +1,51 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """GRUV2HiddenGrad op""" | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| gru_v2_hidden_grad_op_info = TBERegOp("GRUV2HiddenGrad") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("gru_v2_hidden_grad.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("gru_v2_hidden_grad") \ | |||
| .attr("gate_order", "optional", "str", "all", "zrh") \ | |||
| .partial_flag(True) \ | |||
| .input(0, "weight_input", False, "required", "all") \ | |||
| .input(1, "init_h", False, "required", "all") \ | |||
| .input(2, "h", False, "required", "all") \ | |||
| .input(3, "dy", False, "optional", "all") \ | |||
| .input(4, "dh", False, "optional", "all") \ | |||
| .input(5, "update", False, "optional", "all") \ | |||
| .input(6, "reset", False, "optional", "all") \ | |||
| .input(7, "new", False, "optional", "all") \ | |||
| .input(8, "hidden_new", False, "optional", "all") \ | |||
| .output(0, "dh_preh", False, "required", "all") \ | |||
| .output(1, "dgate_h", False, "required", "all") \ | |||
| .output(2, "dnt_x", False, "optional", "all") \ | |||
| .dtype_format(DataType.F16_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, | |||
| DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, | |||
| DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ) \ | |||
| .dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, | |||
| DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, | |||
| DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ) \ | |||
| .get_op_info() | |||
| @op_info_register(gru_v2_hidden_grad_op_info) | |||
| def _gru_v2_hidden_grad_tbe(): | |||
| """DynamicGRUV2 TBE register""" | |||
| return | |||
| @@ -1095,9 +1095,9 @@ class DynamicRNNGrad(PrimitiveWithInfer): | |||
| def __init__(self, | |||
| cell_type='LSTM', | |||
| direction='UNIDIRECTIONAL', | |||
| cell_depth=0, | |||
| cell_depth=1, | |||
| use_peephole=False, | |||
| keep_prob=-1.0, | |||
| keep_prob=1.0, | |||
| cell_clip=-1.0, | |||
| num_proj=0, | |||
| time_major=True, | |||
| @@ -1135,6 +1135,147 @@ class DynamicRNNGrad(PrimitiveWithInfer): | |||
| return x_dtype, x_dtype, x_dtype, x_dtype, x_dtype | |||
| class DynamicGRUV2Grad(PrimitiveWithInfer): | |||
| r""" | |||
| Computes the input gradients of DynamicGRUV2. | |||
| Args: | |||
| direction (str): A string identifying the direction in the op. Default: 'UNIDIRECTIONAL'. | |||
| Only 'UNIDIRECTIONAL' is currently supported. | |||
| cell_depth (int): An integer identifying the cell depth in the op. Default: 1. | |||
| keep_prob (float): A float identifying the keep prob in the op. Default: 1.0. | |||
| cell_clip (float): A float identifying the cell clip in the op. Default: -1.0. | |||
| num_proj (int): An integer identifying the num proj in the op. Default: 0. | |||
| time_major (bool): A bool identifying the time major in the op. Default: True. | |||
| bias_type (str): An string identifying the type of bias_type function in the op. Default to "double_bias". | |||
| gate_order (str): An string identifying the gate order in weight and bias. Default: 'rzh. | |||
| 'zrh' is another option. | |||
| reset_after (bool): An bool identifying whether to apply reset gate after matrix multiplication. Default: True. | |||
| Inputs: | |||
| - **x** (Tensor) - Current words. Tensor of shape :math:`({num_step, batch_size, input_size)`. | |||
| The data type must be float16 or float32. | |||
| - **weight_input** (Tensor) - Weight. Tensor of shape :math:`(input_size, 3 x hidden_size)`. | |||
| The data type must be float16 or float32. | |||
| - **weight_hidden** (Tensor) - Bias. Tensor of shape :math:`(hidden_size, 3 x hidden_size)`. | |||
| The data type must be float16 or float32. | |||
| - **y** (Tensor) - A Tensor of shape :math: | |||
| if num_proj > 0 `(num_step, batch_size, min(hidden_size, num_proj)`, | |||
| if num_proj == 0 `(num_step, batch_size, hidden_size)`. | |||
| The data type must be float16 or float32. | |||
| - **init_h** (Tensor) - Hidden state of initial time. | |||
| Tensor of shape :math:`(batch_size, hidden_size)`, or None. | |||
| The data type must be float16 or float32. | |||
| - **h** (Tensor) - A Tensor of shape :math:`({num_step, batch_size, hidden_size)`. | |||
| The data type must be float16 or float32. | |||
| - **dy** (Tensor) - Gradient of `y`, has the same shape and data type as `y`. | |||
| - **dh** (Tensor) - Gradient of `h`, has the same shape and data type as `h`. | |||
| - **update** (Tensor) - A Tensor of shape :math:`({num_step, batch_size, hidden_size)`. | |||
| The data type must be float16 or float32. | |||
| - **reset** (Tensor) - A Tensor of shape :math:`({num_step, batch_size, hidden_size)`. | |||
| The data type must be float16 or float32. | |||
| - **new** (Tensor) - A Tensor of shape :math:`({num_step, batch_size, hidden_size)`. | |||
| The data type must be float16 or float32. | |||
| - **hidden_new** (Tensor) - A Tensor of shape :math:`(num_step, batch_size, hidden_size)`. | |||
| The data type must be float16 or float32. | |||
| - **seq_length** (Tensor) - The length of each batch. Tensor of shape :math:`(batch_size)`. | |||
| Only `None` is currently supported. | |||
| - **mask** (Tensor) - A 4-D Tensor. The data type must be float16 or float32. | |||
| Outputs: | |||
| - **dw_input** (Tensor) - A Tensor has the same shape as `weight_input`. | |||
| Has the same type with input `x`. | |||
| - **dw_hidden** (Tensor) - A Tensor has the same shape as `weight_hidden`. | |||
| Has the same type with input `x`. | |||
| - **db_input** (Tensor) - A Tensor of shape :math:`(3 x hidden_size)`. | |||
| Has the same type with input `x`. | |||
| - **db_hidden** (Tensor) - A Tensor of shape :math:`(3 x hidden_size)`. | |||
| Has the same type with input `x`. | |||
| - **dx** (Tensor) - A Tensor of shape :math:`(num_step, batch_size, hidden_size)`. | |||
| Has the same type with input `x`. | |||
| - **dh_prev** (Tensor) - A Tensor of shape :math:`(batch_size, hidden_size)`. | |||
| Has the same type with input `x`. | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self, | |||
| direction='UNIDIRECTIONAL', | |||
| cell_depth=1, | |||
| keep_prob=1.0, | |||
| cell_clip=-1.0, | |||
| num_proj=0, | |||
| time_major=True, | |||
| bias_type="double_bias", | |||
| gate_order="zrh", | |||
| reset_after=True): | |||
| self.cell_depth = validator.check_value_type("cell_depth", cell_depth, [int], self.name) | |||
| self.keep_prob = validator.check_value_type("keep_prob", keep_prob, [float], self.name) | |||
| self.cell_clip = validator.check_value_type("cell_clip", cell_clip, [float], self.name) | |||
| self.num_proj = validator.check_non_negative_int(num_proj, "num_proj", self.name) | |||
| self.time_major = validator.check_value_type("time_major", time_major, [bool], self.name) | |||
| self.direction = validator.check_string(direction, ['UNIDIRECTIONAL'], "direction", self.name) | |||
| self.bias_type = validator.check_string(bias_type, | |||
| ['no_bias', 'single_bias', 'double_bias'], "bias_type", self.name) | |||
| self.gate_order = validator.check_string(gate_order, ['zrh', 'rzh'], "gate_order", self.name) | |||
| self.reset_after = validator.check_value_type("reset_after", reset_after, [bool], self.name) | |||
| self.add_prim_attr("io_format", "ND") | |||
| def infer_shape(self, x_shape, winput_shape, whidden_shape, y_shape, init_h_shape, h_shape, | |||
| dy_shape, dh_shape, update_shape, reset_shape, new_shape, hnew_shape, seq_shape, mask_shape): | |||
| validator.check_int(len(x_shape), 3, Rel.EQ, "x shape", self.name) | |||
| validator.check_int(len(winput_shape), 2, Rel.EQ, "weight input shape rank", self.name) | |||
| validator.check_int(len(whidden_shape), 2, Rel.EQ, "weight hidden shape rank", self.name) | |||
| validator.check_int(len(y_shape), 3, Rel.EQ, "y shape rank", self.name) | |||
| num_step, batch_size, input_size = x_shape | |||
| hidden_size = whidden_shape[0] | |||
| validator.check("weight_hidden_shape[-1]", whidden_shape[-1], "3 * hidden_size", | |||
| 3 * hidden_size, Rel.EQ, self.name) | |||
| validator.check("weight_input_shape", winput_shape, "excepted shape", | |||
| [input_size, 3 * hidden_size], Rel.EQ, self.name) | |||
| if self.num_proj > 0: | |||
| valid_y_shape = [num_step, batch_size, min(hidden_size, self.num_proj)] | |||
| else: | |||
| valid_y_shape = [num_step, batch_size, hidden_size] | |||
| validator.check("y_shape", y_shape, "excepted shape", valid_y_shape, Rel.EQ, self.name) | |||
| validator.check("init_h_shape", init_h_shape, "excepted shape", | |||
| [batch_size, hidden_size], Rel.EQ, self.name) | |||
| valid_shape = [num_step, batch_size, hidden_size] | |||
| validator.check("h_shape", h_shape, "excepted shape", valid_shape, Rel.EQ, self.name) | |||
| validator.check("dy_shape", dy_shape, "excepted shape", valid_shape, Rel.EQ, self.name) | |||
| validator.check("dh_shape", dh_shape, "excepted shape", | |||
| [batch_size, hidden_size], Rel.EQ, self.name) | |||
| validator.check("update_shape", update_shape, "excepted shape", valid_shape, Rel.EQ, self.name) | |||
| validator.check("reset_shape", reset_shape, "excepted shape", valid_shape, Rel.EQ, self.name) | |||
| validator.check("new_shape", new_shape, "excepted shape", valid_shape, Rel.EQ, self.name) | |||
| validator.check("hnew_shape", hnew_shape, "excepted shape", valid_shape, Rel.EQ, self.name) | |||
| if seq_shape is not None: | |||
| validator.check("seq_shape", seq_shape, "batch_size", batch_size, Rel.EQ, self.name) | |||
| dx_shape = (num_step, batch_size, input_size) | |||
| dh_shape = (batch_size, hidden_size) | |||
| dwinput_shape = (input_size, 3 * hidden_size) | |||
| dwhidden_shape = (hidden_size, 3 * hidden_size) | |||
| db_shape = (3 * hidden_size,) | |||
| return dwinput_shape, dwhidden_shape, db_shape, db_shape, dx_shape, dh_shape | |||
| def infer_dtype(self, x_dtype, winput_dtype, whidden_dtype, y_dtype, init_h_dtype, h_dtype, | |||
| dy_dtype, dh_dtype, update_dtype, reset_dtype, new_dtype, hnew_dtype, seq_dtype, mask_dtype): | |||
| valid_types = (mstype.float16, mstype.float32) | |||
| args = {"y_dtype": y_dtype, "init_h_dtype": init_h_dtype, "h_dtype": h_dtype, | |||
| "dy_dtype": dy_dtype, "dh_dtype": dh_dtype, "update_dtype": update_dtype, | |||
| "reset_dtype": reset_dtype, "new_dtype": new_dtype, "hnew_dtype": hnew_dtype} | |||
| validator.check_tensor_type_same({"x_dtype": x_dtype}, valid_types, self.name) | |||
| validator.check_tensor_type_same({"winput_dtype": winput_dtype}, valid_types, self.name) | |||
| validator.check_tensor_type_same({"whidden_dtype": whidden_dtype}, valid_types, self.name) | |||
| validator.check_tensor_type_same(args, valid_types, self.name) | |||
| if seq_dtype is not None: | |||
| validator.check_tensor_type_same({"seq_dtype": seq_dtype}, (mstype.float32, mstype.float16), self.name) | |||
| if mask_dtype is not None: | |||
| validator.check_tensor_type_same({"mask_dtype": mask_dtype}, (mstype.float32, mstype.float16), self.name) | |||
| return x_dtype, x_dtype, x_dtype, x_dtype, x_dtype, x_dtype | |||
| class PReLUGrad(PrimitiveWithInfer): | |||
| r""" | |||
| Gradients of PReLU operation. | |||
| @@ -451,6 +451,157 @@ class MatrixSetDiag(PrimitiveWithInfer): | |||
| return assist_shape | |||
| class DynamicGRUV2(PrimitiveWithInfer): | |||
| r""" | |||
| DynamicGRUV2 Operator. | |||
| Args: | |||
| direction (str): A string identifying the direction in the op. Default: 'UNIDIRECTIONAL'. | |||
| Only 'UNIDIRECTIONAL' is currently supported. | |||
| cell_depth (int): An integer identifying the cell depth in the op. Default: 1. | |||
| keep_prob (float): A float identifying the keep prob in the op. Default: 1.0. | |||
| cell_clip (float): A float identifying the cell clip in the op. Default: -1.0. | |||
| num_proj (int): An integer identifying the num proj in the op. Default: 0. | |||
| time_major (bool): A bool identifying the time major in the op. Default: True. | |||
| activation (str) : A string identifying the type of activation function in the op. Default: 'tanh'. | |||
| Only 'tanh' is currently supported. | |||
| gate_order (str): A string identifying the gate order in weight and bias. Default: 'rzh. | |||
| 'zrh' is another option. | |||
| reset_after (bool): A bool identifying whether to apply reset gate after matrix multiplication. Default: True. | |||
| is_training (bool): A bool identifying is training in the op. Default: True. | |||
| Inputs: | |||
| - **x** (Tensor) - Current words. | |||
| Tensor of shape :math:`(\text{num_step}, \text{batch_size}, \text{input_size})`. | |||
| The data type must be float16. | |||
| - **weight_input** (Tensor) - Input-hidden weight. | |||
| Tensor of shape :math:`(\text{input_size}, 3 \times \text{hidden_size})`. | |||
| The data type must be float16. | |||
| - **weight_hidden** (Tensor) - Hidden-hidden weight. | |||
| Tensor of shape :math:`(\text{hidden_size}, 3 \times \text{hidden_size})`. | |||
| The data type must be float16. | |||
| - **bias_input** (Tensor) - Input-hidden bias. Tensor of shape :math:`(3 \times \text{hidden_size})`, or None. | |||
| The data type must be float16 or float32. | |||
| - **bias_hidden** (Tensor) - Hidden-hidden bias. Tensor of shape :math:`(3 \times \text{hidden_size})`, or None. | |||
| The data type must be float16 or float32. | |||
| - **seq_length** (Tensor) - The length of each batch. Tensor of shape :math:`(\text{batch_size})`. | |||
| Only `None` is currently supported. | |||
| - **init_h** (Tensor) - Hidden state of initial time. | |||
| Tensor of shape :math:`(\text{batch_size}, \text{hidden_size})`, or None. | |||
| The data type must be float16 or float32. | |||
| Outputs: | |||
| - **y** (Tensor) - A Tensor of shape :math: | |||
| if num_proj > 0 `(num_step, batch_size, min(hidden_size, num_proj)`, | |||
| if num_proj == 0 `(num_step, batch_size, hidden_size)`. | |||
| Has the same data type with input `bais_type`. | |||
| - **output_h** (Tensor) - A Tensor of shape :math:`(\text{num_step}, \text{batch_size}, \text{hidden_size})`. | |||
| Has the same data type with input `bais_type`. | |||
| - **update** (Tensor) - A Tensor of shape :math:`(\text{num_step}, \text{batch_size}, \text{hidden_size})`. | |||
| Has the same data type with input `bais_type`. | |||
| - **reset** (Tensor) - A Tensor of shape :math:`(\text{num_step}, \text{batch_size}, \text{hidden_size})`. | |||
| Has the same data type with input `bais_type`. | |||
| - **new** (Tensor) - A Tensor of shape :math:`(\text{num_step}, \text{batch_size}, \text{hidden_size})`. | |||
| Has the same data type with input `bais_type`. | |||
| - **hidden_new** (Tensor) - A Tensor of shape :math:`(\text{num_step}, \text{batch_size}, \text{hidden_size})`. | |||
| Has the same data type with input `bais_type`. | |||
| - If `bias_input`, `bias_hidden` and `init_h` all are `None`, `bias_type` is float32. | |||
| - If `bias_input` is not `None`, `bias_type` is the date type of `bias_input`. | |||
| - If `bias_input` is `None` and `bias_hidden` is not `None, `bias_type` is the date type of `bias_hidden`. | |||
| - Otherwise, `bias_type` is the date type of `init_h`. | |||
| Examples: | |||
| >>> x = Tensor(np.random.rand(2, 8, 64).astype(np.float16)) | |||
| >>> weight_i = Tensor(np.random.rand(64, 48).astype(np.float16)) | |||
| >>> weight_h = Tensor(np.random.rand(16, 48).astype(np.float16)) | |||
| >>> bias_i = Tensor(np.random.rand(48).astype(np.float16)) | |||
| >>> bias_h = Tensor(np.random.rand(48).astype(np.float16)) | |||
| >>> init_h = Tensor(np.random.rand(8, 16).astype(np.float16)) | |||
| >>> dynamic_gru_v2 = P.DynamicGRUV2() | |||
| >>> output = dynamic_gru_v2(x, weight_i, weight_h, bias_i, bias_h, None, init_h) | |||
| >>> output[0].shape | |||
| (2, 8, 16) | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self, | |||
| direction='UNIDIRECTIONAL', | |||
| cell_depth=1, | |||
| keep_prob=1.0, | |||
| cell_clip=-1.0, | |||
| num_proj=0, | |||
| time_major=True, | |||
| activation="tanh", | |||
| gate_order="rzh", | |||
| reset_after=True, | |||
| is_training=True): | |||
| self.cell_depth = validator.check_value_type("cell_depth", cell_depth, [int], self.name) | |||
| self.keep_prob = validator.check_value_type("keep_prob", keep_prob, [float], self.name) | |||
| self.cell_clip = validator.check_value_type("cell_clip", cell_clip, [float], self.name) | |||
| self.num_proj = validator.check_non_negative_int(num_proj, "num_proj", self.name) | |||
| self.time_major = validator.check_value_type("time_major", time_major, [bool], self.name) | |||
| self.is_training = validator.check_value_type("is_training", is_training, [bool], self.name) | |||
| self.direction = validator.check_string(direction, ['UNIDIRECTIONAL'], "direction", self.name) | |||
| self.activation = validator.check_string(activation, ['tanh'], "activation", self.name) | |||
| self.gate_order = validator.check_string(gate_order, ['zrh', 'rzh'], "gate_order", self.name) | |||
| self.reset_after = validator.check_value_type("reset_after", reset_after, [bool], self.name) | |||
| self.add_prim_attr("io_format", "ND") | |||
| def infer_shape(self, x_shape, winput_shape, whidden_shape, binput_shape, bhidden_shape, seq_shape, h_shape): | |||
| validator.check_int(len(x_shape), 3, Rel.EQ, "x shape", self.name) | |||
| validator.check_int(len(winput_shape), 2, Rel.EQ, "weight input shape rank", self.name) | |||
| validator.check_int(len(whidden_shape), 2, Rel.EQ, "weight hidden shape rank", self.name) | |||
| if binput_shape is not None: | |||
| validator.check_int(len(binput_shape), 1, Rel.EQ, "bias input shape rank", self.name) | |||
| if bhidden_shape is not None: | |||
| validator.check_int(len(bhidden_shape), 1, Rel.EQ, "bias hidden shape rank", self.name) | |||
| if h_shape is not None: | |||
| validator.check_int(len(h_shape), 2, Rel.EQ, "init_h shape rank", self.name) | |||
| if seq_shape is not None: | |||
| raise ValueError(f"For {self.name}, seq_shape should be None.") | |||
| num_step, batch_size, input_size = x_shape | |||
| hidden_size = winput_shape[-1] // 3 | |||
| if winput_shape[-1] % 3 != 0: | |||
| raise ValueError(f"For {self.name}, weight_input_shape[-1] should multiple of 3.") | |||
| validator.check("weight_input_shape[-1]", winput_shape[-1], "weight_hidden_shape[-1]", | |||
| whidden_shape[-1], Rel.EQ, self.name) | |||
| validator.check("bias_input_shape", binput_shape, "bias_hidden_shape", bhidden_shape, Rel.EQ, self.name) | |||
| validator.check("weight_input_shape[0]", winput_shape[0], "input_size", input_size, Rel.EQ, self.name) | |||
| validator.check("weight_hidden_shape[0]", whidden_shape[0], "hidden_size", hidden_size, Rel.EQ, self.name) | |||
| if h_shape is not None: | |||
| validator.check("init_h_shape[0]", h_shape[0], "batch_size", batch_size, Rel.EQ, self.name) | |||
| validator.check("init_h_shape[1]", h_shape[1], "hidden_size", hidden_size, Rel.EQ, self.name) | |||
| if self.num_proj > 0: | |||
| y_shape = (num_step, batch_size, min(hidden_size, self.num_proj)) | |||
| else: | |||
| y_shape = (num_step, batch_size, hidden_size) | |||
| outh_shape = (num_step, batch_size, hidden_size) | |||
| return y_shape, outh_shape, outh_shape, outh_shape, outh_shape, outh_shape | |||
| def infer_dtype(self, x_dtype, winput_dtype, whidden_dtype, binput_dtype, bhidden_dtype, seq_dtype, h_dtype): | |||
| validator.check_tensor_type_same({"x dtype": x_dtype}, (mstype.float16,), self.name) | |||
| validator.check_tensor_type_same({"weight input dtype": winput_dtype}, (mstype.float16,), self.name) | |||
| validator.check_tensor_type_same({"weight hidden dtype": whidden_dtype}, (mstype.float16,), self.name) | |||
| b_dtype = mstype.float32 | |||
| if binput_dtype is not None: | |||
| validator.check_tensor_type_same({"bias input dtype": binput_dtype}, | |||
| (mstype.float16, mstype.float32), self.name) | |||
| b_dtype = binput_dtype | |||
| elif bhidden_dtype is not None: | |||
| validator.check_tensor_type_same({"bias hidden dtype": bhidden_dtype}, | |||
| (mstype.float16, mstype.float32), self.name) | |||
| b_dtype = bhidden_dtype | |||
| elif h_dtype is not None: | |||
| validator.check_tensor_type_same({"init_h dtype": h_dtype}, | |||
| (mstype.float16, mstype.float32), self.name) | |||
| b_dtype = h_dtype | |||
| return b_dtype, b_dtype, b_dtype, b_dtype, b_dtype, b_dtype | |||
| class ConfusionMulGrad(PrimitiveWithInfer): | |||
| """ | |||
| `output0` is the dot product result of input0 and input1. | |||
| @@ -5611,33 +5611,35 @@ class DynamicRNN(PrimitiveWithInfer): | |||
| DynamicRNN Operator. | |||
| Args: | |||
| cell_type (str): An string identifying the cell type in the op. Default: 'LSTM'. | |||
| cell_type (str): A string identifying the cell type in the op. Default: 'LSTM'. | |||
| Only 'LSTM' is currently supported. | |||
| direction (str): An string identifying the direction in the op. Default: 'UNIDIRECTIONAL'. | |||
| direction (str): A string identifying the direction in the op. Default: 'UNIDIRECTIONAL'. | |||
| Only 'UNIDIRECTIONAL' is currently supported. | |||
| cell_depth (int): An integer identifying the cell depth in the op. Default: 1. | |||
| use_peephole (bool): An bool identifying if use peephole in the op. Default: False. | |||
| keep_prob (float): An float identifying the keep prob in the op. Default: 1.0. | |||
| cell_clip (float): An float identifying the cell clip in the op. Default: -1.0. | |||
| use_peephole (bool): A bool identifying if use peephole in the op. Default: False. | |||
| keep_prob (float): A float identifying the keep prob in the op. Default: 1.0. | |||
| cell_clip (float): A float identifying the cell clip in the op. Default: -1.0. | |||
| num_proj (int): An integer identifying the num proj in the op. Default: 0. | |||
| time_major (bool): An bool identifying the time major in the op. Default: True. | |||
| time_major (bool): A bool identifying the time major in the op. Default: True. | |||
| Only `True` is currently supported. | |||
| activation (str): An string identifying the type of activation function in the op. Default: 'tanh'. | |||
| activation (str): A string identifying the type of activation function in the op. Default: 'tanh'. | |||
| Only 'tanh' is currently supported. | |||
| forget_bias (float): An float identifying the forget bias in the op. Default: 0.0. | |||
| is_training (bool): An bool identifying is training in the op. Default: True. | |||
| forget_bias (float): A float identifying the forget bias in the op. Default: 0.0. | |||
| is_training (bool): A bool identifying is training in the op. Default: True. | |||
| Inputs: | |||
| - **x** (Tensor) - Current words. Tensor of shape (`num_step`, `batch_size`, `input_size`). | |||
| The data type must be float16 or float32. | |||
| The data type must be float16. | |||
| - **w** (Tensor) - Weight. Tensor of shape (`input_size + hidden_size`, `4 x hidden_size`). | |||
| The data type must be float16 or float32. | |||
| The data type must be float16. | |||
| - **b** (Tensor) - Bias. Tensor of shape (`4 x hidden_size`). | |||
| The data type must be float16 or float32. | |||
| - **seq_length** (Tensor) - The length of each batch. Tensor of shape (`batch_size`). | |||
| Only `None` is currently supported. | |||
| - **init_h** (Tensor) - Hidden state of initial time. Tensor of shape (1, `batch_size`, `hidden_size`). | |||
| The data type must be float16. | |||
| - **init_c** (Tensor) - Cell state of initial time. Tensor of shape (1, `batch_size`, `hidden_size`). | |||
| The data type must be float16. | |||
| Outputs: | |||
| - **y** (Tensor) - A Tensor of shape (`num_step`, `batch_size`, `hidden_size`). | |||
| @@ -5664,7 +5666,9 @@ class DynamicRNN(PrimitiveWithInfer): | |||
| >>> init_h = Tensor(np.random.rand(1, 16, 32).astype(np.float16)) | |||
| >>> init_c = Tensor(np.random.rand(1, 16, 32).astype(np.float16)) | |||
| >>> dynamic_rnn = P.DynamicRNN() | |||
| >>> output = lstm(x, w, b, None, init_h, init_c) | |||
| >>> output = dynamic_rnn(x, w, b, None, init_h, init_c) | |||
| >>> output[0].shape | |||
| (2, 16, 32) | |||
| """ | |||
| @prim_attr_register | |||
| @@ -5684,7 +5688,7 @@ class DynamicRNN(PrimitiveWithInfer): | |||
| self.cell_depth = validator.check_value_type("cell_depth", cell_depth, [int], self.name) | |||
| self.keep_prob = validator.check_value_type("keep_prob", keep_prob, [float], self.name) | |||
| self.cell_clip = validator.check_value_type("cell_clip", cell_clip, [float], self.name) | |||
| self.num_proj = validator.check_value_type("num_proj", num_proj, [int], self.name) | |||
| self.num_proj = validator.check_non_negative_int(num_proj, "num_proj", self.name) | |||
| self.forget_bias = validator.check_value_type("forget_bias", forget_bias, [float], self.name) | |||
| self.use_peephole = validator.check_value_type("use_peephole", use_peephole, [bool], self.name) | |||
| self.time_major = validator.check_value_type("time_major", time_major, [bool], self.name) | |||
| @@ -5721,11 +5725,11 @@ class DynamicRNN(PrimitiveWithInfer): | |||
| return y_shape, y_shape, y_shape, y_shape, y_shape, y_shape, y_shape, y_shape | |||
| def infer_dtype(self, x_dtype, w_dtype, b_dtype, seq_dtype, h_dtype, c_dtype): | |||
| validator.check_tensor_type_same({"x dtype": x_dtype}, (mstype.float32, mstype.float16), self.name) | |||
| validator.check_tensor_type_same({"w dtype": w_dtype}, (mstype.float32, mstype.float16), self.name) | |||
| validator.check_tensor_type_same({"x dtype": x_dtype}, (mstype.float16,), self.name) | |||
| validator.check_tensor_type_same({"w dtype": w_dtype}, (mstype.float16,), self.name) | |||
| validator.check_tensor_type_same({"b dtype": b_dtype}, (mstype.float32, mstype.float16), self.name) | |||
| validator.check_tensor_type_same({"h dtype": h_dtype}, (mstype.float32, mstype.float16), self.name) | |||
| validator.check_tensor_type_same({"c dtype": c_dtype}, (mstype.float32, mstype.float16), self.name) | |||
| validator.check_tensor_type_same({"h dtype": h_dtype}, (mstype.float16,), self.name) | |||
| validator.check_tensor_type_same({"c dtype": c_dtype}, (mstype.float16,), self.name) | |||
| return b_dtype, x_dtype, b_dtype, b_dtype, b_dtype, b_dtype, b_dtype, b_dtype | |||
| @@ -817,6 +817,17 @@ class BasicLSTMCellNet(nn.Cell): | |||
| return self.lstm(x, h, c, w, b) | |||
| class DynamicGRUV2Net(nn.Cell): | |||
| """ DynamicGRUV2Net definition """ | |||
| def __init__(self): | |||
| super(DynamicGRUV2Net, self).__init__() | |||
| self.dynamic_gru = inner.DynamicGRUV2() | |||
| def construct(self, x, w_i, w_h, b_i, b_h, init_h): | |||
| return self.dynamic_gru(x, w_i, w_h, b_i, b_h, None, init_h) | |||
| class EditDistance(nn.Cell): | |||
| def __init__(self, hypothesis_shape, truth_shape, normalize=True): | |||
| super(EditDistance, self).__init__() | |||
| @@ -2508,6 +2519,19 @@ test_case_other_ops = [ | |||
| Tensor(np.random.rand(1, 64).astype(np.float16)), | |||
| Tensor(np.random.rand(1, 64).astype(np.float16)), | |||
| Tensor(np.random.rand(1, 64).astype(np.float16))]}), | |||
| ('DynamicGRUV2Net', { | |||
| 'block': DynamicGRUV2Net(), | |||
| 'desc_inputs': [Tensor(np.random.rand(2, 8, 64).astype(np.float16)), | |||
| Tensor(np.random.rand(64, 48).astype(np.float16)), | |||
| Tensor(np.random.rand(16, 48).astype(np.float16)), | |||
| Tensor(np.random.rand(48).astype(np.float16)), | |||
| Tensor(np.random.rand(48).astype(np.float16)), | |||
| Tensor(np.random.rand(8, 16).astype(np.float16))], | |||
| 'desc_bprop': [Tensor(np.random.rand(2, 8, 16).astype(np.float16)), | |||
| Tensor(np.random.rand(2, 8, 16).astype(np.float16)), | |||
| Tensor(np.random.rand(2, 8, 16).astype(np.float16)), | |||
| Tensor(np.random.rand(2, 8, 16).astype(np.float16)), | |||
| Tensor(np.random.rand(2, 8, 16).astype(np.float16))]}), | |||
| ] | |||
| test_case_quant_ops = [ | |||