| @@ -323,6 +323,7 @@ Status DatasetOp::GetNumClasses(int64_t *num_classes) { | |||||
| return child_[child_.size() - 1]->GetNumClasses(num_classes); | return child_[child_.size() - 1]->GetNumClasses(num_classes); | ||||
| } else { | } else { | ||||
| // when num classes isn't found, the default behavior is to return -1 | // when num classes isn't found, the default behavior is to return -1 | ||||
| MS_LOG(WARNING) << "Num classes not defined for : " << Name(); | |||||
| *num_classes = -1; | *num_classes = -1; | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| @@ -54,15 +54,7 @@ Status GetterPass::GetterNodes::RunOnNode(std::shared_ptr<FilterOp> node, bool * | |||||
| Status GetterPass::RunOnTree(ExecutionTree *tree, bool *modified) { | Status GetterPass::RunOnTree(ExecutionTree *tree, bool *modified) { | ||||
| RETURN_IF_NOT_OK(pass_.Run(tree, modified)); | RETURN_IF_NOT_OK(pass_.Run(tree, modified)); | ||||
| // nested private class variables can be directly accessed by its outer class | |||||
| for (auto node : pass_.nodes_to_remove_) { | |||||
| DatasetOp *parent; | |||||
| node->Parent(&parent, 0); | |||||
| // only remove node whose is a single child of its parent | |||||
| if (parent != nullptr && parent->Children().size() == 1) { | |||||
| RETURN_IF_NOT_OK(node->Remove()); | |||||
| } | |||||
| } | |||||
| // currently the getter pass only disables call_back from the execution tree | |||||
| // clear the callback for selected ops (map when its GetOutputType/Shape) | // clear the callback for selected ops (map when its GetOutputType/Shape) | ||||
| for (auto node : pass_.nodes_to_clear_callback_) node->ClearCallbacks(); | for (auto node : pass_.nodes_to_clear_callback_) node->ClearCallbacks(); | ||||
| @@ -131,7 +131,7 @@ Status BasicTokenizerOp::CaseFoldWithoutUnusedWords(const std::string_view &text | |||||
| Status BasicTokenizerOp::CaseFoldWithoutUnusedWords(const std::shared_ptr<Tensor> &input, | Status BasicTokenizerOp::CaseFoldWithoutUnusedWords(const std::shared_ptr<Tensor> &input, | ||||
| std::shared_ptr<Tensor> *output) { | std::shared_ptr<Tensor> *output) { | ||||
| IO_CHECK(input, output); | IO_CHECK(input, output); | ||||
| CHECK_FAIL_RETURN_UNEXPECTED(input->type() == DataType::DE_STRING, "Input tensor not of type string"); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(input->type() == DataType::DE_STRING, "Input tensor not of type string."); | |||||
| std::vector<std::string> strs(input->Size()); | std::vector<std::string> strs(input->Size()); | ||||
| int i = 0; | int i = 0; | ||||
| for (auto iter = input->begin<std::string_view>(); iter != input->end<std::string_view>(); iter++) { | for (auto iter = input->begin<std::string_view>(); iter != input->end<std::string_view>(); iter++) { | ||||
| @@ -29,7 +29,7 @@ namespace dataset { | |||||
| Status CaseFoldOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) { | Status CaseFoldOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) { | ||||
| IO_CHECK(input, output); | IO_CHECK(input, output); | ||||
| CHECK_FAIL_RETURN_UNEXPECTED(input->type() == DataType::DE_STRING, "Input tensor not of type string"); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(input->type() == DataType::DE_STRING, "Input tensor not of type string."); | |||||
| icu::ErrorCode error; | icu::ErrorCode error; | ||||
| const icu::Normalizer2 *nfkc_case_fold = icu::Normalizer2::getNFKCCasefoldInstance(error); | const icu::Normalizer2 *nfkc_case_fold = icu::Normalizer2::getNFKCCasefoldInstance(error); | ||||
| CHECK_FAIL_RETURN_UNEXPECTED(error.isSuccess(), "getNFKCCasefoldInstance failed."); | CHECK_FAIL_RETURN_UNEXPECTED(error.isSuccess(), "getNFKCCasefoldInstance failed."); | ||||
| @@ -33,11 +33,11 @@ JiebaTokenizerOp::JiebaTokenizerOp(const std::string &hmm_path, const std::strin | |||||
| Status JiebaTokenizerOp::Compute(const TensorRow &input, TensorRow *output) { | Status JiebaTokenizerOp::Compute(const TensorRow &input, TensorRow *output) { | ||||
| IO_CHECK_VECTOR(input, output); | IO_CHECK_VECTOR(input, output); | ||||
| CHECK_FAIL_RETURN_UNEXPECTED(input.size() == 1, "Input should be one tensor"); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(input.size() == 1, "Input should be one tensor."); | |||||
| RETURN_UNEXPECTED_IF_NULL(jieba_parser_); | RETURN_UNEXPECTED_IF_NULL(jieba_parser_); | ||||
| if (input[0]->Rank() != 0 || input[0]->type() != DataType::DE_STRING) { | if (input[0]->Rank() != 0 || input[0]->type() != DataType::DE_STRING) { | ||||
| RETURN_STATUS_UNEXPECTED("the input tensor should be scalar string tensor"); | |||||
| RETURN_STATUS_UNEXPECTED("the input tensor should be scalar string tensor."); | |||||
| } | } | ||||
| std::string_view sentence_v; | std::string_view sentence_v; | ||||
| @@ -35,7 +35,7 @@ NgramOp::NgramOp(const std::vector<int32_t> &ngrams, int32_t l_len, int32_t r_le | |||||
| Status NgramOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) { | Status NgramOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) { | ||||
| IO_CHECK(input, output); | IO_CHECK(input, output); | ||||
| CHECK_FAIL_RETURN_UNEXPECTED(input->type() == DataType::DE_STRING && input->Rank() == 1, "Not a 1-D str Tensor"); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(input->type() == DataType::DE_STRING && input->Rank() == 1, "Not a 1-D str Tensor."); | |||||
| std::vector<int32_t> offsets; // offsets for each str | std::vector<int32_t> offsets; // offsets for each str | ||||
| std::vector<std::string> res; // holds the result of ngrams | std::vector<std::string> res; // holds the result of ngrams | ||||
| std::string str_buffer; // concat all pad tokens with string interleaved with separators | std::string str_buffer; // concat all pad tokens with string interleaved with separators | ||||
| @@ -60,7 +60,7 @@ Status NgramOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Te | |||||
| if (end_ind - start_ind <= n) { | if (end_ind - start_ind <= n) { | ||||
| res.emplace_back(std::string()); // push back empty string | res.emplace_back(std::string()); // push back empty string | ||||
| } else { | } else { | ||||
| CHECK_FAIL_RETURN_UNEXPECTED(end_ind - n >= 0, "Incorrect loop condition"); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(end_ind - n >= 0, "Incorrect loop condition."); | |||||
| for (int i = start_ind; i < end_ind - n; i++) { | for (int i = start_ind; i < end_ind - n; i++) { | ||||
| res.emplace_back(str_buffer.substr(offsets[i], offsets[i + n] - offsets[i] - separator_.size())); | res.emplace_back(str_buffer.substr(offsets[i], offsets[i + n] - offsets[i] - separator_.size())); | ||||
| @@ -29,7 +29,7 @@ namespace dataset { | |||||
| const NormalizeForm NormalizeUTF8Op::kDefNormalizeForm = NormalizeForm::kNfkc; | const NormalizeForm NormalizeUTF8Op::kDefNormalizeForm = NormalizeForm::kNfkc; | ||||
| Status NormalizeUTF8Op::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) { | Status NormalizeUTF8Op::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) { | ||||
| IO_CHECK(input, output); | IO_CHECK(input, output); | ||||
| CHECK_FAIL_RETURN_UNEXPECTED(input->type() == DataType::DE_STRING, "Input tensor not of type string"); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(input->type() == DataType::DE_STRING, "Input tensor not of type string."); | |||||
| icu::ErrorCode error; | icu::ErrorCode error; | ||||
| const icu::Normalizer2 *normalize = nullptr; | const icu::Normalizer2 *normalize = nullptr; | ||||
| @@ -40,26 +40,26 @@ Status NormalizeUTF8Op::Compute(const std::shared_ptr<Tensor> &input, std::share | |||||
| } | } | ||||
| case NormalizeForm::kNfc: { | case NormalizeForm::kNfc: { | ||||
| normalize = icu::Normalizer2::getNFCInstance(error); | normalize = icu::Normalizer2::getNFCInstance(error); | ||||
| CHECK_FAIL_RETURN_UNEXPECTED(error.isSuccess(), "getNFCInstance failed"); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(error.isSuccess(), "getNFCInstance failed."); | |||||
| break; | break; | ||||
| } | } | ||||
| case NormalizeForm::kNfkc: { | case NormalizeForm::kNfkc: { | ||||
| normalize = icu::Normalizer2::getNFKCInstance(error); | normalize = icu::Normalizer2::getNFKCInstance(error); | ||||
| CHECK_FAIL_RETURN_UNEXPECTED(error.isSuccess(), "getNFKCInstance failed"); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(error.isSuccess(), "getNFKCInstance failed."); | |||||
| break; | break; | ||||
| } | } | ||||
| case NormalizeForm::kNfd: { | case NormalizeForm::kNfd: { | ||||
| normalize = icu::Normalizer2::getNFDInstance(error); | normalize = icu::Normalizer2::getNFDInstance(error); | ||||
| CHECK_FAIL_RETURN_UNEXPECTED(error.isSuccess(), "getNFDInstance failed"); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(error.isSuccess(), "getNFDInstance failed."); | |||||
| break; | break; | ||||
| } | } | ||||
| case NormalizeForm::kNfkd: { | case NormalizeForm::kNfkd: { | ||||
| normalize = icu::Normalizer2::getNFKDInstance(error); | normalize = icu::Normalizer2::getNFKDInstance(error); | ||||
| CHECK_FAIL_RETURN_UNEXPECTED(error.isSuccess(), "getNFKDInstance failed"); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(error.isSuccess(), "getNFKDInstance failed."); | |||||
| break; | break; | ||||
| } | } | ||||
| default: { | default: { | ||||
| RETURN_STATUS_UNEXPECTED("unexpected normalize form"); | |||||
| RETURN_STATUS_UNEXPECTED("Unexpected normalize form."); | |||||
| break; | break; | ||||
| } | } | ||||
| } | } | ||||
| @@ -68,7 +68,7 @@ Status NormalizeUTF8Op::Compute(const std::shared_ptr<Tensor> &input, std::share | |||||
| for (auto iter = input->begin<std::string_view>(); iter != input->end<std::string_view>(); iter++) { | for (auto iter = input->begin<std::string_view>(); iter != input->end<std::string_view>(); iter++) { | ||||
| icu::StringByteSink<std::string> sink(&strs[i++]); | icu::StringByteSink<std::string> sink(&strs[i++]); | ||||
| normalize->normalizeUTF8(0, icu::StringPiece((*iter).data(), (*iter).size()), sink, nullptr, error); | normalize->normalizeUTF8(0, icu::StringPiece((*iter).data(), (*iter).size()), sink, nullptr, error); | ||||
| CHECK_FAIL_RETURN_UNEXPECTED(error.isSuccess(), "normalizeUTF8 failed."); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(error.isSuccess(), "NormalizeUTF8 failed."); | |||||
| } | } | ||||
| return Tensor::CreateFromVector(strs, input->shape(), output); | return Tensor::CreateFromVector(strs, input->shape(), output); | ||||
| } | } | ||||
| @@ -25,7 +25,7 @@ namespace dataset { | |||||
| Status RegexReplaceOp::RegexReplace(icu::RegexMatcher *const matcher, const std::string_view &text, | Status RegexReplaceOp::RegexReplace(icu::RegexMatcher *const matcher, const std::string_view &text, | ||||
| std::string *out) const { | std::string *out) const { | ||||
| CHECK_FAIL_RETURN_UNEXPECTED((matcher != nullptr && out != nullptr), "Input is null"); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED((matcher != nullptr && out != nullptr), "Input is null."); | |||||
| UErrorCode icu_error = U_ZERO_ERROR; | UErrorCode icu_error = U_ZERO_ERROR; | ||||
| icu::UnicodeString unicode_text = icu::UnicodeString::fromUTF8(text); | icu::UnicodeString unicode_text = icu::UnicodeString::fromUTF8(text); | ||||
| matcher->reset(unicode_text); | matcher->reset(unicode_text); | ||||
| @@ -35,17 +35,18 @@ Status RegexReplaceOp::RegexReplace(icu::RegexMatcher *const matcher, const std: | |||||
| } else { | } else { | ||||
| unicode_out = matcher->replaceFirst(replace_, icu_error); | unicode_out = matcher->replaceFirst(replace_, icu_error); | ||||
| } | } | ||||
| CHECK_FAIL_RETURN_UNEXPECTED(U_SUCCESS(icu_error), "RegexReplace failed"); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(U_SUCCESS(icu_error), "RegexReplace failed."); | |||||
| unicode_out.toUTF8String(*out); | unicode_out.toUTF8String(*out); | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status RegexReplaceOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) { | Status RegexReplaceOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) { | ||||
| IO_CHECK(input, output); | IO_CHECK(input, output); | ||||
| CHECK_FAIL_RETURN_UNEXPECTED(input->type() == DataType::DE_STRING, "Input tensor not of type string"); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(input->type() == DataType::DE_STRING, "Input tensor not of type string."); | |||||
| UErrorCode icu_error = U_ZERO_ERROR; | UErrorCode icu_error = U_ZERO_ERROR; | ||||
| icu::RegexMatcher matcher(pattern_, 0, icu_error); | icu::RegexMatcher matcher(pattern_, 0, icu_error); | ||||
| CHECK_FAIL_RETURN_UNEXPECTED(U_SUCCESS(icu_error), "Create icu RegexMatcher failed, you may input one error pattern"); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(U_SUCCESS(icu_error), | |||||
| "Create icu RegexMatcher failed, you may input one error pattern."); | |||||
| std::vector<std::string> strs(input->Size()); | std::vector<std::string> strs(input->Size()); | ||||
| int i = 0; | int i = 0; | ||||
| for (auto iter = input->begin<std::string_view>(); iter != input->end<std::string_view>(); iter++) { | for (auto iter = input->begin<std::string_view>(); iter != input->end<std::string_view>(); iter++) { | ||||
| @@ -56,7 +56,7 @@ Status SentencePieceTokenizerOp::Compute(const std::shared_ptr<Tensor> &input, s | |||||
| } | } | ||||
| if (input->Rank() != 0 || input->type() != DataType::DE_STRING) { | if (input->Rank() != 0 || input->type() != DataType::DE_STRING) { | ||||
| RETURN_STATUS_UNEXPECTED("the input tensor should be scalar string tensor"); | |||||
| RETURN_STATUS_UNEXPECTED("Input tensor should be scalar string tensor."); | |||||
| } | } | ||||
| std::string_view sentence_v; | std::string_view sentence_v; | ||||
| @@ -67,14 +67,14 @@ Status SentencePieceTokenizerOp::Compute(const std::shared_ptr<Tensor> &input, s | |||||
| std::vector<std::string> pieces; | std::vector<std::string> pieces; | ||||
| auto status = processor_.Encode(sentence, &pieces); | auto status = processor_.Encode(sentence, &pieces); | ||||
| if (!status.ok()) { | if (!status.ok()) { | ||||
| RETURN_STATUS_UNEXPECTED("sentence piece tokenizer error"); | |||||
| RETURN_STATUS_UNEXPECTED("Sentence piece tokenizer error."); | |||||
| } | } | ||||
| RETURN_IF_NOT_OK(Tensor::CreateFromVector(pieces, output)); | RETURN_IF_NOT_OK(Tensor::CreateFromVector(pieces, output)); | ||||
| } else { | } else { | ||||
| std::vector<int> ids; | std::vector<int> ids; | ||||
| auto status = processor_.Encode(sentence, &ids); | auto status = processor_.Encode(sentence, &ids); | ||||
| if (!status.ok()) { | if (!status.ok()) { | ||||
| RETURN_STATUS_UNEXPECTED("sentence piece tokenizer error"); | |||||
| RETURN_STATUS_UNEXPECTED("Sentence piece tokenizer error."); | |||||
| } | } | ||||
| RETURN_IF_NOT_OK(Tensor::CreateFromVector(ids, output)); | RETURN_IF_NOT_OK(Tensor::CreateFromVector(ids, output)); | ||||
| } | } | ||||
| @@ -84,15 +84,15 @@ Status SentencePieceTokenizerOp::Compute(const std::shared_ptr<Tensor> &input, s | |||||
| Status SentencePieceTokenizerOp::GetModelRealPath(const std::string &model_path, const std::string &filename) { | Status SentencePieceTokenizerOp::GetModelRealPath(const std::string &model_path, const std::string &filename) { | ||||
| char real_path[PATH_MAX] = {0}; | char real_path[PATH_MAX] = {0}; | ||||
| if (file_path_.size() >= PATH_MAX) { | if (file_path_.size() >= PATH_MAX) { | ||||
| RETURN_STATUS_UNEXPECTED("sentence piece model path is invalid."); | |||||
| RETURN_STATUS_UNEXPECTED("Sentence piece model path is invalid."); | |||||
| } | } | ||||
| #if defined(_WIN32) || defined(_WIN64) | #if defined(_WIN32) || defined(_WIN64) | ||||
| if (_fullpath(real_path, common::SafeCStr(model_path), PATH_MAX) == nullptr) { | if (_fullpath(real_path, common::SafeCStr(model_path), PATH_MAX) == nullptr) { | ||||
| RETURN_STATUS_UNEXPECTED("sentence piece model path is invalid."); | |||||
| RETURN_STATUS_UNEXPECTED("Sentence piece model path is invalid."); | |||||
| } | } | ||||
| #else | #else | ||||
| if (realpath(common::SafeCStr(model_path), real_path) == nullptr) { | if (realpath(common::SafeCStr(model_path), real_path) == nullptr) { | ||||
| RETURN_STATUS_UNEXPECTED("sentence piece model path is invalid."); | |||||
| RETURN_STATUS_UNEXPECTED("Sentence piece model path is invalid."); | |||||
| } | } | ||||
| #endif | #endif | ||||
| std::string abs_path = real_path; | std::string abs_path = real_path; | ||||
| @@ -29,7 +29,7 @@ Status TruncateSequencePairOp::Compute(const TensorRow &input, TensorRow *output | |||||
| std::shared_ptr<Tensor> seq1 = input[0]; | std::shared_ptr<Tensor> seq1 = input[0]; | ||||
| std::shared_ptr<Tensor> seq2 = input[1]; | std::shared_ptr<Tensor> seq2 = input[1]; | ||||
| CHECK_FAIL_RETURN_UNEXPECTED(seq1->shape().Rank() == 1 && seq2->shape().Rank() == 1, | CHECK_FAIL_RETURN_UNEXPECTED(seq1->shape().Rank() == 1 && seq2->shape().Rank() == 1, | ||||
| "Both sequences should be of rank 1"); | |||||
| "Both sequences should be of rank 1."); | |||||
| dsize_t length1 = seq1->shape()[0]; | dsize_t length1 = seq1->shape()[0]; | ||||
| dsize_t length2 = seq2->shape()[0]; | dsize_t length2 = seq2->shape()[0]; | ||||
| dsize_t outLength1 = length1; | dsize_t outLength1 = length1; | ||||
| @@ -31,9 +31,9 @@ const bool UnicodeCharTokenizerOp::kDefWithOffsets = false; | |||||
| Status UnicodeCharTokenizerOp::Compute(const TensorRow &input, TensorRow *output) { | Status UnicodeCharTokenizerOp::Compute(const TensorRow &input, TensorRow *output) { | ||||
| IO_CHECK_VECTOR(input, output); | IO_CHECK_VECTOR(input, output); | ||||
| CHECK_FAIL_RETURN_UNEXPECTED(input.size() == 1, "Input should be one tensor"); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(input.size() == 1, "Input should be one tensor."); | |||||
| if (input[0]->Rank() != 0 || input[0]->type() != DataType::DE_STRING) { | if (input[0]->Rank() != 0 || input[0]->type() != DataType::DE_STRING) { | ||||
| RETURN_STATUS_UNEXPECTED("The input tensor should be scalar string tensor"); | |||||
| RETURN_STATUS_UNEXPECTED("The input tensor should be scalar string tensor."); | |||||
| } | } | ||||
| std::string_view str; | std::string_view str; | ||||
| RETURN_IF_NOT_OK(input[0]->GetItemAt(&str, {})); | RETURN_IF_NOT_OK(input[0]->GetItemAt(&str, {})); | ||||
| @@ -35,9 +35,9 @@ const bool WhitespaceTokenizerOp::kDefWithOffsets = false; | |||||
| Status WhitespaceTokenizerOp::Compute(const TensorRow &input, TensorRow *output) { | Status WhitespaceTokenizerOp::Compute(const TensorRow &input, TensorRow *output) { | ||||
| IO_CHECK_VECTOR(input, output); | IO_CHECK_VECTOR(input, output); | ||||
| CHECK_FAIL_RETURN_UNEXPECTED(input.size() == 1, "Input should be one tensor"); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(input.size() == 1, "Input should be one tensor."); | |||||
| if (input[0]->Rank() != 0 || input[0]->type() != DataType::DE_STRING) { | if (input[0]->Rank() != 0 || input[0]->type() != DataType::DE_STRING) { | ||||
| RETURN_STATUS_UNEXPECTED("The input tensor should be scalar string tensor"); | |||||
| RETURN_STATUS_UNEXPECTED("The input tensor should be scalar string tensor."); | |||||
| } | } | ||||
| std::string_view str; | std::string_view str; | ||||
| RETURN_IF_NOT_OK(input[0]->GetItemAt(&str, {})); | RETURN_IF_NOT_OK(input[0]->GetItemAt(&str, {})); | ||||
| @@ -117,7 +117,7 @@ Status WordpieceTokenizerOp::GetTokens(const std::string &input_token, const uin | |||||
| Status WordpieceTokenizerOp::Compute(const TensorRow &input, TensorRow *output) { | Status WordpieceTokenizerOp::Compute(const TensorRow &input, TensorRow *output) { | ||||
| IO_CHECK_VECTOR(input, output); | IO_CHECK_VECTOR(input, output); | ||||
| if (input[0]->Rank() > 1 || input[0]->type() != DataType::DE_STRING) { | if (input[0]->Rank() > 1 || input[0]->type() != DataType::DE_STRING) { | ||||
| RETURN_STATUS_UNEXPECTED("The input tensor should be scalar or 1-D string tensor"); | |||||
| RETURN_STATUS_UNEXPECTED("The input tensor should be scalar or 1-D string tensor."); | |||||
| } | } | ||||
| dsize_t count = 0; | dsize_t count = 0; | ||||
| std::vector<std::string> out_tokens; | std::vector<std::string> out_tokens; | ||||
| @@ -95,9 +95,9 @@ TEST_F(MindDataTestOptimizationPass, MindDataTestOutputShapeAndTypePass) { | |||||
| // +- ( 4) <RandomDataOp>: [workers: 4] [total rows: 44] | // +- ( 4) <RandomDataOp>: [workers: 4] [total rows: 44] | ||||
| // | // | ||||
| // verify that Shuffle and RepeatOp are removed, but Batch and ProjectOp are not | |||||
| EXPECT_EQ(ss_str.find("ShuffleOp"), ss_str.npos); | |||||
| EXPECT_EQ(ss_str.find("RepeatOp"), ss_str.npos); | |||||
| // verify that no ops are removed, but Batch and ProjectOp are not | |||||
| EXPECT_NE(ss_str.find("ShuffleOp"), ss_str.npos); | |||||
| EXPECT_NE(ss_str.find("RepeatOp"), ss_str.npos); | |||||
| EXPECT_NE(ss_str.find("ProjectOp"), ss_str.npos); | EXPECT_NE(ss_str.find("ProjectOp"), ss_str.npos); | ||||
| EXPECT_NE(ss_str.find("BatchOp"), ss_str.npos); | EXPECT_NE(ss_str.find("BatchOp"), ss_str.npos); | ||||
| } | } | ||||
| @@ -129,8 +129,8 @@ TEST_F(MindDataTestOptimizationPass, MindDataTestDatasetSizePass) { | |||||
| exe_tree->Print(ss); | exe_tree->Print(ss); | ||||
| std::string ss_str = ss.str(); | std::string ss_str = ss.str(); | ||||
| // verify that Shuffle and RepeatOp are removed, but Batch and ProjectOp are not | |||||
| EXPECT_EQ(ss_str.find("ShuffleOp"), ss_str.npos); | |||||
| // verify that no ops are removed, but Batch and ProjectOp are not | |||||
| EXPECT_NE(ss_str.find("ShuffleOp"), ss_str.npos); | |||||
| EXPECT_NE(ss_str.find("RepeatOp"), ss_str.npos); | EXPECT_NE(ss_str.find("RepeatOp"), ss_str.npos); | ||||
| EXPECT_NE(ss_str.find("ProjectOp"), ss_str.npos); | EXPECT_NE(ss_str.find("ProjectOp"), ss_str.npos); | ||||
| EXPECT_NE(ss_str.find("BatchOp"), ss_str.npos); | EXPECT_NE(ss_str.find("BatchOp"), ss_str.npos); | ||||