Merge pull request !140 from zhaozhenlong/op/lstm-vm-incutags/v0.6.0-beta
| @@ -299,6 +299,9 @@ class Validator: | |||
| def get_typename(t): | |||
| return t.__name__ if hasattr(t, '__name__') else str(t) | |||
| if isinstance(arg_type, type(mstype.tensor)): | |||
| arg_type = arg_type.element_type() | |||
| if arg_type in valid_types: | |||
| return arg_type | |||
| type_names = [get_typename(t) for t in valid_types] | |||
| @@ -709,3 +709,25 @@ def get_bprop_ctc_loss(self): | |||
| return grad, zeros_like(labels_indices), zeros_like(labels_values), zeros_like(sequence_length) | |||
| return bprop | |||
| @bprop_getters.register(P.BasicLSTMCell) | |||
| def get_bprop_basic_lstm_cell(self): | |||
| """Grad definition for `BasicLSTMCell` operation.""" | |||
| basic_lstm_cell_cstate_grad = G.BasicLSTMCellCStateGrad( | |||
| forget_bias=self.forget_bias, | |||
| activation=self.activation | |||
| ) | |||
| basic_lstm_cell_weight_grad = G.BasicLSTMCellWeightGrad() | |||
| basic_lstm_cell_input_grad = G.BasicLSTMCellInputGrad(keep_prob=self.keep_prob) | |||
| def bprop(x, h, c, w, b, out, dout): | |||
| _, _, it, jt, ft, ot, tanhct = out | |||
| dct, dht, _, _, _, _, _ = dout | |||
| dgate, dct_1 = basic_lstm_cell_cstate_grad(c, dht, dct, it, jt, ft, ot, tanhct) | |||
| dxt, dht = basic_lstm_cell_input_grad(dgate, w) | |||
| dw, db = basic_lstm_cell_weight_grad(F.depend(x, dxt), h, dgate) | |||
| return dxt, dht, dct_1, dw, db | |||
| return bprop | |||
| @@ -227,3 +227,7 @@ from .asinh_grad import _asinh_grad_tbe | |||
| from .atan import _atan_tbe | |||
| from .atan_grad import _atan_grad_tbe | |||
| from .atanh import _atanh_tbe | |||
| from .basic_lstm_cell import _basic_lstm_cell_tbe | |||
| from .basic_lstm_cell_c_state_grad import _basic_lstm_cell_c_state_grad_tbe | |||
| from .basic_lstm_cell_weight_grad import _basic_lstm_cell_weight_grad_tbe | |||
| from .basic_lstm_cell_input_grad import _basic_lstm_cell_input_grad_tbe | |||
| @@ -0,0 +1,57 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """BasicLSTMCell op""" | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| basic_lstm_cell_op_info = TBERegOp("BasicLSTMCell") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("basic_lstm_cell.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("basic_lstm_cell") \ | |||
| .attr("keep_prob", "optional", "float", "all") \ | |||
| .attr("forget_bias", "optional", "float", "all") \ | |||
| .attr("state_is_tuple", "optional", "bool", "true") \ | |||
| .attr("activation", "optional", "str", "all") \ | |||
| .partial_flag(True) \ | |||
| .input(0, "x", False, "required", "all") \ | |||
| .input(1, "h", False, "required", "all") \ | |||
| .input(2, "c", False, "required", "all") \ | |||
| .input(3, "w", False, "required", "all") \ | |||
| .input(4, "b", False, "required", "all") \ | |||
| .input(5, "mask", False, "optional", "all") \ | |||
| .output(0, "ct", False, "required", "all") \ | |||
| .output(1, "ht", False, "required", "all") \ | |||
| .output(2, "it", False, "optional", "all") \ | |||
| .output(3, "jt", False, "optional", "all") \ | |||
| .output(4, "ft", False, "optional", "all") \ | |||
| .output(5, "ot", False, "optional", "all") \ | |||
| .output(6, "tanhct", False, "optional", "all") \ | |||
| .dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F32_FracNZ, DataType.F16_FracZ, | |||
| DataType.F32_Default, DataType.U8_Default, DataType.F32_FracNZ, DataType.F16_FracNZ, | |||
| DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, | |||
| DataType.F32_FracNZ) \ | |||
| .dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracZ, | |||
| DataType.F16_Default, DataType.U8_Default, DataType.F16_FracNZ, DataType.F16_FracNZ, | |||
| DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, | |||
| DataType.F16_FracNZ) \ | |||
| .get_op_info() | |||
| @op_info_register(basic_lstm_cell_op_info) | |||
| def _basic_lstm_cell_tbe(): | |||
| """BasicLSTMCell TBE register""" | |||
| return | |||
| @@ -0,0 +1,50 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """BasicLSTMCellCStateGrad op""" | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| basic_lstm_cell_c_state_grad_op_info = TBERegOp("BasicLSTMCellCStateGrad") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("basic_lstm_cell_c_state_grad.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("basic_lstm_cell_c_state_grad") \ | |||
| .attr("forget_bias", "optional", "float", "all") \ | |||
| .attr("activation", "optional", "str", "all") \ | |||
| .partial_flag(True) \ | |||
| .input(0, "c", False, "required", "all") \ | |||
| .input(1, "dht", False, "required", "all") \ | |||
| .input(2, "dct", False, "required", "all") \ | |||
| .input(3, "it", False, "required", "all") \ | |||
| .input(4, "ft", False, "required", "all") \ | |||
| .input(5, "jt", False, "required", "all") \ | |||
| .input(6, "ot", False, "required", "all") \ | |||
| .input(7, "tanhct", False, "required", "all") \ | |||
| .output(0, "dgate", False, "required", "all") \ | |||
| .output(1, "dct_1", False, "required", "all") \ | |||
| .dtype_format(DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, | |||
| DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, | |||
| DataType.F16_FracNZ, DataType.F16_FracNZ) \ | |||
| .dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, | |||
| DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, | |||
| DataType.F32_FracNZ, DataType.F16_FracNZ) \ | |||
| .get_op_info() | |||
| @op_info_register(basic_lstm_cell_c_state_grad_op_info) | |||
| def _basic_lstm_cell_c_state_grad_tbe(): | |||
| """BasicLSTMCellCStateGrad TBE register""" | |||
| return | |||
| @@ -0,0 +1,42 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """BasicLSTMCellInputGrad op""" | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| basic_lstm_cell_input_grad_op_info = TBERegOp("BasicLSTMCellInputGrad") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("basic_lstm_cell_input_grad.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("basic_lstm_cell_input_grad") \ | |||
| .attr("keep_prob", "optional", "float", "all") \ | |||
| .partial_flag(True) \ | |||
| .input(0, "dgate", False, "required", "all") \ | |||
| .input(1, "w", False, "required", "all") \ | |||
| .input(2, "dropout_mask", False, "optional", "all") \ | |||
| .output(0, "dxt", False, "required", "all") \ | |||
| .output(1, "dht", False, "required", "all") \ | |||
| .dtype_format(DataType.F16_FracNZ, DataType.F16_FracZ, DataType.U8_Default, DataType.F32_FracNZ, | |||
| DataType.F32_FracNZ) \ | |||
| .dtype_format(DataType.F16_FracNZ, DataType.F16_FracZ, DataType.U8_Default, DataType.F16_FracNZ, | |||
| DataType.F16_FracNZ) \ | |||
| .get_op_info() | |||
| @op_info_register(basic_lstm_cell_input_grad_op_info) | |||
| def _basic_lstm_cell_input_grad_tbe(): | |||
| """BasicLSTMCellInputGrad TBE register""" | |||
| return | |||
| @@ -0,0 +1,41 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """BasicLSTMCellWeightGrad op""" | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| basic_lstm_cell_weight_grad_op_info = TBERegOp("BasicLSTMCellWeightGrad") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("basic_lstm_cell_weight_grad.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("basic_lstm_cell_weight_grad") \ | |||
| .partial_flag(True) \ | |||
| .input(0, "x", False, "required", "all") \ | |||
| .input(1, "h", False, "required", "all") \ | |||
| .input(2, "dgate", False, "required", "all") \ | |||
| .output(0, "dw", False, "required", "all") \ | |||
| .output(1, "db", False, "required", "all") \ | |||
| .dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracZ, | |||
| DataType.F32_Default) \ | |||
| .dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracZ, | |||
| DataType.F16_Default) \ | |||
| .get_op_info() | |||
| @op_info_register(basic_lstm_cell_weight_grad_op_info) | |||
| def _basic_lstm_cell_weight_grad_tbe(): | |||
| """BasicLSTMCellWeightGrad TBE register""" | |||
| return | |||
| @@ -71,7 +71,7 @@ from .nn_ops import (LSTM, SGD, Adam, ApplyMomentum, BatchNorm, | |||
| SparseSoftmaxCrossEntropyWithLogits, Tanh, | |||
| TopK, BinaryCrossEntropy, SparseApplyAdagrad, LARSUpdate, ApplyFtrl, SparseApplyFtrl, | |||
| ApplyProximalAdagrad, SparseApplyProximalAdagrad, | |||
| ApplyRMSProp, ApplyCenteredRMSProp) | |||
| ApplyRMSProp, ApplyCenteredRMSProp, BasicLSTMCell) | |||
| from .other_ops import Assign, IOU, BoundingBoxDecode, BoundingBoxEncode, CheckValid, MakeRefKey, CheckBprop | |||
| from . import _quant_ops | |||
| from ._quant_ops import * | |||
| @@ -285,7 +285,8 @@ __all__ = [ | |||
| "BesselI0e", | |||
| "BesselI1e", | |||
| "Atan", | |||
| "Atanh" | |||
| "Atanh", | |||
| "BasicLSTMCell" | |||
| ] | |||
| __all__.extend(_quant_ops.__all__) | |||
| @@ -1173,3 +1173,106 @@ class AtanGrad(PrimitiveWithInfer): | |||
| args = {"x": x, "dout": dout} | |||
| validator.check_tensor_type_same(args, mstype.number_type, self.name) | |||
| return x | |||
| class BasicLSTMCellCStateGrad(PrimitiveWithInfer): | |||
| """Computes the state gradients of BasicLSTMCell.""" | |||
| @prim_attr_register | |||
| def __init__(self, forget_bias, activation): | |||
| self.forget_bias = validator.check_value_type("forget_bias", forget_bias, [float], self.name) | |||
| self.activation = validator.check_string("activation", activation, ['tanh'], self.name) | |||
| def infer_shape(self, c_shape, dht_shape, dct_shape, it_shape, jt_shape, ft_shape, ot_shape, tanhct_shape): | |||
| # dhy and dcy should be same shape | |||
| validator.check_integer("c rank", len(c_shape), 2, Rel.EQ, self.name) | |||
| validator.check("dht rank", len(dht_shape), "c rank", len(c_shape), Rel.EQ, self.name) | |||
| validator.check("dct rank", len(dct_shape), "c rank", len(c_shape), Rel.EQ, self.name) | |||
| validator.check("it rank", len(it_shape), "c rank", len(c_shape), Rel.EQ, self.name) | |||
| validator.check("jt rank", len(jt_shape), "c rank", len(c_shape), Rel.EQ, self.name) | |||
| validator.check("ft rank", len(ft_shape), "c rank", len(c_shape), Rel.EQ, self.name) | |||
| validator.check("ot rank", len(ot_shape), "c rank", len(c_shape), Rel.EQ, self.name) | |||
| validator.check("tanhct rank", len(tanhct_shape), "c rank", len(c_shape), Rel.EQ, self.name) | |||
| validator.check("dht shape", dht_shape, "c shape", c_shape, Rel.EQ, self.name) | |||
| validator.check("dct shape", dct_shape, "c shape", c_shape, Rel.EQ, self.name) | |||
| validator.check("it shape", it_shape, "c shape", c_shape, Rel.EQ, self.name) | |||
| validator.check("jt shape", jt_shape, "c shape", c_shape, Rel.EQ, self.name) | |||
| validator.check("ft shape", ft_shape, "c shape", c_shape, Rel.EQ, self.name) | |||
| validator.check("ot shape", ot_shape, "c shape", c_shape, Rel.EQ, self.name) | |||
| validator.check("tanhct shape", tanhct_shape, "c shape", c_shape, Rel.EQ, self.name) | |||
| dgate_shape = (c_shape[0], 4 * c_shape[1]) | |||
| dct_1_shape = c_shape | |||
| return (dgate_shape, dct_1_shape) | |||
| def infer_dtype(self, c_dtype, dht_dtype, dct_dtype, it_dtype, jt_dtype, ft_dtype, ot_dtype, tanhct_dtype): | |||
| validator.check_subclass("c", c_dtype, [mstype.tensor], self.name) | |||
| validator.check_subclass("dht", dht_dtype, [mstype.tensor], self.name) | |||
| validator.check_subclass("dct", dct_dtype, [mstype.tensor], self.name) | |||
| validator.check_subclass("it", it_dtype, [mstype.tensor], self.name) | |||
| validator.check_subclass("jt", jt_dtype, [mstype.tensor], self.name) | |||
| validator.check_subclass("ft", ft_dtype, [mstype.tensor], self.name) | |||
| validator.check_subclass("ot", ot_dtype, [mstype.tensor], self.name) | |||
| validator.check_subclass("tanhct", tanhct_dtype, [mstype.tensor], self.name) | |||
| validator.check_type_name("c", c_dtype, [mstype.float16, mstype.float32], self.name) | |||
| validator.check_type_name("dht", dht_dtype, [mstype.float16, mstype.float32], self.name) | |||
| validator.check_type_name("dct", dct_dtype, [mstype.float16, mstype.float32], self.name) | |||
| validator.check_type_name("it", it_dtype, [mstype.float16, mstype.float32], self.name) | |||
| validator.check_type_name("jt", jt_dtype, [mstype.float16, mstype.float32], self.name) | |||
| validator.check_type_name("ft", ft_dtype, [mstype.float16, mstype.float32], self.name) | |||
| validator.check_type_name("ot", ot_dtype, [mstype.float16, mstype.float32], self.name) | |||
| validator.check_type_name("tanhct", tanhct_dtype, [mstype.float16, mstype.float32], self.name) | |||
| return (c_dtype, c_dtype) | |||
| class BasicLSTMCellWeightGrad(PrimitiveWithInfer): | |||
| """Computes the weight gradients of BasicLSTM.""" | |||
| @prim_attr_register | |||
| def __init__(self): | |||
| pass | |||
| def infer_shape(self, x_shape, h_shape, dgate_shape): | |||
| validator.check_integer("x rank", len(x_shape), 2, Rel.EQ, self.name) | |||
| validator.check("h rank", len(h_shape), " x rank", len(x_shape), Rel.EQ, self.name) | |||
| validator.check("dgate rank", len(dgate_shape), "x rank", len(x_shape), Rel.EQ, self.name) | |||
| validator.check("h_shape[0]", h_shape[0], "x_shape[0]", x_shape[0], Rel.EQ, self.name) | |||
| validator.check("dgate_shape[0]", dgate_shape[0], "h_shape[0]", h_shape[0], Rel.EQ, self.name) | |||
| validator.check("dgate_shape[1]", dgate_shape[1], "4*h_shape[1]", 4 * h_shape[1], Rel.EQ, self.name) | |||
| dw_shape = (dgate_shape[1], x_shape[1] + h_shape[1], 1, 1) | |||
| db_shape = (dgate_shape[1], 1, 1, 1) | |||
| return (dw_shape, db_shape) | |||
| def infer_dtype(self, x_dtype, h_dtype, dgate_dtype): | |||
| validator.check_subclass("x", x_dtype, mstype.tensor, self.name) | |||
| validator.check_subclass("h", h_dtype, mstype.tensor, self.name) | |||
| validator.check_subclass("dgate", dgate_dtype, mstype.tensor, self.name) | |||
| validator.check_type_name("x", x_dtype, [mstype.float16, mstype.float32], self.name) | |||
| validator.check_type_name("h", h_dtype, [mstype.float16, mstype.float32], self.name) | |||
| validator.check_type_name("dgate", dgate_dtype, [mstype.float16, mstype.float32], self.name) | |||
| return (x_dtype, x_dtype) | |||
| class BasicLSTMCellInputGrad(PrimitiveWithInfer): | |||
| """Computes the input gradients of BasicLSTM.""" | |||
| @prim_attr_register | |||
| def __init__(self, keep_prob): | |||
| self.keep_prob = validator.check_value_type("keep_prob", keep_prob, [float], self.name) | |||
| self.keep_prob = validator.check_number_range("keep_prob", keep_prob, 0.0, 1.0, Rel.INC_BOTH, self.name) | |||
| def infer_shape(self, dgate_shape, w_shape): | |||
| validator.check_integer("dgate rank", len(dgate_shape), 2, Rel.EQ, self.name) | |||
| validator.check_integer("w rank", len(w_shape), 4, Rel.EQ, self.name) | |||
| validator.check("dgate_shape[1]", dgate_shape[1], "w_shape[0]", w_shape[0], Rel.EQ, self.name) | |||
| dxt_shape = (dgate_shape[0], w_shape[1] - w_shape[0] // 4) | |||
| dht_shape = (dgate_shape[0], dgate_shape[1] // 4) | |||
| return (dxt_shape, dht_shape) | |||
| def infer_dtype(self, dgate_dtype, w_dtype): | |||
| validator.check_subclass("dgate", dgate_dtype, mstype.tensor, self.name) | |||
| validator.check_subclass("w", w_dtype, mstype.tensor, self.name) | |||
| validator.check_type_name("dgate", dgate_dtype, [mstype.float16, mstype.float32], self.name) | |||
| validator.check_type_name("w", w_dtype, [mstype.float16, mstype.float32], self.name) | |||
| return (dgate_dtype, dgate_dtype) | |||
| @@ -3418,3 +3418,109 @@ class CTCLoss(PrimitiveWithInfer): | |||
| validator.check_tensor_type_same({"labels_values_dtype": labels_values}, [mstype.int32], self.name) | |||
| validator.check_tensor_type_same({"sequence_length_dtype": sequence_length}, [mstype.int32], self.name) | |||
| return inputs, inputs | |||
| class BasicLSTMCell(PrimitiveWithInfer): | |||
| r""" | |||
| Performs the long short term memory(LSTM) on the input. | |||
| .. math:: | |||
| \begin{array}{ll} \\ | |||
| i_t = \sigma(W_{ix} x_t + b_{ix} + W_{ih} h_{(t-1)} + b_{ih}) \\ | |||
| f_t = \sigma(W_{fx} x_t + b_{fx} + W_{fh} h_{(t-1)} + b_{fh}) \\ | |||
| \tilde{c}_t = \tanh(W_{cx} x_t + b_{cx} + W_{ch} h_{(t-1)} + b_{ch}) \\ | |||
| o_t = \sigma(W_{ox} x_t + b_{ox} + W_{oh} h_{(t-1)} + b_{oh}) \\ | |||
| c_t = f_t * c_{(t-1)} + i_t * \tilde{c}_t \\ | |||
| h_t = o_t * \tanh(c_t) \\ | |||
| \end{array} | |||
| Here :math:`\sigma` is the sigmoid function, and :math:`*` is the Hadamard product. :math:`W, b` | |||
| are learnable weights between the output and the input in the formula. For instance, | |||
| :math:`W_{ix}, b_{ix}` are the weight and bias used to transform from input :math:`x` to :math:`i`. | |||
| Details can be found in paper `LONG SHORT-TERM MEMORY | |||
| <https://www.bioinf.jku.at/publications/older/2604.pdf>`_ and | |||
| `Long Short-Term Memory Recurrent Neural Network Architectures for Large Scale Acoustic Modeling | |||
| <https://static.googleusercontent.com/media/research.google.com/zh-CN//pubs/archive/43905.pdf>`_. | |||
| Args: | |||
| keep_prob (float): If not 1.0, append `Dropout` layer on the outputs of each | |||
| LSTM layer except the last layer. Default 1.0. The range of dropout is [0.0, 1.0]. | |||
| forget_bias (float): Add forget bias to forget gate biases in order to decrease former scale. Default to 1.0. | |||
| state_is_tuple (bool): If True, state is tensor tuple, containing h and c; If False, one tensor, | |||
| need split first. Default to True. | |||
| activation (str): Activation. Default to "tanh". | |||
| Inputs: | |||
| - **x** (Tensor) - Current words. Tensor of shape (`batch_size`, `input_size`). | |||
| - **h** (Tensor) - Hidden state last moment. Tensor of shape (`batch_size`, `hidden_size`). | |||
| - **c** (Tensor) - Cell state last moment. Tensor of shape (`batch_size`, `hidden_size`). | |||
| - **w** (Tensor) - Weight. Tensor of shape (`4 x hidden_size`, `input_size + hidden_size`, 1, 1). | |||
| - **b** (Tensor) - Bias. Tensor of shape (`4 x hidden_size`, 1, 1, 1). | |||
| Outputs: | |||
| - **ct** (Tensor) - Forward :math:`c_t` cache at moment `t`. Tensor of shape (`batch_size`, `hidden_size`). | |||
| - **ht** (Tensor) - Cell output. Tensor of shape (`batch_size`, `hidden_size`). | |||
| - **it** (Tensor) - Forward :math:`i_t` cache at moment `t`. Tensor of shape (`batch_size`, `4 x hidden_size`). | |||
| - **jt** (Tensor) - Forward :math:`j_t` cache at moment `t`. Tensor of shape (`batch_size`, `4 x hidden_size`). | |||
| - **ft** (Tensor) - Forward :math:`f_t` cache at moment `t`. Tensor of shape (`batch_size`, `4 x hidden_size`). | |||
| - **ot** (Tensor) - Forward :math:`o_t` cache at moment `t`. Tensor of shape (`batch_size`, `4 x hidden_size`). | |||
| - **tanhct** (Tensor) - Forward :math:`tanh c_t` cache at moment `t`. | |||
| Tensor of shape (`batch_size`, `4 x hidden_size`). | |||
| Examples: | |||
| 'block': P.BasicLSTMCell(keep_prob=1.0, forget_bias=1.0, state_is_tuple=True, activation='tanh'), | |||
| 'desc_inputs': [[128, 128], [128, 128], [128, 128], [512, 256, 1, 1],[512, 1, 1, 1]], | |||
| 'desc_bprop': [[128, 128], [128, 128], [128, 128], [128, 128], [128, 128], [128, 128], [128, 128]], | |||
| >>> x = Tensor(np.random.rand(128, 128).astype(np.float16)) | |||
| >>> h = Tensor(np.random.rand(128, 128).astype(np.float16)) | |||
| >>> c = Tensor(np.random.rand(128, 128).astype(np.float16)) | |||
| >>> w = Tensor(np.random.rand(512, 256, 1, 1).astype(np.float16)) | |||
| >>> b = Tensor(np.random.rand(512, 1, 1, 1).astype(np.float16)) | |||
| >>> lstm = P.BasicLSTMCell(keep_prob=1.0, forget_bias=1.0, state_is_tuple=True, activation='tanh') | |||
| >>> lstm(x, h, c, w, b) | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self, keep_prob=1.0, forget_bias=1.0, state_is_tuple=True, activation="tanh"): | |||
| self.keep_prob = validator.check_value_type("keep_prob", keep_prob, [float], self.name) | |||
| self.keep_prob = validator.check_number_range("keep_prob", keep_prob, 0.0, 1.0, Rel.INC_BOTH, self.name) | |||
| self.forget_bias = validator.check_value_type("forget_bias", forget_bias, [float], self.name) | |||
| self.state_is_tuple = validator.check_value_type("state_is_tuple", state_is_tuple, [bool], self.name) | |||
| self.activation = validator.check_string("activation", activation, ['tanh'], self.name) | |||
| def infer_shape(self, x_shape, h_shape, c_shape, w_shape, b_shape): | |||
| # (batch_size, input_size) | |||
| validator.check_integer("x_shape", len(x_shape), 2, Rel.EQ, self.name) | |||
| # h and c should be same shape | |||
| validator.check_integer("h_shape", len(h_shape), 2, Rel.EQ, self.name) | |||
| validator.check("h rank", len(h_shape), "c rank", len(c_shape), Rel.EQ, self.name) | |||
| validator.check("h shape", h_shape, "c shape", c_shape, Rel.EQ, self.name) | |||
| validator.check_integer("w rank", len(w_shape), 4, Rel.EQ, self.name) | |||
| validator.check_integer("b rank", len(b_shape), 4, Rel.EQ, self.name) | |||
| validator.check("w_shape[0]", w_shape[0], "4*h_shape[1]", 4 * h_shape[1], Rel.EQ, self.name) | |||
| validator.check("w_shape[1]", w_shape[1], "x_shape[1]+h_shape[1]", x_shape[1] + h_shape[1], Rel.EQ, self.name) | |||
| validator.check("b_shape[0]", b_shape[0], "4*h_shape[1]", 4*h_shape[1], Rel.EQ, self.name) | |||
| ct_shape = c_shape | |||
| ht_shape = h_shape | |||
| it_shape = h_shape | |||
| jt_shape = h_shape | |||
| ft_shape = h_shape | |||
| ot_shape = h_shape | |||
| tanhct_shape = h_shape | |||
| return (ct_shape, ht_shape, it_shape, jt_shape, ft_shape, ot_shape, tanhct_shape) | |||
| def infer_dtype(self, x_dtype, h_dtype, c_dtype, w_dtype, b_dtype): | |||
| validator.check_subclass("x", x_dtype, [mstype.tensor], self.name) | |||
| validator.check_subclass("h", h_dtype, [mstype.tensor], self.name) | |||
| validator.check_subclass("c", c_dtype, [mstype.tensor], self.name) | |||
| validator.check_subclass("w", w_dtype, [mstype.tensor], self.name) | |||
| validator.check_subclass("b", b_dtype, [mstype.tensor], self.name) | |||
| validator.check_type_name("x", x_dtype, [mstype.float16, mstype.float32], self.name) | |||
| validator.check_type_name("h", h_dtype, [mstype.float16, mstype.float32], self.name) | |||
| validator.check_type_name("c", c_dtype, [mstype.float16, mstype.float32], self.name) | |||
| validator.check_type_name("w", w_dtype, [mstype.float16, mstype.float32], self.name) | |||
| validator.check_type_name("b", b_dtype, [mstype.float16, mstype.float32], self.name) | |||
| return (x_dtype, x_dtype, x_dtype, x_dtype, x_dtype, x_dtype, x_dtype) | |||
| @@ -878,6 +878,11 @@ test_case_nn_ops = [ | |||
| 'desc_inputs': [[128, 64, 32, 32], [128, 64, 32, 32], [64], [64], [64]], | |||
| 'desc_bprop': [[128, 64, 32, 32], [64], [64], [64], [64]], | |||
| 'skip': ['backward']}), | |||
| ('BasicLSTMCell', { | |||
| 'block': P.BasicLSTMCell(keep_prob=1.0, forget_bias=1.0, state_is_tuple=True, activation='tanh'), | |||
| 'desc_inputs': [[128, 128], [128, 128], [128, 128], [512, 256, 1, 1],[512, 1, 1, 1]], | |||
| 'desc_bprop': [[128, 128], [128, 128], [128, 128], [128, 128], [128, 128], [128, 128], [128, 128]], | |||
| 'skip': []}), | |||
| ('TopK', { | |||
| 'block': P.TopK(), | |||
| 'desc_const': [5], | |||