Browse Source

fix centernet loss error.

tags/v1.2.0-rc1
CaoJian 4 years ago
parent
commit
21a2c7bcbe
2 changed files with 3 additions and 6 deletions
  1. +2
    -5
      model_zoo/research/cv/centernet/src/centernet_pose.py
  2. +1
    -1
      model_zoo/research/cv/centernet/src/utils.py

+ 2
- 5
model_zoo/research/cv/centernet/src/centernet_pose.py View File

@@ -308,11 +308,8 @@ class CenterNetWithLossScaleCell(nn.Cell):
cond = self.less_equal(self.base, flag_reduce)
else:
cond = self.less_equal(self.base, flag_sum)
overflow = cond
if overflow:
succ = False
else:
succ = self.optimizer(grads)

succ = self.optimizer(grads)
ret = (loss, cond, scaling_sens)
return ops.depend(ret, succ)



+ 1
- 1
model_zoo/research/cv/centernet/src/utils.py View File

@@ -137,7 +137,7 @@ class GatherFeature(nn.Cell):
self.gather_nd = ops.GatherD()
self.expand_dims = ops.ExpandDims()
else:
self.gather_nd = ops.GatherND()
self.gather_nd = ops.GatherNd()

def construct(self, feat, ind):
"""gather by specified index"""


Loading…
Cancel
Save