You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_dynamic_stitch_op.py 2.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. import numpy as np
  16. import pytest
  17. import mindspore
  18. import mindspore.context as context
  19. import mindspore.nn as nn
  20. from mindspore import Tensor
  21. from mindspore.ops.operations import _inner_ops as ops
  22. context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
  23. class Net(nn.Cell):
  24. def __init__(self):
  25. super(Net, self).__init__()
  26. self.stitch = ops.DynamicStitch()
  27. def construct(self, indices, data):
  28. return self.stitch(indices, data)
  29. @pytest.mark.level0
  30. @pytest.mark.platform_x86_gpu_training
  31. @pytest.mark.env_onecard
  32. def test_net_int32():
  33. """
  34. Feature: ALL TO ALL
  35. Description: test cases for dynamicstitch.
  36. Expectation: the result match expected array.
  37. """
  38. x1 = Tensor([6], mindspore.int32)
  39. x2 = Tensor(np.array([4, 1]), mindspore.int32)
  40. x3 = Tensor(np.array([[5, 2], [0, 3]]), mindspore.int32)
  41. y1 = Tensor(np.array([[61, 62]]), mindspore.int32)
  42. y2 = Tensor(np.array([[41, 42], [11, 12]]), mindspore.int32)
  43. y3 = Tensor(np.array([[[51, 52], [21, 22]], [[1, 2], [31, 32]]]), mindspore.int32)
  44. expected = np.array([[1, 2], [11, 12], [21, 22],
  45. [31, 32], [41, 42], [51, 52], [61, 62]]).astype(np.int32)
  46. indices = [x1, x2, x3]
  47. data = [y1, y2, y3]
  48. net = Net()
  49. output = net(indices, data)
  50. assert np.array_equal(output.asnumpy(), expected)