|
|
@@ -35,6 +35,43 @@ def _valid_cell(cell): |
|
|
raise TypeError('Cell {} is not subclass of Cell'.format(cell)) |
|
|
raise TypeError('Cell {} is not subclass of Cell'.format(cell)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_prefix_and_index(cells): |
|
|
|
|
|
"""get prefix and index of parameter name in sequential cell or cell list""" |
|
|
|
|
|
prefix = "" |
|
|
|
|
|
index = 0 |
|
|
|
|
|
if not cells: |
|
|
|
|
|
return prefix, index |
|
|
|
|
|
|
|
|
|
|
|
cell_list = list(cells.items()) |
|
|
|
|
|
first_param, first_key = None, None |
|
|
|
|
|
second_param, second_key = None, None |
|
|
|
|
|
for key, cell in cell_list: |
|
|
|
|
|
try: |
|
|
|
|
|
_, param = next(cell.parameters_and_names()) |
|
|
|
|
|
except StopIteration: |
|
|
|
|
|
continue |
|
|
|
|
|
if first_param is None: |
|
|
|
|
|
first_param = param |
|
|
|
|
|
first_key = key |
|
|
|
|
|
continue |
|
|
|
|
|
second_param = param |
|
|
|
|
|
second_key = key |
|
|
|
|
|
break |
|
|
|
|
|
|
|
|
|
|
|
if first_param is None: |
|
|
|
|
|
return prefix, index |
|
|
|
|
|
|
|
|
|
|
|
split_names = first_param.name.split(".") |
|
|
|
|
|
for idx, name in enumerate(split_names): |
|
|
|
|
|
if name == first_key: |
|
|
|
|
|
prefix = ".".join(split_names[:idx]) |
|
|
|
|
|
prefix = prefix + "." if prefix else prefix |
|
|
|
|
|
index = idx |
|
|
|
|
|
if second_param is not None and second_param.name.split(".")[idx] == second_key: |
|
|
|
|
|
break |
|
|
|
|
|
return prefix, index |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class _CellListBase(): |
|
|
class _CellListBase(): |
|
|
""" |
|
|
""" |
|
|
An interface for base the cell as list. |
|
|
An interface for base the cell as list. |
|
|
@@ -97,19 +134,26 @@ class SequentialCell(Cell): |
|
|
""" |
|
|
""" |
|
|
def __init__(self, *args): |
|
|
def __init__(self, *args): |
|
|
super(SequentialCell, self).__init__() |
|
|
super(SequentialCell, self).__init__() |
|
|
|
|
|
self._is_dynamic_name = [] |
|
|
if len(args) == 1: |
|
|
if len(args) == 1: |
|
|
cells = args[0] |
|
|
cells = args[0] |
|
|
if isinstance(cells, list): |
|
|
if isinstance(cells, list): |
|
|
for index, cell in enumerate(cells): |
|
|
for index, cell in enumerate(cells): |
|
|
self.insert_child_to_cell(str(index), cell) |
|
|
self.insert_child_to_cell(str(index), cell) |
|
|
|
|
|
cell.update_parameters_name(str(index) + ".") |
|
|
|
|
|
self._is_dynamic_name.append(True) |
|
|
elif isinstance(cells, OrderedDict): |
|
|
elif isinstance(cells, OrderedDict): |
|
|
for name, cell in cells.items(): |
|
|
for name, cell in cells.items(): |
|
|
self.insert_child_to_cell(name, cell) |
|
|
self.insert_child_to_cell(name, cell) |
|
|
|
|
|
cell.update_parameters_name(name + ".") |
|
|
|
|
|
self._is_dynamic_name.append(False) |
|
|
else: |
|
|
else: |
|
|
raise TypeError('Cells must be list or orderedDict') |
|
|
raise TypeError('Cells must be list or orderedDict') |
|
|
else: |
|
|
else: |
|
|
for index, cell in enumerate(args): |
|
|
for index, cell in enumerate(args): |
|
|
self.insert_child_to_cell(str(index), cell) |
|
|
self.insert_child_to_cell(str(index), cell) |
|
|
|
|
|
cell.update_parameters_name(str(index) + ".") |
|
|
|
|
|
self._is_dynamic_name.append(True) |
|
|
self.cell_list = list(self._cells.values()) |
|
|
self.cell_list = list(self._cells.values()) |
|
|
|
|
|
|
|
|
def __getitem__(self, index): |
|
|
def __getitem__(self, index): |
|
|
@@ -121,9 +165,11 @@ class SequentialCell(Cell): |
|
|
|
|
|
|
|
|
def __setitem__(self, index, cell): |
|
|
def __setitem__(self, index, cell): |
|
|
if _valid_cell(cell): |
|
|
if _valid_cell(cell): |
|
|
|
|
|
prefix, _ = _get_prefix_and_index(self._cells) |
|
|
index = _valid_index(len(self), index) |
|
|
index = _valid_index(len(self), index) |
|
|
key = list(self._cells.keys())[index] |
|
|
key = list(self._cells.keys())[index] |
|
|
self._cells[key] = cell |
|
|
self._cells[key] = cell |
|
|
|
|
|
cell.update_parameters_name(prefix + key + ".") |
|
|
self.cell_list = list(self._cells.values()) |
|
|
self.cell_list = list(self._cells.values()) |
|
|
|
|
|
|
|
|
def __delitem__(self, index): |
|
|
def __delitem__(self, index): |
|
|
@@ -131,12 +177,25 @@ class SequentialCell(Cell): |
|
|
index = _valid_index(len(self), index) |
|
|
index = _valid_index(len(self), index) |
|
|
key = list(self._cells.keys())[index] |
|
|
key = list(self._cells.keys())[index] |
|
|
del self._cells[key] |
|
|
del self._cells[key] |
|
|
|
|
|
del self._is_dynamic_name[index] |
|
|
elif isinstance(index, slice): |
|
|
elif isinstance(index, slice): |
|
|
keys = list(self._cells.keys())[index] |
|
|
keys = list(self._cells.keys())[index] |
|
|
for key in keys: |
|
|
for key in keys: |
|
|
del self._cells[key] |
|
|
del self._cells[key] |
|
|
|
|
|
del self._is_dynamic_name[index] |
|
|
else: |
|
|
else: |
|
|
raise TypeError('Index {} is not int type or slice type'.format(index)) |
|
|
raise TypeError('Index {} is not int type or slice type'.format(index)) |
|
|
|
|
|
prefix, key_index = _get_prefix_and_index(self._cells) |
|
|
|
|
|
temp_dict = OrderedDict() |
|
|
|
|
|
for idx, key in enumerate(self._cells.keys()): |
|
|
|
|
|
cell = self._cells[key] |
|
|
|
|
|
if self._is_dynamic_name[idx]: |
|
|
|
|
|
for _, param in cell.parameters_and_names(): |
|
|
|
|
|
param.name = prefix + str(idx) + "." + ".".join(param.name.split(".")[key_index+1:]) |
|
|
|
|
|
temp_dict[str(idx)] = cell |
|
|
|
|
|
else: |
|
|
|
|
|
temp_dict[key] = cell |
|
|
|
|
|
self._cells = temp_dict |
|
|
self.cell_list = list(self._cells.values()) |
|
|
self.cell_list = list(self._cells.values()) |
|
|
|
|
|
|
|
|
def __len__(self): |
|
|
def __len__(self): |
|
|
@@ -165,6 +224,9 @@ class SequentialCell(Cell): |
|
|
[26.999863 26.999863]]]] |
|
|
[26.999863 26.999863]]]] |
|
|
""" |
|
|
""" |
|
|
if _valid_cell(cell): |
|
|
if _valid_cell(cell): |
|
|
|
|
|
prefix, _ = _get_prefix_and_index(self._cells) |
|
|
|
|
|
cell.update_parameters_name(prefix + str(len(self)) + ".") |
|
|
|
|
|
self._is_dynamic_name.append(True) |
|
|
self._cells[str(len(self))] = cell |
|
|
self._cells[str(len(self))] = cell |
|
|
self.cell_list = list(self._cells.values()) |
|
|
self.cell_list = list(self._cells.values()) |
|
|
|
|
|
|
|
|
@@ -202,9 +264,10 @@ class CellList(_CellListBase, Cell): |
|
|
(2): ReLU<> |
|
|
(2): ReLU<> |
|
|
> |
|
|
> |
|
|
""" |
|
|
""" |
|
|
def __init__(self, *args): |
|
|
|
|
|
|
|
|
def __init__(self, *args, **kwargs): |
|
|
|
|
|
auto_prefix = kwargs["auto_prefix"] if "auto_prefix" in kwargs.keys() else True |
|
|
_CellListBase.__init__(self) |
|
|
_CellListBase.__init__(self) |
|
|
Cell.__init__(self) |
|
|
|
|
|
|
|
|
Cell.__init__(self, auto_prefix) |
|
|
if len(args) == 1: |
|
|
if len(args) == 1: |
|
|
self.extend(args[0]) |
|
|
self.extend(args[0]) |
|
|
|
|
|
|
|
|
@@ -220,6 +283,9 @@ class CellList(_CellListBase, Cell): |
|
|
if not isinstance(index, int) and _valid_cell(cell): |
|
|
if not isinstance(index, int) and _valid_cell(cell): |
|
|
raise TypeError('Index {} is not int type'.format(index)) |
|
|
raise TypeError('Index {} is not int type'.format(index)) |
|
|
index = _valid_index(len(self), index) |
|
|
index = _valid_index(len(self), index) |
|
|
|
|
|
if self._auto_prefix: |
|
|
|
|
|
prefix, _ = _get_prefix_and_index(self._cells) |
|
|
|
|
|
cell.update_parameters_name(prefix + str(index) + ".") |
|
|
self._cells[str(index)] = cell |
|
|
self._cells[str(index)] = cell |
|
|
|
|
|
|
|
|
def __delitem__(self, index): |
|
|
def __delitem__(self, index): |
|
|
@@ -233,8 +299,12 @@ class CellList(_CellListBase, Cell): |
|
|
else: |
|
|
else: |
|
|
raise TypeError('Index {} is not int type or slice type'.format(index)) |
|
|
raise TypeError('Index {} is not int type or slice type'.format(index)) |
|
|
# adjust orderedDict |
|
|
# adjust orderedDict |
|
|
|
|
|
prefix, key_index = _get_prefix_and_index(self._cells) |
|
|
temp_dict = OrderedDict() |
|
|
temp_dict = OrderedDict() |
|
|
for idx, cell in enumerate(self._cells.values()): |
|
|
for idx, cell in enumerate(self._cells.values()): |
|
|
|
|
|
if self._auto_prefix: |
|
|
|
|
|
for _, param in cell.parameters_and_names(): |
|
|
|
|
|
param.name = prefix + str(idx) + "." + ".".join(param.name.split(".")[key_index+1:]) |
|
|
temp_dict[str(idx)] = cell |
|
|
temp_dict[str(idx)] = cell |
|
|
self._cells = temp_dict |
|
|
self._cells = temp_dict |
|
|
|
|
|
|
|
|
@@ -253,10 +323,17 @@ class CellList(_CellListBase, Cell): |
|
|
idx = _valid_index(len(self), index) |
|
|
idx = _valid_index(len(self), index) |
|
|
_valid_cell(cell) |
|
|
_valid_cell(cell) |
|
|
length = len(self) |
|
|
length = len(self) |
|
|
|
|
|
prefix, key_index = _get_prefix_and_index(self._cells) |
|
|
while length > idx: |
|
|
while length > idx: |
|
|
|
|
|
if self._auto_prefix: |
|
|
|
|
|
tmp_cell = self._cells[str(length-1)] |
|
|
|
|
|
for _, param in tmp_cell.parameters_and_names(): |
|
|
|
|
|
param.name = prefix + str(length) + "." + ".".join(param.name.split(".")[key_index+1:]) |
|
|
self._cells[str(length)] = self._cells[str(length - 1)] |
|
|
self._cells[str(length)] = self._cells[str(length - 1)] |
|
|
length -= 1 |
|
|
length -= 1 |
|
|
self._cells[str(idx)] = cell |
|
|
self._cells[str(idx)] = cell |
|
|
|
|
|
if self._auto_prefix: |
|
|
|
|
|
cell.update_parameters_name(prefix + str(idx) + ".") |
|
|
|
|
|
|
|
|
def extend(self, cells): |
|
|
def extend(self, cells): |
|
|
""" |
|
|
""" |
|
|
@@ -267,14 +344,20 @@ class CellList(_CellListBase, Cell): |
|
|
""" |
|
|
""" |
|
|
if not isinstance(cells, list): |
|
|
if not isinstance(cells, list): |
|
|
raise TypeError('Cells {} should be list of subcells'.format(cells)) |
|
|
raise TypeError('Cells {} should be list of subcells'.format(cells)) |
|
|
|
|
|
prefix, _ = _get_prefix_and_index(self._cells) |
|
|
for cell in cells: |
|
|
for cell in cells: |
|
|
if _valid_cell(cell): |
|
|
if _valid_cell(cell): |
|
|
|
|
|
if self._auto_prefix: |
|
|
|
|
|
cell.update_parameters_name(prefix + str(len(self)) + ".") |
|
|
self._cells[str(len(self))] = cell |
|
|
self._cells[str(len(self))] = cell |
|
|
return self |
|
|
return self |
|
|
|
|
|
|
|
|
def append(self, cell): |
|
|
def append(self, cell): |
|
|
"""Appends a given cell to the end of the list.""" |
|
|
"""Appends a given cell to the end of the list.""" |
|
|
if _valid_cell(cell): |
|
|
if _valid_cell(cell): |
|
|
|
|
|
if self._auto_prefix: |
|
|
|
|
|
prefix, _ = _get_prefix_and_index(self._cells) |
|
|
|
|
|
cell.update_parameters_name(prefix + str(len(self)) + ".") |
|
|
self._cells[str(len(self))] = cell |
|
|
self._cells[str(len(self))] = cell |
|
|
|
|
|
|
|
|
def set_grad(self, flag=True): |
|
|
def set_grad(self, flag=True): |
|
|
|