diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_self_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_self_cpu_kernel.cc index 8af583bc80..05f14f5ccb 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_self_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_self_cpu_kernel.cc @@ -109,6 +109,16 @@ void Rint(const T *in, T *out, size_t size) { CPUKernelUtils::ParallelFor(task, size); } +template +void Round(const T *in, T *out, size_t size) { + auto task = [&](size_t start, size_t end) { + for (size_t i = start; i < end; i++) { + out[i] = static_cast(nearbyint(in[i])); + } + }; + CPUKernelUtils::ParallelFor(task, size); +} + template void Reciprocal(const T *in, T *out, size_t size) { auto task = [&](size_t start, size_t end) { @@ -251,6 +261,7 @@ static const std::map kArithmeticOpTypeMap = {{prim::k {prim::kPrimSign->name(), SIGN}, {prim::kPrimFloor->name(), FLOOR}, {prim::kPrimRint->name(), RINT}, + {prim::kPrimRound->name(), ROUND}, {prim::kPrimReciprocal->name(), RECIPROCAL}, {prim::kPrimGeLU->name(), GELU}, {prim::kPrimAsin->name(), ASIN}, @@ -317,7 +328,7 @@ void ArithmeticSelfCPUKernel::LaunchKernel(const std::vector &inputs {ATAN, Atan}, {SINH, Sinh}, {COSH, Cosh}, {ASINH, Asinh}, {ACOSH, Acosh}, {ATANH, Atanh}, - {RINT, Rint}}; + {RINT, Rint}, {ROUND, Round}}; if (kArithmeticOpFuncMap.find(operate_type_) != kArithmeticOpFuncMap.end()) { kArithmeticOpFuncMap.at(operate_type_)(input, output, lens); } else { diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_self_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_self_cpu_kernel.h index 2e47dda977..d46f9336e8 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_self_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_self_cpu_kernel.h @@ -67,6 +67,8 @@ MS_REG_CPU_KERNEL(Floor, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutput ArithmeticSelfCPUKernel); MS_REG_CPU_KERNEL(Rint, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), ArithmeticSelfCPUKernel); +MS_REG_CPU_KERNEL(Round, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + ArithmeticSelfCPUKernel); MS_REG_CPU_KERNEL(Reciprocal, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), ArithmeticSelfCPUKernel); MS_REG_CPU_KERNEL(GeLU, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.h index da8b528d1e..0508fe2c1a 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.h @@ -114,6 +114,7 @@ enum OperateType { ACOSHGRAD, ATAN2, RINT, + ROUND, }; class CPUKernel : public kernel::KernelMod { diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_impl.cu index d3647c6f24..b08983a320 100755 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_impl.cu @@ -17,6 +17,13 @@ #include "unary_op_impl.cuh" template __global__ void ExponentialKernel(const T *input, T *output, const size_t count) { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { + output[i] = expf(input[i]); + } + return; +} +template <> +__global__ void ExponentialKernel(const double *input, double *output, const size_t count) { for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { output[i] = exp(input[i]); } @@ -32,7 +39,14 @@ __global__ void ExponentialKernel(const half *input, half *output, const size_t template __global__ void Expm1Kernel(const T *input, T *output, const size_t count) { for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { - output[i] = static_cast(expm1f(static_cast(input[i]))); + output[i] = expm1f(input[i]); + } + return; +} +template <> +__global__ void Expm1Kernel(const double *input, double *output, const size_t count) { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { + output[i] = expm1(input[i]); } return; } @@ -44,6 +58,13 @@ __global__ void LogarithmKernel(const T *input, T *output, const size_t count) { return; } template <> +__global__ void LogarithmKernel(const double *input, double *output, const size_t count) { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { + output[i] = log(input[i]); + } + return; +} +template <> __global__ void LogarithmKernel(const half *input, half *output, const size_t count) { for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { output[i] = hlog(input[i]); @@ -53,21 +74,42 @@ __global__ void LogarithmKernel(const half *input, half *output, const size_t co template __global__ void Log1pKernel(const T *input, T *output, const size_t count) { for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { - output[i] = static_cast(log1pf(static_cast(input[i]))); + output[i] = log1pf(input[i]); + } + return; +} +template <> +__global__ void Log1pKernel(const double *input, double *output, const size_t count) { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { + output[i] = log1p(input[i]); } return; } template __global__ void ErfKernel(const T *input, T *output, const size_t count) { for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { - output[i] = static_cast(erff(static_cast(input[i]))); + output[i] = erff(input[i]); + } + return; +} +template <> +__global__ void ErfKernel(const double *input, double *output, const size_t count) { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { + output[i] = erf(input[i]); } return; } template __global__ void ErfcKernel(const T *input, T *output, const size_t count) { for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { - output[i] = static_cast(erfcf(static_cast(input[i]))); + output[i] = erfcf(input[i]); + } + return; +} +template <> +__global__ void ErfcKernel(const double *input, double *output, const size_t count) { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { + output[i] = erfc(input[i]); } return; } @@ -96,6 +138,13 @@ __global__ void SquareKernel(const T *input, T *output, const size_t count) { } template __global__ void SqrtKernel(const T *input, T *output, const size_t count) { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { + output[i] = sqrtf(input[i]); + } + return; +} +template <> +__global__ void SqrtKernel(const double *input, double *output, const size_t count) { for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { output[i] = sqrt(input[i]); } @@ -110,6 +159,13 @@ __global__ void SqrtKernel(const half *input, half *output, const size_t count) } template __global__ void RsqrtKernel(const T *input, T *output, const size_t count) { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { + output[i] = rsqrtf(input[i]); + } + return; +} +template <> +__global__ void RsqrtKernel(const double *input, double *output, const size_t count) { for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { output[i] = rsqrt(input[i]); } @@ -124,6 +180,13 @@ __global__ void RsqrtKernel(const half *input, half *output, const size_t count) } template __global__ void SinKernel(const T *input, T *output, const size_t count) { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { + output[i] = sinf(input[i]); + } + return; +} +template <> +__global__ void SinKernel(const double *input, double *output, const size_t count) { for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { output[i] = sin(input[i]); } @@ -139,23 +202,40 @@ __global__ void SinKernel(const half *input, half *output, const size_t count) { template __global__ void AsinKernel(const T *input, T *output, const size_t count) { for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { - float inputf = static_cast(input[i]); - T res = static_cast(asinf(inputf)); - output[i] = res; + output[i] = asinf(input[i]); + } + return; +} +template <> +__global__ void AsinKernel(const double *input, double *output, const size_t count) { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { + output[i] = asin(input[i]); } return; } template __global__ void AsinhKernel(const T *input, T *output, const size_t count) { for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { - float inputf = static_cast(input[i]); - T res = static_cast(asinhf(inputf)); - output[i] = res; + output[i] = asinhf(input[i]); + } + return; +} +template <> +__global__ void AsinhKernel(const double *input, double *output, const size_t count) { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { + output[i] = asinh(input[i]); } return; } template __global__ void CosKernel(const T *input, T *output, const size_t count) { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { + output[i] = cosf(input[i]); + } + return; +} +template <> +__global__ void CosKernel(const double *input, double *output, const size_t count) { for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { output[i] = cos(input[i]); } @@ -171,27 +251,42 @@ __global__ void CosKernel(const half *input, half *output, const size_t count) { template __global__ void ACosKernel(const T *input, T *output, const size_t count) { for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { - float inputf = static_cast(input[i]); - T res = static_cast(acosf(inputf)); - output[i] = res; + output[i] = acosf(input[i]); + } + return; +} +template <> +__global__ void ACosKernel(const double *input, double *output, const size_t count) { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { + output[i] = acos(input[i]); } return; } template __global__ void AcoshKernel(const T *input, T *output, const size_t count) { for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { - float inputf = static_cast(input[i]); - T res = static_cast(acoshf(inputf)); - output[i] = res; + output[i] = acoshf(input[i]); + } + return; +} +template <> +__global__ void AcoshKernel(const double *input, double *output, const size_t count) { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { + output[i] = acosh(input[i]); } return; } template __global__ void AtanKernel(const T *input, T *output, const size_t count) { for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { - float inputf = static_cast(input[i]); - T res = static_cast(atanf(inputf)); - output[i] = res; + output[i] = atanf(input[i]); + } + return; +} +template <> +__global__ void AtanKernel(const double *input, double *output, const size_t count) { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { + output[i] = atan(input[i]); } return; } @@ -212,6 +307,13 @@ __global__ void AbsKernel(const half *input, half *output, const size_t count) { } template __global__ void FloorKernel(const T *input, T *output, const size_t count) { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { + output[i] = floorf(input[i]); + } + return; +} +template <> +__global__ void FloorKernel(const double *input, double *output, const size_t count) { for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { output[i] = floor(input[i]); } @@ -226,6 +328,13 @@ __global__ void FloorKernel(const half *input, half *output, const size_t count) } template __global__ void RintKernel(const T *input, T *output, const size_t count) { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { + output[i] = rintf(input[i]); + } + return; +} +template <> +__global__ void RintKernel(const double *input, double *output, const size_t count) { for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { output[i] = rint(input[i]); } @@ -239,6 +348,20 @@ __global__ void RintKernel(const half *input, half *output, const size_t count) return; } template +__global__ void RoundKernel(const T *input, T *output, const size_t count) { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { + output[i] = nearbyintf(input[i]); + } + return; +} +template <> +__global__ void RoundKernel(const double *input, double *output, const size_t count) { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { + output[i] = nearbyint(input[i]); + } + return; +} +template void Exponential(const T *input, T *output, const size_t count, cudaStream_t cuda_stream) { ExponentialKernel<<>>(input, output, count); return; @@ -348,6 +471,11 @@ void Rint(const T *input, T *output, const size_t count, cudaStream_t cuda_strea RintKernel<<>>(input, output, count); return; } +template +void Round(const T *input, T *output, const size_t count, cudaStream_t cuda_stream) { + RoundKernel<<>>(input, output, count); + return; +} // double template void Exponential(const double *input, double *output, const size_t count, cudaStream_t cuda_stream); @@ -371,6 +499,7 @@ template void Rsqrt(const double *input, double *output, const size_t co template void Abs(const double *input, double *output, const size_t count, cudaStream_t cuda_stream); template void Floor(const double *input, double *output, const size_t count, cudaStream_t cuda_stream); template void Rint(const double *input, double *output, const size_t count, cudaStream_t cuda_stream); +template void Round(const double *input, double *output, const size_t count, cudaStream_t cuda_stream); // float @@ -395,6 +524,7 @@ template void Rsqrt(const float *input, float *output, const size_t count template void Abs(const float *input, float *output, const size_t count, cudaStream_t cuda_stream); template void Floor(const float *input, float *output, const size_t count, cudaStream_t cuda_stream); template void Rint(const float *input, float *output, const size_t count, cudaStream_t cuda_stream); +template void Round(const float *input, float *output, const size_t count, cudaStream_t cuda_stream); // half template void Exponential(const half *input, half *output, const size_t count, cudaStream_t cuda_stream); @@ -418,3 +548,28 @@ template void Rsqrt(const half *input, half *output, const size_t count, c template void Abs(const half *input, half *output, const size_t count, cudaStream_t cuda_stream); template void Floor(const half *input, half *output, const size_t count, cudaStream_t cuda_stream); template void Rint(const half *input, half *output, const size_t count, cudaStream_t cuda_stream); +template void Round(const half *input, half *output, const size_t count, cudaStream_t cuda_stream); + +// int32 +template void Exponential(const int *input, int *output, const size_t count, cudaStream_t cuda_stream); +template void Expm1(const int *input, int *output, const size_t count, cudaStream_t cuda_stream); +template void Logarithm(const int *input, int *output, const size_t count, cudaStream_t cuda_stream); +template void Log1p(const int *input, int *output, const size_t count, cudaStream_t cuda_stream); +template void Erf(const int *input, int *output, const size_t count, cudaStream_t cuda_stream); +template void Erfc(const int *input, int *output, const size_t count, cudaStream_t cuda_stream); +template void Negative(const int *input, int *output, const size_t count, cudaStream_t cuda_stream); +template void Reciprocal(const int *input, int *output, const size_t count, cudaStream_t cuda_stream); +template void Square(const int *input, int *output, const size_t count, cudaStream_t cuda_stream); +template void Sqrt(const int *input, int *output, const size_t count, cudaStream_t cuda_stream); +template void Sin(const int *input, int *output, const size_t count, cudaStream_t cuda_stream); +template void Cos(const int *input, int *output, const size_t count, cudaStream_t cuda_stream); +template void Asin(const int *input, int *output, const size_t count, cudaStream_t cuda_stream); +template void ACos(const int *input, int *output, const size_t count, cudaStream_t cuda_stream); +template void Atan(const int *input, int *output, const size_t count, cudaStream_t cuda_stream); +template void Asinh(const int *input, int *output, const size_t count, cudaStream_t cuda_stream); +template void Acosh(const int *input, int *output, const size_t count, cudaStream_t cuda_stream); +template void Rsqrt(const int *input, int *output, const size_t count, cudaStream_t cuda_stream); +template void Abs(const int *input, int *output, const size_t count, cudaStream_t cuda_stream); +template void Floor(const int *input, int *output, const size_t count, cudaStream_t cuda_stream); +template void Rint(const int *input, int *output, const size_t count, cudaStream_t cuda_stream); +template void Round(const int *input, int *output, const size_t count, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_impl.cuh index 4f0d92d81d..c828496f81 100755 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_impl.cuh +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_impl.cuh @@ -60,5 +60,7 @@ template void Floor(const T *input, T *output, const size_t count, cudaStream_t cuda_stream); template void Rint(const T *input, T *output, const size_t count, cudaStream_t cuda_stream); +template +void Round(const T *input, T *output, const size_t count, cudaStream_t cuda_stream); #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_UNARYOPIMPL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_gpu_kernel.cc index 73a4ea9372..6b0897e02b 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_gpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_gpu_kernel.cc @@ -114,5 +114,13 @@ MS_REG_GPU_KERNEL_ONE(Rint, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOut UnaryOpGpuKernel, float) MS_REG_GPU_KERNEL_ONE(Rint, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), UnaryOpGpuKernel, half) +MS_REG_GPU_KERNEL_ONE(Round, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + UnaryOpGpuKernel, int) +MS_REG_GPU_KERNEL_ONE(Round, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), + UnaryOpGpuKernel, double) +MS_REG_GPU_KERNEL_ONE(Round, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + UnaryOpGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(Round, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + UnaryOpGpuKernel, half) } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_gpu_kernel.h index 3d80f116d4..9a73d7947d 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_gpu_kernel.h @@ -49,6 +49,7 @@ enum UnaryOptype { UNARY_OP_ABS, UNARY_OP_FLOOR, UNARY_OP_RINT, + UNARY_OP_ROUND, UNARY_OP_INVALID_TYPE = 255 }; @@ -63,7 +64,7 @@ static const std::map kUnaryOpTypeMap = { {"ACos", UNARY_OP_ACOS}, {"Atan", UNARY_OP_ATAN}, {"Asinh", UNARY_OP_ASINH}, {"Acosh", UNARY_OP_ACOSH}, {"Abs", UNARY_OP_ABS}, {"Floor", UNARY_OP_FLOOR}, - {"Rint", UNARY_OP_RINT}}; + {"Rint", UNARY_OP_RINT}, {"Round", UNARY_OP_ROUND}}; template class UnaryOpGpuKernel : public GpuKernel { @@ -165,6 +166,10 @@ class UnaryOpGpuKernel : public GpuKernel { Rint(input_addr, output_addr, inputs[0]->size / sizeof(T), reinterpret_cast(stream_ptr)); break; } + case UNARY_OP_ROUND: { + Round(input_addr, output_addr, inputs[0]->size / sizeof(T), reinterpret_cast(stream_ptr)); + break; + } default: { MS_LOG(EXCEPTION) << "Unary operation " << unary_op_type_ << " is not supported."; } diff --git a/mindspore/ops/operations/math_ops.py b/mindspore/ops/operations/math_ops.py index 49cd9d6b9b..c9b3ab60e3 100644 --- a/mindspore/ops/operations/math_ops.py +++ b/mindspore/ops/operations/math_ops.py @@ -3919,7 +3919,7 @@ class Round(PrimitiveWithInfer): TypeError: If `input_x` is not a Tensor. Supported Platforms: - ``Ascend`` + ``Ascend`` ``GPU`` ``CPU`` Examples: >>> input_x = Tensor(np.array([0.8, 1.5, 2.3, 2.5, -4.5]), mindspore.float32) diff --git a/tests/st/ops/cpu/test_arithmetic_self_op.py b/tests/st/ops/cpu/test_arithmetic_self_op.py index a7a00f5e9a..2f5bdb78fd 100644 --- a/tests/st/ops/cpu/test_arithmetic_self_op.py +++ b/tests/st/ops/cpu/test_arithmetic_self_op.py @@ -41,6 +41,15 @@ class FloorNet(nn.Cell): return self.floor(x) +class RoundNet(nn.Cell): + def __init__(self): + super(RoundNet, self).__init__() + self.round = P.Round() + + def construct(self, x): + return self.round(x) + + class ReciprocalNet(nn.Cell): def __init__(self): super(ReciprocalNet, self).__init__() @@ -144,6 +153,20 @@ def test_rint(): np.testing.assert_almost_equal(output.asnumpy(), expect_output) +def test_round(): + net = RoundNet() + + x = np.array([0.9920, -0.4077, 0.9734, -1.0362, 1.5, -2.5, 4.5]).astype(np.float16) + output = net(Tensor(x)) + expect_output = np.round(x).astype(np.float16) + np.testing.assert_almost_equal(output.asnumpy(), expect_output) + + x = np.array([0.9920, -0.4077, 0.9734, -1.0362, 1.5, -2.5, 4.5]).astype(np.float32) + output = net(Tensor(x)) + expect_output = np.round(x).astype(np.float32) + np.testing.assert_almost_equal(output.asnumpy(), expect_output) + + @pytest.mark.level0 @pytest.mark.platform_x86_cpu @pytest.mark.env_onecard diff --git a/tests/st/ops/gpu/test_round_op.py b/tests/st/ops/gpu/test_round_op.py new file mode 100644 index 0000000000..b25cdd9650 --- /dev/null +++ b/tests/st/ops/gpu/test_round_op.py @@ -0,0 +1,60 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor, ops + + +class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.round = ops.Round() + + def construct(self, x): + return self.round(x) + + +def generate_testcases(nptype): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + x = np.array([0.9920, -0.4077, 0.9734, -1.0362, 1.5, -2.5, 4.5]).astype(nptype) + net = Net() + output = net(Tensor(x)) + expect = np.round(x).astype(nptype) + np.testing.assert_almost_equal(output.asnumpy(), expect) + + context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") + x = np.array([0.9920, -0.4077, 0.9734, -1.0362, 1.5, -2.5, 4.5]).astype(nptype) + net = Net() + output = net(Tensor(x)) + expect = np.round(x).astype(nptype) + np.testing.assert_almost_equal(output.asnumpy(), expect) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_sign_float32(): + generate_testcases(np.float32) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_sign_float16(): + generate_testcases(np.float16)