Browse Source

Repair parameter check problem in TFRecordDataset

tags/v0.3.0-alpha
ms_yan chang zherui 5 years ago
parent
commit
51feea03a4
1 changed files with 5 additions and 0 deletions
  1. +5
    -0
      mindspore/dataset/engine/validators.py

+ 5
- 0
mindspore/dataset/engine/validators.py View File

@@ -398,6 +398,7 @@ def check_tfrecorddataset(method):


nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id'] nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
nreq_param_list = ['columns_list'] nreq_param_list = ['columns_list']
nreq_param_bool = ['shard_equal_rows']


# check dataset_files; required argument # check dataset_files; required argument
dataset_files = param_dict.get('dataset_files') dataset_files = param_dict.get('dataset_files')
@@ -410,6 +411,10 @@ def check_tfrecorddataset(method):


check_param_type(nreq_param_list, param_dict, list) check_param_type(nreq_param_list, param_dict, list)


check_param_type(nreq_param_bool, param_dict, bool)

check_sampler_shuffle_shard_options(param_dict)

return method(*args, **kwargs) return method(*args, **kwargs)


return new_method return new_method


Loading…
Cancel
Save