Browse Source

!13260 fix SequentialCell and CellList parameter name bug

From: @caozhou_huawei
Reviewed-by: 
Signed-off-by:
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 4 years ago
parent
commit
5c39c33c92
2 changed files with 87 additions and 3 deletions
  1. +85
    -2
      mindspore/nn/layer/container.py
  2. +2
    -1
      mindspore/nn/optim/optimizer.py

+ 85
- 2
mindspore/nn/layer/container.py View File

@@ -35,6 +35,43 @@ def _valid_cell(cell):
raise TypeError('Cell {} is not subclass of Cell'.format(cell)) raise TypeError('Cell {} is not subclass of Cell'.format(cell))




def _get_prefix_and_index(cells):
"""get prefix and index of parameter name in sequential cell or cell list"""
prefix = ""
index = 0
if not cells:
return prefix, index

cell_list = list(cells.items())
first_param, first_key = None, None
second_param, second_key = None, None
for key, cell in cell_list:
try:
_, param = next(cell.parameters_and_names())
except StopIteration:
continue
if first_param is None:
first_param = param
first_key = key
continue
second_param = param
second_key = key
break

if first_param is None:
return prefix, index

split_names = first_param.name.split(".")
for idx, name in enumerate(split_names):
if name == first_key:
prefix = ".".join(split_names[:idx])
prefix = prefix + "." if prefix else prefix
index = idx
if second_param is not None and second_param.name.split(".")[idx] == second_key:
break
return prefix, index


class _CellListBase(): class _CellListBase():
""" """
An interface for base the cell as list. An interface for base the cell as list.
@@ -97,19 +134,26 @@ class SequentialCell(Cell):
""" """
def __init__(self, *args): def __init__(self, *args):
super(SequentialCell, self).__init__() super(SequentialCell, self).__init__()
self._is_dynamic_name = []
if len(args) == 1: if len(args) == 1:
cells = args[0] cells = args[0]
if isinstance(cells, list): if isinstance(cells, list):
for index, cell in enumerate(cells): for index, cell in enumerate(cells):
self.insert_child_to_cell(str(index), cell) self.insert_child_to_cell(str(index), cell)
cell.update_parameters_name(str(index) + ".")
self._is_dynamic_name.append(True)
elif isinstance(cells, OrderedDict): elif isinstance(cells, OrderedDict):
for name, cell in cells.items(): for name, cell in cells.items():
self.insert_child_to_cell(name, cell) self.insert_child_to_cell(name, cell)
cell.update_parameters_name(name + ".")
self._is_dynamic_name.append(False)
else: else:
raise TypeError('Cells must be list or orderedDict') raise TypeError('Cells must be list or orderedDict')
else: else:
for index, cell in enumerate(args): for index, cell in enumerate(args):
self.insert_child_to_cell(str(index), cell) self.insert_child_to_cell(str(index), cell)
cell.update_parameters_name(str(index) + ".")
self._is_dynamic_name.append(True)
self.cell_list = list(self._cells.values()) self.cell_list = list(self._cells.values())


def __getitem__(self, index): def __getitem__(self, index):
@@ -121,9 +165,11 @@ class SequentialCell(Cell):


def __setitem__(self, index, cell): def __setitem__(self, index, cell):
if _valid_cell(cell): if _valid_cell(cell):
prefix, _ = _get_prefix_and_index(self._cells)
index = _valid_index(len(self), index) index = _valid_index(len(self), index)
key = list(self._cells.keys())[index] key = list(self._cells.keys())[index]
self._cells[key] = cell self._cells[key] = cell
cell.update_parameters_name(prefix + key + ".")
self.cell_list = list(self._cells.values()) self.cell_list = list(self._cells.values())


def __delitem__(self, index): def __delitem__(self, index):
@@ -131,12 +177,25 @@ class SequentialCell(Cell):
index = _valid_index(len(self), index) index = _valid_index(len(self), index)
key = list(self._cells.keys())[index] key = list(self._cells.keys())[index]
del self._cells[key] del self._cells[key]
del self._is_dynamic_name[index]
elif isinstance(index, slice): elif isinstance(index, slice):
keys = list(self._cells.keys())[index] keys = list(self._cells.keys())[index]
for key in keys: for key in keys:
del self._cells[key] del self._cells[key]
del self._is_dynamic_name[index]
else: else:
raise TypeError('Index {} is not int type or slice type'.format(index)) raise TypeError('Index {} is not int type or slice type'.format(index))
prefix, key_index = _get_prefix_and_index(self._cells)
temp_dict = OrderedDict()
for idx, key in enumerate(self._cells.keys()):
cell = self._cells[key]
if self._is_dynamic_name[idx]:
for _, param in cell.parameters_and_names():
param.name = prefix + str(idx) + "." + ".".join(param.name.split(".")[key_index+1:])
temp_dict[str(idx)] = cell
else:
temp_dict[key] = cell
self._cells = temp_dict
self.cell_list = list(self._cells.values()) self.cell_list = list(self._cells.values())


