You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_while_ScatterNdUpdate.py 943 B

12345678910111213141516171819202122232425262728293031
  1. import numpy as np
  2. from mindspore import context, nn, Tensor, Parameter
  3. from mindspore.common import dtype as mstype
  4. from mindspore.ops import operations as P
  5. context.set_context(mode=context.GRAPH_MODE, save_graphs=False)
  6. class Net(nn.Cell):
  7. def __init__(self, data):
  8. super(Net, self).__init__()
  9. self.start = Tensor(0, dtype=mstype.int32)
  10. self.end = Tensor(2, dtype=mstype.int32)
  11. self.max_output = Parameter(data, "output_x")
  12. self.upd = P.ScatterNdUpdate()
  13. self.zero = Tensor(np.ones([1], dtype=np.int32))
  14. def construct(self, inputs):
  15. idx = self.start
  16. end = self.end
  17. while idx < end:
  18. xi = inputs[idx, :, :]
  19. self.upd(self.max_output, idx + self.zero, xi)
  20. idx = idx + 1
  21. return self.max_output + 0
  22. def test_x():
  23. x = Tensor(np.arange(10 * 2 * 3).reshape(10, 2, 3).astype(np.float32))
  24. net = Net(x)
  25. net(x)