|
|
@@ -205,45 +205,6 @@ class TrainOneStepCell(Cell): |
|
|
return F.depend(loss, self.optimizer(grads)) |
|
|
return F.depend(loss, self.optimizer(grads)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DataWrapper(Cell): |
|
|
|
|
|
""" |
|
|
|
|
|
Network training package class for dataset. |
|
|
|
|
|
|
|
|
|
|
|
DataWrapper wraps the input network with a dataset which automatically fetches data with 'GetNext' |
|
|
|
|
|
function from the dataset channel 'queue_name' and does forward computation in the construct function. |
|
|
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
|
|
network (Cell): The training network for dataset. |
|
|
|
|
|
dataset_types (list): The type of dataset. The list contains the types of the inputs. |
|
|
|
|
|
dataset_shapes (list): The shapes of dataset. The list contains multiple sublists that describe |
|
|
|
|
|
the shape of the inputs. |
|
|
|
|
|
queue_name (str): The identification of dataset channel which specifies the dataset channel to supply |
|
|
|
|
|
data for the network. |
|
|
|
|
|
|
|
|
|
|
|
Outputs: |
|
|
|
|
|
Tensor, network output whose shape depends on the network. |
|
|
|
|
|
|
|
|
|
|
|
Examples: |
|
|
|
|
|
>>> # call create_dataset function to create a regular dataset, refer to mindspore.dataset |
|
|
|
|
|
>>> train_dataset = create_dataset() |
|
|
|
|
|
>>> dataset_helper = mindspore.DatasetHelper(train_dataset) |
|
|
|
|
|
>>> net = Net() |
|
|
|
|
|
>>> net = DataWrapper(net, *(dataset_helper.types_shapes()), train_dataset.queue_name) |
|
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, network, dataset_types, dataset_shapes, queue_name): |
|
|
|
|
|
super(DataWrapper, self).__init__(auto_prefix=False, flags=network.get_flags()) |
|
|
|
|
|
# Also copy the flag in `network` construct |
|
|
|
|
|
flags = getattr(network.__class__.construct, "_mindspore_flags", {}) |
|
|
|
|
|
self.add_flags(**flags) |
|
|
|
|
|
self.get_next = P.GetNext(dataset_types, dataset_shapes, len(dataset_types), queue_name) |
|
|
|
|
|
self.network = network |
|
|
|
|
|
|
|
|
|
|
|
def construct(self): |
|
|
|
|
|
outputs = self.get_next() |
|
|
|
|
|
return self.network(*outputs) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class GetNextSingleOp(Cell): |
|
|
class GetNextSingleOp(Cell): |
|
|
""" |
|
|
""" |
|
|
Cell to run for getting the next operation. |
|
|
Cell to run for getting the next operation. |
|
|
|