Browse Source

!12356 Add float64 support to gpu Add, Sub, Mul and Div

From: @peilin-wang
Reviewed-by: @robingrosman
Signed-off-by: @robingrosman
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 4 years ago
parent
commit
d6b87269a2
8 changed files with 325 additions and 144 deletions
  1. +11
    -1
      mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cu
  2. +15
    -1
      mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.cc
  3. +5
    -1
      mindspore/ccsrc/backend/kernel_compiler/gpu/other/gpu_convert_to_dynamic_shape_gpu_kernel.cc
  4. +87
    -33
      tests/st/ops/gpu/test_add_op.py
  5. +47
    -42
      tests/st/ops/gpu/test_div_op.py
  6. +13
    -1
      tests/st/ops/gpu/test_gpu_convert_to_dynamic_shape_op.py
  7. +93
    -38
      tests/st/ops/gpu/test_mul_op.py
  8. +54
    -27
      tests/st/ops/gpu/test_sub_op.py

+ 11
- 1
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cu View File

@@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@@ -222,6 +222,8 @@ void ElewiseCmp(const int &nums, enum BroadcastOpType op, const T *x0, const T *
} }
} }


template void ElewiseCmp(const int &nums, enum BroadcastOpType op, const double *x0, const double *x1, bool *y,
cudaStream_t stream);
template void ElewiseCmp(const int &nums, enum BroadcastOpType op, const float *x0, const float *x1, bool *y, template void ElewiseCmp(const int &nums, enum BroadcastOpType op, const float *x0, const float *x1, bool *y,
cudaStream_t stream); cudaStream_t stream);
template void ElewiseCmp(const int &nums, enum BroadcastOpType op, const half *x0, const half *x1, bool *y, template void ElewiseCmp(const int &nums, enum BroadcastOpType op, const half *x0, const half *x1, bool *y,
@@ -292,6 +294,8 @@ void ElewiseArith(const int &nums, enum BroadcastOpType op, const half *x0, cons
} }
} }


template void ElewiseArith(const int &nums, enum BroadcastOpType op, const double *x0, const double *x1, double *y,
cudaStream_t stream);
template void ElewiseArith(const int &nums, enum BroadcastOpType op, const float *x0, const float *x1, float *y, template void ElewiseArith(const int &nums, enum BroadcastOpType op, const float *x0, const float *x1, float *y,
cudaStream_t stream); cudaStream_t stream);
template void ElewiseArith(const int &nums, enum BroadcastOpType op, const half *x0, const half *x1, half *y, template void ElewiseArith(const int &nums, enum BroadcastOpType op, const half *x0, const half *x1, half *y,
@@ -372,6 +376,9 @@ void BroadcastCmp(const std::vector<size_t> &x0_dims, const std::vector<size_t>
} }
} }


template void BroadcastCmp(const std::vector<size_t> &x0_dims, const std::vector<size_t> &x1_dims,
const std::vector<size_t> &y_dims, enum BroadcastOpType op, const double *x0,
const double *x1, bool *y, cudaStream_t stream);
template void BroadcastCmp(const std::vector<size_t> &x0_dims, const std::vector<size_t> &x1_dims, template void BroadcastCmp(const std::vector<size_t> &x0_dims, const std::vector<size_t> &x1_dims,
const std::vector<size_t> &y_dims, enum BroadcastOpType op, const float *x0, const float *x1, const std::vector<size_t> &y_dims, enum BroadcastOpType op, const float *x0, const float *x1,
bool *y, cudaStream_t stream); bool *y, cudaStream_t stream);
@@ -501,6 +508,9 @@ void BroadcastArith(const std::vector<size_t> &x0_dims, const std::vector<size_t
} }
} }


template void BroadcastArith(const std::vector<size_t> &x0_dims, const std::vector<size_t> &x1_dims,
const std::vector<size_t> &y_dims, enum BroadcastOpType op, const double *x0,
const double *x1, double *y, cudaStream_t stream);
template void BroadcastArith(const std::vector<size_t> &x0_dims, const std::vector<size_t> &x1_dims, template void BroadcastArith(const std::vector<size_t> &x0_dims, const std::vector<size_t> &x1_dims,
const std::vector<size_t> &y_dims, enum BroadcastOpType op, const float *x0, const std::vector<size_t> &y_dims, enum BroadcastOpType op, const float *x0,
const float *x1, float *y, cudaStream_t stream); const float *x1, float *y, cudaStream_t stream);


+ 15
- 1
mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.cc View File

@@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@@ -18,6 +18,20 @@


