|
|
@@ -113,22 +113,27 @@ Status OneHotEncoding(std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> *ou |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
Status Fill(const std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> *output, std::shared_ptr<Tensor> fill_value) { |
|
|
Status Fill(const std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> *output, std::shared_ptr<Tensor> fill_value) { |
|
|
CHECK_FAIL_RETURN_UNEXPECTED(!((fill_value->type() == DataType::DE_STRING) && (input->type() != DataType::DE_STRING)), |
|
|
|
|
|
|
|
|
const DataType &fill_type = fill_value->type(); |
|
|
|
|
|
const DataType &input_type = input->type(); |
|
|
|
|
|
const TensorShape &input_shape = input->shape(); |
|
|
|
|
|
|
|
|
|
|
|
CHECK_FAIL_RETURN_UNEXPECTED(!((fill_type == DataType::DE_STRING) && (input_type != DataType::DE_STRING)), |
|
|
"Types do not match"); |
|
|
"Types do not match"); |
|
|
|
|
|
|
|
|
CHECK_FAIL_RETURN_UNEXPECTED(fill_value->shape() == TensorShape({}), "fill_value is not a scalar"); |
|
|
CHECK_FAIL_RETURN_UNEXPECTED(fill_value->shape() == TensorShape({}), "fill_value is not a scalar"); |
|
|
|
|
|
|
|
|
std::shared_ptr<Tensor> out; |
|
|
|
|
|
|
|
|
|
|
|
const DataType &to = input->type(); |
|
|
|
|
|
std::unique_ptr<TypeCastOp> op(new TypeCastOp(to)); |
|
|
|
|
|
|
|
|
std::shared_ptr<Tensor> out, fill_output; |
|
|
|
|
|
|
|
|
std::shared_ptr<Tensor> fill_output; |
|
|
|
|
|
RETURN_IF_NOT_OK(op->Compute(fill_value, &fill_output)); |
|
|
|
|
|
|
|
|
if (input_type != DataType::DE_STRING && fill_type != DataType::DE_STRING && input_type != fill_type) { |
|
|
|
|
|
std::unique_ptr<TypeCastOp> op(new TypeCastOp(input_type)); |
|
|
|
|
|
RETURN_IF_NOT_OK(op->Compute(fill_value, &fill_output)); |
|
|
|
|
|
} else { |
|
|
|
|
|
fill_output = fill_value; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
RETURN_IF_NOT_OK(Tensor::CreateTensor(&out, TensorImpl::kFlexible, input->shape(), input->type())); |
|
|
|
|
|
|
|
|
RETURN_IF_NOT_OK(Tensor::CreateTensor(&out, TensorImpl::kFlexible, input_shape, input_type)); |
|
|
|
|
|
|
|
|
switch (input->type().value()) { |
|
|
|
|
|
|
|
|
switch (input_type.value()) { |
|
|
case DataType::DE_BOOL: { |
|
|
case DataType::DE_BOOL: { |
|
|
bool value = 0; |
|
|
bool value = 0; |
|
|
RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {})); |
|
|
RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {})); |
|
|
@@ -206,10 +211,10 @@ Status Fill(const std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> *output |
|
|
std::string_view fill_string_view; |
|
|
std::string_view fill_string_view; |
|
|
RETURN_IF_NOT_OK(fill_value->GetItemAt(&fill_string_view, {})); |
|
|
RETURN_IF_NOT_OK(fill_value->GetItemAt(&fill_string_view, {})); |
|
|
std::string fill_string = std::string(fill_string_view); |
|
|
std::string fill_string = std::string(fill_string_view); |
|
|
for (int i = 0; i < input->shape().NumOfElements(); i++) { |
|
|
|
|
|
|
|
|
for (int i = 0; i < input_shape.NumOfElements(); i++) { |
|
|
strings.emplace_back(fill_string); |
|
|
strings.emplace_back(fill_string); |
|
|
} |
|
|
} |
|
|
RETURN_IF_NOT_OK(Tensor::CreateTensor(&out, strings, input->shape())); |
|
|
|
|
|
|
|
|
RETURN_IF_NOT_OK(Tensor::CreateTensor(&out, strings, input_shape)); |
|
|
break; |
|
|
break; |
|
|
} |
|
|
} |
|
|
case DataType::DE_UNKNOWN: { |
|
|
case DataType::DE_UNKNOWN: { |
|
|
|