You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_tensor_array.py 3.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. import numpy as np
  16. import pytest
  17. import mindspore
  18. import mindspore.context as context
  19. import mindspore.nn as nn
  20. from mindspore import Tensor
  21. class TensorArrayNet(nn.Cell):
  22. def __init__(self, dtype, element_shape):
  23. super(TensorArrayNet, self).__init__()
  24. self.ta = nn.TensorArray(dtype, element_shape)
  25. def construct(self, index, value):
  26. for i in range(2):
  27. for _ in range(10):
  28. self.ta.write(index, value)
  29. index += 1
  30. value += 1
  31. if i == 0:
  32. self.ta.clear()
  33. index = 0
  34. v = self.ta.read(index-1)
  35. s = self.ta.stack()
  36. self.ta.close()
  37. return v, s
  38. @pytest.mark.level0
  39. @pytest.mark.platform_x86_gpu_training
  40. @pytest.mark.env_onecard
  41. def test_tensorarray():
  42. """
  43. Feature: TensorArray gpu TEST.
  44. Description: Test the function write, read, stack, clear, close in both graph and pynative mode.
  45. Expectation: success.
  46. """
  47. context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
  48. index = Tensor(0, mindspore.int64)
  49. value = Tensor(5, mindspore.int64)
  50. ta = TensorArrayNet(dtype=mindspore.int64, element_shape=())
  51. v, s = ta(index, value)
  52. expect_v = 24
  53. expect_s = [15, 16, 17, 18, 19, 20, 21, 22, 23, 24]
  54. assert np.allclose(s.asnumpy(), expect_s)
  55. assert np.allclose(v.asnumpy(), expect_v)
  56. context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
  57. ta = nn.TensorArray(mindspore.int64, ())
  58. for i in range(5):
  59. ta.write(i, 99)
  60. v = ta.read(0)
  61. s = ta.stack()
  62. expect_v = 99
  63. expect_s = [99, 99, 99, 99, 99]
  64. assert np.allclose(s.asnumpy(), expect_s)
  65. assert np.allclose(v.asnumpy(), expect_v)
  66. ta_size = ta.size()
  67. assert np.allclose(ta_size.asnumpy(), 5)
  68. ta.clear()
  69. ta_size = ta.size()
  70. assert np.allclose(ta_size.asnumpy(), 0)
  71. ta.write(0, 88)
  72. v = ta.read(0)
  73. s = ta.stack()
  74. ta.close()
  75. expect_v = 88
  76. expect_s = [88]
  77. assert np.allclose(s.asnumpy(), expect_s)
  78. assert np.allclose(v.asnumpy(), expect_v)
  79. ta = nn.TensorArray(mindspore.float32, ())
  80. ta.write(5, 1.)
  81. s = ta.stack()
  82. expect_s = [0., 0., 0., 0., 0., 1.]
  83. assert np.allclose(s.asnumpy(), expect_s)
  84. ta.write(2, 1.)
  85. s = ta.stack()
  86. expect_s = [0., 0., 1., 0., 0., 1.]
  87. assert np.allclose(s.asnumpy(), expect_s)
  88. ta.close()