You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

vectors.h 3.9 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. /**
  2. * Copyright 2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_TEXT_VECTORS_H_
  17. #define MINDSPORE_CCSRC_MINDDATA_DATASET_TEXT_VECTORS_H_
  18. #include <algorithm>
  19. #include <fstream>
  20. #include <limits>
  21. #include <memory>
  22. #include <string>
  23. #include <unordered_map>
  24. #include <utility>
  25. #include <vector>
  26. #include "minddata/dataset/core/tensor.h"
  27. #include "minddata/dataset/include/dataset/iterator.h"
  28. namespace mindspore {
  29. namespace dataset {
  30. /// \brief Pre-train word vectors.
  31. class Vectors {
  32. public:
  33. /// Constructor.
  34. Vectors() = default;
  35. /// Constructor.
  36. /// \param[in] map A map between string and vector.
  37. /// \param[in] dim Dimension of the vectors.
  38. Vectors(const std::unordered_map<std::string, std::vector<float>> &map, int dim);
  39. /// Destructor.
  40. virtual ~Vectors() = default;
  41. /// \brief Build Vectors from reading a pre-train vector file.
  42. /// \param[out] vectors Vectors object which contains the pre-train vectors.
  43. /// \param[in] path Path to the pre-trained word vector file.
  44. /// \param[in] max_vectors This can be used to limit the number of pre-trained vectors loaded (default=0, no limit).
  45. static Status BuildFromFile(std::shared_ptr<Vectors> *vectors, const std::string &path, int32_t max_vectors = 0);
  46. /// \brief Look up embedding vectors of token.
  47. /// \param[in] token A token to be looked up.
  48. /// \param[in] unk_init In case of the token is out-of-vectors (OOV), the result will be initialized with `unk_init`.
  49. /// (default={}, means to initialize with zero vectors).
  50. /// \param[in] lower_case_backup Whether to look up the token in the lower case (Default = false).
  51. /// \return The vector of the input token.
  52. virtual std::vector<float> Lookup(const std::string &token, const std::vector<float> &unk_init = {},
  53. bool lower_case_backup = false);
  54. /// \brief Getter of dimension.
  55. const int &Dim() const { return dim_; }
  56. protected:
  57. /// \brief Infer the shape of the pre-trained word vector file.
  58. /// \param[in] path Path to the pre-trained word vector file.
  59. /// \param[in] max_vectors Maximum number of pre-trained word vectors to be read.
  60. /// \param[out] num_lines The number of lines of the file.
  61. /// \param[out] header_num_lines The number of lines of file header.
  62. /// \param[out] vector_dim The dimension of the vectors in the file.
  63. static Status InferShape(const std::string &path, int32_t max_vectors, int32_t *num_lines, int32_t *header_num_lines,
  64. int32_t *vector_dim);
  65. /// \brief Load map from reading a pre-train vector file.
  66. /// \param[in] path Path to the pre-trained word vector file.
  67. /// \param[in] max_vectors This can be used to limit the number of pre-trained vectors loaded, must be non negative.
  68. /// \param[out] map The map between words and vectors.
  69. /// \param[out] vector_dim The dimension of the vectors in the file.
  70. static Status Load(const std::string &path, int32_t max_vectors,
  71. std::unordered_map<std::string, std::vector<float>> *map, int *vector_dim);
  72. int dim_;
  73. std::unordered_map<std::string, std::vector<float>> map_;
  74. };
  75. } // namespace dataset
  76. } // namespace mindspore
  77. #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_TEXT_VECTORS_H_