/** * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_VALIDATORS_H_ #define MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_VALIDATORS_H_ #include #include #include #include #include "minddata/dataset/core/tensor.h" #include "minddata/dataset/util/status.h" namespace mindspore { namespace dataset { // validator Parameter in json file inline Status ValidateParamInJson(const nlohmann::json &json_obj, const std::string ¶m_name, const std::string &operator_name) { if (json_obj.find(param_name) == json_obj.end()) { std::string err_msg = "Failed to find key '" + param_name + "' in " + operator_name + "' JSON file or input dict, check input content of deserialize()."; RETURN_STATUS_UNEXPECTED(err_msg); } return Status::OK(); } inline Status ValidateTensorShape(const std::string &op_name, bool cond, const std::string &expected_shape = "", const std::string &actual_dim = "") { if (!cond) { std::string err_msg = op_name + ": the shape of input tensor does not match the requirement of operator."; if (expected_shape != "") { err_msg += " Expecting tensor in shape of " + expected_shape + "."; } if (actual_dim != "") { err_msg += " But got tensor with dimension " + actual_dim + "."; } RETURN_STATUS_UNEXPECTED(err_msg); } return Status::OK(); } inline Status ValidateLowRank(const std::string &op_name, const std::shared_ptr &input, dsize_t threshold = 0, const std::string &expected_shape = "") { dsize_t dim = input->shape().Size(); return ValidateTensorShape(op_name, dim >= threshold, expected_shape, std::to_string(dim)); } inline Status ValidateTensorType(const std::string &op_name, bool cond, const std::string &expected_type = "", const std::string &actual_type = "") { if (!cond) { std::string err_msg = op_name + ": the data type of input tensor does not match the requirement of operator."; if (expected_type != "") { err_msg += " Expecting tensor in type of " + expected_type + "."; } if (actual_type != "") { err_msg += " But got type " + actual_type + "."; } RETURN_STATUS_UNEXPECTED(err_msg); } return Status::OK(); } inline Status ValidateTensorNumeric(const std::string &op_name, const std::shared_ptr &input) { return ValidateTensorType(op_name, input->type().IsNumeric(), "[int, float, double]", input->type().ToString()); } inline Status ValidateTensorFloat(const std::string &op_name, const std::shared_ptr &input) { return ValidateTensorType(op_name, input->type().IsFloat(), "[float, double]", input->type().ToString()); } template inline Status ValidateEqual(const std::string &op_name, const std::string ¶m_name, T param_value, const std::string &other_name, T other_value) { if (param_value != other_value) { std::string err_msg = op_name + ": invalid parameter, '" + param_name + "' should be equal to '" + other_name + "', but got: " + param_name + " " + std::to_string(param_value) + " while " + other_name + " " + std::to_string(other_value) + "."; RETURN_STATUS_UNEXPECTED(err_msg); } return Status::OK(); } template inline Status ValidateNotEqual(const std::string &op_name, const std::string ¶m_name, T param_value, const std::string &other_name, T other_value) { if (param_value == other_value) { std::string err_msg = op_name + ": invalid parameter, '" + param_name + "' can not be equal to '" + other_name + "', but got: " + param_name + " " + std::to_string(param_value) + " while " + other_name + " " + std::to_string(other_value) + "."; RETURN_STATUS_UNEXPECTED(err_msg); } return Status::OK(); } template inline Status ValidateGreaterThan(const std::string &op_name, const std::string ¶m_name, T param_value, const std::string &other_name, T other_value) { if (param_value <= other_value) { std::string err_msg = op_name + ": invalid parameter, '" + param_name + "' should be greater than '" + other_name + "', but got: " + param_name + " " + std::to_string(param_value) + " while " + other_name + " " + std::to_string(other_value) + "."; RETURN_STATUS_UNEXPECTED(err_msg); } return Status::OK(); } template inline Status ValidateLessThan(const std::string &op_name, const std::string ¶m_name, T param_value, const std::string &other_name, T other_value) { if (param_value >= other_value) { std::string err_msg = op_name + ": invalid parameter, '" + param_name + "' should be less than '" + other_name + "', but got: " + param_name + " " + std::to_string(param_value) + " while " + other_name + " " + std::to_string(other_value) + "."; RETURN_STATUS_UNEXPECTED(err_msg); } return Status::OK(); } template inline Status ValidateNoGreaterThan(const std::string &op_name, const std::string ¶m_name, T param_value, const std::string &other_name, T other_value) { if (param_value > other_value) { std::string err_msg = op_name + ": invalid parameter, '" + param_name + "' should be no greater than '" + other_name + "', but got: " + param_name + " " + std::to_string(param_value) + " while " + other_name + " " + std::to_string(other_value) + "."; RETURN_STATUS_UNEXPECTED(err_msg); } return Status::OK(); } template inline Status ValidateNoLessThan(const std::string &op_name, const std::string ¶m_name, T param_value, const std::string &other_name, T other_value) { if (param_value < other_value) { std::string err_msg = op_name + ": invalid parameter, '" + param_name + "' should be no less than '" + other_name + "', but got: " + param_name + " " + std::to_string(param_value) + " while " + other_name + " " + std::to_string(other_value) + "."; RETURN_STATUS_UNEXPECTED(err_msg); } return Status::OK(); } template inline Status ValidatePositive(const std::string &op_name, const std::string ¶m_name, T param_value) { if (param_value <= 0) { std::string err_msg = op_name + ": invalid parameter, '" + param_name + "' should be positive, but got: " + std::to_string(param_value) + "."; RETURN_STATUS_UNEXPECTED(err_msg); } return Status::OK(); } template inline Status ValidateNegative(const std::string &op_name, const std::string ¶m_name, T param_value) { if (param_value >= 0) { std::string err_msg = op_name + ": invalid parameter, '" + param_name + "' should be negative, but got: " + std::to_string(param_value) + "."; RETURN_STATUS_UNEXPECTED(err_msg); } return Status::OK(); } template inline Status ValidateNonPositive(const std::string &op_name, const std::string ¶m_name, T param_value) { if (param_value > 0) { std::string err_msg = op_name + ": invalid parameter, '" + param_name + "' should be non positive, but got: " + std::to_string(param_value) + "."; RETURN_STATUS_UNEXPECTED(err_msg); } return Status::OK(); } template inline Status ValidateNonNegative(const std::string &op_name, const std::string ¶m_name, T param_value) { if (param_value < 0) { std::string err_msg = op_name + ": invalid parameter, '" + param_name + "' should be non negative, but got: " + std::to_string(param_value) + "."; RETURN_STATUS_UNEXPECTED(err_msg); } return Status::OK(); } } // namespace dataset } // namespace mindspore #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_VALIDATORS_H_