|
|
|
@@ -48,12 +48,10 @@ void CutMixBatchOp::GetCropBox(int height, int width, float lam, int *x, int *y, |
|
|
|
*crop_height = std::clamp(y2 - *y, 1, height - 1); |
|
|
|
} |
|
|
|
|
|
|
|
Status CutMixBatchOp::Compute(const TensorRow &input, TensorRow *output) { |
|
|
|
Status CutMixBatchOp::ValidateCutMixBatch(const TensorRow &input) { |
|
|
|
if (input.size() < 2) { |
|
|
|
RETURN_STATUS_UNEXPECTED("CutMixBatch: both image and label columns are required."); |
|
|
|
} |
|
|
|
|
|
|
|
std::vector<std::shared_ptr<Tensor>> images; |
|
|
|
std::vector<int64_t> image_shape = input.at(0)->shape().AsVector(); |
|
|
|
std::vector<int64_t> label_shape = input.at(1)->shape().AsVector(); |
|
|
|
|
|
|
|
@@ -80,14 +78,98 @@ Status CutMixBatchOp::Compute(const TensorRow &input, TensorRow *output) { |
|
|
|
RETURN_STATUS_UNEXPECTED("CutMixBatch: image doesn't match the NHWC format."); |
|
|
|
} |
|
|
|
|
|
|
|
return Status::OK(); |
|
|
|
} |
|
|
|
|
|
|
|
Status CutMixBatchOp::ComputeImage(const TensorRow &input, const int64_t rand_indx_i, const float lam, float *label_lam, |
|
|
|
std::shared_ptr<Tensor> *image_i) { |
|
|
|
std::vector<int64_t> image_shape = input.at(0)->shape().AsVector(); |
|
|
|
int x, y, crop_width, crop_height; |
|
|
|
// Get a random image |
|
|
|
TensorShape remaining({-1}); |
|
|
|
uchar *start_addr_of_index = nullptr; |
|
|
|
std::shared_ptr<Tensor> rand_image; |
|
|
|
|
|
|
|
RETURN_IF_NOT_OK(input.at(0)->StartAddrOfIndex({rand_indx_i, 0, 0, 0}, &start_addr_of_index, &remaining)); |
|
|
|
RETURN_IF_NOT_OK(Tensor::CreateFromMemory(TensorShape({image_shape[1], image_shape[2], image_shape[3]}), |
|
|
|
input.at(0)->type(), start_addr_of_index, &rand_image)); |
|
|
|
|
|
|
|
// Compute image |
|
|
|
if (image_batch_format_ == ImageBatchFormat::kNHWC) { |
|
|
|
// NHWC Format |
|
|
|
GetCropBox(static_cast<int32_t>(image_shape[1]), static_cast<int32_t>(image_shape[2]), lam, &x, &y, &crop_width, |
|
|
|
&crop_height); |
|
|
|
std::shared_ptr<Tensor> cropped; |
|
|
|
RETURN_IF_NOT_OK(Crop(rand_image, &cropped, x, y, crop_width, crop_height)); |
|
|
|
RETURN_IF_NOT_OK(MaskWithTensor(cropped, image_i, x, y, crop_width, crop_height, ImageFormat::HWC)); |
|
|
|
*label_lam = 1 - (crop_width * crop_height / static_cast<float>(image_shape[1] * image_shape[2])); |
|
|
|
} else { |
|
|
|
// NCHW Format |
|
|
|
GetCropBox(static_cast<int32_t>(image_shape[2]), static_cast<int32_t>(image_shape[3]), lam, &x, &y, &crop_width, |
|
|
|
&crop_height); |
|
|
|
std::vector<std::shared_ptr<Tensor>> channels; // A vector holding channels of the CHW image |
|
|
|
std::vector<std::shared_ptr<Tensor>> cropped_channels; // A vector holding the channels of the cropped CHW |
|
|
|
RETURN_IF_NOT_OK(BatchTensorToTensorVector(rand_image, &channels)); |
|
|
|
for (auto channel : channels) { |
|
|
|
// Call crop for each single channel |
|
|
|
std::shared_ptr<Tensor> cropped_channel; |
|
|
|
RETURN_IF_NOT_OK(Crop(channel, &cropped_channel, x, y, crop_width, crop_height)); |
|
|
|
cropped_channels.push_back(cropped_channel); |
|
|
|
} |
|
|
|
std::shared_ptr<Tensor> cropped; |
|
|
|
// Merge channels to a single tensor |
|
|
|
RETURN_IF_NOT_OK(TensorVectorToBatchTensor(cropped_channels, &cropped)); |
|
|
|
|
|
|
|
RETURN_IF_NOT_OK(MaskWithTensor(cropped, image_i, x, y, crop_width, crop_height, ImageFormat::CHW)); |
|
|
|
*label_lam = 1 - (crop_width * crop_height / static_cast<float>(image_shape[2] * image_shape[3])); |
|
|
|
} |
|
|
|
|
|
|
|
return Status::OK(); |
|
|
|
} |
|
|
|
|
|
|
|
Status CutMixBatchOp::ComputeLabel(const TensorRow &input, const int64_t rand_indx_i, const int64_t index_i, |
|
|
|
const int64_t row_labels, const int64_t num_classes, |
|
|
|
const std::size_t label_shape_size, const float label_lam, |
|
|
|
std::shared_ptr<Tensor> *out_labels) { |
|
|
|
// Compute labels |
|
|
|
for (int64_t j = 0; j < row_labels; j++) { |
|
|
|
for (int64_t k = 0; k < num_classes; k++) { |
|
|
|
std::vector<int64_t> first_index = label_shape_size == 3 ? std::vector{index_i, j, k} : std::vector{index_i, k}; |
|
|
|
std::vector<int64_t> second_index = |
|
|
|
label_shape_size == 3 ? std::vector{rand_indx_i, j, k} : std::vector{rand_indx_i, k}; |
|
|
|
if (input.at(1)->type().IsSignedInt()) { |
|
|
|
int64_t first_value, second_value; |
|
|
|
RETURN_IF_NOT_OK(input.at(1)->GetItemAt(&first_value, first_index)); |
|
|
|
RETURN_IF_NOT_OK(input.at(1)->GetItemAt(&second_value, second_index)); |
|
|
|
RETURN_IF_NOT_OK( |
|
|
|
(*out_labels)->SetItemAt(first_index, label_lam * first_value + (1 - label_lam) * second_value)); |
|
|
|
} else { |
|
|
|
uint64_t first_value, second_value; |
|
|
|
RETURN_IF_NOT_OK(input.at(1)->GetItemAt(&first_value, first_index)); |
|
|
|
RETURN_IF_NOT_OK(input.at(1)->GetItemAt(&second_value, second_index)); |
|
|
|
RETURN_IF_NOT_OK( |
|
|
|
(*out_labels)->SetItemAt(first_index, label_lam * first_value + (1 - label_lam) * second_value)); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
return Status::OK(); |
|
|
|
} |
|
|
|
|
|
|
|
Status CutMixBatchOp::Compute(const TensorRow &input, TensorRow *output) { |
|
|
|
IO_CHECK_VECTOR(input, output); |
|
|
|
RETURN_IF_NOT_OK(ValidateCutMixBatch(input)); |
|
|
|
std::vector<int64_t> image_shape = input.at(0)->shape().AsVector(); |
|
|
|
std::vector<int64_t> label_shape = input.at(1)->shape().AsVector(); |
|
|
|
|
|
|
|
// Move images into a vector of Tensors |
|
|
|
std::vector<std::shared_ptr<Tensor>> images; |
|
|
|
RETURN_IF_NOT_OK(BatchTensorToTensorVector(input.at(0), &images)); |
|
|
|
|
|
|
|
// Calculate random labels |
|
|
|
std::vector<int64_t> rand_indx; |
|
|
|
for (int64_t i = 0; i < images.size(); i++) rand_indx.push_back(i); |
|
|
|
std::shuffle(rand_indx.begin(), rand_indx.end(), rnd_); |
|
|
|
|
|
|
|
std::gamma_distribution<float> gamma_distribution(alpha_, 1); |
|
|
|
std::uniform_real_distribution<double> uniform_distribution(0.0, 1.0); |
|
|
|
|
|
|
|
@@ -107,69 +189,12 @@ Status CutMixBatchOp::Compute(const TensorRow &input, TensorRow *output) { |
|
|
|
float lam = x1 / (x1 + x2); |
|
|
|
double random_number = uniform_distribution(rnd_); |
|
|
|
if (random_number < prob_) { |
|
|
|
int x, y, crop_width, crop_height; |
|
|
|
float label_lam; // lambda used for labels |
|
|
|
|
|
|
|
// Get a random image |
|
|
|
TensorShape remaining({-1}); |
|
|
|
uchar *start_addr_of_index = nullptr; |
|
|
|
std::shared_ptr<Tensor> rand_image; |
|
|
|
RETURN_IF_NOT_OK(input.at(0)->StartAddrOfIndex({rand_indx[i], 0, 0, 0}, &start_addr_of_index, &remaining)); |
|
|
|
RETURN_IF_NOT_OK(Tensor::CreateFromMemory(TensorShape({image_shape[1], image_shape[2], image_shape[3]}), |
|
|
|
input.at(0)->type(), start_addr_of_index, &rand_image)); |
|
|
|
|
|
|
|
// Compute image |
|
|
|
if (image_batch_format_ == ImageBatchFormat::kNHWC) { |
|
|
|
// NHWC Format |
|
|
|
GetCropBox(static_cast<int32_t>(image_shape[1]), static_cast<int32_t>(image_shape[2]), lam, &x, &y, &crop_width, |
|
|
|
&crop_height); |
|
|
|
std::shared_ptr<Tensor> cropped; |
|
|
|
RETURN_IF_NOT_OK(Crop(rand_image, &cropped, x, y, crop_width, crop_height)); |
|
|
|
RETURN_IF_NOT_OK(MaskWithTensor(cropped, &images[i], x, y, crop_width, crop_height, ImageFormat::HWC)); |
|
|
|
label_lam = 1 - (crop_width * crop_height / static_cast<float>(image_shape[1] * image_shape[2])); |
|
|
|
} else { |
|
|
|
// NCHW Format |
|
|
|
GetCropBox(static_cast<int32_t>(image_shape[2]), static_cast<int32_t>(image_shape[3]), lam, &x, &y, &crop_width, |
|
|
|
&crop_height); |
|
|
|
std::vector<std::shared_ptr<Tensor>> channels; // A vector holding channels of the CHW image |
|
|
|
std::vector<std::shared_ptr<Tensor>> cropped_channels; // A vector holding the channels of the cropped CHW |
|
|
|
RETURN_IF_NOT_OK(BatchTensorToTensorVector(rand_image, &channels)); |
|
|
|
for (auto channel : channels) { |
|
|
|
// Call crop for each single channel |
|
|
|
std::shared_ptr<Tensor> cropped_channel; |
|
|
|
RETURN_IF_NOT_OK(Crop(channel, &cropped_channel, x, y, crop_width, crop_height)); |
|
|
|
cropped_channels.push_back(cropped_channel); |
|
|
|
} |
|
|
|
std::shared_ptr<Tensor> cropped; |
|
|
|
// Merge channels to a single tensor |
|
|
|
RETURN_IF_NOT_OK(TensorVectorToBatchTensor(cropped_channels, &cropped)); |
|
|
|
|
|
|
|
RETURN_IF_NOT_OK(MaskWithTensor(cropped, &images[i], x, y, crop_width, crop_height, ImageFormat::CHW)); |
|
|
|
label_lam = 1 - (crop_width * crop_height / static_cast<float>(image_shape[2] * image_shape[3])); |
|
|
|
} |
|
|
|
|
|
|
|
RETURN_IF_NOT_OK(ComputeImage(input, rand_indx[i], lam, &label_lam, &images[i])); |
|
|
|
// Compute labels |
|
|
|
|
|
|
|
for (int64_t j = 0; j < row_labels; j++) { |
|
|
|
for (int64_t k = 0; k < num_classes; k++) { |
|
|
|
std::vector<int64_t> first_index = label_shape.size() == 3 ? std::vector{i, j, k} : std::vector{i, k}; |
|
|
|
std::vector<int64_t> second_index = |
|
|
|
label_shape.size() == 3 ? std::vector{rand_indx[i], j, k} : std::vector{rand_indx[i], k}; |
|
|
|
if (input.at(1)->type().IsSignedInt()) { |
|
|
|
int64_t first_value, second_value; |
|
|
|
RETURN_IF_NOT_OK(input.at(1)->GetItemAt(&first_value, first_index)); |
|
|
|
RETURN_IF_NOT_OK(input.at(1)->GetItemAt(&second_value, second_index)); |
|
|
|
RETURN_IF_NOT_OK( |
|
|
|
out_labels->SetItemAt(first_index, label_lam * first_value + (1 - label_lam) * second_value)); |
|
|
|
} else { |
|
|
|
uint64_t first_value, second_value; |
|
|
|
RETURN_IF_NOT_OK(input.at(1)->GetItemAt(&first_value, first_index)); |
|
|
|
RETURN_IF_NOT_OK(input.at(1)->GetItemAt(&second_value, second_index)); |
|
|
|
RETURN_IF_NOT_OK( |
|
|
|
out_labels->SetItemAt(first_index, label_lam * first_value + (1 - label_lam) * second_value)); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
RETURN_IF_NOT_OK( |
|
|
|
ComputeLabel(input, rand_indx[i], i, row_labels, num_classes, label_shape.size(), label_lam, &out_labels)); |
|
|
|
} |
|
|
|
} |
|
|
|
std::shared_ptr<Tensor> out_images; |
|
|
|
|