|
|
|
@@ -328,6 +328,9 @@ class Cell(Cell_): |
|
|
|
return tuple(res) |
|
|
|
|
|
|
|
def cast_inputs(self, inputs, dst_type): |
|
|
|
""" |
|
|
|
Cast inputs to specified type. |
|
|
|
""" |
|
|
|
res = list() |
|
|
|
for item in inputs: |
|
|
|
if isinstance(item, tuple): |
|
|
|
@@ -971,6 +974,9 @@ class Cell(Cell_): |
|
|
|
yield param |
|
|
|
|
|
|
|
def check_names(self): |
|
|
|
""" |
|
|
|
Check the names of cell parameters. |
|
|
|
""" |
|
|
|
names = set("") |
|
|
|
for value, param in self.parameters_and_names(): |
|
|
|
if param.name in names: |
|
|
|
@@ -1116,6 +1122,11 @@ class Cell(Cell_): |
|
|
|
return cells |
|
|
|
|
|
|
|
def add_flags(self, **flags): |
|
|
|
""" |
|
|
|
Add customized attributes for cell. |
|
|
|
|
|
|
|
This method is also called when the cell class is instantiated and the class parameter 'flag' is set to True. |
|
|
|
""" |
|
|
|
if not hasattr(self, "_mindspore_flags"): |
|
|
|
self._mindspore_flags = {} |
|
|
|
self._mindspore_flags.update({**flags}) |
|
|
|
@@ -1123,6 +1134,9 @@ class Cell(Cell_): |
|
|
|
return self |
|
|
|
|
|
|
|
def add_flags_recursive(self, **flags): |
|
|
|
""" |
|
|
|
If a cell contains child cells, this method can recursively customize attributes of all cells. |
|
|
|
""" |
|
|
|
self.add_flags(**flags) |
|
|
|
for cell in self.cells(): |
|
|
|
cell.add_flags_recursive(**flags) |
|
|
|
@@ -1133,6 +1147,9 @@ class Cell(Cell_): |
|
|
|
self._cell_init_args += str({**args}) |
|
|
|
|
|
|
|
def get_flags(self): |
|
|
|
""" |
|
|
|
Get the attributes of cell's flags. |
|
|
|
""" |
|
|
|
if not hasattr(self, "_mindspore_flags"): |
|
|
|
self._mindspore_flags = {} |
|
|
|
return self._mindspore_flags |
|
|
|
|