You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tokenizer_op_test.cc 13 kB


  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <memory>
  17. #include <string>
  18. #include <string_view>
  19. #include "common/common.h"
  20. #include "dataset/text/kernels/basic_tokenizer_op.h"
  21. #include "dataset/text/kernels/case_fold_op.h"
  22. #include "dataset/text/kernels/normalize_utf8_op.h"
  23. #include "dataset/text/kernels/regex_replace_op.h"
  24. #include "dataset/text/kernels/regex_tokenizer_op.h"
  25. #include "dataset/text/kernels/unicode_char_tokenizer_op.h"
  26. #include "dataset/text/kernels/unicode_script_tokenizer_op.h"
  27. #include "dataset/text/kernels/whitespace_tokenizer_op.h"
  28. #include "gtest/gtest.h"
  29. #include "utils/log_adapter.h"
  30. using namespace mindspore::dataset;
  31. class MindDataTestTokenizerOp : public UT::Common {
  32. public:
  33. void CheckEqual(const std::shared_ptr<Tensor> &o,
  34. const std::vector<dsize_t> &index,
  35. const std::string &expect) {
  36. std::string_view str;
  37. Status s = o->GetItemAt(&str, index);
  38. EXPECT_TRUE(s.IsOk());
  39. EXPECT_EQ(str, expect);
  40. }
  41. };
  42. TEST_F(MindDataTestTokenizerOp, TestUnicodeCharTokenizerOp) {
  43. MS_LOG(INFO) << "Doing TestUnicodeCharTokenizerOp.";
  44. std::unique_ptr<UnicodeCharTokenizerOp> op(new UnicodeCharTokenizerOp());
  45. std::shared_ptr<Tensor> input = std::make_shared<Tensor>("Hello World!");
  46. std::shared_ptr<Tensor> output;
  47. Status s = op->Compute(input, &output);
  48. EXPECT_TRUE(s.IsOk());
  49. EXPECT_EQ(output->Size(), 12);
  50. EXPECT_EQ(output->Rank(), 1);
  51. MS_LOG(INFO) << "Out tensor1: " << output->ToString();
  52. CheckEqual(output, {0}, "H");
  53. CheckEqual(output, {1}, "e");
  54. CheckEqual(output, {2}, "l");
  55. CheckEqual(output, {3}, "l");
  56. CheckEqual(output, {4}, "o");
  57. CheckEqual(output, {5}, " ");
  58. CheckEqual(output, {6}, "W");
  59. CheckEqual(output, {7}, "o");
  60. CheckEqual(output, {8}, "r");
  61. CheckEqual(output, {9}, "l");
  62. CheckEqual(output, {10}, "d");
  63. CheckEqual(output, {11}, "!");
  64. input = std::make_shared<Tensor>("中国 你好!");
  65. s = op->Compute(input, &output);
  66. EXPECT_TRUE(s.IsOk());
  67. EXPECT_EQ(output->Size(), 6);
  68. EXPECT_EQ(output->Rank(), 1);
  69. MS_LOG(INFO) << "Out tensor2: " << output->ToString();
  70. CheckEqual(output, {0}, "中");
  71. CheckEqual(output, {1}, "国");
  72. CheckEqual(output, {2}, " ");
  73. CheckEqual(output, {3}, "你");
  74. CheckEqual(output, {4}, "好");
  75. CheckEqual(output, {5}, "!");
  76. input = std::make_shared<Tensor>("中");
  77. s = op->Compute(input, &output);
  78. EXPECT_TRUE(s.IsOk());
  79. EXPECT_EQ(output->Size(), 1);
  80. EXPECT_EQ(output->Rank(), 1);
  81. MS_LOG(INFO) << "Out tensor3: " << output->ToString();
  82. CheckEqual(output, {0}, "中");
  83. input = std::make_shared<Tensor>("H");
  84. s = op->Compute(input, &output);
  85. EXPECT_TRUE(s.IsOk());
  86. EXPECT_EQ(output->Size(), 1);
  87. EXPECT_EQ(output->Rank(), 1);
  88. MS_LOG(INFO) << "Out tensor4: " << output->ToString();
  89. CheckEqual(output, {0}, "H");
  90. input = std::make_shared<Tensor>(" ");
  91. s = op->Compute(input, &output);
  92. EXPECT_TRUE(s.IsOk());
  93. EXPECT_EQ(output->Size(), 2);
  94. EXPECT_EQ(output->Rank(), 1);
  95. MS_LOG(INFO) << "Out tensor5: " << output->ToString();
  96. CheckEqual(output, {0}, " ");
  97. CheckEqual(output, {1}, " ");
  98. input = std::make_shared<Tensor>("");
  99. s = op->Compute(input, &output);
  100. EXPECT_TRUE(s.IsOk());
  101. EXPECT_EQ(output->Size(), 1);
  102. EXPECT_EQ(output->Rank(), 1);
  103. MS_LOG(INFO) << "Out tensor6: " << output->ToString();
  104. CheckEqual(output, {0}, "");
  105. }
  106. TEST_F(MindDataTestTokenizerOp, TestWhitespaceTokenizerOp) {
  107. MS_LOG(INFO) << "Doing TestWhitespaceTokenizerOp.";
  108. std::unique_ptr<WhitespaceTokenizerOp> op(new WhitespaceTokenizerOp());
  109. std::shared_ptr<Tensor> input = std::make_shared<Tensor>("Welcome to China.");
  110. std::shared_ptr<Tensor> output;
  111. Status s = op->Compute(input, &output);
  112. EXPECT_TRUE(s.IsOk());
  113. EXPECT_EQ(output->Size(), 3);
  114. EXPECT_EQ(output->Rank(), 1);
  115. MS_LOG(INFO) << "Out tensor1: " << output->ToString();
  116. CheckEqual(output, {0}, "Welcome");
  117. CheckEqual(output, {1}, "to");
  118. CheckEqual(output, {2}, "China.");
  119. input = std::make_shared<Tensor>(" hello");
  120. s = op->Compute(input, &output);
  121. EXPECT_TRUE(s.IsOk());
  122. EXPECT_EQ(output->Size(), 1);
  123. EXPECT_EQ(output->Rank(), 1);
  124. MS_LOG(INFO) << "Out tensor2: " << output->ToString();
  125. CheckEqual(output, {0}, "hello");
  126. input = std::make_shared<Tensor>("hello");
  127. s = op->Compute(input, &output);
  128. EXPECT_TRUE(s.IsOk());
  129. EXPECT_EQ(output->Size(), 1);
  130. EXPECT_EQ(output->Rank(), 1);
  131. MS_LOG(INFO) << "Out tensor3: " << output->ToString();
  132. CheckEqual(output, {0}, "hello");
  133. input = std::make_shared<Tensor>("hello ");
  134. s = op->Compute(input, &output);
  135. EXPECT_TRUE(s.IsOk());
  136. EXPECT_EQ(output->Size(), 1);
  137. EXPECT_EQ(output->Rank(), 1);
  138. MS_LOG(INFO) << "Out tensor4: " << output->ToString();
  139. CheckEqual(output, {0}, "hello");
  140. input = std::make_shared<Tensor>(" ");
  141. s = op->Compute(input, &output);
  142. EXPECT_TRUE(s.IsOk());
  143. EXPECT_EQ(output->Size(), 1);
  144. EXPECT_EQ(output->Rank(), 1);
  145. MS_LOG(INFO) << "Out tensor5: " << output->ToString();
  146. CheckEqual(output, {0}, "");
  147. }
  148. TEST_F(MindDataTestTokenizerOp, TestUnicodeScriptTokenizer) {
  149. MS_LOG(INFO) << "Doing TestUnicodeScriptTokenizer.";
  150. std::unique_ptr<UnicodeScriptTokenizerOp> keep_whitespace_op(new UnicodeScriptTokenizerOp(true));
  151. std::unique_ptr<UnicodeScriptTokenizerOp> skip_whitespace_op(new UnicodeScriptTokenizerOp(false));
  152. std::shared_ptr<Tensor> input = std::make_shared<Tensor>("Welcome to China. \n 中国\t北京");
  153. std::shared_ptr<Tensor> output;
  154. Status s = keep_whitespace_op->Compute(input, &output);
  155. EXPECT_TRUE(s.IsOk());
  156. EXPECT_EQ(output->Size(), 10);
  157. EXPECT_EQ(output->Rank(), 1);
  158. MS_LOG(INFO) << "Out tensor1: " << output->ToString();
  159. CheckEqual(output, {0}, "Welcome");
  160. CheckEqual(output, {1}, " ");
  161. CheckEqual(output, {2}, "to");
  162. CheckEqual(output, {3}, " ");
  163. CheckEqual(output, {4}, "China");
  164. CheckEqual(output, {5}, ".");
  165. CheckEqual(output, {6}, " \n ");
  166. CheckEqual(output, {7}, "中国");
  167. CheckEqual(output, {8}, "\t");
  168. CheckEqual(output, {9}, "北京");
  169. s = skip_whitespace_op->Compute(input, &output);
  170. EXPECT_TRUE(s.IsOk());
  171. EXPECT_EQ(output->Size(), 6);
  172. EXPECT_EQ(output->Rank(), 1);
  173. MS_LOG(INFO) << "Out tensor2: " << output->ToString();
  174. CheckEqual(output, {0}, "Welcome");
  175. CheckEqual(output, {1}, "to");
  176. CheckEqual(output, {2}, "China");
  177. CheckEqual(output, {3}, ".");
  178. CheckEqual(output, {4}, "中国");
  179. CheckEqual(output, {5}, "北京");
  180. input = std::make_shared<Tensor>(" Welcome to 中国. ");
  181. s = skip_whitespace_op->Compute(input, &output);
  182. EXPECT_TRUE(s.IsOk());
  183. EXPECT_EQ(output->Size(), 4);
  184. EXPECT_EQ(output->Rank(), 1);
  185. MS_LOG(INFO) << "Out tensor3: " << output->ToString();
  186. CheckEqual(output, {0}, "Welcome");
  187. CheckEqual(output, {1}, "to");
  188. CheckEqual(output, {2}, "中国");
  189. CheckEqual(output, {3}, ".");
  190. s = keep_whitespace_op->Compute(input, &output);
  191. EXPECT_TRUE(s.IsOk());
  192. EXPECT_EQ(output->Size(), 8);
  193. EXPECT_EQ(output->Rank(), 1);
  194. MS_LOG(INFO) << "Out tensor4: " << output->ToString();
  195. CheckEqual(output, {0}, " ");
  196. CheckEqual(output, {1}, "Welcome");
  197. CheckEqual(output, {2}, " ");
  198. CheckEqual(output, {3}, "to");
  199. CheckEqual(output, {4}, " ");
  200. CheckEqual(output, {5}, "中国");
  201. CheckEqual(output, {6}, ".");
  202. CheckEqual(output, {7}, " ");
  203. input = std::make_shared<Tensor>("Hello");
  204. s = keep_whitespace_op->Compute(input, &output);
  205. EXPECT_TRUE(s.IsOk());
  206. EXPECT_EQ(output->Size(), 1);
  207. EXPECT_EQ(output->Rank(), 1);
  208. MS_LOG(INFO) << "Out tensor5: " << output->ToString();
  209. CheckEqual(output, {0}, "Hello");
  210. input = std::make_shared<Tensor>("H");
  211. s = keep_whitespace_op->Compute(input, &output);
  212. EXPECT_TRUE(s.IsOk());
  213. EXPECT_EQ(output->Size(), 1);
  214. EXPECT_EQ(output->Rank(), 1);
  215. MS_LOG(INFO) << "Out tensor6: " << output->ToString();
  216. CheckEqual(output, {0}, "H");
  217. input = std::make_shared<Tensor>("");
  218. s = keep_whitespace_op->Compute(input, &output);
  219. EXPECT_TRUE(s.IsOk());
  220. EXPECT_EQ(output->Size(), 1);
  221. EXPECT_EQ(output->Rank(), 1);
  222. MS_LOG(INFO) << "Out tensor7: " << output->ToString();
  223. CheckEqual(output, {0}, "");
  224. input = std::make_shared<Tensor>("Hello中国Hello世界");
  225. s = keep_whitespace_op->Compute(input, &output); EXPECT_TRUE(s.IsOk());
  226. EXPECT_EQ(output->Size(), 4);
  227. EXPECT_EQ(output->Rank(), 1);
  228. MS_LOG(INFO) << "Out tensor8: " << output->ToString();
  229. CheckEqual(output, {0}, "Hello");
  230. CheckEqual(output, {1}, "中国");
  231. CheckEqual(output, {2}, "Hello");
  232. CheckEqual(output, {3}, "世界");
  233. input = std::make_shared<Tensor>(" ");
  234. s = keep_whitespace_op->Compute(input, &output);
  235. EXPECT_TRUE(s.IsOk());
  236. EXPECT_EQ(output->Size(), 1);
  237. EXPECT_EQ(output->Rank(), 1);
  238. MS_LOG(INFO) << "Out tensor10: " << output->ToString();
  239. CheckEqual(output, {0}, " ");
  240. input = std::make_shared<Tensor>(" ");
  241. s = skip_whitespace_op->Compute(input, &output);
  242. EXPECT_TRUE(s.IsOk());
  243. EXPECT_EQ(output->Size(), 1);
  244. EXPECT_EQ(output->Rank(), 1);
  245. MS_LOG(INFO) << "Out tensor11: " << output->ToString();
  246. CheckEqual(output, {0}, "");
  247. }
  248. TEST_F(MindDataTestTokenizerOp, TestCaseFold) {
  249. MS_LOG(INFO) << "Doing TestCaseFold.";
  250. std::unique_ptr<CaseFoldOp> case_fold_op(new CaseFoldOp());
  251. std::shared_ptr<Tensor> input = std::make_shared<Tensor>("Welcome to China. \n 中国\t北京");
  252. std::shared_ptr<Tensor> output;
  253. Status s = case_fold_op->Compute(input, &output);
  254. EXPECT_TRUE(s.IsOk());
  255. EXPECT_EQ(output->Size(), 1);
  256. EXPECT_EQ(output->Rank(), 0);
  257. MS_LOG(INFO) << "Out tensor1: " << output->ToString();
  258. CheckEqual(output, {}, "welcome to china. \n 中国\t北京");
  259. }
  260. TEST_F(MindDataTestTokenizerOp, TestNormalize) {
  261. MS_LOG(INFO) << "Doing TestNormalize.";
  262. std::unique_ptr<NormalizeUTF8Op> nfc_normalize_op(new NormalizeUTF8Op(NormalizeForm::kNfc));
  263. std::unique_ptr<NormalizeUTF8Op> nfkc_normalize_op(new NormalizeUTF8Op(NormalizeForm::kNfkc));
  264. std::unique_ptr<NormalizeUTF8Op> nfd_normalize_op(new NormalizeUTF8Op(NormalizeForm::kNfd));
  265. std::unique_ptr<NormalizeUTF8Op> nfkd_normalize_op(new NormalizeUTF8Op(NormalizeForm::kNfkd));
  266. std::shared_ptr<Tensor> input = std::make_shared<Tensor>("ṩ");
  267. std::shared_ptr<Tensor> output;
  268. Status s = nfc_normalize_op->Compute(input, &output);
  269. EXPECT_TRUE(s.IsOk());
  270. MS_LOG(INFO) << "NFC str:" << output->ToString();
  271. nfkc_normalize_op->Compute(input, &output);
  272. EXPECT_TRUE(s.IsOk());
  273. MS_LOG(INFO) << "NFKC str:" << output->ToString();
  274. nfd_normalize_op->Compute(input, &output);
  275. EXPECT_TRUE(s.IsOk());
  276. MS_LOG(INFO) << "NFD str:" << output->ToString();
  277. nfkd_normalize_op->Compute(input, &output);
  278. EXPECT_TRUE(s.IsOk());
  279. MS_LOG(INFO) << "NFKD str:" << output->ToString();
  280. }
  281. TEST_F(MindDataTestTokenizerOp, TestRegexReplace) {
  282. MS_LOG(INFO) << "Doing TestRegexReplace.";
  283. std::unique_ptr<RegexReplaceOp> regex_replace_op(new RegexReplaceOp("\\s+", "_", true));
  284. std::shared_ptr<Tensor> input = std::make_shared<Tensor>("Welcome to China. \n 中国\t北京");
  285. std::shared_ptr<Tensor> output;
  286. Status s = regex_replace_op->Compute(input, &output);
  287. EXPECT_TRUE(s.IsOk());
  288. EXPECT_EQ(output->Size(), 1);
  289. EXPECT_EQ(output->Rank(), 0);
  290. MS_LOG(INFO) << "Out tensor1: " << output->ToString();
  291. CheckEqual(output, {}, "Welcome_to_China._中国_北京");
  292. }
  293. TEST_F(MindDataTestTokenizerOp, TestRegexTokenizer) {
  294. MS_LOG(INFO) << "Doing TestRegexTokenizerOp.";
  295. std::unique_ptr<RegexTokenizerOp> regex_tokenizer_op(new RegexTokenizerOp("\\p{Cc}|\\p{Cf}|\\s+", ""));
  296. std::shared_ptr<Tensor> input = std::make_shared<Tensor>("Welcome to China. \n 中国\t北京");
  297. std::shared_ptr<Tensor> output;
  298. Status s = regex_tokenizer_op->Compute(input, &output);
  299. EXPECT_TRUE(s.IsOk());
  300. }
  301. TEST_F(MindDataTestTokenizerOp, TestBasicTokenizer) {
  302. MS_LOG(INFO) << "Doing TestBasicTokenizer.";
  303. //bool lower_case, bool keep_whitespace,
  304. // NormalizeForm normalization_form, bool preserve_unused_token
  305. std::unique_ptr<BasicTokenizerOp> basic_tokenizer(new BasicTokenizerOp(true, true, NormalizeForm::kNone, false));
  306. std::shared_ptr<Tensor> input = std::make_shared<Tensor>("Welcome to China. 中国\t北京");
  307. std::shared_ptr<Tensor> output;
  308. Status s = basic_tokenizer->Compute(input, &output);
  309. EXPECT_TRUE(s.IsOk());
  310. }