You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_Tensor.py 1.7 kB

5 years ago
5 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. import numpy as np
  16. import mindspore._c_dataengine as cde
  17. def test_shape():
  18. x = [2, 3]
  19. s = cde.TensorShape(x)
  20. assert s.as_list() == x
  21. assert s.is_known()
  22. def test_basic():
  23. x = np.array([1, 2, 3, 4, 5])
  24. n = cde.Tensor(x)
  25. arr = np.array(n, copy=False)
  26. arr[0] = 0
  27. x = np.array([0, 2, 3, 4, 5])
  28. assert np.array_equal(x, arr)
  29. assert n.type() == cde.DataType("int64")
  30. arr2 = n.as_array()
  31. arr[0] = 2
  32. x = np.array([2, 2, 3, 4, 5])
  33. assert np.array_equal(x, arr2)
  34. assert n.type() == cde.DataType("int64")
  35. assert arr.__array_interface__['data'] == arr2.__array_interface__['data']
  36. def test_strides():
  37. x = np.array([[1, 2, 3], [4, 5, 6]])
  38. n1 = cde.Tensor(x[:, 1])
  39. arr = np.array(n1, copy=False)
  40. assert np.array_equal(x[:, 1], arr)
  41. n2 = cde.Tensor(x.transpose())
  42. arr = np.array(n2, copy=False)
  43. assert np.array_equal(x.transpose(), arr)
  44. if __name__ == '__main__':
  45. test_shape()
  46. test_strides()
  47. test_basic()