|
|
|
@@ -1,4 +1,4 @@ |
|
|
|
# Copyright 2019 Huawei Technologies Co., Ltd |
|
|
|
# Copyright 2021 Huawei Technologies Co., Ltd |
|
|
|
# |
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
|
|
|
# you may not use this file except in compliance with the License. |
|
|
|
@@ -44,16 +44,13 @@ class NetOneHot(nn.Cell): |
|
|
|
self.one_hot_3(indices3), self.one_hot_4(indices4)) |
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.level0 |
|
|
|
@pytest.mark.platform_x86_gpu_training |
|
|
|
@pytest.mark.env_onecard |
|
|
|
def test_one_hot(): |
|
|
|
one_hot = NetOneHot() |
|
|
|
indices1 = Tensor(np.array([[0, 1], [4, 5], [2, 6]]).astype(np.int32)) |
|
|
|
indices2 = Tensor(np.array([1, 2, 3]).astype(np.int32)) |
|
|
|
indices3 = Tensor(np.array([[0, 1], [1, 0]]).astype(np.int32)) |
|
|
|
indices4 = Tensor(np.array([[0, 1], [4, 5], [2, 6]]).astype(np.int32)) |
|
|
|
output = one_hot(indices1, indices2, indices3, indices4) |
|
|
|
def one_hot(nptype): |
|
|
|
one_hot_net = NetOneHot() |
|
|
|
indices1 = Tensor(np.array([[0, 1], [4, 5], [2, 6]]).astype(nptype)) |
|
|
|
indices2 = Tensor(np.array([1, 2, 3]).astype(nptype)) |
|
|
|
indices3 = Tensor(np.array([[0, 1], [1, 0]]).astype(nptype)) |
|
|
|
indices4 = Tensor(np.array([[0, 1], [4, 5], [2, 6]]).astype(nptype)) |
|
|
|
output = one_hot_net(indices1, indices2, indices3, indices4) |
|
|
|
expect_0 = np.array([ |
|
|
|
[[2., 3., 3., 3., 3., 3.], [3., 2., 3., 3., 3., 3.]], |
|
|
|
[[3., 3., 3., 3., 2., 3.], [3., 3., 3., 3., 3., 2.]], |
|
|
|
@@ -80,3 +77,15 @@ def test_one_hot(): |
|
|
|
assert (output[1].asnumpy() == expect_1).all() |
|
|
|
assert (output[2].asnumpy() == expect_2).all() |
|
|
|
assert (output[3].asnumpy() == expect_3).all() |
|
|
|
|
|
|
|
@pytest.mark.level0 |
|
|
|
@pytest.mark.platform_x86_gpu_training |
|
|
|
@pytest.mark.env_onecard |
|
|
|
def test_one_hot_int32(): |
|
|
|
one_hot(np.int32) |
|
|
|
|
|
|
|
@pytest.mark.level0 |
|
|
|
@pytest.mark.platform_x86_gpu_training |
|
|
|
@pytest.mark.env_onecard |
|
|
|
def test_one_hot_int64(): |
|
|
|
one_hot(np.int64) |