Browse Source

fix get_dataset_size in CelebADataset when usage is not all

tags/v1.0.0
yanghaitao1 5 years ago
parent
commit
4ff4c17632
7 changed files with 67 additions and 16 deletions
  1. +22
    -1
      mindspore/dataset/engine/datasets.py
  2. +1
    -1
      tests/ut/cpp/dataset/c_api_datasets_test.cc
  3. +12
    -6
      tests/ut/cpp/dataset/celeba_op_test.cc
  4. +3
    -1
      tests/ut/data/dataset/testCelebAData/list_attr_celeba.txt
  5. +4
    -0
      tests/ut/data/dataset/testCelebAData/list_eval_partition.txt
  6. +24
    -6
      tests/ut/python/dataset/test_datasets_celeba.py
  7. +1
    -1
      tests/ut/python/dataset/test_paddeddataset.py

+ 22
- 1
mindspore/dataset/engine/datasets.py View File

@@ -4974,9 +4974,30 @@ class CelebADataset(MappableDataset):
with open(attr_file, 'r') as f:
num_rows = int(f.readline())
except FileNotFoundError:
raise RuntimeError("attr_file not found.")
raise RuntimeError("attr file can not be found.")
except BaseException:
raise RuntimeError("Get dataset size failed from attribution file.")
if self.usage != 'all':
partition_file = os.path.join(dir, "list_eval_partition.txt")
usage_type = 0
partition_num = 0
if self.usage == "train":
usage_type = 0
elif self.usage == "valid":
usage_type = 1
elif self.usage == "test":
usage_type = 2
try:
with open(partition_file, 'r') as f:
for line in f.readlines():
split_line = line.split(' ')
if int(split_line[1]) == usage_type:
partition_num += 1
except FileNotFoundError:
raise RuntimeError("Partition file can not be found")
if partition_num < num_rows:
num_rows = partition_num

self.dataset_size = get_num_rows(num_rows, self.num_shards)
if self.num_samples is not None and self.num_samples < self.dataset_size:
self.dataset_size = self.num_samples


+ 1
- 1
tests/ut/cpp/dataset/c_api_datasets_test.cc View File

@@ -100,7 +100,7 @@ TEST_F(MindDataTestPipeline, TestCelebADefault) {
i++;
}

EXPECT_EQ(i, 2);
EXPECT_EQ(i, 4);

// Manually terminate the pipeline
iter->Stop();


+ 12
- 6
tests/ut/cpp/dataset/celeba_op_test.cc View File

@@ -58,8 +58,10 @@ protected:

TEST_F(MindDataTestCelebaDataset, TestSequentialCeleba) {
std::string dir = datasets_root_path_ + "/testCelebAData/";
uint32_t expect_labels[2][40] = {{0,1,1,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,1,1,0,1,0,0,1,0,0,1,0,0,0,1,1,0,1,0,1,0,0,1},
{0,0,0,1,0,0,0,1,0,0,0,1,0,0,0,0,0,0,0,1,0,1,0,0,1,0,0,0,0,0,0,1,0,0,0,0,0,0,0,1}};
uint32_t expect_labels[4][40] = {{0,1,1,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,1,1,0,1,0,0,1,0,0,1,0,0,0,1,1,0,1,0,1,0,0,1},
{0,0,0,1,0,0,0,1,0,0,0,1,0,0,0,0,0,0,0,1,0,1,0,0,1,0,0,0,0,0,0,1,0,0,0,0,0,0,0,1},
{0,0,0,1,0,0,0,1,0,0,0,1,0,0,0,0,0,0,0,1,0,1,0,0,1,0,0,0,0,0,0,1,0,0,0,0,0,0,0,1},
{0,1,1,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,1,1,0,1,0,0,1,0,0,1,0,0,0,1,1,0,1,0,1,0,0,1}};
uint32_t count = 0;
auto tree = Build({Celeba(16, 2, 32, dir)});
tree->Prepare();
@@ -81,16 +83,20 @@ TEST_F(MindDataTestCelebaDataset, TestSequentialCeleba) {
count++;
di.GetNextAsMap(&tersor_map);
}
EXPECT_TRUE(count == 2);
EXPECT_TRUE(count == 4);
}
}

