Browse Source

!11853 Add call for decoupled image and text ops

From: @alexyuyue
Reviewed-by: 
Signed-off-by:
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 4 years ago
parent
commit
320ea51308
14 changed files with 1207 additions and 1096 deletions
  1. +23
    -2
      mindspore/ccsrc/minddata/dataset/api/execute.cc
  2. +17
    -5
      mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/include/execute_binding.cc
  3. +10
    -3
      mindspore/ccsrc/minddata/dataset/include/execute.h
  4. +396
    -368
      mindspore/dataset/text/transforms.py
  5. +10
    -2
      mindspore/dataset/transforms/c_transforms.py
  6. +11
    -9
      mindspore/dataset/transforms/validators.py
  7. +631
    -696
      mindspore/dataset/vision/c_transforms.py
  8. +16
    -1
      tests/ut/python/dataset/test_HWC2CHW.py
  9. +5
    -5
      tests/ut/python/dataset/test_compose.py
  10. +17
    -1
      tests/ut/python/dataset/test_invert.py
  11. +18
    -1
      tests/ut/python/dataset/test_random_crop_and_resize.py
  12. +20
    -1
      tests/ut/python/dataset/test_text_jieba_tokenizer.py
  13. +20
    -1
      tests/ut/python/dataset/test_uniform_augment.py
  14. +13
    -1
      tests/ut/python/dataset/test_vocab.py

+ 23
- 2
mindspore/ccsrc/minddata/dataset/api/execute.cc View File

@@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -14,10 +14,11 @@
* limitations under the License.
*/

#include "minddata/dataset/include/execute.h"
#include "minddata/dataset/core/tensor_row.h"
#ifdef ENABLE_ANDROID
#include "minddata/dataset/include/de_tensor.h"
#endif
#include "minddata/dataset/include/execute.h"
#include "minddata/dataset/include/tensor.h"
#include "minddata/dataset/kernels/tensor_op.h"
#ifndef ENABLE_ANDROID
@@ -84,5 +85,25 @@ std::shared_ptr<dataset::Tensor> Execute::operator()(std::shared_ptr<dataset::Te
return de_output;
}

Status Execute::operator()(const std::vector<std::shared_ptr<Tensor>> &input_tensor_list,
std::vector<std::shared_ptr<Tensor>> *output_tensor_list) {
CHECK_FAIL_RETURN_UNEXPECTED(op_ != nullptr, "Input TensorOperation is not valid");
CHECK_FAIL_RETURN_UNEXPECTED(!input_tensor_list.empty(), "Input Tensor is not valid");

TensorRow input, output;
std::copy(input_tensor_list.begin(), input_tensor_list.end(), std::back_inserter(input));
CHECK_FAIL_RETURN_UNEXPECTED(!input.empty(), "Input Tensor is not valid");

std::shared_ptr<TensorOp> transform = op_->Build();
Status rc = transform->Compute(input, &output);
if (rc.IsError()) {
// execution failed
RETURN_STATUS_UNEXPECTED("Operation execution failed : " + rc.ToString());
}

std::copy(output.begin(), output.end(), std::back_inserter(*output_tensor_list));
return Status::OK();
}

} // namespace dataset
} // namespace mindspore

+ 17
- 5
mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/include/execute_binding.cc View File

@@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -28,14 +28,26 @@ PYBIND_REGISTER(Execute, 0, ([](const py::module *m) {
auto execute = std::make_shared<Execute>(toTensorOperation(operation));
return execute;
}))
.def("__call__", [](Execute &self, std::shared_ptr<Tensor> in) {
std::shared_ptr<Tensor> out = self(in);
if (out == nullptr) {
.def("__call__",
[](Execute &self, std::shared_ptr<Tensor> in) {
std::shared_ptr<Tensor> out = self(in);
if (out == nullptr) {
THROW_IF_ERROR([]() {
RETURN_STATUS_UNEXPECTED(
"Failed to execute op in eager mode, please check ERROR log above.");
}());
}
return out;
})
.def("__call__", [](Execute &self, const std::vector<std::shared_ptr<Tensor>> &input_tensor_list) {
std::vector<std::shared_ptr<Tensor>> output_tensor_list;
THROW_IF_ERROR(self(input_tensor_list, &output_tensor_list));
if (output_tensor_list.empty()) {
THROW_IF_ERROR([]() {
RETURN_STATUS_UNEXPECTED("Failed to execute op in eager mode, please check ERROR log above.");
}());
}
return out;
return output_tensor_list;
});
}));
} // namespace dataset


+ 10
- 3
mindspore/ccsrc/minddata/dataset/include/execute.h View File

@@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -43,16 +43,23 @@ class Execute {

#ifdef ENABLE_ANDROID
/// \brief callable function to execute the TensorOperation in eager mode
/// \param[inout] input - the tensor to be transformed
/// \param[in] input - the tensor to be transformed
/// \return - the output tensor, nullptr if Compute fails
std::shared_ptr<tensor::MSTensor> operator()(std::shared_ptr<tensor::MSTensor> input);
#endif

/// \brief callable function to execute the TensorOperation in eager mode
/// \param[inout] input - the tensor to be transformed
/// \param[in] input - the tensor to be transformed
/// \return - the output tensor, nullptr if Compute fails
std::shared_ptr<dataset::Tensor> operator()(std::shared_ptr<dataset::Tensor> input);

/// \brief callable function to execute the TensorOperation in eager mode
/// \param[in] input_tensor_list - the tensor to be transformed
/// \param[out] out - the result tensor after transform
/// \return - Status
Status operator()(const std::vector<std::shared_ptr<Tensor>> &input_tensor_list,
std::vector<std::shared_ptr<Tensor>> *out);

private:
std::shared_ptr<TensorOperation> op_;
};


+ 396
- 368
mindspore/dataset/text/transforms.py View File

@@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -59,112 +59,37 @@ from .validators import check_lookup, check_jieba_add_dict, \
check_to_number, check_bert_tokenizer, check_python_tokenizer, check_slidingwindow
from ..core.datatypes import mstype_to_detype
from ..core.validator_helpers import replace_none
from ..transforms.c_transforms import TensorOperation

class TextTensorOperation:
def parse(self):
raise NotImplementedError("TextTensorOperation has to implement parse method.")

class Lookup(TextTensorOperation):
"""
Lookup operator that looks up a word to an id.

Args:
vocab (Vocab): A vocabulary object.
unknown_token (str, optional): Word used for lookup if the word being looked up is out-of-vocabulary (OOV).
If unknown_token is OOV, a runtime error will be thrown (default=None).
data_type (mindspore.dtype, optional): mindspore.dtype that lookup maps string to (default=mstype.int32)

Examples:
>>> import mindspore.dataset.text as text
>>>
>>> # Load vocabulary from list
>>> vocab = text.Vocab.from_list(['深', '圳', '欢', '迎', '您'])
>>> # Use Lookup operator to map tokens to ids
>>> lookup = text.Lookup(vocab)
>>> data1 = data1.map(operations=[lookup])
"""

@check_lookup
def __init__(self, vocab, unknown_token=None, data_type=mstype.int32):
self.vocab = vocab
self.unknown_token = replace_none(unknown_token, '')
self.data_type = data_type

def parse(self):
return cde.LookupOperation(self.vocab, self.unknown_token, str(mstype_to_detype(self.data_type)))


