From c34258d08d1b0c89f9ae7f781eb0169f860da896 Mon Sep 17 00:00:00 2001 From: Peilin Wang Date: Mon, 18 Jan 2021 00:46:51 -0500 Subject: [PATCH] initial commit --- mindspore/ops/operations/nn_ops.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index 5f524aa79b..0191a7068a 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -5340,25 +5340,26 @@ class SparseApplyProximalAdagrad(PrimitiveWithCheck): ... def __init__(self): ... super(Net, self).__init__() ... self.sparse_apply_proximal_adagrad = ops.SparseApplyProximalAdagrad() - ... self.var = Parameter(Tensor(np.random.rand(1, 2).astype(np.float32)), name="var") - ... self.accum = Parameter(Tensor(np.random.rand(1, 2).astype(np.float32)), name="accum") - ... self.lr = 0.01 - ... self.l1 = 0.0 + ... self.var = Parameter(Tensor(np.array([[4.1, 7.2], [1.1, 3.0]], np.float32)), name="var") + ... self.accum = Parameter(Tensor(np.array([[0, 0], [0, 0]], np.float32)), name="accum") + ... self.lr = 1.0 + ... self.l1 = 1.0 ... self.l2 = 0.0 ... def construct(self, grad, indices): ... out = self.sparse_apply_proximal_adagrad(self.var, self.accum, self.lr, self.l1, ... self.l2, grad, indices) ... return out ... - >>> np.random.seed(0) >>> net = Net() - >>> grad = Tensor(np.random.rand(1, 2).astype(np.float32)) - >>> indices = Tensor(np.ones((1,), np.int32)) + >>> grad = Tensor(np.array([[1, 1], [1, 1]], np.float32)) + >>> indices = Tensor(np.array([0], np.int32)) >>> output = net(grad, indices) >>> print(output) - (Tensor(shape=[1, 2], dtype=Float32, value= - [[ 5.48813522e-01, 7.15189338e-01]]), Tensor(shape=[1, 2], dtype=Float32, value= - [[ 6.02763355e-01, 5.44883192e-01]])) + (Tensor(shape=[2, 2], dtype=Float32, value= + [[ 2.97499990e+00, 6.07499981e+00], + [ 0.00000000e+00, 1.87500000e+00]]), Tensor(shape=[2, 2], dtype=Float32, value= + [[ 6.40000000e+01, 6.40000000e+01], + [ 6.40000000e+01, 6.40000000e+01]])) """ __mindspore_signature__ = (