|
|
|
@@ -89,6 +89,7 @@ class Cell(Cell_): |
|
|
|
self._scope = None |
|
|
|
self._phase = 'train' |
|
|
|
self._parameter_layout_dict = {} |
|
|
|
self._parallel_parameter_name_list = () |
|
|
|
self._create_time = int(time.time() * 1e9) |
|
|
|
self.phase_prefix = "" |
|
|
|
self.parameter_broadcast_done = False |
|
|
|
@@ -213,6 +214,16 @@ class Cell(Cell_): |
|
|
|
raise TypeError("'parameter_layout_dict' must be dict type.") |
|
|
|
self._parameter_layout_dict = value |
|
|
|
|
|
|
|
@property |
|
|
|
def parallel_parameter_name_list(self): |
|
|
|
return self._parallel_parameter_name_list |
|
|
|
|
|
|
|
@parallel_parameter_name_list.setter |
|
|
|
def parallel_parameter_name_list(self, value): |
|
|
|
if not isinstance(value, list): |
|
|
|
raise TypeError("'parallel_parameter_name_list' must be list type.") |
|
|
|
self._parallel_parameter_name_list = value |
|
|
|
|
|
|
|
def get_func_graph_proto(self): |
|
|
|
"""Return graph binary proto.""" |
|
|
|
return _executor._get_func_graph_proto(self, self.phase + "." + str(self.create_time), "anf_ir", True) |
|
|
|
@@ -656,6 +667,28 @@ class Cell(Cell_): |
|
|
|
""" |
|
|
|
return None |
|
|
|
|
|
|
|
def remove_redundant_parameters(self): |
|
|
|
"""Remove the redundant parameters""" |
|
|
|
cells = self.cells_and_names() |
|
|
|
for _, cell in cells: |
|
|
|
params = cell._params.items() |
|
|
|
for param_name, param in list(params): |
|
|
|
if param.name not in self.parallel_parameter_name_list: |
|
|
|
cell._params.pop(param_name) |
|
|
|
logger.info("remove the redundant parameter: %s", param.name) |
|
|
|
continue |
|
|
|
cell_dict = cell.__dict__ |
|
|
|
for key in cell_dict: |
|
|
|
if isinstance(cell_dict[key], ParameterTuple): |
|
|
|
param_tuple = cell_dict[key] |
|
|
|
new_param_tuple = [] |
|
|
|
for param in param_tuple: |
|
|
|
if param.name not in self.parallel_parameter_name_list: |
|
|
|
logger.info("remove the redundant parameter: %s in ParameterTuple", param.name) |
|
|
|
continue |
|
|
|
new_param_tuple.append(param) |
|
|
|
cell.__dict__[key] = ParameterTuple(new_param_tuple) |
|
|
|
|
|
|
|
def init_parameters_data(self, auto_parallel_mode=False): |
|
|
|
""" |
|
|
|
Initialize all parameters and replace the original saved parameters in cell. |
|
|
|
@@ -750,7 +783,7 @@ class Cell(Cell_): |
|
|
|
""" |
|
|
|
Returns all trainable parameters. |
|
|
|
|
|
|
|
Returns a list of all trainable parmeters. |
|
|
|
Returns a list of all trainable parameters. |
|
|
|
|
|
|
|
Args: |
|
|
|
recurse (bool): Whether contains the trainable parameters of subcells. Default: True. |
|
|
|
@@ -1031,7 +1064,7 @@ class Cell(Cell_): |
|
|
|
Note: |
|
|
|
fn must be defined as the following code. `cell_name` is the name of registered cell. |
|
|
|
`grad_input` is gradient passed to the cell. `grad_output` is the gradient computed and passed to the |
|
|
|
next cell or primitve, which may be modified and returned. |
|
|
|
next cell or primitive, which may be modified and returned. |
|
|
|
hook_fn(cell_name, grad_input, grad_output) -> Tensor or None. |
|
|
|
|
|
|
|
Args: |
|
|
|
|