Merge pull request !4706 from Peilin/smoothL1Loss-fixtags/v0.7.0-beta
| @@ -18,47 +18,47 @@ | |||||
| #include "runtime/device/gpu/cuda_common.h" | #include "runtime/device/gpu/cuda_common.h" | ||||
| template <typename T> | template <typename T> | ||||
| __global__ void SmoothL1LossKernel(const int input_size, const float sigma, const T *prediction, const T *target, | |||||
| __global__ void SmoothL1LossKernel(const int input_size, const float beta, const T *prediction, const T *target, | |||||
| T *loss) { | T *loss) { | ||||
| for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < input_size; i += blockDim.x * gridDim.x) { | for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < input_size; i += blockDim.x * gridDim.x) { | ||||
| T value = (prediction[i] - target[i]) > 0 ? (prediction[i] - target[i]) : (target[i] - prediction[i]); | |||||
| if (value < sigma) { | |||||
| loss[i] = static_cast<T>(0.5) * value * value; | |||||
| T value = fabsf(prediction[i] - target[i]); | |||||
| if (value < beta) { | |||||
| loss[i] = 0.5 * value * value / beta; | |||||
| } else { | } else { | ||||
| loss[i] = value - static_cast<T>(0.5); | |||||
| loss[i] = value - (0.5 * beta); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| template <typename T> | template <typename T> | ||||
| void SmoothL1Loss(const int &input_size, const float &sigma, const T *prediction, const T *target, T *loss, | |||||
| void SmoothL1Loss(const int &input_size, const float &beta, const T *prediction, const T *target, T *loss, | |||||
| cudaStream_t stream) { | cudaStream_t stream) { | ||||
| SmoothL1LossKernel<<<GET_BLOCKS(input_size), GET_THREADS, 0, stream>>>(input_size, sigma, prediction, target, loss); | |||||
| SmoothL1LossKernel<<<GET_BLOCKS(input_size), GET_THREADS, 0, stream>>>(input_size, beta, prediction, target, loss); | |||||
| } | } | ||||
| template <typename T> | template <typename T> | ||||
| __global__ void SmoothL1LossGradKernel(const int input_size, const float sigma, const T *prediction, const T *target, | |||||
| __global__ void SmoothL1LossGradKernel(const int input_size, const float beta, const T *prediction, const T *target, | |||||
| const T *dloss, T *dx) { | const T *dloss, T *dx) { | ||||
| for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < input_size; i += blockDim.x * gridDim.x) { | for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < input_size; i += blockDim.x * gridDim.x) { | ||||
| T value = prediction[i] - target[i]; | T value = prediction[i] - target[i]; | ||||
| if (value > static_cast<T>(sigma)) { | |||||
| if (value > beta) { | |||||
| dx[i] = dloss[i]; | dx[i] = dloss[i]; | ||||
| } else if (value < static_cast<T>(-sigma)) { | |||||
| } else if (value < -beta) { | |||||
| dx[i] = -dloss[i]; | dx[i] = -dloss[i]; | ||||
| } else { | } else { | ||||
| dx[i] = value * dloss[i]; | |||||
| dx[i] = (value / beta) * dloss[i]; | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| template <typename T> | template <typename T> | ||||
| void SmoothL1LossGrad(const int &input_size, const float &sigma, const T *prediction, const T *target, const T *dloss, | |||||
| void SmoothL1LossGrad(const int &input_size, const float &beta, const T *prediction, const T *target, const T *dloss, | |||||
| T *dx, cudaStream_t stream) { | T *dx, cudaStream_t stream) { | ||||
| SmoothL1LossGradKernel<<<GET_BLOCKS(input_size), GET_THREADS, 0, stream>>>(input_size, sigma, prediction, target, | |||||
| SmoothL1LossGradKernel<<<GET_BLOCKS(input_size), GET_THREADS, 0, stream>>>(input_size, beta, prediction, target, | |||||
| dloss, dx); | dloss, dx); | ||||
| } | } | ||||
| template void SmoothL1Loss(const int &input_size, const float &sigma, const float *prediction, const float *target, | |||||
| float *loss, cudaStream_t stream); | |||||
| template void SmoothL1LossGrad(const int &input_size, const float &sigma, const float *prediction, const float *target, | |||||
| const float *dloss, float *dx, cudaStream_t stream); | |||||
| template void SmoothL1Loss<float>(const int &input_size, const float &beta, const float *prediction, | |||||
| const float *target, float *loss, cudaStream_t stream); | |||||
| template void SmoothL1LossGrad<float>(const int &input_size, const float &beta, const float *prediction, | |||||
| const float *target, const float *dloss, float *dx, cudaStream_t stream); | |||||
| @@ -17,9 +17,9 @@ | |||||
| #ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SMOOTH_L1_LOSS_H_ | #ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SMOOTH_L1_LOSS_H_ | ||||
| #define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SMOOTH_L1_LOSS_H_ | #define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SMOOTH_L1_LOSS_H_ | ||||
| template <typename T> | template <typename T> | ||||
| void SmoothL1Loss(const int &input_size, const float &sigma, const T *prediction, const T *target, T *loss, | |||||
| void SmoothL1Loss(const int &input_size, const float &beta, const T *prediction, const T *target, T *loss, | |||||
| cudaStream_t stream); | cudaStream_t stream); | ||||
| template <typename T> | template <typename T> | ||||
| void SmoothL1LossGrad(const int &input_size, const float &sigma, const T *prediction, const T *target, const T *dloss, | |||||
| void SmoothL1LossGrad(const int &input_size, const float &beta, const T *prediction, const T *target, const T *dloss, | |||||
| T *dx, cudaStream_t stream); | T *dx, cudaStream_t stream); | ||||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SMOOTH_L1_LOSS_H_ | #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SMOOTH_L1_LOSS_H_ | ||||
| @@ -26,7 +26,7 @@ namespace kernel { | |||||
| template <typename T> | template <typename T> | ||||
| class SmoothL1LossGpuKernel : public GpuKernel { | class SmoothL1LossGpuKernel : public GpuKernel { | ||||
| public: | public: | ||||
| SmoothL1LossGpuKernel() : input_size_(1), sigma_(1.0) {} | |||||
| SmoothL1LossGpuKernel() : input_size_(1), beta_(1.0) {} | |||||
| ~SmoothL1LossGpuKernel() override = default; | ~SmoothL1LossGpuKernel() override = default; | ||||
| const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | ||||
| @@ -39,7 +39,7 @@ class SmoothL1LossGpuKernel : public GpuKernel { | |||||
| T *target = GetDeviceAddress<T>(inputs, 1); | T *target = GetDeviceAddress<T>(inputs, 1); | ||||
| T *loss = GetDeviceAddress<T>(outputs, 0); | T *loss = GetDeviceAddress<T>(outputs, 0); | ||||
| SmoothL1Loss(input_size_, sigma_, prediction, target, loss, reinterpret_cast<cudaStream_t>(stream_ptr)); | |||||
| SmoothL1Loss(input_size_, beta_, prediction, target, loss, reinterpret_cast<cudaStream_t>(stream_ptr)); | |||||
| return true; | return true; | ||||
| } | } | ||||
| @@ -49,7 +49,7 @@ class SmoothL1LossGpuKernel : public GpuKernel { | |||||
| input_size_ *= input_shape[i]; | input_size_ *= input_shape[i]; | ||||
| } | } | ||||
| sigma_ = GetAttr<float>(kernel_node, "sigma"); | |||||
| beta_ = GetAttr<float>(kernel_node, "beta"); | |||||
| InitSizeLists(); | InitSizeLists(); | ||||
| return true; | return true; | ||||
| } | } | ||||
| @@ -63,7 +63,7 @@ class SmoothL1LossGpuKernel : public GpuKernel { | |||||
| private: | private: | ||||
| size_t input_size_; | size_t input_size_; | ||||
| float sigma_; | |||||
| float beta_; | |||||
| std::vector<size_t> input_size_list_; | std::vector<size_t> input_size_list_; | ||||
| std::vector<size_t> output_size_list_; | std::vector<size_t> output_size_list_; | ||||
| @@ -26,7 +26,7 @@ namespace kernel { | |||||
| template <typename T> | template <typename T> | ||||
| class SmoothL1LossGradGpuKernel : public GpuKernel { | class SmoothL1LossGradGpuKernel : public GpuKernel { | ||||
| public: | public: | ||||
| SmoothL1LossGradGpuKernel() : input_size_(1), sigma_(1.0) {} | |||||
| SmoothL1LossGradGpuKernel() : input_size_(1), beta_(1.0) {} | |||||
| ~SmoothL1LossGradGpuKernel() override = default; | ~SmoothL1LossGradGpuKernel() override = default; | ||||
| const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | ||||
| @@ -40,7 +40,7 @@ class SmoothL1LossGradGpuKernel : public GpuKernel { | |||||
| T *dloss = GetDeviceAddress<T>(inputs, 2); | T *dloss = GetDeviceAddress<T>(inputs, 2); | ||||
| T *dx = GetDeviceAddress<T>(outputs, 0); | T *dx = GetDeviceAddress<T>(outputs, 0); | ||||
| SmoothL1LossGrad(input_size_, sigma_, prediction, target, dloss, dx, reinterpret_cast<cudaStream_t>(stream_ptr)); | |||||
| SmoothL1LossGrad(input_size_, beta_, prediction, target, dloss, dx, reinterpret_cast<cudaStream_t>(stream_ptr)); | |||||
| return true; | return true; | ||||
| } | } | ||||
| @@ -50,7 +50,7 @@ class SmoothL1LossGradGpuKernel : public GpuKernel { | |||||
| input_size_ *= input_shape[i]; | input_size_ *= input_shape[i]; | ||||
| } | } | ||||
| sigma_ = GetAttr<float>(kernel_node, "sigma"); | |||||
| beta_ = GetAttr<float>(kernel_node, "beta"); | |||||
| InitSizeLists(); | InitSizeLists(); | ||||
| return true; | return true; | ||||
| } | } | ||||
| @@ -64,7 +64,7 @@ class SmoothL1LossGradGpuKernel : public GpuKernel { | |||||
| private: | private: | ||||
| size_t input_size_; | size_t input_size_; | ||||
| float sigma_; | |||||
| float beta_; | |||||
| std::vector<size_t> input_size_list_; | std::vector<size_t> input_size_list_; | ||||
| std::vector<size_t> output_size_list_; | std::vector<size_t> output_size_list_; | ||||
| @@ -713,7 +713,7 @@ def get_bprop_top_kv2(self): | |||||
| @bprop_getters.register(P.SmoothL1Loss) | @bprop_getters.register(P.SmoothL1Loss) | ||||
| def get_bprop_smooth_l1_loss(self): | def get_bprop_smooth_l1_loss(self): | ||||
| """Grad definition for `SmoothL1Loss` operation.""" | """Grad definition for `SmoothL1Loss` operation.""" | ||||
| grad = G.SmoothL1LossGrad(self.sigma) | |||||
| grad = G.SmoothL1LossGrad(self.beta) | |||||
| def bprop(prediction, target, out, dout): | def bprop(prediction, target, out, dout): | ||||
| dx = grad(prediction, target, dout) | dx = grad(prediction, target, dout) | ||||
| @@ -1274,7 +1274,7 @@ class SmoothL1LossGrad(PrimitiveWithInfer): | |||||
| """Computes gradient for prediction on SmoothL1Loss.""" | """Computes gradient for prediction on SmoothL1Loss.""" | ||||
| @prim_attr_register | @prim_attr_register | ||||
| def __init__(self, sigma=1.0): | |||||
| def __init__(self, beta=1.0): | |||||
| pass | pass | ||||
| def infer_shape(self, prediction, target, dloss): | def infer_shape(self, prediction, target, dloss): | ||||
| @@ -1725,11 +1725,11 @@ class SmoothL1Loss(PrimitiveWithInfer): | |||||
| Sets input prediction as `X`, input target as `Y`, output as `loss`. Then, | Sets input prediction as `X`, input target as `Y`, output as `loss`. Then, | ||||
| .. math:: | .. math:: | ||||
| \text{SmoothL1Loss} = \begin{cases}0.5x^{2}, &if \left |x \right |\leq \text{sigma} \cr | |||||
| \left |x \right|-0.5, &\text{otherwise}\end{cases} | |||||
| \text{SmoothL1Loss} = \begin{cases} \frac{0.5 x^{2}}{\text{beta}, &if \left |x \right | < \text{beta} \cr | |||||
| \left |x \right|-0.5 \text{beta}, &\text{otherwise}\end{cases} | |||||
| Args: | Args: | ||||
| sigma (float): A parameter used to control the point where the function will change from | |||||
| beta (float): A parameter used to control the point where the function will change from | |||||
| quadratic to linear. Default: 1.0. | quadratic to linear. Default: 1.0. | ||||
| Inputs: | Inputs: | ||||
| @@ -1748,9 +1748,9 @@ class SmoothL1Loss(PrimitiveWithInfer): | |||||
| """ | """ | ||||
| @prim_attr_register | @prim_attr_register | ||||
| def __init__(self, sigma=1.0): | |||||
| validator.check_value_type('sigma', sigma, [float], self.name) | |||||
| validator.check('sigma', sigma, '', 0, Rel.GT, self.name) | |||||
| def __init__(self, beta=1.0): | |||||
| validator.check_value_type('beta', beta, [float], self.name) | |||||
| validator.check('beta', beta, '', 0, Rel.GT, self.name) | |||||
| self.init_prim_io_names(inputs=['prediction', 'target'], outputs=['output']) | self.init_prim_io_names(inputs=['prediction', 'target'], outputs=['output']) | ||||
| def infer_shape(self, prediction, target): | def infer_shape(self, prediction, target): | ||||
| @@ -21,25 +21,39 @@ import mindspore.nn as nn | |||||
| from mindspore import Tensor | from mindspore import Tensor | ||||
| from mindspore.ops import composite as C | from mindspore.ops import composite as C | ||||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU", save_graphs=True) | |||||
| def smoothl1loss(beta): | |||||
| np.random.seed(42) | |||||
| prediction = np.random.randn(20).astype(np.float32) | |||||
| target = np.random.randn(20).astype(np.float32) | |||||
| net = nn.SmoothL1Loss(beta) | |||||
| return net(Tensor(prediction), Tensor(target)) | |||||
| @pytest.mark.level0 | @pytest.mark.level0 | ||||
| @pytest.mark.platform_x86_gpu_training | @pytest.mark.platform_x86_gpu_training | ||||
| @pytest.mark.env_onecard | @pytest.mark.env_onecard | ||||
| def test_smoothl1loss(): | def test_smoothl1loss(): | ||||
| np.random.seed(42) | |||||
| prediction = np.random.randn(20).astype(np.float32) | |||||
| target = np.random.randn(20).astype(np.float32) | |||||
| sigma = 1.0 | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU", save_graphs=True) | |||||
| epsilon = 1e-6 | |||||
| net = nn.SmoothL1Loss(sigma) | |||||
| loss = net(Tensor(prediction), Tensor(target)) | |||||
| beta = 1.0 | |||||
| loss = smoothl1loss(beta) | |||||
| expect = [0.46941718, 0.00382918, 0.16829303, 2.447778, 0.04812113, 0.05953304, | expect = [0.46941718, 0.00382918, 0.16829303, 2.447778, 0.04812113, 0.05953304, | ||||
| 2.2302065, 0.07672881, 0.00860204, 0.34798968, 0.00956192, 1.818008, | 2.2302065, 0.07672881, 0.00860204, 0.34798968, 0.00956192, 1.818008, | ||||
| 0.03262977, 0.36599946, 2.047463, 0.2168481, 0.7216947, 1.7739174, | 0.03262977, 0.36599946, 2.047463, 0.2168481, 0.7216947, 1.7739174, | ||||
| 0.08826803, 1.109165] | 0.08826803, 1.109165] | ||||
| assert np.allclose(loss.asnumpy(), expect) | |||||
| diff = np.absolute(loss.asnumpy() - np.array(expect)) | |||||
| assert(diff < epsilon).all() | |||||
| beta = 1 / 9 | |||||
| loss = smoothl1loss(beta) | |||||
| expect = [0.9133791, 0.03446258, 0.5246048, 2.8922224, 0.2546738, 0.289504, | |||||
| 2.674651, 0.33618113, 0.07560876, 0.7786982, 0.08273339, 2.2624524, | |||||
| 0.19990394, 0.8000138, 2.4919074, 0.6030006, 1.1661391, 2.2183619, | |||||
| 0.3646064, 1.5536094] | |||||
| diff = np.absolute(loss.asnumpy() - np.array(expect)) | |||||
| assert(diff < epsilon).all() | |||||
| class Grad(nn.Cell): | class Grad(nn.Cell): | ||||
| @@ -53,20 +67,26 @@ class Grad(nn.Cell): | |||||
| return gout | return gout | ||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_smoothl1loss_grad(): | |||||
| def smoothl1loss_grad(beta): | |||||
| np.random.seed(42) | np.random.seed(42) | ||||
| prediction = np.random.randn(20).astype(np.float32) | prediction = np.random.randn(20).astype(np.float32) | ||||
| target = np.random.randn(20).astype(np.float32) | target = np.random.randn(20).astype(np.float32) | ||||
| sens = np.random.randn(20).astype(np.float32) | sens = np.random.randn(20).astype(np.float32) | ||||
| sigma = 1.0 | |||||
| net = nn.SmoothL1Loss(sigma) | |||||
| net = nn.SmoothL1Loss(beta) | |||||
| grad = Grad(net) | grad = Grad(net) | ||||
| dx = grad(Tensor(prediction), Tensor(target), Tensor(sens)) | |||||
| return grad(Tensor(prediction), Tensor(target), Tensor(sens)) | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_smoothl1loss_grad(): | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU", save_graphs=True) | |||||
| epsilon = 1e-6 | |||||
| beta = 1.0 | |||||
| dx = smoothl1loss_grad(beta) | |||||
| dx1_expect = [-0.71552587, 0.01499678, -0.06709455, -0.30110368, -0.45868093, | dx1_expect = [-0.71552587, 0.01499678, -0.06709455, -0.30110368, -0.45868093, | ||||
| 0.24838912, -0.46063876, 0.41411355, 0.04507046, -1.4708229, | 0.24838912, -0.46063876, 0.41411355, 0.04507046, -1.4708229, | ||||
| 0.04481723, 0.38508227, -0.17292616, -0.52333146, -1.0309995, | 0.04481723, 0.38508227, -0.17292616, -0.52333146, -1.0309995, | ||||
| @@ -77,5 +97,23 @@ def test_smoothl1loss_grad(): | |||||
| -0.04481723, -0.38508227, 0.17292616, 0.52333146, 1.0309995, | -0.04481723, -0.38508227, 0.17292616, 0.52333146, 1.0309995, | ||||
| -0.61330026, -0.83921754, 0.3092124, -0.1391843, 0.9755451] | -0.61330026, -0.83921754, 0.3092124, -0.1391843, 0.9755451] | ||||
| assert np.allclose(dx[0].asnumpy(), dx1_expect) | |||||
| assert np.allclose(dx[1].asnumpy(), dx2_expect) | |||||
| diff1 = np.absolute(dx[0].asnumpy() - np.array(dx1_expect)) | |||||
| diff2 = np.absolute(dx[1].asnumpy() - np.array(dx2_expect)) | |||||
| assert(diff1 < epsilon).all() | |||||
| assert(diff2 < epsilon).all() | |||||
| beta = 1 / 9 | |||||
| dx = smoothl1loss_grad(beta) | |||||
| dx1_expect = [-0.73846656, 0.13497104, -0.11564828, -0.30110368, -1.478522, | |||||
| 0.7198442, -0.46063876, 1.0571222, 0.3436183, -1.7630402, | |||||
| 0.32408398, 0.38508227, -0.676922, -0.6116763, -1.0309995, | |||||
| 0.93128014, 0.83921754, -0.3092124, 0.33126342, -0.9755451] | |||||
| dx2_expect = [0.73846656, -0.13497104, 0.11564828, 0.30110368, 1.478522, | |||||
| -0.7198442, 0.46063876, -1.0571222, -0.3436183, 1.7630402, | |||||
| -0.32408398, -0.38508227, 0.676922, 0.6116763, 1.0309995, | |||||
| -0.93128014, -0.83921754, 0.3092124, -0.33126342, 0.9755451] | |||||
| diff1 = np.absolute(dx[0].asnumpy() - np.array(dx1_expect)) | |||||
| diff2 = np.absolute(dx[1].asnumpy() - np.array(dx2_expect)) | |||||
| assert(diff1 < epsilon).all() | |||||
| assert(diff2 < epsilon).all() | |||||