| @@ -25,18 +25,14 @@ UniformAugOp::UniformAugOp(py::list op_list, int32_t num_ops) : num_ops_(num_ops | |||||
| std::shared_ptr<TensorOp> tensor_op; | std::shared_ptr<TensorOp> tensor_op; | ||||
| // iterate over the op list, cast them to TensorOp and add them to tensor_op_list_ | // iterate over the op list, cast them to TensorOp and add them to tensor_op_list_ | ||||
| for (auto op : op_list) { | for (auto op : op_list) { | ||||
| if (py::isinstance<py::function>(op)) { | |||||
| // python op | |||||
| tensor_op = std::make_shared<PyFuncOp>(op.cast<py::function>()); | |||||
| } else if (py::isinstance<TensorOp>(op)) { | |||||
| // C++ op | |||||
| tensor_op = op.cast<std::shared_ptr<TensorOp>>(); | |||||
| } | |||||
| // only C++ op is accepted | |||||
| tensor_op = op.cast<std::shared_ptr<TensorOp>>(); | |||||
| tensor_op_list_.insert(tensor_op_list_.begin(), tensor_op); | tensor_op_list_.insert(tensor_op_list_.begin(), tensor_op); | ||||
| } | } | ||||
| rnd_.seed(GetSeed()); | rnd_.seed(GetSeed()); | ||||
| } | } | ||||
| // compute method to apply uniformly random selected augmentations from a list | // compute method to apply uniformly random selected augmentations from a list | ||||
| Status UniformAugOp::Compute(const std::vector<std::shared_ptr<Tensor>> &input, | Status UniformAugOp::Compute(const std::vector<std::shared_ptr<Tensor>> &input, | ||||
| std::vector<std::shared_ptr<Tensor>> *output) { | std::vector<std::shared_ptr<Tensor>> *output) { | ||||
| @@ -57,7 +53,7 @@ Status UniformAugOp::Compute(const std::vector<std::shared_ptr<Tensor>> &input, | |||||
| continue; | continue; | ||||
| } | } | ||||
| // apply python/C++ op | |||||
| // apply C++ ops (note: python OPs are not accepted) | |||||
| if (count == 1) { | if (count == 1) { | ||||
| (**tensor_op).Compute(input, output); | (**tensor_op).Compute(input, output); | ||||
| } else if (count % 2 == 0) { | } else if (count % 2 == 0) { | ||||
| @@ -36,7 +36,7 @@ class UniformAugOp : public TensorOp { | |||||
| static const int kDefNumOps; | static const int kDefNumOps; | ||||
| // Constructor for UniformAugOp | // Constructor for UniformAugOp | ||||
| // @param list op_list: list of candidate python operations | |||||
| // @param list op_list: list of candidate C++ operations | |||||
| // @param list num_ops: number of augemtation operations to applied | // @param list num_ops: number of augemtation operations to applied | ||||
| UniformAugOp(py::list op_list, int32_t num_ops); | UniformAugOp(py::list op_list, int32_t num_ops); | ||||
| @@ -455,8 +455,19 @@ class UniformAugment(cde.UniformAugOp): | |||||
| Tensor operation to perform randomly selected augmentation | Tensor operation to perform randomly selected augmentation | ||||
| Args: | Args: | ||||
| operations: list of python operations. | |||||
| operations: list of C++ operations (python OPs are not accepted). | |||||
| NumOps (int): number of OPs to be selected and applied. | NumOps (int): number of OPs to be selected and applied. | ||||
| Examples: | |||||
| >>> transforms_list = [c_transforms.RandomHorizontalFlip(), | |||||
| >>> c_transforms.RandomVerticalFlip(), | |||||
| >>> c_transforms.RandomColorAdjust(), | |||||
| >>> c_transforms.RandomRotation(degrees=45)] | |||||
| >>> uni_aug = c_transforms.UniformAugment(operations=transforms_list, num_ops=2) | |||||
| >>> transforms_all = [c_transforms.Decode(), c_transforms.Resize(size=[224, 224]), | |||||
| >>> uni_aug, F.ToTensor()] | |||||
| >>> ds_ua = ds.map(input_columns="image", | |||||
| >>> operations=transforms_all, num_parallel_workers=1) | |||||
| """ | """ | ||||
| @check_uniform_augmentation | @check_uniform_augmentation | ||||
| @@ -837,8 +837,8 @@ def check_uniform_augmentation(method): | |||||
| if not isinstance(operations, list): | if not isinstance(operations, list): | ||||
| raise ValueError("operations is not a python list") | raise ValueError("operations is not a python list") | ||||
| for op in operations: | for op in operations: | ||||
| if not callable(op) and not isinstance(op, TensorOp): | |||||
| raise ValueError("non-callable op in operations list") | |||||
| if not isinstance(op, TensorOp): | |||||
| raise ValueError("operations list only accepts C++ operations.") | |||||
| kwargs["num_ops"] = num_ops | kwargs["num_ops"] = num_ops | ||||
| kwargs["operations"] = operations | kwargs["operations"] = operations | ||||
| @@ -163,7 +163,68 @@ def test_cpp_uniform_augment(plot=False, num_ops=2): | |||||
| mse[i] = np.mean((images_ua[i] - images_original[i]) ** 2) | mse[i] = np.mean((images_ua[i] - images_original[i]) ** 2) | ||||
| logger.info("MSE= {}".format(str(np.mean(mse)))) | logger.info("MSE= {}".format(str(np.mean(mse)))) | ||||
| def test_cpp_uniform_augment_exception_pyops(num_ops=2): | |||||
| """ | |||||
| Test UniformAugment invalid op in operations | |||||
| """ | |||||
| logger.info("Test CPP UniformAugment invalid OP exception") | |||||
| transforms_ua = [C.RandomCrop(size=[224, 224], padding=[32, 32, 32, 32]), | |||||
| C.RandomHorizontalFlip(), | |||||
| C.RandomVerticalFlip(), | |||||
| C.RandomColorAdjust(), | |||||
| C.RandomRotation(degrees=45), | |||||
| F.Invert()] | |||||
| try: | |||||
| uni_aug = C.UniformAugment(operations=transforms_ua, num_ops=num_ops) | |||||
| except BaseException as e: | |||||
| logger.info("Got an exception in DE: {}".format(str(e))) | |||||
| assert "operations" in str(e) | |||||
| def test_cpp_uniform_augment_exception_large_numops(num_ops=6): | |||||
| """ | |||||
| Test UniformAugment invalid large number of ops | |||||
| """ | |||||
| logger.info("Test CPP UniformAugment invalid large num_ops exception") | |||||
| transforms_ua = [C.RandomCrop(size=[224, 224], padding=[32, 32, 32, 32]), | |||||
| C.RandomHorizontalFlip(), | |||||
| C.RandomVerticalFlip(), | |||||
| C.RandomColorAdjust(), | |||||
| C.RandomRotation(degrees=45)] | |||||
| try: | |||||
| uni_aug = C.UniformAugment(operations=transforms_ua, num_ops=num_ops) | |||||
| except BaseException as e: | |||||
| logger.info("Got an exception in DE: {}".format(str(e))) | |||||
| assert "num_ops" in str(e) | |||||
| def test_cpp_uniform_augment_exception_nonpositive_numops(num_ops=0): | |||||
| """ | |||||
| Test UniformAugment invalid non-positive number of ops | |||||
| """ | |||||
| logger.info("Test CPP UniformAugment invalid non-positive num_ops exception") | |||||
| transforms_ua = [C.RandomCrop(size=[224, 224], padding=[32, 32, 32, 32]), | |||||
| C.RandomHorizontalFlip(), | |||||
| C.RandomVerticalFlip(), | |||||
| C.RandomColorAdjust(), | |||||
| C.RandomRotation(degrees=45)] | |||||
| try: | |||||
| uni_aug = C.UniformAugment(operations=transforms_ua, num_ops=num_ops) | |||||
| except BaseException as e: | |||||
| logger.info("Got an exception in DE: {}".format(str(e))) | |||||
| assert "num_ops" in str(e) | |||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||
| test_uniform_augment(num_ops=1) | test_uniform_augment(num_ops=1) | ||||
| test_cpp_uniform_augment(num_ops=1) | test_cpp_uniform_augment(num_ops=1) | ||||
| test_cpp_uniform_augment_exception_pyops(num_ops=1) | |||||
| test_cpp_uniform_augment_exception_large_numops(num_ops=6) | |||||
| test_cpp_uniform_augment_exception_nonpositive_numops(num_ops=0) | |||||