| @@ -33,6 +33,7 @@ static std::map<string, string> tbe_func_adapter_map = { | |||||
| {"softmax", "softmax_v2"}, | {"softmax", "softmax_v2"}, | ||||
| {"log_softmax", "log_softmax_v2"}, | {"log_softmax", "log_softmax_v2"}, | ||||
| {"apply_momentum", "apply_momentum_d"}, | {"apply_momentum", "apply_momentum_d"}, | ||||
| {"apply_ftrl", "apply_ftrl_d"}, | |||||
| {"re_lu6", "relu6"}, | {"re_lu6", "relu6"}, | ||||
| {"re_lu6_grad", "relu6_grad"}, | {"re_lu6_grad", "relu6_grad"}, | ||||
| {"re_lu", "relu"}, | {"re_lu", "relu"}, | ||||
| @@ -384,7 +384,7 @@ std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_ma | |||||
| {string(kNameDepthToSpace), ADPT_DESC(DepthToSpace)}, | {string(kNameDepthToSpace), ADPT_DESC(DepthToSpace)}, | ||||
| {string(kNameSign), ADPT_DESC(Sign)}, | {string(kNameSign), ADPT_DESC(Sign)}, | ||||
| {string(kNameRound), ADPT_DESC(Round)}, | {string(kNameRound), ADPT_DESC(Round)}, | ||||
| {string(kNameApplyFtrl), ADPT_DESC(ApplyFtrl)}, | |||||
| {string(kNameApplyFtrl), ADPT_DESC(ApplyFtrlD)}, | |||||
| {string(kNameDiag), ADPT_DESC(Diag)}, | {string(kNameDiag), ADPT_DESC(Diag)}, | ||||
| {string(kNameDiagPart), ADPT_DESC(DiagPart)}, | {string(kNameDiagPart), ADPT_DESC(DiagPart)}, | ||||
| {string(kNameSpaceToBatch), ADPT_DESC(SpaceToBatchD)}, | {string(kNameSpaceToBatch), ADPT_DESC(SpaceToBatchD)}, | ||||
| @@ -1176,11 +1176,11 @@ ATTR_MAP(Round) = EMPTY_ATTR_MAP; | |||||
| OUTPUT_MAP(Round) = {{0, OUTPUT_DESC(y)}}; | OUTPUT_MAP(Round) = {{0, OUTPUT_DESC(y)}}; | ||||
| // ApplyFtrl | // ApplyFtrl | ||||
| INPUT_MAP(ApplyFtrl) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(accum)}, {3, INPUT_DESC(linear)}, | |||||
| {4, INPUT_DESC(grad)}, {5, INPUT_DESC(lr)}, {6, INPUT_DESC(l1)}, | |||||
| {7, INPUT_DESC(l2)}, {8, INPUT_DESC(lr_power)}}; | |||||
| ATTR_MAP(ApplyFtrl) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())}}; | |||||
| OUTPUT_MAP(ApplyFtrl) = {{0, OUTPUT_DESC(var)}}; | |||||
| INPUT_MAP(ApplyFtrlD) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(accum)}, {3, INPUT_DESC(linear)}, | |||||
| {4, INPUT_DESC(grad)}, {5, INPUT_DESC(lr)}, {6, INPUT_DESC(l1)}, | |||||
| {7, INPUT_DESC(l2)}, {8, INPUT_DESC(lr_power)}}; | |||||
| ATTR_MAP(ApplyFtrlD) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())}}; | |||||
| OUTPUT_MAP(ApplyFtrlD) = {{0, OUTPUT_DESC(var)}, {1, OUTPUT_DESC(accum)}, {2, OUTPUT_DESC(linear)}}; | |||||
| // Diag | // Diag | ||||
| INPUT_MAP(Diag) = {{1, INPUT_DESC(x)}}; | INPUT_MAP(Diag) = {{1, INPUT_DESC(x)}}; | ||||
| @@ -446,8 +446,8 @@ DECLARE_OP_ADAPTER(LarsV2Update) | |||||
| DECLARE_OP_USE_OUTPUT(LarsV2Update) | DECLARE_OP_USE_OUTPUT(LarsV2Update) | ||||
| DECLARE_OP_ADAPTER(Round) | DECLARE_OP_ADAPTER(Round) | ||||
| DECLARE_OP_USE_OUTPUT(Round) | DECLARE_OP_USE_OUTPUT(Round) | ||||
| DECLARE_OP_ADAPTER(ApplyFtrl) | |||||
| DECLARE_OP_USE_OUTPUT(ApplyFtrl) | |||||
| DECLARE_OP_ADAPTER(ApplyFtrlD) | |||||
| DECLARE_OP_USE_OUTPUT(ApplyFtrlD) | |||||
| DECLARE_OP_ADAPTER(SparseApplyFtrlD) | DECLARE_OP_ADAPTER(SparseApplyFtrlD) | ||||
| DECLARE_OP_USE_OUTPUT(SparseApplyFtrlD) | DECLARE_OP_USE_OUTPUT(SparseApplyFtrlD) | ||||
| DECLARE_OP_ADAPTER(Diag) | DECLARE_OP_ADAPTER(Diag) | ||||
| @@ -326,6 +326,18 @@ def get_bprop_log_softmax(self): | |||||
| return bprop | return bprop | ||||
| @bprop_getters.register(P.Softplus) | |||||
| def get_bprop_softplus(self): | |||||
| """Grad definition for `Softplus` operation.""" | |||||
| softplus_grad = G.SoftplusGrad() | |||||
| def bprop(x, out, dout): | |||||
| dx = softplus_grad(dout, x) | |||||
| return (dx,) | |||||
| return bprop | |||||
| @bprop_getters.register(P.Tanh) | @bprop_getters.register(P.Tanh) | ||||
| def get_bprop_tanh(self): | def get_bprop_tanh(self): | ||||
| """Grad definition for `Tanh` operation.""" | """Grad definition for `Tanh` operation.""" | ||||
| @@ -100,6 +100,8 @@ from .round import _round_tbe | |||||
| from .tanh import _tanh_tbe | from .tanh import _tanh_tbe | ||||
| from .tanh_grad import _tanh_grad_tbe | from .tanh_grad import _tanh_grad_tbe | ||||
| from .softmax import _softmax_tbe | from .softmax import _softmax_tbe | ||||
| from .softplus import _softplus_tbe | |||||
| from .softplus_grad import _softplus_grad_tbe | |||||
| from .square import _square_tbe | from .square import _square_tbe | ||||
| from .sqrt import _sqrt_tbe | from .sqrt import _sqrt_tbe | ||||
| from .transpose_d import _transpose_d_tbe | from .transpose_d import _transpose_d_tbe | ||||
| @@ -32,30 +32,32 @@ apply_ftrl_op_info = TBERegOp("ApplyFtrl") \ | |||||
| .input(6, "l2", False, "required", "all") \ | .input(6, "l2", False, "required", "all") \ | ||||
| .input(7, "lr_power", False, "required", "all") \ | .input(7, "lr_power", False, "required", "all") \ | ||||
| .output(0, "var", False, "required", "all") \ | .output(0, "var", False, "required", "all") \ | ||||
| .output(1, "accum", False, "required", "all") \ | |||||
| .output(2, "linear", False, "required", "all") \ | |||||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, | .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, | ||||
| DataType.F16_5HD, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, | DataType.F16_5HD, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, | ||||
| DataType.F16_5HD) \ | |||||
| DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \ | |||||
| .dtype_format(DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_FracZ, | .dtype_format(DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_FracZ, | ||||
| DataType.F16_FracZ, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, | DataType.F16_FracZ, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, | ||||
| DataType.F16_FracZ) \ | |||||
| DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_FracZ) \ | |||||
| .dtype_format(DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, | .dtype_format(DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, | ||||
| DataType.F16_C1HWNCoC0, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, | DataType.F16_C1HWNCoC0, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, | ||||
| DataType.F16_C1HWNCoC0) \ | |||||
| DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0) \ | |||||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, | .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, | ||||
| DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, | DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, | ||||
| DataType.F16_Default) \ | |||||
| DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ | |||||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, | .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, | ||||
| DataType.F32_5HD, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | DataType.F32_5HD, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | ||||
| DataType.F32_5HD) \ | |||||
| DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \ | |||||
| .dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, | .dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, | ||||
| DataType.F32_FracZ, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | DataType.F32_FracZ, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | ||||
| DataType.F32_FracZ) \ | |||||
| DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ) \ | |||||
| .dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, | .dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, | ||||
| DataType.F32_C1HWNCoC0, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | DataType.F32_C1HWNCoC0, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | ||||
| DataType.F32_C1HWNCoC0) \ | |||||
| DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0) \ | |||||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | ||||
| DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | ||||
| DataType.F32_Default) \ | |||||
| DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ | |||||
| .get_op_info() | .get_op_info() | ||||
| @@ -0,0 +1,39 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Softplus op""" | |||||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||||
| softplus_op_info = TBERegOp("Softplus") \ | |||||
| .fusion_type("OPAQUE") \ | |||||
| .async_flag(False) \ | |||||
| .binfile_name("softplus.so") \ | |||||
| .compute_cost(10) \ | |||||
| .kernel_name("softplus") \ | |||||
| .partial_flag(True) \ | |||||
| .op_pattern("formatAgnostic") \ | |||||
| .input(0, "x", False, "required", "all") \ | |||||
| .output(0, "y", False, "required", "all") \ | |||||
| .dtype_format(DataType.F16_Default, DataType.F16_Default) \ | |||||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD) \ | |||||
| .dtype_format(DataType.F32_Default, DataType.F32_Default) \ | |||||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD) \ | |||||
| .get_op_info() | |||||
| @op_info_register(softplus_op_info) | |||||
| def _softplus_tbe(): | |||||
| """Softplus TBE register""" | |||||
| return | |||||
| @@ -0,0 +1,40 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """SoftplusGrad op""" | |||||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||||
| softplus_grad_op_info = TBERegOp("SoftplusGrad") \ | |||||
| .fusion_type("OPAQUE") \ | |||||
| .async_flag(False) \ | |||||
| .binfile_name("softplus_grad.so") \ | |||||
| .compute_cost(10) \ | |||||
| .kernel_name("softplus_grad") \ | |||||
| .partial_flag(True) \ | |||||
| .op_pattern("broadcast") \ | |||||
| .input(0, "gradients", False, "required", "all") \ | |||||
| .input(1, "features", False, "required", "all") \ | |||||
| .output(0, "backprops", False, "required", "all") \ | |||||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ | |||||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \ | |||||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ | |||||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \ | |||||
| .get_op_info() | |||||
| @op_info_register(softplus_grad_op_info) | |||||
| def _softplus_grad_tbe(): | |||||
| """SoftplusGrad TBE register""" | |||||
| return | |||||
| @@ -62,7 +62,7 @@ from .nn_ops import (LSTM, SGD, Adam, ApplyMomentum, BatchNorm, | |||||
| MaxPoolWithArgmax, OneHot, Pad, MirrorPad, PReLU, ReLU, ReLU6, ReLUV2, HSwish, HSigmoid, | MaxPoolWithArgmax, OneHot, Pad, MirrorPad, PReLU, ReLU, ReLU6, ReLUV2, HSwish, HSigmoid, | ||||
| ResizeBilinear, Sigmoid, | ResizeBilinear, Sigmoid, | ||||
| SigmoidCrossEntropyWithLogits, | SigmoidCrossEntropyWithLogits, | ||||
| SmoothL1Loss, Softmax, | |||||
| SmoothL1Loss, Softmax, Softplus, | |||||
| SoftmaxCrossEntropyWithLogits, ROIAlign, | SoftmaxCrossEntropyWithLogits, ROIAlign, | ||||
| SparseSoftmaxCrossEntropyWithLogits, Tanh, | SparseSoftmaxCrossEntropyWithLogits, Tanh, | ||||
| TopK, BinaryCrossEntropy, SparseApplyAdagrad, LARSUpdate, ApplyFtrl, | TopK, BinaryCrossEntropy, SparseApplyAdagrad, LARSUpdate, ApplyFtrl, | ||||
| @@ -974,6 +974,23 @@ class StridedSliceGrad(PrimitiveWithInfer): | |||||
| 'value': None} | 'value': None} | ||||
| class SoftplusGrad(PrimitiveWithInfer): | |||||
| """Computes gradient for the Log Softmax activation.""" | |||||
| @prim_attr_register | |||||
| def __init__(self): | |||||
| self.init_prim_io_names(inputs=['dout', 'x'], outputs=['output']) | |||||
| def infer_shape(self, dout_shape, x_shape): | |||||
| validator.check("x_shape", x_shape, "dout_shape", dout_shape, Rel.EQ, self.name) | |||||
| return x_shape | |||||
| def infer_dtype(self, dout_dtype, x_dtype): | |||||
| args = {"x_dtype": x_dtype, "dout_dtype": dout_dtype} | |||||
| validator.check_tensor_type_same(args, mstype.float_type, self.name) | |||||
| return x_dtype | |||||
| class TanhGrad(PrimitiveWithInfer): | class TanhGrad(PrimitiveWithInfer): | ||||
| """Computes gradient of hyperbolic tangent of input element-wise.""" | """Computes gradient of hyperbolic tangent of input element-wise.""" | ||||
| @@ -183,6 +183,41 @@ class LogSoftmax(PrimitiveWithInfer): | |||||
| return logits | return logits | ||||
| class Softplus(PrimitiveWithInfer): | |||||
| r""" | |||||
| Softplus activation function. | |||||
| Softplus is a smooth approximation to the ReLU function. | |||||
| The function is shown as follows: | |||||
| .. math:: | |||||
| \text{output} = \log(1 + \exp(\text{input_x})), | |||||
| Inputs: | |||||
| - **input_x** (Tensor) - The input tensor whose data type should be float. | |||||
| Outputs: | |||||
| Tensor, with the same type and shape as the `input_x`. | |||||
| Examples: | |||||
| >>> input_x = Tensor(np.array([1, 2, 3, 4, 5]), mindspore.float32) | |||||
| >>> softplus = P.Softplus() | |||||
| >>> softplus(input_x) | |||||
| [1.3132615, 2.126928, 3.0485873, 4.01815, 5.0067153] | |||||
| """ | |||||
| @prim_attr_register | |||||
| def __init__(self): | |||||
| """init Softplus""" | |||||
| self.init_prim_io_names(inputs=['x'], outputs=['output']) | |||||
| def infer_shape(self, input_x): | |||||
| return input_x | |||||
| def infer_dtype(self, input_x): | |||||
| validator.check_tensor_type_same({'input_x': input_x}, mstype.float_type, self.name) | |||||
| return input_x | |||||
| class ReLU(PrimitiveWithInfer): | class ReLU(PrimitiveWithInfer): | ||||
| r""" | r""" | ||||
| Computes ReLU(Rectified Linear Unit) of input tensor element-wise. | Computes ReLU(Rectified Linear Unit) of input tensor element-wise. | ||||
| @@ -2701,11 +2736,14 @@ class ApplyFtrl(PrimitiveWithInfer): | |||||
| self.init_prim_io_names(inputs=['var', 'accum', 'linear', 'grad', 'lr', 'l1', 'l2', 'lr_power'], | self.init_prim_io_names(inputs=['var', 'accum', 'linear', 'grad', 'lr', 'l1', 'l2', 'lr_power'], | ||||
| outputs=['output']) | outputs=['output']) | ||||
| self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name) | self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name) | ||||
| self.is_tbe = context.get_context("device_target") == "Ascend" | |||||
| def infer_shape(self, var_shape, accum_shape, linear_shape, grad_shape, lr_shape, l1_shape, l2_shape, | def infer_shape(self, var_shape, accum_shape, linear_shape, grad_shape, lr_shape, l1_shape, l2_shape, | ||||
| lr_power_shape): | lr_power_shape): | ||||
| validator.check('var shape', var_shape, 'accum shape', accum_shape, Rel.EQ, self.name) | validator.check('var shape', var_shape, 'accum shape', accum_shape, Rel.EQ, self.name) | ||||
| validator.check('var shape', var_shape, 'linear shape', linear_shape, Rel.EQ, self.name) | validator.check('var shape', var_shape, 'linear shape', linear_shape, Rel.EQ, self.name) | ||||
| if self.is_tbe: | |||||
| return var_shape, var_shape, var_shape | |||||
| return var_shape | return var_shape | ||||
| def infer_dtype(self, var_type, accum_type, linear_type, grad_type, lr_type, l1_type, l2_type, lr_power_type): | def infer_dtype(self, var_type, accum_type, linear_type, grad_type, lr_type, l1_type, l2_type, lr_power_type): | ||||
| @@ -2717,6 +2755,8 @@ class ApplyFtrl(PrimitiveWithInfer): | |||||
| validator.check_scalar_or_tensor_type_same({"l1": l1_type}, valid_types, self.name) | validator.check_scalar_or_tensor_type_same({"l1": l1_type}, valid_types, self.name) | ||||
| validator.check_scalar_or_tensor_type_same({"l2": l2_type}, valid_types, self.name) | validator.check_scalar_or_tensor_type_same({"l2": l2_type}, valid_types, self.name) | ||||
| validator.check_scalar_or_tensor_type_same({"lr_power": lr_power_type}, valid_types, self.name) | validator.check_scalar_or_tensor_type_same({"lr_power": lr_power_type}, valid_types, self.name) | ||||
| if self.is_tbe: | |||||
| return var_type, var_type, var_type | |||||
| return var_type | return var_type | ||||
| @@ -185,6 +185,22 @@ class ScatterMax(nn.Cell): | |||||
| out = self.scatter_max(self.ref, indices, updates) | out = self.scatter_max(self.ref, indices, updates) | ||||
| return out | return out | ||||
| class ApplyFtrlNet(nn.Cell): | |||||
| def __init__(self): | |||||
| super(ApplyFtrlNet, self).__init__() | |||||
| self.apply_ftrl = P.ApplyFtrl() | |||||
| self.lr = 0.001 | |||||
| self.l1 = 0.0 | |||||
| self.l2 = 0.0 | |||||
| self.lr_power = -0.5 | |||||
| self.var = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="var") | |||||
| self.accum = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="accum") | |||||
| self.linear = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="linear") | |||||
| def construct(self, grad): | |||||
| out = self.apply_ftrl(self.var, self.accum, self.linear, grad, self.lr, self.l1, self.l2, self.lr_power) | |||||
| return out | |||||
| test_case_math_ops = [ | test_case_math_ops = [ | ||||
| ('Neg', { | ('Neg', { | ||||
| @@ -602,6 +618,14 @@ test_case_nn_ops = [ | |||||
| 'block': G.ReluGrad(), | 'block': G.ReluGrad(), | ||||
| 'desc_inputs': [[1, 3, 4, 4], [1, 3, 4, 4]], | 'desc_inputs': [[1, 3, 4, 4], [1, 3, 4, 4]], | ||||
| 'skip': ['backward']}), | 'skip': ['backward']}), | ||||
| ('Softplus', { | |||||
| 'block': P.Softplus(), | |||||
| 'desc_inputs': [[1, 3, 4, 4]], | |||||
| 'desc_bprop': [[1, 3, 4, 4]]}), | |||||
| ('SoftplusGrad', { | |||||
| 'block': G.SoftplusGrad(), | |||||
| 'desc_inputs': [[1, 3, 4, 4], [1, 3, 4, 4]], | |||||
| 'skip': ['backward']}), | |||||
| ('Elu', { | ('Elu', { | ||||
| 'block': P.Elu(), | 'block': P.Elu(), | ||||
| 'desc_inputs': [[2, 3, 4]], | 'desc_inputs': [[2, 3, 4]], | ||||
| @@ -869,9 +893,8 @@ test_case_nn_ops = [ | |||||
| 'desc_inputs': [[3, 2]], | 'desc_inputs': [[3, 2]], | ||||
| 'desc_bprop': [[3, 2]]}), | 'desc_bprop': [[3, 2]]}), | ||||
| ('ApplyFtrl', { | ('ApplyFtrl', { | ||||
| 'block': P.ApplyFtrl(), | |||||
| 'desc_const': [0.001, 0.0, 0.0, -0.5], | |||||
| 'desc_inputs': [[3, 3], [3, 3], [3, 3], [3, 3]], | |||||
| 'block': ApplyFtrlNet(), | |||||
| 'desc_inputs': [[3, 3]], | |||||
| 'desc_bprop': [3, 3], | 'desc_bprop': [3, 3], | ||||
| 'skip': ['backward']}), | 'skip': ['backward']}), | ||||
| ('ApplyRMSProp', { | ('ApplyRMSProp', { | ||||