You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

vectors.cc 5.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. /**
  2. * Copyright 2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "minddata/dataset/text/vectors.h"
  17. #include "utils/file_utils.h"
  18. namespace mindspore {
  19. namespace dataset {
  20. Status Vectors::InferShape(const std::string &path, int32_t max_vectors, int32_t *num_lines, int32_t *header_num_lines,
  21. int32_t *vector_dim) {
  22. RETURN_UNEXPECTED_IF_NULL(num_lines);
  23. RETURN_UNEXPECTED_IF_NULL(header_num_lines);
  24. RETURN_UNEXPECTED_IF_NULL(vector_dim);
  25. std::ifstream file_reader;
  26. file_reader.open(path, std::ios::in);
  27. CHECK_FAIL_RETURN_UNEXPECTED(file_reader.is_open(), "Vectors: invalid file, failed to open vector file: " + path);
  28. *num_lines = 0, *header_num_lines = 0, *vector_dim = -1;
  29. std::string line, row;
  30. while (std::getline(file_reader, line)) {
  31. if (*vector_dim == -1) {
  32. std::vector<std::string> vec;
  33. std::istringstream line_reader(line);
  34. while (std::getline(line_reader, row, ' ')) {
  35. vec.push_back(row);
  36. }
  37. // The number of rows and dimensions can be obtained directly from the information header.
  38. const int kInfoHeaderSize = 2;
  39. if (vec.size() == kInfoHeaderSize) {
  40. (*header_num_lines)++;
  41. } else {
  42. *vector_dim = vec.size() - 1;
  43. (*num_lines)++;
  44. }
  45. } else {
  46. (*num_lines)++;
  47. }
  48. }
  49. CHECK_FAIL_RETURN_UNEXPECTED(*num_lines > 0, "Vectors: invalid file, file is empty.");
  50. if (max_vectors > 0) {
  51. *num_lines = std::min(max_vectors, *num_lines); // Determine the true rows.
  52. }
  53. return Status::OK();
  54. }
  55. Status Vectors::Load(const std::string &path, int32_t max_vectors,
  56. std::unordered_map<std::string, std::vector<float>> *map, int *vector_dim) {
  57. RETURN_UNEXPECTED_IF_NULL(map);
  58. RETURN_UNEXPECTED_IF_NULL(vector_dim);
  59. auto realpath = FileUtils::GetRealPath(common::SafeCStr(path));
  60. CHECK_FAIL_RETURN_UNEXPECTED(realpath.has_value(), "Vectors: get real path failed, path: " + path);
  61. auto file_path = realpath.value();
  62. CHECK_FAIL_RETURN_UNEXPECTED(max_vectors >= 0,
  63. "Vectors: max_vectors must be non negative, but got: " + std::to_string(max_vectors));
  64. int num_lines = 0, header_num_lines = 0;
  65. RETURN_IF_NOT_OK(InferShape(file_path, max_vectors, &num_lines, &header_num_lines, vector_dim));
  66. std::fstream file_reader;
  67. file_reader.open(file_path, std::ios::in);
  68. CHECK_FAIL_RETURN_UNEXPECTED(file_reader.is_open(),
  69. "Vectors: invalid file, failed to open vector file: " + file_path);
  70. while (header_num_lines > 0) {
  71. file_reader.ignore(std::numeric_limits<std::streamsize>::max(), '\n');
  72. header_num_lines--;
  73. }
  74. std::string line, token, vector_value;
  75. for (auto i = 0; i < num_lines; ++i) {
  76. std::getline(file_reader, line);
  77. std::istringstream line_reader(line);
  78. std::getline(line_reader, token, ' ');
  79. std::vector<float> vector_values;
  80. int dim = 0;
  81. while (line_reader >> vector_value) {
  82. dim++;
  83. vector_values.push_back(atof(vector_value.c_str()));
  84. }
  85. CHECK_FAIL_RETURN_UNEXPECTED(dim > 1, "Vectors: token with 1-dimensional vector.");
  86. CHECK_FAIL_RETURN_UNEXPECTED(dim == *vector_dim,
  87. "Vectors: all vectors must have the same number of dimensions, but got dim " +
  88. std::to_string(dim) + " while expecting " + std::to_string(*vector_dim));
  89. auto token_index = map->find(token);
  90. if (token_index == map->end()) {
  91. (*map)[token] = vector_values;
  92. }
  93. }
  94. return Status::OK();
  95. }
  96. Vectors::Vectors(const std::unordered_map<std::string, std::vector<float>> &map, int dim) {
  97. map_ = std::move(map);
  98. dim_ = dim;
  99. }
  100. Status Vectors::BuildFromFile(std::shared_ptr<Vectors> *vectors, const std::string &path, int32_t max_vectors) {
  101. std::unordered_map<std::string, std::vector<float>> map;
  102. int vector_dim = -1;
  103. RETURN_IF_NOT_OK(Load(path, max_vectors, &map, &vector_dim));
  104. *vectors = std::make_shared<Vectors>(std::move(map), vector_dim);
  105. return Status::OK();
  106. }
  107. std::vector<float> Vectors::Lookup(const std::string &token, const std::vector<float> &unk_init,
  108. bool lower_case_backup) {
  109. std::vector<float> init_vec(dim_, 0);
  110. if (!unk_init.empty()) {
  111. if (unk_init.size() != dim_) {
  112. MS_LOG(WARNING) << "Vectors: size of unk_init is not the same as vectors, will initialize with zero vectors.";
  113. } else {
  114. init_vec = unk_init;
  115. }
  116. }
  117. std::string lower_token = token;
  118. if (lower_case_backup) {
  119. transform(lower_token.begin(), lower_token.end(), lower_token.begin(), ::tolower);
  120. }
  121. auto str_index = map_.find(lower_token);
  122. if (str_index == map_.end()) {
  123. return init_vec;
  124. } else {
  125. return str_index->second;
  126. }
  127. }
  128. } // namespace dataset
  129. } // namespace mindspore