| @@ -69,7 +69,7 @@ Status VOCOp::Builder::Build(std::shared_ptr<VOCOp> *ptr) { | |||
| RETURN_IF_NOT_OK(builder_schema_->AddColumn( | |||
| ColDescriptor(std::string(kColumnImage), DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1))); | |||
| RETURN_IF_NOT_OK(builder_schema_->AddColumn( | |||
| ColDescriptor(std::string(kColumnAnnotation), DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1))); | |||
| ColDescriptor(std::string(kColumnAnnotation), DataType(DataType::DE_FLOAT32), TensorImpl::kFlexible, 1))); | |||
| } | |||
| *ptr = std::make_shared<VOCOp>(builder_task_type_, builder_task_mode_, builder_dir_, builder_labels_to_read_, | |||
| builder_num_workers_, builder_rows_per_buffer_, builder_op_connector_size_, | |||
| @@ -308,30 +308,30 @@ Status VOCOp::ParseAnnotationBbox(const std::string &path) { | |||
| } | |||
| while (object != nullptr) { | |||
| std::string label_name; | |||
| uint32_t xmin = 0, ymin = 0, xmax = 0, ymax = 0, truncated = 0, difficult = 0; | |||
| float xmin = 0.0, ymin = 0.0, xmax = 0.0, ymax = 0.0, truncated = 0.0, difficult = 0.0; | |||
| XMLElement *name_node = object->FirstChildElement("name"); | |||
| if (name_node != nullptr && name_node->GetText() != 0) label_name = name_node->GetText(); | |||
| XMLElement *truncated_node = object->FirstChildElement("truncated"); | |||
| if (truncated_node != nullptr) truncated = truncated_node->UnsignedText(); | |||
| if (truncated_node != nullptr) truncated = truncated_node->FloatText(); | |||
| XMLElement *difficult_node = object->FirstChildElement("difficult"); | |||
| if (difficult_node != nullptr) difficult = difficult_node->UnsignedText(); | |||
| if (difficult_node != nullptr) difficult = difficult_node->FloatText(); | |||
| XMLElement *bbox_node = object->FirstChildElement("bndbox"); | |||
| if (bbox_node != nullptr) { | |||
| XMLElement *xmin_node = bbox_node->FirstChildElement("xmin"); | |||
| if (xmin_node != nullptr) xmin = xmin_node->UnsignedText(); | |||
| if (xmin_node != nullptr) xmin = xmin_node->FloatText(); | |||
| XMLElement *ymin_node = bbox_node->FirstChildElement("ymin"); | |||
| if (ymin_node != nullptr) ymin = ymin_node->UnsignedText(); | |||
| if (ymin_node != nullptr) ymin = ymin_node->FloatText(); | |||
| XMLElement *xmax_node = bbox_node->FirstChildElement("xmax"); | |||
| if (xmax_node != nullptr) xmax = xmax_node->UnsignedText(); | |||
| if (xmax_node != nullptr) xmax = xmax_node->FloatText(); | |||
| XMLElement *ymax_node = bbox_node->FirstChildElement("ymax"); | |||
| if (ymax_node != nullptr) ymax = ymax_node->UnsignedText(); | |||
| if (ymax_node != nullptr) ymax = ymax_node->FloatText(); | |||
| } else { | |||
| RETURN_STATUS_UNEXPECTED("bndbox dismatch in " + path); | |||
| } | |||
| if (label_name != "" && (class_index_.empty() || class_index_.find(label_name) != class_index_.end()) && xmin > 0 && | |||
| ymin > 0 && xmax > xmin && ymax > ymin) { | |||
| std::vector<uint32_t> bbox_list = {xmin, ymin, xmax - xmin, ymax - ymin, truncated, difficult}; | |||
| std::vector<float> bbox_list = {xmin, ymin, xmax - xmin, ymax - ymin, truncated, difficult}; | |||
| bbox.emplace_back(std::make_pair(label_name, bbox_list)); | |||
| label_index_[label_name] = 0; | |||
| } | |||
| @@ -376,17 +376,17 @@ Status VOCOp::ReadImageToTensor(const std::string &path, const ColDescriptor &co | |||
| Status VOCOp::ReadAnnotationToTensor(const std::string &path, const ColDescriptor &col, | |||
| std::shared_ptr<Tensor> *tensor) { | |||
| Bbox bbox_info = label_map_[path]; | |||
| std::vector<uint32_t> bbox_row; | |||
| std::vector<float> bbox_row; | |||
| dsize_t bbox_column_num = 0, bbox_num = 0; | |||
| for (auto box : bbox_info) { | |||
| if (label_index_.find(box.first) != label_index_.end()) { | |||
| std::vector<uint32_t> bbox; | |||
| std::vector<float> bbox; | |||
| bbox.insert(bbox.end(), box.second.begin(), box.second.end()); | |||
| if (class_index_.find(box.first) != class_index_.end()) { | |||
| bbox.emplace_back(class_index_[box.first]); | |||
| bbox.push_back(static_cast<float>(class_index_[box.first])); | |||
| } else { | |||
| bbox.emplace_back(label_index_[box.first]); | |||
| bbox.push_back(static_cast<float>(label_index_[box.first])); | |||
| } | |||
| bbox.insert(bbox.end(), box.second.begin(), box.second.end()); | |||
| bbox_row.insert(bbox_row.end(), bbox.begin(), bbox.end()); | |||
| if (bbox_column_num == 0) { | |||
| bbox_column_num = static_cast<dsize_t>(bbox.size()); | |||
| @@ -40,7 +40,7 @@ namespace dataset { | |||
| template <typename T> | |||
| class Queue; | |||
| using Bbox = std::vector<std::pair<std::string, std::vector<uint32_t>>>; | |||
| using Bbox = std::vector<std::pair<std::string, std::vector<float>>>; | |||
| class VOCOp : public ParallelOp, public RandomAccessOp { | |||
| public: | |||
| @@ -1,292 +0,0 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """ | |||
| Testing the bounding box augment op in DE | |||
| """ | |||
| from util import visualize_with_bounding_boxes, InvalidBBoxType, check_bad_bbox, \ | |||
| config_get_set_seed, config_get_set_num_parallel_workers, save_and_check_md5 | |||
| import numpy as np | |||
| import mindspore.log as logger | |||
| import mindspore.dataset as ds | |||
| import mindspore.dataset.transforms.vision.c_transforms as c_vision | |||
| GENERATE_GOLDEN = False | |||
| DATA_DIR = "../data/dataset/testVOC2012_2" | |||
| def fix_annotate(bboxes): | |||
| """ | |||
| Fix annotations to format followed by mindspore. | |||
| :param bboxes: in [label, x_min, y_min, w, h, truncate, difficult] format | |||
| :return: annotation in [x_min, y_min, w, h, label, truncate, difficult] format | |||
| """ | |||
| for bbox in bboxes: | |||
| if bbox.size == 7: | |||
| tmp = bbox[0] | |||
| bbox[0] = bbox[1] | |||
| bbox[1] = bbox[2] | |||
| bbox[2] = bbox[3] | |||
| bbox[3] = bbox[4] | |||
| bbox[4] = tmp | |||
| else: | |||
| print("ERROR: Invalid Bounding Box size provided") | |||
| break | |||
| return bboxes | |||
| def test_bounding_box_augment_with_rotation_op(plot_vis=False): | |||
| """ | |||
| Test BoundingBoxAugment op (passing rotation op as transform) | |||
| Prints images side by side with and without Aug applied + bboxes to compare and test | |||
| """ | |||
| logger.info("test_bounding_box_augment_with_rotation_op") | |||
| original_seed = config_get_set_seed(0) | |||
| original_num_parallel_workers = config_get_set_num_parallel_workers(1) | |||
| dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | |||
| dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | |||
| # Ratio is set to 1 to apply rotation on all bounding boxes. | |||
| test_op = c_vision.BoundingBoxAugment(c_vision.RandomRotation(90), 1) | |||
| # maps to fix annotations to minddata standard | |||
| dataVoc1 = dataVoc1.map(input_columns=["annotation"], | |||
| output_columns=["annotation"], | |||
| operations=fix_annotate) | |||
| dataVoc2 = dataVoc2.map(input_columns=["annotation"], | |||
| output_columns=["annotation"], | |||
| operations=fix_annotate) | |||
| # map to apply ops | |||
| dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"], | |||
| output_columns=["image", "annotation"], | |||
| columns_order=["image", "annotation"], | |||
| operations=[test_op]) | |||
| filename = "bounding_box_augment_rotation_c_result.npz" | |||
| save_and_check_md5(dataVoc2, filename, generate_golden=GENERATE_GOLDEN) | |||
| unaugSamp, augSamp = [], [] | |||
| for unAug, Aug in zip(dataVoc1.create_dict_iterator(), dataVoc2.create_dict_iterator()): | |||
| unaugSamp.append(unAug) | |||
| augSamp.append(Aug) | |||
| if plot_vis: | |||
| visualize_with_bounding_boxes(unaugSamp, augSamp) | |||
| # Restore config setting | |||
| ds.config.set_seed(original_seed) | |||
| ds.config.set_num_parallel_workers(original_num_parallel_workers) | |||
| def test_bounding_box_augment_with_crop_op(plot_vis=False): | |||
| """ | |||
| Test BoundingBoxAugment op (passing crop op as transform) | |||
| Prints images side by side with and without Aug applied + bboxes to compare and test | |||
| """ | |||
| logger.info("test_bounding_box_augment_with_crop_op") | |||
| original_seed = config_get_set_seed(1) | |||
| original_num_parallel_workers = config_get_set_num_parallel_workers(1) | |||
| dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | |||
| dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | |||
| # Ratio is set to 1 to apply rotation on all bounding boxes. | |||
| test_op = c_vision.BoundingBoxAugment(c_vision.RandomCrop(90), 1) | |||
| # maps to fix annotations to minddata standard | |||
| dataVoc1 = dataVoc1.map(input_columns=["annotation"], | |||
| output_columns=["annotation"], | |||
| operations=fix_annotate) | |||
| dataVoc2 = dataVoc2.map(input_columns=["annotation"], | |||
| output_columns=["annotation"], | |||
| operations=fix_annotate) | |||
| # map to apply ops | |||
| dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"], | |||
| output_columns=["image", "annotation"], | |||
| columns_order=["image", "annotation"], | |||
| operations=[test_op]) | |||
| filename = "bounding_box_augment_crop_c_result.npz" | |||
| save_and_check_md5(dataVoc2, filename, generate_golden=GENERATE_GOLDEN) | |||
| unaugSamp, augSamp = [], [] | |||
| for unAug, Aug in zip(dataVoc1.create_dict_iterator(), dataVoc2.create_dict_iterator()): | |||
| unaugSamp.append(unAug) | |||
| augSamp.append(Aug) | |||
| if plot_vis: | |||
| visualize_with_bounding_boxes(unaugSamp, augSamp) | |||
| # Restore config setting | |||
| ds.config.set_seed(original_seed) | |||
| ds.config.set_num_parallel_workers(original_num_parallel_workers) | |||
| def test_bounding_box_augment_valid_ratio_c(plot_vis=False): | |||
| """ | |||
| Test BoundingBoxAugment op (testing with valid ratio, less than 1. | |||
| Prints images side by side with and without Aug applied + bboxes to compare and test | |||
| """ | |||
| logger.info("test_bounding_box_augment_valid_ratio_c") | |||
| original_seed = config_get_set_seed(1) | |||
| original_num_parallel_workers = config_get_set_num_parallel_workers(1) | |||
| dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | |||
| dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | |||
| test_op = c_vision.BoundingBoxAugment(c_vision.RandomHorizontalFlip(1), 0.9) | |||
| # maps to fix annotations to minddata standard | |||
| dataVoc1 = dataVoc1.map(input_columns=["annotation"], | |||
| output_columns=["annotation"], | |||
| operations=fix_annotate) | |||
| dataVoc2 = dataVoc2.map(input_columns=["annotation"], | |||
| output_columns=["annotation"], | |||
| operations=fix_annotate) | |||
| # map to apply ops | |||
| dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"], | |||
| output_columns=["image", "annotation"], | |||
| columns_order=["image", "annotation"], | |||
| operations=[test_op]) # Add column for "annotation" | |||
| filename = "bounding_box_augment_valid_ratio_c_result.npz" | |||
| save_and_check_md5(dataVoc2, filename, generate_golden=GENERATE_GOLDEN) | |||
| unaugSamp, augSamp = [], [] | |||
| for unAug, Aug in zip(dataVoc1.create_dict_iterator(), dataVoc2.create_dict_iterator()): | |||
| unaugSamp.append(unAug) | |||
| augSamp.append(Aug) | |||
| if plot_vis: | |||
| visualize_with_bounding_boxes(unaugSamp, augSamp) | |||
| # Restore config setting | |||
| ds.config.set_seed(original_seed) | |||
| ds.config.set_num_parallel_workers(original_num_parallel_workers) | |||
| def test_bounding_box_augment_valid_edge_c(plot_vis=False): | |||
| """ | |||
| Test BoundingBoxAugment op (testing with valid edge case, box covering full image). | |||
| Prints images side by side with and without Aug applied + bboxes to compare and test | |||
| """ | |||
| logger.info("test_bounding_box_augment_valid_edge_c") | |||
| original_seed = config_get_set_seed(1) | |||
| original_num_parallel_workers = config_get_set_num_parallel_workers(1) | |||
| dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | |||
| dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | |||
| test_op = c_vision.BoundingBoxAugment(c_vision.RandomHorizontalFlip(1), 1) | |||
| # maps to fix annotations to minddata standard | |||
| dataVoc1 = dataVoc1.map(input_columns=["annotation"], | |||
| output_columns=["annotation"], | |||
| operations=fix_annotate) | |||
| dataVoc2 = dataVoc2.map(input_columns=["annotation"], | |||
| output_columns=["annotation"], | |||
| operations=fix_annotate) | |||
| # map to apply ops | |||
| # Add column for "annotation" | |||
| dataVoc1 = dataVoc1.map(input_columns=["image", "annotation"], | |||
| output_columns=["image", "annotation"], | |||
| columns_order=["image", "annotation"], | |||
| operations=lambda img, bbox: | |||
| (img, np.array([[0, 0, img.shape[1], img.shape[0], 0, 0, 0]]).astype(np.uint32))) | |||
| dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"], | |||
| output_columns=["image", "annotation"], | |||
| columns_order=["image", "annotation"], | |||
| operations=lambda img, bbox: | |||
| (img, np.array([[0, 0, img.shape[1], img.shape[0], 0, 0, 0]]).astype(np.uint32))) | |||
| dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"], | |||
| output_columns=["image", "annotation"], | |||
| columns_order=["image", "annotation"], | |||
| operations=[test_op]) | |||
| filename = "bounding_box_augment_valid_edge_c_result.npz" | |||
| save_and_check_md5(dataVoc2, filename, generate_golden=GENERATE_GOLDEN) | |||
| unaugSamp, augSamp = [], [] | |||
| for unAug, Aug in zip(dataVoc1.create_dict_iterator(), dataVoc2.create_dict_iterator()): | |||
| unaugSamp.append(unAug) | |||
| augSamp.append(Aug) | |||
| if plot_vis: | |||
| visualize_with_bounding_boxes(unaugSamp, augSamp) | |||
| # Restore config setting | |||
| ds.config.set_seed(original_seed) | |||
| ds.config.set_num_parallel_workers(original_num_parallel_workers) | |||
| def test_bounding_box_augment_invalid_ratio_c(): | |||
| """ | |||
| Test BoundingBoxAugment op with invalid input ratio | |||
| """ | |||
| logger.info("test_bounding_box_augment_invalid_ratio_c") | |||
| dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | |||
| try: | |||
| # ratio range is from 0 - 1 | |||
| test_op = c_vision.BoundingBoxAugment(c_vision.RandomHorizontalFlip(1), 1.5) | |||
| # maps to fix annotations to minddata standard | |||
| dataVoc2 = dataVoc2.map(input_columns=["annotation"], | |||
| output_columns=["annotation"], | |||
| operations=fix_annotate) | |||
| # map to apply ops | |||
| dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"], | |||
| output_columns=["image", "annotation"], | |||
| columns_order=["image", "annotation"], | |||
| operations=[test_op]) # Add column for "annotation" | |||
| except ValueError as error: | |||
| logger.info("Got an exception in DE: {}".format(str(error))) | |||
| assert "Input is not" in str(error) | |||
| def test_bounding_box_augment_invalid_bounds_c(): | |||
| """ | |||
| Test BoundingBoxAugment op with invalid bboxes. | |||
| """ | |||
| logger.info("test_bounding_box_augment_invalid_bounds_c") | |||
| test_op = c_vision.BoundingBoxAugment(c_vision.RandomHorizontalFlip(1), | |||
| 1) | |||
| dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | |||
| check_bad_bbox(dataVoc2, test_op, InvalidBBoxType.WidthOverflow, "bounding boxes is out of bounds of the image") | |||
| dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | |||
| check_bad_bbox(dataVoc2, test_op, InvalidBBoxType.HeightOverflow, "bounding boxes is out of bounds of the image") | |||
| dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | |||
| check_bad_bbox(dataVoc2, test_op, InvalidBBoxType.NegativeXY, "min_x") | |||
| dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | |||
| check_bad_bbox(dataVoc2, test_op, InvalidBBoxType.WrongShape, "4 features") | |||
| if __name__ == "__main__": | |||
| # set to false to not show plots | |||
| test_bounding_box_augment_with_rotation_op(plot_vis=False) | |||
| test_bounding_box_augment_with_crop_op(plot_vis=False) | |||
| test_bounding_box_augment_valid_ratio_c(plot_vis=False) | |||
| test_bounding_box_augment_valid_edge_c(plot_vis=False) | |||
| test_bounding_box_augment_invalid_ratio_c() | |||
| test_bounding_box_augment_invalid_bounds_c() | |||
| @@ -37,7 +37,7 @@ def test_voc_detection(): | |||
| for item in data1.create_dict_iterator(): | |||
| assert item["image"].shape[0] == IMAGE_SHAPE[num] | |||
| for bbox in item["annotation"]: | |||
| count[bbox[0]] += 1 | |||
| count[int(bbox[6])] += 1 | |||
| num += 1 | |||
| assert num == 9 | |||
| assert count == [3, 2, 1, 2, 4, 3] | |||
| @@ -55,8 +55,8 @@ def test_voc_class_index(): | |||
| count = [0, 0, 0, 0, 0, 0] | |||
| for item in data1.create_dict_iterator(): | |||
| for bbox in item["annotation"]: | |||
| assert (bbox[0] == 0 or bbox[0] == 1 or bbox[0] == 5) | |||
| count[bbox[0]] += 1 | |||
| assert (int(bbox[6]) == 0 or int(bbox[6]) == 1 or int(bbox[6]) == 5) | |||
| count[int(bbox[6])] += 1 | |||
| num += 1 | |||
| assert num == 6 | |||
| assert count == [3, 2, 0, 0, 0, 3] | |||
| @@ -73,8 +73,9 @@ def test_voc_get_class_indexing(): | |||
| count = [0, 0, 0, 0, 0, 0] | |||
| for item in data1.create_dict_iterator(): | |||
| for bbox in item["annotation"]: | |||
| assert (bbox[0] == 0 or bbox[0] == 1 or bbox[0] == 2 or bbox[0] == 3 or bbox[0] == 4 or bbox[0] == 5) | |||
| count[bbox[0]] += 1 | |||
| assert (int(bbox[6]) == 0 or int(bbox[6]) == 1 or int(bbox[6]) == 2 or int(bbox[6]) == 3 | |||
| or int(bbox[6]) == 4 or int(bbox[6]) == 5) | |||
| count[int(bbox[6])] += 1 | |||
| num += 1 | |||
| assert num == 9 | |||
| assert count == [3, 2, 1, 2, 4, 3] | |||
| @@ -1,220 +0,0 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """ | |||
| Testing RandomCropAndResizeWithBBox op in DE | |||
| """ | |||
| import numpy as np | |||
| import mindspore.dataset as ds | |||
| import mindspore.dataset.transforms.vision.c_transforms as c_vision | |||
| from mindspore import log as logger | |||
| from util import visualize_with_bounding_boxes, InvalidBBoxType, check_bad_bbox, \ | |||
| config_get_set_seed, config_get_set_num_parallel_workers, save_and_check_md5 | |||
| GENERATE_GOLDEN = False | |||
| # updated VOC dataset with correct annotations | |||
| DATA_DIR = "../data/dataset/testVOC2012_2" | |||
| def fix_annotate(bboxes): | |||
| """ | |||
| Fix annotations to format followed by mindspore. | |||
| :param bboxes: in [label, x_min, y_min, w, h, truncate, difficult] format | |||
| :return: annotation in [x_min, y_min, w, h, label, truncate, difficult] format | |||
| """ | |||
| for bbox in bboxes: | |||
| if bbox.size == 7: | |||
| tmp = bbox[0] | |||
| bbox[0] = bbox[1] | |||
| bbox[1] = bbox[2] | |||
| bbox[2] = bbox[3] | |||
| bbox[3] = bbox[4] | |||
| bbox[4] = tmp | |||
| else: | |||
| print("ERROR: Invalid Bounding Box size provided") | |||
| break | |||
| return bboxes | |||
| def test_random_resized_crop_with_bbox_op_c(plot_vis=False): | |||
| """ | |||
| Prints images and bboxes side by side with and without RandomResizedCropWithBBox Op applied, | |||
| tests with MD5 check, expected to pass | |||
| """ | |||
| logger.info("test_random_resized_crop_with_bbox_op_c") | |||
| original_seed = config_get_set_seed(23415) | |||
| original_num_parallel_workers = config_get_set_num_parallel_workers(1) | |||
| # Load dataset | |||
| dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | |||
| dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | |||
| test_op = c_vision.RandomResizedCropWithBBox((256, 512), (0.5, 0.5), (0.5, 0.5)) | |||
| dataVoc1 = dataVoc1.map(input_columns=["annotation"], | |||
| output_columns=["annotation"], | |||
| operations=fix_annotate) | |||
| dataVoc2 = dataVoc2.map(input_columns=["annotation"], | |||
| output_columns=["annotation"], | |||
| operations=fix_annotate) | |||
| # map to apply ops | |||
| dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"], | |||
| output_columns=["image", "annotation"], | |||
| columns_order=["image", "annotation"], | |||
| operations=[test_op]) # Add column for "annotation" | |||
| filename = "random_resized_crop_with_bbox_01_c_result.npz" | |||
| save_and_check_md5(dataVoc2, filename, generate_golden=GENERATE_GOLDEN) | |||
| unaugSamp, augSamp = [], [] | |||
| for unAug, Aug in zip(dataVoc1.create_dict_iterator(), dataVoc2.create_dict_iterator()): | |||
| unaugSamp.append(unAug) | |||
| augSamp.append(Aug) | |||
| if plot_vis: | |||
| visualize_with_bounding_boxes(unaugSamp, augSamp) | |||
| # Restore config setting | |||
| ds.config.set_seed(original_seed) | |||
| ds.config.set_num_parallel_workers(original_num_parallel_workers) | |||
| def test_random_resized_crop_with_bbox_op_edge_c(plot_vis=False): | |||
| """ | |||
| Prints images and bboxes side by side with and without RandomResizedCropWithBBox Op applied, | |||
| tests on dynamically generated edge case, expected to pass | |||
| """ | |||
| logger.info("test_random_resized_crop_with_bbox_op_edge_c") | |||
| # Load dataset | |||
| dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | |||
| dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | |||
| test_op = c_vision.RandomResizedCropWithBBox((256, 512), (0.5, 0.5), (0.5, 0.5)) | |||
| dataVoc1 = dataVoc1.map(input_columns=["annotation"], | |||
| output_columns=["annotation"], | |||
| operations=fix_annotate) | |||
| dataVoc2 = dataVoc2.map(input_columns=["annotation"], | |||
| output_columns=["annotation"], | |||
| operations=fix_annotate) | |||
| # maps to convert data into valid edge case data | |||
| dataVoc1 = dataVoc1.map(input_columns=["image", "annotation"], | |||
| output_columns=["image", "annotation"], | |||
| columns_order=["image", "annotation"], | |||
| operations=[lambda img, bboxes: (img, np.array([[0, 0, img.shape[1], img.shape[0]]]).astype(bboxes.dtype))]) | |||
| # Test Op added to list of Operations here | |||
| dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"], | |||
| output_columns=["image", "annotation"], | |||
| columns_order=["image", "annotation"], | |||
| operations=[lambda img, bboxes: (img, np.array([[0, 0, img.shape[1], img.shape[0]]]).astype(bboxes.dtype)), test_op]) | |||
| unaugSamp, augSamp = [], [] | |||
| for unAug, Aug in zip(dataVoc1.create_dict_iterator(), dataVoc2.create_dict_iterator()): | |||
| unaugSamp.append(unAug) | |||
| augSamp.append(Aug) | |||
| if plot_vis: | |||
| visualize_with_bounding_boxes(unaugSamp, augSamp) | |||
| def test_random_resized_crop_with_bbox_op_invalid_c(): | |||
| """ | |||
| Tests RandomResizedCropWithBBox on invalid constructor parameters, expected to raise ValueError | |||
| """ | |||
| logger.info("test_random_resized_crop_with_bbox_op_invalid_c") | |||
| # Load dataset, only Augmented Dataset as test will raise ValueError | |||
| dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | |||
| try: | |||
| # If input range of scale is not in the order of (min, max), ValueError will be raised. | |||
| test_op = c_vision.RandomResizedCropWithBBox((256, 512), (1, 0.5), (0.5, 0.5)) | |||
| dataVoc2 = dataVoc2.map(input_columns=["annotation"], | |||
| output_columns=["annotation"], | |||
| operations=fix_annotate) | |||
| # map to apply ops | |||
| dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"], | |||
| output_columns=["image", "annotation"], | |||
| columns_order=["image", "annotation"], | |||
| operations=[test_op]) | |||
| for _ in dataVoc2.create_dict_iterator(): | |||
| break | |||
| except ValueError as err: | |||
| logger.info("Got an exception in DE: {}".format(str(err))) | |||
| assert "Input range is not valid" in str(err) | |||
| def test_random_resized_crop_with_bbox_op_invalid2_c(): | |||
| """ | |||
| Tests RandomResizedCropWithBBox Op on invalid constructor parameters, expected to raise ValueError | |||
| """ | |||
| logger.info("test_random_resized_crop_with_bbox_op_invalid2_c") | |||
| # Load dataset # only loading the to AugDataset as test will fail on this | |||
| dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | |||
| try: | |||
| # If input range of ratio is not in the order of (min, max), ValueError will be raised. | |||
| test_op = c_vision.RandomResizedCropWithBBox((256, 512), (1, 1), (1, 0.5)) | |||
| dataVoc2 = dataVoc2.map(input_columns=["annotation"], | |||
| output_columns=["annotation"], | |||
| operations=fix_annotate) | |||
| # map to apply ops | |||
| dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"], | |||
| output_columns=["image", "annotation"], | |||
| columns_order=["image", "annotation"], | |||
| operations=[test_op]) | |||
| for _ in dataVoc2.create_dict_iterator(): | |||
| break | |||
| except ValueError as err: | |||
| logger.info("Got an exception in DE: {}".format(str(err))) | |||
| assert "Input range is not valid" in str(err) | |||
| def test_random_resized_crop_with_bbox_op_bad_c(): | |||
| """ | |||
| Test RandomCropWithBBox op with invalid bounding boxes, expected to catch multiple errors. | |||
| """ | |||
| logger.info("test_random_resized_crop_with_bbox_op_bad_c") | |||
| test_op = c_vision.RandomResizedCropWithBBox((256, 512), (0.5, 0.5), (0.5, 0.5)) | |||
| data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | |||
| check_bad_bbox(data_voc2, test_op, InvalidBBoxType.WidthOverflow, "bounding boxes is out of bounds of the image") | |||
| data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | |||
| check_bad_bbox(data_voc2, test_op, InvalidBBoxType.HeightOverflow, "bounding boxes is out of bounds of the image") | |||
| data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | |||
| check_bad_bbox(data_voc2, test_op, InvalidBBoxType.NegativeXY, "min_x") | |||
| data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | |||
| check_bad_bbox(data_voc2, test_op, InvalidBBoxType.WrongShape, "4 features") | |||
| if __name__ == "__main__": | |||
| test_random_resized_crop_with_bbox_op_c(plot_vis=True) | |||
| test_random_resized_crop_with_bbox_op_edge_c(plot_vis=True) | |||
| test_random_resized_crop_with_bbox_op_invalid_c() | |||
| test_random_resized_crop_with_bbox_op_invalid2_c() | |||
| test_random_resized_crop_with_bbox_op_bad_c() | |||
| @@ -1,265 +0,0 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """ | |||
| Testing RandomCropWithBBox op in DE | |||
| """ | |||
| import numpy as np | |||
| import mindspore.dataset as ds | |||
| import mindspore.dataset.transforms.vision.c_transforms as c_vision | |||
| import mindspore.dataset.transforms.vision.utils as mode | |||
| from mindspore import log as logger | |||
| from util import visualize_with_bounding_boxes, InvalidBBoxType, check_bad_bbox, \ | |||
| config_get_set_seed, config_get_set_num_parallel_workers, save_and_check_md5 | |||
| GENERATE_GOLDEN = False | |||
| # updated VOC dataset with correct annotations | |||
| DATA_DIR = "../data/dataset/testVOC2012_2" | |||
| def fix_annotate(bboxes): | |||
| """ | |||
| Fix annotations to format followed by mindspore. | |||
| :param bboxes: in [label, x_min, y_min, w, h, truncate, difficult] format | |||
| :return: annotation in [x_min, y_min, w, h, label, truncate, difficult] format | |||
| """ | |||
| for bbox in bboxes: | |||
| if bbox.size == 7: | |||
| tmp = bbox[0] | |||
| bbox[0] = bbox[1] | |||
| bbox[1] = bbox[2] | |||
| bbox[2] = bbox[3] | |||
| bbox[3] = bbox[4] | |||
| bbox[4] = tmp | |||
| else: | |||
| print("ERROR: Invalid Bounding Box size provided") | |||
| break | |||
| return bboxes | |||
| def test_random_crop_with_bbox_op_c(plot_vis=False): | |||
| """ | |||
| Prints images and bboxes side by side with and without RandomCropWithBBox Op applied | |||
| """ | |||
| logger.info("test_random_crop_with_bbox_op_c") | |||
| # Load dataset | |||
| dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | |||
| dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | |||
| # define test OP with values to match existing Op UT | |||
| test_op = c_vision.RandomCropWithBBox([512, 512], [200, 200, 200, 200]) | |||
| dataVoc1 = dataVoc1.map(input_columns=["annotation"], | |||
| output_columns=["annotation"], | |||
| operations=fix_annotate) | |||
| dataVoc2 = dataVoc2.map(input_columns=["annotation"], | |||
| output_columns=["annotation"], | |||
| operations=fix_annotate) | |||
| # map to apply ops | |||
| dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"], | |||
| output_columns=["image", "annotation"], | |||
| columns_order=["image", "annotation"], | |||
| operations=[test_op]) # Add column for "annotation" | |||
| unaugSamp, augSamp = [], [] | |||
| for unAug, Aug in zip(dataVoc1.create_dict_iterator(), dataVoc2.create_dict_iterator()): | |||
| unaugSamp.append(unAug) | |||
| augSamp.append(Aug) | |||
| if plot_vis: | |||
| visualize_with_bounding_boxes(unaugSamp, augSamp) | |||
| def test_random_crop_with_bbox_op2_c(plot_vis=False): | |||
| """ | |||
| Prints images and bboxes side by side with and without RandomCropWithBBox Op applied, | |||
| with md5 check, expected to pass | |||
| """ | |||
| logger.info("test_random_crop_with_bbox_op2_c") | |||
| original_seed = config_get_set_seed(593447) | |||
| original_num_parallel_workers = config_get_set_num_parallel_workers(1) | |||
| # Load dataset | |||
| dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | |||
| dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | |||
| # define test OP with values to match existing Op unit - test | |||
| test_op = c_vision.RandomCropWithBBox(512, [200, 200, 200, 200], fill_value=(255, 255, 255)) | |||
| dataVoc1 = dataVoc1.map(input_columns=["annotation"], | |||
| output_columns=["annotation"], | |||
| operations=fix_annotate) | |||
| dataVoc2 = dataVoc2.map(input_columns=["annotation"], | |||
| output_columns=["annotation"], | |||
| operations=fix_annotate) | |||
| # map to apply ops | |||
| dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"], | |||
| output_columns=["image", "annotation"], | |||
| columns_order=["image", "annotation"], | |||
| operations=[test_op]) # Add column for "annotation" | |||
| filename = "random_crop_with_bbox_01_c_result.npz" | |||
| save_and_check_md5(dataVoc2, filename, generate_golden=GENERATE_GOLDEN) | |||
| unaugSamp, augSamp = [], [] | |||
| for unAug, Aug in zip(dataVoc1.create_dict_iterator(), dataVoc2.create_dict_iterator()): | |||
| unaugSamp.append(unAug) | |||
| augSamp.append(Aug) | |||
| if plot_vis: | |||
| visualize_with_bounding_boxes(unaugSamp, augSamp) | |||
| # Restore config setting | |||
| ds.config.set_seed(original_seed) | |||
| ds.config.set_num_parallel_workers(original_num_parallel_workers) | |||
| def test_random_crop_with_bbox_op3_c(plot_vis=False): | |||
| """ | |||
| Prints images and bboxes side by side with and without RandomCropWithBBox Op applied, | |||
| with Padding Mode explicitly passed | |||
| """ | |||
| logger.info("test_random_crop_with_bbox_op3_c") | |||
| # Load dataset | |||
| dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | |||
| dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | |||
| # define test OP with values to match existing Op unit - test | |||
| test_op = c_vision.RandomCropWithBBox(512, [200, 200, 200, 200], padding_mode=mode.Border.EDGE) | |||
| dataVoc1 = dataVoc1.map(input_columns=["annotation"], | |||
| output_columns=["annotation"], | |||
| operations=fix_annotate) | |||
| dataVoc2 = dataVoc2.map(input_columns=["annotation"], | |||
| output_columns=["annotation"], | |||
| operations=fix_annotate) | |||
| # map to apply ops | |||
| dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"], | |||
| output_columns=["image", "annotation"], | |||
| columns_order=["image", "annotation"], | |||
| operations=[test_op]) # Add column for "annotation" | |||
| unaugSamp, augSamp = [], [] | |||
| for unAug, Aug in zip(dataVoc1.create_dict_iterator(), dataVoc2.create_dict_iterator()): | |||
| unaugSamp.append(unAug) | |||
| augSamp.append(Aug) | |||
| if plot_vis: | |||
| visualize_with_bounding_boxes(unaugSamp, augSamp) | |||
| def test_random_crop_with_bbox_op_edge_c(plot_vis=False): | |||
| """ | |||
| Prints images and bboxes side by side with and without RandomCropWithBBox Op applied, | |||
| applied on dynamically generated edge case, expected to pass | |||
| """ | |||
| logger.info("test_random_crop_with_bbox_op_edge_c") | |||
| # Load dataset | |||
| dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | |||
| dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | |||
| # define test OP with values to match existing Op unit - test | |||
| test_op = c_vision.RandomCropWithBBox(512, [200, 200, 200, 200], padding_mode=mode.Border.EDGE) | |||
| dataVoc1 = dataVoc1.map(input_columns=["annotation"], | |||
| output_columns=["annotation"], | |||
| operations=fix_annotate) | |||
| dataVoc2 = dataVoc2.map(input_columns=["annotation"], | |||
| output_columns=["annotation"], | |||
| operations=fix_annotate) | |||
| # maps to convert data into valid edge case data | |||
| dataVoc1 = dataVoc1.map(input_columns=["image", "annotation"], | |||
| output_columns=["image", "annotation"], | |||
| columns_order=["image", "annotation"], | |||
| operations=[lambda img, bboxes: (img, np.array([[0, 0, img.shape[1], img.shape[0]]]).astype(bboxes.dtype))]) | |||
| # Test Op added to list of Operations here | |||
| dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"], | |||
| output_columns=["image", "annotation"], | |||
| columns_order=["image", "annotation"], | |||
| operations=[lambda img, bboxes: (img, np.array([[0, 0, img.shape[1], img.shape[0]]]).astype(bboxes.dtype)), test_op]) | |||
| unaugSamp, augSamp = [], [] | |||
| for unAug, Aug in zip(dataVoc1.create_dict_iterator(), dataVoc2.create_dict_iterator()): | |||
| unaugSamp.append(unAug) | |||
| augSamp.append(Aug) | |||
| if plot_vis: | |||
| visualize_with_bounding_boxes(unaugSamp, augSamp) | |||
| def test_random_crop_with_bbox_op_invalid_c(): | |||
| """ | |||
| Test RandomCropWithBBox Op on invalid constructor parameters, expected to raise ValueError | |||
| """ | |||
| logger.info("test_random_crop_with_bbox_op_invalid_c") | |||
| # Load dataset | |||
| dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | |||
| try: | |||
| # define test OP with values to match existing Op unit - test | |||
| test_op = c_vision.RandomCropWithBBox([512, 512, 375]) | |||
| dataVoc2 = dataVoc2.map(input_columns=["annotation"], | |||
| output_columns=["annotation"], | |||
| operations=fix_annotate) | |||
| # map to apply ops | |||
| dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"], | |||
| output_columns=["image", "annotation"], | |||
| columns_order=["image", "annotation"], | |||
| operations=[test_op]) # Add column for "annotation" | |||
| for _ in dataVoc2.create_dict_iterator(): | |||
| break | |||
| except TypeError as err: | |||
| logger.info("Got an exception in DE: {}".format(str(err))) | |||
| assert "Size should be a single integer" in str(err) | |||
| def test_random_crop_with_bbox_op_bad_c(): | |||
| """ | |||
| Tests RandomCropWithBBox Op with invalid bounding boxes, expected to catch multiple errors. | |||
| """ | |||
| logger.info("test_random_crop_with_bbox_op_bad_c") | |||
| test_op = c_vision.RandomCropWithBBox([512, 512], [200, 200, 200, 200]) | |||
| data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | |||
| check_bad_bbox(data_voc2, test_op, InvalidBBoxType.WidthOverflow, "bounding boxes is out of bounds of the image") | |||
| data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | |||
| check_bad_bbox(data_voc2, test_op, InvalidBBoxType.HeightOverflow, "bounding boxes is out of bounds of the image") | |||
| data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | |||
| check_bad_bbox(data_voc2, test_op, InvalidBBoxType.NegativeXY, "min_x") | |||
| data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | |||
| check_bad_bbox(data_voc2, test_op, InvalidBBoxType.WrongShape, "4 features") | |||
| if __name__ == "__main__": | |||
| test_random_crop_with_bbox_op_c(plot_vis=True) | |||
| test_random_crop_with_bbox_op2_c(plot_vis=True) | |||
| test_random_crop_with_bbox_op3_c(plot_vis=True) | |||
| test_random_crop_with_bbox_op_edge_c(plot_vis=True) | |||
| test_random_crop_with_bbox_op_invalid_c() | |||
| test_random_crop_with_bbox_op_bad_c() | |||
| @@ -1,233 +0,0 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """ | |||
| Testing the random horizontal flip with bounding boxes op in DE | |||
| """ | |||
| import numpy as np | |||
| import mindspore.log as logger | |||
| import mindspore.dataset as ds | |||
| import mindspore.dataset.transforms.vision.c_transforms as c_vision | |||
| from util import visualize_with_bounding_boxes, InvalidBBoxType, check_bad_bbox, \ | |||
| config_get_set_seed, config_get_set_num_parallel_workers, save_and_check_md5 | |||
| GENERATE_GOLDEN = False | |||
| DATA_DIR = "../data/dataset/testVOC2012_2" | |||
| def fix_annotate(bboxes): | |||
| """ | |||
| Fix annotations to format followed by mindspore. | |||
| :param bboxes: in [label, x_min, y_min, w, h, truncate, difficult] format | |||
| :return: annotation in [x_min, y_min, w, h, label, truncate, difficult] format | |||
| """ | |||
| for bbox in bboxes: | |||
| if bbox.size == 7: | |||
| tmp = bbox[0] | |||
| bbox[0] = bbox[1] | |||
| bbox[1] = bbox[2] | |||
| bbox[2] = bbox[3] | |||
| bbox[3] = bbox[4] | |||
| bbox[4] = tmp | |||
| else: | |||
| print("ERROR: Invalid Bounding Box size provided") | |||
| break | |||
| return bboxes | |||
| def test_random_horizontal_flip_with_bbox_op_c(plot_vis=False): | |||
| """ | |||
| Prints images side by side with and without Aug applied + bboxes to | |||
| compare and test | |||
| """ | |||
| logger.info("test_random_horizontal_flip_with_bbox_op_c") | |||
| # Load dataset | |||
| dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", | |||
| decode=True, shuffle=False) | |||
| dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", | |||
| decode=True, shuffle=False) | |||
| test_op = c_vision.RandomHorizontalFlipWithBBox(1) | |||
| # maps to fix annotations to minddata standard | |||
| dataVoc1 = dataVoc1.map(input_columns=["annotation"], | |||
| output_columns=["annotation"], | |||
| operations=fix_annotate) | |||
| dataVoc2 = dataVoc2.map(input_columns=["annotation"], | |||
| output_columns=["annotation"], | |||
| operations=fix_annotate) | |||
| # map to apply ops | |||
| dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"], | |||
| output_columns=["image", "annotation"], | |||
| columns_order=["image", "annotation"], | |||
| operations=[test_op]) | |||
| unaugSamp, augSamp = [], [] | |||
| for unAug, Aug in zip(dataVoc1.create_dict_iterator(), dataVoc2.create_dict_iterator()): | |||
| unaugSamp.append(unAug) | |||
| augSamp.append(Aug) | |||
| if plot_vis: | |||
| visualize_with_bounding_boxes(unaugSamp, augSamp) | |||
| def test_random_horizontal_bbox_with_bbox_valid_rand_c(plot_vis=False): | |||
| """ | |||
| Uses a valid non-default input, expect to pass | |||
| Prints images side by side with and without Aug applied + bboxes to | |||
| compare and test | |||
| """ | |||
| logger.info("test_random_horizontal_bbox_valid_rand_c") | |||
| original_seed = config_get_set_seed(1) | |||
| original_num_parallel_workers = config_get_set_num_parallel_workers(1) | |||
| # Load dataset | |||
| dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", | |||
| decode=True, shuffle=False) | |||
| dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", | |||
| decode=True, shuffle=False) | |||
| test_op = c_vision.RandomHorizontalFlipWithBBox(0.6) | |||
| # maps to fix annotations to minddata standard | |||
| dataVoc1 = dataVoc1.map(input_columns=["annotation"], | |||
| output_columns=["annotation"], | |||
| operations=fix_annotate) | |||
| dataVoc2 = dataVoc2.map(input_columns=["annotation"], | |||
| output_columns=["annotation"], | |||
| operations=fix_annotate) | |||
| # map to apply ops | |||
| dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"], | |||
| output_columns=["image", "annotation"], | |||
| columns_order=["image", "annotation"], | |||
| operations=[test_op]) | |||
| filename = "random_horizontal_flip_with_bbox_01_c_result.npz" | |||
| save_and_check_md5(dataVoc2, filename, generate_golden=GENERATE_GOLDEN) | |||
| unaugSamp, augSamp = [], [] | |||
| for unAug, Aug in zip(dataVoc1.create_dict_iterator(), dataVoc2.create_dict_iterator()): | |||
| unaugSamp.append(unAug) | |||
| augSamp.append(Aug) | |||
| if plot_vis: | |||
| visualize_with_bounding_boxes(unaugSamp, augSamp) | |||
| # Restore config setting | |||
| ds.config.set_seed(original_seed) | |||
| ds.config.set_num_parallel_workers(original_num_parallel_workers) | |||
| def test_random_horizontal_flip_with_bbox_valid_edge_c(plot_vis=False): | |||
| """ | |||
| Test RandomHorizontalFlipWithBBox op (testing with valid edge case, box covering full image). | |||
| Prints images side by side with and without Aug applied + bboxes to compare and test | |||
| """ | |||
| logger.info("test_horizontal_flip_with_bbox_valid_edge_c") | |||
| dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | |||
| dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | |||
| test_op = c_vision.RandomHorizontalFlipWithBBox(1) | |||
| # maps to fix annotations to minddata standard | |||
| dataVoc1 = dataVoc1.map(input_columns=["annotation"], | |||
| output_columns=["annotation"], | |||
| operations=fix_annotate) | |||
| dataVoc2 = dataVoc2.map(input_columns=["annotation"], | |||
| output_columns=["annotation"], | |||
| operations=fix_annotate) | |||
| # map to apply ops | |||
| # Add column for "annotation" | |||
| dataVoc1 = dataVoc1.map(input_columns=["image", "annotation"], | |||
| output_columns=["image", "annotation"], | |||
| columns_order=["image", "annotation"], | |||
| operations=lambda img, bbox: | |||
| (img, np.array([[0, 0, img.shape[1], img.shape[0], 0, 0, 0]]).astype(np.uint32))) | |||
| dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"], | |||
| output_columns=["image", "annotation"], | |||
| columns_order=["image", "annotation"], | |||
| operations=lambda img, bbox: | |||
| (img, np.array([[0, 0, img.shape[1], img.shape[0], 0, 0, 0]]).astype(np.uint32))) | |||
| dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"], | |||
| output_columns=["image", "annotation"], | |||
| columns_order=["image", "annotation"], | |||
| operations=[test_op]) | |||
| unaugSamp, augSamp = [], [] | |||
| for unAug, Aug in zip(dataVoc1.create_dict_iterator(), dataVoc2.create_dict_iterator()): | |||
| unaugSamp.append(unAug) | |||
| augSamp.append(Aug) | |||
| if plot_vis: | |||
| visualize_with_bounding_boxes(unaugSamp, augSamp) | |||
| def test_random_horizontal_flip_with_bbox_invalid_prob_c(): | |||
| """ | |||
| Test RandomHorizontalFlipWithBBox op with invalid input probability | |||
| """ | |||
| logger.info("test_random_horizontal_bbox_invalid_prob_c") | |||
| dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | |||
| try: | |||
| # Note: Valid range of prob should be [0.0, 1.0] | |||
| test_op = c_vision.RandomHorizontalFlipWithBBox(1.5) | |||
| dataVoc2 = dataVoc2.map(input_columns=["annotation"], | |||
| output_columns=["annotation"], | |||
| operations=fix_annotate) | |||
| # map to apply ops | |||
| dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"], | |||
| output_columns=["image", "annotation"], | |||
| columns_order=["image", "annotation"], | |||
| operations=[test_op]) # Add column for "annotation" | |||
| except ValueError as error: | |||
| logger.info("Got an exception in DE: {}".format(str(error))) | |||
| assert "Input is not" in str(error) | |||
| def test_random_horizontal_flip_with_bbox_invalid_bounds_c(): | |||
| """ | |||
| Test RandomHorizontalFlipWithBBox op with invalid bounding boxes | |||
| """ | |||
| logger.info("test_random_horizontal_bbox_invalid_bounds_c") | |||
| test_op = c_vision.RandomHorizontalFlipWithBBox(1) | |||
| dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | |||
| check_bad_bbox(dataVoc2, test_op, InvalidBBoxType.WidthOverflow, "bounding boxes is out of bounds of the image") | |||
| dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | |||
| check_bad_bbox(dataVoc2, test_op, InvalidBBoxType.HeightOverflow, "bounding boxes is out of bounds of the image") | |||
| dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | |||
| check_bad_bbox(dataVoc2, test_op, InvalidBBoxType.NegativeXY, "min_x") | |||
| dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | |||
| check_bad_bbox(dataVoc2, test_op, InvalidBBoxType.WrongShape, "4 features") | |||
| if __name__ == "__main__": | |||
| # set to false to not show plots | |||
| test_random_horizontal_flip_with_bbox_op_c(plot_vis=False) | |||
| test_random_horizontal_bbox_with_bbox_valid_rand_c(plot_vis=False) | |||
| test_random_horizontal_flip_with_bbox_valid_edge_c(plot_vis=False) | |||
| test_random_horizontal_flip_with_bbox_invalid_prob_c() | |||
| test_random_horizontal_flip_with_bbox_invalid_bounds_c() | |||
| @@ -1,198 +0,0 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """ | |||
| Testing the random resize with bounding boxes op in DE | |||
| """ | |||
| import numpy as np | |||
| import mindspore.dataset as ds | |||
| import mindspore.dataset.transforms.vision.c_transforms as c_vision | |||
| from mindspore import log as logger | |||
| from util import visualize_with_bounding_boxes, InvalidBBoxType, check_bad_bbox, \ | |||
| config_get_set_seed, config_get_set_num_parallel_workers, save_and_check_md5 | |||
| GENERATE_GOLDEN = False | |||
| DATA_DIR = "../data/dataset/testVOC2012_2" | |||
| def fix_annotate(bboxes): | |||
| """ | |||
| Fix annotations to format followed by mindspore. | |||
| :param bboxes: in [label, x_min, y_min, w, h, truncate, difficult] format | |||
| :return: annotation in [x_min, y_min, w, h, label, truncate, difficult] format | |||
| """ | |||
| for (i, box) in enumerate(bboxes): | |||
| if box.size == 7: | |||
| bboxes[i] = np.roll(box, -1) | |||
| else: | |||
| print("ERROR: Invalid Bounding Box size provided") | |||
| break | |||
| return bboxes | |||
| def test_random_resize_with_bbox_op_rand_c(plot_vis=False): | |||
| """ | |||
| Prints images and bboxes side by side with and without RandomResizeWithBBox Op applied, | |||
| tests with MD5 check, expected to pass | |||
| """ | |||
| logger.info("test_random_resize_with_bbox_rand_c") | |||
| original_seed = config_get_set_seed(1) | |||
| original_num_parallel_workers = config_get_set_num_parallel_workers(1) | |||
| # Load dataset | |||
| dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", | |||
| decode=True, shuffle=False) | |||
| dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", | |||
| decode=True, shuffle=False) | |||
| test_op = c_vision.RandomResizeWithBBox(200) | |||
| dataVoc1 = dataVoc1.map(input_columns=["annotation"], | |||
| output_columns=["annotation"], | |||
| operations=fix_annotate) | |||
| dataVoc2 = dataVoc2.map(input_columns=["annotation"], | |||
| output_columns=["annotation"], | |||
| operations=fix_annotate) | |||
| # map to apply ops | |||
| dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"], | |||
| output_columns=["image", "annotation"], | |||
| columns_order=["image", "annotation"], | |||
| operations=[test_op]) | |||
| filename = "random_resize_with_bbox_op_01_c_result.npz" | |||
| save_and_check_md5(dataVoc2, filename, generate_golden=GENERATE_GOLDEN) | |||
| unaugSamp, augSamp = [], [] | |||
| for unAug, Aug in zip(dataVoc1.create_dict_iterator(), dataVoc2.create_dict_iterator()): | |||
| unaugSamp.append(unAug) | |||
| augSamp.append(Aug) | |||
| if plot_vis: | |||
| visualize_with_bounding_boxes(unaugSamp, augSamp) | |||
| # Restore config setting | |||
| ds.config.set_seed(original_seed) | |||
| ds.config.set_num_parallel_workers(original_num_parallel_workers) | |||
| def test_random_resize_with_bbox_op_edge_c(plot_vis=False): | |||
| """ | |||
| Prints images and bboxes side by side with and without RandomresizeWithBBox Op applied, | |||
| applied on dynamically generated edge case, expected to pass. edge case is when bounding | |||
| box has dimensions as the image itself. | |||
| """ | |||
| logger.info("test_random_resize_with_bbox_op_edge_c") | |||
| dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", | |||
| decode=True, shuffle=False) | |||
| dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", | |||
| decode=True, shuffle=False) | |||
| test_op = c_vision.RandomResizeWithBBox(500) | |||
| dataVoc1 = dataVoc1.map(input_columns=["annotation"], | |||
| output_columns=["annotation"], | |||
| operations=fix_annotate) | |||
| dataVoc2 = dataVoc2.map(input_columns=["annotation"], | |||
| output_columns=["annotation"], | |||
| operations=fix_annotate) | |||
| # maps to convert data into valid edge case data | |||
| dataVoc1 = dataVoc1.map(input_columns=["image", "annotation"], | |||
| output_columns=["image", "annotation"], | |||
| columns_order=["image", "annotation"], | |||
| operations=[lambda img, bboxes: ( | |||
| img, np.array([[0, 0, img.shape[1], img.shape[0]]]).astype(bboxes.dtype))]) | |||
| dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"], | |||
| output_columns=["image", "annotation"], | |||
| columns_order=["image", "annotation"], | |||
| operations=[lambda img, bboxes: ( | |||
| img, np.array([[0, 0, img.shape[1], img.shape[0]]]).astype(bboxes.dtype)), test_op]) | |||
| unaugSamp, augSamp = [], [] | |||
| for unAug, Aug in zip(dataVoc1.create_dict_iterator(), dataVoc2.create_dict_iterator()): | |||
| unaugSamp.append(unAug) | |||
| augSamp.append(Aug) | |||
| if plot_vis: | |||
| visualize_with_bounding_boxes(unaugSamp, augSamp) | |||
| def test_random_resize_with_bbox_op_invalid_c(): | |||
| """ | |||
| Test RandomResizeWithBBox Op on invalid constructor parameters, expected to raise ValueError | |||
| """ | |||
| logger.info("test_random_resize_with_bbox_op_invalid_c") | |||
| try: | |||
| # zero value for resize | |||
| c_vision.RandomResizeWithBBox(0) | |||
| except ValueError as err: | |||
| logger.info("Got an exception in DE: {}".format(str(err))) | |||
| assert "Input is not" in str(err) | |||
| try: | |||
| # one of the size values is zero | |||
| c_vision.RandomResizeWithBBox((0, 100)) | |||
| except ValueError as err: | |||
| logger.info("Got an exception in DE: {}".format(str(err))) | |||
| assert "Input is not" in str(err) | |||
| try: | |||
| # negative value for resize | |||
| c_vision.RandomResizeWithBBox(-10) | |||
| except ValueError as err: | |||
| logger.info("Got an exception in DE: {}".format(str(err))) | |||
| assert "Input is not" in str(err) | |||
| try: | |||
| # invalid input shape | |||
| c_vision.RandomResizeWithBBox((100, 100, 100)) | |||
| except TypeError as err: | |||
| logger.info("Got an exception in DE: {}".format(str(err))) | |||
| assert "Size should be" in str(err) | |||
| def test_random_resize_with_bbox_op_bad_c(): | |||
| """ | |||
| Tests RandomResizeWithBBox Op with invalid bounding boxes, expected to catch multiple errors | |||
| """ | |||
| logger.info("test_random_resize_with_bbox_op_bad_c") | |||
| test_op = c_vision.RandomResizeWithBBox((400, 300)) | |||
| data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | |||
| check_bad_bbox(data_voc2, test_op, InvalidBBoxType.WidthOverflow, "bounding boxes is out of bounds of the image") | |||
| data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | |||
| check_bad_bbox(data_voc2, test_op, InvalidBBoxType.HeightOverflow, "bounding boxes is out of bounds of the image") | |||
| data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | |||
| check_bad_bbox(data_voc2, test_op, InvalidBBoxType.NegativeXY, "min_x") | |||
| data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | |||
| check_bad_bbox(data_voc2, test_op, InvalidBBoxType.WrongShape, "4 features") | |||
| if __name__ == "__main__": | |||
| test_random_resize_with_bbox_op_rand_c(plot_vis=False) | |||
| test_random_resize_with_bbox_op_edge_c(plot_vis=False) | |||
| test_random_resize_with_bbox_op_invalid_c() | |||
| test_random_resize_with_bbox_op_bad_c() | |||
| @@ -1,227 +0,0 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """ | |||
| Testing RandomVerticalFlipWithBBox op in DE | |||
| """ | |||
| import numpy as np | |||
| import mindspore.dataset as ds | |||
| import mindspore.dataset.transforms.vision.c_transforms as c_vision | |||
| from mindspore import log as logger | |||
| from util import visualize_with_bounding_boxes, InvalidBBoxType, check_bad_bbox, \ | |||
| config_get_set_seed, config_get_set_num_parallel_workers, save_and_check_md5 | |||
| GENERATE_GOLDEN = False | |||
| # updated VOC dataset with correct annotations | |||
| DATA_DIR = "../data/dataset/testVOC2012_2" | |||
| def fix_annotate(bboxes): | |||
| """ | |||
| Fix annotations to format followed by mindspore. | |||
| :param bboxes: in [label, x_min, y_min, w, h, truncate, difficult] format | |||
| :return: annotation in [x_min, y_min, w, h, label, truncate, difficult] format | |||
| """ | |||
| for bbox in bboxes: | |||
| if bbox.size == 7: | |||
| tmp = bbox[0] | |||
| bbox[0] = bbox[1] | |||
| bbox[1] = bbox[2] | |||
| bbox[2] = bbox[3] | |||
| bbox[3] = bbox[4] | |||
| bbox[4] = tmp | |||
| else: | |||
| print("ERROR: Invalid Bounding Box size provided") | |||
| break | |||
| return bboxes | |||
| def test_random_vertical_flip_with_bbox_op_c(plot_vis=False): | |||
| """ | |||
| Prints images and bboxes side by side with and without RandomVerticalFlipWithBBox Op applied | |||
| """ | |||
| logger.info("test_random_vertical_flip_with_bbox_op_c") | |||
| # Load dataset | |||
| dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", | |||
| decode=True, shuffle=False) | |||
| dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", | |||
| decode=True, shuffle=False) | |||
| test_op = c_vision.RandomVerticalFlipWithBBox(1) | |||
| dataVoc1 = dataVoc1.map(input_columns=["annotation"], | |||
| output_columns=["annotation"], | |||
| operations=fix_annotate) | |||
| dataVoc2 = dataVoc2.map(input_columns=["annotation"], | |||
| output_columns=["annotation"], | |||
| operations=fix_annotate) | |||
| # map to apply ops | |||
| dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"], | |||
| output_columns=["image", "annotation"], | |||
| columns_order=["image", "annotation"], | |||
| operations=[test_op]) | |||
| unaugSamp, augSamp = [], [] | |||
| for unAug, Aug in zip(dataVoc1.create_dict_iterator(), dataVoc2.create_dict_iterator()): | |||
| unaugSamp.append(unAug) | |||
| augSamp.append(Aug) | |||
| if plot_vis: | |||
| visualize_with_bounding_boxes(unaugSamp, augSamp) | |||
| def test_random_vertical_flip_with_bbox_op_rand_c(plot_vis=False): | |||
| """ | |||
| Prints images and bboxes side by side with and without RandomVerticalFlipWithBBox Op applied, | |||
| tests with MD5 check, expected to pass | |||
| """ | |||
| logger.info("test_random_vertical_flip_with_bbox_op_rand_c") | |||
| original_seed = config_get_set_seed(29847) | |||
| original_num_parallel_workers = config_get_set_num_parallel_workers(1) | |||
| # Load dataset | |||
| dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", | |||
| decode=True, shuffle=False) | |||
| dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", | |||
| decode=True, shuffle=False) | |||
| test_op = c_vision.RandomVerticalFlipWithBBox(0.8) | |||
| dataVoc1 = dataVoc1.map(input_columns=["annotation"], | |||
| output_columns=["annotation"], | |||
| operations=fix_annotate) | |||
| dataVoc2 = dataVoc2.map(input_columns=["annotation"], | |||
| output_columns=["annotation"], | |||
| operations=fix_annotate) | |||
| # map to apply ops | |||
| dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"], | |||
| output_columns=["image", "annotation"], | |||
| columns_order=["image", "annotation"], | |||
| operations=[test_op]) | |||
| filename = "random_vertical_flip_with_bbox_01_c_result.npz" | |||
| save_and_check_md5(dataVoc2, filename, generate_golden=GENERATE_GOLDEN) | |||
| unaugSamp, augSamp = [], [] | |||
| for unAug, Aug in zip(dataVoc1.create_dict_iterator(), dataVoc2.create_dict_iterator()): | |||
| unaugSamp.append(unAug) | |||
| augSamp.append(Aug) | |||
| if plot_vis: | |||
| visualize_with_bounding_boxes(unaugSamp, augSamp) | |||
| # Restore config setting | |||
| ds.config.set_seed(original_seed) | |||
| ds.config.set_num_parallel_workers(original_num_parallel_workers) | |||
| def test_random_vertical_flip_with_bbox_op_edge_c(plot_vis=False): | |||
| """ | |||
| Prints images and bboxes side by side with and without RandomVerticalFlipWithBBox Op applied, | |||
| applied on dynamically generated edge case, expected to pass | |||
| """ | |||
| logger.info("test_random_vertical_flip_with_bbox_op_edge_c") | |||
| dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", | |||
| decode=True, shuffle=False) | |||
| dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", | |||
| decode=True, shuffle=False) | |||
| test_op = c_vision.RandomVerticalFlipWithBBox(1) | |||
| dataVoc1 = dataVoc1.map(input_columns=["annotation"], | |||
| output_columns=["annotation"], | |||
| operations=fix_annotate) | |||
| dataVoc2 = dataVoc2.map(input_columns=["annotation"], | |||
| output_columns=["annotation"], | |||
| operations=fix_annotate) | |||
| # maps to convert data into valid edge case data | |||
| dataVoc1 = dataVoc1.map(input_columns=["image", "annotation"], | |||
| output_columns=["image", "annotation"], | |||
| columns_order=["image", "annotation"], | |||
| operations=[lambda img, bboxes: (img, np.array([[0, 0, img.shape[1], img.shape[0]]]).astype(bboxes.dtype))]) | |||
| # Test Op added to list of Operations here | |||
| dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"], | |||
| output_columns=["image", "annotation"], | |||
| columns_order=["image", "annotation"], | |||
| operations=[lambda img, bboxes: (img, np.array([[0, 0, img.shape[1], img.shape[0]]]).astype(bboxes.dtype)), test_op]) | |||
| unaugSamp, augSamp = [], [] | |||
| for unAug, Aug in zip(dataVoc1.create_dict_iterator(), dataVoc2.create_dict_iterator()): | |||
| unaugSamp.append(unAug) | |||
| augSamp.append(Aug) | |||
| if plot_vis: | |||
| visualize_with_bounding_boxes(unaugSamp, augSamp) | |||
| def test_random_vertical_flip_with_bbox_op_invalid_c(): | |||
| """ | |||
| Test RandomVerticalFlipWithBBox Op on invalid constructor parameters, expected to raise ValueError | |||
| """ | |||
| logger.info("test_random_vertical_flip_with_bbox_op_invalid_c") | |||
| dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", | |||
| decode=True, shuffle=False) | |||
| try: | |||
| test_op = c_vision.RandomVerticalFlipWithBBox(2) | |||
| dataVoc2 = dataVoc2.map(input_columns=["annotation"], | |||
| output_columns=["annotation"], | |||
| operations=fix_annotate) | |||
| # map to apply ops | |||
| dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"], | |||
| output_columns=["image", "annotation"], | |||
| columns_order=["image", "annotation"], | |||
| operations=[test_op]) | |||
| for _ in dataVoc2.create_dict_iterator(): | |||
| break | |||
| except ValueError as err: | |||
| logger.info("Got an exception in DE: {}".format(str(err))) | |||
| assert "Input is not" in str(err) | |||
| def test_random_vertical_flip_with_bbox_op_bad_c(): | |||
| """ | |||
| Tests RandomVerticalFlipWithBBox Op with invalid bounding boxes, expected to catch multiple errors | |||
| """ | |||
| logger.info("test_random_vertical_flip_with_bbox_op_bad_c") | |||
| test_op = c_vision.RandomVerticalFlipWithBBox(1) | |||
| data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | |||
| check_bad_bbox(data_voc2, test_op, InvalidBBoxType.WidthOverflow, "bounding boxes is out of bounds of the image") | |||
| data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | |||
| check_bad_bbox(data_voc2, test_op, InvalidBBoxType.HeightOverflow, "bounding boxes is out of bounds of the image") | |||
| data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | |||
| check_bad_bbox(data_voc2, test_op, InvalidBBoxType.NegativeXY, "min_x") | |||
| data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | |||
| check_bad_bbox(data_voc2, test_op, InvalidBBoxType.WrongShape, "4 features") | |||
| if __name__ == "__main__": | |||
| test_random_vertical_flip_with_bbox_op_c(plot_vis=True) | |||
| test_random_vertical_flip_with_bbox_op_rand_c(plot_vis=True) | |||
| test_random_vertical_flip_with_bbox_op_edge_c(plot_vis=True) | |||
| test_random_vertical_flip_with_bbox_op_invalid_c() | |||
| test_random_vertical_flip_with_bbox_op_bad_c() | |||
| @@ -1,169 +0,0 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """ | |||
| Testing the resize with bounding boxes op in DE | |||
| """ | |||
| import numpy as np | |||
| import mindspore.dataset as ds | |||
| import mindspore.dataset.transforms.vision.c_transforms as c_vision | |||
| from mindspore import log as logger | |||
| from util import visualize_with_bounding_boxes, InvalidBBoxType, check_bad_bbox, \ | |||
| save_and_check_md5 | |||
| GENERATE_GOLDEN = False | |||
| DATA_DIR = "../data/dataset/testVOC2012_2" | |||
| def fix_annotate(bboxes): | |||
| """ | |||
| Fix annotations to format followed by mindspore. | |||
| :param bboxes: in [label, x_min, y_min, w, h, truncate, difficult] format | |||
| :return: annotation in [x_min, y_min, w, h, label, truncate, difficult] format | |||
| """ | |||
| for (i, box) in enumerate(bboxes): | |||
| if box.size == 7: | |||
| bboxes[i] = np.roll(box, -1) | |||
| else: | |||
| print("ERROR: Invalid Bounding Box size provided") | |||
| break | |||
| return bboxes | |||
| def test_resize_with_bbox_op_c(plot_vis=False): | |||
| """ | |||
| Prints images and bboxes side by side with and without ResizeWithBBox Op applied, | |||
| tests with MD5 check, expected to pass | |||
| """ | |||
| logger.info("test_resize_with_bbox_op_c") | |||
| # Load dataset | |||
| dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", | |||
| decode=True, shuffle=False) | |||
| dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", | |||
| decode=True, shuffle=False) | |||
| test_op = c_vision.ResizeWithBBox(200) | |||
| dataVoc1 = dataVoc1.map(input_columns=["annotation"], | |||
| output_columns=["annotation"], | |||
| operations=fix_annotate) | |||
| dataVoc2 = dataVoc2.map(input_columns=["annotation"], | |||
| output_columns=["annotation"], | |||
| operations=fix_annotate) | |||
| # map to apply ops | |||
| dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"], | |||
| output_columns=["image", "annotation"], | |||
| columns_order=["image", "annotation"], | |||
| operations=[test_op]) | |||
| filename = "resize_with_bbox_op_01_c_result.npz" | |||
| save_and_check_md5(dataVoc2, filename, generate_golden=GENERATE_GOLDEN) | |||
| unaugSamp, augSamp = [], [] | |||
| for unAug, Aug in zip(dataVoc1.create_dict_iterator(), dataVoc2.create_dict_iterator()): | |||
| unaugSamp.append(unAug) | |||
| augSamp.append(Aug) | |||
| if plot_vis: | |||
| visualize_with_bounding_boxes(unaugSamp, augSamp) | |||
| def test_resize_with_bbox_op_edge_c(plot_vis=False): | |||
| """ | |||
| Prints images and bboxes side by side with and without ResizeWithBBox Op applied, | |||
| applied on dynamically generated edge case, expected to pass. edge case is when bounding | |||
| box has dimensions as the image itself. | |||
| """ | |||
| logger.info("test_resize_with_bbox_op_edge_c") | |||
| dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", | |||
| decode=True, shuffle=False) | |||
| dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", | |||
| decode=True, shuffle=False) | |||
| test_op = c_vision.ResizeWithBBox(500) | |||
| dataVoc1 = dataVoc1.map(input_columns=["annotation"], | |||
| output_columns=["annotation"], | |||
| operations=fix_annotate) | |||
| dataVoc2 = dataVoc2.map(input_columns=["annotation"], | |||
| output_columns=["annotation"], | |||
| operations=fix_annotate) | |||
| # maps to convert data into valid edge case data | |||
| dataVoc1 = dataVoc1.map(input_columns=["image", "annotation"], | |||
| output_columns=["image", "annotation"], | |||
| columns_order=["image", "annotation"], | |||
| operations=[lambda img, bboxes: ( | |||
| img, np.array([[0, 0, img.shape[1], img.shape[0]]]).astype(bboxes.dtype))]) | |||
| # Test Op added to list of Operations here | |||
| dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"], | |||
| output_columns=["image", "annotation"], | |||
| columns_order=["image", "annotation"], | |||
| operations=[lambda img, bboxes: ( | |||
| img, np.array([[0, 0, img.shape[1], img.shape[0]]]).astype(bboxes.dtype)), test_op]) | |||
| unaugSamp, augSamp = [], [] | |||
| for unAug, Aug in zip(dataVoc1.create_dict_iterator(), dataVoc2.create_dict_iterator()): | |||
| unaugSamp.append(unAug) | |||
| augSamp.append(Aug) | |||
| if plot_vis: | |||
| visualize_with_bounding_boxes(unaugSamp, augSamp) | |||
| def test_resize_with_bbox_op_invalid_c(): | |||
| """ | |||
| Test ResizeWithBBox Op on invalid constructor parameters, expected to raise ValueError | |||
| """ | |||
| logger.info("test_resize_with_bbox_op_invalid_c") | |||
| try: | |||
| # invalid interpolation value | |||
| c_vision.ResizeWithBBox(400, interpolation="invalid") | |||
| except ValueError as err: | |||
| logger.info("Got an exception in DE: {}".format(str(err))) | |||
| assert "interpolation" in str(err) | |||
| def test_resize_with_bbox_op_bad_c(): | |||
| """ | |||
| Tests ResizeWithBBox Op with invalid bounding boxes, expected to catch multiple errors | |||
| """ | |||
| logger.info("test_resize_with_bbox_op_bad_c") | |||
| test_op = c_vision.ResizeWithBBox((200, 300)) | |||
| data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | |||
| check_bad_bbox(data_voc2, test_op, InvalidBBoxType.WidthOverflow, "bounding boxes is out of bounds of the image") | |||
| data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | |||
| check_bad_bbox(data_voc2, test_op, InvalidBBoxType.HeightOverflow, "bounding boxes is out of bounds of the image") | |||
| data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | |||
| check_bad_bbox(data_voc2, test_op, InvalidBBoxType.NegativeXY, "min_x") | |||
| data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | |||
| check_bad_bbox(data_voc2, test_op, InvalidBBoxType.WrongShape, "4 features") | |||
| if __name__ == "__main__": | |||
| test_resize_with_bbox_op_c(plot_vis=False) | |||
| test_resize_with_bbox_op_edge_c(plot_vis=False) | |||
| test_resize_with_bbox_op_invalid_c() | |||
| test_resize_with_bbox_op_bad_c() | |||