Browse Source

!11379 fix missing return description in comment

From: @tiancixiao
Reviewed-by: @liucunwei,@heleiwang
Signed-off-by: @liucunwei
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 4 years ago
parent
commit
344567e5d7
16 changed files with 165 additions and 107 deletions
  1. +9
    -8
      mindspore/dataset/core/config.py
  2. +60
    -49
      mindspore/dataset/engine/datasets.py
  3. +10
    -10
      mindspore/dataset/engine/graphdata.py
  4. +2
    -1
      mindspore/dataset/engine/samplers.py
  5. +3
    -2
      mindspore/dataset/vision/py_transforms.py
  6. +1
    -1
      mindspore/explainer/_image_classification_runner.py
  7. +12
    -8
      mindspore/mindrecord/filewriter.py
  8. +22
    -5
      mindspore/mindrecord/mindpage.py
  9. +4
    -1
      mindspore/mindrecord/shardwriter.py
  10. +7
    -3
      mindspore/mindrecord/tools/cifar100_to_mr.py
  11. +8
    -3
      mindspore/mindrecord/tools/cifar10_to_mr.py
  12. +2
    -1
      mindspore/mindrecord/tools/csv_to_mr.py
  13. +4
    -3
      mindspore/mindrecord/tools/imagenet_to_mr.py
  14. +4
    -3
      mindspore/mindrecord/tools/mnist_to_mr.py
  15. +3
    -3
      mindspore/mindrecord/tools/tfrecord_to_mr.py
  16. +14
    -6
      model_zoo/official/recommend/ncf/src/dataset.py

+ 9
- 8
mindspore/dataset/core/config.py View File

@@ -102,7 +102,7 @@ def get_seed():
Get the seed. Get the seed.


Returns: Returns:
Int, seed.
int, seed.
""" """
return _config.get_seed() return _config.get_seed()


@@ -131,7 +131,7 @@ def get_prefetch_size():
Get the prefetch size in number of rows. Get the prefetch size in number of rows.


Returns: Returns:
Size, total number of rows to be prefetched.
int, total number of rows to be prefetched.
""" """
return _config.get_op_connector_size() return _config.get_op_connector_size()


@@ -162,7 +162,7 @@ def get_num_parallel_workers():
This is the DEFAULT num_parallel_workers value used for each op, it is not related to AutoNumWorker feature. This is the DEFAULT num_parallel_workers value used for each op, it is not related to AutoNumWorker feature.


Returns: Returns:
Int, number of parallel workers to be used as a default for each operation
int, number of parallel workers to be used as a default for each operation.
""" """
return _config.get_num_parallel_workers() return _config.get_num_parallel_workers()


@@ -193,7 +193,7 @@ def get_numa_enable():
This is the DEFAULT numa enabled value used for the all process. This is the DEFAULT numa enabled value used for the all process.


Returns: Returns:
boolean, the default state of numa enabled
bool, the default state of numa enabled.
""" """
return _config.get_numa_enable() return _config.get_numa_enable()


@@ -222,7 +222,7 @@ def get_monitor_sampling_interval():
Get the default interval of performance monitor sampling. Get the default interval of performance monitor sampling.


Returns: Returns:
Int, interval (in milliseconds) for performance monitor sampling.
int, interval (in milliseconds) for performance monitor sampling.
""" """
return _config.get_monitor_sampling_interval() return _config.get_monitor_sampling_interval()


@@ -280,7 +280,8 @@ def get_auto_num_workers():
Get the setting (turned on or off) automatic number of workers. Get the setting (turned on or off) automatic number of workers.


Returns: Returns:
Bool, whether auto num worker feature is turned on
bool, whether auto num worker feature is turned on.

Examples: Examples:
>>> num_workers = ds.config.get_auto_num_workers() >>> num_workers = ds.config.get_auto_num_workers()
""" """
@@ -313,7 +314,7 @@ def get_callback_timeout():
In case of a deadlock, the wait function will exit after the timeout period. In case of a deadlock, the wait function will exit after the timeout period.


Returns: Returns:
Int, the duration in seconds
int, the duration in seconds.
""" """
return _config.get_callback_timeout() return _config.get_callback_timeout()


@@ -323,7 +324,7 @@ def __str__():
String representation of the configurations. String representation of the configurations.


Returns: Returns:
Str, configurations.
str, configurations.
""" """
return str(_config) return str(_config)




+ 60
- 49
mindspore/dataset/engine/datasets.py View File

@@ -80,7 +80,7 @@ def zip(datasets):
The number of datasets must be more than 1. The number of datasets must be more than 1.


Returns: Returns:
Dataset, ZipDataset.
ZipDataset, dataset zipped.


Raises: Raises:
ValueError: If the number of datasets is 1. ValueError: If the number of datasets is 1.
@@ -149,8 +149,8 @@ class Dataset:
Internal method to create an IR tree. Internal method to create an IR tree.


Returns: Returns:
ir_tree, The onject of the IR tree.
dataset, the root dataset of the IR tree.
DatasetNode, the root node of the IR tree.
Dataset, the root dataset of the IR tree.
""" """
parent = self.parent parent = self.parent
self.parent = [] self.parent = []
@@ -165,7 +165,7 @@ class Dataset:
Internal method to parse the API tree into an IR tree. Internal method to parse the API tree into an IR tree.


Returns: Returns:
DatasetNode, The root of the IR tree.
DatasetNode, the root node of the IR tree.
""" """
if len(self.parent) > 1: if len(self.parent) > 1:
raise ValueError("The data pipeline is not a tree (i.e., one node has 2 consumers)") raise ValueError("The data pipeline is not a tree (i.e., one node has 2 consumers)")
@@ -197,7 +197,7 @@ class Dataset:
Args: Args:


Returns: Returns:
Python dictionary.
dict, attributes related to the current class.
""" """
args = dict() args = dict()
args["num_parallel_workers"] = self.num_parallel_workers args["num_parallel_workers"] = self.num_parallel_workers
@@ -211,7 +211,7 @@ class Dataset:
filename (str): filename of json file to be saved as filename (str): filename of json file to be saved as


Returns: Returns:
Str, JSON string of the pipeline.
str, JSON string of the pipeline.
""" """
return json.loads(self.parse_tree().to_json(filename)) return json.loads(self.parse_tree().to_json(filename))


@@ -258,6 +258,9 @@ class Dataset:
drop_remainder (bool, optional): If True, will drop the last batch for each drop_remainder (bool, optional): If True, will drop the last batch for each
bucket if it is not a full batch (default=False). bucket if it is not a full batch (default=False).


