Browse Source

fix bug in dataset helper

tags/v1.1.0
liyong 5 years ago
parent
commit
602a8e52a0
1 changed files with 3 additions and 0 deletions
  1. +3
    -0
      mindspore/train/dataset_helper.py

+ 3
- 0
mindspore/train/dataset_helper.py View File

@@ -106,6 +106,9 @@ def connect_network_with_dataset(network, dataset_helper):
if hasattr(dataset, '__network_manage__') and key in dataset.__network_manage__:
network = dataset.__network_manage__[key]
else:
if _need_to_full():
device_num = _get_device_num()
dataset_shapes = _to_full_shapes(dataset_shapes, device_num)
network = _DataWrapper(network, dataset_types, dataset_shapes, dataset.__transfer_dataset__.queue_name)
dataset.__network_manage__ = dataset.__network_manage__ if hasattr(
dataset, '__network_manage__') else dict()


Loading…
Cancel
Save