|
|
|
@@ -23,6 +23,7 @@ from mindspore.ops import functional as F |
|
|
|
from mindspore.ops.functional import identity |
|
|
|
from mindspore.common.parameter import Parameter |
|
|
|
from mindspore._extends import cell_attr_register |
|
|
|
from mindspore.common.api import ms_function |
|
|
|
from ..cell import Cell |
|
|
|
from .activation import get_activation |
|
|
|
from ..._checkparam import Validator as validator |
|
|
|
@@ -261,7 +262,9 @@ class ClipByNorm(Cell): |
|
|
|
self.expand_dims = P.ExpandDims() |
|
|
|
self.dtype = P.DType() |
|
|
|
|
|
|
|
@ms_function |
|
|
|
def construct(self, x, clip_norm): |
|
|
|
"""add ms_function decorator for pynative mode""" |
|
|
|
mul_x = F.square(x) |
|
|
|
l2sum = self.cast(self.reduce_sum(mul_x, self.axis), mstype.float32) |
|
|
|
cond = self.greater_(l2sum, self.zero) |
|
|
|
|