Browse Source

!1342 Bug fix on issue Core dump on GPU when train with lenet with AU

Merge pull request !1342 from Tinazhang/cc
tags/v0.3.0-alpha
mindspore-ci-bot Gitee 5 years ago
parent
commit
39b9aedf68
2 changed files with 25 additions and 3 deletions
  1. +3
    -3
      mindspore/ccsrc/dataset/kernels/image/uniform_aug_op.cc
  2. +22
    -0
      tests/ut/python/dataset/test_uniform_augment.py

+ 3
- 3
mindspore/ccsrc/dataset/kernels/image/uniform_aug_op.cc View File

@@ -55,11 +55,11 @@ Status UniformAugOp::Compute(const std::vector<std::shared_ptr<Tensor>> &input,

// apply C++ ops (note: python OPs are not accepted)
if (count == 1) {
(**tensor_op).Compute(input, output);
RETURN_IF_NOT_OK((**tensor_op).Compute(input, output));
} else if (count % 2 == 0) {
(**tensor_op).Compute(*output, even_out_ptr);
RETURN_IF_NOT_OK((**tensor_op).Compute(*output, even_out_ptr));
} else {
(**tensor_op).Compute(even_out, output);
RETURN_IF_NOT_OK((**tensor_op).Compute(even_out, output));
}
count++;
}


+ 22
- 0
tests/ut/python/dataset/test_uniform_augment.py View File

@@ -226,6 +226,27 @@ def test_cpp_uniform_augment_exception_nonpositive_numops(num_ops=0):
logger.info("Got an exception in DE: {}".format(str(e)))
assert "num_ops" in str(e)

def test_cpp_uniform_augment_random_crop_ut():
batch_size=2
cifar10_dir = "../data/dataset/testCifar10Data"
ds1 = de.Cifar10Dataset(cifar10_dir, shuffle=False) # shape = [32,32,3]

transforms_ua = [
C.RandomCrop(size=[224, 224]),
C.RandomHorizontalFlip()
]
uni_aug = C.UniformAugment(operations=transforms_ua, num_ops=1)
ds1 = ds1.map(input_columns="image", operations=uni_aug)

# apply DatasetOps
ds1 = ds1.batch(batch_size, drop_remainder=True, num_parallel_workers=1)
num_batches = 0
try:
for data in ds1.create_dict_iterator():
num_batches += 1
except BaseException as e:
assert "Crop size" in str(e)


if __name__ == "__main__":
test_uniform_augment(num_ops=1)
@@ -233,3 +254,4 @@ if __name__ == "__main__":
test_cpp_uniform_augment_exception_pyops(num_ops=1)
test_cpp_uniform_augment_exception_large_numops(num_ops=6)
test_cpp_uniform_augment_exception_nonpositive_numops(num_ops=0)
test_cpp_uniform_augment_random_crop_ut()

Loading…
Cancel
Save