| @@ -308,11 +308,8 @@ class CenterNetWithLossScaleCell(nn.Cell): | |||||
| cond = self.less_equal(self.base, flag_reduce) | cond = self.less_equal(self.base, flag_reduce) | ||||
| else: | else: | ||||
| cond = self.less_equal(self.base, flag_sum) | cond = self.less_equal(self.base, flag_sum) | ||||
| overflow = cond | |||||
| if overflow: | |||||
| succ = False | |||||
| else: | |||||
| succ = self.optimizer(grads) | |||||
| succ = self.optimizer(grads) | |||||
| ret = (loss, cond, scaling_sens) | ret = (loss, cond, scaling_sens) | ||||
| return ops.depend(ret, succ) | return ops.depend(ret, succ) | ||||
| @@ -137,7 +137,7 @@ class GatherFeature(nn.Cell): | |||||
| self.gather_nd = ops.GatherD() | self.gather_nd = ops.GatherD() | ||||
| self.expand_dims = ops.ExpandDims() | self.expand_dims = ops.ExpandDims() | ||||
| else: | else: | ||||
| self.gather_nd = ops.GatherND() | |||||
| self.gather_nd = ops.GatherNd() | |||||
| def construct(self, feat, ind): | def construct(self, feat, ind): | ||||
| """gather by specified index""" | """gather by specified index""" | ||||