You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_dataset_util.py 3.1 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import numpy as np
  15. import mindspore as ms
  16. from mindspore import Tensor
  17. from mindspore.parallel._utils import _to_full_shapes, _to_full_tensor
  18. def test_to_full_shapes():
  19. device_num = 16
  20. shapes = [[32, 128], [12], [24, 1, 12]]
  21. full_shapes = _to_full_shapes(shapes, device_num)
  22. assert full_shapes == [(512, 128), (192,), (384, 1, 12)]
  23. def test_to_full_tensor_1():
  24. elem = Tensor([[1, 2, 3], [4, 5, 6]], dtype=ms.float32)
  25. device_num = 4
  26. global_rank = 2
  27. full_tensor = _to_full_tensor(elem, device_num, global_rank, scaling_sens=None)
  28. expect = ([[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [1, 2, 3], [4, 5, 6], [0, 0, 0], [0, 0, 0]])
  29. expect_tensor = Tensor(expect, dtype=ms.float32)
  30. assert np.all(full_tensor[0].asnumpy() == expect_tensor.asnumpy())
  31. def test_to_full_tensor_2():
  32. elem0 = Tensor([[1, 2, 3], [4, 5, 6]], dtype=ms.float32)
  33. elem1 = Tensor([[1], [4]], dtype=ms.int32)
  34. elem = (elem0, elem1,)
  35. device_num = 4
  36. global_rank = 2
  37. full_tensor = _to_full_tensor(elem, device_num, global_rank, scaling_sens=None)
  38. expect0 = ([[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [1, 2, 3], [4, 5, 6], [0, 0, 0], [0, 0, 0]])
  39. expect_tensor0 = Tensor(expect0, dtype=ms.float32)
  40. expect1 = ([[0], [0], [0], [0], [1], [4], [0], [0]])
  41. expect_tensor1 = Tensor(expect1, dtype=ms.int32)
  42. expect_tensors = (expect_tensor0, expect_tensor1)
  43. assert np.all(full_tensor[0].asnumpy() == expect_tensors[0].asnumpy())
  44. assert np.all(full_tensor[1].asnumpy() == expect_tensors[1].asnumpy())
  45. def test_to_full_tensor_sens_2():
  46. elem0 = Tensor([[1, 2, 3], [4, 5, 6]], dtype=ms.float32)
  47. elem1 = Tensor([[1], [4]], dtype=ms.int32)
  48. elem = (elem0, elem1,)
  49. device_num = 4
  50. global_rank = 2
  51. full_tensor = _to_full_tensor(elem, device_num, global_rank, scaling_sens=0.1)
  52. expect0 = ([[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [1, 2, 3], [4, 5, 6], [0, 0, 0], [0, 0, 0]])
  53. expect_tensor0 = Tensor(expect0, dtype=ms.float32)
  54. expect1 = ([[0], [0], [0], [0], [1], [4], [0], [0]])
  55. expect_tensor1 = Tensor(expect1, dtype=ms.int32)
  56. expect_tensor_sens = Tensor(0.1, dtype=ms.float32)
  57. expect_tensors = (expect_tensor0, expect_tensor1, expect_tensor_sens)
  58. assert np.all(full_tensor[0].asnumpy() == expect_tensors[0].asnumpy())
  59. assert np.all(full_tensor[1].asnumpy() == expect_tensors[1].asnumpy())
  60. assert np.all(full_tensor[2].asnumpy() == expect_tensors[2].asnumpy())