/** * Copyright 2019 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "dataset/kernels/data/data_utils.h" #include #include "dataset/core/constants.h" #include "dataset/core/tensor.h" #include "dataset/core/tensor_shape.h" #include "dataset/core/data_type.h" #include "dataset/core/pybind_support.h" namespace mindspore { namespace dataset { Status OneHotEncodingUnsigned(const std::shared_ptr &input, std::shared_ptr *output, dsize_t num_classes, int64_t index) { uint64_t class_idx; if (input->Rank() == 0) { RETURN_IF_NOT_OK(input->GetItemAt(&class_idx, {})); } else { RETURN_IF_NOT_OK(input->GetItemAt(&class_idx, {index})); } if (class_idx >= static_cast(num_classes)) { RETURN_STATUS_UNEXPECTED("One_hot index values are not in range"); } if (input->type() == DataType::DE_UINT64) { RETURN_IF_NOT_OK((*output)->SetItemAt({index, static_cast(class_idx)}, 1)); } else if (input->type() == DataType::DE_UINT32) { RETURN_IF_NOT_OK((*output)->SetItemAt({index, static_cast(class_idx)}, 1)); } else if (input->type() == DataType::DE_UINT16) { RETURN_IF_NOT_OK((*output)->SetItemAt({index, static_cast(class_idx)}, 1)); } else if (input->type() == DataType::DE_UINT8) { RETURN_IF_NOT_OK((*output)->SetItemAt({index, static_cast(class_idx)}, 1)); } else { RETURN_STATUS_UNEXPECTED("One hot unsigned only supports unsigned int as input."); } return Status::OK(); } Status OneHotEncodingSigned(const std::shared_ptr &input, std::shared_ptr *output, dsize_t num_classes, int64_t index) { int64_t class_idx; if (input->Rank() == 0) { RETURN_IF_NOT_OK(input->GetItemAt(&class_idx, {})); } else { RETURN_IF_NOT_OK(input->GetItemAt(&class_idx, {index})); } if (class_idx >= static_cast(num_classes)) { RETURN_STATUS_UNEXPECTED("One_hot index values are not in range"); } if (input->type() == DataType::DE_INT64) { RETURN_IF_NOT_OK((*output)->SetItemAt({index, static_cast(class_idx)}, 1)); } else if (input->type() == DataType::DE_INT32) { RETURN_IF_NOT_OK((*output)->SetItemAt({index, static_cast(class_idx)}, 1)); } else if (input->type() == DataType::DE_INT16) { RETURN_IF_NOT_OK((*output)->SetItemAt({index, static_cast(class_idx)}, 1)); } else if (input->type() == DataType::DE_INT8) { RETURN_IF_NOT_OK((*output)->SetItemAt({index, static_cast(class_idx)}, 1)); } else { RETURN_STATUS_UNEXPECTED("One hot signed only supports signed int as input."); } return Status::OK(); } Status OneHotEncoding(std::shared_ptr input, std::shared_ptr *output, dsize_t num_classes) { input->Squeeze(); if (input->Rank() > 1) { // We expect the input to be int he first dimension RETURN_STATUS_UNEXPECTED("One hot only supports scalars or 1D shape Tensors."); } if (!input->type().IsInt()) { RETURN_STATUS_UNEXPECTED("One hot does not support input of this type."); } try { dsize_t num_elements = 1; if (input->Rank() == 1) num_elements = input->shape()[0]; TensorShape out_shape({num_elements, num_classes}); std::shared_ptr out; RETURN_IF_NOT_OK(Tensor::CreateTensor(&out, TensorImpl::kFlexible, out_shape, input->type())); RETURN_IF_NOT_OK(out->Zero()); for (dsize_t i = 0; i < num_elements; ++i) { if (input->type().IsUnsignedInt()) { RETURN_IF_NOT_OK(OneHotEncodingUnsigned(input, &out, num_classes, i)); } else { RETURN_IF_NOT_OK(OneHotEncodingSigned(input, &out, num_classes, i)); } } out->Squeeze(); *output = out; return Status::OK(); } catch (const std::exception &e) { RETURN_STATUS_UNEXPECTED("Unexpected error in OneHotOp"); } } template void Cast(const std::shared_ptr &input, std::shared_ptr *output) { auto in_itr = input->begin(); auto out_itr = (*output)->begin(); auto out_end = (*output)->end(); for (; out_itr != out_end; static_cast(in_itr++), static_cast(out_itr++)) *out_itr = static_cast(*in_itr); } template void CastFrom(const std::shared_ptr &input, std::shared_ptr *output) { switch ((*output)->type().value()) { case DataType::DE_BOOL: Cast(input, output); break; case DataType::DE_INT8: Cast(input, output); break; case DataType::DE_UINT8: Cast(input, output); break; case DataType::DE_INT16: Cast(input, output); break; case DataType::DE_UINT16: Cast(input, output); break; case DataType::DE_INT32: Cast(input, output); break; case DataType::DE_UINT32: Cast(input, output); break; case DataType::DE_INT64: Cast(input, output); break; case DataType::DE_UINT64: Cast(input, output); break; case DataType::DE_FLOAT16: Cast(input, output); break; case DataType::DE_FLOAT32: Cast(input, output); break; case DataType::DE_FLOAT64: Cast(input, output); break; case DataType::DE_UNKNOWN: MS_LOG(ERROR) << "Unknown data type."; break; } } // Type cast operator Status TypeCast(const std::shared_ptr &input, std::shared_ptr *output, const DataType &data_type) { RETURN_IF_NOT_OK(Tensor::CreateTensor(output, TensorImpl::kFlexible, input->shape(), data_type)); static_cast((*output)->StartAddr()); switch (input->type().value()) { case DataType::DE_BOOL: CastFrom(input, output); break; case DataType::DE_INT8: CastFrom(input, output); break; case DataType::DE_UINT8: CastFrom(input, output); break; case DataType::DE_INT16: CastFrom(input, output); break; case DataType::DE_UINT16: CastFrom(input, output); break; case DataType::DE_INT32: CastFrom(input, output); break; case DataType::DE_UINT32: CastFrom(input, output); break; case DataType::DE_INT64: CastFrom(input, output); break; case DataType::DE_UINT64: CastFrom(input, output); break; case DataType::DE_FLOAT16: CastFrom(input, output); break; case DataType::DE_FLOAT32: CastFrom(input, output); break; case DataType::DE_FLOAT64: CastFrom(input, output); break; case DataType::DE_UNKNOWN: // sanity check, unreachable code. RETURN_STATUS_UNEXPECTED("TypeCast does not support input of this type."); } return Status::OK(); } Status ToFloat16(const std::shared_ptr &input, std::shared_ptr *output) { // initiate new tensor for type cast DataType new_type = DataType("float16"); RETURN_IF_NOT_OK(Tensor::CreateTensor(output, TensorImpl::kFlexible, input->shape(), new_type)); static_cast((*output)->StartAddr()); auto in_itr = input->begin(); auto out_itr = (*output)->begin(); auto out_end = (*output)->end(); for (; out_itr != out_end; in_itr++, out_itr++) *out_itr = Eigen::half(*in_itr); return Status::OK(); } } // namespace dataset } // namespace mindspore