Browse Source

Add floormod op for cpu and gpu

pull/15554/head
xcnick 4 years ago
parent
commit
181844a3d9
10 changed files with 226 additions and 6 deletions
  1. +19
    -1
      mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_cpu_kernel.cc
  2. +16
    -0
      mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_cpu_kernel.h
  3. +1
    -0
      mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.h
  4. +41
    -0
      mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cu
  5. +1
    -0
      mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cuh
  6. +18
    -0
      mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.cc
  7. +5
    -4
      mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.h
  8. +1
    -1
      mindspore/ops/operations/math_ops.py
  9. +104
    -0
      tests/st/ops/cpu/test_arithmetic_op.py
  10. +20
    -0
      tests/st/ops/gpu/test_broadcast_op.py

+ 19
- 1
mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_cpu_kernel.cc View File

@@ -169,6 +169,21 @@ void ArithmeticCPUKernel::Mod(const T *input1, const T *input2, T *out, size_t s
CPUKernelUtils::ParallelFor(task, size);
}

template <typename T>
void ArithmeticCPUKernel::FloorMod(const T *input1, const T *input2, T *out, size_t size) {
auto task = [&](size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
std::vector<size_t> idx;
GenIndex(i, &idx);
auto x = static_cast<double>(input1[idx[0]]);
auto y = static_cast<double>(input2[idx[1]]);
auto res = x - floor(x / y) * y;
out[i] = static_cast<T>((std::abs(res) > 1e-9) && ((res < 0.0) != (y < 0.0)) ? res + y : res);
}
};
CPUKernelUtils::ParallelFor(task, size);
}

