Browse Source

delete DataWrapper

tags/v1.0.0
wangnan39@huawei.com 5 years ago
parent
commit
ef94ac12e9
3 changed files with 2 additions and 42 deletions
  1. +1
    -2
      mindspore/nn/wrap/__init__.py
  2. +0
    -39
      mindspore/nn/wrap/cell_wrapper.py
  3. +1
    -1
      mindspore/ops/operations/nn_ops.py

+ 1
- 2
mindspore/nn/wrap/__init__.py View File

@@ -17,7 +17,7 @@ Wrap cells for networks.

Use the Wrapper to combine the loss or build the training steps.
"""
from .cell_wrapper import TrainOneStepCell, WithLossCell, WithGradCell, WithEvalCell, DataWrapper, \
from .cell_wrapper import TrainOneStepCell, WithLossCell, WithGradCell, WithEvalCell, \
ParameterUpdate, GetNextSingleOp, VirtualDatasetCellTriple
from .loss_scale import TrainOneStepWithLossScaleCell, DynamicLossScaleUpdateCell, FixedLossScaleUpdateCell
from .grad_reducer import DistributedGradReducer
@@ -27,7 +27,6 @@ __all__ = [
"WithLossCell",
"WithGradCell",
"WithEvalCell",
"DataWrapper",
"GetNextSingleOp",
"TrainOneStepWithLossScaleCell",
"DistributedGradReducer",


+ 0
- 39
mindspore/nn/wrap/cell_wrapper.py View File

@@ -205,45 +205,6 @@ class TrainOneStepCell(Cell):
return F.depend(loss, self.optimizer(grads))


class DataWrapper(Cell):
"""
Network training package class for dataset.

DataWrapper wraps the input network with a dataset which automatically fetches data with 'GetNext'
function from the dataset channel 'queue_name' and does forward computation in the construct function.

Args:
network (Cell): The training network for dataset.
dataset_types (list): The type of dataset. The list contains the types of the inputs.
dataset_shapes (list): The shapes of dataset. The list contains multiple sublists that describe
the shape of the inputs.
queue_name (str): The identification of dataset channel which specifies the dataset channel to supply
data for the network.

Outputs:
Tensor, network output whose shape depends on the network.

Examples:
>>> # call create_dataset function to create a regular dataset, refer to mindspore.dataset
>>> train_dataset = create_dataset()
>>> dataset_helper = mindspore.DatasetHelper(train_dataset)
>>> net = Net()
>>> net = DataWrapper(net, *(dataset_helper.types_shapes()), train_dataset.queue_name)
"""

def __init__(self, network, dataset_types, dataset_shapes, queue_name):
super(DataWrapper, self).__init__(auto_prefix=False, flags=network.get_flags())
# Also copy the flag in `network` construct
flags = getattr(network.__class__.construct, "_mindspore_flags", {})
self.add_flags(**flags)
self.get_next = P.GetNext(dataset_types, dataset_shapes, len(dataset_types), queue_name)
self.network = network

def construct(self):
outputs = self.get_next()
return self.network(*outputs)


class GetNextSingleOp(Cell):
"""
Cell to run for getting the next operation.


+ 1
- 1
mindspore/ops/operations/nn_ops.py View File

@@ -2538,7 +2538,7 @@ class GetNext(PrimitiveWithInfer):
Note:
The GetNext operation needs to be associated with network and it also depends on the init_dataset interface,
it can't be used directly as a single operation.
For details, please refer to `nn.DataWrapper` source code.
For details, please refer to `connect_network_with_dataset` source code.

Args:
types (list[:class:`mindspore.dtype`]): The type of the outputs.


Loading…
Cancel
Save