Browse Source

!8283 Minimum Op and Mul Op support dynamic shape

From: @jonwe
Reviewed-by: 
Signed-off-by:
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
71af3bf1ac
10 changed files with 146 additions and 92 deletions
  1. +59
    -11
      mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cu
  2. +3
    -0
      mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cuh
  3. +11
    -0
      mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.cc
  4. +1
    -1
      mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.h
  5. +2
    -0
      mindspore/core/abstract/infer_functions.h
  6. +22
    -78
      mindspore/core/abstract/prim_maths.cc
  7. +1
    -0
      mindspore/core/abstract/primitive_infer_map.cc
  8. +2
    -1
      mindspore/core/base/core_ops.h
  9. +6
    -1
      mindspore/ops/operations/math_ops.py
  10. +39
    -0
      tests/st/ops/gpu/test_broadcast_op.py

+ 59
- 11
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cu View File

@@ -89,6 +89,48 @@ struct AddFunc {
__device__ __host__ __forceinline__ T operator()(const T &lhs, const T &rhs) { return (lhs + rhs); }
};

// DivNoNan check if rhs is less than epsilon
template <typename T>
struct DivNoNanFunc {
// default T is float
__device__ __host__ __forceinline__ T operator()(const T &lhs, const T &rhs) {
return rhs < kFloatEplison && rhs > -kFloatEplison ? 0.0 : (lhs / rhs);
}
};

template <>
struct DivNoNanFunc<int> {
__device__ __host__ __forceinline__ int operator()(const int &lhs, const int &rhs) {
return rhs == 0 ? 0 : (lhs / rhs);
}
};

template <>
struct DivNoNanFunc<half> {
__device__ __host__ __forceinline__ half operator()(const half &lhs, const half &rhs) {
if (__half2float(rhs) < (0.00007) && __half2float(rhs) > -0.00007) {
return static_cast<half>(0.0);
}
return __float2half_rn(__half2float(lhs) / __half2float(rhs));
}
};

template <>
struct DivNoNanFunc<half2> {
__device__ __host__ __forceinline__ half2 operator()(const half2 &lhs, const half2 &rhs) {
float2 l = __half22float2(lhs);
float2 r = __half22float2(rhs);
if ((r.x < kFloatEplison && r.x > -kFloatEplison) || (r.y < kFloatEplison && r.y > -kFloatEplison)) {
l.x = 0.0;
l.y = 0.0;
} else {
l.x = l.x / r.x;
l.y = l.y / r.y;
}
return __float22half2_rn(l);
}
};

