| @@ -215,7 +215,7 @@ Status CocoOp::LoadTensorRow(row_id_type row_id, const std::string &image_id, Te | |||
| auto itr = coordinate_map_.find(image_id); | |||
| if (itr == coordinate_map_.end()) RETURN_STATUS_UNEXPECTED("Invalid image_id found :" + image_id); | |||
| std::string kImageFile = image_folder_path_ + image_id; | |||
| std::string kImageFile = image_folder_path_ + std::string("/") + image_id; | |||
| RETURN_IF_NOT_OK(ReadImageToTensor(kImageFile, data_schema_->column(0), &image)); | |||
| auto bboxRow = itr->second; | |||
| @@ -34,7 +34,10 @@ namespace mindspore { | |||
| namespace dataset { | |||
| const char kColumnImage[] = "image"; | |||
| const char kColumnTarget[] = "target"; | |||
| const char kColumnAnnotation[] = "annotation"; | |||
| const char kColumnBbox[] = "bbox"; | |||
| const char kColumnLabel[] = "label"; | |||
| const char kColumnDifficult[] = "difficult"; | |||
| const char kColumnTruncate[] = "truncate"; | |||
| const char kJPEGImagesFolder[] = "/JPEGImages/"; | |||
| const char kSegmentationClassFolder[] = "/SegmentationClass/"; | |||
| const char kAnnotationsFolder[] = "/Annotations/"; | |||
| @@ -70,7 +73,13 @@ Status VOCOp::Builder::Build(std::shared_ptr<VOCOp> *ptr) { | |||
| RETURN_IF_NOT_OK(builder_schema_->AddColumn( | |||
| ColDescriptor(std::string(kColumnImage), DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1))); | |||
| RETURN_IF_NOT_OK(builder_schema_->AddColumn( | |||
| ColDescriptor(std::string(kColumnAnnotation), DataType(DataType::DE_FLOAT32), TensorImpl::kFlexible, 1))); | |||
| ColDescriptor(std::string(kColumnBbox), DataType(DataType::DE_FLOAT32), TensorImpl::kFlexible, 1))); | |||
| RETURN_IF_NOT_OK(builder_schema_->AddColumn( | |||
| ColDescriptor(std::string(kColumnLabel), DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1))); | |||
| RETURN_IF_NOT_OK(builder_schema_->AddColumn( | |||
| ColDescriptor(std::string(kColumnDifficult), DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1))); | |||
| RETURN_IF_NOT_OK(builder_schema_->AddColumn( | |||
| ColDescriptor(std::string(kColumnTruncate), DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1))); | |||
| } | |||
| *ptr = std::make_shared<VOCOp>(builder_task_type_, builder_task_mode_, builder_dir_, builder_labels_to_read_, | |||
| builder_num_workers_, builder_rows_per_buffer_, builder_op_connector_size_, | |||
| @@ -190,14 +199,16 @@ Status VOCOp::LoadTensorRow(row_id_type row_id, const std::string &image_id, Ten | |||
| RETURN_IF_NOT_OK(ReadImageToTensor(kTargetFile, data_schema_->column(1), &target)); | |||
| (*trow) = TensorRow(row_id, {std::move(image), std::move(target)}); | |||
| } else if (task_type_ == TaskType::Detection) { | |||
| std::shared_ptr<Tensor> image, annotation; | |||
| std::shared_ptr<Tensor> image; | |||
| TensorRow annotation; | |||
| const std::string kImageFile = | |||
| folder_path_ + std::string(kJPEGImagesFolder) + image_id + std::string(kImageExtension); | |||
| const std::string kAnnotationFile = | |||
| folder_path_ + std::string(kAnnotationsFolder) + image_id + std::string(kAnnotationExtension); | |||
| RETURN_IF_NOT_OK(ReadImageToTensor(kImageFile, data_schema_->column(0), &image)); | |||
| RETURN_IF_NOT_OK(ReadAnnotationToTensor(kAnnotationFile, data_schema_->column(1), &annotation)); | |||
| (*trow) = TensorRow(row_id, {std::move(image), std::move(annotation)}); | |||
| RETURN_IF_NOT_OK(ReadAnnotationToTensor(kAnnotationFile, &annotation)); | |||
| trow->push_back(std::move(image)); | |||
| trow->insert(trow->end(), annotation.begin(), annotation.end()); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| @@ -271,7 +282,7 @@ Status VOCOp::ParseAnnotationIds() { | |||
| const std::string kAnnotationName = | |||
| folder_path_ + std::string(kAnnotationsFolder) + id + std::string(kAnnotationExtension); | |||
| RETURN_IF_NOT_OK(ParseAnnotationBbox(kAnnotationName)); | |||
| if (label_map_.find(kAnnotationName) != label_map_.end()) { | |||
| if (annotation_map_.find(kAnnotationName) != annotation_map_.end()) { | |||
| new_image_ids.push_back(id); | |||
| } | |||
| } | |||
| @@ -293,7 +304,7 @@ Status VOCOp::ParseAnnotationBbox(const std::string &path) { | |||
| if (!Path(path).Exists()) { | |||
| RETURN_STATUS_UNEXPECTED("File is not found : " + path); | |||
| } | |||
| Bbox bbox; | |||
| Annotation annotation; | |||
| XMLDocument doc; | |||
| XMLError e = doc.LoadFile(common::SafeCStr(path)); | |||
| if (e != XMLError::XML_SUCCESS) { | |||
| @@ -332,13 +343,13 @@ Status VOCOp::ParseAnnotationBbox(const std::string &path) { | |||
| } | |||
| if (label_name != "" && (class_index_.empty() || class_index_.find(label_name) != class_index_.end()) && xmin > 0 && | |||
| ymin > 0 && xmax > xmin && ymax > ymin) { | |||
| std::vector<float> bbox_list = {xmin, ymin, xmax - xmin, ymax - ymin, truncated, difficult}; | |||
| bbox.emplace_back(std::make_pair(label_name, bbox_list)); | |||
| std::vector<float> bbox_list = {xmin, ymin, xmax - xmin, ymax - ymin, difficult, truncated}; | |||
| annotation.emplace_back(std::make_pair(label_name, bbox_list)); | |||
| label_index_[label_name] = 0; | |||
| } | |||
| object = object->NextSiblingElement("object"); | |||
| } | |||
| if (bbox.size() > 0) label_map_[path] = bbox; | |||
| if (annotation.size() > 0) annotation_map_[path] = annotation; | |||
| return Status::OK(); | |||
| } | |||
| @@ -374,31 +385,46 @@ Status VOCOp::ReadImageToTensor(const std::string &path, const ColDescriptor &co | |||
| return Status::OK(); | |||
| } | |||
| Status VOCOp::ReadAnnotationToTensor(const std::string &path, const ColDescriptor &col, | |||
| std::shared_ptr<Tensor> *tensor) { | |||
| Bbox bbox_info = label_map_[path]; | |||
| std::vector<float> bbox_row; | |||
| dsize_t bbox_column_num = 0, bbox_num = 0; | |||
| for (auto box : bbox_info) { | |||
| if (label_index_.find(box.first) != label_index_.end()) { | |||
| std::vector<float> bbox; | |||
| bbox.insert(bbox.end(), box.second.begin(), box.second.end()); | |||
| if (class_index_.find(box.first) != class_index_.end()) { | |||
| bbox.push_back(static_cast<float>(class_index_[box.first])); | |||
| // When task is Detection, user can get bbox data with four columns: | |||
| // column ["bbox"] with datatype=float32 | |||
| // column ["label"] with datatype=uint32 | |||
| // column ["difficult"] with datatype=uint32 | |||
| // column ["truncate"] with datatype=uint32 | |||
| Status VOCOp::ReadAnnotationToTensor(const std::string &path, TensorRow *row) { | |||
| Annotation annotation = annotation_map_[path]; | |||
| std::shared_ptr<Tensor> bbox, label, difficult, truncate; | |||
| std::vector<float> bbox_data; | |||
| std::vector<uint32_t> label_data, difficult_data, truncate_data; | |||
| dsize_t bbox_num = 0; | |||
| for (auto item : annotation) { | |||
| if (label_index_.find(item.first) != label_index_.end()) { | |||
| if (class_index_.find(item.first) != class_index_.end()) { | |||
| label_data.push_back(static_cast<uint32_t>(class_index_[item.first])); | |||
| } else { | |||
| bbox.push_back(static_cast<float>(label_index_[box.first])); | |||
| } | |||
| bbox_row.insert(bbox_row.end(), bbox.begin(), bbox.end()); | |||
| if (bbox_column_num == 0) { | |||
| bbox_column_num = static_cast<dsize_t>(bbox.size()); | |||
| label_data.push_back(static_cast<uint32_t>(label_index_[item.first])); | |||
| } | |||
| CHECK_FAIL_RETURN_UNEXPECTED(item.second.size() == 6, "annotation only support 6 parameters."); | |||
| std::vector<float> tmp_bbox = {(item.second)[0], (item.second)[1], (item.second)[2], (item.second)[3]}; | |||
| bbox_data.insert(bbox_data.end(), tmp_bbox.begin(), tmp_bbox.end()); | |||
| difficult_data.push_back(static_cast<uint32_t>((item.second)[4])); | |||
| truncate_data.push_back(static_cast<uint32_t>((item.second)[5])); | |||
| bbox_num++; | |||
| } | |||
| } | |||
| std::vector<dsize_t> bbox_dim = {bbox_num, bbox_column_num}; | |||
| RETURN_IF_NOT_OK(Tensor::CreateTensor(tensor, col.tensorImpl(), TensorShape(bbox_dim), col.type(), | |||
| reinterpret_cast<unsigned char *>(&bbox_row[0]))); | |||
| RETURN_IF_NOT_OK(Tensor::CreateTensor(&bbox, data_schema_->column(1).tensorImpl(), TensorShape({bbox_num, 4}), | |||
| data_schema_->column(1).type(), | |||
| reinterpret_cast<unsigned char *>(&bbox_data[0]))); | |||
| RETURN_IF_NOT_OK(Tensor::CreateTensor(&label, data_schema_->column(2).tensorImpl(), TensorShape({bbox_num, 1}), | |||
| data_schema_->column(2).type(), | |||
| reinterpret_cast<unsigned char *>(&label_data[0]))); | |||
| RETURN_IF_NOT_OK(Tensor::CreateTensor(&difficult, data_schema_->column(3).tensorImpl(), TensorShape({bbox_num, 1}), | |||
| data_schema_->column(3).type(), | |||
| reinterpret_cast<unsigned char *>(&difficult_data[0]))); | |||
| RETURN_IF_NOT_OK(Tensor::CreateTensor(&truncate, data_schema_->column(4).tensorImpl(), TensorShape({bbox_num, 1}), | |||
| data_schema_->column(4).type(), | |||
| reinterpret_cast<unsigned char *>(&truncate_data[0]))); | |||
| (*row) = TensorRow({std::move(bbox), std::move(label), std::move(difficult), std::move(truncate)}); | |||
| return Status::OK(); | |||
| } | |||
| @@ -40,7 +40,7 @@ namespace dataset { | |||
| template <typename T> | |||
| class Queue; | |||
| using Bbox = std::vector<std::pair<std::string, std::vector<float>>>; | |||
| using Annotation = std::vector<std::pair<std::string, std::vector<float>>>; | |||
| class VOCOp : public ParallelOp, public RandomAccessOp { | |||
| public: | |||
| @@ -234,10 +234,9 @@ class VOCOp : public ParallelOp, public RandomAccessOp { | |||
| Status ReadImageToTensor(const std::string &path, const ColDescriptor &col, std::shared_ptr<Tensor> *tensor); | |||
| // @param const std::string &path - path to the image file | |||
| // @param const ColDescriptor &col - contains tensor implementation and datatype | |||
| // @param std::shared_ptr<Tensor> tensor - return | |||
| // @param TensorRow *row - return | |||
| // @return Status - The error code return | |||
| Status ReadAnnotationToTensor(const std::string &path, const ColDescriptor &col, std::shared_ptr<Tensor> *tensor); | |||
| Status ReadAnnotationToTensor(const std::string &path, TensorRow *row); | |||
| // @param const std::vector<uint64_t> &keys - keys in ioblock | |||
| // @param std::unique_ptr<DataBuffer> db | |||
| @@ -287,7 +286,7 @@ class VOCOp : public ParallelOp, public RandomAccessOp { | |||
| QueueList<std::unique_ptr<IOBlock>> io_block_queues_; | |||
| std::map<std::string, int32_t> class_index_; | |||
| std::map<std::string, int32_t> label_index_; | |||
| std::map<std::string, Bbox> label_map_; | |||
| std::map<std::string, Annotation> annotation_map_; | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -4122,13 +4122,11 @@ class VOCDataset(MappableDataset): | |||
| """ | |||
| A source dataset for reading and parsing VOC dataset. | |||
| The generated dataset has two columns : | |||
| task='Detection' : ['image', 'annotation']; | |||
| task='Segmentation' : ['image', 'target']. | |||
| The shape of both column 'image' and 'target' is [image_size] if decode flag is False, or [H, W, C] | |||
| otherwise. | |||
| The type of both tensor 'image' and 'target' is uint8. | |||
| The type of tensor 'annotation' is uint32. | |||
| The generated dataset has multi-columns : | |||
| - task='Detection', column: [['image', dtype=uint8], ['bbox', dtype=float32], ['label', dtype=uint32], | |||
| ['difficult', dtype=uint32], ['truncate', dtype=uint32]]. | |||
| - task='Segmentation', column: [['image', dtype=uint8], ['target',dtype=uint8]]. | |||
| This dataset can take in a sampler. sampler and shuffle are mutually exclusive. Table | |||
| below shows what input args are allowed and their expected behavior. | |||
| @@ -49,9 +49,9 @@ def test_bounding_box_augment_with_rotation_op(plot_vis=False): | |||
| test_op = c_vision.BoundingBoxAugment(c_vision.RandomRotation(90), 1) | |||
| # map to apply ops | |||
| dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"], | |||
| output_columns=["image", "annotation"], | |||
| columns_order=["image", "annotation"], | |||
| dataVoc2 = dataVoc2.map(input_columns=["image", "bbox"], | |||
| output_columns=["image", "bbox"], | |||
| columns_order=["image", "bbox"], | |||
| operations=[test_op]) | |||
| filename = "bounding_box_augment_rotation_c_result.npz" | |||
| @@ -88,9 +88,9 @@ def test_bounding_box_augment_with_crop_op(plot_vis=False): | |||
| test_op = c_vision.BoundingBoxAugment(c_vision.RandomCrop(50), 0.9) | |||
| # map to apply ops | |||
| dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"], | |||
| output_columns=["image", "annotation"], | |||
| columns_order=["image", "annotation"], | |||
| dataVoc2 = dataVoc2.map(input_columns=["image", "bbox"], | |||
| output_columns=["image", "bbox"], | |||
| columns_order=["image", "bbox"], | |||
| operations=[test_op]) | |||
| filename = "bounding_box_augment_crop_c_result.npz" | |||
| @@ -126,10 +126,11 @@ def test_bounding_box_augment_valid_ratio_c(plot_vis=False): | |||
| test_op = c_vision.BoundingBoxAugment(c_vision.RandomHorizontalFlip(1), 0.9) | |||
| # map to apply ops | |||
| dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"], | |||
| output_columns=["image", "annotation"], | |||
| columns_order=["image", "annotation"], | |||
| operations=[test_op]) # Add column for "annotation" | |||
| dataVoc2 = dataVoc2.map(input_columns=["image", "bbox"], | |||
| output_columns=["image", "bbox"], | |||
| columns_order=["image", "bbox"], | |||
| operations=[test_op]) # Add column for "bbox" | |||
| filename = "bounding_box_augment_valid_ratio_c_result.npz" | |||
| save_and_check_md5(dataVoc2, filename, generate_golden=GENERATE_GOLDEN) | |||
| @@ -193,20 +194,20 @@ def test_bounding_box_augment_valid_edge_c(plot_vis=False): | |||
| test_op = c_vision.BoundingBoxAugment(c_vision.RandomHorizontalFlip(1), 1) | |||
| # map to apply ops | |||
| # Add column for "annotation" | |||
| dataVoc1 = dataVoc1.map(input_columns=["image", "annotation"], | |||
| output_columns=["image", "annotation"], | |||
| columns_order=["image", "annotation"], | |||
| # Add column for "bbox" | |||
| dataVoc1 = dataVoc1.map(input_columns=["image", "bbox"], | |||
| output_columns=["image", "bbox"], | |||
| columns_order=["image", "bbox"], | |||
| operations=lambda img, bbox: | |||
| (img, np.array([[0, 0, img.shape[1], img.shape[0], 0, 0, 0]]).astype(np.float32))) | |||
| dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"], | |||
| output_columns=["image", "annotation"], | |||
| columns_order=["image", "annotation"], | |||
| dataVoc2 = dataVoc2.map(input_columns=["image", "bbox"], | |||
| output_columns=["image", "bbox"], | |||
| columns_order=["image", "bbox"], | |||
| operations=lambda img, bbox: | |||
| (img, np.array([[0, 0, img.shape[1], img.shape[0], 0, 0, 0]]).astype(np.float32))) | |||
| dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"], | |||
| output_columns=["image", "annotation"], | |||
| columns_order=["image", "annotation"], | |||
| dataVoc2 = dataVoc2.map(input_columns=["image", "bbox"], | |||
| output_columns=["image", "bbox"], | |||
| columns_order=["image", "bbox"], | |||
| operations=[test_op]) | |||
| filename = "bounding_box_augment_valid_edge_c_result.npz" | |||
| save_and_check_md5(dataVoc2, filename, generate_golden=GENERATE_GOLDEN) | |||
| @@ -237,10 +238,10 @@ def test_bounding_box_augment_invalid_ratio_c(): | |||
| # ratio range is from 0 - 1 | |||
| test_op = c_vision.BoundingBoxAugment(c_vision.RandomHorizontalFlip(1), 1.5) | |||
| # map to apply ops | |||
| dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"], | |||
| output_columns=["image", "annotation"], | |||
| columns_order=["image", "annotation"], | |||
| operations=[test_op]) # Add column for "annotation" | |||
| dataVoc2 = dataVoc2.map(input_columns=["image", "bbox"], | |||
| output_columns=["image", "bbox"], | |||
| columns_order=["image", "bbox"], | |||
| operations=[test_op]) # Add column for "bbox" | |||
| except ValueError as error: | |||
| logger.info("Got an exception in DE: {}".format(str(error))) | |||
| assert "Input ratio is not within the required interval of (0.0 to 1.0)." in str(error) | |||
| @@ -17,6 +17,7 @@ import mindspore.dataset as ds | |||
| import mindspore.dataset.transforms.vision.c_transforms as vision | |||
| DATA_DIR = "../data/dataset/testCOCO/train/" | |||
| DATA_DIR_2 = "../data/dataset/testCOCO/train" | |||
| ANNOTATION_FILE = "../data/dataset/testCOCO/annotations/train.json" | |||
| KEYPOINT_FILE = "../data/dataset/testCOCO/annotations/key_point.json" | |||
| PANOPTIC_FILE = "../data/dataset/testCOCO/annotations/panoptic.json" | |||
| @@ -202,6 +203,17 @@ def test_coco_case_2(): | |||
| num_iter += 1 | |||
| assert num_iter == 24 | |||
| def test_coco_case_3(): | |||
| data1 = ds.CocoDataset(DATA_DIR_2, annotation_file=ANNOTATION_FILE, task="Detection", decode=True) | |||
| resize_op = vision.Resize((224, 224)) | |||
| data1 = data1.map(input_columns=["image"], operations=resize_op) | |||
| data1 = data1.repeat(4) | |||
| num_iter = 0 | |||
| for _ in data1.__iter__(): | |||
| num_iter += 1 | |||
| assert num_iter == 24 | |||
| def test_coco_case_exception(): | |||
| try: | |||
| data1 = ds.CocoDataset("path_not_exist/", annotation_file=ANNOTATION_FILE, task="Detection") | |||
| @@ -271,4 +283,5 @@ if __name__ == '__main__': | |||
| test_coco_case_0() | |||
| test_coco_case_1() | |||
| test_coco_case_2() | |||
| test_coco_case_3() | |||
| test_coco_case_exception() | |||
| @@ -36,8 +36,8 @@ def test_voc_detection(): | |||
| count = [0, 0, 0, 0, 0, 0] | |||
| for item in data1.create_dict_iterator(): | |||
| assert item["image"].shape[0] == IMAGE_SHAPE[num] | |||
| for bbox in item["annotation"]: | |||
| count[int(bbox[6])] += 1 | |||
| for label in item["label"]: | |||
| count[label[0]] += 1 | |||
| num += 1 | |||
| assert num == 9 | |||
| assert count == [3, 2, 1, 2, 4, 3] | |||
| @@ -54,9 +54,9 @@ def test_voc_class_index(): | |||
| num = 0 | |||
| count = [0, 0, 0, 0, 0, 0] | |||
| for item in data1.create_dict_iterator(): | |||
| for bbox in item["annotation"]: | |||
| assert (int(bbox[6]) == 0 or int(bbox[6]) == 1 or int(bbox[6]) == 5) | |||
| count[int(bbox[6])] += 1 | |||
| for label in item["label"]: | |||
| count[label[0]] += 1 | |||
| assert label[0] in (0, 1, 5) | |||
| num += 1 | |||
| assert num == 6 | |||
| assert count == [3, 2, 0, 0, 0, 3] | |||
| @@ -72,10 +72,9 @@ def test_voc_get_class_indexing(): | |||
| num = 0 | |||
| count = [0, 0, 0, 0, 0, 0] | |||
| for item in data1.create_dict_iterator(): | |||
| for bbox in item["annotation"]: | |||
| assert (int(bbox[6]) == 0 or int(bbox[6]) == 1 or int(bbox[6]) == 2 or int(bbox[6]) == 3 | |||
| or int(bbox[6]) == 4 or int(bbox[6]) == 5) | |||
| count[int(bbox[6])] += 1 | |||
| for label in item["label"]: | |||
| count[label[0]] += 1 | |||
| assert label[0] in (0, 1, 2, 3, 4, 5) | |||
| num += 1 | |||
| assert num == 9 | |||
| assert count == [3, 2, 1, 2, 4, 3] | |||
| @@ -48,9 +48,9 @@ def test_random_resized_crop_with_bbox_op_c(plot_vis=False): | |||
| test_op = c_vision.RandomResizedCropWithBBox((256, 512), (0.5, 0.5), (0.5, 0.5)) | |||
| # map to apply ops | |||
| dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"], | |||
| output_columns=["image", "annotation"], | |||
| columns_order=["image", "annotation"], | |||
| dataVoc2 = dataVoc2.map(input_columns=["image", "bbox"], | |||
| output_columns=["image", "bbox"], | |||
| columns_order=["image", "bbox"], | |||
| operations=[test_op]) | |||
| filename = "random_resized_crop_with_bbox_01_c_result.npz" | |||
| @@ -114,15 +114,15 @@ def test_random_resized_crop_with_bbox_op_edge_c(plot_vis=False): | |||
| test_op = c_vision.RandomResizedCropWithBBox((256, 512), (0.5, 0.5), (0.5, 0.5)) | |||
| # maps to convert data into valid edge case data | |||
| dataVoc1 = dataVoc1.map(input_columns=["image", "annotation"], | |||
| output_columns=["image", "annotation"], | |||
| columns_order=["image", "annotation"], | |||
| dataVoc1 = dataVoc1.map(input_columns=["image", "bbox"], | |||
| output_columns=["image", "bbox"], | |||
| columns_order=["image", "bbox"], | |||
| operations=[lambda img, bboxes: (img, np.array([[0, 0, img.shape[1], img.shape[0]]]).astype(bboxes.dtype))]) | |||
| # Test Op added to list of Operations here | |||
| dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"], | |||
| output_columns=["image", "annotation"], | |||
| columns_order=["image", "annotation"], | |||
| dataVoc2 = dataVoc2.map(input_columns=["image", "bbox"], | |||
| output_columns=["image", "bbox"], | |||
| columns_order=["image", "bbox"], | |||
| operations=[lambda img, bboxes: (img, np.array([[0, 0, img.shape[1], img.shape[0]]]).astype(bboxes.dtype)), test_op]) | |||
| unaugSamp, augSamp = [], [] | |||
| @@ -149,9 +149,9 @@ def test_random_resized_crop_with_bbox_op_invalid_c(): | |||
| test_op = c_vision.RandomResizedCropWithBBox((256, 512), (1, 0.5), (0.5, 0.5)) | |||
| # map to apply ops | |||
| dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"], | |||
| output_columns=["image", "annotation"], | |||
| columns_order=["image", "annotation"], | |||
| dataVoc2 = dataVoc2.map(input_columns=["image", "bbox"], | |||
| output_columns=["image", "bbox"], | |||
| columns_order=["image", "bbox"], | |||
| operations=[test_op]) | |||
| for _ in dataVoc2.create_dict_iterator(): | |||
| @@ -175,9 +175,9 @@ def test_random_resized_crop_with_bbox_op_invalid2_c(): | |||
| test_op = c_vision.RandomResizedCropWithBBox((256, 512), (1, 1), (1, 0.5)) | |||
| # map to apply ops | |||
| dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"], | |||
| output_columns=["image", "annotation"], | |||
| columns_order=["image", "annotation"], | |||
| dataVoc2 = dataVoc2.map(input_columns=["image", "bbox"], | |||
| output_columns=["image", "bbox"], | |||
| columns_order=["image", "bbox"], | |||
| operations=[test_op]) | |||
| for _ in dataVoc2.create_dict_iterator(): | |||
| @@ -206,9 +206,9 @@ def test_random_resized_crop_with_bbox_op_bad_c(): | |||
| if __name__ == "__main__": | |||
| test_random_resized_crop_with_bbox_op_c(plot_vis=True) | |||
| test_random_resized_crop_with_bbox_op_coco_c(plot_vis=True) | |||
| test_random_resized_crop_with_bbox_op_edge_c(plot_vis=True) | |||
| test_random_resized_crop_with_bbox_op_c(plot_vis=False) | |||
| test_random_resized_crop_with_bbox_op_coco_c(plot_vis=False) | |||
| test_random_resized_crop_with_bbox_op_edge_c(plot_vis=False) | |||
| test_random_resized_crop_with_bbox_op_invalid_c() | |||
| test_random_resized_crop_with_bbox_op_invalid2_c() | |||
| test_random_resized_crop_with_bbox_op_bad_c() | |||
| @@ -46,10 +46,10 @@ def test_random_crop_with_bbox_op_c(plot_vis=False): | |||
| test_op = c_vision.RandomCropWithBBox([512, 512], [200, 200, 200, 200]) | |||
| # map to apply ops | |||
| dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"], | |||
| output_columns=["image", "annotation"], | |||
| columns_order=["image", "annotation"], | |||
| operations=[test_op]) # Add column for "annotation" | |||
| dataVoc2 = dataVoc2.map(input_columns=["image", "bbox"], | |||
| output_columns=["image", "bbox"], | |||
| columns_order=["image", "bbox"], | |||
| operations=[test_op]) # Add column for "bbox" | |||
| unaugSamp, augSamp = [], [] | |||
| @@ -108,9 +108,9 @@ def test_random_crop_with_bbox_op2_c(plot_vis=False): | |||
| test_op = c_vision.RandomCropWithBBox(512, [200, 200, 200, 200], fill_value=(255, 255, 255)) | |||
| # map to apply ops | |||
| dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"], | |||
| output_columns=["image", "annotation"], | |||
| columns_order=["image", "annotation"], | |||
| dataVoc2 = dataVoc2.map(input_columns=["image", "bbox"], | |||
| output_columns=["image", "bbox"], | |||
| columns_order=["image", "bbox"], | |||
| operations=[test_op]) | |||
| filename = "random_crop_with_bbox_01_c_result.npz" | |||
| @@ -145,9 +145,9 @@ def test_random_crop_with_bbox_op3_c(plot_vis=False): | |||
| test_op = c_vision.RandomCropWithBBox(512, [200, 200, 200, 200], padding_mode=mode.Border.EDGE) | |||
| # map to apply ops | |||
| dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"], | |||
| output_columns=["image", "annotation"], | |||
| columns_order=["image", "annotation"], | |||
| dataVoc2 = dataVoc2.map(input_columns=["image", "bbox"], | |||
| output_columns=["image", "bbox"], | |||
| columns_order=["image", "bbox"], | |||
| operations=[test_op]) | |||
| unaugSamp, augSamp = [], [] | |||
| @@ -175,16 +175,16 @@ def test_random_crop_with_bbox_op_edge_c(plot_vis=False): | |||
| test_op = c_vision.RandomCropWithBBox(512, [200, 200, 200, 200], padding_mode=mode.Border.EDGE) | |||
| # maps to convert data into valid edge case data | |||
| dataVoc1 = dataVoc1.map(input_columns=["image", "annotation"], | |||
| output_columns=["image", "annotation"], | |||
| columns_order=["image", "annotation"], | |||
| dataVoc1 = dataVoc1.map(input_columns=["image", "bbox"], | |||
| output_columns=["image", "bbox"], | |||
| columns_order=["image", "bbox"], | |||
| operations=[lambda img, bboxes: ( | |||
| img, np.array([[0, 0, img.shape[1], img.shape[0]]]).astype(bboxes.dtype))]) | |||
| # Test Op added to list of Operations here | |||
| dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"], | |||
| output_columns=["image", "annotation"], | |||
| columns_order=["image", "annotation"], | |||
| dataVoc2 = dataVoc2.map(input_columns=["image", "bbox"], | |||
| output_columns=["image", "bbox"], | |||
| columns_order=["image", "bbox"], | |||
| operations=[lambda img, bboxes: ( | |||
| img, np.array([[0, 0, img.shape[1], img.shape[0]]]).astype(bboxes.dtype)), test_op]) | |||
| @@ -212,10 +212,10 @@ def test_random_crop_with_bbox_op_invalid_c(): | |||
| test_op = c_vision.RandomCropWithBBox([512, 512, 375]) | |||
| # map to apply ops | |||
| dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"], | |||
| output_columns=["image", "annotation"], | |||
| columns_order=["image", "annotation"], | |||
| operations=[test_op]) # Add column for "annotation" | |||
| dataVoc2 = dataVoc2.map(input_columns=["image", "bbox"], | |||
| output_columns=["image", "bbox"], | |||
| columns_order=["image", "bbox"], | |||
| operations=[test_op]) # Add column for "bbox" | |||
| for _ in dataVoc2.create_dict_iterator(): | |||
| break | |||
| @@ -45,9 +45,9 @@ def test_random_horizontal_flip_with_bbox_op_c(plot_vis=False): | |||
| test_op = c_vision.RandomHorizontalFlipWithBBox(1) | |||
| dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"], | |||
| output_columns=["image", "annotation"], | |||
| columns_order=["image", "annotation"], | |||
| dataVoc2 = dataVoc2.map(input_columns=["image", "bbox"], | |||
| output_columns=["image", "bbox"], | |||
| columns_order=["image", "bbox"], | |||
| operations=[test_op]) | |||
| unaugSamp, augSamp = [], [] | |||
| @@ -111,9 +111,9 @@ def test_random_horizontal_flip_with_bbox_valid_rand_c(plot_vis=False): | |||
| test_op = c_vision.RandomHorizontalFlipWithBBox(0.6) | |||
| # map to apply ops | |||
| dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"], | |||
| output_columns=["image", "annotation"], | |||
| columns_order=["image", "annotation"], | |||
| dataVoc2 = dataVoc2.map(input_columns=["image", "bbox"], | |||
| output_columns=["image", "bbox"], | |||
| columns_order=["image", "bbox"], | |||
| operations=[test_op]) | |||
| filename = "random_horizontal_flip_with_bbox_01_c_result.npz" | |||
| @@ -146,20 +146,20 @@ def test_random_horizontal_flip_with_bbox_valid_edge_c(plot_vis=False): | |||
| test_op = c_vision.RandomHorizontalFlipWithBBox(1) | |||
| # map to apply ops | |||
| # Add column for "annotation" | |||
| dataVoc1 = dataVoc1.map(input_columns=["image", "annotation"], | |||
| output_columns=["image", "annotation"], | |||
| columns_order=["image", "annotation"], | |||
| # Add column for "bbox" | |||
| dataVoc1 = dataVoc1.map(input_columns=["image", "bbox"], | |||
| output_columns=["image", "bbox"], | |||
| columns_order=["image", "bbox"], | |||
| operations=lambda img, bbox: | |||
| (img, np.array([[0, 0, img.shape[1], img.shape[0], 0, 0, 0]]).astype(np.float32))) | |||
| dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"], | |||
| output_columns=["image", "annotation"], | |||
| columns_order=["image", "annotation"], | |||
| dataVoc2 = dataVoc2.map(input_columns=["image", "bbox"], | |||
| output_columns=["image", "bbox"], | |||
| columns_order=["image", "bbox"], | |||
| operations=lambda img, bbox: | |||
| (img, np.array([[0, 0, img.shape[1], img.shape[0], 0, 0, 0]]).astype(np.float32))) | |||
| dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"], | |||
| output_columns=["image", "annotation"], | |||
| columns_order=["image", "annotation"], | |||
| dataVoc2 = dataVoc2.map(input_columns=["image", "bbox"], | |||
| output_columns=["image", "bbox"], | |||
| columns_order=["image", "bbox"], | |||
| operations=[test_op]) | |||
| unaugSamp, augSamp = [], [] | |||
| @@ -184,10 +184,10 @@ def test_random_horizontal_flip_with_bbox_invalid_prob_c(): | |||
| # Note: Valid range of prob should be [0.0, 1.0] | |||
| test_op = c_vision.RandomHorizontalFlipWithBBox(1.5) | |||
| # map to apply ops | |||
| dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"], | |||
| output_columns=["image", "annotation"], | |||
| columns_order=["image", "annotation"], | |||
| operations=[test_op]) # Add column for "annotation" | |||
| dataVoc2 = dataVoc2.map(input_columns=["image", "bbox"], | |||
| output_columns=["image", "bbox"], | |||
| columns_order=["image", "bbox"], | |||
| operations=[test_op]) # Add column for "bbox" | |||
| except ValueError as error: | |||
| logger.info("Got an exception in DE: {}".format(str(error))) | |||
| assert "Input prob is not within the required interval of (0.0 to 1.0)." in str(error) | |||
| @@ -48,9 +48,9 @@ def test_random_resize_with_bbox_op_voc_c(plot_vis=False): | |||
| test_op = c_vision.RandomResizeWithBBox(100) | |||
| # map to apply ops | |||
| dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"], | |||
| output_columns=["image", "annotation"], | |||
| columns_order=["image", "annotation"], | |||
| dataVoc2 = dataVoc2.map(input_columns=["image", "bbox"], | |||
| output_columns=["image", "bbox"], | |||
| columns_order=["image", "bbox"], | |||
| operations=[test_op]) | |||
| filename = "random_resize_with_bbox_op_01_c_voc_result.npz" | |||
| @@ -129,15 +129,15 @@ def test_random_resize_with_bbox_op_edge_c(plot_vis=False): | |||
| test_op = c_vision.RandomResizeWithBBox(500) | |||
| # maps to convert data into valid edge case data | |||
| dataVoc1 = dataVoc1.map(input_columns=["image", "annotation"], | |||
| output_columns=["image", "annotation"], | |||
| columns_order=["image", "annotation"], | |||
| dataVoc1 = dataVoc1.map(input_columns=["image", "bbox"], | |||
| output_columns=["image", "bbox"], | |||
| columns_order=["image", "bbox"], | |||
| operations=[lambda img, bboxes: ( | |||
| img, np.array([[0, 0, img.shape[1], img.shape[0]]]).astype(bboxes.dtype))]) | |||
| dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"], | |||
| output_columns=["image", "annotation"], | |||
| columns_order=["image", "annotation"], | |||
| dataVoc2 = dataVoc2.map(input_columns=["image", "bbox"], | |||
| output_columns=["image", "bbox"], | |||
| columns_order=["image", "bbox"], | |||
| operations=[lambda img, bboxes: ( | |||
| img, np.array([[0, 0, img.shape[1], img.shape[0]]]).astype(bboxes.dtype)), test_op]) | |||
| @@ -46,9 +46,9 @@ def test_random_vertical_flip_with_bbox_op_c(plot_vis=False): | |||
| test_op = c_vision.RandomVerticalFlipWithBBox(1) | |||
| # map to apply ops | |||
| dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"], | |||
| output_columns=["image", "annotation"], | |||
| columns_order=["image", "annotation"], | |||
| dataVoc2 = dataVoc2.map(input_columns=["image", "bbox"], | |||
| output_columns=["image", "bbox"], | |||
| columns_order=["image", "bbox"], | |||
| operations=[test_op]) | |||
| unaugSamp, augSamp = [], [] | |||
| @@ -111,9 +111,9 @@ def test_random_vertical_flip_with_bbox_op_rand_c(plot_vis=False): | |||
| test_op = c_vision.RandomVerticalFlipWithBBox(0.8) | |||
| # map to apply ops | |||
| dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"], | |||
| output_columns=["image", "annotation"], | |||
| columns_order=["image", "annotation"], | |||
| dataVoc2 = dataVoc2.map(input_columns=["image", "bbox"], | |||
| output_columns=["image", "bbox"], | |||
| columns_order=["image", "bbox"], | |||
| operations=[test_op]) | |||
| filename = "random_vertical_flip_with_bbox_01_c_result.npz" | |||
| @@ -148,15 +148,15 @@ def test_random_vertical_flip_with_bbox_op_edge_c(plot_vis=False): | |||
| test_op = c_vision.RandomVerticalFlipWithBBox(1) | |||
| # maps to convert data into valid edge case data | |||
| dataVoc1 = dataVoc1.map(input_columns=["image", "annotation"], | |||
| output_columns=["image", "annotation"], | |||
| columns_order=["image", "annotation"], | |||
| dataVoc1 = dataVoc1.map(input_columns=["image", "bbox"], | |||
| output_columns=["image", "bbox"], | |||
| columns_order=["image", "bbox"], | |||
| operations=[lambda img, bboxes: (img, np.array([[0, 0, img.shape[1], img.shape[0]]]).astype(bboxes.dtype))]) | |||
| # Test Op added to list of Operations here | |||
| dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"], | |||
| output_columns=["image", "annotation"], | |||
| columns_order=["image", "annotation"], | |||
| dataVoc2 = dataVoc2.map(input_columns=["image", "bbox"], | |||
| output_columns=["image", "bbox"], | |||
| columns_order=["image", "bbox"], | |||
| operations=[lambda img, bboxes: (img, np.array([[0, 0, img.shape[1], img.shape[0]]]).astype(bboxes.dtype)), test_op]) | |||
| unaugSamp, augSamp = [], [] | |||
| @@ -181,9 +181,9 @@ def test_random_vertical_flip_with_bbox_op_invalid_c(): | |||
| test_op = c_vision.RandomVerticalFlipWithBBox(2) | |||
| # map to apply ops | |||
| dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"], | |||
| output_columns=["image", "annotation"], | |||
| columns_order=["image", "annotation"], | |||
| dataVoc2 = dataVoc2.map(input_columns=["image", "bbox"], | |||
| output_columns=["image", "bbox"], | |||
| columns_order=["image", "bbox"], | |||
| operations=[test_op]) | |||
| for _ in dataVoc2.create_dict_iterator(): | |||
| @@ -48,9 +48,9 @@ def test_resize_with_bbox_op_voc_c(plot_vis=False): | |||
| test_op = c_vision.ResizeWithBBox(100) | |||
| # map to apply ops | |||
| dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"], | |||
| output_columns=["image", "annotation"], | |||
| columns_order=["image", "annotation"], | |||
| dataVoc2 = dataVoc2.map(input_columns=["image", "bbox"], | |||
| output_columns=["image", "bbox"], | |||
| columns_order=["image", "bbox"], | |||
| operations=[test_op]) | |||
| filename = "resize_with_bbox_op_01_c_voc_result.npz" | |||
| @@ -119,15 +119,15 @@ def test_resize_with_bbox_op_edge_c(plot_vis=False): | |||
| test_op = c_vision.ResizeWithBBox(500) | |||
| # maps to convert data into valid edge case data | |||
| dataVoc1 = dataVoc1.map(input_columns=["image", "annotation"], | |||
| output_columns=["image", "annotation"], | |||
| columns_order=["image", "annotation"], | |||
| dataVoc1 = dataVoc1.map(input_columns=["image", "bbox"], | |||
| output_columns=["image", "bbox"], | |||
| columns_order=["image", "bbox"], | |||
| operations=[lambda img, bboxes: ( | |||
| img, np.array([[0, 0, img.shape[1], img.shape[0]]]).astype(bboxes.dtype))]) | |||
| dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"], | |||
| output_columns=["image", "annotation"], | |||
| columns_order=["image", "annotation"], | |||
| dataVoc2 = dataVoc2.map(input_columns=["image", "bbox"], | |||
| output_columns=["image", "bbox"], | |||
| columns_order=["image", "bbox"], | |||
| operations=[lambda img, bboxes: ( | |||
| img, np.array([[0, 0, img.shape[1], img.shape[0]]]).astype(bboxes.dtype)), test_op]) | |||
| @@ -252,13 +252,13 @@ def visualize_image(image_original, image_de, mse=None, image_lib=None): | |||
| plt.show() | |||
| def visualize_with_bounding_boxes(orig, aug, annot_name="annotation", plot_rows=3): | |||
| def visualize_with_bounding_boxes(orig, aug, annot_name="bbox", plot_rows=3): | |||
| """ | |||
| Take a list of un-augmented and augmented images with "annotation" bounding boxes | |||
| Take a list of un-augmented and augmented images with "bbox" bounding boxes | |||
| Plot images to compare test correct BBox augment functionality | |||
| :param orig: list of original images and bboxes (without aug) | |||
| :param aug: list of augmented images and bboxes | |||
| :param annot_name: the dict key for bboxes in data, e.g "bbox" (COCO) / "annotation" (VOC) | |||
| :param annot_name: the dict key for bboxes in data, e.g "bbox" (COCO) / "bbox" (VOC) | |||
| :param plot_rows: number of rows on plot (rows = samples on one plot) | |||
| :return: None | |||
| """ | |||
| @@ -337,7 +337,7 @@ def check_bad_bbox(data, test_op, invalid_bbox_type, expected_error): | |||
| :return: None | |||
| """ | |||
| def add_bad_annotation(img, bboxes, invalid_bbox_type_): | |||
| def add_bad_bbox(img, bboxes, invalid_bbox_type_): | |||
| """ | |||
| Used to generate erroneous bounding box examples on given img. | |||
| :param img: image where the bounding boxes are. | |||
| @@ -366,15 +366,15 @@ def check_bad_bbox(data, test_op, invalid_bbox_type, expected_error): | |||
| try: | |||
| # map to use selected invalid bounding box type | |||
| data = data.map(input_columns=["image", "annotation"], | |||
| output_columns=["image", "annotation"], | |||
| columns_order=["image", "annotation"], | |||
| operations=lambda img, bboxes: add_bad_annotation(img, bboxes, invalid_bbox_type)) | |||
| data = data.map(input_columns=["image", "bbox"], | |||
| output_columns=["image", "bbox"], | |||
| columns_order=["image", "bbox"], | |||
| operations=lambda img, bboxes: add_bad_bbox(img, bboxes, invalid_bbox_type)) | |||
| # map to apply ops | |||
| data = data.map(input_columns=["image", "annotation"], | |||
| output_columns=["image", "annotation"], | |||
| columns_order=["image", "annotation"], | |||
| operations=[test_op]) # Add column for "annotation" | |||
| data = data.map(input_columns=["image", "bbox"], | |||
| output_columns=["image", "bbox"], | |||
| columns_order=["image", "bbox"], | |||
| operations=[test_op]) # Add column for "bbox" | |||
| for _, _ in enumerate(data.create_dict_iterator()): | |||
| break | |||
| except RuntimeError as error: | |||