|
|
|
@@ -32,7 +32,7 @@ CutMixBatchOp::CutMixBatchOp(ImageBatchFormat image_batch_format, float alpha, f |
|
|
|
} |
|
|
|
|
|
|
|
void CutMixBatchOp::GetCropBox(int height, int width, float lam, int *x, int *y, int *crop_width, int *crop_height) { |
|
|
|
float cut_ratio = 1 - lam; |
|
|
|
const float cut_ratio = 1 - lam; |
|
|
|
int cut_w = static_cast<int>(width * cut_ratio); |
|
|
|
int cut_h = static_cast<int>(height * cut_ratio); |
|
|
|
std::uniform_int_distribution<int> width_uniform_distribution(0, width); |
|
|
|
@@ -116,7 +116,6 @@ Status CutMixBatchOp::Compute(const TensorRow &input, TensorRow *output) { |
|
|
|
RETURN_IF_NOT_OK(Crop(rand_image, &cropped, x, y, crop_width, crop_height)); |
|
|
|
RETURN_IF_NOT_OK(MaskWithTensor(cropped, &images[i], x, y, crop_width, crop_height, ImageFormat::HWC)); |
|
|
|
label_lam = 1 - (crop_width * crop_height / static_cast<float>(image_shape[1] * image_shape[2])); |
|
|
|
|
|
|
|
} else { |
|
|
|
// NCHW Format |
|
|
|
GetCropBox(static_cast<int32_t>(image_shape[2]), static_cast<int32_t>(image_shape[3]), lam, &x, &y, &crop_width, |
|
|
|
|