From: @liu_xiao_93 Reviewed-by: @liangchenghui Signed-off-by: @liangchenghuitags/v1.1.0
| @@ -892,7 +892,7 @@ def get_bprop_dynamic_rnn(self): | |||
| return bprop | |||
| @bprop_getters.register(inner.DynamicGRUV2) | |||
| @bprop_getters.register(P.DynamicGRUV2) | |||
| def get_bprop_dynamic_gru_v2(self): | |||
| """Grad definition for `DynamicGRUV2` operation.""" | |||
| dynamic_gru_v2_grad = G.DynamicGRUV2Grad(self.direction, self.cell_depth, self.keep_prob, self.cell_clip, | |||
| @@ -72,7 +72,7 @@ from .nn_ops import (LSTM, SGD, Adam, FusedSparseAdam, FusedSparseLazyAdam, Adam | |||
| MaxPoolWithArgmax, OneHot, Pad, MirrorPad, PReLU, ReLU, ReLU6, ReLUV2, HSwish, HSigmoid, | |||
| ResizeBilinear, Sigmoid, | |||
| SigmoidCrossEntropyWithLogits, | |||
| SmoothL1Loss, Softmax, Softsign, Softplus, LRN, RNNTLoss, DynamicRNN, | |||
| SmoothL1Loss, Softmax, Softsign, Softplus, LRN, RNNTLoss, DynamicRNN, DynamicGRUV2, | |||
| SoftmaxCrossEntropyWithLogits, ROIAlign, | |||
| SparseSoftmaxCrossEntropyWithLogits, Tanh, | |||
| TopK, BinaryCrossEntropy, KLDivLoss, SparseApplyAdagrad, LARSUpdate, ApplyFtrl, SparseApplyFtrl, | |||
| @@ -119,6 +119,7 @@ __all__ = [ | |||
| 'Rsqrt', | |||
| 'Sqrt', | |||
| 'Square', | |||
| 'DynamicGRUV2', | |||
| 'SquaredDifference', | |||
| 'Xdivy', | |||
| 'Xlogy', | |||
| @@ -529,159 +529,6 @@ class MatrixSetDiag(PrimitiveWithInfer): | |||
| return assist_shape | |||
| class DynamicGRUV2(PrimitiveWithInfer): | |||
| r""" | |||
| DynamicGRUV2 Operator. | |||
| Args: | |||
| direction (str): A string identifying the direction in the op. Default: 'UNIDIRECTIONAL'. | |||
| Only 'UNIDIRECTIONAL' is currently supported. | |||
| cell_depth (int): An integer identifying the cell depth in the op. Default: 1. | |||
| keep_prob (float): A float identifying the keep prob in the op. Default: 1.0. | |||
| cell_clip (float): A float identifying the cell clip in the op. Default: -1.0. | |||
| num_proj (int): An integer identifying the num proj in the op. Default: 0. | |||
| time_major (bool): A bool identifying the time major in the op. Default: True. | |||
| activation (str) : A string identifying the type of activation function in the op. Default: 'tanh'. | |||
| Only 'tanh' is currently supported. | |||
| gate_order (str): A string identifying the gate order in weight and bias. Default: 'rzh. | |||
| 'zrh' is another option. | |||
| reset_after (bool): A bool identifying whether to apply reset gate after matrix multiplication. Default: True. | |||
| is_training (bool): A bool identifying is training in the op. Default: True. | |||
| Inputs: | |||
| - **x** (Tensor) - Current words. | |||
| Tensor of shape :math:`(\text{num_step}, \text{batch_size}, \text{input_size})`. | |||
| The data type must be float16. | |||
| - **weight_input** (Tensor) - Input-hidden weight. | |||
| Tensor of shape :math:`(\text{input_size}, 3 \times \text{hidden_size})`. | |||
| The data type must be float16. | |||
| - **weight_hidden** (Tensor) - Hidden-hidden weight. | |||
| Tensor of shape :math:`(\text{hidden_size}, 3 \times \text{hidden_size})`. | |||
| The data type must be float16. | |||
| - **bias_input** (Tensor) - Input-hidden bias. Tensor of shape :math:`(3 \times \text{hidden_size})`, or None. | |||
| The data type must be float16 or float32. | |||
| - **bias_hidden** (Tensor) - Hidden-hidden bias. Tensor of shape :math:`(3 \times \text{hidden_size})`, or None. | |||
| The data type must be float16 or float32. | |||
| - **seq_length** (Tensor) - The length of each batch. Tensor of shape :math:`(\text{batch_size})`. | |||
| Only `None` is currently supported. | |||
| - **init_h** (Tensor) - Hidden state of initial time. | |||
| Tensor of shape :math:`(\text{batch_size}, \text{hidden_size})`. | |||
| The data type must be float16 or float32. | |||
| Outputs: | |||
| - **y** (Tensor) - A Tensor of shape :math: | |||
| if num_proj > 0 `(num_step, batch_size, min(hidden_size, num_proj)`, | |||
| if num_proj == 0 `(num_step, batch_size, hidden_size)`. | |||
| Has the same data type with input `bais_type`. | |||
| - **output_h** (Tensor) - A Tensor of shape :math:`(\text{num_step}, \text{batch_size}, \text{hidden_size})`. | |||
| Has the same data type with input `bais_type`. | |||
| - **update** (Tensor) - A Tensor of shape :math:`(\text{num_step}, \text{batch_size}, \text{hidden_size})`. | |||
| Has the same data type with input `bais_type`. | |||
| - **reset** (Tensor) - A Tensor of shape :math:`(\text{num_step}, \text{batch_size}, \text{hidden_size})`. | |||
| Has the same data type with input `bais_type`. | |||
| - **new** (Tensor) - A Tensor of shape :math:`(\text{num_step}, \text{batch_size}, \text{hidden_size})`. | |||
| Has the same data type with input `bais_type`. | |||
| - **hidden_new** (Tensor) - A Tensor of shape :math:`(\text{num_step}, \text{batch_size}, \text{hidden_size})`. | |||
| Has the same data type with input `bais_type`. | |||
| - If `bias_input` and `bias_hidden` both are `None`, `bias_type` is float32. | |||
| - If `bias_input` is not `None`, `bias_type` is the date type of `bias_input`. | |||
| - If `bias_input` is `None` and `bias_hidden` is not `None, `bias_type` is the date type of `bias_hidden`. | |||
| Examples: | |||
| >>> x = Tensor(np.random.rand(2, 8, 64).astype(np.float16)) | |||
| >>> weight_i = Tensor(np.random.rand(64, 48).astype(np.float16)) | |||
| >>> weight_h = Tensor(np.random.rand(16, 48).astype(np.float16)) | |||
| >>> bias_i = Tensor(np.random.rand(48).astype(np.float16)) | |||
| >>> bias_h = Tensor(np.random.rand(48).astype(np.float16)) | |||
| >>> init_h = Tensor(np.random.rand(8, 16).astype(np.float16)) | |||
| >>> dynamic_gru_v2 = ops.DynamicGRUV2() | |||
| >>> output = dynamic_gru_v2(x, weight_i, weight_h, bias_i, bias_h, None, init_h) | |||
| >>> result = output[0].shape | |||
| >>> print(result) | |||
| (2, 8, 16) | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self, | |||
| direction='UNIDIRECTIONAL', | |||
| cell_depth=1, | |||
| keep_prob=1.0, | |||
| cell_clip=-1.0, | |||
| num_proj=0, | |||
| time_major=True, | |||
| activation="tanh", | |||
| gate_order="rzh", | |||
| reset_after=True, | |||
| is_training=True): | |||
| self.cell_depth = validator.check_value_type("cell_depth", cell_depth, [int], self.name) | |||
| self.keep_prob = validator.check_value_type("keep_prob", keep_prob, [float], self.name) | |||
| self.cell_clip = validator.check_value_type("cell_clip", cell_clip, [float], self.name) | |||
| self.num_proj = validator.check_non_negative_int(num_proj, "num_proj", self.name) | |||
| self.time_major = validator.check_value_type("time_major", time_major, [bool], self.name) | |||
| self.is_training = validator.check_value_type("is_training", is_training, [bool], self.name) | |||
| self.direction = validator.check_string(direction, ['UNIDIRECTIONAL'], "direction", self.name) | |||
| self.activation = validator.check_string(activation, ['tanh'], "activation", self.name) | |||
| self.gate_order = validator.check_string(gate_order, ['zrh', 'rzh'], "gate_order", self.name) | |||
| self.reset_after = validator.check_value_type("reset_after", reset_after, [bool], self.name) | |||
| self.add_prim_attr("io_format", "ND") | |||
| def infer_shape(self, x_shape, winput_shape, whidden_shape, binput_shape, bhidden_shape, seq_shape, h_shape): | |||
| validator.check_int(len(x_shape), 3, Rel.EQ, "x shape", self.name) | |||
| validator.check_int(len(winput_shape), 2, Rel.EQ, "weight input shape rank", self.name) | |||
| validator.check_int(len(whidden_shape), 2, Rel.EQ, "weight hidden shape rank", self.name) | |||
| num_step, batch_size, input_size = x_shape | |||
| hidden_size = winput_shape[-1] // 3 | |||
| if winput_shape[-1] % 3 != 0: | |||
| raise ValueError(f"For {self.name}, weight_input_shape[-1] should multiple of 3.") | |||
| self.placeholder_index = [3, 4, 5] | |||
| if binput_shape is not None: | |||
| validator.check_int(len(binput_shape), 1, Rel.EQ, "bias input shape rank", self.name) | |||
| validator.check("bias_input_shape", binput_shape, "3 * hidden_shape", [3 * hidden_size], Rel.EQ, self.name) | |||
| self.placeholder_index.remove(3) | |||
| if bhidden_shape is not None: | |||
| validator.check_int(len(bhidden_shape), 1, Rel.EQ, "bias hidden shape rank", self.name) | |||
| validator.check("bias_hidden_shape", bhidden_shape, | |||
| "3 * hidden_shape", [3 * hidden_size], Rel.EQ, self.name) | |||
| self.placeholder_index.remove(4) | |||
| if seq_shape is not None: | |||
| raise ValueError(f"For {self.name}, seq_shape should be None.") | |||
| validator.check_int(len(h_shape), 2, Rel.EQ, "init_h shape rank", self.name) | |||
| validator.check("init_h_shape[0]", h_shape[0], "batch_size", batch_size, Rel.EQ, self.name) | |||
| validator.check("init_h_shape[1]", h_shape[1], "hidden_size", hidden_size, Rel.EQ, self.name) | |||
| validator.check("weight_input_shape[-1]", winput_shape[-1], "weight_hidden_shape[-1]", | |||
| whidden_shape[-1], Rel.EQ, self.name) | |||
| validator.check("weight_input_shape[0]", winput_shape[0], "input_size", input_size, Rel.EQ, self.name) | |||
| validator.check("weight_hidden_shape[0]", whidden_shape[0], "hidden_size", hidden_size, Rel.EQ, self.name) | |||
| if self.num_proj > 0: | |||
| y_shape = (num_step, batch_size, min(hidden_size, self.num_proj)) | |||
| else: | |||
| y_shape = (num_step, batch_size, hidden_size) | |||
| out_shape = (num_step, batch_size, hidden_size) | |||
| self.add_prim_attr("placeholder_index", self.placeholder_index) | |||
| return y_shape, out_shape, out_shape, out_shape, out_shape, out_shape | |||
| def infer_dtype(self, x_dtype, winput_dtype, whidden_dtype, binput_dtype, bhidden_dtype, seq_dtype, h_dtype): | |||
| validator.check_tensor_dtype_valid("x dtype", x_dtype, [mstype.float16], self.name) | |||
| validator.check_tensor_dtype_valid("weight input dtype", winput_dtype, [mstype.float16], self.name) | |||
| validator.check_tensor_dtype_valid("weight hidden dtype", whidden_dtype, [mstype.float16], self.name) | |||
| validator.check_tensor_dtype_valid("init_h dtype", h_dtype, (mstype.float16, mstype.float32), self.name) | |||
| b_dtype = mstype.float32 | |||
| if binput_dtype is not None: | |||
| validator.check_tensor_dtype_valid("bias input dtype", binput_dtype, | |||
| (mstype.float16, mstype.float32), self.name) | |||
| b_dtype = binput_dtype | |||
| elif bhidden_dtype is not None: | |||
| validator.check_tensor_dtype_valid("bias hidden dtype", bhidden_dtype, | |||
| (mstype.float16, mstype.float32), self.name) | |||
| b_dtype = bhidden_dtype | |||
| return b_dtype, b_dtype, b_dtype, b_dtype, b_dtype, b_dtype | |||
| class ConfusionMulGrad(PrimitiveWithInfer): | |||
| """ | |||
| `output0` is the dot product result of input0 and input1. | |||
| @@ -6403,32 +6403,18 @@ class DynamicRNN(PrimitiveWithInfer): | |||
| - **tanhct** (Tensor) - A Tensor of shape (`num_step`, `batch_size`, `hidden_size`). | |||
| Has the same type with input `b`. | |||
| Supported Platforms: | |||
| ``Ascend`` | |||
| Examples: | |||
| >>> import mindspore | |||
| >>> import mindspore.nn as nn | |||
| >>> import numpy as np | |||
| >>> from mindspore import Parameter | |||
| >>> from mindspore import Tensor | |||
| >>> from mindspore.ops import operations as ops | |||
| >>> import mindspore.context as context | |||
| >>> context.set_context(mode=context.GRAPH_MODE) | |||
| >>> class DynamicRNNNet(nn.Cell): | |||
| >>> def __init__(self): | |||
| >>> super(DynamicRNNNet, self).__init__() | |||
| >>> self.dynamic_rnn = ops.DynamicRNN() | |||
| >>> | |||
| >>> def construct(self, x, w, b, init_h, init_c): | |||
| >>> out = self.dynamic_rnn(x, w, b, None, init_h, init_c) | |||
| >>> return out | |||
| >>> | |||
| >>> x = Tensor(np.random.rand(2, 16, 64).astype(np.float16)) | |||
| >>> w = Tensor(np.random.rand(96, 128).astype(np.float16)) | |||
| >>> b = Tensor(np.random.rand(128).astype(np.float16)) | |||
| >>> init_h = Tensor(np.random.rand(1, 16, 32).astype(np.float16)) | |||
| >>> init_c = Tensor(np.random.rand(1, 16, 32).astype(np.float16)) | |||
| >>> net = DynamicRNNNet() | |||
| >>> output = net(x, w, b, init_h, init_c) | |||
| >>> output[0].shape | |||
| >>> dynamic_rnn = ops.DynamicRNNN() | |||
| >>> output = dynamic_rnn(x, w, b, None, init_h, init_c) | |||
| >>> print(output[0].shape) | |||
| (2, 16, 32) | |||
| """ | |||
| @@ -6493,6 +6479,161 @@ class DynamicRNN(PrimitiveWithInfer): | |||
| return b_dtype, x_dtype, b_dtype, b_dtype, b_dtype, b_dtype, b_dtype, b_dtype | |||
| class DynamicGRUV2(PrimitiveWithInfer): | |||
| r""" | |||
| DynamicGRUV2 Operator. | |||
| Args: | |||
| direction (str): A string identifying the direction in the op. Default: 'UNIDIRECTIONAL'. | |||
| Only 'UNIDIRECTIONAL' is currently supported. | |||
| cell_depth (int): An integer identifying the cell depth in the op. Default: 1. | |||
| keep_prob (float): A float identifying the keep prob in the op. Default: 1.0. | |||
| cell_clip (float): A float identifying the cell clip in the op. Default: -1.0. | |||
| num_proj (int): An integer identifying the num proj in the op. Default: 0. | |||
| time_major (bool): A bool identifying the time major in the op. Default: True. | |||
| activation (str) : A string identifying the type of activation function in the op. Default: 'tanh'. | |||
| Only 'tanh' is currently supported. | |||
| gate_order (str): A string identifying the gate order in weight and bias. Default: 'rzh. | |||
| 'zrh' is another option. | |||
| reset_after (bool): A bool identifying whether to apply reset gate after matrix multiplication. Default: True. | |||
| is_training (bool): A bool identifying is training in the op. Default: True. | |||
| Inputs: | |||
| - **x** (Tensor) - Current words. | |||
| Tensor of shape :math:`(\text{num_step}, \text{batch_size}, \text{input_size})`. | |||
| The data type must be float16. | |||
| - **weight_input** (Tensor) - Input-hidden weight. | |||
| Tensor of shape :math:`(\text{input_size}, 3 \times \text{hidden_size})`. | |||
| The data type must be float16. | |||
| - **weight_hidden** (Tensor) - Hidden-hidden weight. | |||
| Tensor of shape :math:`(\text{hidden_size}, 3 \times \text{hidden_size})`. | |||
| The data type must be float16. | |||
| - **bias_input** (Tensor) - Input-hidden bias. Tensor of shape :math:`(3 \times \text{hidden_size})`, or None. | |||
| The data type must be float16 or float32. | |||
| - **bias_hidden** (Tensor) - Hidden-hidden bias. Tensor of shape :math:`(3 \times \text{hidden_size})`, or None. | |||
| The data type must be float16 or float32. | |||
| - **seq_length** (Tensor) - The length of each batch. Tensor of shape :math:`(\text{batch_size})`. | |||
| Only `None` is currently supported. | |||
| - **init_h** (Tensor) - Hidden state of initial time. | |||
| Tensor of shape :math:`(\text{batch_size}, \text{hidden_size})`. | |||
| The data type must be float16 or float32. | |||
| Outputs: | |||
| - **y** (Tensor) - A Tensor of shape :math: | |||
| if num_proj > 0 `(num_step, batch_size, min(hidden_size, num_proj)`, | |||
| if num_proj == 0 `(num_step, batch_size, hidden_size)`. | |||
| Has the same data type with input `bais_type`. | |||
| - **output_h** (Tensor) - A Tensor of shape :math:`(\text{num_step}, \text{batch_size}, \text{hidden_size})`. | |||
| Has the same data type with input `bais_type`. | |||
| - **update** (Tensor) - A Tensor of shape :math:`(\text{num_step}, \text{batch_size}, \text{hidden_size})`. | |||
| Has the same data type with input `bais_type`. | |||
| - **reset** (Tensor) - A Tensor of shape :math:`(\text{num_step}, \text{batch_size}, \text{hidden_size})`. | |||
| Has the same data type with input `bais_type`. | |||
| - **new** (Tensor) - A Tensor of shape :math:`(\text{num_step}, \text{batch_size}, \text{hidden_size})`. | |||
| Has the same data type with input `bais_type`. | |||
| - **hidden_new** (Tensor) - A Tensor of shape :math:`(\text{num_step}, \text{batch_size}, \text{hidden_size})`. | |||
| Has the same data type with input `bais_type`. | |||
| - If `bias_input` and `bias_hidden` both are `None`, `bias_type` is float32. | |||
| - If `bias_input` is not `None`, `bias_type` is the date type of `bias_input`. | |||
| - If `bias_input` is `None` and `bias_hidden` is not `None, `bias_type` is the date type of `bias_hidden`. | |||
| Supported Platforms: | |||
| ``Ascend`` | |||
| Examples: | |||
| >>> x = Tensor(np.random.rand(2, 8, 64).astype(np.float16)) | |||
| >>> weight_i = Tensor(np.random.rand(64, 48).astype(np.float16)) | |||
| >>> weight_h = Tensor(np.random.rand(16, 48).astype(np.float16)) | |||
| >>> bias_i = Tensor(np.random.rand(48).astype(np.float16)) | |||
| >>> bias_h = Tensor(np.random.rand(48).astype(np.float16)) | |||
| >>> init_h = Tensor(np.random.rand(8, 16).astype(np.float16)) | |||
| >>> dynamic_gru_v2 = ops.DynamicGRUV2() | |||
| >>> output = dynamic_gru_v2(x, weight_i, weight_h, bias_i, bias_h, None, init_h) | |||
| >>> print(output[0].shape) | |||
| (2, 8, 16) | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self, | |||
| direction='UNIDIRECTIONAL', | |||
| cell_depth=1, | |||
| keep_prob=1.0, | |||
| cell_clip=-1.0, | |||
| num_proj=0, | |||
| time_major=True, | |||
| activation="tanh", | |||
| gate_order="rzh", | |||
| reset_after=True, | |||
| is_training=True): | |||
| self.cell_depth = validator.check_value_type("cell_depth", cell_depth, [int], self.name) | |||
| self.keep_prob = validator.check_value_type("keep_prob", keep_prob, [float], self.name) | |||
| self.cell_clip = validator.check_value_type("cell_clip", cell_clip, [float], self.name) | |||
| self.num_proj = validator.check_non_negative_int(num_proj, "num_proj", self.name) | |||
| self.time_major = validator.check_value_type("time_major", time_major, [bool], self.name) | |||
| self.is_training = validator.check_value_type("is_training", is_training, [bool], self.name) | |||
| self.direction = validator.check_string(direction, ['UNIDIRECTIONAL'], "direction", self.name) | |||
| self.activation = validator.check_string(activation, ['tanh'], "activation", self.name) | |||
| self.gate_order = validator.check_string(gate_order, ['zrh', 'rzh'], "gate_order", self.name) | |||
| self.reset_after = validator.check_value_type("reset_after", reset_after, [bool], self.name) | |||
| self.add_prim_attr("io_format", "ND") | |||
| def infer_shape(self, x_shape, winput_shape, whidden_shape, binput_shape, bhidden_shape, seq_shape, h_shape): | |||
| validator.check_int(len(x_shape), 3, Rel.EQ, "x shape", self.name) | |||
| validator.check_int(len(winput_shape), 2, Rel.EQ, "weight input shape rank", self.name) | |||
| validator.check_int(len(whidden_shape), 2, Rel.EQ, "weight hidden shape rank", self.name) | |||
| num_step, batch_size, input_size = x_shape | |||
| hidden_size = winput_shape[-1] // 3 | |||
| if winput_shape[-1] % 3 != 0: | |||
| raise ValueError(f"For {self.name}, weight_input_shape[-1] should multiple of 3.") | |||
| self.placeholder_index = [3, 4, 5] | |||
| if binput_shape is not None: | |||
| validator.check_int(len(binput_shape), 1, Rel.EQ, "bias input shape rank", self.name) | |||
| validator.check("bias_input_shape", binput_shape, "3 * hidden_shape", [3 * hidden_size], Rel.EQ, self.name) | |||
| self.placeholder_index.remove(3) | |||
| if bhidden_shape is not None: | |||
| validator.check_int(len(bhidden_shape), 1, Rel.EQ, "bias hidden shape rank", self.name) | |||
| validator.check("bias_hidden_shape", bhidden_shape, | |||
| "3 * hidden_shape", [3 * hidden_size], Rel.EQ, self.name) | |||
| self.placeholder_index.remove(4) | |||
| if seq_shape is not None: | |||
| raise ValueError(f"For {self.name}, seq_shape should be None.") | |||
| validator.check_int(len(h_shape), 2, Rel.EQ, "init_h shape rank", self.name) | |||
| validator.check("init_h_shape[0]", h_shape[0], "batch_size", batch_size, Rel.EQ, self.name) | |||
| validator.check("init_h_shape[1]", h_shape[1], "hidden_size", hidden_size, Rel.EQ, self.name) | |||
| validator.check("weight_input_shape[-1]", winput_shape[-1], "weight_hidden_shape[-1]", | |||
| whidden_shape[-1], Rel.EQ, self.name) | |||
| validator.check("weight_input_shape[0]", winput_shape[0], "input_size", input_size, Rel.EQ, self.name) | |||
| validator.check("weight_hidden_shape[0]", whidden_shape[0], "hidden_size", hidden_size, Rel.EQ, self.name) | |||
| if self.num_proj > 0: | |||
| y_shape = (num_step, batch_size, min(hidden_size, self.num_proj)) | |||
| else: | |||
| y_shape = (num_step, batch_size, hidden_size) | |||
| out_shape = (num_step, batch_size, hidden_size) | |||
| self.add_prim_attr("placeholder_index", self.placeholder_index) | |||
| return y_shape, out_shape, out_shape, out_shape, out_shape, out_shape | |||
| def infer_dtype(self, x_dtype, winput_dtype, whidden_dtype, binput_dtype, bhidden_dtype, seq_dtype, h_dtype): | |||
| validator.check_tensor_dtype_valid("x dtype", x_dtype, [mstype.float16], self.name) | |||
| validator.check_tensor_dtype_valid("weight input dtype", winput_dtype, [mstype.float16], self.name) | |||
| validator.check_tensor_dtype_valid("weight hidden dtype", whidden_dtype, [mstype.float16], self.name) | |||
| validator.check_tensor_dtype_valid("init_h dtype", h_dtype, (mstype.float16, mstype.float32), self.name) | |||
| b_dtype = mstype.float32 | |||
| if binput_dtype is not None: | |||
| validator.check_tensor_dtype_valid("bias input dtype", binput_dtype, | |||
| (mstype.float16, mstype.float32), self.name) | |||
| b_dtype = binput_dtype | |||
| elif bhidden_dtype is not None: | |||
| validator.check_tensor_dtype_valid("bias hidden dtype", bhidden_dtype, | |||
| (mstype.float16, mstype.float32), self.name) | |||
| b_dtype = bhidden_dtype | |||
| return b_dtype, b_dtype, b_dtype, b_dtype, b_dtype, b_dtype | |||
| class InTopK(PrimitiveWithInfer): | |||
| r""" | |||
| Determines whether the targets are in the top `k` predictions. | |||
| @@ -822,7 +822,7 @@ class DynamicGRUV2Net(nn.Cell): | |||
| def __init__(self): | |||
| super(DynamicGRUV2Net, self).__init__() | |||
| self.dynamic_gru = inner.DynamicGRUV2() | |||
| self.dynamic_gru = P.DynamicGRUV2() | |||
| def construct(self, x, w_i, w_h, b_i, b_h, init_h): | |||
| return self.dynamic_gru(x, w_i, w_h, b_i, b_h, None, init_h) | |||