From dd63faa90c00e6a407535255c48fc4cf2891eaf1 Mon Sep 17 00:00:00 2001 From: jinyaohui Date: Fri, 11 Dec 2020 21:30:04 +0800 Subject: [PATCH] modify dataset_helper example --- mindspore/train/dataset_helper.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/mindspore/train/dataset_helper.py b/mindspore/train/dataset_helper.py index e7ef2e14c5..a95ec9d266 100644 --- a/mindspore/train/dataset_helper.py +++ b/mindspore/train/dataset_helper.py @@ -62,7 +62,7 @@ def connect_network_with_dataset(network, dataset_helper): Examples: >>> # call create_dataset function to create a regular dataset, refer to mindspore.dataset - >>> train_dataset = create_dataset() + >>> train_dataset = create_custom_dataset() >>> dataset_helper = mindspore.DatasetHelper(train_dataset, dataset_sink_mode=True) >>> net = Net() >>> net_with_get_next = connect_network_with_dataset(net, dataset_helper) @@ -152,9 +152,13 @@ class DatasetHelper: epoch_num (int): Control the number of epoch data to send. Default: 1. Examples: - >>> dataset_helper = DatasetHelper(dataset) - >>> for inputs in dataset_helper: - >>> outputs = network(*inputs) + >>> network = Net() + >>> net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") + >>> network = nn.WithLossCell(network, net_loss) + >>> train_dataset = create_custom_dataset() + >>> dataset_helper = DatasetHelper(train_dataset, dataset_sink_mode=False) + >>> for next_element in dataset_helper: + ... outputs = network(*next_element) """ def __init__(self, dataset, dataset_sink_mode=True, sink_size=-1, epoch_num=1):