// convert to float to fix accuracy issue
template <typename T>
struct FloorDivFunc {
@@ -189,6 +231,8 @@ void ElewiseArithKernel(const int &nums, enum BroadcastOpType op, const T *x0, c
return ElewiseArithKernel<T, AbsGradFunc<T>><<<(nums + 255) / 256, 256, 0, stream>>>(nums, x0, x1, y);
case BROADCAST_TYPE_DIV:
return ElewiseArithKernel<T, DivFunc<T>><<<(nums + 255) / 256, 256, 0, stream>>>(nums, x0, x1, y);
case BROADCAST_TYPE_DIVNONAN:
return ElewiseArithKernel<T, DivNoNanFunc<T>><<<(nums + 255) / 256, 256, 0, stream>>>(nums, x0, x1, y);
default:
break;
}
@@ -222,11 +266,10 @@ template void ElewiseArith(const int &nums, enum BroadcastOpType op, const int *
__device__ __forceinline__ size_t Index(const size_t &index, const size_t &dim) { return dim == 1 ? 0 : index; }

template <typename T, typename Func>
__global__ void BroadcastCmpKernel(const size_t l0, const size_t l1, const size_t l2, const size_t l3,
const size_t l4, const size_t l5, const size_t l6, const size_t r0,
const size_t r1, const size_t r2, const size_t r3, const size_t r4,
const size_t r5, const size_t r6, const size_t d0, const size_t d1,
const size_t d2, const size_t d3, const size_t d4, const size_t d5,
__global__ void BroadcastCmpKernel(const size_t l0, const size_t l1, const size_t l2, const size_t l3, const size_t l4,
const size_t l5, const size_t l6, const size_t r0, const size_t r1, const size_t r2,
const size_t r3, const size_t r4, const size_t r5, const size_t r6, const size_t d0,
const size_t d1, const size_t d2, const size_t d3, const size_t d4, const size_t d5,
const size_t d6, const T *x0, const T *x1, bool *y) {
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < d0 * d1 * d2 * d3 * d4 * d5 * d6;
pos += blockDim.x * gridDim.x) {
@@ -258,8 +301,8 @@ __global__ void BroadcastCmpKernel(const size_t l0, const size_t l1, const size_

template <typename T>
void BroadcastCmp(const std::vector<size_t> &x0_dims, const std::vector<size_t> &x1_dims,
const std::vector<size_t> &y_dims, enum BroadcastOpType op, const T *x0,
const T *x1, bool *y, cudaStream_t stream) {
const std::vector<size_t> &y_dims, enum BroadcastOpType op, const T *x0, const T *x1, bool *y,
cudaStream_t stream) {
size_t size = 1;
for (auto d : y_dims) {
size *= d;
@@ -329,8 +372,8 @@ __global__ void BroadcastArithKernel(const size_t l0, const size_t l1, const siz

template <typename T>
void BroadcastArith(const std::vector<size_t> &x0_dims, const std::vector<size_t> &x1_dims,
const std::vector<size_t> &y_dims, enum BroadcastOpType op, const T *x0,
const T *x1, T *y, cudaStream_t stream) {
const std::vector<size_t> &y_dims, enum BroadcastOpType op, const T *x0, const T *x1, T *y,
cudaStream_t stream) {
size_t size = 1;
for (auto d : y_dims) {
size *= d;
@@ -386,6 +429,11 @@ void BroadcastArith(const std::vector<size_t> &x0_dims, const std::vector<size_t
x0_dims[0], x0_dims[1], x0_dims[2], x0_dims[3], x0_dims[4], x0_dims[5], x0_dims[6], x1_dims[0], x1_dims[1],
x1_dims[2], x1_dims[3], x1_dims[4], x1_dims[5], x1_dims[6], y_dims[0], y_dims[1], y_dims[2], y_dims[3],
y_dims[4], y_dims[5], y_dims[6], x0, x1, y);
case BROADCAST_TYPE_DIVNONAN:
return BroadcastArithKernel<T, DivNoNanFunc<T>><<<(size + 255) / 256, 256, 0, stream>>>(
x0_dims[0], x0_dims[1], x0_dims[2], x0_dims[3], x0_dims[4], x0_dims[5], x0_dims[6], x1_dims[0], x1_dims[1],
x1_dims[2], x1_dims[3], x1_dims[4], x1_dims[5], x1_dims[6], y_dims[0], y_dims[1], y_dims[2], y_dims[3],
y_dims[4], y_dims[5], y_dims[6], x0, x1, y);
default:
break;
}
@@ -419,8 +467,8 @@ __global__ void BroadcastToKernel(const size_t i0, const size_t i1, const size_t

template <typename T>
void BroadcastTo(const size_t &i0, const size_t &i1, const size_t &i2, const size_t &i3, const size_t &o0,
const size_t &o1, const size_t &o2, const size_t &o3, const T *input_addr,
T *output_addr, cudaStream_t stream) {
const size_t &o1, const size_t &o2, const size_t &o3, const T *input_addr, T *output_addr,
cudaStream_t stream) {
size_t nums = o0 * o1 * o2 * o3;
BroadcastToKernel<<<GET_BLOCKS(nums), GET_THREADS, 0, stream>>>(i0, i1, i2, i3, o0, o1, o2, o3, input_addr,
output_addr);


+ 3
- 0
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cuh View File

@@ -20,6 +20,8 @@
#include <vector>
#include "runtime/device/gpu/cuda_common.h"

const float kFloatEplison = 1e-37;

enum BroadcastOpType {
BROADCAST_TYPE_GREATER = 0,
BROADCAST_TYPE_LESS = 1,
@@ -33,6 +35,7 @@ enum BroadcastOpType {
BROADCAST_TYPE_FLOORDIV = 9,
BROADCAST_TYPE_ABSGRAD = 10,
BROADCAST_TYPE_DIV = 11,
BROADCAST_TYPE_DIVNONAN = 12,
BROADCAST_TYPE_INVALID = 0xffffffff,
};



+ 11
- 0
mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.cc View File

@@ -62,6 +62,10 @@ MS_REG_GPU_KERNEL_ONE(
MS_REG_GPU_KERNEL_ONE(
Div, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
BroadcastOpGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(
DivNoNan,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
BroadcastOpGpuKernel, float)

// fp16
MS_REG_GPU_KERNEL_ONE(
@@ -107,6 +111,10 @@ MS_REG_GPU_KERNEL_ONE(
MS_REG_GPU_KERNEL_ONE(
Div, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
BroadcastOpGpuKernel, half)
MS_REG_GPU_KERNEL_ONE(
DivNoNan,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
BroadcastOpGpuKernel, half)

// int32
MS_REG_GPU_KERNEL_ONE(
@@ -136,5 +144,8 @@ MS_REG_GPU_KERNEL_ONE(
MS_REG_GPU_KERNEL_ONE(
Div, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
BroadcastOpGpuKernel, int)
MS_REG_GPU_KERNEL_ONE(
DivNoNan, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
BroadcastOpGpuKernel, int)
} // namespace kernel
} // namespace mindspore

+ 1
- 1
mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.h View File

@@ -145,7 +145,7 @@ class BroadcastOpGpuKernel : public GpuKernel {
{"Maximum", BROADCAST_TYPE_MAXIMUM}, {"Minimum", BROADCAST_TYPE_MINIMUM}, {"Pow", BROADCAST_TYPE_POWER},
{"RealDiv", BROADCAST_TYPE_REALDIV}, {"Mul", BROADCAST_TYPE_MUL}, {"Sub", BROADCAST_TYPE_SUB},
{"TensorAdd", BROADCAST_TYPE_ADD}, {"FloorDiv", BROADCAST_TYPE_FLOORDIV}, {"AbsGrad", BROADCAST_TYPE_ABSGRAD},
{"Div", BROADCAST_TYPE_DIV},
{"Div", BROADCAST_TYPE_DIV}, {"DivNoNan", BROADCAST_TYPE_DIVNONAN},
};

iter = kBroadcastArithmetricTypeMap.find(kernel_name);


+ 2
- 0
mindspore/core/abstract/infer_functions.h View File

@@ -236,6 +236,8 @@ AbstractBasePtr InferImplCast(const AnalysisEnginePtr &, const PrimitivePtr &pri
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplMinimum(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplDivNoNan(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplExpandDims(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
template <typename T>


+ 22
- 78
mindspore/core/abstract/prim_maths.cc View File

@@ -47,27 +47,6 @@ AbstractBasePtr InferImplSqrt(const AnalysisEnginePtr &, const PrimitivePtr &pri
return inp->Clone()->Broaden();
}

AbstractBasePtr InferImplMul(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Inputs: two tensors.
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 2);
ShapePtr shape_x = dyn_cast<Shape>(args_spec_list[0]->GetShapeTrack());
MS_EXCEPTION_IF_NULL(shape_x);
std::vector<int64_t> x_dims = shape_x->shape();
ShapePtr shape_y = dyn_cast<Shape>(args_spec_list[1]->GetShapeTrack());
MS_EXCEPTION_IF_NULL(shape_y);
std::vector<int64_t> y_dims = shape_y->shape();
auto broadcast_shape = BroadcastShape(x_dims, y_dims);
if (broadcast_shape.empty()) {
MS_LOG(EXCEPTION) << "BroadcastShape fail: " << args_spec_list[0]->ToString() << ","
<< args_spec_list[1]->ToString();
}
auto out = args_spec_list[0]->Broaden();
out->set_shape(std::make_shared<Shape>(broadcast_shape));
return out;
}

AbstractBasePtr InferImplTensorAdd(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Inputs: two tensors.
@@ -97,57 +76,6 @@ AbstractBasePtr InferImplSquare(const AnalysisEnginePtr &, const PrimitivePtr &p
return args_spec_list[0]->Broaden();
}

AbstractBasePtr InferImplSub(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 2);
auto input_x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
MS_EXCEPTION_IF_NULL(input_x);
MS_EXCEPTION_IF_NULL(input_x->shape());

auto input_y = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
MS_EXCEPTION_IF_NULL(input_y);
MS_EXCEPTION_IF_NULL(input_y->shape());

auto x_shape = input_x->shape()->shape();
auto y_shape = input_y->shape()->shape();
auto output_shape = BroadcastShape(x_shape, y_shape);
if (output_shape.empty()) {
MS_LOG(EXCEPTION) << "BroadcastShape fail: " << args_spec_list[0]->ToString() << ","
<< args_spec_list[1]->ToString();
}

auto x_type = input_x->BuildType();
MS_EXCEPTION_IF_NULL(x_type);
MS_EXCEPTION_IF_NULL(x_type->cast<TensorTypePtr>());
auto y_type = input_y->BuildType();
MS_EXCEPTION_IF_NULL(y_type);
MS_EXCEPTION_IF_NULL(y_type->cast<TensorTypePtr>());

auto x_element = x_type->cast<TensorTypePtr>()->element();
MS_EXCEPTION_IF_NULL(x_element);
auto y_element = y_type->cast<TensorTypePtr>()->element();
MS_EXCEPTION_IF_NULL(y_element);

auto x_element_type = x_element->number_type();
auto y_element_type = y_element->number_type();

auto x_priority = type_priority_map.find(x_element_type);
if (x_priority == type_priority_map.end()) {
MS_LOG(EXCEPTION) << "input_x type is " << x_element_type << ", it's not number type.";
}
auto y_priority = type_priority_map.find(y_element_type);
if (y_priority == type_priority_map.end()) {
MS_LOG(EXCEPTION) << "input_y type is " << y_element_type << ", it's not number type.";
}

if (x_priority->second >= y_priority->second) {
return std::make_shared<AbstractTensor>(input_x->element(), std::make_shared<Shape>(output_shape));
} else {
return std::make_shared<AbstractTensor>(input_y->element(), std::make_shared<Shape>(output_shape));
}
}

AbstractBasePtr InferImplEqual(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
const std::string op_name = primitive->name();
@@ -173,8 +101,8 @@ AbstractBasePtr InferImplEqual(const AnalysisEnginePtr &, const PrimitivePtr &pr
return ret;
}

AbstractBasePtr InferImplMinimum(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
AbstractBasePtr InferImplBinaryBase(const AnalysisEnginePtr &engine_ptr, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 2);
auto input_x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
@@ -188,10 +116,6 @@ AbstractBasePtr InferImplMinimum(const AnalysisEnginePtr &, const PrimitivePtr &
auto x_shape = input_x->shape()->shape();
auto y_shape = input_y->shape()->shape();
auto output_shape = BroadcastShape(x_shape, y_shape);
if (output_shape.empty()) {
MS_LOG(EXCEPTION) << "BroadcastShape fail: " << args_spec_list[0]->ToString() << ","
<< args_spec_list[1]->ToString();
}

auto x_type = input_x->BuildType();
MS_EXCEPTION_IF_NULL(x_type);
@@ -223,5 +147,25 @@ AbstractBasePtr InferImplMinimum(const AnalysisEnginePtr &, const PrimitivePtr &
return std::make_shared<AbstractTensor>(input_y->element(), std::make_shared<Shape>(output_shape));
}
}

AbstractBasePtr InferImplMinimum(const AnalysisEnginePtr &engine_ptr, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
return InferImplBinaryBase(engine_ptr, primitive, args_spec_list);
}

AbstractBasePtr InferImplMul(const AnalysisEnginePtr &engine_ptr, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
return InferImplBinaryBase(engine_ptr, primitive, args_spec_list);
}

AbstractBasePtr InferImplSub(const AnalysisEnginePtr &engine_ptr, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
return InferImplBinaryBase(engine_ptr, primitive, args_spec_list);
}

AbstractBasePtr InferImplDivNoNan(const AnalysisEnginePtr &engine_ptr, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
return InferImplBinaryBase(engine_ptr, primitive, args_spec_list);
}
} // namespace abstract
} // namespace mindspore

+ 1
- 0
mindspore/core/abstract/primitive_infer_map.cc View File

@@ -44,6 +44,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
{prim::kPrimSub, {InferImplSub, true}},
{prim::kPrimEqual, {InferImplEqual, true}},
{prim::kPrimMinimum, {InferImplMinimum, true}},
{prim::kPrimDivNoNan, {InferImplDivNoNan, true}},
// Array
{prim::kPrimScalarToArray, {InferImplScalarToArray, true}},
{prim::kPrimArrayToScalar, {InferImplArrayToScalar, true}},


+ 2
- 1
mindspore/core/base/core_ops.h View File

@@ -114,7 +114,6 @@ inline const PrimitivePtr kPrimDynamicGRUV2 = std::make_shared<Primitive>("Dynam
inline const PrimitivePtr kPrimDynamicGRUV2Grad = std::make_shared<Primitive>("DynamicGRUV2Grad");
inline const PrimitivePtr kPrimScatterAdd = std::make_shared<Primitive>("ScatterAdd");
inline const PrimitivePtr kPrimScatterUpdate = std::make_shared<Primitive>("ScatterUpdate");
inline const PrimitivePtr kPrimDiv = std::make_shared<Primitive>("Div");

// NN
inline const PrimitivePtr kPrimFlatten = std::make_shared<Primitive>("Flatten");
@@ -215,6 +214,8 @@ inline const PrimitivePtr kPrimReduceMin = std::make_shared<Primitive>("ReduceMi
inline const PrimitivePtr kPrimNeg = std::make_shared<Primitive>("Neg");
inline const PrimitivePtr kPrimSub = std::make_shared<Primitive>("Sub");
inline const PrimitivePtr kPrimMul = std::make_shared<Primitive>("Mul");
inline const PrimitivePtr kPrimDiv = std::make_shared<Primitive>("Div");
inline const PrimitivePtr kPrimDivNoNan = std::make_shared<Primitive>("DivNoNan");
inline const PrimitivePtr kPrimMinimum = std::make_shared<Primitive>("Minimum");
inline const PrimitivePtr kPrimMaximum = std::make_shared<Primitive>("Maximum");
inline const PrimitivePtr kPrimSquare = std::make_shared<Primitive>("Square");


+ 6
- 1
mindspore/ops/operations/math_ops.py View File

@@ -1827,7 +1827,7 @@ class Div(_MathBinaryOp):
return None


class DivNoNan(_MathBinaryOp):
class DivNoNan(PrimitiveWithCheck):
"""
Computes a safe divide which returns 0 if the y is zero.

@@ -1856,6 +1856,11 @@ class DivNoNan(_MathBinaryOp):
[0., 0., 0., 2.5, 2.0]
"""

@prim_attr_register
def __init__(self):
"""Initialize _BinaryOp"""
self.init_prim_io_names(inputs=['x', 'y'], outputs=['output'])

def infer_value(self, x, y):
if x is not None and y is not None:
x = x.asnumpy()


+ 39
- 0
tests/st/ops/gpu/test_broadcast_op.py View File

@@ -71,6 +71,13 @@ def test_nobroadcast():
output_np = x1_np - x2_np
assert np.allclose(output_ms.asnumpy(), output_np)

output_ms = P.DivNoNan()(Tensor(x1_np), Tensor(x2_np))
output_np = x1_np / x2_np
assert np.allclose(output_ms.asnumpy(), output_np)

x2_np_zero = np.zeros_like(x2_np)
output_ms = P.DivNoNan()(Tensor(x1_np), Tensor(x2_np_zero))
assert np.allclose(output_ms.asnumpy(), x2_np_zero)

@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@@ -114,6 +121,14 @@ def test_nobroadcast_fp16():
output_np = x1_np - x2_np
assert np.allclose(output_ms.asnumpy(), output_np)

output_ms = P.DivNoNan()(Tensor(x1_np), Tensor(x2_np))
output_np = x1_np / x2_np
assert np.allclose(output_ms.asnumpy(), output_np)

x2_np_zero = np.zeros_like(x2_np)
output_ms = P.DivNoNan()(Tensor(x1_np), Tensor(x2_np_zero))
assert np.allclose(output_ms.asnumpy(), x2_np_zero)


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@@ -165,6 +180,14 @@ def test_broadcast():
output_np = x1_np - x2_np
assert np.allclose(output_ms.asnumpy(), output_np)

output_ms = P.DivNoNan()(Tensor(x1_np), Tensor(x2_np))
output_np = x1_np / x2_np
assert np.allclose(output_ms.asnumpy(), output_np)

x2_np_zero = np.zeros_like(x2_np)
output_ms = P.DivNoNan()(Tensor(x1_np), Tensor(x2_np_zero))
assert np.allclose(output_ms.asnumpy(), x2_np_zero)


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@@ -216,6 +239,14 @@ def test_broadcast_diff_dims():
output_np = x1_np - x2_np
assert np.allclose(output_ms.asnumpy(), output_np)

output_ms = P.DivNoNan()(Tensor(x1_np), Tensor(x2_np))
output_np = x1_np / x2_np
assert np.allclose(output_ms.asnumpy(), output_np)

x2_np_zero = np.zeros_like(x2_np)
output_ms = P.DivNoNan()(Tensor(x1_np), Tensor(x2_np_zero))
assert np.allclose(output_ms.asnumpy(), x2_np_zero)


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@@ -258,3 +289,11 @@ def test_broadcast_fp16():
output_ms = P.Sub()(Tensor(x1_np), Tensor(x2_np))
output_np = x1_np - x2_np
assert np.allclose(output_ms.asnumpy(), output_np)

output_ms = P.DivNoNan()(Tensor(x1_np), Tensor(x2_np))
output_np = x1_np / x2_np
assert np.allclose(output_ms.asnumpy(), output_np)

x2_np_zero = np.zeros_like(x2_np)
output_ms = P.DivNoNan()(Tensor(x1_np), Tensor(x2_np_zero))
assert np.allclose(output_ms.asnumpy(), x2_np_zero)

Loading…
Cancel
Save