| @@ -4974,9 +4974,30 @@ class CelebADataset(MappableDataset): | |||
| with open(attr_file, 'r') as f: | |||
| num_rows = int(f.readline()) | |||
| except FileNotFoundError: | |||
| raise RuntimeError("attr_file not found.") | |||
| raise RuntimeError("attr file can not be found.") | |||
| except BaseException: | |||
| raise RuntimeError("Get dataset size failed from attribution file.") | |||
| if self.usage != 'all': | |||
| partition_file = os.path.join(dir, "list_eval_partition.txt") | |||
| usage_type = 0 | |||
| partition_num = 0 | |||
| if self.usage == "train": | |||
| usage_type = 0 | |||
| elif self.usage == "valid": | |||
| usage_type = 1 | |||
| elif self.usage == "test": | |||
| usage_type = 2 | |||
| try: | |||
| with open(partition_file, 'r') as f: | |||
| for line in f.readlines(): | |||
| split_line = line.split(' ') | |||
| if int(split_line[1]) == usage_type: | |||
| partition_num += 1 | |||
| except FileNotFoundError: | |||
| raise RuntimeError("Partition file can not be found") | |||
| if partition_num < num_rows: | |||
| num_rows = partition_num | |||
| self.dataset_size = get_num_rows(num_rows, self.num_shards) | |||
| if self.num_samples is not None and self.num_samples < self.dataset_size: | |||
| self.dataset_size = self.num_samples | |||
| @@ -100,7 +100,7 @@ TEST_F(MindDataTestPipeline, TestCelebADefault) { | |||
| i++; | |||
| } | |||
| EXPECT_EQ(i, 2); | |||
| EXPECT_EQ(i, 4); | |||
| // Manually terminate the pipeline | |||
| iter->Stop(); | |||
| @@ -58,8 +58,10 @@ protected: | |||
| TEST_F(MindDataTestCelebaDataset, TestSequentialCeleba) { | |||
| std::string dir = datasets_root_path_ + "/testCelebAData/"; | |||
| uint32_t expect_labels[2][40] = {{0,1,1,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,1,1,0,1,0,0,1,0,0,1,0,0,0,1,1,0,1,0,1,0,0,1}, | |||
| {0,0,0,1,0,0,0,1,0,0,0,1,0,0,0,0,0,0,0,1,0,1,0,0,1,0,0,0,0,0,0,1,0,0,0,0,0,0,0,1}}; | |||
| uint32_t expect_labels[4][40] = {{0,1,1,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,1,1,0,1,0,0,1,0,0,1,0,0,0,1,1,0,1,0,1,0,0,1}, | |||
| {0,0,0,1,0,0,0,1,0,0,0,1,0,0,0,0,0,0,0,1,0,1,0,0,1,0,0,0,0,0,0,1,0,0,0,0,0,0,0,1}, | |||
| {0,0,0,1,0,0,0,1,0,0,0,1,0,0,0,0,0,0,0,1,0,1,0,0,1,0,0,0,0,0,0,1,0,0,0,0,0,0,0,1}, | |||
| {0,1,1,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,1,1,0,1,0,0,1,0,0,1,0,0,0,1,1,0,1,0,1,0,0,1}}; | |||
| uint32_t count = 0; | |||
| auto tree = Build({Celeba(16, 2, 32, dir)}); | |||
| tree->Prepare(); | |||
| @@ -81,16 +83,20 @@ TEST_F(MindDataTestCelebaDataset, TestSequentialCeleba) { | |||
| count++; | |||
| di.GetNextAsMap(&tersor_map); | |||
| } | |||
| EXPECT_TRUE(count == 2); | |||
| EXPECT_TRUE(count == 4); | |||
| } | |||
| } | |||
| TEST_F(MindDataTestCelebaDataset, TestCelebaRepeat) { | |||
| std::string dir = datasets_root_path_ + "/testCelebAData/"; | |||
| uint32_t expect_labels[4][40] = {{0,1,1,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,1,1,0,1,0,0,1,0,0,1,0,0,0,1,1,0,1,0,1,0,0,1}, | |||
| uint32_t expect_labels[8][40] = {{0,1,1,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,1,1,0,1,0,0,1,0,0,1,0,0,0,1,1,0,1,0,1,0,0,1}, | |||
| {0,0,0,1,0,0,0,1,0,0,0,1,0,0,0,0,0,0,0,1,0,1,0,0,1,0,0,0,0,0,0,1,0,0,0,0,0,0,0,1}, | |||
| {0,0,0,1,0,0,0,1,0,0,0,1,0,0,0,0,0,0,0,1,0,1,0,0,1,0,0,0,0,0,0,1,0,0,0,0,0,0,0,1}, | |||
| {0,1,1,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,1,1,0,1,0,0,1,0,0,1,0,0,0,1,1,0,1,0,1,0,0,1}, | |||
| {0,1,1,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,1,1,0,1,0,0,1,0,0,1,0,0,0,1,1,0,1,0,1,0,0,1}, | |||
| {0,0,0,1,0,0,0,1,0,0,0,1,0,0,0,0,0,0,0,1,0,1,0,0,1,0,0,0,0,0,0,1,0,0,0,0,0,0,0,1}}; | |||
| {0,0,0,1,0,0,0,1,0,0,0,1,0,0,0,0,0,0,0,1,0,1,0,0,1,0,0,0,0,0,0,1,0,0,0,0,0,0,0,1}, | |||
| {0,0,0,1,0,0,0,1,0,0,0,1,0,0,0,0,0,0,0,1,0,1,0,0,1,0,0,0,0,0,0,1,0,0,0,0,0,0,0,1}, | |||
| {0,1,1,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,1,1,0,1,0,0,1,0,0,1,0,0,0,1,1,0,1,0,1,0,0,1}}; | |||
| uint32_t count = 0; | |||
| auto tree = Build({Celeba(16, 2, 32, dir), Repeat(2)}); | |||
| tree->Prepare(); | |||
| @@ -112,7 +118,7 @@ TEST_F(MindDataTestCelebaDataset, TestCelebaRepeat) { | |||
| count++; | |||
| di.GetNextAsMap(&tersor_map); | |||
| } | |||
| EXPECT_TRUE(count == 4); | |||
| EXPECT_TRUE(count == 8); | |||
| } | |||
| } | |||
| @@ -1,4 +1,6 @@ | |||
| 2 | |||
| 4 | |||
| 5_o_Clock_Shadow Arched_Eyebrows Attractive Bags_Under_Eyes Bald Bangs Big_Lips Big_Nose Black_Hair Blond_Hair Blurry Brown_Hair Bushy_Eyebrows Chubby Double_Chin Eyeglasses Goatee Gray_Hair Heavy_Makeup High_Cheekbones Male Mouth_Slightly_Open Mustache Narrow_Eyes No_Beard Oval_Face Pale_Skin Pointy_Nose Receding_Hairline Rosy_Cheeks Sideburns Smiling Straight_Hair Wavy_Hair Wearing_Earrings Wearing_Hat Wearing_Lipstick Wearing_Necklace Wearing_Necktie Young | |||
| 1.JPEG -1 1 1 -1 -1 -1 -1 -1 -1 -1 -1 1 -1 -1 -1 -1 -1 -1 1 1 -1 1 -1 -1 1 -1 -1 1 -1 -1 -1 1 1 -1 1 -1 1 -1 -1 1 | |||
| 2.jpg -1 -1 -1 1 -1 -1 -1 1 -1 -1 -1 1 -1 -1 -1 -1 -1 -1 -1 1 -1 1 -1 -1 1 -1 -1 -1 -1 -1 -1 1 -1 -1 -1 -1 -1 -1 -1 1 | |||
| 2.jpg -1 -1 -1 1 -1 -1 -1 1 -1 -1 -1 1 -1 -1 -1 -1 -1 -1 -1 1 -1 1 -1 -1 1 -1 -1 -1 -1 -1 -1 1 -1 -1 -1 -1 -1 -1 -1 1 | |||
| 1.JPEG -1 1 1 -1 -1 -1 -1 -1 -1 -1 -1 1 -1 -1 -1 -1 -1 -1 1 1 -1 1 -1 -1 1 -1 -1 1 -1 -1 -1 1 1 -1 1 -1 1 -1 -1 1 | |||
| @@ -0,0 +1,4 @@ | |||
| 1.JPEG 0 | |||
| 2.jpeg 1 | |||
| 2.jpeg 2 | |||
| 2.jpeg 0 | |||
| @@ -25,6 +25,10 @@ def test_celeba_dataset_label(): | |||
| [0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 1, 0, 1, | |||
| 0, 0, 1], | |||
| [0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, | |||
| 0, 0, 1], | |||
| [0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, | |||
| 0, 0, 1], | |||
| [0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 1, 0, 1, | |||
| 0, 0, 1]] | |||
| count = 0 | |||
| for item in data.create_dict_iterator(num_epochs=1, output_numpy=True): | |||
| @@ -35,7 +39,7 @@ def test_celeba_dataset_label(): | |||
| for index in range(len(expect_labels[count])): | |||
| assert item["attr"][index] == expect_labels[count][index] | |||
| count = count + 1 | |||
| assert count == 2 | |||
| assert count == 4 | |||
| def test_celeba_dataset_op(): | |||
| @@ -54,14 +58,17 @@ def test_celeba_dataset_op(): | |||
| logger.info("----------image--------") | |||
| logger.info(item["image"]) | |||
| count = count + 1 | |||
| assert count == 4 | |||
| assert count == 8 | |||
| def test_celeba_dataset_ext(): | |||
| ext = [".JPEG"] | |||
| data = ds.CelebADataset(DATA_DIR, decode=True, extensions=ext) | |||
| expect_labels = [0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1, | |||
| 0, 1, 0, 1, 0, 0, 1], | |||
| expect_labels = [ | |||
| [0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1, | |||
| 0, 1, 0, 1, 0, 0, 1], | |||
| [0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1, | |||
| 0, 1, 0, 1, 0, 0, 1]] | |||
| count = 0 | |||
| for item in data.create_dict_iterator(num_epochs=1, output_numpy=True): | |||
| logger.info("----------image--------") | |||
| @@ -71,7 +78,7 @@ def test_celeba_dataset_ext(): | |||
| for index in range(len(expect_labels[count])): | |||
| assert item["attr"][index] == expect_labels[count][index] | |||
| count = count + 1 | |||
| assert count == 1 | |||
| assert count == 2 | |||
| def test_celeba_dataset_distribute(): | |||
| @@ -83,14 +90,25 @@ def test_celeba_dataset_distribute(): | |||
| logger.info("----------attr--------") | |||
| logger.info(item["attr"]) | |||
| count = count + 1 | |||
| assert count == 1 | |||
| assert count == 2 | |||
| def test_celeba_get_dataset_size(): | |||
| data = ds.CelebADataset(DATA_DIR, shuffle=False, decode=True) | |||
| size = data.get_dataset_size() | |||
| assert size == 4 | |||
| data = ds.CelebADataset(DATA_DIR, shuffle=False, decode=True, usage="train") | |||
| size = data.get_dataset_size() | |||
| assert size == 2 | |||
| data = ds.CelebADataset(DATA_DIR, shuffle=False, decode=True, usage="valid") | |||
| size = data.get_dataset_size() | |||
| assert size == 1 | |||
| data = ds.CelebADataset(DATA_DIR, shuffle=False, decode=True, usage="test") | |||
| size = data.get_dataset_size() | |||
| assert size == 1 | |||
| if __name__ == '__main__': | |||
| test_celeba_dataset_label() | |||
| @@ -504,7 +504,7 @@ def test_celeba_padded(): | |||
| count = 0 | |||
| for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True): | |||
| count = count + 1 | |||
| assert count == 2 | |||
| assert count == 4 | |||
| if __name__ == '__main__': | |||