|
|
@@ -321,6 +321,16 @@ std::shared_ptr<UniformAugOperation> UniformAugment(std::vector<std::shared_ptr< |
|
|
return op; |
|
|
return op; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
/* ####################################### Validator Functions ############################################ */ |
|
|
|
|
|
bool CheckVectorPositive(const std::vector<int32_t> &size) { |
|
|
|
|
|
for (int i = 0; i < size.size(); ++i) { |
|
|
|
|
|
if (size[i] <= 0) return false; |
|
|
|
|
|
} |
|
|
|
|
|
return true; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
bool CmpFloat(const float &a, const float &b, float epsilon = 0.0000000001f) { return (std::fabs(a - b) < epsilon); } |
|
|
|
|
|
|
|
|
/* ####################################### Derived TensorOperation classes ################################# */ |
|
|
/* ####################################### Derived TensorOperation classes ################################# */ |
|
|
|
|
|
|
|
|
// CenterCropOperation |
|
|
// CenterCropOperation |
|
|
@@ -331,6 +341,13 @@ bool CenterCropOperation::ValidateParams() { |
|
|
MS_LOG(ERROR) << "CenterCrop: size vector has incorrect size."; |
|
|
MS_LOG(ERROR) << "CenterCrop: size vector has incorrect size."; |
|
|
return false; |
|
|
return false; |
|
|
} |
|
|
} |
|
|
|
|
|
// We have to limit crop size due to library restrictions, optimized to only iterate over size_ once |
|
|
|
|
|
for (int i = 0; i < size_.size(); ++i) { |
|
|
|
|
|
if (size_[i] <= 0 || size_[i] == INT_MAX) { |
|
|
|
|
|
MS_LOG(ERROR) << "Crop: invalid size, size must be greater than zero, got: " << size_[i]; |
|
|
|
|
|
return false; |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
return true; |
|
|
return true; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
@@ -353,14 +370,22 @@ CropOperation::CropOperation(std::vector<int32_t> coordinates, std::vector<int32 |
|
|
|
|
|
|
|
|
bool CropOperation::ValidateParams() { |
|
|
bool CropOperation::ValidateParams() { |
|
|
// Do some input validation. |
|
|
// Do some input validation. |
|
|
if (coordinates_.empty() || coordinates_.size() > 2) { |
|
|
|
|
|
MS_LOG(ERROR) << "Crop: coordinates must be a vector of one or two values"; |
|
|
|
|
|
|
|
|
if (coordinates_.size() != 2) { |
|
|
|
|
|
MS_LOG(ERROR) << "Crop: coordinates must be a vector of two values"; |
|
|
return false; |
|
|
return false; |
|
|
} |
|
|
} |
|
|
|
|
|
// we don't check the coordinates here because we don't have access to image dimensions |
|
|
if (size_.empty() || size_.size() > 2) { |
|
|
if (size_.empty() || size_.size() > 2) { |
|
|
MS_LOG(ERROR) << "Crop: size must be a vector of one or two values"; |
|
|
MS_LOG(ERROR) << "Crop: size must be a vector of one or two values"; |
|
|
return false; |
|
|
return false; |
|
|
} |
|
|
} |
|
|
|
|
|
// We have to limit crop size due to library restrictions, optimized to only iterate over size_ once |
|
|
|
|
|
for (int i = 0; i < size_.size(); ++i) { |
|
|
|
|
|
if (size_[i] <= 0 || size_[i] == INT_MAX) { |
|
|
|
|
|
MS_LOG(ERROR) << "Crop: invalid size, size must be greater than zero, got: " << size_[i]; |
|
|
|
|
|
return false; |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
return true; |
|
|
return true; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
@@ -452,12 +477,24 @@ bool NormalizeOperation::ValidateParams() { |
|
|
MS_LOG(ERROR) << "Normalize: mean vector has incorrect size: " << mean_.size(); |
|
|
MS_LOG(ERROR) << "Normalize: mean vector has incorrect size: " << mean_.size(); |
|
|
return false; |
|
|
return false; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
// check mean value |
|
|
|
|
|
for (int i = 0; i < mean_.size(); ++i) { |
|
|
|
|
|
if (mean_[i] < 0.0f || mean_[i] > 255.0f || CmpFloat(mean_[i], 0.0f)) { |
|
|
|
|
|
MS_LOG(ERROR) << "Normalize: mean vector has incorrect value: " << mean_[i]; |
|
|
|
|
|
return false; |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
if (std_.size() != 3) { |
|
|
if (std_.size() != 3) { |
|
|
MS_LOG(ERROR) << "Normalize: std vector has incorrect size: " << std_.size(); |
|
|
MS_LOG(ERROR) << "Normalize: std vector has incorrect size: " << std_.size(); |
|
|
return false; |
|
|
return false; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
// check std value |
|
|
|
|
|
for (int i = 0; i < std_.size(); ++i) { |
|
|
|
|
|
if (std_[i] < 0.0f || mean_[i] > 255.0f || CmpFloat(std_[i], 0.0f)) { |
|
|
|
|
|
MS_LOG(ERROR) << "Normalize: std vector has incorrect value: " << std_[i]; |
|
|
|
|
|
return false; |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
return true; |
|
|
return true; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
@@ -886,6 +923,9 @@ bool ResizeOperation::ValidateParams() { |
|
|
MS_LOG(ERROR) << "Resize: size vector has incorrect size: " << size_.size(); |
|
|
MS_LOG(ERROR) << "Resize: size vector has incorrect size: " << size_.size(); |
|
|
return false; |
|
|
return false; |
|
|
} |
|
|
} |
|
|
|
|
|
if (!CheckVectorPositive(size_)) { |
|
|
|
|
|
return false; |
|
|
|
|
|
} |
|
|
return true; |
|
|
return true; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|