| @@ -158,16 +158,16 @@ class SmoothL1Loss(_Loss): | |||||
| .. math:: | .. math:: | ||||
| L_{i} = | L_{i} = | ||||
| \begin{cases} | \begin{cases} | ||||
| 0.5 (x_i - y_i)^2, & \text{if } |x_i - y_i| < \text{sigma}; \\ | |||||
| |x_i - y_i| - 0.5, & \text{otherwise. } | |||||
| \frac{0.5 (x_i - y_i)^{2}}{\text{beta}}, & \text{if } |x_i - y_i| < \text{beta} \\ | |||||
| |x_i - y_i| - 0.5 \text{beta}, & \text{otherwise. } | |||||
| \end{cases} | \end{cases} | ||||
| Here :math:`\text{sigma}` controls the point where the loss function changes from quadratic to linear. | |||||
| Here :math:`\text{beta}` controls the point where the loss function changes from quadratic to linear. | |||||
| Its default value is 1.0. :math:`N` is the batch size. This function returns an | Its default value is 1.0. :math:`N` is the batch size. This function returns an | ||||
| unreduced loss Tensor. | unreduced loss Tensor. | ||||
| Args: | Args: | ||||
| sigma (float): A parameter used to control the point where the function will change from | |||||
| beta (float): A parameter used to control the point where the function will change from | |||||
| quadratic to linear. Default: 1.0. | quadratic to linear. Default: 1.0. | ||||
| Inputs: | Inputs: | ||||
| @@ -183,10 +183,10 @@ class SmoothL1Loss(_Loss): | |||||
| >>> target_data = Tensor(np.array([1, 2, 2]), mindspore.float32) | >>> target_data = Tensor(np.array([1, 2, 2]), mindspore.float32) | ||||
| >>> loss(input_data, target_data) | >>> loss(input_data, target_data) | ||||
| """ | """ | ||||
| def __init__(self, sigma=1.0): | |||||
| def __init__(self, beta=1.0): | |||||
| super(SmoothL1Loss, self).__init__() | super(SmoothL1Loss, self).__init__() | ||||
| self.sigma = sigma | |||||
| self.smooth_l1_loss = P.SmoothL1Loss(self.sigma) | |||||
| self.beta = beta | |||||
| self.smooth_l1_loss = P.SmoothL1Loss(self.beta) | |||||
| def construct(self, base, target): | def construct(self, base, target): | ||||
| return self.smooth_l1_loss(base, target) | return self.smooth_l1_loss(base, target) | ||||