You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

validators.h 8.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199
  1. /**
  2. * Copyright 2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_VALIDATORS_H_
  17. #define MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_VALIDATORS_H_
  18. #include <limits>
  19. #include <memory>
  20. #include <string>
  21. #include <nlohmann/json.hpp>
  22. #include "minddata/dataset/core/tensor.h"
  23. #include "minddata/dataset/util/status.h"
  24. namespace mindspore {
  25. namespace dataset {
  26. // validator Parameter in json file
  27. inline Status ValidateParamInJson(const nlohmann::json &json_obj, const std::string &param_name,
  28. const std::string &operator_name) {
  29. if (json_obj.find(param_name) == json_obj.end()) {
  30. std::string err_msg = "Failed to find key '" + param_name + "' in " + operator_name +
  31. "' JSON file or input dict, check input content of deserialize().";
  32. RETURN_STATUS_UNEXPECTED(err_msg);
  33. }
  34. return Status::OK();
  35. }
  36. inline Status ValidateTensorShape(const std::string &op_name, bool cond, const std::string &expected_shape = "",
  37. const std::string &actual_dim = "") {
  38. if (!cond) {
  39. std::string err_msg = op_name + ": the shape of input tensor does not match the requirement of operator.";
  40. if (expected_shape != "") {
  41. err_msg += " Expecting tensor in shape of " + expected_shape + ".";
  42. }
  43. if (actual_dim != "") {
  44. err_msg += " But got tensor with dimension " + actual_dim + ".";
  45. }
  46. RETURN_STATUS_UNEXPECTED(err_msg);
  47. }
  48. return Status::OK();
  49. }
  50. inline Status ValidateLowRank(const std::string &op_name, const std::shared_ptr<Tensor> &input, dsize_t threshold = 0,
  51. const std::string &expected_shape = "") {
  52. dsize_t dim = input->shape().Size();
  53. return ValidateTensorShape(op_name, dim >= threshold, expected_shape, std::to_string(dim));
  54. }
  55. inline Status ValidateTensorType(const std::string &op_name, bool cond, const std::string &expected_type = "",
  56. const std::string &actual_type = "") {
  57. if (!cond) {
  58. std::string err_msg = op_name + ": the data type of input tensor does not match the requirement of operator.";
  59. if (expected_type != "") {
  60. err_msg += " Expecting tensor in type of " + expected_type + ".";
  61. }
  62. if (actual_type != "") {
  63. err_msg += " But got type " + actual_type + ".";
  64. }
  65. RETURN_STATUS_UNEXPECTED(err_msg);
  66. }
  67. return Status::OK();
  68. }
  69. inline Status ValidateTensorNumeric(const std::string &op_name, const std::shared_ptr<Tensor> &input) {
  70. return ValidateTensorType(op_name, input->type().IsNumeric(), "[int, float, double]", input->type().ToString());
  71. }
  72. inline Status ValidateTensorFloat(const std::string &op_name, const std::shared_ptr<Tensor> &input) {
  73. return ValidateTensorType(op_name, input->type().IsFloat(), "[float, double]", input->type().ToString());
  74. }
  75. template <typename T>
  76. inline Status ValidateEqual(const std::string &op_name, const std::string &param_name, T param_value,
  77. const std::string &other_name, T other_value) {
  78. if (param_value != other_value) {
  79. std::string err_msg = op_name + ": invalid parameter, '" + param_name + "' should be equal to '" + other_name +
  80. "', but got: " + param_name + " " + std::to_string(param_value) + " while " + other_name +
  81. " " + std::to_string(other_value) + ".";
  82. RETURN_STATUS_UNEXPECTED(err_msg);
  83. }
  84. return Status::OK();
  85. }
  86. template <typename T>
  87. inline Status ValidateNotEqual(const std::string &op_name, const std::string &param_name, T param_value,
  88. const std::string &other_name, T other_value) {
  89. if (param_value == other_value) {
  90. std::string err_msg = op_name + ": invalid parameter, '" + param_name + "' can not be equal to '" + other_name +
  91. "', but got: " + param_name + " " + std::to_string(param_value) + " while " + other_name +
  92. " " + std::to_string(other_value) + ".";
  93. RETURN_STATUS_UNEXPECTED(err_msg);
  94. }
  95. return Status::OK();
  96. }
  97. template <typename T>
  98. inline Status ValidateGreaterThan(const std::string &op_name, const std::string &param_name, T param_value,
  99. const std::string &other_name, T other_value) {
  100. if (param_value <= other_value) {
  101. std::string err_msg = op_name + ": invalid parameter, '" + param_name + "' should be greater than '" + other_name +
  102. "', but got: " + param_name + " " + std::to_string(param_value) + " while " + other_name +
  103. " " + std::to_string(other_value) + ".";
  104. RETURN_STATUS_UNEXPECTED(err_msg);
  105. }
  106. return Status::OK();
  107. }
  108. template <typename T>
  109. inline Status ValidateLessThan(const std::string &op_name, const std::string &param_name, T param_value,
  110. const std::string &other_name, T other_value) {
  111. if (param_value >= other_value) {
  112. std::string err_msg = op_name + ": invalid parameter, '" + param_name + "' should be less than '" + other_name +
  113. "', but got: " + param_name + " " + std::to_string(param_value) + " while " + other_name +
  114. " " + std::to_string(other_value) + ".";
  115. RETURN_STATUS_UNEXPECTED(err_msg);
  116. }
  117. return Status::OK();
  118. }
  119. template <typename T>
  120. inline Status ValidateNoGreaterThan(const std::string &op_name, const std::string &param_name, T param_value,
  121. const std::string &other_name, T other_value) {
  122. if (param_value > other_value) {
  123. std::string err_msg = op_name + ": invalid parameter, '" + param_name + "' should be no greater than '" +
  124. other_name + "', but got: " + param_name + " " + std::to_string(param_value) + " while " +
  125. other_name + " " + std::to_string(other_value) + ".";
  126. RETURN_STATUS_UNEXPECTED(err_msg);
  127. }
  128. return Status::OK();
  129. }
  130. template <typename T>
  131. inline Status ValidateNoLessThan(const std::string &op_name, const std::string &param_name, T param_value,
  132. const std::string &other_name, T other_value) {
  133. if (param_value < other_value) {
  134. std::string err_msg = op_name + ": invalid parameter, '" + param_name + "' should be no less than '" + other_name +
  135. "', but got: " + param_name + " " + std::to_string(param_value) + " while " + other_name +
  136. " " + std::to_string(other_value) + ".";
  137. RETURN_STATUS_UNEXPECTED(err_msg);
  138. }
  139. return Status::OK();
  140. }
  141. template <typename T>
  142. inline Status ValidatePositive(const std::string &op_name, const std::string &param_name, T param_value) {
  143. if (param_value <= 0) {
  144. std::string err_msg = op_name + ": invalid parameter, '" + param_name +
  145. "' should be positive, but got: " + std::to_string(param_value) + ".";
  146. RETURN_STATUS_UNEXPECTED(err_msg);
  147. }
  148. return Status::OK();
  149. }
  150. template <typename T>
  151. inline Status ValidateNegative(const std::string &op_name, const std::string &param_name, T param_value) {
  152. if (param_value >= 0) {
  153. std::string err_msg = op_name + ": invalid parameter, '" + param_name +
  154. "' should be negative, but got: " + std::to_string(param_value) + ".";
  155. RETURN_STATUS_UNEXPECTED(err_msg);
  156. }
  157. return Status::OK();
  158. }
  159. template <typename T>
  160. inline Status ValidateNonPositive(const std::string &op_name, const std::string &param_name, T param_value) {
  161. if (param_value > 0) {
  162. std::string err_msg = op_name + ": invalid parameter, '" + param_name +
  163. "' should be non positive, but got: " + std::to_string(param_value) + ".";
  164. RETURN_STATUS_UNEXPECTED(err_msg);
  165. }
  166. return Status::OK();
  167. }
  168. template <typename T>
  169. inline Status ValidateNonNegative(const std::string &op_name, const std::string &param_name, T param_value) {
  170. if (param_value < 0) {
  171. std::string err_msg = op_name + ": invalid parameter, '" + param_name +
  172. "' should be non negative, but got: " + std::to_string(param_value) + ".";
  173. RETURN_STATUS_UNEXPECTED(err_msg);
  174. }
  175. return Status::OK();
  176. }
  177. } // namespace dataset
  178. } // namespace mindspore
  179. #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_VALIDATORS_H_