Returns:
BucketBatchByLengthDataset, dataset bucketed and batched by length.

Examples: Examples:
>>> import mindspore.dataset as ds >>> import mindspore.dataset as ds
>>> >>>
@@ -371,6 +374,9 @@ class Dataset:
num_batch (int): the number of batches without blocking at the start of each epoch. num_batch (int): the number of batches without blocking at the start of each epoch.
callback (function): The callback funciton that will be invoked when sync_update is called. callback (function): The callback funciton that will be invoked when sync_update is called.


Returns:
SyncWaitDataset, dataset added a blocking condition.

Raises: Raises:
RuntimeError: If condition name already exists. RuntimeError: If condition name already exists.


@@ -434,7 +440,7 @@ class Dataset:
return a 'Dataset'. return a 'Dataset'.


Returns: Returns:
Dataset, applied by the function.
Dataset, dataset applied by the function.


Examples: Examples:
>>> import mindspore.dataset as ds >>> import mindspore.dataset as ds
@@ -650,7 +656,7 @@ class Dataset:
in parallel (default=None). in parallel (default=None).


Returns: Returns:
FilterDataset, dataset filter.
FilterDataset, dataset filtered.


Examples: Examples:
>>> import mindspore.dataset as ds >>> import mindspore.dataset as ds
@@ -748,6 +754,9 @@ class Dataset:
""" """
Internal method called by split to calculate absolute split sizes and to Internal method called by split to calculate absolute split sizes and to
do some error checking after calculating absolute split sizes. do some error checking after calculating absolute split sizes.

Returns:
int, absolute split sizes of the dataset.
""" """
# Call get_dataset_size here and check input here because # Call get_dataset_size here and check input here because
# don't want to call this once in check_split and another time in # don't want to call this once in check_split and another time in
@@ -1015,7 +1024,7 @@ class Dataset:
is specified and special_first is set to default, special_tokens will be prepended is specified and special_first is set to default, special_tokens will be prepended


Returns: Returns:
Vocab node
Vocab, vocab built from the dataset.


Example: Example:
>>> import mindspore.dataset as ds >>> import mindspore.dataset as ds
@@ -1074,7 +1083,7 @@ class Dataset:
params(dict): contains more optional parameters of sentencepiece library params(dict): contains more optional parameters of sentencepiece library


Returns: Returns:
SentencePieceVocab node
SentencePieceVocab, vocab built from the dataset.


Example: Example:
>>> import mindspore.dataset as ds >>> import mindspore.dataset as ds
@@ -1115,7 +1124,7 @@ class Dataset:
return a preprogressing 'Dataset'. return a preprogressing 'Dataset'.


Returns: Returns:
Dataset, applied by the function.
Dataset, dataset applied by the function.


Examples: Examples:
>>> import mindspore.dataset as ds >>> import mindspore.dataset as ds
@@ -1159,7 +1168,7 @@ class Dataset:
If device is Ascend, features of data will be transferred one by one. The limitation If device is Ascend, features of data will be transferred one by one. The limitation
of data transmission per time is 256M. of data transmission per time is 256M.


Return:
Returns:
TransferDataset, dataset for transferring. TransferDataset, dataset for transferring.
""" """
return self.to_device(send_epoch_end=send_epoch_end, create_data_info_queue=create_data_info_queue) return self.to_device(send_epoch_end=send_epoch_end, create_data_info_queue=create_data_info_queue)
@@ -1287,7 +1296,7 @@ class Dataset:
use this param to select the conversion method, only take False for better performance (default=True). use this param to select the conversion method, only take False for better performance (default=True).


Returns: Returns:
Iterator, list of ndarrays.
TupleIterator, tuple iterator over the dataset.


Examples: Examples:
>>> import mindspore.dataset as ds >>> import mindspore.dataset as ds
@@ -1322,7 +1331,7 @@ class Dataset:
if output_numpy=False, iterator will output MSTensor (default=False). if output_numpy=False, iterator will output MSTensor (default=False).


Returns: Returns:
Iterator, dictionary of column name-ndarray pair.
DictIterator, dictionary iterator over the dataset.


Examples: Examples:
>>> import mindspore.dataset as ds >>> import mindspore.dataset as ds
@@ -1352,6 +1361,9 @@ class Dataset:
""" """
Get Input Index Information Get Input Index Information


Returns:
tuple, tuple of the input index information.

Examples: Examples:
>>> import mindspore.dataset as ds >>> import mindspore.dataset as ds
>>> >>>
@@ -1409,6 +1421,9 @@ class Dataset:
def get_col_names(self): def get_col_names(self):
""" """
Get names of the columns in the dataset Get names of the columns in the dataset

Returns:
list, list of column names in the dataset.
""" """
if self._col_names is None: if self._col_names is None:
runtime_getter = self._init_tree_getters() runtime_getter = self._init_tree_getters()
@@ -1419,8 +1434,8 @@ class Dataset:
""" """
Get the shapes of output data. Get the shapes of output data.


Return:
List, list of shapes of each column.
Returns:
list, list of shapes of each column.
""" """
if self.saved_output_shapes is None: if self.saved_output_shapes is None:
runtime_getter = self._init_tree_getters() runtime_getter = self._init_tree_getters()
@@ -1432,8 +1447,8 @@ class Dataset:
""" """
Get the types of output data. Get the types of output data.


Return:
List of data types.
Returns:
list, list of data types.
""" """
if self.saved_output_types is None: if self.saved_output_types is None:
runtime_getter = self._init_tree_getters() runtime_getter = self._init_tree_getters()
@@ -1445,8 +1460,8 @@ class Dataset:
""" """
Get the number of batches in an epoch. Get the number of batches in an epoch.


Return:
Number, number of batches.
Returns:
int, number of batches.
""" """
if self.dataset_size is None: if self.dataset_size is None:
runtime_getter = self._init_size_getter() runtime_getter = self._init_size_getter()
@@ -1457,8 +1472,8 @@ class Dataset:
""" """
Get the number of classes in a dataset. Get the number of classes in a dataset.


Return:
Number, number of classes.
Returns:
int, number of classes.
""" """
if self._num_classes is None: if self._num_classes is None:
runtime_getter = self._init_tree_getters() runtime_getter = self._init_tree_getters()
@@ -1511,8 +1526,8 @@ class Dataset:
""" """
Get the size of a batch. Get the size of a batch.


Return:
Number, the number of data in a batch.
Returns:
int, the number of data in a batch.
""" """
if self._batch_size is None: if self._batch_size is None:
runtime_getter = self._init_tree_getters() runtime_getter = self._init_tree_getters()
@@ -1525,8 +1540,8 @@ class Dataset:
""" """
Get the replication times in RepeatDataset else 1. Get the replication times in RepeatDataset else 1.


