| @@ -131,6 +131,7 @@ Status BasicTokenizerOp::CaseFoldWithoutUnusedWords(const std::string_view &text | |||||
| Status BasicTokenizerOp::CaseFoldWithoutUnusedWords(const std::shared_ptr<Tensor> &input, | Status BasicTokenizerOp::CaseFoldWithoutUnusedWords(const std::shared_ptr<Tensor> &input, | ||||
| std::shared_ptr<Tensor> *output) { | std::shared_ptr<Tensor> *output) { | ||||
| IO_CHECK(input, output); | IO_CHECK(input, output); | ||||
| CHECK_FAIL_RETURN_UNEXPECTED(input->type() == DataType::DE_STRING, "Input tensor not of type string"); | |||||
| std::vector<std::string> strs(input->Size()); | std::vector<std::string> strs(input->Size()); | ||||
| int i = 0; | int i = 0; | ||||
| for (auto iter = input->begin<std::string_view>(); iter != input->end<std::string_view>(); iter++) { | for (auto iter = input->begin<std::string_view>(); iter != input->end<std::string_view>(); iter++) { | ||||
| @@ -29,6 +29,7 @@ namespace dataset { | |||||
| Status CaseFoldOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) { | Status CaseFoldOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) { | ||||
| IO_CHECK(input, output); | IO_CHECK(input, output); | ||||
| CHECK_FAIL_RETURN_UNEXPECTED(input->type() == DataType::DE_STRING, "Input tensor not of type string"); | |||||
| icu::ErrorCode error; | icu::ErrorCode error; | ||||
| const icu::Normalizer2 *nfkc_case_fold = icu::Normalizer2::getNFKCCasefoldInstance(error); | const icu::Normalizer2 *nfkc_case_fold = icu::Normalizer2::getNFKCCasefoldInstance(error); | ||||
| CHECK_FAIL_RETURN_UNEXPECTED(error.isSuccess(), "getNFKCCasefoldInstance failed."); | CHECK_FAIL_RETURN_UNEXPECTED(error.isSuccess(), "getNFKCCasefoldInstance failed."); | ||||
| @@ -29,6 +29,8 @@ namespace dataset { | |||||
| const NormalizeForm NormalizeUTF8Op::kDefNormalizeForm = NormalizeForm::kNfkc; | const NormalizeForm NormalizeUTF8Op::kDefNormalizeForm = NormalizeForm::kNfkc; | ||||
| Status NormalizeUTF8Op::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) { | Status NormalizeUTF8Op::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) { | ||||
| IO_CHECK(input, output); | IO_CHECK(input, output); | ||||
| CHECK_FAIL_RETURN_UNEXPECTED(input->type() == DataType::DE_STRING, "Input tensor not of type string"); | |||||
| icu::ErrorCode error; | icu::ErrorCode error; | ||||
| const icu::Normalizer2 *normalize = nullptr; | const icu::Normalizer2 *normalize = nullptr; | ||||
| switch (normalize_form_) { | switch (normalize_form_) { | ||||
| @@ -42,6 +42,7 @@ Status RegexReplaceOp::RegexReplace(icu::RegexMatcher *const matcher, const std: | |||||
| Status RegexReplaceOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) { | Status RegexReplaceOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) { | ||||
| IO_CHECK(input, output); | IO_CHECK(input, output); | ||||
| CHECK_FAIL_RETURN_UNEXPECTED(input->type() == DataType::DE_STRING, "Input tensor not of type string"); | |||||
| UErrorCode icu_error = U_ZERO_ERROR; | UErrorCode icu_error = U_ZERO_ERROR; | ||||
| icu::RegexMatcher matcher(pattern_, 0, icu_error); | icu::RegexMatcher matcher(pattern_, 0, icu_error); | ||||
| CHECK_FAIL_RETURN_UNEXPECTED(U_SUCCESS(icu_error), "Create icu RegexMatcher failed, you may input one error pattern"); | CHECK_FAIL_RETURN_UNEXPECTED(U_SUCCESS(icu_error), "Create icu RegexMatcher failed, you may input one error pattern"); | ||||
| @@ -1940,7 +1940,7 @@ class BlockReleasePair: | |||||
| Args: | Args: | ||||
| init_release_rows (int): Number of lines to allow through the pipeline. | init_release_rows (int): Number of lines to allow through the pipeline. | ||||
| callback (function): The callback function that will be called when release is called. | |||||
| callback (function): The callback function that will be called when release is called (default=None). | |||||
| """ | """ | ||||
| def __init__(self, init_release_rows, callback=None): | def __init__(self, init_release_rows, callback=None): | ||||
| @@ -2015,7 +2015,7 @@ class SyncWaitDataset(Dataset): | |||||
| input_dataset (Dataset): Input dataset to apply flow control. | input_dataset (Dataset): Input dataset to apply flow control. | ||||
| num_batch (int): Number of batches without blocking at the start of each epoch. | num_batch (int): Number of batches without blocking at the start of each epoch. | ||||
| condition_name (str): Condition name that is used to toggle sending next row. | condition_name (str): Condition name that is used to toggle sending next row. | ||||
| callback (function): Callback function that will be invoked when sync_update is called. | |||||
| callback (function): Callback function that will be invoked when sync_update is called (default=None). | |||||
| Raises: | Raises: | ||||
| RuntimeError: If condition name already exists. | RuntimeError: If condition name already exists. | ||||
| @@ -270,6 +270,21 @@ def test_simple_sync_wait_empty_condition_name(): | |||||
| dataset.sync_update(condition_name="", data=data) | dataset.sync_update(condition_name="", data=data) | ||||
| def test_sync_exception_06(): | |||||
| """ | |||||
| Test sync: with string batch size | |||||
| """ | |||||
| logger.info("test_sync_exception_03") | |||||
| dataset = ds.GeneratorDataset(gen, column_names=["input"]) | |||||
| aug = Augment(0) | |||||
| # try to create dataset with batch_size < 0 | |||||
| with pytest.raises(TypeError) as e: | |||||
| dataset.sync_wait(condition_name="every batch", num_batch="123", callback=aug.update) | |||||
| assert "is not of type (<class 'int'>" in str(e.value) | |||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||
| test_simple_sync_wait() | test_simple_sync_wait() | ||||
| test_simple_shuffle_sync() | test_simple_shuffle_sync() | ||||
| @@ -279,6 +294,7 @@ if __name__ == "__main__": | |||||
| test_sync_exception_03() | test_sync_exception_03() | ||||
| test_sync_exception_04() | test_sync_exception_04() | ||||
| test_sync_exception_05() | test_sync_exception_05() | ||||
| test_sync_exception_06() | |||||
| test_sync_epoch() | test_sync_epoch() | ||||
| test_multiple_iterators() | test_multiple_iterators() | ||||
| test_simple_sync_wait_empty_condition_name() | test_simple_sync_wait_empty_condition_name() | ||||