From f51d5fc411104153c8169a1b6313b89412faa64e Mon Sep 17 00:00:00 2001 From: wsc Date: Mon, 20 Apr 2020 16:34:00 +0800 Subject: [PATCH] Add interface to get the attributes of network --- mindspore/nn/cell.py | 9 ++++++++- mindspore/nn/wrap/cell_wrapper.py | 2 +- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/mindspore/nn/cell.py b/mindspore/nn/cell.py index 5507d12af8..853abff0b6 100755 --- a/mindspore/nn/cell.py +++ b/mindspore/nn/cell.py @@ -56,7 +56,7 @@ class Cell: >>> def construct(self, x): >>> return self.relu(x) """ - def __init__(self, auto_prefix=True): + def __init__(self, auto_prefix=True, flags=None): self._params = OrderedDict() self._cells = OrderedDict() self.training = False @@ -74,6 +74,8 @@ class Cell: if _get_parallel_mode() in ["auto_parallel", "semi_auto_parallel"]: self._get_construct_inputs_number_and_name() self._parallel_inputs_run = None + if flags: + self.add_flags(**flags) @property def create_time(self): @@ -603,6 +605,11 @@ class Cell: cell.add_flags_recursive(**flags) return self + def get_flags(self): + if not hasattr(self, "_mindspore_flags"): + self._mindspore_flags = {} + return self._mindspore_flags + def to_float(self, dst_type): """ Add cast on all inputs of cell and child cells to run with certain float type. diff --git a/mindspore/nn/wrap/cell_wrapper.py b/mindspore/nn/wrap/cell_wrapper.py index 64c382557a..453ddae0fc 100644 --- a/mindspore/nn/wrap/cell_wrapper.py +++ b/mindspore/nn/wrap/cell_wrapper.py @@ -226,7 +226,7 @@ class DataWrapper(Cell): """ def __init__(self, network, dataset_types, dataset_shapes, queue_name): - super(DataWrapper, self).__init__(auto_prefix=False) + super(DataWrapper, self).__init__(auto_prefix=False, flags=network.get_flags()) self.get_next = P.GetNext(dataset_types, dataset_shapes, len(dataset_types), queue_name) self.network = network