|
|
|
@@ -480,6 +480,7 @@ def test_tensor_operation(): |
|
|
|
def test_tensor_from_numpy(): |
|
|
|
a = np.ones((2, 3)) |
|
|
|
t = ms.Tensor.from_numpy(a) |
|
|
|
assert isinstance(t, ms.Tensor) |
|
|
|
assert np.all(t.asnumpy() == 1) |
|
|
|
# 't' and 'a' share same data. |
|
|
|
a[1] = 2 |
|
|
|
@@ -489,3 +490,6 @@ def test_tensor_from_numpy(): |
|
|
|
del a |
|
|
|
assert np.all(t.asnumpy()[0] == 1) |
|
|
|
assert np.all(t.asnumpy()[1] == 2) |
|
|
|
with pytest.raises(TypeError): |
|
|
|
# incorrect input. |
|
|
|
t = ms.Tensor.from_numpy([1, 2, 3]) |