Browse Source

add param_perfix to cell

tags/v0.3.0-alpha
chenzomi 5 years ago
parent
commit
6710e1feac
1 changed files with 19 additions and 2 deletions
  1. +19
    -2
      mindspore/nn/cell.py

+ 19
- 2
mindspore/nn/cell.py View File

@@ -61,6 +61,7 @@ class Cell:
self._cells = OrderedDict()
self.training = False
self.pynative = False
self._param_perfix = ''
self._auto_prefix = auto_prefix
self._scope = None
self._phase = 'train'
@@ -85,6 +86,24 @@ class Cell:
def cell_init_args(self):
return self._cell_init_args

@property
def param_perfix(self):
"""
Param perfix is the prfix of curent cell's direct child parameter.
"""
return self._param_perfix

def update_cell_prefix(self):
"""
Update the all child cells' self.param_prefix.

After invoked, can get all the cell's children's name perfix by '_param_perfix'.
"""
cells = self.cells_and_names

for cell_name, cell in cells:
cell._param_perfix = cell_name

@cell_init_args.setter
def cell_init_args(self, value):
if not isinstance(value, str):
@@ -225,7 +244,6 @@ class Cell:
Args:
params (dict): The parameters dictionary used for init data graph.
"""

if params is None:
for key in self.parameters_dict():
tensor = self.parameters_dict()[key].data
@@ -255,7 +273,6 @@ class Cell:
Args:
inputs (Function or Cell): inputs of construct method.
"""

parallel_inputs_run = []
if len(inputs) > self._construct_inputs_num:
raise ValueError('Len of inputs: {} is bigger than self._construct_inputs_num: {}.'.


Loading…
Cancel
Save