|
|
|
@@ -48,6 +48,9 @@ grad_overflow = P.FloatStatus() |
|
|
|
def _tensor_grad_overflow(grad): |
|
|
|
return grad_overflow(grad) |
|
|
|
|
|
|
|
@_grad_overflow.register("RowTensor") |
|
|
|
def _tensor_grad_overflow_row_tensor(grad): |
|
|
|
return grad_overflow(grad.values) |
|
|
|
|
|
|
|
class DynamicLossScaleUpdateCell(Cell): |
|
|
|
r""" |
|
|
|
|