Browse Source

!17449 clearn codechekc for master thor optimizer

From: @sl_wang
Reviewed-by: @kisnwang,@c_34
Signed-off-by: @c_34
tags/v1.3.0
mindspore-ci-bot Gitee 4 years ago
parent
commit
294438c4ed
3 changed files with 21 additions and 13 deletions
  1. +6
    -6
      mindspore/nn/layer/thor_layer.py
  2. +4
    -0
      mindspore/nn/optim/thor.py
  3. +11
    -7
      mindspore/train/train_thor/convert_utils.py

+ 6
- 6
mindspore/nn/layer/thor_layer.py View File

@@ -221,9 +221,9 @@ class _ConvThor(Cell):
self.dilation = dilation
self.group = Validator.check_positive_int(group)
self.has_bias = has_bias
self._validate_kernel_size(kernel_size)
self._validate_stride(stride)
self._validate_dilation(dilation)
self.__validate_kernel_size(kernel_size)
self.__validate_stride(stride)
self.__validate_dilation(dilation)
if in_channels % group != 0:
raise ValueError("Attr 'in_channels' of 'Conv2DThor' Op must be divisible by "
"attr 'group' of 'Conv2DThor' Op.")
@@ -243,7 +243,7 @@ class _ConvThor(Cell):
logger.warning("Value of 'has_bias' is False, value of 'bias_init' will be ignored.")
self.bias = None

def _validate_kernel_size(self, kernel_size):
def __validate_kernel_size(self, kernel_size):
"""validate kernel size."""
if (not isinstance(kernel_size[0], int)) or (not isinstance(kernel_size[1], int)) or \
isinstance(kernel_size[0], bool) or isinstance(kernel_size[1], bool) or \
@@ -251,14 +251,14 @@ class _ConvThor(Cell):
raise ValueError("Attr 'kernel_size' of 'Conv2D' Op passed "
+ str(self.kernel_size) + ", should be a int or tuple and equal to or greater than 1.")

def _validate_stride(self, stride):
def __validate_stride(self, stride):
"""validate stride."""
if (not isinstance(stride[0], int)) or (not isinstance(stride[1], int)) or \
isinstance(stride[0], bool) or isinstance(stride[1], bool) or stride[0] < 1 or stride[1] < 1:
raise ValueError("Attr 'stride' of 'Conv2D' Op passed "
+ str(self.stride) + ", should be a int or tuple and equal to or greater than 1.")

def _validate_dilation(self, dilation):
def __validate_dilation(self, dilation):
"""validate dilation."""
if (not isinstance(dilation[0], int)) or (not isinstance(dilation[1], int)) or \
isinstance(dilation[0], bool) or isinstance(dilation[1], bool) or dilation[0] < 1 or dilation[1] < 1:


+ 4
- 0
mindspore/nn/optim/thor.py View File

@@ -42,6 +42,7 @@ op_add = P.AddN()
apply_decay = C.MultitypeFuncGraph("apply_decay")
_momentum_opt = C.MultitypeFuncGraph("momentum_opt")


@apply_decay.register("Number", "Bool", "Tensor", "Tensor")
def _tensor_apply_decay(weight_decay, if_apply, weight, gradient):
"""Get grad with weight_decay."""
@@ -63,6 +64,7 @@ GRADIENT_CLIP_VALUE = 1.0
clip_grad = C.MultitypeFuncGraph("clip_grad")
hyper_map_op = C.HyperMap()


@clip_grad.register("Number", "Number", "Tensor")
def _clip_grad(clip_type, clip_value, grad):
"""
@@ -97,6 +99,8 @@ def clip_gradient(enable_clip_grad, gradients):
return gradients

C0 = 16


def caculate_device_shape(matrix_dim, channel, is_a):
ll = (0)
if is_a:


+ 11
- 7
mindspore/train/train_thor/convert_utils.py View File

@@ -25,12 +25,13 @@ class ConvertNetUntils():
Convert net to thor layer net
"""
def __init__(self):
self._convert_method_map = {nn.Dense: self._convert_dense,
nn.Embedding: self._convert_embedding,
nn.Conv2d: self._convert_conv2d}
self._convert_method_map = {nn.Dense: ConvertNetUntils._convert_dense,
nn.Embedding: ConvertNetUntils._convert_embedding,
nn.Conv2d: ConvertNetUntils._convert_conv2d}


def _convert_dense(self, subcell):
@staticmethod
def _convert_dense(subcell):
"""
convert dense cell to second_order cell
"""
@@ -65,7 +66,8 @@ class ConvertNetUntils():
return new_subcell


def _convert_embedding(self, subcell):
@staticmethod
def _convert_embedding(subcell):
"""
convert embedding cell to second_order cell
"""
@@ -76,7 +78,8 @@ class ConvertNetUntils():
return new_subcell


def _convert_conv2d(self, subcell):
@staticmethod
def _convert_conv2d(subcell):
"""
convert conv2d cell to second_order cell
"""
@@ -140,7 +143,8 @@ class ConvertModelUtils():
convert model to thor model utils
"""

def convert_to_thor_model(self, model, network, loss_fn=None, optimizer=None, metrics=None, amp_level="O0",
@staticmethod
def convert_to_thor_model(model, network, loss_fn=None, optimizer=None, metrics=None, amp_level="O0",
loss_scale_manager=None, keep_batchnorm_fp32=False):

"""


Loading…
Cancel
Save