|
|
|
@@ -56,7 +56,7 @@ class Cell: |
|
|
|
>>> def construct(self, x): |
|
|
|
>>> return self.relu(x) |
|
|
|
""" |
|
|
|
def __init__(self, auto_prefix=True): |
|
|
|
def __init__(self, auto_prefix=True, flags=None): |
|
|
|
self._params = OrderedDict() |
|
|
|
self._cells = OrderedDict() |
|
|
|
self.training = False |
|
|
|
@@ -74,6 +74,8 @@ class Cell: |
|
|
|
if _get_parallel_mode() in ["auto_parallel", "semi_auto_parallel"]: |
|
|
|
self._get_construct_inputs_number_and_name() |
|
|
|
self._parallel_inputs_run = None |
|
|
|
if flags: |
|
|
|
self.add_flags(**flags) |
|
|
|
|
|
|
|
@property |
|
|
|
def create_time(self): |
|
|
|
@@ -603,6 +605,11 @@ class Cell: |
|
|
|
cell.add_flags_recursive(**flags) |
|
|
|
return self |
|
|
|
|
|
|
|
def get_flags(self): |
|
|
|
if not hasattr(self, "_mindspore_flags"): |
|
|
|
self._mindspore_flags = {} |
|
|
|
return self._mindspore_flags |
|
|
|
|
|
|
|
def to_float(self, dst_type): |
|
|
|
""" |
|
|
|
Add cast on all inputs of cell and child cells to run with certain float type. |
|
|
|
|