def __len__(self): def __len__(self):
@@ -165,6 +224,9 @@ class SequentialCell(Cell):
[26.999863 26.999863]]]] [26.999863 26.999863]]]]
""" """
if _valid_cell(cell): if _valid_cell(cell):
prefix, _ = _get_prefix_and_index(self._cells)
cell.update_parameters_name(prefix + str(len(self)) + ".")
self._is_dynamic_name.append(True)
self._cells[str(len(self))] = cell self._cells[str(len(self))] = cell
self.cell_list = list(self._cells.values()) self.cell_list = list(self._cells.values())


@@ -202,9 +264,10 @@ class CellList(_CellListBase, Cell):
(2): ReLU<> (2): ReLU<>
> >
""" """
def __init__(self, *args):
def __init__(self, *args, **kwargs):
auto_prefix = kwargs["auto_prefix"] if "auto_prefix" in kwargs.keys() else True
_CellListBase.__init__(self) _CellListBase.__init__(self)
Cell.__init__(self)
Cell.__init__(self, auto_prefix)
if len(args) == 1: if len(args) == 1:
self.extend(args[0]) self.extend(args[0])


@@ -220,6 +283,9 @@ class CellList(_CellListBase, Cell):
if not isinstance(index, int) and _valid_cell(cell): if not isinstance(index, int) and _valid_cell(cell):
raise TypeError('Index {} is not int type'.format(index)) raise TypeError('Index {} is not int type'.format(index))
index = _valid_index(len(self), index) index = _valid_index(len(self), index)
if self._auto_prefix:
prefix, _ = _get_prefix_and_index(self._cells)
cell.update_parameters_name(prefix + str(index) + ".")
self._cells[str(index)] = cell self._cells[str(index)] = cell


def __delitem__(self, index): def __delitem__(self, index):
@@ -233,8 +299,12 @@ class CellList(_CellListBase, Cell):
else: else:
raise TypeError('Index {} is not int type or slice type'.format(index)) raise TypeError('Index {} is not int type or slice type'.format(index))
# adjust orderedDict # adjust orderedDict
prefix, key_index = _get_prefix_and_index(self._cells)
temp_dict = OrderedDict() temp_dict = OrderedDict()
for idx, cell in enumerate(self._cells.values()): for idx, cell in enumerate(self._cells.values()):
if self._auto_prefix:
for _, param in cell.parameters_and_names():
param.name = prefix + str(idx) + "." + ".".join(param.name.split(".")[key_index+1:])
temp_dict[str(idx)] = cell temp_dict[str(idx)] = cell
self._cells = temp_dict self._cells = temp_dict


@@ -253,10 +323,17 @@ class CellList(_CellListBase, Cell):
idx = _valid_index(len(self), index) idx = _valid_index(len(self), index)
_valid_cell(cell) _valid_cell(cell)
length = len(self) length = len(self)
prefix, key_index = _get_prefix_and_index(self._cells)
while length > idx: while length > idx:
if self._auto_prefix:
tmp_cell = self._cells[str(length-1)]
for _, param in tmp_cell.parameters_and_names():
param.name = prefix + str(length) + "." + ".".join(param.name.split(".")[key_index+1:])
self._cells[str(length)] = self._cells[str(length - 1)] self._cells[str(length)] = self._cells[str(length - 1)]
length -= 1 length -= 1
self._cells[str(idx)] = cell self._cells[str(idx)] = cell
if self._auto_prefix:
cell.update_parameters_name(prefix + str(idx) + ".")


def extend(self, cells): def extend(self, cells):
""" """
@@ -267,14 +344,20 @@ class CellList(_CellListBase, Cell):
""" """
if not isinstance(cells, list): if not isinstance(cells, list):
raise TypeError('Cells {} should be list of subcells'.format(cells)) raise TypeError('Cells {} should be list of subcells'.format(cells))
prefix, _ = _get_prefix_and_index(self._cells)
for cell in cells: for cell in cells:
if _valid_cell(cell): if _valid_cell(cell):
if self._auto_prefix:
cell.update_parameters_name(prefix + str(len(self)) + ".")
self._cells[str(len(self))] = cell self._cells[str(len(self))] = cell
return self return self


def append(self, cell): def append(self, cell):
"""Appends a given cell to the end of the list.""" """Appends a given cell to the end of the list."""
if _valid_cell(cell): if _valid_cell(cell):
if self._auto_prefix:
prefix, _ = _get_prefix_and_index(self._cells)
cell.update_parameters_name(prefix + str(len(self)) + ".")
self._cells[str(len(self))] = cell self._cells[str(len(self))] = cell


def set_grad(self, flag=True): def set_grad(self, flag=True):


+ 2
- 1
mindspore/nn/optim/optimizer.py View File

@@ -146,7 +146,8 @@ class Optimizer(Cell):
self.global_step = Parameter(initializer(0, [1], mindspore.int32), name='global_step') self.global_step = Parameter(initializer(0, [1], mindspore.int32), name='global_step')


if self.is_group_lr: if self.is_group_lr:
self.learning_rate = CellList(self.group_lr) if self.dynamic_lr else ParameterTuple(self.group_lr)
self.learning_rate = CellList(self.group_lr, auto_prefix=False) if self.dynamic_lr \
else ParameterTuple(self.group_lr)
else: else:
self.learning_rate = self._build_single_lr(learning_rate, 'learning_rate') self.learning_rate = self._build_single_lr(learning_rate, 'learning_rate')




Loading…
Cancel
Save