You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

transforms.cc 6.4 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159
  1. /**
  2. * Copyright 2020-2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "minddata/dataset/include/transforms.h"
  17. #include <algorithm>
  18. #include "minddata/dataset/kernels/ir/data/transforms_ir.h"
  19. namespace mindspore {
  20. namespace dataset {
  21. // Transform operations for data.
  22. namespace transforms {
  23. // API CLASS FOR DATA TRANSFORM OPERATIONS
  24. // (In alphabetical order)
  25. // Constructor to Compose.
  26. struct Compose::Data {
  27. std::vector<std::shared_ptr<TensorOperation>> transforms_;
  28. };
  29. Compose::Compose(const std::vector<TensorTransform *> &transforms) : data_(std::make_shared<Data>()) {
  30. (void)std::transform(
  31. transforms.begin(), transforms.end(), std::back_inserter(data_->transforms_),
  32. [](TensorTransform *op) -> std::shared_ptr<TensorOperation> { return op != nullptr ? op->Parse() : nullptr; });
  33. }
  34. Compose::Compose(const std::vector<std::shared_ptr<TensorTransform>> &transforms) : data_(std::make_shared<Data>()) {
  35. (void)std::transform(transforms.begin(), transforms.end(), std::back_inserter(data_->transforms_),
  36. [](std::shared_ptr<TensorTransform> op) -> std::shared_ptr<TensorOperation> {
  37. return op != nullptr ? op->Parse() : nullptr;
  38. });
  39. }
  40. Compose::Compose(const std::vector<std::reference_wrapper<TensorTransform>> &transforms)
  41. : data_(std::make_shared<Data>()) {
  42. (void)std::transform(transforms.begin(), transforms.end(), std::back_inserter(data_->transforms_),
  43. [](TensorTransform &op) -> std::shared_ptr<TensorOperation> { return op.Parse(); });
  44. }
  45. std::shared_ptr<TensorOperation> Compose::Parse() { return std::make_shared<ComposeOperation>(data_->transforms_); }
  46. // Constructor to Duplicate
  47. Duplicate::Duplicate() {}
  48. std::shared_ptr<TensorOperation> Duplicate::Parse() { return std::make_shared<DuplicateOperation>(); }
  49. // Constructor to OneHot
  50. struct OneHot::Data {
  51. explicit Data(int32_t num_classes) : num_classes_(num_classes) {}
  52. float num_classes_;
  53. };
  54. OneHot::OneHot(int32_t num_classes) : data_(std::make_shared<Data>(num_classes)) {}
  55. std::shared_ptr<TensorOperation> OneHot::Parse() { return std::make_shared<OneHotOperation>(data_->num_classes_); }
  56. // Constructor to RandomApply.
  57. struct RandomApply::Data {
  58. std::vector<std::shared_ptr<TensorOperation>> transforms_;
  59. double prob_;
  60. };
  61. RandomApply::RandomApply(const std::vector<TensorTransform *> &transforms, double prob)
  62. : data_(std::make_shared<Data>()) {
  63. (void)std::transform(
  64. transforms.begin(), transforms.end(), std::back_inserter(data_->transforms_),
  65. [](TensorTransform *op) -> std::shared_ptr<TensorOperation> { return op != nullptr ? op->Parse() : nullptr; });
  66. data_->prob_ = prob;
  67. }
  68. RandomApply::RandomApply(const std::vector<std::shared_ptr<TensorTransform>> &transforms, double prob)
  69. : data_(std::make_shared<Data>()) {
  70. (void)std::transform(transforms.begin(), transforms.end(), std::back_inserter(data_->transforms_),
  71. [](std::shared_ptr<TensorTransform> op) -> std::shared_ptr<TensorOperation> {
  72. return op != nullptr ? op->Parse() : nullptr;
  73. });
  74. data_->prob_ = prob;
  75. }
  76. RandomApply::RandomApply(const std::vector<std::reference_wrapper<TensorTransform>> &transforms, double prob)
  77. : data_(std::make_shared<Data>()) {
  78. (void)std::transform(transforms.begin(), transforms.end(), std::back_inserter(data_->transforms_),
  79. [](TensorTransform &op) -> std::shared_ptr<TensorOperation> { return op.Parse(); });
  80. data_->prob_ = prob;
  81. }
  82. std::shared_ptr<TensorOperation> RandomApply::Parse() {
  83. return std::make_shared<RandomApplyOperation>(data_->transforms_, data_->prob_);
  84. }
  85. // Constructor to RandomChoice.
  86. struct RandomChoice::Data {
  87. std::vector<std::shared_ptr<TensorOperation>> transforms_;
  88. };
  89. RandomChoice::RandomChoice(const std::vector<TensorTransform *> &transforms) : data_(std::make_shared<Data>()) {
  90. (void)std::transform(
  91. transforms.begin(), transforms.end(), std::back_inserter(data_->transforms_),
  92. [](TensorTransform *op) -> std::shared_ptr<TensorOperation> { return op != nullptr ? op->Parse() : nullptr; });
  93. }
  94. RandomChoice::RandomChoice(const std::vector<std::shared_ptr<TensorTransform>> &transforms)
  95. : data_(std::make_shared<Data>()) {
  96. (void)std::transform(transforms.begin(), transforms.end(), std::back_inserter(data_->transforms_),
  97. [](std::shared_ptr<TensorTransform> op) -> std::shared_ptr<TensorOperation> {
  98. return op != nullptr ? op->Parse() : nullptr;
  99. });
  100. }
  101. RandomChoice::RandomChoice(const std::vector<std::reference_wrapper<TensorTransform>> &transforms)
  102. : data_(std::make_shared<Data>()) {
  103. (void)std::transform(transforms.begin(), transforms.end(), std::back_inserter(data_->transforms_),
  104. [](TensorTransform &op) -> std::shared_ptr<TensorOperation> { return op.Parse(); });
  105. }
  106. std::shared_ptr<TensorOperation> RandomChoice::Parse() {
  107. return std::make_shared<RandomChoiceOperation>(data_->transforms_);
  108. }
  109. // Constructor to TypeCast
  110. struct TypeCast::Data {
  111. explicit Data(const std::vector<char> &data_type) : data_type_(CharToString(data_type)) {}
  112. std::string data_type_;
  113. };
  114. TypeCast::TypeCast(const std::vector<char> &data_type) : data_(std::make_shared<Data>(data_type)) {}
  115. std::shared_ptr<TensorOperation> TypeCast::Parse() { return std::make_shared<TypeCastOperation>(data_->data_type_); }
  116. // Constructor to Unique
  117. Unique::Unique() {}
  118. #ifndef ENABLE_ANDROID
  119. std::shared_ptr<TensorOperation> Unique::Parse() { return std::make_shared<UniqueOperation>(); }
  120. #else
  121. std::shared_ptr<TensorOperation> Unique::Parse() {
  122. MS_LOG(ERROR) << "Unique op is not supported for Android.";
  123. return nullptr;
  124. }
  125. #endif
  126. } // namespace transforms
  127. } // namespace dataset
  128. } // namespace mindspore