namespace mindspore { namespace mindspore {
namespace kernel { namespace kernel {
// fp64
MS_REG_GPU_KERNEL_ONE(
Add, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
BroadcastOpGpuKernel, double)
MS_REG_GPU_KERNEL_ONE(
Sub, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
BroadcastOpGpuKernel, double)
MS_REG_GPU_KERNEL_ONE(
Mul, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
BroadcastOpGpuKernel, double)
MS_REG_GPU_KERNEL_ONE(
Div, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
BroadcastOpGpuKernel, double)

// fp32 // fp32
MS_REG_GPU_KERNEL_ONE( MS_REG_GPU_KERNEL_ONE(
Greater, Greater,


+ 5
- 1
mindspore/ccsrc/backend/kernel_compiler/gpu/other/gpu_convert_to_dynamic_shape_gpu_kernel.cc View File

@@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@@ -31,6 +31,10 @@ MS_REG_GPU_KERNEL_ONE(GpuConvertToDynamicShape,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
GpuConvertToDynamicShapeGpuKernel, float) GpuConvertToDynamicShapeGpuKernel, float)


MS_REG_GPU_KERNEL_ONE(GpuConvertToDynamicShape,
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
GpuConvertToDynamicShapeGpuKernel, double)

MS_REG_GPU_KERNEL_ONE(GpuConvertToDynamicShape, MS_REG_GPU_KERNEL_ONE(GpuConvertToDynamicShape,
KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8), KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
GpuConvertToDynamicShapeGpuKernel, int8_t) GpuConvertToDynamicShapeGpuKernel, int8_t)


tests/st/ops/gpu/test_tensoradd.py → tests/st/ops/gpu/test_add_op.py View File

@@ -1,4 +1,4 @@
# Copyright 2019 Huawei Technologies Co., Ltd
# Copyright 2019-2021 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@@ -25,34 +25,32 @@ from mindspore.common.parameter import Parameter
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops.operations import _inner_ops as inner from mindspore.ops.operations import _inner_ops as inner


context.set_context(device_target='GPU')


class TensroAdd(nn.Cell):
def __init__(self):
super(TensroAdd, self).__init__()
class AddNet(nn.Cell):
def __init__(self, nptype):
super(AddNet, self).__init__()


self.add = P.Add() self.add = P.Add()


np.random.seed(0)
self.x = Parameter(initializer( self.x = Parameter(initializer(
Tensor(np.random.randn(2, 0).astype(np.float32)), [2, 0]), name='x')
Tensor(np.random.randn(2, 0).astype(nptype)), [2, 0]), name='x')
self.y = Parameter(initializer( self.y = Parameter(initializer(
Tensor(np.random.randn(2, 1).astype(np.float32)), [2, 1]), name='y')
Tensor(np.random.randn(2, 1).astype(nptype)), [2, 1]), name='y')


self.x1 = Parameter(initializer( self.x1 = Parameter(initializer(
Tensor(np.arange(3).reshape(3).astype(np.float32)), [3]), name='x1')
Tensor(np.arange(3).reshape(3).astype(nptype)), [3]), name='x1')
self.y1 = Parameter(initializer( self.y1 = Parameter(initializer(
Tensor(np.array([2]).astype(np.float32)), [1]), name='y1')
Tensor(np.array([2]).astype(nptype)), [1]), name='y1')


self.x2 = Parameter(initializer( self.x2 = Parameter(initializer(
Tensor(np.arange(3 * 3 * 3 * 3).reshape(3, 3, 3, 3).astype(np.float32)), [3, 3, 3, 3]), name='x2')
Tensor(np.arange(3 * 3 * 3 * 3).reshape(3, 3, 3, 3).astype(nptype)), [3, 3, 3, 3]), name='x2')
self.y2 = Parameter(initializer( self.y2 = Parameter(initializer(
Tensor(np.arange(3 * 3 * 3 * 3).reshape(3, 3, 3, 3).astype(np.float32)), [3, 3, 3, 3]), name='y2')
Tensor(np.arange(3 * 3 * 3 * 3).reshape(3, 3, 3, 3).astype(nptype)), [3, 3, 3, 3]), name='y2')


self.x3 = Parameter(initializer( self.x3 = Parameter(initializer(
Tensor(np.arange(1 * 1 * 3 * 3).reshape(1, 1, 3, 3).astype(np.float32)), [1, 1, 3, 3]), name='x3')
Tensor(np.arange(1 * 1 * 3 * 3).reshape(1, 1, 3, 3).astype(nptype)), [1, 1, 3, 3]), name='x3')
self.y3 = Parameter(initializer( self.y3 = Parameter(initializer(
Tensor(np.arange(3 * 3 * 3 * 3).reshape(3, 3, 3, 3).astype(np.float32)), [3, 3, 3, 3]), name='y3')
Tensor(np.arange(3 * 3 * 3 * 3).reshape(3, 3, 3, 3).astype(nptype)), [3, 3, 3, 3]), name='y3')


@ms_function @ms_function
def construct(self): def construct(self):
@@ -61,14 +59,13 @@ class TensroAdd(nn.Cell):
self.add(self.x3, self.y3)) self.add(self.x3, self.y3))




@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_TensorAdd():
add = TensroAdd()
output = add()
def add(nptype):
context.set_context(device_target='GPU')

add_net = AddNet(nptype)
output = add_net()
expect0 = np.array([]) expect0 = np.array([])
expect1 = np.array([2, 3, 4])
expect1 = np.array([2, 3, 4]).astype(nptype)
expect2 = np.array( expect2 = np.array(
[[[[0., 2., 4.], [[[[0., 2., 4.],
[6., 8., 10.], [6., 8., 10.],
@@ -96,7 +93,7 @@ def test_TensorAdd():
[138., 140., 142.]], [138., 140., 142.]],
[[144., 146., 148.], [[144., 146., 148.],
[150., 152., 154.], [150., 152., 154.],
[156., 158., 160.]]]])
[156., 158., 160.]]]]).astype(nptype)
expect3 = np.array( expect3 = np.array(
[[[[0., 2., 4.], [[[[0., 2., 4.],
[6., 8., 10.], [6., 8., 10.],
@@ -124,13 +121,42 @@ def test_TensorAdd():
[75., 77., 79.]], [75., 77., 79.]],
[[72., 74., 76.], [[72., 74., 76.],
[78., 80., 82.], [78., 80., 82.],
[84., 86., 88.]]]]
)
[84., 86., 88.]]]]).astype(nptype)
assert (output[0].asnumpy() == expect0).all() assert (output[0].asnumpy() == expect0).all()
assert (output[1].asnumpy() == expect1).all() assert (output[1].asnumpy() == expect1).all()
assert (output[2].asnumpy() == expect2).all() assert (output[2].asnumpy() == expect2).all()
assert (output[3].asnumpy() == expect3).all() assert (output[3].asnumpy() == expect3).all()


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_add_float64():
add(np.float64)

@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_add_float32():
add(np.float32)

@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_add_float16():
add(np.float16)

@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_add_int64():
add(np.int64)

@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_add_int32():
add(np.int32)

class Tensoradd_d(nn.Cell): class Tensoradd_d(nn.Cell):
def __init__(self): def __init__(self):
super(Tensoradd_d, self).__init__() super(Tensoradd_d, self).__init__()
@@ -142,18 +168,16 @@ class Tensoradd_d(nn.Cell):
y = self.test_dynamic(y) y = self.test_dynamic(y)
return self.add(x, y) return self.add(x, y)


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_TensorAdd_dynamic():

def add_dynamic(nptype):
context.set_context(device_target='GPU', mode=context.GRAPH_MODE) context.set_context(device_target='GPU', mode=context.GRAPH_MODE)
net = Tensoradd_d() net = Tensoradd_d()


x1 = Tensor(np.arange(3).reshape(3).astype(np.float32))
y1 = Tensor(np.array([2]).astype(np.float32))
x1 = Tensor(np.arange(3).reshape(3).astype(nptype))
y1 = Tensor(np.array([2]).astype(nptype))


x2 = Tensor(np.arange(3 * 3 * 3 * 3).reshape(3, 3, 3, 3).astype(np.float32))
y2 = Tensor(np.arange(3 * 3 * 3 * 3).reshape(3, 3, 3, 3).astype(np.float32))
x2 = Tensor(np.arange(3 * 3 * 3 * 3).reshape(3, 3, 3, 3).astype(nptype))
y2 = Tensor(np.arange(3 * 3 * 3 * 3).reshape(3, 3, 3, 3).astype(nptype))


expect1 = np.array([2, 3, 4]) expect1 = np.array([2, 3, 4])
expect2 = np.array( expect2 = np.array(
@@ -189,3 +213,33 @@ def test_TensorAdd_dynamic():
output2 = net(x2, y2) output2 = net(x2, y2)
assert (output1.asnumpy() == expect1).all() assert (output1.asnumpy() == expect1).all()
assert (output2.asnumpy() == expect2).all() assert (output2.asnumpy() == expect2).all()

@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_add_dynamic_float64():
add_dynamic(np.float64)

@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_add_dynamic_float32():
add_dynamic(np.float32)

@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_add_dynamic_float16():
add_dynamic(np.float16)

@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_add_dynamic_int64():
add_dynamic(np.int64)

@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_add_dynamic_int32():
add_dynamic(np.int32)

+ 47
- 42
tests/st/ops/gpu/test_div_op.py View File

@@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@@ -29,24 +29,17 @@ class NetDiv(nn.Cell):
def construct(self, x, y): def construct(self, x, y):
return self.div(x, y) return self.div(x, y)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_div():
x0_np = np.random.randint(1, 5, (2, 3, 4, 4)).astype(np.float32)
y0_np = np.random.randint(1, 5, (2, 3, 4, 4)).astype(np.float32)
x1_np = np.random.randint(1, 5, (2, 3, 4, 4)).astype(np.float32)
y1_np = np.random.randint(1, 5, (2, 1, 4, 4)).astype(np.float32)
x2_np = np.random.randint(1, 5, (2, 1, 1, 4)).astype(np.float32)
y2_np = np.random.randint(1, 5, (2, 3, 4, 4)).astype(np.float32)
x3_np = np.random.randint(1, 5, 1).astype(np.float32)
y3_np = np.random.randint(1, 5, 1).astype(np.float32)
x4_np = np.array(768).astype(np.float32)
y4_np = np.array(3072.5).astype(np.float32)
x5_np = np.random.randint(1, 5, (2, 3, 4, 4)).astype(np.float16)
y5_np = np.random.randint(1, 5, (2, 3, 4, 4)).astype(np.float16)
x6_np = np.random.randint(1, 5, (2, 3, 4, 4)).astype(np.int32)
y6_np = np.random.randint(1, 5, (2, 1, 4, 4)).astype(np.int32)
def div(nptype):
x0_np = np.random.randint(1, 5, (2, 3, 4, 4)).astype(nptype)
y0_np = np.random.randint(1, 5, (2, 3, 4, 4)).astype(nptype)
x1_np = np.random.randint(1, 5, (2, 3, 4, 4)).astype(nptype)
y1_np = np.random.randint(1, 5, (2, 1, 4, 4)).astype(nptype)
x2_np = np.random.randint(1, 5, (2, 1, 1, 4)).astype(nptype)
y2_np = np.random.randint(1, 5, (2, 3, 4, 4)).astype(nptype)
x3_np = np.random.randint(1, 5, 1).astype(nptype)
y3_np = np.random.randint(1, 5, 1).astype(nptype)
x4_np = np.array(78).astype(nptype)
y4_np = np.array(37.5).astype(nptype)
x0 = Tensor(x0_np) x0 = Tensor(x0_np)
y0 = Tensor(y0_np) y0 = Tensor(y0_np)
@@ -58,28 +51,24 @@ def test_div():
y3 = Tensor(y3_np) y3 = Tensor(y3_np)
x4 = Tensor(x4_np) x4 = Tensor(x4_np)
y4 = Tensor(y4_np) y4 = Tensor(y4_np)
x5 = Tensor(x5_np)
y5 = Tensor(y5_np)
x6 = Tensor(x6_np)
y6 = Tensor(y6_np)
context.set_context(mode=context.GRAPH_MODE, device_target='GPU') context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
div = NetDiv()
output0 = div(x0, y0)
div_net = NetDiv()
output0 = div_net(x0, y0)
expect0 = np.divide(x0_np, y0_np) expect0 = np.divide(x0_np, y0_np)
diff0 = output0.asnumpy() - expect0 diff0 = output0.asnumpy() - expect0
error0 = np.ones(shape=expect0.shape) * 1.0e-5 error0 = np.ones(shape=expect0.shape) * 1.0e-5
assert np.all(diff0 < error0) assert np.all(diff0 < error0)
assert output0.shape == expect0.shape assert output0.shape == expect0.shape
output1 = div(x1, y1)
output1 = div_net(x1, y1)
expect1 = np.divide(x1_np, y1_np) expect1 = np.divide(x1_np, y1_np)
diff1 = output1.asnumpy() - expect1 diff1 = output1.asnumpy() - expect1
error1 = np.ones(shape=expect1.shape) * 1.0e-5 error1 = np.ones(shape=expect1.shape) * 1.0e-5
assert np.all(diff1 < error1) assert np.all(diff1 < error1)
assert output1.shape == expect1.shape assert output1.shape == expect1.shape
output2 = div(x2, y2)
output2 = div_net(x2, y2)
expect2 = np.divide(x2_np, y2_np) expect2 = np.divide(x2_np, y2_np)
diff2 = output2.asnumpy() - expect2 diff2 = output2.asnumpy() - expect2
error2 = np.ones(shape=expect2.shape) * 1.0e-5 error2 = np.ones(shape=expect2.shape) * 1.0e-5
@@ -87,30 +76,46 @@ def test_div():
assert output2.shape == expect2.shape assert output2.shape == expect2.shape
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
output3 = div(x3, y3)
output3 = div_net(x3, y3)
expect3 = np.divide(x3_np, y3_np) expect3 = np.divide(x3_np, y3_np)
diff3 = output3.asnumpy() - expect3 diff3 = output3.asnumpy() - expect3
error3 = np.ones(shape=expect3.shape) * 1.0e-5 error3 = np.ones(shape=expect3.shape) * 1.0e-5
assert np.all(diff3 < error3) assert np.all(diff3 < error3)
assert output3.shape == expect3.shape assert output3.shape == expect3.shape
output4 = div(x4, y4)
output4 = div_net(x4, y4)
expect4 = np.divide(x4_np, y4_np) expect4 = np.divide(x4_np, y4_np)
diff4 = output4.asnumpy() - expect4 diff4 = output4.asnumpy() - expect4
error4 = np.ones(shape=expect4.shape) * 1.0e-5 error4 = np.ones(shape=expect4.shape) * 1.0e-5
assert np.all(diff4 < error4) assert np.all(diff4 < error4)
assert output4.shape == expect4.shape assert output4.shape == expect4.shape
output5 = div(x5, y5)
expect5 = np.divide(x5_np, y5_np)
diff5 = output5.asnumpy() - expect5
error5 = np.ones(shape=expect5.shape) * 1.0e-5
assert np.all(diff5 < error5)
assert output5.shape == expect5.shape
output6 = div(x6, y6)
expect6 = np.divide(x6_np, y6_np)
diff6 = output6.asnumpy() - expect6
error6 = np.ones(shape=expect6.shape) * 1.0e-5
assert np.all(diff6 < error6)
assert output6.shape == expect6.shape
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_div_float64():
div(np.float64)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_div_float32():
div(np.float32)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_div_float16():
div(np.float16)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_div_int64():
div(np.int64)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_div_int32():
div(np.int32)

+ 13
- 1
tests/st/ops/gpu/test_gpu_convert_to_dynamic_shape_op.py View File

@@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@@ -63,6 +63,12 @@ def gpu_convert_to_dynamic_shape_float(dtype):
np.random.seed(0) np.random.seed(0)
finfo = np.finfo(dtype) finfo = np.finfo(dtype)
# np.random.uniform will overflow if we use min/max for float64, so we use
# the finfo for float32, but still test the operator with float64 input.
if dtype == np.float64:
finfo = np.finfo(np.float32)
float_min = finfo.min float_min = finfo.min
float_max = finfo.max float_max = finfo.max
x = np.random.uniform(low=float_min, high=float_max, size=12).astype(dtype) x = np.random.uniform(low=float_min, high=float_max, size=12).astype(dtype)
@@ -103,6 +109,12 @@ def test_gpu_convert_to_dynamic_shape_float16():
def test_gpu_convert_to_dynamic_shape_float32(): def test_gpu_convert_to_dynamic_shape_float32():
gpu_convert_to_dynamic_shape_float(np.float32) gpu_convert_to_dynamic_shape_float(np.float32)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_gpu_convert_to_dynamic_shape_float64():
gpu_convert_to_dynamic_shape_float(np.float64)
@pytest.mark.level0 @pytest.mark.level0
@pytest.mark.platform_x86_gpu_training @pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard @pytest.mark.env_onecard


+ 93
- 38
tests/st/ops/gpu/test_mul_op.py View File

@@ -1,4 +1,4 @@
# Copyright 2019 Huawei Technologies Co., Ltd
# Copyright 2019-2021 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@@ -31,20 +31,17 @@ class NetMul(nn.Cell):
return self.mul(x, y) return self.mul(x, y)




@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_mul():
x0_np = np.random.uniform(-2, 2, (2, 3, 4, 4)).astype(np.float32)
y0_np = np.random.uniform(-2, 2, (2, 3, 4, 4)).astype(np.float32)
x1_np = np.random.uniform(-2, 2, (2, 3, 4, 4)).astype(np.float32)
y1_np = np.random.uniform(-2, 2, (2, 1, 4, 4)).astype(np.float32)
x2_np = np.random.uniform(-2, 2, (2, 1, 1, 4)).astype(np.float32)
y2_np = np.random.uniform(-2, 2, (2, 3, 4, 4)).astype(np.float32)
x3_np = np.random.uniform(-2, 2, 1).astype(np.float32)
y3_np = np.random.uniform(-2, 2, 1).astype(np.float32)
x4_np = np.array(768).astype(np.float32)
y4_np = np.array(3072.5).astype(np.float32)
def mul(nptype):
x0_np = np.random.uniform(-2, 2, (2, 3, 4, 4)).astype(nptype)
y0_np = np.random.uniform(-2, 2, (2, 3, 4, 4)).astype(nptype)
x1_np = np.random.uniform(-2, 2, (2, 3, 4, 4)).astype(nptype)
y1_np = np.random.uniform(-2, 2, (2, 1, 4, 4)).astype(nptype)
x2_np = np.random.uniform(-2, 2, (2, 1, 1, 4)).astype(nptype)
y2_np = np.random.uniform(-2, 2, (2, 3, 4, 4)).astype(nptype)
x3_np = np.random.uniform(-2, 2, 1).astype(nptype)
y3_np = np.random.uniform(-2, 2, 1).astype(nptype)
x4_np = np.array(78).astype(nptype)
y4_np = np.array(37.5).astype(nptype)


x0 = Tensor(x0_np) x0 = Tensor(x0_np)
y0 = Tensor(y0_np) y0 = Tensor(y0_np)
@@ -58,36 +55,36 @@ def test_mul():
y4 = Tensor(y4_np) y4 = Tensor(y4_np)


context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
mul = NetMul()
output0 = mul(x0, y0)
mul_net = NetMul()
output0 = mul_net(x0, y0)
expect0 = np.multiply(x0_np, y0_np) expect0 = np.multiply(x0_np, y0_np)
diff0 = output0.asnumpy() - expect0 diff0 = output0.asnumpy() - expect0
error0 = np.ones(shape=expect0.shape) * 1.0e-5 error0 = np.ones(shape=expect0.shape) * 1.0e-5
assert np.all(diff0 < error0) assert np.all(diff0 < error0)
assert output0.shape == expect0.shape assert output0.shape == expect0.shape


output1 = mul(x1, y1)
output1 = mul_net(x1, y1)
expect1 = np.multiply(x1_np, y1_np) expect1 = np.multiply(x1_np, y1_np)
diff1 = output1.asnumpy() - expect1 diff1 = output1.asnumpy() - expect1
error1 = np.ones(shape=expect1.shape) * 1.0e-5 error1 = np.ones(shape=expect1.shape) * 1.0e-5
assert np.all(diff1 < error1) assert np.all(diff1 < error1)
assert output1.shape == expect1.shape assert output1.shape == expect1.shape


output2 = mul(x2, y2)
output2 = mul_net(x2, y2)
expect2 = np.multiply(x2_np, y2_np) expect2 = np.multiply(x2_np, y2_np)
diff2 = output2.asnumpy() - expect2 diff2 = output2.asnumpy() - expect2
error2 = np.ones(shape=expect2.shape) * 1.0e-5 error2 = np.ones(shape=expect2.shape) * 1.0e-5
assert np.all(diff2 < error2) assert np.all(diff2 < error2)
assert output2.shape == expect2.shape assert output2.shape == expect2.shape


output3 = mul(x3, y3)
output3 = mul_net(x3, y3)
expect3 = np.multiply(x3_np, y3_np) expect3 = np.multiply(x3_np, y3_np)
diff3 = output3.asnumpy() - expect3 diff3 = output3.asnumpy() - expect3
error3 = np.ones(shape=expect3.shape) * 1.0e-5 error3 = np.ones(shape=expect3.shape) * 1.0e-5
assert np.all(diff3 < error3) assert np.all(diff3 < error3)
assert output3.shape == expect3.shape assert output3.shape == expect3.shape


output4 = mul(x4, y4)
output4 = mul_net(x4, y4)
expect4 = np.multiply(x4_np, y4_np) expect4 = np.multiply(x4_np, y4_np)
diff4 = output4.asnumpy() - expect4 diff4 = output4.asnumpy() - expect4
error4 = np.ones(shape=expect4.shape) * 1.0e-5 error4 = np.ones(shape=expect4.shape) * 1.0e-5
@@ -95,42 +92,72 @@ def test_mul():
assert output4.shape == expect4.shape assert output4.shape == expect4.shape


context.set_context(mode=context.GRAPH_MODE, device_target="GPU") context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
mul = NetMul()
output0 = mul(x0, y0)
mul_net = NetMul()
output0 = mul_net(x0, y0)
expect0 = np.multiply(x0_np, y0_np) expect0 = np.multiply(x0_np, y0_np)
diff0 = output0.asnumpy() - expect0 diff0 = output0.asnumpy() - expect0
error0 = np.ones(shape=expect0.shape) * 1.0e-5 error0 = np.ones(shape=expect0.shape) * 1.0e-5
assert np.all(diff0 < error0) assert np.all(diff0 < error0)
assert output0.shape == expect0.shape assert output0.shape == expect0.shape


output1 = mul(x1, y1)
output1 = mul_net(x1, y1)
expect1 = np.multiply(x1_np, y1_np) expect1 = np.multiply(x1_np, y1_np)
diff1 = output1.asnumpy() - expect1 diff1 = output1.asnumpy() - expect1
error1 = np.ones(shape=expect1.shape) * 1.0e-5 error1 = np.ones(shape=expect1.shape) * 1.0e-5
assert np.all(diff1 < error1) assert np.all(diff1 < error1)
assert output1.shape == expect1.shape assert output1.shape == expect1.shape


output2 = mul(x2, y2)
output2 = mul_net(x2, y2)
expect2 = np.multiply(x2_np, y2_np) expect2 = np.multiply(x2_np, y2_np)
diff2 = output2.asnumpy() - expect2 diff2 = output2.asnumpy() - expect2
error2 = np.ones(shape=expect2.shape) * 1.0e-5 error2 = np.ones(shape=expect2.shape) * 1.0e-5
assert np.all(diff2 < error2) assert np.all(diff2 < error2)
assert output2.shape == expect2.shape assert output2.shape == expect2.shape


output3 = mul(x3, y3)
output3 = mul_net(x3, y3)
expect3 = np.multiply(x3_np, y3_np) expect3 = np.multiply(x3_np, y3_np)
diff3 = output3.asnumpy() - expect3 diff3 = output3.asnumpy() - expect3
error3 = np.ones(shape=expect3.shape) * 1.0e-5 error3 = np.ones(shape=expect3.shape) * 1.0e-5
assert np.all(diff3 < error3) assert np.all(diff3 < error3)
assert output3.shape == expect3.shape assert output3.shape == expect3.shape


output4 = mul(x4, y4)
output4 = mul_net(x4, y4)
expect4 = np.multiply(x4_np, y4_np) expect4 = np.multiply(x4_np, y4_np)
diff4 = output4.asnumpy() - expect4 diff4 = output4.asnumpy() - expect4
error4 = np.ones(shape=expect4.shape) * 1.0e-5 error4 = np.ones(shape=expect4.shape) * 1.0e-5
assert np.all(diff4 < error4) assert np.all(diff4 < error4)
assert output4.shape == expect4.shape assert output4.shape == expect4.shape


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_mul_float64():
mul(np.float64)

@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_mul_float32():
mul(np.float32)

@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_mul_float16():
mul(np.float16)

@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_mul_int64():
mul(np.int64)

@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_mul_int32():
mul(np.int32)

class NetMul_dynamic(nn.Cell): class NetMul_dynamic(nn.Cell):
def __init__(self): def __init__(self):
super(NetMul_dynamic, self).__init__() super(NetMul_dynamic, self).__init__()
@@ -143,14 +170,12 @@ class NetMul_dynamic(nn.Cell):
out = self.mul(x, y) out = self.mul(x, y)
return out return out


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_mul_dynamic():
x1_np = np.array([768]).astype(np.float32)
y1_np = np.array([3072.5]).astype(np.float32)
x2_np = np.random.uniform(-2, 2, (2, 1, 1, 4)).astype(np.float32)
y2_np = np.random.uniform(-2, 2, (2, 3, 4, 4)).astype(np.float32)

def mul_dynamic(nptype):
x1_np = np.array([78]).astype(nptype)
y1_np = np.array([37.5]).astype(nptype)
x2_np = np.random.uniform(-2, 2, (2, 1, 1, 4)).astype(nptype)
y2_np = np.random.uniform(-2, 2, (2, 3, 4, 4)).astype(nptype)


x1 = Tensor(x1_np) x1 = Tensor(x1_np)
y1 = Tensor(y1_np) y1 = Tensor(y1_np)
@@ -159,10 +184,10 @@ def test_mul_dynamic():


context.set_context(mode=context.GRAPH_MODE, device_target="GPU") context.set_context(mode=context.GRAPH_MODE, device_target="GPU")


mul = NetMul_dynamic()
mul_net = NetMul_dynamic()


output1 = mul(x1, y1)
output2 = mul(x2, y2)
output1 = mul_net(x1, y1)
output2 = mul_net(x2, y2)
expect1 = np.multiply(x1_np, y1_np) expect1 = np.multiply(x1_np, y1_np)
expect2 = np.multiply(x2_np, y2_np) expect2 = np.multiply(x2_np, y2_np)
diff1 = output1.asnumpy() - expect1 diff1 = output1.asnumpy() - expect1
@@ -173,3 +198,33 @@ def test_mul_dynamic():
error2 = np.ones(shape=expect2.shape) * 1.0e-5 error2 = np.ones(shape=expect2.shape) * 1.0e-5
assert np.all(diff2 < error2) assert np.all(diff2 < error2)
assert output2.shape == expect2.shape assert output2.shape == expect2.shape

@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_mul_dynamic_float64():
mul_dynamic(np.float64)

@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_mul_dynamic_float32():
mul_dynamic(np.float32)

@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_mul_dynamic_float16():
mul_dynamic(np.float16)

@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_mul_dynamic_int64():
mul_dynamic(np.int64)

@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_mul_dynamic_int32():
mul_dynamic(np.int32)

+ 54
- 27
tests/st/ops/gpu/test_sub_op.py View File

@@ -1,4 +1,4 @@
# Copyright 2019 Huawei Technologies Co., Ltd
# Copyright 2019-2021 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@@ -31,20 +31,17 @@ class Net(nn.Cell):
return self.sub(x, y) return self.sub(x, y)




@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_Sub():
np_x0 = np.random.uniform(-2, 2, (2, 3, 4, 4)).astype(np.float32)
np_y0 = np.random.uniform(-2, 2, (2, 3, 4, 4)).astype(np.float32)
np_x1 = np.random.uniform(-2, 2, (2, 3, 4, 4)).astype(np.float32)
np_y1 = np.random.uniform(-2, 2, (2, 1, 4, 4)).astype(np.float32)
np_x2 = np.random.uniform(-2, 2, (2, 1, 1, 4)).astype(np.float32)
np_y2 = np.random.uniform(-2, 2, (2, 3, 4, 4)).astype(np.float32)
np_x3 = np.random.uniform(-2, 2, 1).astype(np.float32)
np_y3 = np.random.uniform(-2, 2, 1).astype(np.float32)
np_x4 = np.array(768).astype(np.float32)
np_y4 = np.array(3072.5).astype(np.float32)
def sub(nptype):
np_x0 = np.random.uniform(-2, 2, (2, 3, 4, 4)).astype(nptype)
np_y0 = np.random.uniform(-2, 2, (2, 3, 4, 4)).astype(nptype)
np_x1 = np.random.uniform(-2, 2, (2, 3, 4, 4)).astype(nptype)
np_y1 = np.random.uniform(-2, 2, (2, 1, 4, 4)).astype(nptype)
np_x2 = np.random.uniform(-2, 2, (2, 1, 1, 4)).astype(nptype)
np_y2 = np.random.uniform(-2, 2, (2, 3, 4, 4)).astype(nptype)
np_x3 = np.random.uniform(-2, 2, 1).astype(nptype)
np_y3 = np.random.uniform(-2, 2, 1).astype(nptype)
np_x4 = np.array(768).astype(nptype)
np_y4 = np.array(3072.5).astype(nptype)
x0 = Tensor(np_x0) x0 = Tensor(np_x0)
y0 = Tensor(np_y0) y0 = Tensor(np_y0)
x1 = Tensor(np_x1) x1 = Tensor(np_x1)
@@ -68,12 +65,12 @@ def test_Sub():
error4 = np.ones(shape=expect4.shape) * 1.0e-5 error4 = np.ones(shape=expect4.shape) * 1.0e-5


context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
sub = Net()
output0 = sub(x0, y0)
output1 = sub(x1, y1)
output2 = sub(x2, y2)
output3 = sub(x3, y3)
output4 = sub(x4, y4)
sub_net = Net()
output0 = sub_net(x0, y0)
output1 = sub_net(x1, y1)
output2 = sub_net(x2, y2)
output3 = sub_net(x3, y3)
output4 = sub_net(x4, y4)
diff0 = output0.asnumpy() - expect0 diff0 = output0.asnumpy() - expect0
assert np.all(diff0 < error0) assert np.all(diff0 < error0)
assert output0.shape == expect0.shape assert output0.shape == expect0.shape
@@ -91,12 +88,12 @@ def test_Sub():
assert output4.shape == expect4.shape assert output4.shape == expect4.shape


context.set_context(mode=context.GRAPH_MODE, device_target="GPU") context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
sub = Net()
output0 = sub(x0, y0)
output1 = sub(x1, y1)
output2 = sub(x2, y2)
output3 = sub(x3, y3)
output4 = sub(x4, y4)
sub_net = Net()
output0 = sub_net(x0, y0)
output1 = sub_net(x1, y1)
output2 = sub_net(x2, y2)
output3 = sub_net(x3, y3)
output4 = sub_net(x4, y4)
diff0 = output0.asnumpy() - expect0 diff0 = output0.asnumpy() - expect0
assert np.all(diff0 < error0) assert np.all(diff0 < error0)
assert output0.shape == expect0.shape assert output0.shape == expect0.shape
@@ -112,3 +109,33 @@ def test_Sub():
diff4 = output4.asnumpy() - expect4 diff4 = output4.asnumpy() - expect4
assert np.all(diff4 < error4) assert np.all(diff4 < error4)
assert output4.shape == expect4.shape assert output4.shape == expect4.shape

@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_sub_float64():
sub(np.float64)

@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_sub_float32():
sub(np.float32)

@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_sub_float16():
sub(np.float16)

@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_sub_int64():
sub(np.int64)

@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_sub_int32():
sub(np.int32)

Loading…
Cancel
Save