TEST_F(MindDataTestCelebaDataset, TestCelebaRepeat) {
std::string dir = datasets_root_path_ + "/testCelebAData/";
uint32_t expect_labels[4][40] = {{0,1,1,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,1,1,0,1,0,0,1,0,0,1,0,0,0,1,1,0,1,0,1,0,0,1},
uint32_t expect_labels[8][40] = {{0,1,1,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,1,1,0,1,0,0,1,0,0,1,0,0,0,1,1,0,1,0,1,0,0,1},
{0,0,0,1,0,0,0,1,0,0,0,1,0,0,0,0,0,0,0,1,0,1,0,0,1,0,0,0,0,0,0,1,0,0,0,0,0,0,0,1},
{0,0,0,1,0,0,0,1,0,0,0,1,0,0,0,0,0,0,0,1,0,1,0,0,1,0,0,0,0,0,0,1,0,0,0,0,0,0,0,1},
{0,1,1,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,1,1,0,1,0,0,1,0,0,1,0,0,0,1,1,0,1,0,1,0,0,1},
{0,1,1,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,1,1,0,1,0,0,1,0,0,1,0,0,0,1,1,0,1,0,1,0,0,1},
{0,0,0,1,0,0,0,1,0,0,0,1,0,0,0,0,0,0,0,1,0,1,0,0,1,0,0,0,0,0,0,1,0,0,0,0,0,0,0,1}};
{0,0,0,1,0,0,0,1,0,0,0,1,0,0,0,0,0,0,0,1,0,1,0,0,1,0,0,0,0,0,0,1,0,0,0,0,0,0,0,1},
{0,0,0,1,0,0,0,1,0,0,0,1,0,0,0,0,0,0,0,1,0,1,0,0,1,0,0,0,0,0,0,1,0,0,0,0,0,0,0,1},
{0,1,1,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,1,1,0,1,0,0,1,0,0,1,0,0,0,1,1,0,1,0,1,0,0,1}};
uint32_t count = 0;
auto tree = Build({Celeba(16, 2, 32, dir), Repeat(2)});
tree->Prepare();
@@ -112,7 +118,7 @@ TEST_F(MindDataTestCelebaDataset, TestCelebaRepeat) {
count++;
di.GetNextAsMap(&tersor_map);
}
EXPECT_TRUE(count == 4);
EXPECT_TRUE(count == 8);
}
}



+ 3
- 1
tests/ut/data/dataset/testCelebAData/list_attr_celeba.txt View File

@@ -1,4 +1,6 @@
2
4
5_o_Clock_Shadow Arched_Eyebrows Attractive Bags_Under_Eyes Bald Bangs Big_Lips Big_Nose Black_Hair Blond_Hair Blurry Brown_Hair Bushy_Eyebrows Chubby Double_Chin Eyeglasses Goatee Gray_Hair Heavy_Makeup High_Cheekbones Male Mouth_Slightly_Open Mustache Narrow_Eyes No_Beard Oval_Face Pale_Skin Pointy_Nose Receding_Hairline Rosy_Cheeks Sideburns Smiling Straight_Hair Wavy_Hair Wearing_Earrings Wearing_Hat Wearing_Lipstick Wearing_Necklace Wearing_Necktie Young
1.JPEG -1 1 1 -1 -1 -1 -1 -1 -1 -1 -1 1 -1 -1 -1 -1 -1 -1 1 1 -1 1 -1 -1 1 -1 -1 1 -1 -1 -1 1 1 -1 1 -1 1 -1 -1 1
2.jpg -1 -1 -1 1 -1 -1 -1 1 -1 -1 -1 1 -1 -1 -1 -1 -1 -1 -1 1 -1 1 -1 -1 1 -1 -1 -1 -1 -1 -1 1 -1 -1 -1 -1 -1 -1 -1 1
2.jpg -1 -1 -1 1 -1 -1 -1 1 -1 -1 -1 1 -1 -1 -1 -1 -1 -1 -1 1 -1 1 -1 -1 1 -1 -1 -1 -1 -1 -1 1 -1 -1 -1 -1 -1 -1 -1 1
1.JPEG -1 1 1 -1 -1 -1 -1 -1 -1 -1 -1 1 -1 -1 -1 -1 -1 -1 1 1 -1 1 -1 -1 1 -1 -1 1 -1 -1 -1 1 1 -1 1 -1 1 -1 -1 1

