From 35b4a804e2d12458e0fc475e4ca8c4430db78d55 Mon Sep 17 00:00:00 2001 From: peixu_ren Date: Mon, 7 Dec 2020 15:14:37 -0500 Subject: [PATCH] Modify the names of parameter check --- mindspore/nn/layer/math.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/mindspore/nn/layer/math.py b/mindspore/nn/layer/math.py index 324d81b8bb..884ab24372 100644 --- a/mindspore/nn/layer/math.py +++ b/mindspore/nn/layer/math.py @@ -245,7 +245,7 @@ class LGamma(Cell): def construct(self, x): input_dtype = self.dtype(x) - _check_input_dtype("input", input_dtype, [mstype.float16, mstype.float32], self.cls_name) + _check_input_dtype("x", input_dtype, [mstype.float16, mstype.float32], self.cls_name) infinity = self.fill(input_dtype, self.shape(x), self.inf) need_to_reflect = self.less(x, 0.5) @@ -352,7 +352,7 @@ class DiGamma(Cell): def construct(self, x): input_dtype = self.dtype(x) - _check_input_dtype("input_x", input_dtype, [mstype.float16, mstype.float32], self.cls_name) + _check_input_dtype("x", input_dtype, [mstype.float16, mstype.float32], self.cls_name) need_to_reflect = self.less(x, 0.5) neg_input = -x z = self.select(need_to_reflect, neg_input, x - 1) @@ -612,8 +612,8 @@ class IGamma(Cell): def construct(self, a, x): a_dtype = self.dtype(a) x_dtype = self.dtype(x) - _check_input_dtype("input_a", a_dtype, [mstype.float32], self.cls_name) - _check_input_dtype("input_x", x_dtype, a_dtype, self.cls_name) + _check_input_dtype("a", a_dtype, [mstype.float32], self.cls_name) + _check_input_dtype("x", x_dtype, a_dtype, self.cls_name) domain_error = self.logicalor(self.less(x, 0), self.less(a, 0)) use_igammac = self.logicaland(self.greater(x, 1), self.greater(x, a)) ax = a * self.log(x) - x - self.lgamma(a) @@ -688,8 +688,8 @@ class LBeta(Cell): def construct(self, x, y): x_dtype = self.dtype(x) y_dtype = self.dtype(y) - _check_input_dtype("input_x", x_dtype, [mstype.float16, mstype.float32], self.cls_name) - _check_input_dtype("input_y", y_dtype, x_dtype, self.cls_name) + _check_input_dtype("x", x_dtype, [mstype.float16, mstype.float32], self.cls_name) + _check_input_dtype("y", y_dtype, x_dtype, self.cls_name) x_plus_y = x + y para_shape = self.shape(x_plus_y) broadcastto = P.BroadcastTo(para_shape)