Browse Source

!1087 add param_perfix to cell

Merge pull request !1087 from SanjayChan/03cell
tags/v0.3.0-alpha
mindspore-ci-bot Gitee 5 years ago
parent
commit
00b7877ec4
1 changed files with 19 additions and 2 deletions
  1. +19
    -2
      mindspore/nn/cell.py

+ 19
- 2
mindspore/nn/cell.py View File

@@ -60,6 +60,7 @@ class Cell:
self._cells = OrderedDict() self._cells = OrderedDict()
self.training = False self.training = False
self.pynative = False self.pynative = False
self._param_perfix = ''
self._auto_prefix = auto_prefix self._auto_prefix = auto_prefix
self._scope = None self._scope = None
self._phase = 'train' self._phase = 'train'
@@ -83,6 +84,24 @@ class Cell:
def cell_init_args(self): def cell_init_args(self):
return self._cell_init_args return self._cell_init_args


@property
def param_perfix(self):
"""
Param perfix is the prfix of curent cell's direct child parameter.
"""
return self._param_perfix

def update_cell_prefix(self):
"""
Update the all child cells' self.param_prefix.

After invoked, can get all the cell's children's name perfix by '_param_perfix'.
"""
cells = self.cells_and_names

for cell_name, cell in cells:
cell._param_perfix = cell_name

@cell_init_args.setter @cell_init_args.setter
def cell_init_args(self, value): def cell_init_args(self, value):
if not isinstance(value, str): if not isinstance(value, str):
@@ -223,7 +242,6 @@ class Cell:
Args: Args:
params (dict): The parameters dictionary used for init data graph. params (dict): The parameters dictionary used for init data graph.
""" """

if params is None: if params is None:
for key in self.parameters_dict(): for key in self.parameters_dict():
tensor = self.parameters_dict()[key].data tensor = self.parameters_dict()[key].data
@@ -253,7 +271,6 @@ class Cell:
Args: Args:
inputs (Function or Cell): inputs of construct method. inputs (Function or Cell): inputs of construct method.
""" """

parallel_inputs_run = [] parallel_inputs_run = []
if len(inputs) > self._construct_inputs_num: if len(inputs) > self._construct_inputs_num:
raise ValueError('Len of inputs: {} is bigger than self._construct_inputs_num: {}.'. raise ValueError('Len of inputs: {} is bigger than self._construct_inputs_num: {}.'.


Loading…
Cancel
Save