|
|
|
@@ -53,5 +53,29 @@ std::shared_ptr<tensor::MSTensor> Execute::operator()(std::shared_ptr<tensor::MS |
|
|
|
return std::make_shared<tensor::DETensor>(std::move(de_output)); |
|
|
|
} |
|
|
|
|
|
|
|
std::shared_ptr<dataset::Tensor> Execute::operator()(std::shared_ptr<dataset::Tensor> input) { |
|
|
|
// Build the op |
|
|
|
if (op_ == nullptr) { |
|
|
|
MS_LOG(ERROR) << "Input TensorOperation is not valid"; |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
|
|
|
|
if (input == nullptr) { |
|
|
|
MS_LOG(ERROR) << "Input Tensor is not valid"; |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
// will add validate params once API is set |
|
|
|
std::shared_ptr<TensorOp> transform = op_->Build(); |
|
|
|
std::shared_ptr<Tensor> de_output; |
|
|
|
Status rc = transform->Compute(input, &de_output); |
|
|
|
|
|
|
|
if (rc.IsError()) { |
|
|
|
// execution failed |
|
|
|
MS_LOG(ERROR) << "Operation execution failed : " << rc.ToString(); |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
return de_output; |
|
|
|
} |
|
|
|
|
|
|
|
} // namespace dataset |
|
|
|
} // namespace mindspore |