class SlidingWindow(TextTensorOperation):
"""
TensorOp to construct a tensor from data (only 1-D for now), where each element in the dimension axis
is a slice of data starting at the corresponding position, with a specified width.

Args:
width (int): The width of the window. It must be an integer and greater than zero.
axis (int, optional): The axis along which the sliding window is computed (default=0).

Examples:
>>> import mindspore.dataset.text as text
>>>
>>> # Data before
>>> # | col1 |
>>> # +-------------+
>>> # | [1,2,3,4,5] |
>>> # +-------------+
>>> data1 = data1.map(operations=text.SlidingWindow(3, 0))
>>> # Data after
>>> # | col1 |
>>> # +-------------+
>>> # | [[1,2,3], |
>>> # | [2,3,4], |
>>> # | [3,4,5]] |
>>> # +--------------+
"""

@check_slidingwindow
def __init__(self, width, axis=0):
self.width = width
self.axis = axis

def parse(self):
return cde.SlidingWindowOperation(self.width, self.axis)


class Ngram(TextTensorOperation):
class TextTensorOperation(TensorOperation):
"""
TensorOp to generate n-gram from a 1-D string Tensor.

Refer to https://en.wikipedia.org/wiki/N-gram#Examples for an overview of what n-gram is and how it works.

Args:
n (list[int]): n in n-gram, n >= 1. n is a list of positive integers. For example, if n=[4,3], then the result
would be a 4-gram followed by a 3-gram in the same tensor. If the number of words is not enough to make up
for a n-gram, an empty string will be returned. For example, 3 grams on ["mindspore","best"] will result in
an empty string produced.
left_pad (tuple, optional): ("pad_token", pad_width). Padding performed on left side of the sequence. pad_width
will be capped at n-1. left_pad=("_",2) would pad left side of the sequence with "__" (default=None).
right_pad (tuple, optional): ("pad_token", pad_width). Padding performed on right side of the sequence.
pad_width will be capped at n-1. right_pad=("-":2) would pad right side of the sequence with "--"
(default=None).
separator (str, optional): symbol used to join strings together. For example. if 2-gram is
["mindspore", "amazing"] with separator="-", the result would be ["mindspore-amazing"]
(default=None, which means whitespace is used).

Examples:
>>> import mindspore.dataset.text as text
>>>
>>> data1 = data1.map(operations=text.Ngram(3, separator=" "))
Base class of Text Tensor Ops
"""

@check_ngram
def __init__(self, n, left_pad=("", 0), right_pad=("", 0), separator=" "):
self.ngrams = n
self.left_pad = left_pad
self.right_pad = right_pad
self.separator = separator
def __call__(self, input_tensor):
if not isinstance(input_tensor, list):
input_list = [input_tensor]
else:
input_list = input_tensor
tensor_list = []
for tensor in input_list:
if not isinstance(tensor, str):
raise TypeError("Input should be string or list of strings, got {}.".format(type(tensor)))
tensor_list.append(cde.Tensor(np.asarray(tensor)))
callable_op = cde.Execute(self.parse())
output_list = callable_op(tensor_list)
for i, element in enumerate(output_list):
arr = element.as_array()
if arr.dtype.char == 'S':
output_list[i] = to_str(arr)
else:
output_list[i] = arr
if not isinstance(input_tensor, list) and len(output_list) == 1:
output_list = output_list[0]
return output_list

def parse(self):
return cde.NgramOperation(self.ngrams, self.left_pad, self.right_pad, self.separator)
raise NotImplementedError("TextTensorOperation has to implement parse() method.")


