You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

data_schema.h 9.5 kB

6 years ago
6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef DATASET_ENGINE_DATA_SCHEMA_H_
  17. #define DATASET_ENGINE_DATA_SCHEMA_H_
  18. #include <iostream>
  19. #include <map>
  20. #include <memory>
  21. #include <string>
  22. #include <unordered_map>
  23. #include <vector>
  24. #include <nlohmann/json.hpp>
  25. #include "dataset/core/constants.h"
  26. #include "dataset/core/data_type.h"
  27. #include "dataset/core/tensor_shape.h"
  28. #include "dataset/util/status.h"
  29. namespace mindspore {
  30. namespace dataset {
  31. // A simple class to provide meta info about a column.
  32. class ColDescriptor {
  33. public:
  34. // Constructor 1: Simple constructor that leaves things uninitialized.
  35. ColDescriptor();
  36. // Constructor 2: Main constructor
  37. // @param col_name - The name of the column
  38. // @param col_type - The DE Datatype of the column
  39. // @param tensor_impl - The (initial) type of tensor implementation for the column
  40. // @param rank - The number of dimension of the data
  41. // @param in_shape - option argument for input shape
  42. ColDescriptor(const std::string &col_name, DataType col_type, TensorImpl tensor_impl, int32_t rank,
  43. const TensorShape *in_shape = nullptr);
  44. // Explicit copy constructor is required
  45. // @param in_cd - the source ColDescriptor
  46. ColDescriptor(const ColDescriptor &in_cd);
  47. // Assignment overload
  48. // @param in_cd - the source ColDescriptor
  49. ColDescriptor &operator=(const ColDescriptor &in_cd);
  50. // Destructor
  51. ~ColDescriptor();
  52. // A print method typically used for debugging
  53. // @param out - The output stream to write output to
  54. void Print(std::ostream &out) const;
  55. // Given a number of elements, this function will compute what the actual Tensor shape would be.
  56. // If there is no starting TensorShape in this column, or if there is a shape but it contains
  57. // an unknown dimension, then the output shape returned shall resolve dimensions as needed.
  58. // @param num_elements - The number of elements in the data for a Tensor
  59. // @param out_shape - The materialized output Tensor shape
  60. // @return Status - The error code return
  61. Status MaterializeTensorShape(int32_t num_elements, TensorShape *out_shape) const;
  62. // << Stream output operator overload
  63. // @notes This allows you to write the debug print info using stream operators
  64. // @param out - reference to the output stream being overloaded
  65. // @param cd - reference to the ColDescriptor to display
  66. // @return - the output stream must be returned
  67. friend std::ostream &operator<<(std::ostream &out, const ColDescriptor &cd) {
  68. cd.Print(out);
  69. return out;
  70. }
  71. // getter function
  72. // @return The column's DataType
  73. DataType type() const { return type_; }
  74. // getter function
  75. // @return The column's rank
  76. int32_t rank() const { return rank_; }
  77. // getter function
  78. // @return The column's name
  79. std::string name() const { return col_name_; }
  80. // getter function
  81. // @return The column's shape
  82. TensorShape shape() const;
  83. // getter function
  84. // @return TF if the column has an assigned fixed shape.
  85. bool hasShape() const { return tensor_shape_ != nullptr; }
  86. // getter function
  87. // @return The column's tensor implementation type
  88. TensorImpl tensorImpl() const { return tensor_impl_; }
  89. private:
  90. DataType type_; // The columns type
  91. int32_t rank_; // The rank for this column (number of dimensions)
  92. TensorImpl tensor_impl_; // The initial flavour of the tensor for this column.
  93. std::unique_ptr<TensorShape> tensor_shape_; // The fixed shape (if given by user)
  94. std::string col_name_; // The name of the column
  95. };
  96. // A list of the columns.
  97. class DataSchema {
  98. public:
  99. // Constructor
  100. DataSchema();
  101. // Destructor
  102. ~DataSchema();
  103. // Populates the schema with a dataset type from a json file. It does not populate any of the
  104. // column info. To populate everything, use loadSchema() afterwards.
  105. // @param schema_file_path - Absolute path to the schema file to use for getting dataset type info.
  106. Status LoadDatasetType(const std::string &schema_file_path);
  107. // Parses a schema json file and populates the columns and meta info.
  108. // @param schema_file_path - the schema file that has the column's info to load
  109. // @param columns_to_load - list of strings for columns to load. if empty, assumes all columns.
  110. // @return Status - The error code return
  111. Status LoadSchemaFile(const std::string &schema_file_path, const std::vector<std::string> &columns_to_load);
  112. // Parses a schema JSON string and populates the columns and meta info.
  113. // @param schema_json_string - the schema file that has the column's info to load
  114. // @param columns_to_load - list of strings for columns to load. if empty, assumes all columns.
  115. // @return Status - The error code return
  116. Status LoadSchemaString(const std::string &schema_json_string, const std::vector<std::string> &columns_to_load);
  117. // A print method typically used for debugging
  118. // @param out - The output stream to write output to
  119. void Print(std::ostream &out) const;
  120. // << Stream output operator overload
  121. // @notes This allows you to write the debug print info using stream operators
  122. // @param out - reference to the output stream being overloaded
  123. // @param ds - reference to the DataSchema to display
  124. // @return - the output stream must be returned
  125. friend std::ostream &operator<<(std::ostream &out, const DataSchema &ds) {
  126. ds.Print(out);
  127. return out;
  128. }
  129. // Adds a column descriptor to the schema
  130. // @param cd - The ColDescriptor to add
  131. // @return Status - The error code return
  132. Status AddColumn(const ColDescriptor &cd);
  133. // Setter
  134. // @param in_type - The Dataset type to set into the schema
  135. void set_dataset_type(DatasetType in_type) { dataset_type_ = in_type; }
  136. // getter
  137. // @return The dataset type of the schema
  138. DatasetType dataset_type() const { return dataset_type_; }
  139. // getter
  140. // @return The reference to a ColDescriptor to get (const version)
  141. const ColDescriptor &column(int32_t idx) const;
  142. // getter
  143. // @return The number of columns in the schema
  144. int32_t NumColumns() const { return col_descs_.size(); }
  145. bool Empty() const { return NumColumns() == 0; }
  146. std::string dir_structure() const { return dir_structure_; }
  147. std::string dataset_type_str() const { return dataset_type_str_; }
  148. int64_t num_rows() const { return num_rows_; }
  149. static const char DEFAULT_DATA_SCHEMA_FILENAME[];
  150. // Loops through all columns in the schema and returns a map with the column
  151. // name to column index number.
  152. // @param out_column_name_map - The output map of columns names to column index
  153. // @return Status - The error code return
  154. Status GetColumnNameMap(std::unordered_map<std::string, int32_t> *out_column_name_map);
  155. private:
  156. // Internal helper function. Parses the json schema file in any order and produces a schema that
  157. // does not follow any particular order (json standard does not enforce any ordering protocol).
  158. // This one produces a schema that contains all of the columns from the schema file.
  159. // @param column_tree - The nlohmann tree from the json file to parse
  160. // @return Status - The error code return
  161. Status AnyOrderLoad(nlohmann::json column_tree);
  162. // Internal helper function. For each input column name, perform a lookup to the json document to
  163. // find the matching column. When the match is found, process that column to build the column
  164. // descriptor and add to the schema in the order in which the input column names are given.
  165. // @param column_tree - The nlohmann tree from the json file to parse
  166. // @param columns_to_load - list of strings for the columns to add to the schema
  167. // @return Status - The error code return
  168. Status ColumnOrderLoad(nlohmann::json column_tree, const std::vector<std::string> &columns_to_load);
  169. // Internal helper function. Given the json tree for a given column, load it into our schema.
  170. // @param columnTree - The nlohmann child tree for a given column to load.
  171. // @param col_name - The string name of the column for that subtree.
  172. // @return Status - The error code return
  173. Status ColumnLoad(nlohmann::json column_child_tree, const std::string &col_name);
  174. // Internal helper function. Performs sanity checks on the json file setup.
  175. // @param js - The nlohmann tree for the schema file
  176. // @return Status - The error code return
  177. Status PreLoadExceptionCheck(const nlohmann::json &js);
  178. DatasetType GetDatasetTYpeFromString(const std::string &type) const;
  179. std::vector<ColDescriptor> col_descs_; // Vector of column descriptors
  180. std::string dataset_type_str_; // A string that represents the type of dataset
  181. DatasetType dataset_type_; // The numeric form of the dataset type from enum
  182. std::string dir_structure_; // Implicit or flatten
  183. int64_t num_rows_;
  184. };
  185. } // namespace dataset
  186. } // namespace mindspore
  187. #endif // DATASET_ENGINE_DATA_SCHEMA_H_