+ 4
- 0
tests/ut/data/dataset/testCelebAData/list_eval_partition.txt View File

@@ -0,0 +1,4 @@
1.JPEG 0
2.jpeg 1
2.jpeg 2
2.jpeg 0

+ 24
- 6
tests/ut/python/dataset/test_datasets_celeba.py View File

@@ -25,6 +25,10 @@ def test_celeba_dataset_label():
[0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 1, 0, 1,
0, 0, 1],
[0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0,
0, 0, 1],
[0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0,
0, 0, 1],
[0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 1, 0, 1,
0, 0, 1]]
count = 0
for item in data.create_dict_iterator(num_epochs=1, output_numpy=True):
@@ -35,7 +39,7 @@ def test_celeba_dataset_label():
for index in range(len(expect_labels[count])):
assert item["attr"][index] == expect_labels[count][index]
count = count + 1
assert count == 2
assert count == 4


def test_celeba_dataset_op():
@@ -54,14 +58,17 @@ def test_celeba_dataset_op():
logger.info("----------image--------")
logger.info(item["image"])
count = count + 1
assert count == 4
assert count == 8


def test_celeba_dataset_ext():
ext = [".JPEG"]
data = ds.CelebADataset(DATA_DIR, decode=True, extensions=ext)
expect_labels = [0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1,
0, 1, 0, 1, 0, 0, 1],
expect_labels = [
[0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1,
0, 1, 0, 1, 0, 0, 1],
[0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1,
0, 1, 0, 1, 0, 0, 1]]
count = 0
for item in data.create_dict_iterator(num_epochs=1, output_numpy=True):
logger.info("----------image--------")
@@ -71,7 +78,7 @@ def test_celeba_dataset_ext():
for index in range(len(expect_labels[count])):
assert item["attr"][index] == expect_labels[count][index]
count = count + 1
assert count == 1
assert count == 2


def test_celeba_dataset_distribute():
@@ -83,14 +90,25 @@ def test_celeba_dataset_distribute():
logger.info("----------attr--------")
logger.info(item["attr"])
count = count + 1
assert count == 1
assert count == 2


def test_celeba_get_dataset_size():
data = ds.CelebADataset(DATA_DIR, shuffle=False, decode=True)
size = data.get_dataset_size()
assert size == 4

data = ds.CelebADataset(DATA_DIR, shuffle=False, decode=True, usage="train")
size = data.get_dataset_size()
assert size == 2

data = ds.CelebADataset(DATA_DIR, shuffle=False, decode=True, usage="valid")
size = data.get_dataset_size()
assert size == 1

data = ds.CelebADataset(DATA_DIR, shuffle=False, decode=True, usage="test")
size = data.get_dataset_size()
assert size == 1

if __name__ == '__main__':
test_celeba_dataset_label()


+ 1
- 1
tests/ut/python/dataset/test_paddeddataset.py View File

@@ -504,7 +504,7 @@ def test_celeba_padded():
count = 0
for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
count = count + 1
assert count == 2
assert count == 4


if __name__ == '__main__':


Loading…
Cancel
Save