DE_C_INTER_JIEBA_MODE = {
@@ -174,6 +99,18 @@ DE_C_INTER_JIEBA_MODE = {
}


DE_C_INTER_SENTENCEPIECE_LOADTYPE = {
SPieceTokenizerLoadType.FILE: cde.SPieceTokenizerLoadType.DE_SPIECE_TOKENIZER_LOAD_KFILE,
SPieceTokenizerLoadType.MODEL: cde.SPieceTokenizerLoadType.DE_SPIECE_TOKENIZER_LOAD_KMODEL
}


DE_C_INTER_SENTENCEPIECE_OUTTYPE = {
SPieceTokenizerOutType.STRING: cde.SPieceTokenizerOutType.DE_SPIECE_TOKENIZER_OUTTYPE_KString,
SPieceTokenizerOutType.INT: cde.SPieceTokenizerOutType.DE_SPIECE_TOKENIZER_OUTTYPE_KINT
}


class JiebaTokenizer(TextTensorOperation):
"""
Tokenize Chinese string into words based on dictionary.
@@ -335,85 +272,71 @@ class JiebaTokenizer(TextTensorOperation):
" jieba mode file {} is not exist.".format(model_path))


class UnicodeCharTokenizer(TextTensorOperation):
class Lookup(TextTensorOperation):
"""
Tokenize a scalar tensor of UTF-8 string to Unicode characters.
Lookup operator that looks up a word to an id.

Args:
with_offsets (bool, optional): If or not output offsets of tokens (default=False).
vocab (Vocab): A vocabulary object.
unknown_token (str, optional): Word used for lookup if the word being looked up is out-of-vocabulary (OOV).
If unknown_token is OOV, a runtime error will be thrown (default=None).
data_type (mindspore.dtype, optional): mindspore.dtype that lookup maps string to (default=mstype.int32)

Examples:
>>> import mindspore.dataset.text as text
>>>
>>> # If with_offsets=False, default output one column {["text", dtype=str]}
>>> tokenizer_op = text.UnicodeCharTokenizer()
>>> data1 = data1.map(operations=tokenizer_op)
>>> # If with_offsets=False, then output three columns {["token", dtype=str], ["offsets_start", dtype=uint32],
>>> # ["offsets_limit", dtype=uint32]}
>>> tokenizer_op = text.UnicodeCharTokenizer(True)
>>> data2 = data2.map(operations=tokenizer_op, input_columns=["text"],
>>> output_columns=["token", "offsets_start", "offsets_limit"],
>>> column_order=["token", "offsets_start", "offsets_limit"])
>>> # Load vocabulary from list
>>> vocab = text.Vocab.from_list(['深', '圳', '欢', '迎', '您'])
>>> # Use Lookup operator to map tokens to ids
>>> lookup = text.Lookup(vocab)
>>> data1 = data1.map(operations=[lookup])
"""

@check_with_offsets
def __init__(self, with_offsets=False):
self.with_offsets = with_offsets
@check_lookup
def __init__(self, vocab, unknown_token=None, data_type=mstype.int32):
self.vocab = vocab
self.unknown_token = replace_none(unknown_token, '')
self.data_type = data_type

def parse(self):
return cde.UnicodeCharTokenizerOperation(self.with_offsets)
return cde.LookupOperation(self.vocab, self.unknown_token, str(mstype_to_detype(self.data_type)))


# TODO(alexyuyue): Need to decouple WordpieceTokenizerOp to WordpieceTokenizerOperation after it's supported in C++
class WordpieceTokenizer(cde.WordpieceTokenizerOp):
class Ngram(TextTensorOperation):
"""
Tokenize scalar token or 1-D tokens to 1-D subword tokens.
TensorOp to generate n-gram from a 1-D string Tensor.

Refer to https://en.wikipedia.org/wiki/N-gram#Examples for an overview of what n-gram is and how it works.

Args:
vocab (Vocab): A vocabulary object.
suffix_indicator (str, optional): Used to show that the subword is the last part of a word (default='##').
max_bytes_per_token (int, optional): Tokens exceeding this length will not be further split (default=100).
unknown_token (str, optional): When a token cannot be found: if 'unknown_token' is empty string,
return the token directly, else return 'unknown_token' (default='[UNK]').
with_offsets (bool, optional): If or not output offsets of tokens (default=False).
n (list[int]): n in n-gram, n >= 1. n is a list of positive integers. For example, if n=[4,3], then the result
would be a 4-gram followed by a 3-gram in the same tensor. If the number of words is not enough to make up
for a n-gram, an empty string will be returned. For example, 3 grams on ["mindspore","best"] will result in
an empty string produced.
left_pad (tuple, optional): ("pad_token", pad_width). Padding performed on left side of the sequence. pad_width
will be capped at n-1. left_pad=("_",2) would pad left side of the sequence with "__" (default=None).
right_pad (tuple, optional): ("pad_token", pad_width). Padding performed on right side of the sequence.
pad_width will be capped at n-1. right_pad=("-":2) would pad right side of the sequence with "--"
(default=None).
separator (str, optional): symbol used to join strings together. For example. if 2-gram is
["mindspore", "amazing"] with separator="-", the result would be ["mindspore-amazing"]
(default=None, which means whitespace is used).

Examples:
>>> import mindspore.dataset.text as text
>>>
>>> # If with_offsets=False, default output one column {["text", dtype=str]}
>>> tokenizer_op = text.WordpieceTokenizer(vocab=vocab, unknown_token='[UNK]',
>>> max_bytes_per_token=100, with_offsets=False)
>>> data1 = data1.map(operations=tokenizer_op)
>>> # If with_offsets=False, then output three columns {["token", dtype=str], ["offsets_start", dtype=uint32],
>>> # ["offsets_limit", dtype=uint32]}
>>> tokenizer_op = text.WordpieceTokenizer(vocab=vocab, unknown_token='[UNK]',
>>> max_bytes_per_token=100, with_offsets=True)
>>> data2 = data2.map(operations=tokenizer_op,
>>> input_columns=["text"], output_columns=["token", "offsets_start", "offsets_limit"],
>>> column_order=["token", "offsets_start", "offsets_limit"])
>>> data1 = data1.map(operations=text.Ngram(3, separator=" "))
"""

@check_wordpiece_tokenizer
def __init__(self, vocab, suffix_indicator='##', max_bytes_per_token=100,
unknown_token='[UNK]', with_offsets=False):
self.vocab = vocab
self.suffix_indicator = suffix_indicator
self.max_bytes_per_token = max_bytes_per_token
self.unknown_token = unknown_token
self.with_offsets = with_offsets
super().__init__(self.vocab, self.suffix_indicator, self.max_bytes_per_token,
self.unknown_token, self.with_offsets)


DE_C_INTER_SENTENCEPIECE_LOADTYPE = {
SPieceTokenizerLoadType.FILE: cde.SPieceTokenizerLoadType.DE_SPIECE_TOKENIZER_LOAD_KFILE,
SPieceTokenizerLoadType.MODEL: cde.SPieceTokenizerLoadType.DE_SPIECE_TOKENIZER_LOAD_KMODEL
}
@check_ngram
def __init__(self, n, left_pad=("", 0), right_pad=("", 0), separator=" "):
self.ngrams = n
self.left_pad = left_pad
self.right_pad = right_pad
self.separator = separator

DE_C_INTER_SENTENCEPIECE_OUTTYPE = {
SPieceTokenizerOutType.STRING: cde.SPieceTokenizerOutType.DE_SPIECE_TOKENIZER_OUTTYPE_KString,
SPieceTokenizerOutType.INT: cde.SPieceTokenizerOutType.DE_SPIECE_TOKENIZER_OUTTYPE_KINT
}
def parse(self):
return cde.NgramOperation(self.ngrams, self.left_pad, self.right_pad, self.separator)


class SentencePieceTokenizer(TextTensorOperation):
@@ -441,75 +364,336 @@ class SentencePieceTokenizer(TextTensorOperation):
return cde.SentencePieceTokenizerOperation(self.mode, DE_C_INTER_SENTENCEPIECE_OUTTYPE[self.out_type])


if platform.system().lower() != 'windows':
class WhitespaceTokenizer(TextTensorOperation):
"""
Tokenize a scalar tensor of UTF-8 string on ICU4C defined whitespaces, such as: ' ', '\\\\t', '\\\\r', '\\\\n'.
class SlidingWindow(TextTensorOperation):
"""
TensorOp to construct a tensor from data (only 1-D for now), where each element in the dimension axis
is a slice of data starting at the corresponding position, with a specified width.

Note:
WhitespaceTokenizer is not supported on Windows platform yet.
Args:
width (int): The width of the window. It must be an integer and greater than zero.
axis (int, optional): The axis along which the sliding window is computed (default=0).

Args:
with_offsets (bool, optional): If or not output offsets of tokens (default=False).
Examples:
>>> import mindspore.dataset.text as text
>>>
>>> # Data before
>>> # | col1 |
>>> # +-------------+
>>> # | [1,2,3,4,5] |
>>> # +-------------+
>>> data1 = data1.map(operations=text.SlidingWindow(3, 0))
>>> # Data after
>>> # | col1 |
>>> # +-------------+
>>> # | [[1,2,3], |
>>> # | [2,3,4], |
>>> # | [3,4,5]] |
>>> # +--------------+
"""

Examples:
>>> import mindspore.dataset.text as text
>>>
>>> # If with_offsets=False, default output one column {["text", dtype=str]}
>>> tokenizer_op = text.WhitespaceTokenizer()
>>> data1 = data1.map(operations=tokenizer_op)
>>> # If with_offsets=False, then output three columns {["token", dtype=str],
>>> # ["offsets_start", dtype=uint32],
>>> # ["offsets_limit", dtype=uint32]}
>>> tokenizer_op = text.WhitespaceTokenizer(True)
>>> data2 = data2.map(operations=tokenizer_op, input_columns=["text"],
>>> output_columns=["token", "offsets_start", "offsets_limit"],
>>> column_order=["token", "offsets_start", "offsets_limit"])
"""
@check_slidingwindow
def __init__(self, width, axis=0):
self.width = width
self.axis = axis

@check_with_offsets
def __init__(self, with_offsets=False):
def parse(self):
return cde.SlidingWindowOperation(self.width, self.axis)


class ToNumber(TextTensorOperation):
"""
Tensor operation to convert every element of a string tensor to a number.

Strings are casted according to the rules specified in the following links:
https://en.cppreference.com/w/cpp/string/basic_string/stof,
https://en.cppreference.com/w/cpp/string/basic_string/stoul,
except that any strings which represent negative numbers cannot be cast to an
unsigned integer type.

Args:
data_type (mindspore.dtype): mindspore.dtype to be casted to. Must be
a numeric type.

Raises:
RuntimeError: If strings are invalid to cast, or are out of range after being casted.

Examples:
>>> import mindspore.dataset.text as text
>>> import mindspore.common.dtype as mstype
>>>
>>> to_number_op = text.ToNumber(mstype.int8)
>>> data1 = data1.map(operations=to_number_op)
"""

@check_to_number
def __init__(self, data_type):
data_type = mstype_to_detype(data_type)
self.data_type = str(data_type)

def parse(self):
return cde.ToNumberOperation(self.data_type)


class TruncateSequencePair(TextTensorOperation):
"""
Truncate a pair of rank-1 tensors such that the total length is less than max_length.

This operation takes two input tensors and returns two output Tensors.

Args:
max_length (int): Maximum length required.

Examples:
>>> import mindspore.dataset.text as text
>>>
>>> # Data before
>>> # | col1 | col2 |
>>> # +---------+---------|
>>> # | [1,2,3] | [4,5] |
>>> # +---------+---------+
>>> data1 = data1.map(operations=text.TruncateSequencePair(4))
>>> # Data after
>>> # | col1 | col2 |
>>> # +---------+---------+
>>> # | [1,2] | [4,5] |
>>> # +---------+---------+
"""

@check_pair_truncate
def __init__(self, max_length):
self.max_length = max_length

def parse(self):
return cde.TruncateSequencePairOperation(self.max_length)


class UnicodeCharTokenizer(TextTensorOperation):
"""
Tokenize a scalar tensor of UTF-8 string to Unicode characters.

Args:
with_offsets (bool, optional): If or not output offsets of tokens (default=False).

Examples:
>>> import mindspore.dataset.text as text
>>>
>>> # If with_offsets=False, default output one column {["text", dtype=str]}
>>> tokenizer_op = text.UnicodeCharTokenizer()
>>> data1 = data1.map(operations=tokenizer_op)
>>> # If with_offsets=False, then output three columns {["token", dtype=str], ["offsets_start", dtype=uint32],
>>> # ["offsets_limit", dtype=uint32]}
>>> tokenizer_op = text.UnicodeCharTokenizer(True)
>>> data2 = data2.map(operations=tokenizer_op, input_columns=["text"],
>>> output_columns=["token", "offsets_start", "offsets_limit"],
>>> column_order=["token", "offsets_start", "offsets_limit"])
"""

@check_with_offsets
def __init__(self, with_offsets=False):
self.with_offsets = with_offsets

def parse(self):
return cde.UnicodeCharTokenizerOperation(self.with_offsets)


# TODO(alexyuyue): Need to decouple WordpieceTokenizerOp to WordpieceTokenizerOperation after it's supported in C++
class WordpieceTokenizer(cde.WordpieceTokenizerOp):
"""
Tokenize scalar token or 1-D tokens to 1-D subword tokens.

Args:
vocab (Vocab): A vocabulary object.
suffix_indicator (str, optional): Used to show that the subword is the last part of a word (default='##').
max_bytes_per_token (int, optional): Tokens exceeding this length will not be further split (default=100).
unknown_token (str, optional): When a token cannot be found: if 'unknown_token' is empty string,
return the token directly, else return 'unknown_token' (default='[UNK]').
with_offsets (bool, optional): If or not output offsets of tokens (default=False).

Examples:
>>> import mindspore.dataset.text as text
>>>
>>> # If with_offsets=False, default output one column {["text", dtype=str]}
>>> tokenizer_op = text.WordpieceTokenizer(vocab=vocab, unknown_token='[UNK]',
>>> max_bytes_per_token=100, with_offsets=False)
>>> data1 = data1.map(operations=tokenizer_op)
>>> # If with_offsets=False, then output three columns {["token", dtype=str], ["offsets_start", dtype=uint32],
>>> # ["offsets_limit", dtype=uint32]}
>>> tokenizer_op = text.WordpieceTokenizer(vocab=vocab, unknown_token='[UNK]',
>>> max_bytes_per_token=100, with_offsets=True)
>>> data2 = data2.map(operations=tokenizer_op,
>>> input_columns=["text"], output_columns=["token", "offsets_start", "offsets_limit"],
>>> column_order=["token", "offsets_start", "offsets_limit"])
"""

@check_wordpiece_tokenizer
def __init__(self, vocab, suffix_indicator='##', max_bytes_per_token=100,
unknown_token='[UNK]', with_offsets=False):
self.vocab = vocab
self.suffix_indicator = suffix_indicator
self.max_bytes_per_token = max_bytes_per_token
self.unknown_token = unknown_token
self.with_offsets = with_offsets
super().__init__(self.vocab, self.suffix_indicator, self.max_bytes_per_token,
self.unknown_token, self.with_offsets)


class PythonTokenizer:
"""
Callable class to be used for user-defined string tokenizer.

Args:
tokenizer (Callable): Python function that takes a `str` and returns a list of `str` as tokens.

Examples:
>>> import mindspore.dataset.text as text
>>>
>>> def my_tokenizer(line):
>>> return line.split()
>>> data1 = data1.map(operations=text.PythonTokenizer(my_tokenizer))
"""

@check_python_tokenizer
def __init__(self, tokenizer):
self.tokenizer = np.vectorize(lambda x: np.array(tokenizer(x), dtype='U'), signature='()->(n)')

def __call__(self, in_array):
in_array = to_str(in_array)
tokens = self.tokenizer(in_array)
return tokens

if platform.system().lower() != 'windows':
DE_C_INTER_NORMALIZE_FORM = {
NormalizeForm.NONE: cde.NormalizeForm.DE_NORMALIZE_NONE,
NormalizeForm.NFC: cde.NormalizeForm.DE_NORMALIZE_NFC,
NormalizeForm.NFKC: cde.NormalizeForm.DE_NORMALIZE_NFKC,
NormalizeForm.NFD: cde.NormalizeForm.DE_NORMALIZE_NFD,
NormalizeForm.NFKD: cde.NormalizeForm.DE_NORMALIZE_NFKD
}


class BasicTokenizer(TextTensorOperation):
"""
Tokenize a scalar tensor of UTF-8 string by specific rules.

Note:
BasicTokenizer is not supported on Windows platform yet.

Args:
lower_case (bool, optional): If True, apply CaseFold, NormalizeUTF8(NFD mode), RegexReplace operation
on input text to fold the text to lower case and strip accents characters. If False, only apply
NormalizeUTF8('normalization_form' mode) operation on input text (default=False).
keep_whitespace (bool, optional): If True, the whitespace will be kept in out tokens (default=False).
normalization_form (NormalizeForm, optional): Used to specify a specific normalize mode. This is
only effective when 'lower_case' is False. See NormalizeUTF8 for details (default=NormalizeForm.NONE).
preserve_unused_token (bool, optional): If True, do not split special tokens like
'[CLS]', '[SEP]', '[UNK]', '[PAD]', '[MASK]' (default=True).
with_offsets (bool, optional): If or not output offsets of tokens (default=False).

Examples:
>>> import mindspore.dataset.text as text
>>>
>>> # If with_offsets=False, default output one column {["text", dtype=str]}
>>> tokenizer_op = text.BasicTokenizer(lower_case=False,
>>> keep_whitespace=False,
>>> normalization_form=NormalizeForm.NONE,
>>> preserve_unused_token=True,
>>> with_offsets=False)
>>> data1 = data1.map(operations=tokenizer_op)
>>> # If with_offsets=False, then output three columns {["token", dtype=str],
>>> # ["offsets_start", dtype=uint32],
>>> # ["offsets_limit", dtype=uint32]}
>>> tokenizer_op = text.BasicTokenizer(lower_case=False,
>>> keep_whitespace=False,
>>> normalization_form=NormalizeForm.NONE,
>>> preserve_unused_token=True,
>>> with_offsets=True)
>>> data2 = data2.map(operations=tokenizer_op, input_columns=["text"],
>>> output_columns=["token", "offsets_start", "offsets_limit"],
>>> column_order=["token", "offsets_start", "offsets_limit"])
"""

@check_basic_tokenizer
def __init__(self, lower_case=False, keep_whitespace=False, normalization_form=NormalizeForm.NONE,
preserve_unused_token=True, with_offsets=False):
if not isinstance(normalization_form, NormalizeForm):
raise TypeError("Wrong input type for normalization_form, should be enum of 'NormalizeForm'.")

self.lower_case = lower_case
self.keep_whitespace = keep_whitespace
self.normalization_form = DE_C_INTER_NORMALIZE_FORM[normalization_form]
self.preserve_unused_token = preserve_unused_token
self.with_offsets = with_offsets

def parse(self):
return cde.WhitespaceTokenizerOperation(self.with_offsets)
return cde.BasicTokenizerOperation(self.lower_case, self.keep_whitespace, self.normalization_form,
self.preserve_unused_token, self.with_offsets)


class UnicodeScriptTokenizer(TextTensorOperation):
class BertTokenizer(TextTensorOperation):
"""
Tokenize a scalar tensor of UTF-8 string on Unicode script boundaries.
Tokenizer used for Bert text process.

Note:
UnicodeScriptTokenizer is not supported on Windows platform yet.
BertTokenizer is not supported on Windows platform yet.

Args:
keep_whitespace (bool, optional): If or not emit whitespace tokens (default=False).
vocab (Vocab): A vocabulary object.
suffix_indicator (str, optional): Used to show that the subword is the last part of a word (default='##').
max_bytes_per_token (int, optional): Tokens exceeding this length will not be further split (default=100).
unknown_token (str, optional): When a token cannot be found: if 'unknown_token' is empty string,
return the token directly, else return 'unknown_token'(default='[UNK]').
lower_case (bool, optional): If True, apply CaseFold, NormalizeUTF8(NFD mode), RegexReplace operation
on input text to fold the text to lower case and strip accented characters. If False, only apply
NormalizeUTF8('normalization_form' mode) operation on input text (default=False).
keep_whitespace (bool, optional): If True, the whitespace will be kept in out tokens (default=False).
normalization_form (NormalizeForm, optional): Used to specify a specific normalize mode,
only effective when 'lower_case' is False. See NormalizeUTF8 for details (default='NONE').
preserve_unused_token (bool, optional): If True, do not split special tokens like
'[CLS]', '[SEP]', '[UNK]', '[PAD]', '[MASK]' (default=True).
with_offsets (bool, optional): If or not output offsets of tokens (default=False).

Examples:
>>> import mindspore.dataset.text as text
>>>
>>> # If with_offsets=False, default output one column {["text", dtype=str]}
>>> tokenizer_op = text.UnicodeScriptTokenizerOp(keep_whitespace=True, with_offsets=False)
>>> tokenizer_op = text.BertTokenizer(vocab=vocab, suffix_indicator='##', max_bytes_per_token=100,
>>> unknown_token='[UNK]', lower_case=False, keep_whitespace=False,
>>> normalization_form=NormalizeForm.NONE, preserve_unused_token=True,
>>> with_offsets=False)
>>> data1 = data1.map(operations=tokenizer_op)
>>> # If with_offsets=False, then output three columns {["token", dtype=str],
>>> # ["offsets_start", dtype=uint32],
>>> # ["offsets_limit", dtype=uint32]}
>>> tokenizer_op = text.UnicodeScriptTokenizerOp(keep_whitespace=True, with_offsets=True)
>>> tokenizer_op = text.BertTokenizer(vocab=vocab, suffix_indicator='##', max_bytes_per_token=100,
>>> unknown_token='[UNK]', lower_case=False, keep_whitespace=False,
>>> normalization_form=NormalizeForm.NONE, preserve_unused_token=True,
>>> with_offsets=True)
>>> data2 = data2.map(operations=tokenizer_op, input_columns=["text"],
>>> output_columns=["token", "offsets_start", "offsets_limit"],
>>> column_order=["token", "offsets_start", "offsets_limit"])
"""

@check_unicode_script_tokenizer
def __init__(self, keep_whitespace=False, with_offsets=False):
keep_whitespace = replace_none(keep_whitespace, False)
with_offsets = replace_none(with_offsets, False)
@check_bert_tokenizer
def __init__(self, vocab, suffix_indicator='##', max_bytes_per_token=100, unknown_token='[UNK]',
lower_case=False, keep_whitespace=False, normalization_form=NormalizeForm.NONE,
preserve_unused_token=True, with_offsets=False):
if not isinstance(normalization_form, NormalizeForm):
raise TypeError("Wrong input type for normalization_form, should be enum of 'NormalizeForm'.")

self.vocab = vocab
self.suffix_indicator = suffix_indicator
self.max_bytes_per_token = max_bytes_per_token
self.unknown_token = unknown_token
self.lower_case = lower_case
self.keep_whitespace = keep_whitespace
self.normalization_form = DE_C_INTER_NORMALIZE_FORM[normalization_form]
self.preserve_unused_token = preserve_unused_token
self.with_offsets = with_offsets

def parse(self):
return cde.UnicodeScriptTokenizerOperation(self.keep_whitespace, self.with_offsets)
return cde.BertTokenizerOperation(self.vocab, self.suffix_indicator, self.max_bytes_per_token,
self.unknown_token, self.lower_case, self.keep_whitespace,
self.normalization_form, self.preserve_unused_token, self.with_offsets)


class CaseFold(TextTensorOperation):
@@ -530,15 +714,6 @@ if platform.system().lower() != 'windows':
return cde.CaseFoldOperation()


DE_C_INTER_NORMALIZE_FORM = {
NormalizeForm.NONE: cde.NormalizeForm.DE_NORMALIZE_NONE,
NormalizeForm.NFC: cde.NormalizeForm.DE_NORMALIZE_NFC,
NormalizeForm.NFKC: cde.NormalizeForm.DE_NORMALIZE_NFKC,
NormalizeForm.NFD: cde.NormalizeForm.DE_NORMALIZE_NFD,
NormalizeForm.NFKD: cde.NormalizeForm.DE_NORMALIZE_NFKD
}


class NormalizeUTF8(TextTensorOperation):
"""
Apply normalize operation on UTF-8 string tensor.
@@ -651,218 +826,71 @@ if platform.system().lower() != 'windows':
return cde.RegexTokenizerOperation(self.delim_pattern, self.keep_delim_pattern, self.with_offsets)


class BasicTokenizer(TextTensorOperation):
class UnicodeScriptTokenizer(TextTensorOperation):
"""
Tokenize a scalar tensor of UTF-8 string by specific rules.
Tokenize a scalar tensor of UTF-8 string on Unicode script boundaries.

Note:
BasicTokenizer is not supported on Windows platform yet.
UnicodeScriptTokenizer is not supported on Windows platform yet.

Args:
lower_case (bool, optional): If True, apply CaseFold, NormalizeUTF8(NFD mode), RegexReplace operation
on input text to fold the text to lower case and strip accents characters. If False, only apply
NormalizeUTF8('normalization_form' mode) operation on input text (default=False).
keep_whitespace (bool, optional): If True, the whitespace will be kept in out tokens (default=False).
normalization_form (NormalizeForm, optional): Used to specify a specific normalize mode. This is
only effective when 'lower_case' is False. See NormalizeUTF8 for details (default=NormalizeForm.NONE).
preserve_unused_token (bool, optional): If True, do not split special tokens like
'[CLS]', '[SEP]', '[UNK]', '[PAD]', '[MASK]' (default=True).
keep_whitespace (bool, optional): If or not emit whitespace tokens (default=False).
with_offsets (bool, optional): If or not output offsets of tokens (default=False).

Examples:
>>> import mindspore.dataset.text as text
>>>
>>> # If with_offsets=False, default output one column {["text", dtype=str]}
>>> tokenizer_op = text.BasicTokenizer(lower_case=False,
>>> keep_whitespace=False,
>>> normalization_form=NormalizeForm.NONE,
>>> preserve_unused_token=True,
>>> with_offsets=False)
>>> tokenizer_op = text.UnicodeScriptTokenizerOp(keep_whitespace=True, with_offsets=False)
>>> data1 = data1.map(operations=tokenizer_op)
>>> # If with_offsets=False, then output three columns {["token", dtype=str],
>>> # ["offsets_start", dtype=uint32],
>>> # ["offsets_limit", dtype=uint32]}
>>> tokenizer_op = text.BasicTokenizer(lower_case=False,
>>> keep_whitespace=False,
>>> normalization_form=NormalizeForm.NONE,
>>> preserve_unused_token=True,
>>> with_offsets=True)
>>> tokenizer_op = text.UnicodeScriptTokenizerOp(keep_whitespace=True, with_offsets=True)
>>> data2 = data2.map(operations=tokenizer_op, input_columns=["text"],
>>> output_columns=["token", "offsets_start", "offsets_limit"],
>>> column_order=["token", "offsets_start", "offsets_limit"])
"""

@check_basic_tokenizer
def __init__(self, lower_case=False, keep_whitespace=False, normalization_form=NormalizeForm.NONE,
preserve_unused_token=True, with_offsets=False):
if not isinstance(normalization_form, NormalizeForm):
raise TypeError("Wrong input type for normalization_form, should be enum of 'NormalizeForm'.")

self.lower_case = lower_case
@check_unicode_script_tokenizer
def __init__(self, keep_whitespace=False, with_offsets=False):
keep_whitespace = replace_none(keep_whitespace, False)
with_offsets = replace_none(with_offsets, False)
self.keep_whitespace = keep_whitespace
self.normalization_form = DE_C_INTER_NORMALIZE_FORM[normalization_form]
self.preserve_unused_token = preserve_unused_token
self.with_offsets = with_offsets

def parse(self):
return cde.BasicTokenizerOperation(self.lower_case, self.keep_whitespace, self.normalization_form,
self.preserve_unused_token, self.with_offsets)
return cde.UnicodeScriptTokenizerOperation(self.keep_whitespace, self.with_offsets)


class BertTokenizer(TextTensorOperation):
class WhitespaceTokenizer(TextTensorOperation):
"""
Tokenizer used for Bert text process.
Tokenize a scalar tensor of UTF-8 string on ICU4C defined whitespaces, such as: ' ', '\\\\t', '\\\\r', '\\\\n'.

Note:
BertTokenizer is not supported on Windows platform yet.
WhitespaceTokenizer is not supported on Windows platform yet.

Args:
vocab (Vocab): A vocabulary object.
suffix_indicator (str, optional): Used to show that the subword is the last part of a word (default='##').
max_bytes_per_token (int, optional): Tokens exceeding this length will not be further split (default=100).
unknown_token (str, optional): When a token cannot be found: if 'unknown_token' is empty string,
return the token directly, else return 'unknown_token'(default='[UNK]').
lower_case (bool, optional): If True, apply CaseFold, NormalizeUTF8(NFD mode), RegexReplace operation
on input text to fold the text to lower case and strip accented characters. If False, only apply
NormalizeUTF8('normalization_form' mode) operation on input text (default=False).
keep_whitespace (bool, optional): If True, the whitespace will be kept in out tokens (default=False).
normalization_form (NormalizeForm, optional): Used to specify a specific normalize mode,
only effective when 'lower_case' is False. See NormalizeUTF8 for details (default='NONE').
preserve_unused_token (bool, optional): If True, do not split special tokens like
'[CLS]', '[SEP]', '[UNK]', '[PAD]', '[MASK]' (default=True).
with_offsets (bool, optional): If or not output offsets of tokens (default=False).

Examples:
>>> import mindspore.dataset.text as text
>>>
>>> # If with_offsets=False, default output one column {["text", dtype=str]}
>>> tokenizer_op = text.BertTokenizer(vocab=vocab, suffix_indicator='##', max_bytes_per_token=100,
>>> unknown_token='[UNK]', lower_case=False, keep_whitespace=False,
>>> normalization_form=NormalizeForm.NONE, preserve_unused_token=True,
>>> with_offsets=False)
>>> tokenizer_op = text.WhitespaceTokenizer()
>>> data1 = data1.map(operations=tokenizer_op)
>>> # If with_offsets=False, then output three columns {["token", dtype=str],
>>> # ["offsets_start", dtype=uint32],
>>> # ["offsets_limit", dtype=uint32]}
>>> tokenizer_op = text.BertTokenizer(vocab=vocab, suffix_indicator='##', max_bytes_per_token=100,
>>> unknown_token='[UNK]', lower_case=False, keep_whitespace=False,
>>> normalization_form=NormalizeForm.NONE, preserve_unused_token=True,
>>> with_offsets=True)
>>> tokenizer_op = text.WhitespaceTokenizer(True)
>>> data2 = data2.map(operations=tokenizer_op, input_columns=["text"],
>>> output_columns=["token", "offsets_start", "offsets_limit"],
>>> column_order=["token", "offsets_start", "offsets_limit"])
"""

@check_bert_tokenizer
def __init__(self, vocab, suffix_indicator='##', max_bytes_per_token=100, unknown_token='[UNK]',
lower_case=False, keep_whitespace=False, normalization_form=NormalizeForm.NONE,
preserve_unused_token=True, with_offsets=False):
if not isinstance(normalization_form, NormalizeForm):
raise TypeError("Wrong input type for normalization_form, should be enum of 'NormalizeForm'.")

self.vocab = vocab
self.suffix_indicator = suffix_indicator
self.max_bytes_per_token = max_bytes_per_token
self.unknown_token = unknown_token
self.lower_case = lower_case
self.keep_whitespace = keep_whitespace
self.normalization_form = DE_C_INTER_NORMALIZE_FORM[normalization_form]
self.preserve_unused_token = preserve_unused_token
@check_with_offsets
def __init__(self, with_offsets=False):
self.with_offsets = with_offsets

def parse(self):
return cde.BertTokenizerOperation(self.vocab, self.suffix_indicator, self.max_bytes_per_token,
self.unknown_token, self.lower_case, self.keep_whitespace,
self.normalization_form, self.preserve_unused_token, self.with_offsets)


class TruncateSequencePair(TextTensorOperation):
"""
Truncate a pair of rank-1 tensors such that the total length is less than max_length.

This operation takes two input tensors and returns two output Tensors.

Args:
max_length (int): Maximum length required.

Examples:
>>> import mindspore.dataset.text as text
>>>
>>> # Data before
>>> # | col1 | col2 |
>>> # +---------+---------|
>>> # | [1,2,3] | [4,5] |
>>> # +---------+---------+
>>> data1 = data1.map(operations=text.TruncateSequencePair(4))
>>> # Data after
>>> # | col1 | col2 |
>>> # +---------+---------+
>>> # | [1,2] | [4,5] |
>>> # +---------+---------+
"""

@check_pair_truncate
def __init__(self, max_length):
self.max_length = max_length

def parse(self):
return cde.TruncateSequencePairOperation(self.max_length)


class ToNumber(TextTensorOperation):
"""
Tensor operation to convert every element of a string tensor to a number.

Strings are casted according to the rules specified in the following links:
https://en.cppreference.com/w/cpp/string/basic_string/stof,
https://en.cppreference.com/w/cpp/string/basic_string/stoul,
except that any strings which represent negative numbers cannot be cast to an
unsigned integer type.

Args:
data_type (mindspore.dtype): mindspore.dtype to be casted to. Must be
a numeric type.

Raises:
RuntimeError: If strings are invalid to cast, or are out of range after being casted.

Examples:
>>> import mindspore.dataset.text as text
>>> import mindspore.common.dtype as mstype
>>>
>>> to_number_op = text.ToNumber(mstype.int8)
>>> data1 = data1.map(operations=to_number_op)
"""

@check_to_number
def __init__(self, data_type):
data_type = mstype_to_detype(data_type)
self.data_type = str(data_type)

def parse(self):
return cde.ToNumberOperation(self.data_type)


class PythonTokenizer:
"""
Callable class to be used for user-defined string tokenizer.

Args:
tokenizer (Callable): Python function that takes a `str` and returns a list of `str` as tokens.

Examples:
>>> import mindspore.dataset.text as text
>>>
>>> def my_tokenizer(line):
>>> return line.split()
>>> data1 = data1.map(operations=text.PythonTokenizer(my_tokenizer))
"""

@check_python_tokenizer
def __init__(self, tokenizer):
self.tokenizer = np.vectorize(lambda x: np.array(tokenizer(x), dtype='U'), signature='()->(n)')

def __call__(self, in_array):
in_array = to_str(in_array)
tokens = self.tokenizer(in_array)
return tokens
return cde.WhitespaceTokenizerOperation(self.with_offsets)

+ 10
- 2
mindspore/dataset/transforms/c_transforms.py View File

@@ -1,4 +1,4 @@
# Copyright 2019 Huawei Technologies Co., Ltd
# Copyright 2019-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -26,6 +26,14 @@ from .validators import check_num_classes, check_de_type, check_fill_value, chec
from ..core.datatypes import mstype_to_detype


class TensorOperation:
def __call__(self):
raise NotImplementedError("TensorOperation has to implement __call__() method.")

def parse(self):
raise NotImplementedError("TensorOperation has to implement parse() method.")


class OneHot(cde.OneHotOp):
"""
Tensor operation to apply one hot encoding.
@@ -304,7 +312,7 @@ class Unique(cde.UniqueOp):
Also return an index tensor that contains the index of each element of the
input tensor in the Unique output tensor.

Finally, return a count tensor that constains the count of each element of
Finally, return a count tensor that contains the count of each element of
the output tensor in the input tensor.

Note:


+ 11
- 9
mindspore/dataset/transforms/validators.py View File

@@ -1,4 +1,4 @@
# Copyright 2019 Huawei Technologies Co., Ltd
# Copyright 2019-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -237,8 +237,8 @@ def check_compose_list(method):
type_check(transforms, (list,), transforms)
if not transforms:
raise ValueError("transforms list is empty.")
for i, transfrom in enumerate(transforms):
if not callable(transfrom):
for i, transform in enumerate(transforms):
if not callable(transform):
raise ValueError("transforms[{}] is not callable.".format(i))
return method(self, *args, **kwargs)

@@ -269,9 +269,10 @@ def check_random_apply(method):
[transforms, prob], _ = parse_user_args(method, *args, **kwargs)
type_check(transforms, (list,), "transforms")

for i, transfrom in enumerate(transforms):
if not callable(transfrom):
raise ValueError("transforms[{}] is not callable.".format(i))
for i, transform in enumerate(transforms):
if str(transform).find("c_transform") >= 0:
raise ValueError("transforms[{}] is not a py transforms. Should not use a c transform in py transform" \
.format(i))

if prob is not None:
type_check(prob, (float, int,), "prob")
@@ -290,9 +291,10 @@ def check_transforms_list(method):
[transforms], _ = parse_user_args(method, *args, **kwargs)

type_check(transforms, (list,), "transforms")
for i, transfrom in enumerate(transforms):
if not callable(transfrom):
raise ValueError("transforms[{}] is not callable.".format(i))
for i, transform in enumerate(transforms):
if str(transform).find("c_transform") >= 0:
raise ValueError("transforms[{}] is not a py transforms. Should not use a c transform in py transform" \
.format(i))
return method(self, *args, **kwargs)

return new_method

+ 631
- 696
mindspore/dataset/vision/c_transforms.py
File diff suppressed because it is too large
View File


+ 16
- 1
tests/ut/python/dataset/test_HWC2CHW.py View File

@@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -29,6 +29,20 @@ DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"


def test_HWC2CHW_callable():
"""
Test HWC2CHW is callable
"""
logger.info("Test HWC2CHW callable")
img = np.fromfile("../data/dataset/apple.jpg", dtype=np.uint8)
logger.info("Image.type: {}, Image.shape: {}".format(type(img), img.shape))

img = c_vision.Decode()(img)
img = c_vision.HWC2CHW()(img)
logger.info("Image.type: {}, Image.shape: {}".format(type(img), img.shape))
assert img.shape == (3, 2268, 4032)


def test_HWC2CHW(plot=False):
"""
Test HWC2CHW
@@ -122,6 +136,7 @@ def test_HWC2CHW_comp(plot=False):


if __name__ == '__main__':
test_HWC2CHW_callable()
test_HWC2CHW(True)
test_HWC2CHW_md5()
test_HWC2CHW_comp(True)

+ 5
- 5
tests/ut/python/dataset/test_compose.py View File

@@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -219,7 +219,7 @@ def test_c_py_compose_vision_module(plot=False, run_golden=True):

def test_py_transforms_with_c_vision():
"""
These examples will fail, as py_transforms.Random(Apply/Choice/Order) expect callable functions
These examples will fail, as c_transform should not be used in py_transforms.Random(Apply/Choice/Order)
"""

ds.config.set_seed(0)
@@ -236,15 +236,15 @@ def test_py_transforms_with_c_vision():

with pytest.raises(ValueError) as error_info:
test_config(py_transforms.RandomApply([c_vision.RandomResizedCrop(200)]))
assert "transforms[0] is not callable." in str(error_info.value)
assert "transforms[0] is not a py transforms." in str(error_info.value)

with pytest.raises(ValueError) as error_info:
test_config(py_transforms.RandomChoice([c_vision.RandomResizedCrop(200)]))
assert "transforms[0] is not callable." in str(error_info.value)
assert "transforms[0] is not a py transforms." in str(error_info.value)

with pytest.raises(ValueError) as error_info:
test_config(py_transforms.RandomOrder([np.array, c_vision.RandomResizedCrop(200)]))
assert "transforms[1] is not callable." in str(error_info.value)
assert "transforms[1] is not a py transforms." in str(error_info.value)

with pytest.raises(RuntimeError) as error_info:
test_config([py_transforms.OneHotOp(20, 0.1)])


+ 17
- 1
tests/ut/python/dataset/test_invert.py View File

@@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -29,6 +29,21 @@ DATA_DIR = "../data/dataset/testImageNetData/train/"
GENERATE_GOLDEN = False


def test_invert_callable():
"""
Test Invert is callable
"""
logger.info("Test Invert callable")
img = np.fromfile("../data/dataset/apple.jpg", dtype=np.uint8)
logger.info("Image.type: {}, Image.shape: {}".format(type(img), img.shape))

img = C.Decode()(img)
img = C.Invert()(img)
logger.info("Image.type: {}, Image.shape: {}".format(type(img), img.shape))

assert img.shape == (2268, 4032, 3)


def test_invert_py(plot=False):
"""
Test Invert python op
@@ -247,6 +262,7 @@ def test_invert_md5_c():


if __name__ == "__main__":
test_invert_callable()
test_invert_py(plot=False)
test_invert_c(plot=False)
test_invert_py_c(plot=False)


+ 18
- 1
tests/ut/python/dataset/test_random_crop_and_resize.py View File

@@ -1,4 +1,4 @@
# Copyright 2019 Huawei Technologies Co., Ltd
# Copyright 2019-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -34,6 +34,22 @@ SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
GENERATE_GOLDEN = False


def test_random_crop_and_resize_callable():
"""
Test RandomCropAndResize op is callable
"""
logger.info("test_random_crop_and_resize_callable")
img = np.fromfile("../data/dataset/apple.jpg", dtype=np.uint8)
logger.info("Image.type: {}, Image.shape: {}".format(type(img), img.shape))

decode_op = c_vision.Decode()
img = decode_op(img)

random_crop_and_resize_op = c_vision.RandomResizedCrop((256, 512), (2, 2), (1, 3))
img = random_crop_and_resize_op(img)
assert np.shape(img) == (256, 512, 3)


def test_random_crop_and_resize_op_c(plot=False):
"""
Test RandomCropAndResize op in c transforms
@@ -389,6 +405,7 @@ def test_random_crop_and_resize_06():


if __name__ == "__main__":
test_random_crop_and_resize_callable()
test_random_crop_and_resize_op_c(True)
test_random_crop_and_resize_op_py(True)
test_random_crop_and_resize_op_py_ANTIALIAS()


+ 20
- 1
tests/ut/python/dataset/test_text_jieba_tokenizer.py View File

@@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -16,6 +16,7 @@ import numpy as np
import mindspore.dataset as ds
from mindspore.dataset.text import JiebaTokenizer
from mindspore.dataset.text import JiebaMode, to_str
from mindspore import log as logger

DATA_FILE = "../data/dataset/testJiebaDataset/3.txt"
DATA_ALL_FILE = "../data/dataset/testJiebaDataset/*"
@@ -24,6 +25,23 @@ HMM_FILE = "../data/dataset/jiebadict/hmm_model.utf8"
MP_FILE = "../data/dataset/jiebadict/jieba.dict.utf8"


def test_jieba_callable():
"""
Test jieba tokenizer op is callable
"""
logger.info("test_jieba_callable")
jieba_op1 = JiebaTokenizer(HMM_FILE, MP_FILE, mode=JiebaMode.MP)
jieba_op2 = JiebaTokenizer(HMM_FILE, MP_FILE, mode=JiebaMode.HMM)

text1 = "今天天气太好了我们一起去外面玩吧"
text2 = "男默女泪市长江大桥"
assert np.array_equal(jieba_op1(text1), ['今天天气', '太好了', '我们', '一起', '去', '外面', '玩吧'])
assert np.array_equal(jieba_op2(text1), ['今天', '天气', '太', '好', '了', '我们', '一起', '去', '外面', '玩', '吧'])

jieba_op1.add_word("男默女泪")
assert np.array_equal(jieba_op1(text2), ['男默女泪', '市', '长江大桥'])


def test_jieba_1():
"""Test jieba tokenizer with MP mode"""
data = ds.TextFileDataset(DATA_FILE)
@@ -457,6 +475,7 @@ def test_jieba_6():


if __name__ == "__main__":
test_jieba_callable()
test_jieba_1()
test_jieba_1_1()
test_jieba_1_2()


+ 20
- 1
tests/ut/python/dataset/test_uniform_augment.py View File

@@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -28,6 +28,24 @@ from util import visualize_list, diff_mse
DATA_DIR = "../data/dataset/testImageNetData/train/"


def test_uniform_augment_callable(num_ops=2):
"""
Test UniformAugment is callable
"""
logger.info("test_uniform_augment_callable")
img = np.fromfile("../data/dataset/apple.jpg", dtype=np.uint8)
logger.info("Image.type: {}, Image.shape: {}".format(type(img), img.shape))

decode_op = C.Decode()
img = decode_op(img)

transforms_ua = [C.RandomCrop(size=[400, 400], padding=[32, 32, 32, 32]),
C.RandomCrop(size=[400, 400], padding=[32, 32, 32, 32])]
uni_aug = C.UniformAugment(transforms=transforms_ua, num_ops=num_ops)
img = uni_aug([img, img])
assert ((np.shape(img) == (2, 2268, 4032, 3)) or (np.shape(img) == (1, 400, 400, 3)))


def test_uniform_augment(plot=False, num_ops=2):
"""
Test UniformAugment
@@ -262,6 +280,7 @@ def test_cpp_uniform_augment_random_crop_badinput(num_ops=1):


if __name__ == "__main__":
test_uniform_augment_callable(num_ops=2)
test_uniform_augment(num_ops=1, plot=True)
test_cpp_uniform_augment(num_ops=1, plot=True)
test_cpp_uniform_augment_exception_pyops(num_ops=1)


+ 13
- 1
tests/ut/python/dataset/test_vocab.py View File

@@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,6 +18,7 @@ import numpy as np
import mindspore.dataset as ds
import mindspore.dataset.text as text
import mindspore.common.dtype as mstype
from mindspore import log as logger

# this file contains "home is behind the world head" each word is 1 line
DATA_FILE = "../data/dataset/testVocab/words.txt"
@@ -25,6 +26,16 @@ VOCAB_FILE = "../data/dataset/testVocab/vocab_list.txt"
SIMPLE_VOCAB_FILE = "../data/dataset/testVocab/simple_vocab_list.txt"


def test_lookup_callable():
"""
Test lookup is callable
"""
logger.info("test_lookup_callable")
vocab = text.Vocab.from_list(['深', '圳', '欢', '迎', '您'])
lookup = text.Lookup(vocab)
word = "迎"
assert lookup(word) == 3

def test_from_list_tutorial():
vocab = text.Vocab.from_list("home IS behind the world ahead !".split(" "), ["<pad>", "<unk>"], True)
lookup = text.Lookup(vocab, "<unk>")
@@ -171,6 +182,7 @@ def test_lookup_cast_type():


if __name__ == '__main__':
test_lookup_callable()
test_from_dict_exception()
test_from_list_tutorial()
test_from_file_tutorial()


Loading…
Cancel
Save