Return:
Number, the count of repeat.
Returns:
int, the count of repeat.
""" """
if self._repeat_count is None: if self._repeat_count is None:
runtime_getter = self._init_tree_getters() runtime_getter = self._init_tree_getters()
@@ -1540,8 +1555,8 @@ class Dataset:
Get the class index. Get the class index.


Returns: Returns:
Dict, A str-to-int mapping from label name to index.
Dict, A str-to-list<int> mapping from label name to index for Coco ONLY. The second number
dict, a str-to-int mapping from label name to index.
dict, a str-to-list<int> mapping from label name to index for Coco ONLY. The second number
in the list is used to indicate the super category in the list is used to indicate the super category
""" """
if self.children: if self.children:
@@ -1588,7 +1603,7 @@ class SourceDataset(Dataset):
patterns (Union[str, list[str]]): String or list of patterns to be searched. patterns (Union[str, list[str]]): String or list of patterns to be searched.


Returns: Returns:
List, files.
list, list of files.
""" """


if not isinstance(patterns, list): if not isinstance(patterns, list):
@@ -1646,9 +1661,6 @@ class MappableDataset(SourceDataset):
Args: Args:
new_sampler (Sampler): The sampler to use for the current dataset. new_sampler (Sampler): The sampler to use for the current dataset.


Returns:
Dataset, that uses new_sampler.

Examples: Examples:
>>> import mindspore.dataset as ds >>> import mindspore.dataset as ds
>>> >>>
@@ -1909,8 +1921,9 @@ class BatchDataset(Dataset):


Args: Args:
dataset (Dataset): Dataset to be checked. dataset (Dataset): Dataset to be checked.
Return:
True or False.

Returns:
bool, whether repeat is used before batch.
""" """
if isinstance(dataset, RepeatDataset): if isinstance(dataset, RepeatDataset):
return True return True
@@ -1995,18 +2008,12 @@ class BatchInfo(cde.CBatchInfo):
def get_batch_num(self): def get_batch_num(self):
""" """
Return the batch number of the current batch. Return the batch number of the current batch.

Return:
Number, number of the current batch.
""" """
return return


def get_epoch_num(self): def get_epoch_num(self):
""" """
Return the epoch number of the current batch. Return the epoch number of the current batch.

Return:
Number, number of the current epoch.
""" """
return return


@@ -2055,8 +2062,8 @@ class BlockReleasePair:
""" """
Function for handing blocking condition. Function for handing blocking condition.


Return:
True
Returns:
bool, True.
""" """
with self.cv: with self.cv:
# if disable is true, the always evaluate to true # if disable is true, the always evaluate to true
@@ -2145,8 +2152,9 @@ class SyncWaitDataset(Dataset):


Args: Args:
dataset (Dataset): Dataset to be checked. dataset (Dataset): Dataset to be checked.
Return:
True or False.

