| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -343,6 +343,31 @@ void Floor(const T *input, T *output, const size_t count, cudaStream_t cuda_stre | |||
| return; | |||
| } | |||
| // double | |||
| template void Exponential<double>(const double *input, double *output, const size_t count, cudaStream_t cuda_stream); | |||
| template void Expm1<double>(const double *input, double *output, const size_t count, cudaStream_t cuda_stream); | |||
| template void Logarithm<double>(const double *input, double *output, const size_t count, cudaStream_t cuda_stream); | |||
| template void Log1p<double>(const double *input, double *output, const size_t count, cudaStream_t cuda_stream); | |||
| template void Erf<double>(const double *input, double *output, const size_t count, cudaStream_t cuda_stream); | |||
| template void Erfc<double>(const double *input, double *output, const size_t count, cudaStream_t cuda_stream); | |||
| template void Negative<double>(const double *input, double *output, const size_t count, cudaStream_t cuda_stream); | |||
| template void Reciprocal<double>(const double *input, double *output, const size_t count, cudaStream_t cuda_stream); | |||
| template void Square<double>(const double *input, double *output, const size_t count, cudaStream_t cuda_stream); | |||
| template void Sqrt<double>(const double *input, double *output, const size_t count, cudaStream_t cuda_stream); | |||
| template void Sin<double>(const double *input, double *output, const size_t count, cudaStream_t cuda_stream); | |||
| template void Cos<double>(const double *input, double *output, const size_t count, cudaStream_t cuda_stream); | |||
| template void Asin<double>(const double *input, double *output, const size_t count, cudaStream_t cuda_stream); | |||
| template void ACos<double>(const double *input, double *output, const size_t count, cudaStream_t cuda_stream); | |||
| template void Atan<double>(const double *input, double *output, const size_t count, cudaStream_t cuda_stream); | |||
| template void Asinh<double>(const double *input, double *output, const size_t count, cudaStream_t cuda_stream); | |||
| template void Acosh<double>(const double *input, double *output, const size_t count, cudaStream_t cuda_stream); | |||
| template void Rsqrt<double>(const double *input, double *output, const size_t count, cudaStream_t cuda_stream); | |||
| template void Zeroslike<double>(double *output, const size_t count, cudaStream_t cuda_stream); | |||
| template void Abs<double>(const double *input, double *output, const size_t count, cudaStream_t cuda_stream); | |||
| template void Floor<double>(const double *input, double *output, const size_t count, cudaStream_t cuda_stream); | |||
| // float | |||
| template void Exponential<float>(const float *input, float *output, const size_t count, cudaStream_t cuda_stream); | |||
| template void Expm1<float>(const float *input, float *output, const size_t count, cudaStream_t cuda_stream); | |||
| template void Logarithm<float>(const float *input, float *output, const size_t count, cudaStream_t cuda_stream); | |||
| @@ -364,6 +389,8 @@ template void Rsqrt<float>(const float *input, float *output, const size_t count | |||
| template void Zeroslike<float>(float *output, const size_t count, cudaStream_t cuda_stream); | |||
| template void Abs<float>(const float *input, float *output, const size_t count, cudaStream_t cuda_stream); | |||
| template void Floor<float>(const float *input, float *output, const size_t count, cudaStream_t cuda_stream); | |||
| // half | |||
| template void Exponential<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream); | |||
| template void Expm1<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream); | |||
| template void Logarithm<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream); | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -42,6 +42,8 @@ MS_REG_GPU_KERNEL_ONE(Erfc, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOut | |||
| UnaryOpGpuKernel, float) | |||
| MS_REG_GPU_KERNEL_ONE(Erfc, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | |||
| UnaryOpGpuKernel, half) | |||
| MS_REG_GPU_KERNEL_ONE(Neg, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), | |||
| UnaryOpGpuKernel, double) | |||
| MS_REG_GPU_KERNEL_ONE(Neg, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| UnaryOpGpuKernel, float) | |||
| MS_REG_GPU_KERNEL_ONE(Neg, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | |||
| @@ -58,6 +60,8 @@ MS_REG_GPU_KERNEL_ONE(Square, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddO | |||
| UnaryOpGpuKernel, float) | |||
| MS_REG_GPU_KERNEL_ONE(Square, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | |||
| UnaryOpGpuKernel, half) | |||
| MS_REG_GPU_KERNEL_ONE(Sqrt, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), | |||
| UnaryOpGpuKernel, double) | |||
| MS_REG_GPU_KERNEL_ONE(Sqrt, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| UnaryOpGpuKernel, float) | |||
| MS_REG_GPU_KERNEL_ONE(Sqrt, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | |||
| @@ -66,6 +70,8 @@ MS_REG_GPU_KERNEL_ONE(Rsqrt, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOu | |||
| UnaryOpGpuKernel, float) | |||
| MS_REG_GPU_KERNEL_ONE(Rsqrt, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | |||
| UnaryOpGpuKernel, half) | |||
| MS_REG_GPU_KERNEL_ONE(Sin, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), | |||
| UnaryOpGpuKernel, double) | |||
| MS_REG_GPU_KERNEL_ONE(Sin, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| UnaryOpGpuKernel, float) | |||
| MS_REG_GPU_KERNEL_ONE(Sin, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | |||
| @@ -78,6 +84,8 @@ MS_REG_GPU_KERNEL_ONE(Asinh, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOu | |||
| UnaryOpGpuKernel, float) | |||
| MS_REG_GPU_KERNEL_ONE(Asinh, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | |||
| UnaryOpGpuKernel, half) | |||
| MS_REG_GPU_KERNEL_ONE(Cos, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), | |||
| UnaryOpGpuKernel, double) | |||
| MS_REG_GPU_KERNEL_ONE(Cos, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| UnaryOpGpuKernel, float) | |||
| MS_REG_GPU_KERNEL_ONE(Cos, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | |||
| @@ -94,6 +102,8 @@ MS_REG_GPU_KERNEL_ONE(Atan, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOut | |||
| UnaryOpGpuKernel, float) | |||
| MS_REG_GPU_KERNEL_ONE(Atan, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | |||
| UnaryOpGpuKernel, half) | |||
| MS_REG_GPU_KERNEL_ONE(Abs, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), | |||
| UnaryOpGpuKernel, double) | |||
| MS_REG_GPU_KERNEL_ONE(Abs, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| UnaryOpGpuKernel, float) | |||
| MS_REG_GPU_KERNEL_ONE(Abs, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -20,14 +20,29 @@ import mindspore.context as context | |||
| from mindspore import Tensor | |||
| from mindspore.ops import operations as P | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_cos(): | |||
| x_np = np.random.rand(2, 3, 4, 4).astype(np.float32) | |||
| def cos(nptype): | |||
| np.random.seed(0) | |||
| x_np = np.random.rand(2, 3, 4, 4).astype(nptype) | |||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") | |||
| output_ms = P.Cos()(Tensor(x_np)) | |||
| output_np = np.cos(x_np) | |||
| assert np.allclose(output_ms.asnumpy(), output_np) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_cos_float16(): | |||
| cos(np.float16) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_cos_float32(): | |||
| cos(np.float32) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_cos_float64(): | |||
| cos(np.float64) | |||
| @@ -14,15 +14,14 @@ | |||
| # ============================================================================ | |||
| """ test loss """ | |||
| import numpy as np | |||
| import mindspore | |||
| import pytest | |||
| from mindspore import Tensor | |||
| from mindspore.ops import operations as P | |||
| from mindspore.nn.loss.loss import _Loss | |||
| from mindspore.nn.loss.loss import L1Loss | |||
| import mindspore.context as context | |||
| context.set_context(mode=context.GRAPH_MODE, device_target='GPU') | |||
| class WeightedLoss(_Loss): | |||
| def __init__(self, reduction='mean', weights=1.0): | |||
| super(WeightedLoss, self).__init__(reduction) | |||
| @@ -33,10 +32,13 @@ class WeightedLoss(_Loss): | |||
| x = self.abs(base - target) | |||
| return self.get_loss(x, self.weights) | |||
| def test_WeightedLoss(): | |||
| def weighted_loss(nptype): | |||
| context.set_context(mode=context.GRAPH_MODE, device_target='GPU') | |||
| loss = WeightedLoss() | |||
| input_data = Tensor(np.array([[1, 2, 3], [2, 3, 4]]).astype(np.float32)) | |||
| target_data = Tensor(np.array([[0, 2, 5], [3, 1, 1]]).astype(np.float32)) | |||
| input_data = Tensor(np.array([[1, 2, 3], [2, 3, 4]]).astype(nptype)) | |||
| target_data = Tensor(np.array([[0, 2, 5], [3, 1, 1]]).astype(nptype)) | |||
| output_data = loss(input_data, target_data) | |||
| error_range = np.ones(shape=output_data.shape) * 10e-6 | |||
| @@ -50,14 +52,26 @@ def test_WeightedLoss(): | |||
| diff = test_output - output_data * 3 | |||
| assert np.all(abs(diff.asnumpy()) < error_range) | |||
| loss = WeightedLoss(weights=Tensor(np.array([[0.7, 0.3], [0.7, 0.3]]))) | |||
| y_true = Tensor(np.array([[0., 1.], [0., 0.]]), mindspore.float32) | |||
| y_pred = Tensor(np.array([[1., 1.], [1., 0.]]), mindspore.float32) | |||
| loss = WeightedLoss(weights=Tensor(np.array([[0.7, 0.3], [0.7, 0.3]]).astype(nptype))) | |||
| y_true = Tensor(np.array([[0., 1.], [0., 0.]]).astype(nptype)) | |||
| y_pred = Tensor(np.array([[1., 1.], [1., 0.]]).astype(nptype)) | |||
| test_data = 0.35 | |||
| output = loss(y_true, y_pred) | |||
| diff = test_data - output.asnumpy() | |||
| assert np.all(abs(diff) < error_range) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_weighted_loss_float32(): | |||
| weighted_loss(np.float32) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_weighted_loss_float64(): | |||
| weighted_loss(np.float64) | |||
| class CustomLoss(_Loss): | |||
| def __init__(self, reduction='mean'): | |||
| super(CustomLoss, self).__init__(reduction) | |||
| @@ -67,10 +81,10 @@ class CustomLoss(_Loss): | |||
| x = self.abs(base - target) | |||
| return self.get_loss(x, weights=2.0) | |||
| def test_CustomLoss(): | |||
| def custom_loss(nptype): | |||
| loss = L1Loss() | |||
| input_data = Tensor(np.array([[1, 2, 3], [2, 3, 4]]).astype(np.float32)) | |||
| target_data = Tensor(np.array([[0, 2, 5], [3, 1, 1]]).astype(np.float32)) | |||
| input_data = Tensor(np.array([[1, 2, 3], [2, 3, 4]]).astype(nptype)) | |||
| target_data = Tensor(np.array([[0, 2, 5], [3, 1, 1]]).astype(nptype)) | |||
| output_data = loss(input_data, target_data) | |||
| error_range = np.ones(shape=output_data.shape) * 10e-6 | |||
| @@ -78,3 +92,21 @@ def test_CustomLoss(): | |||
| test_output = customloss(input_data, target_data) | |||
| diff = test_output - output_data * 2.0 | |||
| assert np.all(abs(diff.asnumpy()) < error_range) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_custom_loss_float16(): | |||
| custom_loss(np.float16) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_custom_loss_float32(): | |||
| custom_loss(np.float32) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_custom_loss_float64(): | |||
| custom_loss(np.float64) | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||
| # Copyright 2019-2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -31,12 +31,9 @@ class NetNeg(nn.Cell): | |||
| return self.neg(x) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_neg(): | |||
| x0_np = np.random.uniform(-2, 2, (2, 3, 4, 4)).astype(np.float32) | |||
| x1_np = np.random.uniform(-2, 2, 1).astype(np.float32) | |||
| def neg(nptype): | |||
| x0_np = np.random.uniform(-2, 2, (2, 3, 4, 4)).astype(nptype) | |||
| x1_np = np.random.uniform(-2, 2, 1).astype(nptype) | |||
| x0 = Tensor(x0_np) | |||
| x1 = Tensor(x1_np) | |||
| expect0 = np.negative(x0_np) | |||
| @@ -45,23 +42,41 @@ def test_neg(): | |||
| error1 = np.ones(shape=expect1.shape) * 1.0e-5 | |||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") | |||
| neg = NetNeg() | |||
| output0 = neg(x0) | |||
| neg_net = NetNeg() | |||
| output0 = neg_net(x0) | |||
| diff0 = output0.asnumpy() - expect0 | |||
| assert np.all(diff0 < error0) | |||
| assert output0.shape == expect0.shape | |||
| output1 = neg(x1) | |||
| output1 = neg_net(x1) | |||
| diff1 = output1.asnumpy() - expect1 | |||
| assert np.all(diff1 < error1) | |||
| assert output1.shape == expect1.shape | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| neg = NetNeg() | |||
| output0 = neg(x0) | |||
| neg_net = NetNeg() | |||
| output0 = neg_net(x0) | |||
| diff0 = output0.asnumpy() - expect0 | |||
| assert np.all(diff0 < error0) | |||
| assert output0.shape == expect0.shape | |||
| output1 = neg(x1) | |||
| output1 = neg_net(x1) | |||
| diff1 = output1.asnumpy() - expect1 | |||
| assert np.all(diff1 < error1) | |||
| assert output1.shape == expect1.shape | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_neg_float16(): | |||
| neg(np.float16) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_neg_float32(): | |||
| neg(np.float32) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_neg_float64(): | |||
| neg(np.float64) | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -20,14 +20,29 @@ import mindspore.context as context | |||
| from mindspore import Tensor | |||
| from mindspore.ops import operations as P | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_sin(): | |||
| x_np = np.random.rand(2, 3, 4, 4).astype(np.float32) | |||
| def sin(nptype): | |||
| np.random.seed(0) | |||
| x_np = np.random.rand(2, 3, 4, 4).astype(nptype) | |||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") | |||
| output_ms = P.Sin()(Tensor(x_np)) | |||
| output_np = np.sin(x_np) | |||
| assert np.allclose(output_ms.asnumpy(), output_np) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_sin_float16(): | |||
| sin(np.float16) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_sin_float32(): | |||
| sin(np.float32) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_sin_float64(): | |||
| sin(np.float64) | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -20,18 +20,40 @@ import mindspore.context as context | |||
| from mindspore import Tensor | |||
| from mindspore.ops import operations as P | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_sqrt(): | |||
| x_np = np.random.rand(2, 3, 4, 4).astype(np.float32) | |||
| def sqrt(nptype): | |||
| np.random.seed(0) | |||
| x_np = np.random.rand(2, 3, 4, 4).astype(nptype) | |||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") | |||
| output_ms = P.Sqrt()(Tensor(x_np)) | |||
| output_np = np.sqrt(x_np) | |||
| assert np.allclose(output_ms.asnumpy(), output_np) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_sqrt_float16(): | |||
| sqrt(np.float16) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_sqrt_float32(): | |||
| sqrt(np.float32) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_sqrt_float64(): | |||
| sqrt(np.float64) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_rsqrt(): | |||
| np.random.seed(0) | |||
| x_np = np.random.rand(2, 3, 4, 4).astype(np.float32) | |||
| output_ms = P.Rsqrt()(Tensor(x_np)) | |||
| output_np = 1 / np.sqrt(x_np) | |||
| assert np.allclose(output_ms.asnumpy(), output_np) | |||