Browse Source

!461 Add interface to get attributes of network.

Merge pull request !461 from wsc/master
tags/v0.2.0-alpha
mindspore-ci-bot Gitee 6 years ago
parent
commit
507b63ea20
2 changed files with 9 additions and 2 deletions
  1. +8
    -1
      mindspore/nn/cell.py
  2. +1
    -1
      mindspore/nn/wrap/cell_wrapper.py

+ 8
- 1
mindspore/nn/cell.py View File

@@ -56,7 +56,7 @@ class Cell:
>>> def construct(self, x):
>>> return self.relu(x)
"""
def __init__(self, auto_prefix=True):
def __init__(self, auto_prefix=True, flags=None):
self._params = OrderedDict()
self._cells = OrderedDict()
self.training = False
@@ -74,6 +74,8 @@ class Cell:
if _get_parallel_mode() in ["auto_parallel", "semi_auto_parallel"]:
self._get_construct_inputs_number_and_name()
self._parallel_inputs_run = None
if flags:
self.add_flags(**flags)

@property
def create_time(self):
@@ -607,6 +609,11 @@ class Cell:
cell.add_flags_recursive(**flags)
return self

def get_flags(self):
if not hasattr(self, "_mindspore_flags"):
self._mindspore_flags = {}
return self._mindspore_flags

def to_float(self, dst_type):
"""
Add cast on all inputs of cell and child cells to run with certain float type.


+ 1
- 1
mindspore/nn/wrap/cell_wrapper.py View File

@@ -219,7 +219,7 @@ class DataWrapper(Cell):
"""

def __init__(self, network, dataset_types, dataset_shapes, queue_name):
super(DataWrapper, self).__init__(auto_prefix=False)
super(DataWrapper, self).__init__(auto_prefix=False, flags=network.get_flags())

self.get_next = P.GetNext(dataset_types, dataset_shapes, len(dataset_types), queue_name)
self.network = network


Loading…
Cancel
Save