You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

transforms.cc 5.4 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. /**
  2. * Copyright 2020-2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "minddata/dataset/include/transforms.h"
  17. #include <algorithm>
  18. #include "minddata/dataset/kernels/ir/data/transforms_ir.h"
  19. namespace mindspore {
  20. namespace dataset {
  21. // Transform operations for data.
  22. namespace transforms {
  23. // API CLASS FOR DATA TRANSFORM OPERATIONS
  24. // (In alphabetical order)
  25. // Constructor to Compose.
  26. Compose::Compose(const std::vector<TensorTransform *> &transforms) {
  27. (void)std::transform(
  28. transforms.begin(), transforms.end(), std::back_inserter(transforms_),
  29. [](TensorTransform *op) -> std::shared_ptr<TensorOperation> { return op != nullptr ? op->Parse() : nullptr; });
  30. }
  31. Compose::Compose(const std::vector<std::shared_ptr<TensorTransform>> &transforms) {
  32. (void)std::transform(transforms.begin(), transforms.end(), std::back_inserter(transforms_),
  33. [](std::shared_ptr<TensorTransform> op) -> std::shared_ptr<TensorOperation> {
  34. return op != nullptr ? op->Parse() : nullptr;
  35. });
  36. }
  37. Compose::Compose(const std::vector<std::reference_wrapper<TensorTransform>> &transforms) {
  38. (void)std::transform(transforms.begin(), transforms.end(), std::back_inserter(transforms_),
  39. [](TensorTransform &op) -> std::shared_ptr<TensorOperation> { return op.Parse(); });
  40. }
  41. std::shared_ptr<TensorOperation> Compose::Parse() { return std::make_shared<ComposeOperation>(transforms_); }
  42. // Constructor to Duplicate
  43. Duplicate::Duplicate() {}
  44. std::shared_ptr<TensorOperation> Duplicate::Parse() { return std::make_shared<DuplicateOperation>(); }
  45. // Constructor to OneHot
  46. OneHot::OneHot(int32_t num_classes) : num_classes_(num_classes) {}
  47. std::shared_ptr<TensorOperation> OneHot::Parse() { return std::make_shared<OneHotOperation>(num_classes_); }
  48. // Constructor to RandomApply.
  49. RandomApply::RandomApply(const std::vector<TensorTransform *> &transforms, double prob) : prob_(prob) {
  50. (void)std::transform(
  51. transforms.begin(), transforms.end(), std::back_inserter(transforms_),
  52. [](TensorTransform *op) -> std::shared_ptr<TensorOperation> { return op != nullptr ? op->Parse() : nullptr; });
  53. }
  54. RandomApply::RandomApply(const std::vector<std::shared_ptr<TensorTransform>> &transforms, double prob) : prob_(prob) {
  55. (void)std::transform(transforms.begin(), transforms.end(), std::back_inserter(transforms_),
  56. [](std::shared_ptr<TensorTransform> op) -> std::shared_ptr<TensorOperation> {
  57. return op != nullptr ? op->Parse() : nullptr;
  58. });
  59. }
  60. RandomApply::RandomApply(const std::vector<std::reference_wrapper<TensorTransform>> &transforms, double prob)
  61. : prob_(prob) {
  62. (void)std::transform(transforms.begin(), transforms.end(), std::back_inserter(transforms_),
  63. [](TensorTransform &op) -> std::shared_ptr<TensorOperation> { return op.Parse(); });
  64. }
  65. std::shared_ptr<TensorOperation> RandomApply::Parse() {
  66. return std::make_shared<RandomApplyOperation>(transforms_, prob_);
  67. }
  68. // Constructor to RandomChoice.
  69. RandomChoice::RandomChoice(const std::vector<TensorTransform *> &transforms) {
  70. (void)std::transform(
  71. transforms.begin(), transforms.end(), std::back_inserter(transforms_),
  72. [](TensorTransform *op) -> std::shared_ptr<TensorOperation> { return op != nullptr ? op->Parse() : nullptr; });
  73. }
  74. RandomChoice::RandomChoice(const std::vector<std::shared_ptr<TensorTransform>> &transforms) {
  75. (void)std::transform(transforms.begin(), transforms.end(), std::back_inserter(transforms_),
  76. [](std::shared_ptr<TensorTransform> op) -> std::shared_ptr<TensorOperation> {
  77. return op != nullptr ? op->Parse() : nullptr;
  78. });
  79. }
  80. RandomChoice::RandomChoice(const std::vector<std::reference_wrapper<TensorTransform>> &transforms) {
  81. (void)std::transform(transforms.begin(), transforms.end(), std::back_inserter(transforms_),
  82. [](TensorTransform &op) -> std::shared_ptr<TensorOperation> { return op.Parse(); });
  83. }
  84. std::shared_ptr<TensorOperation> RandomChoice::Parse() { return std::make_shared<RandomChoiceOperation>(transforms_); }
  85. // Constructor to TypeCast
  86. TypeCast::TypeCast(std::string data_type) : data_type_(data_type) {}
  87. std::shared_ptr<TensorOperation> TypeCast::Parse() { return std::make_shared<TypeCastOperation>(data_type_); }
  88. // Constructor to Unique
  89. Unique::Unique() {}
  90. #ifndef ENABLE_ANDROID
  91. std::shared_ptr<TensorOperation> Unique::Parse() { return std::make_shared<UniqueOperation>(); }
  92. #else
  93. std::shared_ptr<TensorOperation> Unique::Parse() {
  94. MS_LOG(ERROR) << "Unique op is not supported for Android.";
  95. return nullptr;
  96. }
  97. #endif
  98. } // namespace transforms
  99. } // namespace dataset
  100. } // namespace mindspore