From 0d5220d33c368679b94166b1f889d0cd7148cabc Mon Sep 17 00:00:00 2001 From: Peilin Wang Date: Tue, 18 Aug 2020 17:53:11 -0400 Subject: [PATCH] modified documentation and gpu kernel for smoothL1Loss fix pylint changed doc and code for SmoothL1Loss to be same a dchip. fixed grad kernel fix ci --- .../gpu/cuda_impl/smooth_l1_loss_impl.cu | 34 ++++----- .../gpu/cuda_impl/smooth_l1_loss_impl.cuh | 4 +- .../gpu/nn/smooth_l1_loss_gpu_kernel.h | 8 +-- .../gpu/nn/smooth_l1_loss_grad_gpu_kernel.h | 8 +-- mindspore/ops/_grad/grad_nn_ops.py | 2 +- mindspore/ops/operations/_grad_ops.py | 2 +- mindspore/ops/operations/nn_ops.py | 12 ++-- tests/st/ops/gpu/test_smoothl1loss_op.py | 72 ++++++++++++++----- 8 files changed, 90 insertions(+), 52 deletions(-) diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/smooth_l1_loss_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/smooth_l1_loss_impl.cu index 9050044b7f..bc3e38b8a4 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/smooth_l1_loss_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/smooth_l1_loss_impl.cu @@ -18,47 +18,47 @@ #include "runtime/device/gpu/cuda_common.h" template -__global__ void SmoothL1LossKernel(const int input_size, const float sigma, const T *prediction, const T *target, +__global__ void SmoothL1LossKernel(const int input_size, const float beta, const T *prediction, const T *target, T *loss) { for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < input_size; i += blockDim.x * gridDim.x) { - T value = (prediction[i] - target[i]) > 0 ? (prediction[i] - target[i]) : (target[i] - prediction[i]); - if (value < sigma) { - loss[i] = static_cast(0.5) * value * value; + T value = fabsf(prediction[i] - target[i]); + if (value < beta) { + loss[i] = 0.5 * value * value / beta; } else { - loss[i] = value - static_cast(0.5); + loss[i] = value - (0.5 * beta); } } } template -void SmoothL1Loss(const int &input_size, const float &sigma, const T *prediction, const T *target, T *loss, +void SmoothL1Loss(const int &input_size, const float &beta, const T *prediction, const T *target, T *loss, cudaStream_t stream) { - SmoothL1LossKernel<<>>(input_size, sigma, prediction, target, loss); + SmoothL1LossKernel<<>>(input_size, beta, prediction, target, loss); } template -__global__ void SmoothL1LossGradKernel(const int input_size, const float sigma, const T *prediction, const T *target, +__global__ void SmoothL1LossGradKernel(const int input_size, const float beta, const T *prediction, const T *target, const T *dloss, T *dx) { for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < input_size; i += blockDim.x * gridDim.x) { T value = prediction[i] - target[i]; - if (value > static_cast(sigma)) { + if (value > beta) { dx[i] = dloss[i]; - } else if (value < static_cast(-sigma)) { + } else if (value < -beta) { dx[i] = -dloss[i]; } else { - dx[i] = value * dloss[i]; + dx[i] = (value / beta) * dloss[i]; } } } template -void SmoothL1LossGrad(const int &input_size, const float &sigma, const T *prediction, const T *target, const T *dloss, +void SmoothL1LossGrad(const int &input_size, const float &beta, const T *prediction, const T *target, const T *dloss, T *dx, cudaStream_t stream) { - SmoothL1LossGradKernel<<>>(input_size, sigma, prediction, target, + SmoothL1LossGradKernel<<>>(input_size, beta, prediction, target, dloss, dx); } -template void SmoothL1Loss(const int &input_size, const float &sigma, const float *prediction, const float *target, - float *loss, cudaStream_t stream); -template void SmoothL1LossGrad(const int &input_size, const float &sigma, const float *prediction, const float *target, - const float *dloss, float *dx, cudaStream_t stream); +template void SmoothL1Loss(const int &input_size, const float &beta, const float *prediction, + const float *target, float *loss, cudaStream_t stream); +template void SmoothL1LossGrad(const int &input_size, const float &beta, const float *prediction, + const float *target, const float *dloss, float *dx, cudaStream_t stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/smooth_l1_loss_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/smooth_l1_loss_impl.cuh index 7938e18a3b..ef6409763a 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/smooth_l1_loss_impl.cuh +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/smooth_l1_loss_impl.cuh @@ -17,9 +17,9 @@ #ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SMOOTH_L1_LOSS_H_ #define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SMOOTH_L1_LOSS_H_ template -void SmoothL1Loss(const int &input_size, const float &sigma, const T *prediction, const T *target, T *loss, +void SmoothL1Loss(const int &input_size, const float &beta, const T *prediction, const T *target, T *loss, cudaStream_t stream); template -void SmoothL1LossGrad(const int &input_size, const float &sigma, const T *prediction, const T *target, const T *dloss, +void SmoothL1LossGrad(const int &input_size, const float &beta, const T *prediction, const T *target, const T *dloss, T *dx, cudaStream_t stream); #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SMOOTH_L1_LOSS_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/smooth_l1_loss_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/smooth_l1_loss_gpu_kernel.h index 1ebd56874b..c0c0fb5fc8 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/smooth_l1_loss_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/smooth_l1_loss_gpu_kernel.h @@ -26,7 +26,7 @@ namespace kernel { template class SmoothL1LossGpuKernel : public GpuKernel { public: - SmoothL1LossGpuKernel() : input_size_(1), sigma_(1.0) {} + SmoothL1LossGpuKernel() : input_size_(1), beta_(1.0) {} ~SmoothL1LossGpuKernel() override = default; const std::vector &GetInputSizeList() const override { return input_size_list_; } @@ -39,7 +39,7 @@ class SmoothL1LossGpuKernel : public GpuKernel { T *target = GetDeviceAddress(inputs, 1); T *loss = GetDeviceAddress(outputs, 0); - SmoothL1Loss(input_size_, sigma_, prediction, target, loss, reinterpret_cast(stream_ptr)); + SmoothL1Loss(input_size_, beta_, prediction, target, loss, reinterpret_cast(stream_ptr)); return true; } @@ -49,7 +49,7 @@ class SmoothL1LossGpuKernel : public GpuKernel { input_size_ *= input_shape[i]; } - sigma_ = GetAttr(kernel_node, "sigma"); + beta_ = GetAttr(kernel_node, "beta"); InitSizeLists(); return true; } @@ -63,7 +63,7 @@ class SmoothL1LossGpuKernel : public GpuKernel { private: size_t input_size_; - float sigma_; + float beta_; std::vector input_size_list_; std::vector output_size_list_; diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/smooth_l1_loss_grad_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/smooth_l1_loss_grad_gpu_kernel.h index 88e8bbd30e..a9a716d8f1 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/smooth_l1_loss_grad_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/smooth_l1_loss_grad_gpu_kernel.h @@ -26,7 +26,7 @@ namespace kernel { template class SmoothL1LossGradGpuKernel : public GpuKernel { public: - SmoothL1LossGradGpuKernel() : input_size_(1), sigma_(1.0) {} + SmoothL1LossGradGpuKernel() : input_size_(1), beta_(1.0) {} ~SmoothL1LossGradGpuKernel() override = default; const std::vector &GetInputSizeList() const override { return input_size_list_; } @@ -40,7 +40,7 @@ class SmoothL1LossGradGpuKernel : public GpuKernel { T *dloss = GetDeviceAddress(inputs, 2); T *dx = GetDeviceAddress(outputs, 0); - SmoothL1LossGrad(input_size_, sigma_, prediction, target, dloss, dx, reinterpret_cast(stream_ptr)); + SmoothL1LossGrad(input_size_, beta_, prediction, target, dloss, dx, reinterpret_cast(stream_ptr)); return true; } @@ -50,7 +50,7 @@ class SmoothL1LossGradGpuKernel : public GpuKernel { input_size_ *= input_shape[i]; } - sigma_ = GetAttr(kernel_node, "sigma"); + beta_ = GetAttr(kernel_node, "beta"); InitSizeLists(); return true; } @@ -64,7 +64,7 @@ class SmoothL1LossGradGpuKernel : public GpuKernel { private: size_t input_size_; - float sigma_; + float beta_; std::vector input_size_list_; std::vector output_size_list_; diff --git a/mindspore/ops/_grad/grad_nn_ops.py b/mindspore/ops/_grad/grad_nn_ops.py index 8f4cf8496d..3e0443a9e2 100755 --- a/mindspore/ops/_grad/grad_nn_ops.py +++ b/mindspore/ops/_grad/grad_nn_ops.py @@ -689,7 +689,7 @@ def get_bprop_top_kv2(self): @bprop_getters.register(P.SmoothL1Loss) def get_bprop_smooth_l1_loss(self): """Grad definition for `SmoothL1Loss` operation.""" - grad = G.SmoothL1LossGrad(self.sigma) + grad = G.SmoothL1LossGrad(self.beta) def bprop(prediction, target, out, dout): dx = grad(prediction, target, dout) diff --git a/mindspore/ops/operations/_grad_ops.py b/mindspore/ops/operations/_grad_ops.py index 417441fb41..10c9256e50 100644 --- a/mindspore/ops/operations/_grad_ops.py +++ b/mindspore/ops/operations/_grad_ops.py @@ -1258,7 +1258,7 @@ class SmoothL1LossGrad(PrimitiveWithInfer): """Computes gradient for prediction on SmoothL1Loss.""" @prim_attr_register - def __init__(self, sigma=1.0): + def __init__(self, beta=1.0): pass def infer_shape(self, prediction, target, dloss): diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index 50316fab8f..89ec0483c1 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -1658,11 +1658,11 @@ class SmoothL1Loss(PrimitiveWithInfer): Sets input prediction as `X`, input target as `Y`, output as `loss`. Then, .. math:: - \text{SmoothL1Loss} = \begin{cases}0.5x^{2}, &if \left |x \right |\leq \text{sigma} \cr - \left |x \right|-0.5, &\text{otherwise}\end{cases} + \text{SmoothL1Loss} = \begin{cases} \frac{0.5 x^{2}}{\text{beta}, &if \left |x \right | < \text{beta} \cr + \left |x \right|-0.5 \text{beta}, &\text{otherwise}\end{cases} Args: - sigma (float): A parameter used to control the point where the function will change from + beta (float): A parameter used to control the point where the function will change from quadratic to linear. Default: 1.0. Inputs: @@ -1681,9 +1681,9 @@ class SmoothL1Loss(PrimitiveWithInfer): """ @prim_attr_register - def __init__(self, sigma=1.0): - validator.check_value_type('sigma', sigma, [float], self.name) - validator.check('sigma', sigma, '', 0, Rel.GT, self.name) + def __init__(self, beta=1.0): + validator.check_value_type('beta', beta, [float], self.name) + validator.check('beta', beta, '', 0, Rel.GT, self.name) self.init_prim_io_names(inputs=['prediction', 'target'], outputs=['output']) def infer_shape(self, prediction, target): diff --git a/tests/st/ops/gpu/test_smoothl1loss_op.py b/tests/st/ops/gpu/test_smoothl1loss_op.py index 040f404eb0..10d8411d20 100644 --- a/tests/st/ops/gpu/test_smoothl1loss_op.py +++ b/tests/st/ops/gpu/test_smoothl1loss_op.py @@ -21,25 +21,39 @@ import mindspore.nn as nn from mindspore import Tensor from mindspore.ops import composite as C -context.set_context(mode=context.GRAPH_MODE, device_target="GPU", save_graphs=True) +def smoothl1loss(beta): + np.random.seed(42) + prediction = np.random.randn(20).astype(np.float32) + target = np.random.randn(20).astype(np.float32) + + net = nn.SmoothL1Loss(beta) + return net(Tensor(prediction), Tensor(target)) @pytest.mark.level0 @pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard def test_smoothl1loss(): - np.random.seed(42) - prediction = np.random.randn(20).astype(np.float32) - target = np.random.randn(20).astype(np.float32) - sigma = 1.0 + context.set_context(mode=context.GRAPH_MODE, device_target="GPU", save_graphs=True) + + epsilon = 1e-6 - net = nn.SmoothL1Loss(sigma) - loss = net(Tensor(prediction), Tensor(target)) + beta = 1.0 + loss = smoothl1loss(beta) expect = [0.46941718, 0.00382918, 0.16829303, 2.447778, 0.04812113, 0.05953304, 2.2302065, 0.07672881, 0.00860204, 0.34798968, 0.00956192, 1.818008, 0.03262977, 0.36599946, 2.047463, 0.2168481, 0.7216947, 1.7739174, 0.08826803, 1.109165] - assert np.allclose(loss.asnumpy(), expect) + diff = np.absolute(loss.asnumpy() - np.array(expect)) + assert(diff < epsilon).all() + beta = 1 / 9 + loss = smoothl1loss(beta) + expect = [0.9133791, 0.03446258, 0.5246048, 2.8922224, 0.2546738, 0.289504, + 2.674651, 0.33618113, 0.07560876, 0.7786982, 0.08273339, 2.2624524, + 0.19990394, 0.8000138, 2.4919074, 0.6030006, 1.1661391, 2.2183619, + 0.3646064, 1.5536094] + diff = np.absolute(loss.asnumpy() - np.array(expect)) + assert(diff < epsilon).all() class Grad(nn.Cell): @@ -53,20 +67,26 @@ class Grad(nn.Cell): return gout -@pytest.mark.level0 -@pytest.mark.platform_x86_gpu_training -@pytest.mark.env_onecard -def test_smoothl1loss_grad(): +def smoothl1loss_grad(beta): np.random.seed(42) prediction = np.random.randn(20).astype(np.float32) target = np.random.randn(20).astype(np.float32) sens = np.random.randn(20).astype(np.float32) - sigma = 1.0 - net = nn.SmoothL1Loss(sigma) + net = nn.SmoothL1Loss(beta) grad = Grad(net) - dx = grad(Tensor(prediction), Tensor(target), Tensor(sens)) + return grad(Tensor(prediction), Tensor(target), Tensor(sens)) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_smoothl1loss_grad(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU", save_graphs=True) + + epsilon = 1e-6 + beta = 1.0 + dx = smoothl1loss_grad(beta) dx1_expect = [-0.71552587, 0.01499678, -0.06709455, -0.30110368, -0.45868093, 0.24838912, -0.46063876, 0.41411355, 0.04507046, -1.4708229, 0.04481723, 0.38508227, -0.17292616, -0.52333146, -1.0309995, @@ -77,5 +97,23 @@ def test_smoothl1loss_grad(): -0.04481723, -0.38508227, 0.17292616, 0.52333146, 1.0309995, -0.61330026, -0.83921754, 0.3092124, -0.1391843, 0.9755451] - assert np.allclose(dx[0].asnumpy(), dx1_expect) - assert np.allclose(dx[1].asnumpy(), dx2_expect) + diff1 = np.absolute(dx[0].asnumpy() - np.array(dx1_expect)) + diff2 = np.absolute(dx[1].asnumpy() - np.array(dx2_expect)) + assert(diff1 < epsilon).all() + assert(diff2 < epsilon).all() + + beta = 1 / 9 + dx = smoothl1loss_grad(beta) + dx1_expect = [-0.73846656, 0.13497104, -0.11564828, -0.30110368, -1.478522, + 0.7198442, -0.46063876, 1.0571222, 0.3436183, -1.7630402, + 0.32408398, 0.38508227, -0.676922, -0.6116763, -1.0309995, + 0.93128014, 0.83921754, -0.3092124, 0.33126342, -0.9755451] + dx2_expect = [0.73846656, -0.13497104, 0.11564828, 0.30110368, 1.478522, + -0.7198442, 0.46063876, -1.0571222, -0.3436183, 1.7630402, + -0.32408398, -0.38508227, 0.676922, 0.6116763, 1.0309995, + -0.93128014, -0.83921754, 0.3092124, -0.33126342, 0.9755451] + + diff1 = np.absolute(dx[0].asnumpy() - np.array(dx1_expect)) + diff2 = np.absolute(dx[1].asnumpy() - np.array(dx2_expect)) + assert(diff1 < epsilon).all() + assert(diff2 < epsilon).all()