You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

text.cc 5.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <unistd.h>
  17. #include "minddata/dataset/include/text.h"
  18. #include "minddata/dataset/text/kernels/lookup_op.h"
  19. #include "minddata/dataset/text/kernels/sentence_piece_tokenizer_op.h"
  20. #include "minddata/dataset/util/path.h"
  21. namespace mindspore {
  22. namespace dataset {
  23. // Transform operations for text.
  24. namespace text {
  25. // FUNCTIONS TO CREATE TEXT OPERATIONS
  26. // (In alphabetical order)
  27. std::shared_ptr<LookupOperation> Lookup(const std::shared_ptr<Vocab> &vocab, const std::string &unknown_token,
  28. const DataType &data_type) {
  29. auto op = std::make_shared<LookupOperation>(vocab, unknown_token, data_type);
  30. return op->ValidateParams() ? op : nullptr;
  31. }
  32. std::shared_ptr<SentencePieceTokenizerOperation> SentencePieceTokenizer(
  33. const std::shared_ptr<SentencePieceVocab> &vocab, SPieceTokenizerOutType out_type) {
  34. auto op = std::make_shared<SentencePieceTokenizerOperation>(vocab, out_type);
  35. return op->ValidateParams() ? op : nullptr;
  36. }
  37. std::shared_ptr<SentencePieceTokenizerOperation> SentencePieceTokenizer(const std::string &vocab_path,
  38. SPieceTokenizerOutType out_type) {
  39. auto op = std::make_shared<SentencePieceTokenizerOperation>(vocab_path, out_type);
  40. return op->ValidateParams() ? op : nullptr;
  41. }
  42. /* ####################################### Validator Functions ############################################ */
  43. /* ####################################### Derived TensorOperation classes ################################# */
  44. // (In alphabetical order)
  45. // LookupOperation
  46. LookupOperation::LookupOperation(const std::shared_ptr<Vocab> &vocab, const std::string &unknown_token,
  47. const DataType &data_type)
  48. : vocab_(vocab), unknown_token_(unknown_token), default_id_(Vocab::kNoTokenExists), data_type_(data_type) {}
  49. Status LookupOperation::ValidateParams() {
  50. if (vocab_ == nullptr) {
  51. std::string err_msg = "Lookup: vocab object type is incorrect or null.";
  52. MS_LOG(ERROR) << err_msg;
  53. RETURN_STATUS_SYNTAX_ERROR(err_msg);
  54. }
  55. default_id_ = vocab_->Lookup(unknown_token_);
  56. if (default_id_ == Vocab::kNoTokenExists) {
  57. std::string err_msg = "Lookup: " + unknown_token_ + " doesn't exist in vocab.";
  58. MS_LOG(ERROR) << err_msg;
  59. RETURN_STATUS_SYNTAX_ERROR(err_msg);
  60. }
  61. return Status::OK();
  62. }
  63. std::shared_ptr<TensorOp> LookupOperation::Build() {
  64. std::shared_ptr<LookupOp> tensor_op = std::make_shared<LookupOp>(vocab_, default_id_, data_type_);
  65. return tensor_op;
  66. }
  67. // SentencePieceTokenizerOperation
  68. SentencePieceTokenizerOperation::SentencePieceTokenizerOperation(const std::shared_ptr<SentencePieceVocab> &vocab,
  69. SPieceTokenizerOutType out_type)
  70. : vocab_(vocab), vocab_path_(std::string()), load_type_(SPieceTokenizerLoadType::kModel), out_type_(out_type) {}
  71. SentencePieceTokenizerOperation::SentencePieceTokenizerOperation(const std::string &vocab_path,
  72. SPieceTokenizerOutType out_type)
  73. : vocab_(nullptr), vocab_path_(vocab_path), load_type_(SPieceTokenizerLoadType::kFile), out_type_(out_type) {}
  74. Status SentencePieceTokenizerOperation::ValidateParams() {
  75. if (load_type_ == SPieceTokenizerLoadType::kModel) {
  76. if (vocab_ == nullptr) {
  77. std::string err_msg = "SentencePieceTokenizer: vocab object type is incorrect or null.";
  78. MS_LOG(ERROR) << err_msg;
  79. RETURN_STATUS_SYNTAX_ERROR(err_msg);
  80. }
  81. } else {
  82. Path vocab_file(vocab_path_);
  83. if (!vocab_file.Exists() || vocab_file.IsDirectory()) {
  84. std::string err_msg = "SentencePieceTokenizer : vocab file: [" + vocab_path_ + "] is invalid or does not exist.";
  85. MS_LOG(ERROR) << err_msg;
  86. RETURN_STATUS_SYNTAX_ERROR(err_msg);
  87. }
  88. if (access(vocab_file.toString().c_str(), R_OK) == -1) {
  89. std::string err_msg = "SentencePieceTokenizer : no access to specified dataset file: " + vocab_path_;
  90. MS_LOG(ERROR) << err_msg;
  91. RETURN_STATUS_SYNTAX_ERROR(err_msg);
  92. }
  93. }
  94. return Status::OK();
  95. }
  96. std::shared_ptr<TensorOp> SentencePieceTokenizerOperation::Build() {
  97. std::shared_ptr<SentencePieceTokenizerOp> tensor_op;
  98. if (load_type_ == SPieceTokenizerLoadType::kModel) {
  99. tensor_op = std::make_shared<SentencePieceTokenizerOp>(vocab_, load_type_, out_type_);
  100. } else {
  101. Path vocab_file(vocab_path_);
  102. std::string model_path = vocab_file.ParentPath();
  103. std::string model_filename = vocab_file.Basename();
  104. tensor_op = std::make_shared<SentencePieceTokenizerOp>(model_path, model_filename, load_type_, out_type_);
  105. }
  106. return tensor_op;
  107. }
  108. } // namespace text
  109. } // namespace dataset
  110. } // namespace mindspore