You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

transforms.cc 6.6 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162
  1. /**
  2. * Copyright 2020-2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "minddata/dataset/include/transforms.h"
  17. #include <algorithm>
  18. #include "minddata/dataset/kernels/ir/data/transforms_ir.h"
  19. namespace mindspore {
  20. namespace dataset {
  21. // Transform operations for data.
  22. namespace transforms {
  23. // API CLASS FOR DATA TRANSFORM OPERATIONS
  24. // (In alphabetical order)
  25. // Constructor to Compose.
  26. struct Compose::Data {
  27. std::vector<std::shared_ptr<TensorOperation>> transforms_;
  28. };
  29. Compose::Compose(const std::vector<TensorTransform *> &transforms) : data_(std::make_shared<Data>()) {
  30. (void)std::transform(transforms.begin(), transforms.end(), std::back_inserter(data_->transforms_),
  31. [](TensorTransform *const op) -> std::shared_ptr<TensorOperation> {
  32. return op != nullptr ? op->Parse() : nullptr;
  33. });
  34. }
  35. Compose::Compose(const std::vector<std::shared_ptr<TensorTransform>> &transforms) : data_(std::make_shared<Data>()) {
  36. (void)std::transform(transforms.begin(), transforms.end(), std::back_inserter(data_->transforms_),
  37. [](std::shared_ptr<TensorTransform> op) -> std::shared_ptr<TensorOperation> {
  38. return op != nullptr ? op->Parse() : nullptr;
  39. });
  40. }
  41. Compose::Compose(const std::vector<std::reference_wrapper<TensorTransform>> &transforms)
  42. : data_(std::make_shared<Data>()) {
  43. (void)std::transform(transforms.begin(), transforms.end(), std::back_inserter(data_->transforms_),
  44. [](TensorTransform &op) -> std::shared_ptr<TensorOperation> { return op.Parse(); });
  45. }
  46. std::shared_ptr<TensorOperation> Compose::Parse() { return std::make_shared<ComposeOperation>(data_->transforms_); }
  47. // Constructor to Duplicate
  48. Duplicate::Duplicate() {}
  49. std::shared_ptr<TensorOperation> Duplicate::Parse() { return std::make_shared<DuplicateOperation>(); }
  50. // Constructor to OneHot
  51. struct OneHot::Data {
  52. explicit Data(int32_t num_classes) : num_classes_(num_classes) {}
  53. float num_classes_;
  54. };
  55. OneHot::OneHot(int32_t num_classes) : data_(std::make_shared<Data>(num_classes)) {}
  56. std::shared_ptr<TensorOperation> OneHot::Parse() { return std::make_shared<OneHotOperation>(data_->num_classes_); }
  57. // Constructor to RandomApply.
  58. struct RandomApply::Data {
  59. std::vector<std::shared_ptr<TensorOperation>> transforms_;
  60. double prob_;
  61. };
  62. RandomApply::RandomApply(const std::vector<TensorTransform *> &transforms, double prob)
  63. : data_(std::make_shared<Data>()) {
  64. (void)std::transform(transforms.begin(), transforms.end(), std::back_inserter(data_->transforms_),
  65. [](TensorTransform *const op) -> std::shared_ptr<TensorOperation> {
  66. return op != nullptr ? op->Parse() : nullptr;
  67. });
  68. data_->prob_ = prob;
  69. }
  70. RandomApply::RandomApply(const std::vector<std::shared_ptr<TensorTransform>> &transforms, double prob)
  71. : data_(std::make_shared<Data>()) {
  72. (void)std::transform(transforms.begin(), transforms.end(), std::back_inserter(data_->transforms_),
  73. [](std::shared_ptr<TensorTransform> op) -> std::shared_ptr<TensorOperation> {
  74. return op != nullptr ? op->Parse() : nullptr;
  75. });
  76. data_->prob_ = prob;
  77. }
  78. RandomApply::RandomApply(const std::vector<std::reference_wrapper<TensorTransform>> &transforms, double prob)
  79. : data_(std::make_shared<Data>()) {
  80. (void)std::transform(transforms.begin(), transforms.end(), std::back_inserter(data_->transforms_),
  81. [](TensorTransform &op) -> std::shared_ptr<TensorOperation> { return op.Parse(); });
  82. data_->prob_ = prob;
  83. }
  84. std::shared_ptr<TensorOperation> RandomApply::Parse() {
  85. return std::make_shared<RandomApplyOperation>(data_->transforms_, data_->prob_);
  86. }
  87. // Constructor to RandomChoice.
  88. struct RandomChoice::Data {
  89. std::vector<std::shared_ptr<TensorOperation>> transforms_;
  90. };
  91. RandomChoice::RandomChoice(const std::vector<TensorTransform *> &transforms) : data_(std::make_shared<Data>()) {
  92. (void)std::transform(transforms.begin(), transforms.end(), std::back_inserter(data_->transforms_),
  93. [](TensorTransform *const op) -> std::shared_ptr<TensorOperation> {
  94. return op != nullptr ? op->Parse() : nullptr;
  95. });
  96. }
  97. RandomChoice::RandomChoice(const std::vector<std::shared_ptr<TensorTransform>> &transforms)
  98. : data_(std::make_shared<Data>()) {
  99. (void)std::transform(transforms.begin(), transforms.end(), std::back_inserter(data_->transforms_),
  100. [](std::shared_ptr<TensorTransform> op) -> std::shared_ptr<TensorOperation> {
  101. return op != nullptr ? op->Parse() : nullptr;
  102. });
  103. }
  104. RandomChoice::RandomChoice(const std::vector<std::reference_wrapper<TensorTransform>> &transforms)
  105. : data_(std::make_shared<Data>()) {
  106. (void)std::transform(transforms.begin(), transforms.end(), std::back_inserter(data_->transforms_),
  107. [](TensorTransform &op) -> std::shared_ptr<TensorOperation> { return op.Parse(); });
  108. }
  109. std::shared_ptr<TensorOperation> RandomChoice::Parse() {
  110. return std::make_shared<RandomChoiceOperation>(data_->transforms_);
  111. }
  112. // Constructor to TypeCast
  113. struct TypeCast::Data {
  114. explicit Data(const std::vector<char> &data_type) : data_type_(CharToString(data_type)) {}
  115. std::string data_type_;
  116. };
  117. TypeCast::TypeCast(const std::vector<char> &data_type) : data_(std::make_shared<Data>(data_type)) {}
  118. std::shared_ptr<TensorOperation> TypeCast::Parse() { return std::make_shared<TypeCastOperation>(data_->data_type_); }
  119. // Constructor to Unique
  120. Unique::Unique() {}
  121. #ifndef ENABLE_ANDROID
  122. std::shared_ptr<TensorOperation> Unique::Parse() { return std::make_shared<UniqueOperation>(); }
  123. #else
  124. std::shared_ptr<TensorOperation> Unique::Parse() {
  125. MS_LOG(ERROR) << "Unique op is not supported for Android.";
  126. return nullptr;
  127. }
  128. #endif
  129. } // namespace transforms
  130. } // namespace dataset
  131. } // namespace mindspore