diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index 63b4851995..7a22394d97 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -5764,7 +5764,12 @@ class ApplyFtrl(PrimitiveWithInfer): Default: -0.5. It must be a float number or a scalar tensor with float16 or float32 data type. Outputs: - Tensor, represents the updated `var`. + There are three outputs for Ascend environment. + - **var** (Tensor) - represents the updated `var`. + - **accum** (Tensor) - represents the updated `accum`. + - **linear** (Tensor) - represents the updated `linear`. + There is only one output for GPU environment. + - **var** (Tensor) - This value is alwalys zero and the input parameters has been updated in-place. Supported Platforms: ``Ascend`` ``GPU`` @@ -5773,8 +5778,8 @@ class ApplyFtrl(PrimitiveWithInfer): >>> import mindspore >>> import mindspore.nn as nn >>> import numpy as np - >>> from mindspore import Parameter - >>> from mindspore import Tensor + >>> from mindspore import Parameter, Tensor + >>> import mindspore.context as context >>> from mindspore.ops import operations as ops >>> class ApplyFtrlNet(nn.Cell): ... def __init__(self): @@ -5797,7 +5802,9 @@ class ApplyFtrl(PrimitiveWithInfer): >>> net = ApplyFtrlNet() >>> input_x = Tensor(np.random.randint(-4, 4, (2, 2)), mindspore.float32) >>> output = net(input_x) - >>> print(output) + >>> is_tbe = context.get_context("device_target") == "Ascend" + >>> if is_tbe: + ... print(output) (Tensor(shape=[2, 2], dtype=Float32, value= [[ 4.61418092e-01, 5.30964255e-01], [ 2.68715084e-01, 3.82065028e-01]]), Tensor(shape=[2, 2], dtype=Float32, value= @@ -5805,6 +5812,16 @@ class ApplyFtrl(PrimitiveWithInfer): [ 1.43758726e+00, 9.89177322e+00]]), Tensor(shape=[2, 2], dtype=Float32, value= [[-1.86994812e+03, -1.64906018e+03], [-3.22187836e+02, -1.20163989e+03]])) + >>> else: + ... print(net.var.asnumpy()) + [[0.4614181 0.5309642 ] + [0.2687151 0.38206503]] + ... print(net.accum.asnumpy()) + [[16.423655 9.645894 ] + [ 1.4375873 9.891773 ]] + ... print(net.linear.asnumpy()) + [[-1869.9479 -1649.0599] + [ -322.1879 -1201.6399]] """ @prim_attr_register