From: @luoyang42 Reviewed-by: @liucunwei Signed-off-by: @liucunweitags/v1.2.0-rc1
| @@ -51,12 +51,19 @@ Status Execute::operator()(const mindspore::MSTensor &input, mindspore::MSTensor | |||
| Status rc = dataset::Tensor::CreateFromMemory(dataset::TensorShape(input.Shape()), | |||
| MSTypeToDEType(static_cast<TypeId>(input.DataType())), | |||
| (const uchar *)(input.Data().get()), input.DataSize(), &de_tensor); | |||
| RETURN_IF_NOT_OK(rc); | |||
| if (rc.IsError()) { | |||
| MS_LOG(ERROR) << rc; | |||
| RETURN_IF_NOT_OK(rc); | |||
| } | |||
| // Apply transforms on tensor | |||
| for (auto &t : transforms) { | |||
| std::shared_ptr<dataset::Tensor> de_output; | |||
| RETURN_IF_NOT_OK(t->Compute(de_tensor, &de_output)); | |||
| Status rc_ = t->Compute(de_tensor, &de_output); | |||
| if (rc_.IsError()) { | |||
| MS_LOG(ERROR) << rc_; | |||
| RETURN_IF_NOT_OK(rc_); | |||
| } | |||
| // For next transform | |||
| de_tensor = std::move(de_output); | |||
| @@ -90,14 +97,21 @@ Status Execute::operator()(const std::vector<MSTensor> &input_tensor_list, std:: | |||
| Status rc = dataset::Tensor::CreateFromMemory(dataset::TensorShape(tensor.Shape()), | |||
| MSTypeToDEType(static_cast<TypeId>(tensor.DataType())), | |||
| (const uchar *)(tensor.Data().get()), tensor.DataSize(), &de_tensor); | |||
| RETURN_IF_NOT_OK(rc); | |||
| if (rc.IsError()) { | |||
| MS_LOG(ERROR) << rc; | |||
| RETURN_IF_NOT_OK(rc); | |||
| } | |||
| de_tensor_list.emplace_back(std::move(de_tensor)); | |||
| } | |||
| // Apply transforms on tensor | |||
| for (auto &t : transforms) { | |||
| TensorRow de_output_list; | |||
| RETURN_IF_NOT_OK(t->Compute(de_tensor_list, &de_output_list)); | |||
| Status rc = t->Compute(de_tensor_list, &de_output_list); | |||
| if (rc.IsError()) { | |||
| MS_LOG(ERROR) << rc; | |||
| RETURN_IF_NOT_OK(rc); | |||
| } | |||
| // For next transform | |||
| de_tensor_list = std::move(de_output_list); | |||
| } | |||
| @@ -31,12 +31,7 @@ PYBIND_REGISTER(Execute, 0, ([](const py::module *m) { | |||
| .def("__call__", | |||
| [](Execute &self, const std::shared_ptr<Tensor> &de_tensor) { | |||
| auto ms_tensor = mindspore::MSTensor(std::make_shared<DETensor>(de_tensor)); | |||
| Status rc = self(ms_tensor, &ms_tensor); | |||
| if (rc.IsError()) { | |||
| THROW_IF_ERROR([&rc]() { | |||
| RETURN_STATUS_UNEXPECTED("Failed to execute transform op, " + rc.ToString()); | |||
| }()); | |||
| } | |||
| THROW_IF_ERROR(self(ms_tensor, &ms_tensor)); | |||
| std::shared_ptr<dataset::Tensor> de_output_tensor; | |||
| dataset::Tensor::CreateFromMemory(dataset::TensorShape(ms_tensor.Shape()), | |||
| MSTypeToDEType(static_cast<TypeId>(ms_tensor.DataType())), | |||
| @@ -51,11 +46,7 @@ PYBIND_REGISTER(Execute, 0, ([](const py::module *m) { | |||
| auto ms_tensor = mindspore::MSTensor(std::make_shared<DETensor>(tensor)); | |||
| ms_input_tensor_list.emplace_back(std::move(ms_tensor)); | |||
| } | |||
| Status rc = self(ms_input_tensor_list, &ms_output_tensor_list); | |||
| if (rc.IsError()) { | |||
| THROW_IF_ERROR( | |||
| [&rc]() { RETURN_STATUS_UNEXPECTED("Failed to execute transform op, " + rc.ToString()); }()); | |||
| } | |||
| THROW_IF_ERROR(self(ms_input_tensor_list, &ms_output_tensor_list)); | |||
| std::vector<std::shared_ptr<dataset::Tensor>> de_output_tensor_list; | |||
| for (auto &tensor : ms_output_tensor_list) { | |||
| std::shared_ptr<dataset::Tensor> de_output_tensor; | |||
| @@ -825,8 +825,8 @@ std::shared_ptr<TensorOp> PadOperation::Build() { | |||
| break; | |||
| case 2: | |||
| pad_left = padding_[0]; | |||
| pad_top = padding_[1]; | |||
| pad_right = padding_[0]; | |||
| pad_top = padding_[0]; | |||
| pad_right = padding_[1]; | |||
| pad_bottom = padding_[1]; | |||
| break; | |||
| default: | |||
| @@ -1096,8 +1096,8 @@ std::shared_ptr<TensorOp> RandomCropOperation::Build() { | |||
| break; | |||
| case 2: | |||
| pad_left = padding_[0]; | |||
| pad_top = padding_[1]; | |||
| pad_right = padding_[0]; | |||
| pad_top = padding_[0]; | |||
| pad_right = padding_[1]; | |||
| pad_bottom = padding_[1]; | |||
| break; | |||
| default: | |||
| @@ -1215,8 +1215,8 @@ std::shared_ptr<TensorOp> RandomCropWithBBoxOperation::Build() { | |||
| break; | |||
| case 2: | |||
| pad_left = padding_[0]; | |||
| pad_top = padding_[1]; | |||
| pad_right = padding_[0]; | |||
| pad_top = padding_[0]; | |||
| pad_right = padding_[1]; | |||
| pad_bottom = padding_[1]; | |||
| break; | |||
| default: | |||
| @@ -141,7 +141,7 @@ std::shared_ptr<JiebaTokenizerOperation> JiebaTokenizer(const std::string &hmm_p | |||
| const JiebaMode &mode = JiebaMode::kMix, | |||
| bool with_offsets = false); | |||
| /// \brief Lookup operator that looks up a word to an id. | |||
| /// \brief Look up a word into an id according to the input vocabulary table. | |||
| /// \param[in] vocab a Vocab object. | |||
| /// \param[in] unknown_token word to use for lookup if the word being looked up is out of Vocabulary (oov). | |||
| /// If unknown_token is oov, runtime error will be thrown. If unknown_token is {}, which means that not to | |||
| @@ -200,8 +200,8 @@ std::shared_ptr<NormalizePadOperation> NormalizePad(const std::vector<float> &me | |||
| /// \notes Pads the image according to padding parameters | |||
| /// \param[in] padding A vector representing the number of pixels to pad the image | |||
| /// If vector has one value, it pads all sides of the image with that value. | |||
| /// If vector has two values, it pads left and right with the first and | |||
| /// top and bottom with the second value. | |||
| /// If vector has two values, it pads left and top with the first and | |||
| /// right and bottom with the second value. | |||
| /// If vector has four values, it pads left, top, right, and bottom with | |||
| /// those values respectively. | |||
| /// \param[in] fill_value A vector representing the pixel intensity of the borders if the padding_mode is | |||
| @@ -270,8 +270,12 @@ std::shared_ptr<RandomColorAdjustOperation> RandomColorAdjust(std::vector<float> | |||
| /// \param[in] size A vector representing the output size of the cropped image. | |||
| /// If size is a single value, a square crop of size (size, size) is returned. | |||
| /// If size has 2 values, it should be (height, width). | |||
| /// \param[in] padding A vector with the value of pixels to pad the image. If 4 values are provided, | |||
| /// it pads the left, top, right and bottom respectively. | |||
| /// \param[in] padding A vector representing the number of pixels to pad the image | |||
| /// If vector has one value, it pads all sides of the image with that value. | |||
| /// If vector has two values, it pads left and top with the first and | |||
| /// right and bottom with the second value. | |||
| /// If vector has four values, it pads left, top, right, and bottom with | |||
| /// those values respectively. | |||
| /// \param[in] pad_if_needed A boolean whether to pad the image if either side is smaller than | |||
| /// the given output size. | |||
| /// \param[in] fill_value A vector representing the pixel intensity of the borders if the padding_mode is | |||
| @@ -304,8 +308,12 @@ std::shared_ptr<RandomCropDecodeResizeOperation> RandomCropDecodeResize( | |||
| /// \param[in] size A vector representing the output size of the cropped image. | |||
| /// If size is a single value, a square crop of size (size, size) is returned. | |||
| /// If size has 2 values, it should be (height, width). | |||
| /// \param[in] padding A vector with the value of pixels to pad the image. If 4 values are provided, | |||
| /// it pads the left, top, right and bottom respectively. | |||
| /// \param[in] padding A vector representing the number of pixels to pad the image | |||
| /// If vector has one value, it pads all sides of the image with that value. | |||
| /// If vector has two values, it pads left and top with the first and | |||
| /// right and bottom with the second value. | |||
| /// If vector has four values, it pads left, top, right, and bottom with | |||
| /// those values respectively. | |||
| /// \param[in] pad_if_needed A boolean whether to pad the image if either side is smaller than | |||
| /// the given output size. | |||
| /// \param[in] fill_value A vector representing the pixel intensity of the borders if the padding_mode is | |||
| @@ -24,7 +24,7 @@ namespace dataset { | |||
| Status DuplicateOp::Compute(const TensorRow &input, TensorRow *output) { | |||
| IO_CHECK_VECTOR(input, output); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(input.size() == 1, "Duplicate: only support one input."); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(input.size() == 1, "Duplicate: only supports transform one column each time."); | |||
| std::shared_ptr<Tensor> out; | |||
| RETURN_IF_NOT_OK(Tensor::CreateFromTensor(input[0], &out)); | |||
| output->push_back(input[0]); | |||
| @@ -103,8 +103,12 @@ Status Resize(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *out | |||
| RETURN_STATUS_UNEXPECTED("Resize: load image failed."); | |||
| } | |||
| if (input_cv->Rank() != 3 && input_cv->Rank() != 2) { | |||
| RETURN_STATUS_UNEXPECTED("Resize: input is not in shape of <H,W,C> or <H,W>"); | |||
| RETURN_STATUS_UNEXPECTED("Resize: input tensor is not in shape of <H,W,C> or <H,W>"); | |||
| } | |||
| if (input_cv->shape()[2] != 3 && input_cv->shape()[2] != 1) { | |||
| RETURN_STATUS_UNEXPECTED("Resize: channel of input tesnor is not in 1 or 3."); | |||
| } | |||
| cv::Mat in_image = input_cv->mat(); | |||
| // resize image too large or too small | |||
| if (output_height > in_image.rows * 1000 || output_width > in_image.cols * 1000) { | |||
| @@ -41,7 +41,7 @@ Status SharpnessOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_pt | |||
| /// Get number of channels and image matrix | |||
| std::size_t num_of_channels = input_cv->shape()[2]; | |||
| if (num_of_channels != 1 && num_of_channels != 3) { | |||
| RETURN_STATUS_UNEXPECTED("Sharpness: image shape is not <H,W,C>."); | |||
| RETURN_STATUS_UNEXPECTED("Sharpness: image channel is not 1 or 3."); | |||
| } | |||
| /// creating a smoothing filter. 1, 1, 1, | |||
| @@ -579,7 +579,7 @@ class Dataset: | |||
| python_multiprocessing (bool, optional): Parallelize Python operations with multiple worker processes. This | |||
| option could be beneficial if the Python operation is computational heavy (default=False). | |||
| cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing. | |||
| (default=None which means no cache is used). | |||
| (default=None, which means no cache is used). | |||
| callbacks: (DSCallback, list[DSCallback], optional): List of Dataset callbacks to be called (Default=None). | |||
| @@ -2288,7 +2288,7 @@ class MapDataset(Dataset): | |||
| python_multiprocessing (bool, optional): Parallelize Python operations with multiple worker process. This | |||
| option could be beneficial if the Python operation is computational heavy (default=False). | |||
| cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing. | |||
| (default=None which means no cache is used). | |||
| (default=None, which means no cache is used). | |||
| callbacks: (DSCallback, list[DSCallback], optional): List of Dataset callbacks to be called (Default=None) | |||
| Raises: | |||
| @@ -2843,7 +2843,7 @@ class ImageFolderDataset(MappableDataset): | |||
| shard_id (int, optional): The shard ID within num_shards (default=None). This | |||
| argument can only be specified when num_shards is also specified. | |||
| cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing. | |||
| (default=None which means no cache is used). | |||
| (default=None, which means no cache is used). | |||
| Raises: | |||
| RuntimeError: If sampler and shuffle are specified at the same time. | |||
| @@ -2951,7 +2951,7 @@ class MnistDataset(MappableDataset): | |||
| shard_id (int, optional): The shard ID within num_shards (default=None). This | |||
| argument can only be specified when num_shards is also specified. | |||
| cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing. | |||
| (default=None which means no cache is used). | |||
| (default=None, which means no cache is used). | |||
| Raises: | |||
| RuntimeError: If sampler and shuffle are specified at the same time. | |||
| @@ -2981,7 +2981,7 @@ class MnistDataset(MappableDataset): | |||
| class MindDataset(MappableDataset): | |||
| """ | |||
| A source dataset that reads MindRecord files. | |||
| A source dataset for reading and parsing MindRecord dataset. | |||
| Args: | |||
| dataset_file (Union[str, list[str]]): If dataset_file is a str, it represents for | |||
| @@ -3505,7 +3505,7 @@ class GeneratorDataset(MappableDataset): | |||
| class TFRecordDataset(SourceDataset): | |||
| """ | |||
| A source dataset that reads and parses datasets stored on disk in TFData format. | |||
| A source dataset for reading and parsing datasets stored on disk in TFData format. | |||
| Args: | |||
| dataset_files (Union[str, list[str]]): String or list of files to be read or glob strings to search for a | |||
| @@ -3537,7 +3537,7 @@ class TFRecordDataset(SourceDataset): | |||
| shard_equal_rows (bool, optional): Get equal rows for all shards(default=False). If shard_equal_rows | |||
| is false, number of rows of each shard may be not equal. | |||
| cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing. | |||
| (default=None which means no cache is used). | |||
| (default=None, which means no cache is used). | |||
| Examples: | |||
| >>> import mindspore.common.dtype as mstype | |||
| @@ -3581,7 +3581,7 @@ class TFRecordDataset(SourceDataset): | |||
| class ManifestDataset(MappableDataset): | |||
| """ | |||
| A source dataset that reads images from a manifest file. | |||
| A source dataset for reading images from a Manifest file. | |||
| The generated dataset has two columns ['image', 'label']. | |||
| The shape of the image column is [image_size] if decode flag is False, or [H,W,C] | |||
| @@ -3637,7 +3637,7 @@ class ManifestDataset(MappableDataset): | |||
| shard_id (int, optional): The shard ID within num_shards (default=None). This | |||
| argument can only be specified when num_shards is also specified. | |||
| cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing. | |||
| (default=None which means no cache is used). | |||
| (default=None, which means no cache is used). | |||
| Raises: | |||
| RuntimeError: If sampler and shuffle are specified at the same time. | |||
| @@ -3690,7 +3690,7 @@ class ManifestDataset(MappableDataset): | |||
| class Cifar10Dataset(MappableDataset): | |||
| """ | |||
| A source dataset that reads cifar10 data. | |||
| A source dataset for reading and parsing Cifar10 dataset. | |||
| The generated dataset has two columns ['image', 'label']. | |||
| The type of the image tensor is uint8. The label is a scalar uint32 tensor. | |||
| @@ -3756,7 +3756,7 @@ class Cifar10Dataset(MappableDataset): | |||
| shard_id (int, optional): The shard ID within num_shards (default=None). This | |||
| argument can only be specified when num_shards is also specified. | |||
| cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing. | |||
| (default=None which means no cache is used). | |||
| (default=None, which means no cache is used). | |||
| Raises: | |||
| RuntimeError: If sampler and shuffle are specified at the same time. | |||
| @@ -3793,7 +3793,7 @@ class Cifar10Dataset(MappableDataset): | |||
| class Cifar100Dataset(MappableDataset): | |||
| """ | |||
| A source dataset that reads cifar100 data. | |||
| A source dataset for reading and parsing Cifar100 dataset. | |||
| The generated dataset has three columns ['image', 'coarse_label', 'fine_label']. | |||
| The type of the image tensor is uint8. The coarse and fine labels are each a scalar uint32 tensor. | |||
| @@ -3861,7 +3861,7 @@ class Cifar100Dataset(MappableDataset): | |||
| shard_id (int, optional): The shard ID within num_shards (default=None). This | |||
| argument can only be specified when num_shards is also specified. | |||
| cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing. | |||
| (default=None which means no cache is used). | |||
| (default=None, which means no cache is used). | |||
| Raises: | |||
| RuntimeError: If sampler and shuffle are specified at the same time. | |||
| @@ -3906,7 +3906,7 @@ class RandomDataset(SourceDataset): | |||
| num_parallel_workers (int, optional): Number of workers to read the data | |||
| (default=None, number set in the config). | |||
| cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing. | |||
| (default=None which means no cache is used). | |||
| (default=None, which means no cache is used). | |||
| shuffle (bool, optional): Whether or not to perform shuffle on the dataset | |||
| (default=None, expected order behavior shown in the table). | |||
| num_shards (int, optional): Number of shards that the dataset will be divided | |||
| @@ -4125,7 +4125,7 @@ class VOCDataset(MappableDataset): | |||
| shard_id (int, optional): The shard ID within num_shards (default=None). This | |||
| argument can only be specified when num_shards is also specified. | |||
| cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing. | |||
| (default=None which means no cache is used). | |||
| (default=None, which means no cache is used). | |||
| Raises: | |||
| RuntimeError: If xml of Annotations is an invalid format. | |||
| @@ -4280,7 +4280,7 @@ class CocoDataset(MappableDataset): | |||
| shard_id (int, optional): The shard ID within num_shards (default=None). This | |||
| argument can only be specified when num_shards is also specified. | |||
| cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing. | |||
| (default=None which means no cache is used). | |||
| (default=None, which means no cache is used). | |||
| Raises: | |||
| RuntimeError: If sampler and shuffle are specified at the same time. | |||
| @@ -4391,7 +4391,7 @@ class CelebADataset(MappableDataset): | |||
| shard_id (int, optional): The shard ID within num_shards (default=None). This | |||
| argument can only be specified when num_shards is also specified. | |||
| cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing. | |||
| (default=None which means no cache is used). | |||
| (default=None, which means no cache is used). | |||
| Examples: | |||
| >>> dataset = ds.CelebADataset(dataset_dir=celeba_dataset_dir, usage='train') | |||
| @@ -4465,7 +4465,7 @@ class CLUEDataset(SourceDataset): | |||
| shard_id (int, optional): The shard ID within num_shards (default=None). This | |||
| argument can only be specified when num_shards is also specified. | |||
| cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing. | |||
| (default=None which means no cache is used). | |||
| (default=None, which means no cache is used). | |||
| Examples: | |||
| >>> clue_dataset_dir = ["/path/to/clue_dataset_file"] # contains 1 or multiple text files | |||
| @@ -4592,7 +4592,7 @@ class CSVDataset(SourceDataset): | |||
| shard_id (int, optional): The shard ID within num_shards (default=None). This | |||
| argument can only be specified when num_shards is also specified. | |||
| cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing. | |||
| (default=None which means no cache is used). | |||
| (default=None, which means no cache is used). | |||
| Examples: | |||
| @@ -4642,7 +4642,7 @@ class TextFileDataset(SourceDataset): | |||
| shard_id (int, optional): The shard ID within num_shards (default=None). This | |||
| argument can only be specified when num_shards is also specified. | |||
| cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing. | |||
| (default=None which means no cache is used). | |||
| (default=None, which means no cache is used). | |||
| Examples: | |||
| >>> # contains 1 or multiple text files | |||
| @@ -586,7 +586,7 @@ class SubsetSampler(BuiltinSampler): | |||
| for i, item in enumerate(indices): | |||
| if not isinstance(item, numbers.Number): | |||
| raise TypeError("type of weights element should be number, " | |||
| raise TypeError("type of indices element should be number, " | |||
| "but got w[{}]: {}, type: {}.".format(i, item, type(item))) | |||
| self.indices = indices | |||
| @@ -275,7 +275,7 @@ class JiebaTokenizer(TextTensorOperation): | |||
| class Lookup(TextTensorOperation): | |||
| """ | |||
| Lookup operator that looks up a word to an id. | |||
| Look up a word into an id according to the input vocabulary table. | |||
| Args: | |||
| vocab (Vocab): A vocabulary object. | |||
| @@ -39,13 +39,14 @@ class OneHot(cde.OneHotOp): | |||
| Tensor operation to apply one hot encoding. | |||
| Args: | |||
| num_classes (int): Number of classes of the label. | |||
| num_classes (int): Number of classes of objects in dataset. | |||
| It should be larger than the largest label number in the dataset. | |||
| Raises: | |||
| RuntimeError: feature size is bigger than num_classes. | |||
| Examples: | |||
| >>> # Assume that dataset has 10 classes, thus the label ranges from 0 to 9 | |||
| >>> onehot_op = c_transforms.OneHot(num_classes=10) | |||
| >>> mnist_dataset = mnist_dataset.map(operations=onehot_op, input_columns=["label"]) | |||
| """ | |||
| @@ -114,13 +115,13 @@ class _SliceOption(cde.SliceOption): | |||
| Internal class SliceOption to be used with SliceOperation | |||
| Args: | |||
| _SliceOption(Union[int, list(int), slice, None, Ellipses, bool, _SliceOption]): | |||
| _SliceOption(Union[int, list(int), slice, None, Ellipsis, bool, _SliceOption]): | |||
| 1. :py:obj:`int`: Slice this index only along the dimension. Negative index is supported. | |||
| 2. :py:obj:`list(int)`: Slice these indices along the dimension. Negative indices are supported. | |||
| 3. :py:obj:`slice`: Slice the generated indices from the slice object along the dimension. | |||
| 4. :py:obj:`None`: Slice the whole dimension. Similar to `:` in Python indexing. | |||
| 5. :py:obj:`Ellipses`: Slice the whole dimension. Similar to `:` in Python indexing. | |||
| 5. :py:obj:`Ellipsis`: Slice the whole dimension. Similar to `:` in Python indexing. | |||
| 6. :py:obj:`boolean`: Slice the whole dimension. Similar to `:` in Python indexing. | |||
| """ | |||
| @@ -143,16 +144,16 @@ class Slice(cde.SliceOp): | |||
| (Currently only rank-1 tensors are supported). | |||
| Args: | |||
| *slices(Union[int, list(int), slice, None, Ellipses]): | |||
| *slices(Union[int, list(int), slice, None, Ellipsis]): | |||
| Maximum `n` number of arguments to slice a tensor of rank `n`. | |||
| One object in slices can be one of: | |||
| 1. :py:obj:`int`: Slice this index only along the first dimension. Negative index is supported. | |||
| 2. :py:obj:`list(int)`: Slice these indices along the first dimension. Negative indices are supported. | |||
| 3. :py:obj:`slice`: Slice the generated indices from the slice object along the first dimension. | |||
| Similar to `start:stop:step`. | |||
| Similar to start:stop:step. | |||
| 4. :py:obj:`None`: Slice the whole dimension. Similar to `:` in Python indexing. | |||
| 5. :py:obj:`Ellipses`: Slice the whole dimension. Similar to `:` in Python indexing. | |||
| 5. :py:obj:`Ellipsis`: Slice the whole dimension. Similar to `:` in Python indexing. | |||
| Examples: | |||
| >>> # Data before | |||
| @@ -232,7 +233,7 @@ class Mask(cde.MaskOp): | |||
| class PadEnd(cde.PadEndOp): | |||
| """ | |||
| Pad input tensor according to `pad_shape`, need to have same rank. | |||
| Pad input tensor according to pad_shape, need to have same rank. | |||
| Args: | |||
| pad_shape (list(int)): List of integers representing the shape needed. Dimensions that set to `None` will | |||
| @@ -295,7 +296,7 @@ class Concatenate(cde.ConcatenateOp): | |||
| class Duplicate(cde.DuplicateOp): | |||
| """ | |||
| Duplicate the input tensor to a new output tensor. The input tensor is carried over to the output list. | |||
| Duplicate the input tensor to output, only support transform one column each time. | |||
| Examples: | |||
| >>> # Data before | |||
| @@ -405,7 +406,7 @@ class RandomApply(): | |||
| class RandomChoice(): | |||
| """ | |||
| Randomly selects one transform from a list of transforms to perform operation. | |||
| Randomly select one transform from a list of transforms to perform operation. | |||
| Args: | |||
| transforms (list): List of transformations to be chosen from to apply. | |||
| @@ -26,11 +26,13 @@ class OneHotOp: | |||
| Apply one hot encoding transformation to the input label, make label be more smoothing and continuous. | |||
| Args: | |||
| num_classes (int): Number of classes of objects in dataset. Value must be larger than 0. | |||
| num_classes (int): Number of classes of objects in dataset. | |||
| It should be larger than the largest label number in the dataset. | |||
| smoothing_rate (float, optional): Adjustable hyperparameter for label smoothing level. | |||
| (Default=0.0 means no smoothing is applied.) | |||
| Examples: | |||
| >>> # Assume that dataset has 10 classes, thus the label ranges from 0 to 9 | |||
| >>> transforms_list = [py_transforms.OneHotOp(num_classes=10, smoothing_rate=0.1)] | |||
| >>> transform = py_transforms.Compose(transforms_list) | |||
| >>> mnist_dataset = mnist_dataset(input_columns=["label"], operations=transform) | |||
| @@ -83,14 +85,14 @@ class Compose: | |||
| >>> | |||
| >>> # Compose is also be invoked implicitly, by just passing in a list of ops | |||
| >>> # the above example then becomes: | |||
| >>> transform_list = [py_vision.Decode(), | |||
| ... py_vision.RandomHorizontalFlip(0.5), | |||
| ... py_vision.ToTensor(), | |||
| ... py_vision.Normalize((0.491, 0.482, 0.447), (0.247, 0.243, 0.262)), | |||
| ... py_vision.RandomErasing()] | |||
| >>> transforms_list = [py_vision.Decode(), | |||
| ... py_vision.RandomHorizontalFlip(0.5), | |||
| ... py_vision.ToTensor(), | |||
| ... py_vision.Normalize((0.491, 0.482, 0.447), (0.247, 0.243, 0.262)), | |||
| ... py_vision.RandomErasing()] | |||
| >>> | |||
| >>> # apply the transform to the dataset through dataset.map() | |||
| >>> image_folder_dataset_1 = image_folder_dataset_1.map(operations=transform_list, input_columns=["image"]) | |||
| >>> image_folder_dataset_1 = image_folder_dataset_1.map(operations=transforms_list, input_columns=["image"]) | |||
| >>> | |||
| >>> # Certain C++ and Python ops can be combined, but not all of them | |||
| >>> # An example of combined operations | |||
| @@ -163,9 +165,9 @@ class RandomApply: | |||
| Examples: | |||
| >>> from mindspore.dataset.transforms.py_transforms import Compose | |||
| >>> transform_list = [py_vision.RandomHorizontalFlip(0.5), | |||
| ... py_vision.Normalize((0.491, 0.482, 0.447), (0.247, 0.243, 0.262)), | |||
| ... py_vision.RandomErasing()] | |||
| >>> transforms_list = [py_vision.RandomHorizontalFlip(0.5), | |||
| ... py_vision.Normalize((0.491, 0.482, 0.447), (0.247, 0.243, 0.262)), | |||
| ... py_vision.RandomErasing()] | |||
| >>> transforms = Compose([py_vision.Decode(), | |||
| ... py_transforms.RandomApply(transforms_list, prob=0.6), | |||
| ... py_vision.ToTensor()]) | |||
| @@ -199,11 +201,11 @@ class RandomChoice: | |||
| Examples: | |||
| >>> from mindspore.dataset.transforms.py_transforms import Compose, RandomChoice | |||
| >>> transform_list = [py_vision.RandomHorizontalFlip(0.5), | |||
| ... py_vision.Normalize((0.491, 0.482, 0.447), (0.247, 0.243, 0.262)), | |||
| ... py_vision.RandomErasing()] | |||
| >>> transforms_list = [py_vision.RandomHorizontalFlip(0.5), | |||
| ... py_vision.Normalize((0.491, 0.482, 0.447), (0.247, 0.243, 0.262)), | |||
| ... py_vision.RandomErasing()] | |||
| >>> transforms = Compose([py_vision.Decode(), | |||
| ... py_transforms.RandomChoice(transform_list), | |||
| ... py_transforms.RandomChoice(transforms_list), | |||
| ... py_vision.ToTensor()]) | |||
| >>> image_folder_dataset = image_folder_dataset.map(operations=transforms, input_columns=["image"]) | |||
| """ | |||
| @@ -234,9 +236,9 @@ class RandomOrder: | |||
| Examples: | |||
| >>> from mindspore.dataset.transforms.py_transforms import Compose | |||
| >>> transform_list = [py_vision.RandomHorizontalFlip(0.5), | |||
| ... py_vision.Normalize((0.491, 0.482, 0.447), (0.247, 0.243, 0.262)), | |||
| ... py_vision.RandomErasing()] | |||
| >>> transforms_list = [py_vision.RandomHorizontalFlip(0.5), | |||
| ... py_vision.Normalize((0.491, 0.482, 0.447), (0.247, 0.243, 0.262)), | |||
| ... py_vision.RandomErasing()] | |||
| >>> transforms = Compose([py_vision.Decode(), | |||
| ... py_transforms.RandomOrder(transforms_list), | |||
| ... py_vision.ToTensor()]) | |||
| @@ -108,8 +108,8 @@ def parse_padding(padding): | |||
| if isinstance(padding, numbers.Number): | |||
| padding = [padding] * 4 | |||
| if len(padding) == 2: | |||
| left = right = padding[0] | |||
| top = bottom = padding[1] | |||
| left = top = padding[0] | |||
| right = bottom = padding[1] | |||
| padding = (left, top, right, bottom,) | |||
| if isinstance(padding, list): | |||
| padding = tuple(padding) | |||
| @@ -438,8 +438,8 @@ class Pad(ImageTensorOperation): | |||
| Args: | |||
| padding (Union[int, sequence]): The number of pixels to pad the image. | |||
| If a single number is provided, it pads all borders with this value. | |||
| If a tuple or list of 2 values are provided, it pads left and right | |||
| with the first value and top and bottom with the second value. | |||
| If a tuple or list of 2 values are provided, it pads the (left and top) | |||
| with the first value and (right and bottom) with the second value. | |||
| If 4 values are provided as a list or tuple, | |||
| it pads the left, top, right and bottom respectively. | |||
| fill_value (Union[int, tuple], optional): The pixel intensity of the borders, only valid for | |||
| @@ -674,8 +674,8 @@ class RandomCrop(ImageTensorOperation): | |||
| padding (Union[int, sequence], optional): The number of pixels to pad the image (default=None). | |||
| If padding is not None, pad image firstly with padding values. | |||
| If a single number is provided, pad all borders with this value. | |||
| If a tuple or list of 2 values are provided, it pads left and right | |||
| with the first value and top and bottom with the second value. | |||
| If a tuple or list of 2 values are provided, pad the (left and top) | |||
| with the first value and (right and bottom) with the second value. | |||
| If 4 values are provided as a list or tuple, | |||
| pad the left, top, right and bottom respectively. | |||
| pad_if_needed (bool, optional): Pad the image if either side is smaller than | |||
| @@ -790,8 +790,8 @@ class RandomCropWithBBox(ImageTensorOperation): | |||
| padding (Union[int, sequence], optional): The number of pixels to pad the image (default=None). | |||
| If padding is not None, first pad image with padding values. | |||
| If a single number is provided, pad all borders with this value. | |||
| If a tuple or list of 2 values are provided, it pads left and right | |||
| with the first value and top and bottom with the second value. | |||
| If a tuple or list of 2 values are provided, pad the (left and top) | |||
| with the first value and (right and bottom) with the second value. | |||
| If 4 values are provided as a list or tuple, pad the left, top, right and bottom respectively. | |||
| pad_if_needed (bool, optional): Pad the image if either side is smaller than | |||
| the given output size (default=False). | |||
| @@ -845,7 +845,7 @@ class RandomCropWithBBox(ImageTensorOperation): | |||
| class RandomHorizontalFlip(ImageTensorOperation): | |||
| """ | |||
| Flip the input image horizontally, randomly with a given probability. | |||
| Randomly flip the input image horizontally with a given probability. | |||
| Args: | |||
| prob (float, optional): Probability of the image being flipped (default=0.5). | |||
| @@ -1202,12 +1202,12 @@ class RandomSharpness(ImageTensorOperation): | |||
| class RandomSolarize(ImageTensorOperation): | |||
| """ | |||
| Invert all pixel values above a threshold. | |||
| Invert all pixel values with given range. | |||
| Args: | |||
| threshold (tuple, optional): Range of random solarize threshold. Threshold values should always be | |||
| in the range (0, 255), include at least one integer value in the given range and | |||
| be in (min, max) format. If min=max, then it is a single fixed magnitude operation (default=(0, 255)). | |||
| in the range (0, 255), include at least one integer value in the given range and be in | |||
| (min, max) format. If min=max, then invert all pixel values above min(max) (default=(0, 255)). | |||
| Examples: | |||
| >>> transforms_list = [c_vision.Decode(), c_vision.RandomSolarize(threshold=(10,100))] | |||
| @@ -1225,7 +1225,7 @@ class RandomSolarize(ImageTensorOperation): | |||
| class RandomVerticalFlip(ImageTensorOperation): | |||
| """ | |||
| Flip the input image vertically, randomly with a given probability. | |||
| Randomly flip the input image vertically with a given probability. | |||
| Args: | |||
| prob (float, optional): Probability of the image being flipped (default=0.5). | |||
| @@ -1357,7 +1357,8 @@ class RandomColor: | |||
| class RandomSharpness: | |||
| """ | |||
| Adjust the sharpness of the input PIL image by a random degree. | |||
| Adjust the sharpness of the input PIL image by a fixed or random degree. Degree of 0.0 gives a blurred image, | |||
| degree of 1.0 gives the original image, and degree of 2.0 gives a sharpened image. | |||
| Args: | |||
| degrees (sequence): Range of random sharpness adjustment degrees. | |||
| @@ -499,7 +499,8 @@ def check_cutout(method): | |||
| @wraps(method) | |||
| def new_method(self, *args, **kwargs): | |||
| [length, num_patches], _ = parse_user_args(method, *args, **kwargs) | |||
| type_check(length, (int,), "length") | |||
| type_check(num_patches, (int,), "num_patches") | |||
| check_value(length, (1, FLOAT_MAX_INTEGER)) | |||
| check_value(num_patches, (1, FLOAT_MAX_INTEGER)) | |||
| @@ -613,7 +614,10 @@ def check_uniform_augment_cpp(method): | |||
| parsed_transforms.append(op.parse()) | |||
| else: | |||
| parsed_transforms.append(op) | |||
| type_check_list(parsed_transforms, (TensorOp, TensorOperation), "transforms") | |||
| type_check(parsed_transforms, (list, tuple,), "transforms") | |||
| for index, arg in enumerate(parsed_transforms): | |||
| if not isinstance(arg, (TensorOp, TensorOperation)): | |||
| raise TypeError("Type of Transforms[{0}] must be c_transform, but got {1}".format(index, type(arg))) | |||
| return method(self, *args, **kwargs) | |||
| @@ -72,6 +72,26 @@ def test_pad_op(): | |||
| assert mse < 0.01 | |||
| def test_pad_op2(): | |||
| """ | |||
| Test Pad op2 | |||
| """ | |||
| logger.info("test padding parameter with size 2") | |||
| data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) | |||
| decode_op = c_vision.Decode() | |||
| resize_op = c_vision.Resize([90, 90]) | |||
| pad_op = c_vision.Pad((100, 9,)) | |||
| ctrans = [decode_op, resize_op, pad_op] | |||
| data1 = data1.map(operations=ctrans, input_columns=["image"]) | |||
| for data in data1.create_dict_iterator(num_epochs=1, output_numpy=True): | |||
| logger.info(data["image"].shape) | |||
| # It pads left, top with 100 and right, bottom with 9, | |||
| # so the final size of image is 90 + 100 + 9 = 199 | |||
| assert data["image"].shape[0] == 199 | |||
| assert data["image"].shape[1] == 199 | |||
| def test_pad_grayscale(): | |||
| """ | |||
| @@ -145,6 +145,28 @@ def test_slice_multiple_rows(): | |||
| np.testing.assert_array_equal(exp_d, d['col']) | |||
| def test_slice_none_and_ellipsis(): | |||
| """ | |||
| Test passing None and Ellipsis to Slice | |||
| """ | |||
| dataset = [[1], [3, 4, 5], [1, 2], [1, 2, 3, 4, 5, 6, 7]] | |||
| exp_dataset = [[1], [3, 4, 5], [1, 2], [1, 2, 3, 4, 5, 6, 7]] | |||
| def gen(): | |||
| for row in dataset: | |||
| yield (np.array(row),) | |||
| data = ds.GeneratorDataset(gen, column_names=["col"]) | |||
| data = data.map(operations=ops.Slice(None)) | |||
| for (d, exp_d) in zip(data.create_dict_iterator(output_numpy=True), exp_dataset): | |||
| np.testing.assert_array_equal(exp_d, d['col']) | |||
| data = ds.GeneratorDataset(gen, column_names=["col"]) | |||
| data = data.map(operations=ops.Slice(Ellipsis)) | |||
| for (d, exp_d) in zip(data.create_dict_iterator(output_numpy=True), exp_dataset): | |||
| np.testing.assert_array_equal(exp_d, d['col']) | |||
| def test_slice_obj_neg(): | |||
| slice_compare([1, 2, 3, 4, 5], slice(-1, -5, -1), [5, 4, 3, 2]) | |||
| slice_compare([1, 2, 3, 4, 5], slice(-1), [1, 2, 3, 4]) | |||
| @@ -186,10 +186,7 @@ def test_cpp_uniform_augment_exception_pyops(num_ops=2): | |||
| C.UniformAugment(transforms=transforms_ua, num_ops=num_ops) | |||
| logger.info("Got an exception in DE: {}".format(str(e))) | |||
| assert "Argument transforms[5] with value" \ | |||
| " <mindspore.dataset.vision.py_transforms.Invert" in str(e.value) | |||
| assert "is not of type (<class 'mindspore._c_dataengine.TensorOp'>,"\ | |||
| " <class 'mindspore._c_dataengine.TensorOperation'>)" in str(e.value) | |||
| assert "Type of Transforms[5] must be c_transform" in str(e.value) | |||
| def test_cpp_uniform_augment_exception_large_numops(num_ops=6): | |||