diff --git a/mindspore/nn/cell.py b/mindspore/nn/cell.py index a7fe712f3c..3f3f2ecdc0 100755 --- a/mindspore/nn/cell.py +++ b/mindspore/nn/cell.py @@ -60,6 +60,7 @@ class Cell: self._cells = OrderedDict() self.training = False self.pynative = False + self._param_perfix = '' self._auto_prefix = auto_prefix self._scope = None self._phase = 'train' @@ -83,6 +84,24 @@ class Cell: def cell_init_args(self): return self._cell_init_args + @property + def param_perfix(self): + """ + Param perfix is the prfix of curent cell's direct child parameter. + """ + return self._param_perfix + + def update_cell_prefix(self): + """ + Update the all child cells' self.param_prefix. + + After invoked, can get all the cell's children's name perfix by '_param_perfix'. + """ + cells = self.cells_and_names + + for cell_name, cell in cells: + cell._param_perfix = cell_name + @cell_init_args.setter def cell_init_args(self, value): if not isinstance(value, str): @@ -223,7 +242,6 @@ class Cell: Args: params (dict): The parameters dictionary used for init data graph. """ - if params is None: for key in self.parameters_dict(): tensor = self.parameters_dict()[key].data @@ -253,7 +271,6 @@ class Cell: Args: inputs (Function or Cell): inputs of construct method. """ - parallel_inputs_run = [] if len(inputs) > self._construct_inputs_num: raise ValueError('Len of inputs: {} is bigger than self._construct_inputs_num: {}.'.