|
|
|
@@ -51,16 +51,18 @@ class ControlDepend(Primitive): |
|
|
|
>>> class Net(nn.Cell): |
|
|
|
>>> def __init__(self): |
|
|
|
>>> super(Net, self).__init__() |
|
|
|
>>> self.global_step = mindspore.Parameter(initializer(0, [1]), name="global_step") |
|
|
|
>>> self.rate = 0.2 |
|
|
|
>>> self.control_depend = P.ControlDepend() |
|
|
|
>>> self.softmax = P.Softmax() |
|
|
|
>>> |
|
|
|
>>> def construct(self, x): |
|
|
|
>>> data = self.rate * self.global_step + x |
|
|
|
>>> added_global_step = self.global_step + 1 |
|
|
|
>>> self.global_step = added_global_step |
|
|
|
>>> self.control_depend(data, added_global_step) |
|
|
|
>>> return data |
|
|
|
>>> def construct(self, x, y): |
|
|
|
>>> mul = x * y |
|
|
|
>>> softmax = self.softmax(x) |
|
|
|
>>> ret = self.control_depend(mul, softmax) |
|
|
|
>>> return ret |
|
|
|
>>> x = Tensor(np.ones([4, 5]), dtype=mindspore.float32) |
|
|
|
>>> y = Tensor(np.ones([4, 5]), dtype=mindspore.float32) |
|
|
|
>>> net = Net() |
|
|
|
>>> output = net(x, y) |
|
|
|
""" |
|
|
|
|
|
|
|
@prim_attr_register |
|
|
|
|