template <typename T>
void ArithmeticCPUKernel::Pow(const T *input1, const T *input2, T *out, size_t size) {
auto task = [&](size_t start, size_t end) {
@@ -323,7 +338,8 @@ static const std::map<std::string, OperateType> kArithmeticBinOpTypeMap = {
{prim::kPrimAtan2->name(), ATAN2},
{prim::kPrimRealDiv->name(), REALDIV},
{prim::kPrimEqual->name(), EQUAL},
{prim::kPrimSquaredDifference->name(), SQUAREDDIFFERENCE}};
{prim::kPrimSquaredDifference->name(), SQUAREDDIFFERENCE},
{prim::kPrimFloorMod->name(), FLOORMOD}};

void ArithmeticCPUKernel::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
@@ -471,6 +487,8 @@ void ArithmeticCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, co
FloorDiv<T>(input1, input2, output, lens);
} else if (operate_type_ == MOD) {
Mod<T>(input1, input2, output, lens);
} else if (operate_type_ == FLOORMOD) {
FloorMod<T>(input1, input2, output, lens);
} else if (operate_type_ == POW) {
Pow<T>(input1, input2, output, lens);
} else if (operate_type_ == ASSIGNADD) {


+ 16
- 0
mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_cpu_kernel.h View File

@@ -54,6 +54,8 @@ class ArithmeticCPUKernel : public CPUKernel {
template <typename T>
void Mod(const T *input1, const T *input2, T *out, size_t size);
template <typename T>
void FloorMod(const T *input1, const T *input2, T *out, size_t size);
template <typename T>
void Pow(const T *input1, const T *input2, T *out, size_t size);
template <typename T>
void AssignAdd(T *input1, const T *input2, T *out, size_t size);
@@ -144,6 +146,20 @@ MS_REG_CPU_KERNEL(
MS_REG_CPU_KERNEL(
Mod, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
ArithmeticCPUKernel);
MS_REG_CPU_KERNEL(
FloorMod, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
ArithmeticCPUKernel);
MS_REG_CPU_KERNEL(
FloorMod, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
ArithmeticCPUKernel);
MS_REG_CPU_KERNEL(
FloorMod,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ArithmeticCPUKernel);
MS_REG_CPU_KERNEL(
FloorMod,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
ArithmeticCPUKernel);
MS_REG_CPU_KERNEL(
Less, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool),
ArithmeticCPUKernel);


+ 1
- 0
mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.h View File

@@ -70,6 +70,7 @@ enum OperateType {
REALDIV,
FLOORDIV,
MOD,
FLOORMOD,
NEG,
LESS,
ASSIGNADD,


+ 41
- 0
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cu View File

@@ -220,6 +220,40 @@ struct ModFunc<half2> {
}
};

template <typename T>
struct FloorModFunc {
__device__ __host__ __forceinline__ T operator()(const T &lhs, const T &rhs) {
T res = lhs - floorf(lhs / rhs) * rhs;
res = (std::abs(res) > 1e-9) && ((res < 0.0) != (rhs < 0.0)) ? res + rhs : res;
return res;
}
};

template <>
struct FloorModFunc<half> {
__device__ __host__ __forceinline__ half operator()(const half &lhs, const half &rhs) {
float l = __half2float(lhs);
float r = __half2float(rhs);
float res = l - floorf(l / r) * r;
res = (std::abs(res) > 1e-9) && ((res < 0.0) != (r < 0.0)) ? res + r : res;
return __float2half_rn(res);
}
};

template <>
struct FloorModFunc<half2> {
__device__ __host__ __forceinline__ half2 operator()(const half2 &lhs, const half2 &rhs) {
float2 l = __half22float2(lhs);
float2 r = __half22float2(rhs);
float2 res;
res.x = l.x - floorf(l.x / r.x) * r.x;
res.y = l.y - floorf(l.y / r.y) * r.y;
res.x = (std::abs(res.x) > 1e-9) && ((res.x < 0.0) != (r.x < 0.0)) ? res.x + r.x : res.x;
res.y = (std::abs(res.y) > 1e-9) && ((res.y < 0.0) != (r.y < 0.0)) ? res.y + r.y : res.y;
return __float22half2_rn(res);
}
};

template <typename T>
struct AbsGradFunc {
__device__ __forceinline__ T operator()(const T &lhs, const T &rhs) {
@@ -318,6 +352,8 @@ void ElewiseArithKernel(const int &nums, enum BroadcastOpType op, const T *x0, c
return ElewiseArithKernel<T, SquaredDifferenceFunc<T>><<<(nums + 255) / 256, 256, 0, stream>>>(nums, x0, x1, y);
case BROADCAST_TYPE_MOD:
return ElewiseArithKernel<T, ModFunc<T>><<<(nums + 255) / 256, 256, 0, stream>>>(nums, x0, x1, y);
case BROADCAST_TYPE_FLOORMOD:
return ElewiseArithKernel<T, FloorModFunc<T>><<<(nums + 255) / 256, 256, 0, stream>>>(nums, x0, x1, y);
default:
break;
}
@@ -554,6 +590,11 @@ void BroadcastArith(const std::vector<size_t> &x0_dims, const std::vector<size_t
x0_dims[0], x0_dims[1], x0_dims[2], x0_dims[3], x0_dims[4], x0_dims[5], x0_dims[6], x1_dims[0], x1_dims[1],
x1_dims[2], x1_dims[3], x1_dims[4], x1_dims[5], x1_dims[6], y_dims[0], y_dims[1], y_dims[2], y_dims[3],
y_dims[4], y_dims[5], y_dims[6], x0, x1, y);
case BROADCAST_TYPE_FLOORMOD:
return BroadcastArithKernel<T, FloorModFunc<T>><<<(size + 255) / 256, 256, 0, stream>>>(
x0_dims[0], x0_dims[1], x0_dims[2], x0_dims[3], x0_dims[4], x0_dims[5], x0_dims[6], x1_dims[0], x1_dims[1],
x1_dims[2], x1_dims[3], x1_dims[4], x1_dims[5], x1_dims[6], y_dims[0], y_dims[1], y_dims[2], y_dims[3],
y_dims[4], y_dims[5], y_dims[6], x0, x1, y);
default:
break;
}


+ 1
- 0
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cuh View File

@@ -39,6 +39,7 @@ enum BroadcastOpType {
BROADCAST_TYPE_EQUAL = 13,
BROADCAST_TYPE_SQUARED_DIFFERENCE = 14,
BROADCAST_TYPE_MOD = 15,
BROADCAST_TYPE_FLOORMOD = 16,
BROADCAST_TYPE_INVALID = 0xffffffff,
};



+ 18
- 0
mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.cc View File

@@ -56,6 +56,10 @@ MS_REG_GPU_KERNEL_ONE(
MS_REG_GPU_KERNEL_ONE(
Mod, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
BroadcastOpGpuKernel, double)
MS_REG_GPU_KERNEL_ONE(
FloorMod,
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
BroadcastOpGpuKernel, double)

// fp32
MS_REG_GPU_KERNEL_ONE(
@@ -110,6 +114,10 @@ MS_REG_GPU_KERNEL_ONE(
MS_REG_GPU_KERNEL_ONE(
Mod, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
BroadcastOpGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(
FloorMod,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
BroadcastOpGpuKernel, float)

// fp16
MS_REG_GPU_KERNEL_ONE(
@@ -164,6 +172,10 @@ MS_REG_GPU_KERNEL_ONE(
MS_REG_GPU_KERNEL_ONE(
Mod, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
BroadcastOpGpuKernel, half)
MS_REG_GPU_KERNEL_ONE(
FloorMod,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
BroadcastOpGpuKernel, half)

// int32
MS_REG_GPU_KERNEL_ONE(
@@ -205,6 +217,9 @@ MS_REG_GPU_KERNEL_ONE(
MS_REG_GPU_KERNEL_ONE(
Mod, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
BroadcastOpGpuKernel, int)
MS_REG_GPU_KERNEL_ONE(
FloorMod, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
BroadcastOpGpuKernel, int)

// int64
MS_REG_GPU_KERNEL_ONE(
@@ -246,6 +261,9 @@ MS_REG_GPU_KERNEL_ONE(
MS_REG_GPU_KERNEL_ONE(
Mod, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
BroadcastOpGpuKernel, int64_t)
MS_REG_GPU_KERNEL_ONE(
FloorMod, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
BroadcastOpGpuKernel, int64_t)

// int8
MS_REG_GPU_KERNEL_ONE(


+ 5
- 4
mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.h View File

@@ -143,10 +143,11 @@ class BroadcastOpGpuKernel : public GpuKernel {
}

static std::map<std::string, BroadcastOpType> kBroadcastArithmetricTypeMap = {
{"Maximum", BROADCAST_TYPE_MAXIMUM}, {"Minimum", BROADCAST_TYPE_MINIMUM}, {"Pow", BROADCAST_TYPE_POWER},
{"RealDiv", BROADCAST_TYPE_REALDIV}, {"Mul", BROADCAST_TYPE_MUL}, {"Sub", BROADCAST_TYPE_SUB},
{"Add", BROADCAST_TYPE_ADD}, {"FloorDiv", BROADCAST_TYPE_FLOORDIV}, {"AbsGrad", BROADCAST_TYPE_ABSGRAD},
{"Div", BROADCAST_TYPE_DIV}, {"DivNoNan", BROADCAST_TYPE_DIVNONAN}, {"Mod", BROADCAST_TYPE_MOD},
{"Maximum", BROADCAST_TYPE_MAXIMUM}, {"Minimum", BROADCAST_TYPE_MINIMUM}, {"Pow", BROADCAST_TYPE_POWER},
{"RealDiv", BROADCAST_TYPE_REALDIV}, {"Mul", BROADCAST_TYPE_MUL}, {"Sub", BROADCAST_TYPE_SUB},
{"Add", BROADCAST_TYPE_ADD}, {"FloorDiv", BROADCAST_TYPE_FLOORDIV}, {"AbsGrad", BROADCAST_TYPE_ABSGRAD},
{"Div", BROADCAST_TYPE_DIV}, {"DivNoNan", BROADCAST_TYPE_DIVNONAN}, {"Mod", BROADCAST_TYPE_MOD},
{"FloorMod", BROADCAST_TYPE_FLOORMOD},
};

iter = kBroadcastArithmetricTypeMap.find(kernel_name);


+ 1
- 1
mindspore/ops/operations/math_ops.py View File

@@ -2501,7 +2501,7 @@ class FloorMod(_MathBinaryOp):
TypeError: If neither `input_x` nor `input_y` is a Tensor.

Supported Platforms:
``Ascend``
``Ascend`` ``GPU`` ``CPU``

Examples:
>>> input_x = Tensor(np.array([2, 4, -1]), mindspore.int32)


+ 104
- 0
tests/st/ops/cpu/test_arithmetic_op.py View File

@@ -60,6 +60,15 @@ class ModNet(nn.Cell):
return self.mod(x, y)


class FloorModNet(nn.Cell):
def __init__(self):
super(FloorModNet, self).__init__()
self.floor_mod = P.FloorMod()

def construct(self, x, y):
return self.floor_mod(x, y)


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@@ -324,7 +333,102 @@ def test_mod():
assert output6.shape == expect6.shape


@pytest.mark.level0
@pytest.mark.platform_x86_cpu_training
@pytest.mark.env_onecard
def test_floor_mod():
prop = 1 if np.random.random() < 0.5 else -1
x0_np = np.random.randint(1, 100, (2, 3, 4, 4)).astype(np.float32) * prop
y0_np = np.random.randint(1, 100, (2, 3, 4, 4)).astype(np.float32) * prop
x1_np = np.random.randint(1, 100, (2, 3, 4, 4)).astype(np.float32) * prop
y1_np = np.random.randint(1, 100, (2, 1, 4, 4)).astype(np.float32) * prop
x2_np = np.random.randint(1, 100, (2, 1, 1, 4)).astype(np.float16) * prop
y2_np = np.random.randint(1, 100, (2, 3, 4, 4)).astype(np.float16) * prop
x3_np = np.random.randint(1, 100, 1).astype(np.float32) * prop
y3_np = np.random.randint(1, 100, 1).astype(np.float32) * prop
x4_np = np.array(768).astype(np.float32) * prop
y4_np = np.array(3072.5).astype(np.float32) * prop
x5_np = np.random.randint(1, 100, (2, 1, 1, 4)).astype(np.int32) * prop
y5_np = np.random.randint(1, 100, (2, 3, 4, 4)).astype(np.int32) * prop
x6_np = np.random.randint(1, 100, (2, 3, 4, 4)).astype(np.int32) * prop
y6_np = np.random.randint(1, 100, (2, 3, 4, 4)).astype(np.float32) * prop
x7_np = np.random.randint(1, 100, (2, 1, 1, 4)).astype(np.int64) * prop
y7_np = np.random.randint(1, 100, (2, 3, 4, 4)).astype(np.int64) * prop

x0 = Tensor(x0_np)
y0 = Tensor(y0_np)
x1 = Tensor(x1_np)
y1 = Tensor(y1_np)
x2 = Tensor(x2_np)
y2 = Tensor(y2_np)
x3 = Tensor(x3_np)
y3 = Tensor(y3_np)
x4 = Tensor(x4_np)
y4 = Tensor(y4_np)
x5 = Tensor(x5_np)
y5 = Tensor(y5_np)
x6 = Tensor(x6_np)
y6 = Tensor(y6_np)
x7 = Tensor(x7_np)
y7 = Tensor(y7_np)

context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
floor_mod = FloorModNet()
output0 = floor_mod(x0, y0)
expect0 = np.mod(x0_np, y0_np)
diff0 = output0.asnumpy() - expect0
error0 = np.ones(shape=expect0.shape) * 1.0e-5
assert np.all(diff0 < error0)
assert output0.shape == expect0.shape

output1 = floor_mod(x1, y1)
expect1 = np.mod(x1_np, y1_np)
diff1 = output1.asnumpy() - expect1
error1 = np.ones(shape=expect1.shape) * 1.0e-5
assert np.all(diff1 < error1)
assert output1.shape == expect1.shape

output2 = floor_mod(x2, y2)
expect2 = np.mod(x2_np, y2_np).astype(np.float16)
diff2 = output2.asnumpy() - expect2
error2 = np.ones(shape=expect2.shape) * 1.0e-5
assert np.all(diff2 < error2)
assert output2.shape == expect2.shape

output3 = floor_mod(x3, y3)
expect3 = np.mod(x3_np, y3_np)
diff3 = output3.asnumpy() - expect3
error3 = np.ones(shape=expect3.shape) * 1.0e-5
assert np.all(diff3 < error3)
assert output3.shape == expect3.shape

output4 = floor_mod(x4, y4)
expect4 = np.mod(x4_np, y4_np)
diff4 = output4.asnumpy() - expect4
error4 = np.ones(shape=expect4.shape) * 1.0e-5
assert np.all(diff4 < error4)
assert output4.shape == expect4.shape

output5 = floor_mod(x5, y5)
expect5 = np.mod(x5_np, y5_np)
assert np.all(output5.asnumpy() == expect5)
assert output5.shape == expect5.shape

output6 = floor_mod(x6, y6)
expect6 = np.mod(x6_np, y6_np)
diff6 = output6.asnumpy() - expect6
error6 = np.ones(shape=expect6.shape) * 1.0e-5
assert np.all(diff6 < error6)
assert output6.shape == expect6.shape

output7 = floor_mod(x7, y7)
expect7 = np.mod(x7_np, y7_np).astype(np.int64)
assert np.all(output7.asnumpy() == expect7)
assert output6.shape == expect6.shape


test_sub()
test_div()
test_floor_div()
test_mod()
test_floor_mod()

+ 20
- 0
tests/st/ops/gpu/test_broadcast_op.py View File

@@ -83,6 +83,10 @@ def test_nobroadcast():
output_np = np.fmod(x1_np, x2_np)
assert np.allclose(output_ms.asnumpy(), output_np)

output_ms = P.FloorMod()(Tensor(x1_np), Tensor(x2_np))
output_np = np.mod(x1_np, x2_np)
assert np.allclose(output_ms.asnumpy(), output_np)


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@@ -138,6 +142,10 @@ def test_nobroadcast_fp16():
output_np = np.fmod(x1_np, x2_np)
assert np.allclose(output_ms.asnumpy(), output_np)

output_ms = P.FloorMod()(Tensor(x1_np), Tensor(x2_np))
output_np = np.mod(x1_np, x2_np)
assert np.allclose(output_ms.asnumpy(), output_np)


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@@ -201,6 +209,10 @@ def test_broadcast():
output_np = np.fmod(x1_np, x2_np)
assert np.allclose(output_ms.asnumpy(), output_np)

output_ms = P.FloorMod()(Tensor(x1_np), Tensor(x2_np))
output_np = np.mod(x1_np, x2_np)
assert np.allclose(output_ms.asnumpy(), output_np)


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@@ -264,6 +276,10 @@ def test_broadcast_diff_dims():
output_np = np.fmod(x1_np, x2_np)
assert np.allclose(output_ms.asnumpy(), output_np)

output_ms = P.FloorMod()(Tensor(x1_np), Tensor(x2_np))
output_np = np.mod(x1_np, x2_np)
assert np.allclose(output_ms.asnumpy(), output_np)


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@@ -319,6 +335,10 @@ def test_broadcast_fp16():
output_np = np.fmod(x1_np, x2_np)
assert np.allclose(output_ms.asnumpy(), output_np)

output_ms = P.FloorMod()(Tensor(x1_np), Tensor(x2_np))
output_np = np.mod(x1_np, x2_np)
assert np.allclose(output_ms.asnumpy(), output_np)


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training


Loading…
Cancel
Save