Browse Source

!9686 Fix a bug of naming the variable in LGamma

From: @peixu_ren
Reviewed-by: @zichun_ye,@sunnybeike
Signed-off-by: @sunnybeike
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
b724bac9fc
1 changed files with 6 additions and 6 deletions
  1. +6
    -6
      mindspore/nn/layer/math.py

+ 6
- 6
mindspore/nn/layer/math.py View File

@@ -253,18 +253,18 @@ class LGamma(Cell):
z = self.select(need_to_reflect, neg_input, x - 1) z = self.select(need_to_reflect, neg_input, x - 1)


@constexpr @constexpr
def _calculate_x(z, k_base_lanczos_coeff, k_lanczos_coefficients):
x = k_base_lanczos_coeff
def _calculate_reflected_x(z, k_base_lanczos_coeff, k_lanczos_coefficients):
reflex_x = k_base_lanczos_coeff
for i in range(8): for i in range(8):
product_ = k_lanczos_coefficients[i] / (z + i + 1) product_ = k_lanczos_coefficients[i] / (z + i + 1)
x = product_ + x
return x
x = _calculate_x(z, self.k_base_lanczos_coeff, self.k_lanczos_coefficients)
reflex_x = product_ + reflex_x
return reflex_x
reflex_x = _calculate_reflected_x(z, self.k_base_lanczos_coeff, self.k_lanczos_coefficients)


t = z + self.lanczos_gamma_plus_one_half t = z + self.lanczos_gamma_plus_one_half
log_t = self.log1p(z / self.lanczos_gamma_plus_one_half) + self.log_lanczos_gamma_plus_one_half log_t = self.log1p(z / self.lanczos_gamma_plus_one_half) + self.log_lanczos_gamma_plus_one_half


log_y = self.log(x) + (z + self.one_half - t / log_t) * log_t + self.log_sqrt_two_pi
log_y = self.log(reflex_x) + (z + self.one_half - t / log_t) * log_t + self.log_sqrt_two_pi


abs_input = self.abs(x) abs_input = self.abs(x)
abs_frac_input = abs_input - self.floor(abs_input) abs_frac_input = abs_input - self.floor(abs_input)


Loading…
Cancel
Save