|
|
|
@@ -127,9 +127,12 @@ def serialize_operations(node_repr, key, val): |
|
|
|
|
|
|
|
def serialize_sampler(node_repr, val): |
|
|
|
"""Serialize sampler object to dictionary.""" |
|
|
|
node_repr['sampler'] = val.__dict__ |
|
|
|
node_repr['sampler']['sampler_module'] = type(val).__module__ |
|
|
|
node_repr['sampler']['sampler_name'] = type(val).__name__ |
|
|
|
if val is None: |
|
|
|
node_repr['sampler'] = None |
|
|
|
else: |
|
|
|
node_repr['sampler'] = val.__dict__ |
|
|
|
node_repr['sampler']['sampler_module'] = type(val).__module__ |
|
|
|
node_repr['sampler']['sampler_name'] = type(val).__name__ |
|
|
|
|
|
|
|
|
|
|
|
def traverse(node): |
|
|
|
@@ -253,9 +256,10 @@ def create_node(node): |
|
|
|
node.get('shuffle'), sampler, node.get('num_shards'), node.get('shard_id')) |
|
|
|
|
|
|
|
elif dataset_op == 'MindDataset': |
|
|
|
pyobj = pyclass(node['dataset_file'], node.get('column_list'), |
|
|
|
sampler = construct_sampler(node.get('sampler')) |
|
|
|
pyobj = pyclass(node['dataset_file'], node.get('columns_list'), |
|
|
|
node.get('num_parallel_workers'), node.get('seed'), node.get('num_shards'), |
|
|
|
node.get('shard_id'), node.get('block_reader')) |
|
|
|
node.get('shard_id'), node.get('block_reader'), sampler) |
|
|
|
|
|
|
|
elif dataset_op == 'TFRecordDataset': |
|
|
|
pyobj = pyclass(node['dataset_files'], node.get('schema'), node.get('column_list'), |
|
|
|
@@ -341,24 +345,25 @@ def create_node(node): |
|
|
|
|
|
|
|
def construct_sampler(in_sampler): |
|
|
|
"""Instantiate Sampler object based on the information from dictionary['sampler']""" |
|
|
|
sampler_name = in_sampler['sampler_name'] |
|
|
|
sampler_module = in_sampler['sampler_module'] |
|
|
|
sampler_class = getattr(sys.modules[sampler_module], sampler_name) |
|
|
|
sampler = None |
|
|
|
if sampler_name == 'DistributedSampler': |
|
|
|
sampler = sampler_class(in_sampler['num_shards'], in_sampler['shard_id'], in_sampler.get('shuffle')) |
|
|
|
elif sampler_name == 'PKSampler': |
|
|
|
sampler = sampler_class(in_sampler['num_val'], in_sampler.get('num_class'), in_sampler('shuffle')) |
|
|
|
elif sampler_name == 'RandomSampler': |
|
|
|
sampler = sampler_class(in_sampler.get('replacement'), in_sampler.get('num_samples')) |
|
|
|
elif sampler_name == 'SequentialSampler': |
|
|
|
sampler = sampler_class() |
|
|
|
elif sampler_name == 'SubsetRandomSampler': |
|
|
|
sampler = sampler_class(in_sampler['indices']) |
|
|
|
elif sampler_name == 'WeightedRandomSampler': |
|
|
|
sampler = sampler_class(in_sampler['weights'], in_sampler['num_samples'], in_sampler.get('replacement')) |
|
|
|
else: |
|
|
|
raise ValueError("Sampler type is unknown: " + sampler_name) |
|
|
|
if in_sampler is not None: |
|
|
|
sampler_name = in_sampler['sampler_name'] |
|
|
|
sampler_module = in_sampler['sampler_module'] |
|
|
|
sampler_class = getattr(sys.modules[sampler_module], sampler_name) |
|
|
|
if sampler_name == 'DistributedSampler': |
|
|
|
sampler = sampler_class(in_sampler['num_shards'], in_sampler['shard_id'], in_sampler.get('shuffle')) |
|
|
|
elif sampler_name == 'PKSampler': |
|
|
|
sampler = sampler_class(in_sampler['num_val'], in_sampler.get('num_class'), in_sampler('shuffle')) |
|
|
|
elif sampler_name == 'RandomSampler': |
|
|
|
sampler = sampler_class(in_sampler.get('replacement'), in_sampler.get('num_samples')) |
|
|
|
elif sampler_name == 'SequentialSampler': |
|
|
|
sampler = sampler_class() |
|
|
|
elif sampler_name == 'SubsetRandomSampler': |
|
|
|
sampler = sampler_class(in_sampler['indices']) |
|
|
|
elif sampler_name == 'WeightedRandomSampler': |
|
|
|
sampler = sampler_class(in_sampler['weights'], in_sampler['num_samples'], in_sampler.get('replacement')) |
|
|
|
else: |
|
|
|
raise ValueError("Sampler type is unknown: " + sampler_name) |
|
|
|
|
|
|
|
return sampler |
|
|
|
|
|
|
|
|