|
|
|
@@ -154,6 +154,16 @@ def connect_network_with_dataset(network, dataset_helper): |
|
|
|
dataset.__me_inited__ = True |
|
|
|
dataset_types, dataset_shapes = dataset_helper.types_shapes() |
|
|
|
network = _generate_dataset_sink_mode_net(network, dataset_shapes, dataset_types, queue_name) |
|
|
|
|
|
|
|
if hasattr(dataset_iter, "sink_size") and \ |
|
|
|
dataset_iter.sink_size == 1 and \ |
|
|
|
dataset.get_dataset_size() != 1 and \ |
|
|
|
hasattr(dataset_iter, "sink_count") and \ |
|
|
|
dataset_iter.sink_count == 1 and \ |
|
|
|
context.get_context("device_target") == "Ascend" and \ |
|
|
|
context.get_context("mode") == context.PYNATIVE_MODE: |
|
|
|
dataset_helper.get_data_info() |
|
|
|
|
|
|
|
return network |
|
|
|
|
|
|
|
|
|
|
|
|