|
|
|
@@ -22,6 +22,7 @@ from mindspore.ops import operations as P |
|
|
|
from mindspore.ops import functional as F |
|
|
|
from mindspore.ops.functional import identity |
|
|
|
from mindspore.ops.operations import _inner_ops as inner |
|
|
|
from mindspore.ops.primitive import constexpr |
|
|
|
from mindspore.common.parameter import Parameter |
|
|
|
from mindspore._extends import cell_attr_register |
|
|
|
from mindspore.common.api import ms_function |
|
|
|
@@ -236,6 +237,13 @@ class Dense(Cell): |
|
|
|
return str_info |
|
|
|
|
|
|
|
|
|
|
|
@constexpr |
|
|
|
def _is_equal_one(x): |
|
|
|
if x is None: |
|
|
|
return False |
|
|
|
return bool(x.asnumpy().mean() == 1.0) |
|
|
|
|
|
|
|
|
|
|
|
class ClipByNorm(Cell): |
|
|
|
r""" |
|
|
|
Clips tensor values to a maximum :math:`L_2`-norm. |
|
|
|
@@ -290,7 +298,10 @@ class ClipByNorm(Cell): |
|
|
|
l2sum_safe = self.select_(cond, l2sum, self.cast(ones_, self.dtype(l2sum))) |
|
|
|
l2norm = self.select_(cond, self.sqrt(l2sum_safe), l2sum) |
|
|
|
|
|
|
|
intermediate = x * clip_norm |
|
|
|
if _is_equal_one(clip_norm): |
|
|
|
intermediate = x |
|
|
|
else: |
|
|
|
intermediate = x * clip_norm |
|
|
|
max_norm = self.max_op(l2norm, clip_norm) |
|
|
|
values_clip = self.cast(intermediate, mstype.float32) / self.expand_dims(max_norm, -1) |
|
|
|
values_clip = self.reshape(values_clip, self.shape(x)) |
|
|
|
|