Returns:
bool, whether sync_wait is used before batch.
""" """
if isinstance(dataset, BatchDataset): if isinstance(dataset, BatchDataset):
return True return True
@@ -2932,6 +2940,9 @@ def _select_sampler(num_samples, input_sampler, shuffle, num_shards, shard_id, n
num_shards (int): Number of shard for sharding. num_shards (int): Number of shard for sharding.
shard_id (int): Shard ID. shard_id (int): Shard ID.
non_mappable (bool, optional): Indicate if caller is non-mappable dataset for special handling (default=False). non_mappable (bool, optional): Indicate if caller is non-mappable dataset for special handling (default=False).

Returns:
Sampler, sampler selected based on user input.
""" """
if non_mappable is True and all(arg is None for arg in [num_samples, shuffle, num_shards, shard_id, input_sampler]): if non_mappable is True and all(arg is None for arg in [num_samples, shuffle, num_shards, shard_id, input_sampler]):
return None return None
@@ -4180,7 +4191,7 @@ class ManifestDataset(MappableDataset):
Get the class index. Get the class index.


Returns: Returns:
Dict, A str-to-int mapping from label name to index.
dict, a str-to-int mapping from label name to index.
""" """
if self.class_indexing is None: if self.class_indexing is None:
if self._class_indexing is None: if self._class_indexing is None:
@@ -4579,7 +4590,7 @@ class Schema:
Args: Args:
schema_file(str): Path of schema file (default=None). schema_file(str): Path of schema file (default=None).


Return:
Returns:
Schema object, schema info about dataset. Schema object, schema info about dataset.


Raises: Raises:
@@ -4654,7 +4665,7 @@ class Schema:
Get a JSON string of the schema. Get a JSON string of the schema.


Returns: Returns:
Str, JSON string of the schema.
str, JSON string of the schema.
""" """
return self.cpp_schema.to_json() return self.cpp_schema.to_json()


@@ -4840,7 +4851,7 @@ class VOCDataset(MappableDataset):
Get the class index. Get the class index.


Returns: Returns:
Dict, A str-to-int mapping from label name to index.
dict, a str-to-int mapping from label name to index.
""" """
if self.task != "Detection": if self.task != "Detection":
raise NotImplementedError("Only 'Detection' support get_class_indexing.") raise NotImplementedError("Only 'Detection' support get_class_indexing.")
@@ -5032,7 +5043,7 @@ class CocoDataset(MappableDataset):
Get the class index. Get the class index.


Returns: Returns:
Dict, A str-to-list<int> mapping from label name to index
dict, a str-to-list<int> mapping from label name to index
""" """
if self.task not in {"Detection", "Panoptic"}: if self.task not in {"Detection", "Panoptic"}:
raise NotImplementedError("Only 'Detection' and 'Panoptic' support get_class_indexing.") raise NotImplementedError("Only 'Detection' and 'Panoptic' support get_class_indexing.")


+ 10
- 10
mindspore/dataset/engine/graphdata.py View File

@@ -100,7 +100,7 @@ class GraphData:
node_type (int): Specify the type of node. node_type (int): Specify the type of node.


Returns: Returns:
numpy.ndarray: Array of nodes.
numpy.ndarray, array of nodes.


Examples: Examples:
>>> import mindspore.dataset as ds >>> import mindspore.dataset as ds
@@ -124,7 +124,7 @@ class GraphData:
edge_type (int): Specify the type of edge. edge_type (int): Specify the type of edge.


Returns: Returns:
numpy.ndarray: array of edges.
numpy.ndarray, array of edges.


Examples: Examples:
>>> import mindspore.dataset as ds >>> import mindspore.dataset as ds
@@ -148,7 +148,7 @@ class GraphData:
edge_list (Union[list, numpy.ndarray]): The given list of edges. edge_list (Union[list, numpy.ndarray]): The given list of edges.


Returns: Returns:
numpy.ndarray: Array of nodes.
numpy.ndarray, array of nodes.


Raises: Raises:
TypeError: If `edge_list` is not list or ndarray. TypeError: If `edge_list` is not list or ndarray.
@@ -167,7 +167,7 @@ class GraphData:
neighbor_type (int): Specify the type of neighbor. neighbor_type (int): Specify the type of neighbor.


Returns: Returns:
numpy.ndarray: Array of nodes.
numpy.ndarray, array of neighbors.


Examples: Examples:
>>> import mindspore.dataset as ds >>> import mindspore.dataset as ds
@@ -201,7 +201,7 @@ class GraphData:
neighbor_types (Union[list, numpy.ndarray]): Neighbor type sampled per hop. neighbor_types (Union[list, numpy.ndarray]): Neighbor type sampled per hop.


Returns: Returns:
numpy.ndarray: Array of nodes.
numpy.ndarray, array of neighbors.


Examples: Examples:
>>> import mindspore.dataset as ds >>> import mindspore.dataset as ds
@@ -231,7 +231,7 @@ class GraphData:
neg_neighbor_type (int): Specify the type of negative neighbor. neg_neighbor_type (int): Specify the type of negative neighbor.


Returns: Returns:
numpy.ndarray: Array of nodes.
numpy.ndarray, array of neighbors.


Examples: Examples:
>>> import mindspore.dataset as ds >>> import mindspore.dataset as ds
@@ -260,7 +260,7 @@ class GraphData:
feature_types (Union[list, numpy.ndarray]): The given list of feature types. feature_types (Union[list, numpy.ndarray]): The given list of feature types.


Returns: Returns:
numpy.ndarray: array of features.
numpy.ndarray, array of features.


Examples: Examples:
>>> import mindspore.dataset as ds >>> import mindspore.dataset as ds
@@ -292,7 +292,7 @@ class GraphData:
feature_types (Union[list, numpy.ndarray]): The given list of feature types. feature_types (Union[list, numpy.ndarray]): The given list of feature types.


Returns: Returns:
numpy.ndarray: array of features.
numpy.ndarray, array of features.


Examples: Examples:
>>> import mindspore.dataset as ds >>> import mindspore.dataset as ds
@@ -320,7 +320,7 @@ class GraphData:
the feature information of nodes, the number of edges, the type of edges, and the feature information of edges. the feature information of nodes, the number of edges, the type of edges, and the feature information of edges.


Returns: Returns:
dict: Meta information of the graph. The key is node_type, edge_type, node_num, edge_num,
dict, meta information of the graph. The key is node_type, edge_type, node_num, edge_num,
node_feature_type and edge_feature_type. node_feature_type and edge_feature_type.
""" """
if self._working_mode == 'server': if self._working_mode == 'server':
@@ -347,7 +347,7 @@ class GraphData:
A default value of -1 indicates that no node is given. A default value of -1 indicates that no node is given.


Returns: Returns:
numpy.ndarray: Array of nodes.
numpy.ndarray, array of nodes.


Examples: Examples:
>>> import mindspore.dataset as ds >>> import mindspore.dataset as ds


+ 2
- 1
mindspore/dataset/engine/samplers.py View File

@@ -128,6 +128,7 @@ class BuiltinSampler:


User should not extend this class. User should not extend this class.
""" """

def __init__(self, num_samples=None): def __init__(self, num_samples=None):
self.child_sampler = None self.child_sampler = None
self.num_samples = num_samples self.num_samples = num_samples
@@ -201,7 +202,7 @@ class BuiltinSampler:
- None - None


Returns: Returns:
int, The number of samples, or None
int, the number of samples, or None
""" """
if self.child_sampler is not None: if self.child_sampler is not None:
child_samples = self.child_sampler.get_num_samples() child_samples = self.child_sampler.get_num_samples()


+ 3
- 2
mindspore/dataset/vision/py_transforms.py View File

@@ -1063,8 +1063,9 @@ class LinearTransformation:
the dot product with the transformation matrix, and reshapes it back to its original shape. the dot product with the transformation matrix, and reshapes it back to its original shape.


Args: Args:
transformation_matrix (numpy.ndarray): a square transformation matrix of shape (D, D), D = C x H x W.
mean_vector (numpy.ndarray): a NumPy ndarray of shape (D,) where D = C x H x W.
transformation_matrix (numpy.ndarray): a square transformation matrix of shape (D, D), where
:math:`D = C \times H \times W`.
mean_vector (numpy.ndarray): a NumPy ndarray of shape (D,) where :math:`D = C \times H \times W`.


Examples: Examples:
>>> from mindspore.dataset.transforms.py_transforms import Compose >>> from mindspore.dataset.transforms.py_transforms import Compose


+ 1
- 1
mindspore/explainer/_image_classification_runner.py View File

@@ -23,7 +23,7 @@ from PIL import Image
import mindspore as ms import mindspore as ms
import mindspore.dataset as ds import mindspore.dataset as ds
from mindspore import log from mindspore import log
from mindspore.dataset.engine.datasets import Dataset
from mindspore.dataset import Dataset
from mindspore.nn import Cell, SequentialCell from mindspore.nn import Cell, SequentialCell
from mindspore.ops.operations import ExpandDims from mindspore.ops.operations import ExpandDims
from mindspore.train._utils import check_value_type from mindspore.train._utils import check_value_type


+ 12
- 8
mindspore/mindrecord/filewriter.py View File

@@ -30,6 +30,7 @@ from .common.exceptions import ParamValueError, ParamTypeError, MRMInvalidSchema


__all__ = ['FileWriter'] __all__ = ['FileWriter']



class FileWriter: class FileWriter:
""" """
Class to write user defined raw data into MindRecord File series. Class to write user defined raw data into MindRecord File series.
@@ -45,6 +46,7 @@ class FileWriter:
Raises: Raises:
ParamValueError: If `file_name` or `shard_num` is invalid. ParamValueError: If `file_name` or `shard_num` is invalid.
""" """

def __init__(self, file_name, shard_num=1): def __init__(self, file_name, shard_num=1):
check_filename(file_name) check_filename(file_name)
self._file_name = file_name self._file_name = file_name
@@ -84,7 +86,7 @@ class FileWriter:
file_name (str): String of MindRecord file name. file_name (str): String of MindRecord file name.


Returns: Returns:
Instance of FileWriter.
FileWriter, file writer for the opened MindRecord file.


Raises: Raises:
ParamValueError: If file_name is invalid. ParamValueError: If file_name is invalid.
@@ -118,7 +120,7 @@ class FileWriter:
desc (str, optional): String of schema description (default=None). desc (str, optional): String of schema description (default=None).


Returns: Returns:
An integer, schema id.
int, schema id.


Raises: Raises:
MRMInvalidSchemaError: If schema is invalid. MRMInvalidSchemaError: If schema is invalid.
@@ -175,17 +177,17 @@ class FileWriter:


if field not in v: if field not in v:
error_data_dic[i] = "for schema, {} th data is wrong, " \ error_data_dic[i] = "for schema, {} th data is wrong, " \
"there is not '{}' object in the raw data.".format(i, field)
"there is not '{}' object in the raw data.".format(i, field)
continue continue
field_type = type(v[field]).__name__ field_type = type(v[field]).__name__
if field_type not in VALUE_TYPE_MAP: if field_type not in VALUE_TYPE_MAP:
error_data_dic[i] = "for schema, {} th data is wrong, " \ error_data_dic[i] = "for schema, {} th data is wrong, " \
"data type for '{}' is not matched.".format(i, field)
"data type for '{}' is not matched.".format(i, field)
continue continue


if schema_content[field]["type"] not in VALUE_TYPE_MAP[field_type]: if schema_content[field]["type"] not in VALUE_TYPE_MAP[field_type]:
error_data_dic[i] = "for schema, {} th data is wrong, " \ error_data_dic[i] = "for schema, {} th data is wrong, " \
"data type for '{}' is not matched.".format(i, field)
"data type for '{}' is not matched.".format(i, field)
continue continue


if field_type == 'ndarray': if field_type == 'ndarray':
@@ -206,7 +208,6 @@ class FileWriter:
def open_and_set_header(self): def open_and_set_header(self):
""" """
Open writer and set header. Open writer and set header.

""" """
if not self._writer.is_open: if not self._writer.is_open:
self._writer.open(self._paths) self._writer.open(self._paths)
@@ -222,6 +223,9 @@ class FileWriter:
raw_data (list[dict]): List of raw data. raw_data (list[dict]): List of raw data.
parallel_writer (bool, optional): Load data parallel if it equals to True (default=False). parallel_writer (bool, optional): Load data parallel if it equals to True (default=False).


Returns:
MSRStatus, SUCCESS or FAILED.

Raises: Raises:
ParamTypeError: If index field is invalid. ParamTypeError: If index field is invalid.
MRMOpenError: If failed to open MindRecord File. MRMOpenError: If failed to open MindRecord File.
@@ -330,7 +334,7 @@ class FileWriter:
v (dict): Sub dict in schema v (dict): Sub dict in schema


Returns: Returns:
bool, True or False.
bool, whether the array item is valid.
str, error message. str, error message.
""" """
if v['type'] not in VALID_ARRAY_ATTRIBUTES: if v['type'] not in VALID_ARRAY_ATTRIBUTES:
@@ -355,7 +359,7 @@ class FileWriter:
content (dict): Dict of raw schema. content (dict): Dict of raw schema.


Returns: Returns:
bool, True or False.
bool, whether the schema is valid.
str, error message. str, error message.
""" """
error = '' error = ''


+ 22
- 5
mindspore/mindrecord/mindpage.py View File

@@ -23,6 +23,7 @@ from .common.exceptions import ParamValueError, ParamTypeError, MRMDefineCategor


__all__ = ['MindPage'] __all__ = ['MindPage']



class MindPage: class MindPage:
""" """
Class to read MindRecord File series in pagination. Class to read MindRecord File series in pagination.
@@ -36,6 +37,7 @@ class MindPage:
ParamValueError: If `file_name`, `num_consumer` or columns is invalid. ParamValueError: If `file_name`, `num_consumer` or columns is invalid.
MRMInitSegmentError: If failed to initialize ShardSegment. MRMInitSegmentError: If failed to initialize ShardSegment.
""" """

def __init__(self, file_name, num_consumer=4): def __init__(self, file_name, num_consumer=4):
if isinstance(file_name, list): if isinstance(file_name, list):
for f in file_name: for f in file_name:
@@ -69,7 +71,12 @@ class MindPage:
return self._candidate_fields return self._candidate_fields


def get_category_fields(self): def get_category_fields(self):
"""Return candidate category fields."""
"""
Return candidate category fields.

Returns:
list[str], by which data could be grouped.
"""
logger.warning("WARN_DEPRECATED: The usage of get_category_fields is deprecated." logger.warning("WARN_DEPRECATED: The usage of get_category_fields is deprecated."
" Please use candidate_fields") " Please use candidate_fields")
return self.candidate_fields return self.candidate_fields
@@ -97,12 +104,22 @@ class MindPage:


@property @property
def category_field(self): def category_field(self):
"""Getter function for category fields."""
"""
Getter function for category fields.

Returns:
list[str], by which data could be grouped.
"""
return self._category_field return self._category_field


@category_field.setter @category_field.setter
def category_field(self, category_field): def category_field(self, category_field):
"""Setter function for category field"""
"""
Setter function for category field.

Returns:
MSRStatus, SUCCESS or FAILED.
"""
if not category_field or not isinstance(category_field, str): if not category_field or not isinstance(category_field, str):
raise ParamTypeError('category_fields', 'str') raise ParamTypeError('category_fields', 'str')
if category_field not in self._candidate_fields: if category_field not in self._candidate_fields:
@@ -132,7 +149,7 @@ class MindPage:
num_row (int): Number of rows in a page. num_row (int): Number of rows in a page.


Returns: Returns:
List, list[dict].
list[dict], data queried by category id.


Raises: Raises:
ParamValueError: If any parameter is invalid. ParamValueError: If any parameter is invalid.
@@ -158,7 +175,7 @@ class MindPage:
num_row (int): Number of row in a page. num_row (int): Number of row in a page.


Returns: Returns:
str, read at page.
list[dict], data queried by category name.
""" """
if not isinstance(category_name, str): if not isinstance(category_name, str):
raise ParamValueError("Category name should be str.") raise ParamValueError("Category name should be str.")


+ 4
- 1
mindspore/mindrecord/shardwriter.py View File

@@ -23,6 +23,7 @@ from .common.exceptions import MRMOpenError, MRMOpenForAppendError, MRMInvalidHe


__all__ = ['ShardWriter'] __all__ = ['ShardWriter']



class ShardWriter: class ShardWriter:
""" """
Wrapper class which is represent shardWrite class in c++ module. Wrapper class which is represent shardWrite class in c++ module.
@@ -192,9 +193,11 @@ class ShardWriter:
if len(blob_data) == 1: if len(blob_data) == 1:
values = [v for v in blob_data.values()] values = [v for v in blob_data.values()]
return bytes(values[0]) return bytes(values[0])

# convert int to bytes # convert int to bytes
def int_to_bytes(x: int) -> bytes: def int_to_bytes(x: int) -> bytes:
return x.to_bytes(8, 'big') return x.to_bytes(8, 'big')

merged = bytes() merged = bytes()
for field, v in blob_data.items(): for field, v in blob_data.items():
# convert ndarray to bytes # convert ndarray to bytes
@@ -209,7 +212,7 @@ class ShardWriter:
Flush data to disk. Flush data to disk.


Returns: Returns:
Class MSRStatus, SUCCESS or FAILED.
MSRStatus, SUCCESS or FAILED.


Raises: Raises:
MRMCommitError: If failed to flush data to disk. MRMCommitError: If failed to flush data to disk.


+ 7
- 3
mindspore/mindrecord/tools/cifar100_to_mr.py View File

@@ -33,6 +33,7 @@ except ModuleNotFoundError:


__all__ = ['Cifar100ToMR'] __all__ = ['Cifar100ToMR']



class Cifar100ToMR: class Cifar100ToMR:
""" """
A class to transform from cifar100 to MindRecord. A class to transform from cifar100 to MindRecord.
@@ -44,6 +45,7 @@ class Cifar100ToMR:
Raises: Raises:
ValueError: If source or destination is invalid. ValueError: If source or destination is invalid.
""" """

def __init__(self, source, destination): def __init__(self, source, destination):
check_filename(source) check_filename(source)
self.source = source self.source = source
@@ -74,7 +76,7 @@ class Cifar100ToMR:
fields (list[str]): A list of index field, e.g.["fine_label", "coarse_label"]. fields (list[str]): A list of index field, e.g.["fine_label", "coarse_label"].


Returns: Returns:
SUCCESS or FAILED, whether cifar100 is successfully transformed to MindRecord.
MSRStatus, whether cifar100 is successfully transformed to MindRecord.
""" """
if fields and not isinstance(fields, list): if fields and not isinstance(fields, list):
raise ValueError("The parameter fields should be None or list") raise ValueError("The parameter fields should be None or list")
@@ -114,6 +116,7 @@ class Cifar100ToMR:
raise t.exception raise t.exception
return t.res return t.res



def _construct_raw_data(images, fine_labels, coarse_labels): def _construct_raw_data(images, fine_labels, coarse_labels):
""" """
Construct raw data from cifar100 data. Construct raw data from cifar100 data.
@@ -124,7 +127,7 @@ def _construct_raw_data(images, fine_labels, coarse_labels):
coarse_labels (list): coarse label list from cifar100. coarse_labels (list): coarse label list from cifar100.


Returns: Returns:
SUCCESS/FAILED, whether successfully written into MindRecord.
list[dict], data dictionary constructed from cifar100.
""" """
if not cv2: if not cv2:
raise ModuleNotFoundError("opencv-python module not found, please use pip install it.") raise ModuleNotFoundError("opencv-python module not found, please use pip install it.")
@@ -141,6 +144,7 @@ def _construct_raw_data(images, fine_labels, coarse_labels):
raw_data.append(row_data) raw_data.append(row_data)
return raw_data return raw_data



def _generate_mindrecord(file_name, raw_data, fields, schema_desc): def _generate_mindrecord(file_name, raw_data, fields, schema_desc):
""" """
Generate MindRecord file from raw data. Generate MindRecord file from raw data.
@@ -153,7 +157,7 @@ def _generate_mindrecord(file_name, raw_data, fields, schema_desc):
schema_desc (str): String of schema description. schema_desc (str): String of schema description.


Returns: Returns:
SUCCESS/FAILED, whether successfully written into MindRecord.
MSRStatus, whether successfully written into MindRecord.
""" """
schema = {"id": {"type": "int64"}, "fine_label": {"type": "int64"}, schema = {"id": {"type": "int64"}, "fine_label": {"type": "int64"},
"coarse_label": {"type": "int64"}, "data": {"type": "bytes"}} "coarse_label": {"type": "int64"}, "data": {"type": "bytes"}}


+ 8
- 3
mindspore/mindrecord/tools/cifar10_to_mr.py View File

@@ -25,6 +25,7 @@ from .cifar10 import Cifar10
from ..common.exceptions import PathNotExistsError from ..common.exceptions import PathNotExistsError
from ..filewriter import FileWriter from ..filewriter import FileWriter
from ..shardutils import check_filename, ExceptionThread, SUCCESS, FAILED from ..shardutils import check_filename, ExceptionThread, SUCCESS, FAILED

try: try:
cv2 = import_module("cv2") cv2 = import_module("cv2")
except ModuleNotFoundError: except ModuleNotFoundError:
@@ -32,6 +33,7 @@ except ModuleNotFoundError:


__all__ = ['Cifar10ToMR'] __all__ = ['Cifar10ToMR']



class Cifar10ToMR: class Cifar10ToMR:
""" """
A class to transform from cifar10 to MindRecord. A class to transform from cifar10 to MindRecord.
@@ -43,6 +45,7 @@ class Cifar10ToMR:
Raises: Raises:
ValueError: If source or destination is invalid. ValueError: If source or destination is invalid.
""" """

def __init__(self, source, destination): def __init__(self, source, destination):
check_filename(source) check_filename(source)
self.source = source self.source = source
@@ -73,7 +76,7 @@ class Cifar10ToMR:
fields (list[str], optional): A list of index fields, e.g.["label"] (default=None). fields (list[str], optional): A list of index fields, e.g.["label"] (default=None).


Returns: Returns:
SUCCESS or FAILED, whether cifar10 is successfully transformed to MindRecord.
MSRStatus, whether cifar10 is successfully transformed to MindRecord.
""" """
if fields and not isinstance(fields, list): if fields and not isinstance(fields, list):
raise ValueError("The parameter fields should be None or list") raise ValueError("The parameter fields should be None or list")
@@ -109,6 +112,7 @@ class Cifar10ToMR:
raise t.exception raise t.exception
return t.res return t.res



def _construct_raw_data(images, labels): def _construct_raw_data(images, labels):
""" """
Construct raw data from cifar10 data. Construct raw data from cifar10 data.
@@ -118,7 +122,7 @@ def _construct_raw_data(images, labels):
labels (list): label list from cifar10. labels (list): label list from cifar10.


Returns: Returns:
SUCCESS/FAILED, whether successfully written into MindRecord.
list[dict], data dictionary constructed from cifar10.
""" """
if not cv2: if not cv2:
raise ModuleNotFoundError("opencv-python module not found, please use pip install it.") raise ModuleNotFoundError("opencv-python module not found, please use pip install it.")
@@ -133,6 +137,7 @@ def _construct_raw_data(images, labels):
raw_data.append(row_data) raw_data.append(row_data)
return raw_data return raw_data



def _generate_mindrecord(file_name, raw_data, fields, schema_desc): def _generate_mindrecord(file_name, raw_data, fields, schema_desc):
""" """
Generate MindRecord file from raw data. Generate MindRecord file from raw data.
@@ -145,7 +150,7 @@ def _generate_mindrecord(file_name, raw_data, fields, schema_desc):
schema_desc (str): String of schema description. schema_desc (str): String of schema description.


Returns: Returns:
SUCCESS/FAILED, whether successfully written into MindRecord.
MSRStatus, whether successfully written into MindRecord.
""" """
schema = {"id": {"type": "int64"}, "label": {"type": "int64"}, schema = {"id": {"type": "int64"}, "label": {"type": "int64"},
"data": {"type": "bytes"}} "data": {"type": "bytes"}}


+ 2
- 1
mindspore/mindrecord/tools/csv_to_mr.py View File

@@ -29,6 +29,7 @@ except ModuleNotFoundError:


__all__ = ['CsvToMR'] __all__ = ['CsvToMR']



class CsvToMR: class CsvToMR:
""" """
A class to transform from csv to MindRecord. A class to transform from csv to MindRecord.
@@ -121,7 +122,7 @@ class CsvToMR:
Executes transformation from csv to MindRecord. Executes transformation from csv to MindRecord.


Returns: Returns:
SUCCESS or FAILED, whether csv is successfully transformed to MindRecord.
MSRStatus, whether csv is successfully transformed to MindRecord.
""" """
if not os.path.exists(self.source): if not os.path.exists(self.source):
raise IOError("Csv file {} do not exist.".format(self.source)) raise IOError("Csv file {} do not exist.".format(self.source))


+ 4
- 3
mindspore/mindrecord/tools/imagenet_to_mr.py View File

@@ -47,6 +47,7 @@ class ImageNetToMR:
Raises: Raises:
ValueError: If `map_file`, `image_dir` or `destination` is invalid. ValueError: If `map_file`, `image_dir` or `destination` is invalid.
""" """

def __init__(self, map_file, image_dir, destination, partition_number=1): def __init__(self, map_file, image_dir, destination, partition_number=1):
check_filename(map_file) check_filename(map_file)
self.map_file = map_file self.map_file = map_file
@@ -122,7 +123,7 @@ class ImageNetToMR:
Executes transformation from imagenet to MindRecord. Executes transformation from imagenet to MindRecord.


Returns: Returns:
SUCCESS or FAILED, whether imagenet is successfully transformed to MindRecord.
MSRStatus, whether imagenet is successfully transformed to MindRecord.
""" """
t0_total = time.time() t0_total = time.time()


@@ -133,10 +134,10 @@ class ImageNetToMR:
logger.info("transformed MindRecord schema is: {}".format(imagenet_schema_json)) logger.info("transformed MindRecord schema is: {}".format(imagenet_schema_json))


# set the header size # set the header size
self.writer.set_header_size(1<<24)
self.writer.set_header_size(1 << 24)


# set the page size # set the page size
self.writer.set_page_size(1<<26)
self.writer.set_page_size(1 << 26)


# create the schema # create the schema
self.writer.add_schema(imagenet_schema_json, "imagenet_schema") self.writer.add_schema(imagenet_schema_json, "imagenet_schema")


+ 4
- 3
mindspore/mindrecord/tools/mnist_to_mr.py View File

@@ -32,6 +32,7 @@ except ModuleNotFoundError:


__all__ = ['MnistToMR'] __all__ = ['MnistToMR']



class MnistToMR: class MnistToMR:
""" """
A class to transform from Mnist to MindRecord. A class to transform from Mnist to MindRecord.
@@ -125,7 +126,7 @@ class MnistToMR:
Executes transformation from Mnist train part to MindRecord. Executes transformation from Mnist train part to MindRecord.


Returns: Returns:
SUCCESS/FAILED, whether successfully written into MindRecord.
MSRStatus, whether successfully written into MindRecord.
""" """
t0_total = time.time() t0_total = time.time()


@@ -173,7 +174,7 @@ class MnistToMR:
Executes transformation from Mnist test part to MindRecord. Executes transformation from Mnist test part to MindRecord.


Returns: Returns:
SUCCESS or FAILED, whether Mnist is successfully transformed to MindRecord.
MSRStatus, whether Mnist is successfully transformed to MindRecord.
""" """
t0_total = time.time() t0_total = time.time()


@@ -222,7 +223,7 @@ class MnistToMR:
Executes transformation from Mnist to MindRecord. Executes transformation from Mnist to MindRecord.


Returns: Returns:
SUCCESS/FAILED, whether successfully written into MindRecord.
MSRStatus, whether successfully written into MindRecord.
""" """
if not cv2: if not cv2:
raise ModuleNotFoundError("opencv-python module not found, please use pip install it.") raise ModuleNotFoundError("opencv-python module not found, please use pip install it.")


+ 3
- 3
mindspore/mindrecord/tools/tfrecord_to_mr.py View File

@@ -23,7 +23,6 @@ from mindspore import log as logger
from ..filewriter import FileWriter from ..filewriter import FileWriter
from ..shardutils import check_filename, ExceptionThread from ..shardutils import check_filename, ExceptionThread



__all__ = ['TFRecordToMR'] __all__ = ['TFRecordToMR']


SupportedTensorFlowVersion = '1.13.0-rc1' SupportedTensorFlowVersion = '1.13.0-rc1'
@@ -86,9 +85,10 @@ class TFRecordToMR:
ValueError: If parameter is invalid. ValueError: If parameter is invalid.
Exception: when tensorflow module is not found or version is not correct. Exception: when tensorflow module is not found or version is not correct.
""" """

def __init__(self, source, destination, feature_dict, bytes_fields=None): def __init__(self, source, destination, feature_dict, bytes_fields=None):
try: try:
self.tf = import_module("tensorflow") # just used to convert tfrecord to mindrecord
self.tf = import_module("tensorflow") # just used to convert tfrecord to mindrecord
except ModuleNotFoundError: except ModuleNotFoundError:
self.tf = None self.tf = None
if not self.tf: if not self.tf:
@@ -265,7 +265,7 @@ class TFRecordToMR:
Execute transformation from TFRecord to MindRecord. Execute transformation from TFRecord to MindRecord.


Returns: Returns:
SUCCESS or FAILED, whether TFRecord is successfuly transformed to MindRecord.
MSRStatus, whether TFRecord is successfuly transformed to MindRecord.
""" """
writer = FileWriter(self.destination) writer = FileWriter(self.destination)
logger.info("Transformed MindRecord schema is: {}, TFRecord feature dict is: {}" logger.info("Transformed MindRecord schema is: {}, TFRecord feature dict is: {}"


+ 14
- 6
model_zoo/official/recommend/ncf/src/dataset.py View File

@@ -22,13 +22,12 @@ import pickle
import numpy as np import numpy as np
import pandas as pd import pandas as pd


from mindspore.dataset.engine import GeneratorDataset
from mindspore.dataset import GeneratorDataset


import src.constants as rconst import src.constants as rconst
import src.movielens as movielens import src.movielens as movielens
import src.stat_utils as stat_utils import src.stat_utils as stat_utils



DATASET_TO_NUM_USERS_AND_ITEMS = { DATASET_TO_NUM_USERS_AND_ITEMS = {
"ml-1m": (6040, 3706), "ml-1m": (6040, 3706),
"ml-20m": (138493, 26744) "ml-20m": (138493, 26744)
@@ -205,6 +204,7 @@ class NCFDataset:
""" """
A dataset for NCF network. A dataset for NCF network.
""" """

def __init__(self, def __init__(self,
pos_users, pos_users,
pos_items, pos_items,
@@ -407,6 +407,7 @@ class RandomSampler:
""" """
A random sampler for dataset. A random sampler for dataset.
""" """

def __init__(self, pos_count, num_train_negatives, batch_size): def __init__(self, pos_count, num_train_negatives, batch_size):
self.pos_count = pos_count self.pos_count = pos_count
self._num_samples = (1 + num_train_negatives) * self.pos_count self._num_samples = (1 + num_train_negatives) * self.pos_count
@@ -433,6 +434,7 @@ class DistributedSamplerOfTrain:
""" """
A distributed sampler for dataset. A distributed sampler for dataset.
""" """

def __init__(self, pos_count, num_train_negatives, batch_size, rank_id, rank_size): def __init__(self, pos_count, num_train_negatives, batch_size, rank_id, rank_size):
""" """
Distributed sampler of training dataset. Distributed sampler of training dataset.
@@ -443,15 +445,16 @@ class DistributedSamplerOfTrain:
self._batch_size = batch_size self._batch_size = batch_size


self._batchs_per_rank = int(math.ceil(self._num_samples / self._batch_size / rank_size)) self._batchs_per_rank = int(math.ceil(self._num_samples / self._batch_size / rank_size))
self._samples_per_rank = int(math.ceil(self._batchs_per_rank * self._batch_size))
self._samples_per_rank = int(math.ceil(self._batchs_per_rank * self._batch_size))
self._total_num_samples = self._samples_per_rank * self._rank_size self._total_num_samples = self._samples_per_rank * self._rank_size

def __iter__(self): def __iter__(self):
""" """
Returns the data after each sampling. Returns the data after each sampling.
""" """
indices = stat_utils.permutation((self._num_samples, stat_utils.random_int32())) indices = stat_utils.permutation((self._num_samples, stat_utils.random_int32()))
indices = indices.tolist() indices = indices.tolist()
indices.extend(indices[:self._total_num_samples-len(indices)])
indices.extend(indices[:self._total_num_samples - len(indices)])
indices = indices[self._rank_id:self._total_num_samples:self._rank_size] indices = indices[self._rank_id:self._total_num_samples:self._rank_size]
batch_indices = [indices[x * self._batch_size:(x + 1) * self._batch_size] for x in range(self._batchs_per_rank)] batch_indices = [indices[x * self._batch_size:(x + 1) * self._batch_size] for x in range(self._batchs_per_rank)]


@@ -463,10 +466,12 @@ class DistributedSamplerOfTrain:
""" """
return self._batchs_per_rank return self._batchs_per_rank



class SequenceSampler: class SequenceSampler:
""" """
A sequence sampler for dataset. A sequence sampler for dataset.
""" """

def __init__(self, eval_batch_size, num_users): def __init__(self, eval_batch_size, num_users):
self._eval_users_per_batch = int( self._eval_users_per_batch = int(
eval_batch_size // (1 + rconst.NUM_EVAL_NEGATIVES)) eval_batch_size // (1 + rconst.NUM_EVAL_NEGATIVES))
@@ -491,10 +496,12 @@ class SequenceSampler:
""" """
return self._eval_batches_per_epoch return self._eval_batches_per_epoch



class DistributedSamplerOfEval: class DistributedSamplerOfEval:
""" """
A distributed sampler for eval dataset. A distributed sampler for eval dataset.
""" """

def __init__(self, eval_batch_size, num_users, rank_id, rank_size): def __init__(self, eval_batch_size, num_users, rank_id, rank_size):
self._eval_users_per_batch = int( self._eval_users_per_batch = int(
eval_batch_size // (1 + rconst.NUM_EVAL_NEGATIVES)) eval_batch_size // (1 + rconst.NUM_EVAL_NEGATIVES))
@@ -507,8 +514,8 @@ class DistributedSamplerOfEval:
self._eval_batch_size = eval_batch_size self._eval_batch_size = eval_batch_size


self._batchs_per_rank = int(math.ceil(self._eval_batches_per_epoch / rank_size)) self._batchs_per_rank = int(math.ceil(self._eval_batches_per_epoch / rank_size))
#self._samples_per_rank = int(math.ceil(self._batchs_per_rank * self._eval_batch_size))
#self._total_num_samples = self._samples_per_rank * self._rank_size
# self._samples_per_rank = int(math.ceil(self._batchs_per_rank * self._eval_batch_size))
# self._total_num_samples = self._samples_per_rank * self._rank_size


def __iter__(self): def __iter__(self):
indices = [(x * self._eval_users_per_batch, (x + self._rank_id + 1) * self._eval_users_per_batch) indices = [(x * self._eval_users_per_batch, (x + self._rank_id + 1) * self._eval_users_per_batch)
@@ -525,6 +532,7 @@ class DistributedSamplerOfEval:
def __len__(self): def __len__(self):
return self._batchs_per_rank return self._batchs_per_rank



def parse_eval_batch_size(eval_batch_size): def parse_eval_batch_size(eval_batch_size):
""" """
Parse eval batch size. Parse eval batch size.


Loading…
Cancel
Save