/** * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "minddata/dataset/include/dataset/transforms.h" #include #include "mindspore/ccsrc/minddata/dataset/core/type_id.h" #include "mindspore/core/ir/dtype/type_id.h" #include "minddata/dataset/core/type_id.h" #include "minddata/dataset/kernels/ir/data/transforms_ir.h" namespace mindspore { namespace dataset { // Transform operations for data. namespace transforms { // API CLASS FOR DATA TRANSFORM OPERATIONS // (In alphabetical order) // Constructor to Compose. struct Compose::Data { std::vector> transforms_; }; Compose::Compose(const std::vector &transforms) : data_(std::make_shared()) { (void)std::transform(transforms.begin(), transforms.end(), std::back_inserter(data_->transforms_), [](TensorTransform *const op) -> std::shared_ptr { return op != nullptr ? op->Parse() : nullptr; }); } Compose::Compose(const std::vector> &transforms) : data_(std::make_shared()) { (void)std::transform(transforms.begin(), transforms.end(), std::back_inserter(data_->transforms_), [](std::shared_ptr op) -> std::shared_ptr { return op != nullptr ? op->Parse() : nullptr; }); } Compose::Compose(const std::vector> &transforms) : data_(std::make_shared()) { (void)std::transform(transforms.begin(), transforms.end(), std::back_inserter(data_->transforms_), [](TensorTransform &op) -> std::shared_ptr { return op.Parse(); }); } std::shared_ptr Compose::Parse() { return std::make_shared(data_->transforms_); } #ifndef ENABLE_ANDROID // Constructor to Concatenate struct Concatenate::Data { explicit Data(int8_t axis, MSTensor prepend, MSTensor append) : axis_(axis), prepend_(prepend), append_(append) {} int8_t axis_; MSTensor prepend_; MSTensor append_; }; Concatenate::Concatenate(int8_t axis, MSTensor prepend, MSTensor append) : data_(std::make_shared(axis, prepend, append)) {} std::shared_ptr Concatenate::Parse() { std::shared_ptr out_prepend, out_append; Status rc = Tensor::CreateFromMSTensor(data_->prepend_, &out_prepend); if (rc.IsError()) { MS_LOG(ERROR) << "Error creating prepend constant tensor. " << rc; return nullptr; } rc = Tensor::CreateFromMSTensor(data_->append_, &out_append); if (rc.IsError()) { MS_LOG(ERROR) << "Error creating append constant tensor. " << rc; return nullptr; } return std::make_shared(data_->axis_, out_prepend, out_append); } #endif // not ENABLE_ANDROID // Constructor to Duplicate Duplicate::Duplicate() {} std::shared_ptr Duplicate::Parse() { return std::make_shared(); } #ifndef ENABLE_ANDROID // Constructor to Fill struct Fill::Data { explicit Data(MSTensor fill_value) : fill_value_(fill_value) {} MSTensor fill_value_; }; Fill::Fill(MSTensor fill_value) : data_(std::make_shared(fill_value)) {} std::shared_ptr Fill::Parse() { std::shared_ptr out_fill_value; Status rc = Tensor::CreateFromMSTensor(data_->fill_value_, &out_fill_value); if (rc.IsError()) { MS_LOG(ERROR) << "Error creating fill value tensor. " << rc; return nullptr; } return std::make_shared(out_fill_value); } // Constructor to Mask struct Mask::Data { explicit Data(RelationalOp op, MSTensor constant, mindspore::DataType ms_type) : op_(op), constant_(constant), ms_type_(ms_type) {} RelationalOp op_; MSTensor constant_; mindspore::DataType ms_type_; }; Mask::Mask(RelationalOp op, MSTensor constant, mindspore::DataType ms_type) : data_(std::make_shared(op, constant, ms_type)) {} std::shared_ptr Mask::Parse() { std::shared_ptr out_constant; Status rc = Tensor::CreateFromMSTensor(data_->constant_, &out_constant); if (rc.IsError()) { MS_LOG(ERROR) << "Error creating constant tensor. " << rc; return nullptr; } DataType de_type = dataset::MSTypeToDEType(static_cast(data_->ms_type_)); return std::make_shared(data_->op_, out_constant, de_type); } #endif // not ENABLE_ANDROID // Constructor to OneHot struct OneHot::Data { explicit Data(int32_t num_classes) : num_classes_(num_classes) {} int32_t num_classes_; }; OneHot::OneHot(int32_t num_classes) : data_(std::make_shared(num_classes)) {} std::shared_ptr OneHot::Parse() { return std::make_shared(data_->num_classes_); } #ifndef ENABLE_ANDROID // Constructor to PadEnd struct PadEnd::Data { explicit Data(const std::vector &pad_shape, MSTensor pad_value) : pad_shape_(pad_shape), pad_value_(pad_value) {} std::vector pad_shape_; MSTensor pad_value_; }; PadEnd::PadEnd(const std::vector &pad_shape, MSTensor pad_value) : data_(std::make_shared(pad_shape, pad_value)) {} std::shared_ptr PadEnd::Parse() { std::shared_ptr pad_value; Status rc = Tensor::CreateFromMSTensor(data_->pad_value_, &pad_value); if (rc.IsError()) { MS_LOG(ERROR) << "Error creating value constant tensor. " << rc; return nullptr; } return std::make_shared(TensorShape(data_->pad_shape_), pad_value); } #endif // not ENABLE_ANDROID // Constructor to RandomApply. struct RandomApply::Data { std::vector> transforms_; double prob_; }; RandomApply::RandomApply(const std::vector &transforms, double prob) : data_(std::make_shared()) { (void)std::transform(transforms.begin(), transforms.end(), std::back_inserter(data_->transforms_), [](TensorTransform *const op) -> std::shared_ptr { return op != nullptr ? op->Parse() : nullptr; }); data_->prob_ = prob; } RandomApply::RandomApply(const std::vector> &transforms, double prob) : data_(std::make_shared()) { (void)std::transform(transforms.begin(), transforms.end(), std::back_inserter(data_->transforms_), [](std::shared_ptr op) -> std::shared_ptr { return op != nullptr ? op->Parse() : nullptr; }); data_->prob_ = prob; } RandomApply::RandomApply(const std::vector> &transforms, double prob) : data_(std::make_shared()) { (void)std::transform(transforms.begin(), transforms.end(), std::back_inserter(data_->transforms_), [](TensorTransform &op) -> std::shared_ptr { return op.Parse(); }); data_->prob_ = prob; } std::shared_ptr RandomApply::Parse() { return std::make_shared(data_->transforms_, data_->prob_); } // Constructor to RandomChoice. struct RandomChoice::Data { std::vector> transforms_; }; RandomChoice::RandomChoice(const std::vector &transforms) : data_(std::make_shared()) { (void)std::transform(transforms.begin(), transforms.end(), std::back_inserter(data_->transforms_), [](TensorTransform *const op) -> std::shared_ptr { return op != nullptr ? op->Parse() : nullptr; }); } RandomChoice::RandomChoice(const std::vector> &transforms) : data_(std::make_shared()) { (void)std::transform(transforms.begin(), transforms.end(), std::back_inserter(data_->transforms_), [](std::shared_ptr op) -> std::shared_ptr { return op != nullptr ? op->Parse() : nullptr; }); } RandomChoice::RandomChoice(const std::vector> &transforms) : data_(std::make_shared()) { (void)std::transform(transforms.begin(), transforms.end(), std::back_inserter(data_->transforms_), [](TensorTransform &op) -> std::shared_ptr { return op.Parse(); }); } std::shared_ptr RandomChoice::Parse() { return std::make_shared(data_->transforms_); } #ifndef ENABLE_ANDROID // Constructor to Slice struct Slice::Data { explicit Data(const std::vector &slice_input) : slice_input_(slice_input) {} std::vector slice_input_; }; Slice::Slice(const std::vector &slice_input) : data_(std::make_shared(slice_input)) {} std::shared_ptr Slice::Parse() { return std::make_shared(data_->slice_input_); } #endif // not ENABLE_ANDROID // Constructor to TypeCast struct TypeCast::Data { dataset::DataType data_type_; }; TypeCast::TypeCast(mindspore::DataType data_type) : data_(std::make_shared()) { data_->data_type_ = dataset::MSTypeToDEType(static_cast(data_type)); } std::shared_ptr TypeCast::Parse() { return std::make_shared(data_->data_type_); } // Constructor to Unique Unique::Unique() {} #ifndef ENABLE_ANDROID std::shared_ptr Unique::Parse() { return std::make_shared(); } #else std::shared_ptr Unique::Parse() { MS_LOG(ERROR) << "Unique op is not supported for Android."; return nullptr; } #endif } // namespace transforms } // namespace dataset } // namespace mindspore