diff --git a/mindspore/nn/layer/container.py b/mindspore/nn/layer/container.py index 817d271774..53de13879e 100644 --- a/mindspore/nn/layer/container.py +++ b/mindspore/nn/layer/container.py @@ -35,6 +35,43 @@ def _valid_cell(cell): raise TypeError('Cell {} is not subclass of Cell'.format(cell)) +def _get_prefix_and_index(cells): + """get prefix and index of parameter name in sequential cell or cell list""" + prefix = "" + index = 0 + if not cells: + return prefix, index + + cell_list = list(cells.items()) + first_param, first_key = None, None + second_param, second_key = None, None + for key, cell in cell_list: + try: + _, param = next(cell.parameters_and_names()) + except StopIteration: + continue + if first_param is None: + first_param = param + first_key = key + continue + second_param = param + second_key = key + break + + if first_param is None: + return prefix, index + + split_names = first_param.name.split(".") + for idx, name in enumerate(split_names): + if name == first_key: + prefix = ".".join(split_names[:idx]) + prefix = prefix + "." if prefix else prefix + index = idx + if second_param is not None and second_param.name.split(".")[idx] == second_key: + break + return prefix, index + + class _CellListBase(): """ An interface for base the cell as list. @@ -97,19 +134,26 @@ class SequentialCell(Cell): """ def __init__(self, *args): super(SequentialCell, self).__init__() + self._is_dynamic_name = [] if len(args) == 1: cells = args[0] if isinstance(cells, list): for index, cell in enumerate(cells): self.insert_child_to_cell(str(index), cell) + cell.update_parameters_name(str(index) + ".") + self._is_dynamic_name.append(True) elif isinstance(cells, OrderedDict): for name, cell in cells.items(): self.insert_child_to_cell(name, cell) + cell.update_parameters_name(name + ".") + self._is_dynamic_name.append(False) else: raise TypeError('Cells must be list or orderedDict') else: for index, cell in enumerate(args): self.insert_child_to_cell(str(index), cell) + cell.update_parameters_name(str(index) + ".") + self._is_dynamic_name.append(True) self.cell_list = list(self._cells.values()) def __getitem__(self, index): @@ -121,9 +165,11 @@ class SequentialCell(Cell): def __setitem__(self, index, cell): if _valid_cell(cell): + prefix, _ = _get_prefix_and_index(self._cells) index = _valid_index(len(self), index) key = list(self._cells.keys())[index] self._cells[key] = cell + cell.update_parameters_name(prefix + key + ".") self.cell_list = list(self._cells.values()) def __delitem__(self, index): @@ -131,12 +177,25 @@ class SequentialCell(Cell): index = _valid_index(len(self), index) key = list(self._cells.keys())[index] del self._cells[key] + del self._is_dynamic_name[index] elif isinstance(index, slice): keys = list(self._cells.keys())[index] for key in keys: del self._cells[key] + del self._is_dynamic_name[index] else: raise TypeError('Index {} is not int type or slice type'.format(index)) + prefix, key_index = _get_prefix_and_index(self._cells) + temp_dict = OrderedDict() + for idx, key in enumerate(self._cells.keys()): + cell = self._cells[key] + if self._is_dynamic_name[idx]: + for _, param in cell.parameters_and_names(): + param.name = prefix + str(idx) + "." + ".".join(param.name.split(".")[key_index+1:]) + temp_dict[str(idx)] = cell + else: + temp_dict[key] = cell + self._cells = temp_dict self.cell_list = list(self._cells.values()) def __len__(self): @@ -165,6 +224,9 @@ class SequentialCell(Cell): [26.999863 26.999863]]]] """ if _valid_cell(cell): + prefix, _ = _get_prefix_and_index(self._cells) + cell.update_parameters_name(prefix + str(len(self)) + ".") + self._is_dynamic_name.append(True) self._cells[str(len(self))] = cell self.cell_list = list(self._cells.values()) @@ -202,9 +264,10 @@ class CellList(_CellListBase, Cell): (2): ReLU<> > """ - def __init__(self, *args): + def __init__(self, *args, **kwargs): + auto_prefix = kwargs["auto_prefix"] if "auto_prefix" in kwargs.keys() else True _CellListBase.__init__(self) - Cell.__init__(self) + Cell.__init__(self, auto_prefix) if len(args) == 1: self.extend(args[0]) @@ -220,6 +283,9 @@ class CellList(_CellListBase, Cell): if not isinstance(index, int) and _valid_cell(cell): raise TypeError('Index {} is not int type'.format(index)) index = _valid_index(len(self), index) + if self._auto_prefix: + prefix, _ = _get_prefix_and_index(self._cells) + cell.update_parameters_name(prefix + str(index) + ".") self._cells[str(index)] = cell def __delitem__(self, index): @@ -233,8 +299,12 @@ class CellList(_CellListBase, Cell): else: raise TypeError('Index {} is not int type or slice type'.format(index)) # adjust orderedDict + prefix, key_index = _get_prefix_and_index(self._cells) temp_dict = OrderedDict() for idx, cell in enumerate(self._cells.values()): + if self._auto_prefix: + for _, param in cell.parameters_and_names(): + param.name = prefix + str(idx) + "." + ".".join(param.name.split(".")[key_index+1:]) temp_dict[str(idx)] = cell self._cells = temp_dict @@ -253,10 +323,17 @@ class CellList(_CellListBase, Cell): idx = _valid_index(len(self), index) _valid_cell(cell) length = len(self) + prefix, key_index = _get_prefix_and_index(self._cells) while length > idx: + if self._auto_prefix: + tmp_cell = self._cells[str(length-1)] + for _, param in tmp_cell.parameters_and_names(): + param.name = prefix + str(length) + "." + ".".join(param.name.split(".")[key_index+1:]) self._cells[str(length)] = self._cells[str(length - 1)] length -= 1 self._cells[str(idx)] = cell + if self._auto_prefix: + cell.update_parameters_name(prefix + str(idx) + ".") def extend(self, cells): """ @@ -267,14 +344,20 @@ class CellList(_CellListBase, Cell): """ if not isinstance(cells, list): raise TypeError('Cells {} should be list of subcells'.format(cells)) + prefix, _ = _get_prefix_and_index(self._cells) for cell in cells: if _valid_cell(cell): + if self._auto_prefix: + cell.update_parameters_name(prefix + str(len(self)) + ".") self._cells[str(len(self))] = cell return self def append(self, cell): """Appends a given cell to the end of the list.""" if _valid_cell(cell): + if self._auto_prefix: + prefix, _ = _get_prefix_and_index(self._cells) + cell.update_parameters_name(prefix + str(len(self)) + ".") self._cells[str(len(self))] = cell def set_grad(self, flag=True): diff --git a/mindspore/nn/optim/optimizer.py b/mindspore/nn/optim/optimizer.py index c008e35e9f..4cbdef3284 100755 --- a/mindspore/nn/optim/optimizer.py +++ b/mindspore/nn/optim/optimizer.py @@ -139,7 +139,8 @@ class Optimizer(Cell): self.global_step = Parameter(initializer(0, [1], mindspore.int32), name='global_step') if self.is_group_lr: - self.learning_rate = CellList(self.group_lr) if self.dynamic_lr else ParameterTuple(self.group_lr) + self.learning_rate = CellList(self.group_lr, auto_prefix=False) if self.dynamic_lr \ + else ParameterTuple(self.group_lr) else: self.learning_rate = self._build_single_lr(learning_rate, 'learning_rate')