|
|
|
@@ -5764,7 +5764,12 @@ class ApplyFtrl(PrimitiveWithInfer): |
|
|
|
Default: -0.5. It must be a float number or a scalar tensor with float16 or float32 data type. |
|
|
|
|
|
|
|
Outputs: |
|
|
|
Tensor, represents the updated `var`. |
|
|
|
There are three outputs for Ascend environment. |
|
|
|
- **var** (Tensor) - represents the updated `var`. |
|
|
|
- **accum** (Tensor) - represents the updated `accum`. |
|
|
|
- **linear** (Tensor) - represents the updated `linear`. |
|
|
|
There is only one output for GPU environment. |
|
|
|
- **var** (Tensor) - This value is alwalys zero and the input parameters has been updated in-place. |
|
|
|
|
|
|
|
Supported Platforms: |
|
|
|
``Ascend`` ``GPU`` |
|
|
|
@@ -5773,8 +5778,8 @@ class ApplyFtrl(PrimitiveWithInfer): |
|
|
|
>>> import mindspore |
|
|
|
>>> import mindspore.nn as nn |
|
|
|
>>> import numpy as np |
|
|
|
>>> from mindspore import Parameter |
|
|
|
>>> from mindspore import Tensor |
|
|
|
>>> from mindspore import Parameter, Tensor |
|
|
|
>>> import mindspore.context as context |
|
|
|
>>> from mindspore.ops import operations as ops |
|
|
|
>>> class ApplyFtrlNet(nn.Cell): |
|
|
|
... def __init__(self): |
|
|
|
@@ -5797,7 +5802,9 @@ class ApplyFtrl(PrimitiveWithInfer): |
|
|
|
>>> net = ApplyFtrlNet() |
|
|
|
>>> input_x = Tensor(np.random.randint(-4, 4, (2, 2)), mindspore.float32) |
|
|
|
>>> output = net(input_x) |
|
|
|
>>> print(output) |
|
|
|
>>> is_tbe = context.get_context("device_target") == "Ascend" |
|
|
|
>>> if is_tbe: |
|
|
|
... print(output) |
|
|
|
(Tensor(shape=[2, 2], dtype=Float32, value= |
|
|
|
[[ 4.61418092e-01, 5.30964255e-01], |
|
|
|
[ 2.68715084e-01, 3.82065028e-01]]), Tensor(shape=[2, 2], dtype=Float32, value= |
|
|
|
@@ -5805,6 +5812,16 @@ class ApplyFtrl(PrimitiveWithInfer): |
|
|
|
[ 1.43758726e+00, 9.89177322e+00]]), Tensor(shape=[2, 2], dtype=Float32, value= |
|
|
|
[[-1.86994812e+03, -1.64906018e+03], |
|
|
|
[-3.22187836e+02, -1.20163989e+03]])) |
|
|
|
>>> else: |
|
|
|
... print(net.var.asnumpy()) |
|
|
|
[[0.4614181 0.5309642 ] |
|
|
|
[0.2687151 0.38206503]] |
|
|
|
... print(net.accum.asnumpy()) |
|
|
|
[[16.423655 9.645894 ] |
|
|
|
[ 1.4375873 9.891773 ]] |
|
|
|
... print(net.linear.asnumpy()) |
|
|
|
[[-1869.9479 -1649.0599] |
|
|
|
[ -322.1879 -1201.6399]] |
|
|
|
""" |
|
|
|
|
|
|
|
@prim_attr_register |
|
|
|
|