From 21a2c7bcbee7925eeb6c24cf1c80e208be3fb632 Mon Sep 17 00:00:00 2001 From: CaoJian Date: Sat, 20 Mar 2021 21:42:25 +0800 Subject: [PATCH] fix centernet loss error. --- model_zoo/research/cv/centernet/src/centernet_pose.py | 7 ++----- model_zoo/research/cv/centernet/src/utils.py | 2 +- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/model_zoo/research/cv/centernet/src/centernet_pose.py b/model_zoo/research/cv/centernet/src/centernet_pose.py index c558156a03..9cac9a8fa5 100644 --- a/model_zoo/research/cv/centernet/src/centernet_pose.py +++ b/model_zoo/research/cv/centernet/src/centernet_pose.py @@ -308,11 +308,8 @@ class CenterNetWithLossScaleCell(nn.Cell): cond = self.less_equal(self.base, flag_reduce) else: cond = self.less_equal(self.base, flag_sum) - overflow = cond - if overflow: - succ = False - else: - succ = self.optimizer(grads) + + succ = self.optimizer(grads) ret = (loss, cond, scaling_sens) return ops.depend(ret, succ) diff --git a/model_zoo/research/cv/centernet/src/utils.py b/model_zoo/research/cv/centernet/src/utils.py index 42898f85ab..89b73d803c 100644 --- a/model_zoo/research/cv/centernet/src/utils.py +++ b/model_zoo/research/cv/centernet/src/utils.py @@ -137,7 +137,7 @@ class GatherFeature(nn.Cell): self.gather_nd = ops.GatherD() self.expand_dims = ops.ExpandDims() else: - self.gather_nd = ops.GatherND() + self.gather_nd = ops.GatherNd() def construct(self, feat, ind): """gather by specified index"""