You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

plugin_op.cc 4.4 kB

4 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. /**
  2. * Copyright 2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "minddata/dataset/kernels/plugin_op.h"
  17. #include "minddata/dataset/core/tensor.h"
  18. #include "minddata/dataset/plugin/plugin_loader.h"
  19. namespace mindspore {
  20. namespace dataset {
  21. Status PluginOp::PluginToTensorRow(const std::vector<plugin::Tensor> &in_row, TensorRow *out_row) {
  22. CHECK_FAIL_RETURN_UNEXPECTED(out_row != nullptr && out_row->empty(), "null/empty out_row received!");
  23. out_row->reserve(in_row.size());
  24. for (const auto &tensor : in_row) {
  25. std::shared_ptr<Tensor> output;
  26. DataType tp = DataType(tensor.type_);
  27. CHECK_FAIL_RETURN_UNEXPECTED(tp.IsNumeric() && tp != DataType::DE_UNKNOWN, "Unsupported type: " + tensor.type_);
  28. RETURN_IF_NOT_OK(Tensor::CreateFromMemory(TensorShape(tensor.shape_), tp, tensor.buffer_.data(), &output));
  29. out_row->emplace_back(output);
  30. }
  31. return Status::OK();
  32. }
  33. Status PluginOp::TensorRowToPlugin(const TensorRow &in_row, std::vector<plugin::Tensor> *out_row) {
  34. CHECK_FAIL_RETURN_UNEXPECTED(out_row != nullptr && out_row->empty(), "null/empty out_row received!");
  35. out_row->resize(in_row.size());
  36. for (size_t ind = 0; ind < in_row.size(); ind++) {
  37. plugin::Tensor &tensor = (*out_row)[ind];
  38. if (in_row[ind]->type().IsNumeric()) {
  39. dsize_t buffer_size = in_row[ind]->SizeInBytes();
  40. tensor.buffer_.resize(buffer_size);
  41. if (buffer_size < SECUREC_MEM_MAX_LEN) {
  42. int ret_code = memcpy_s(tensor.buffer_.data(), tensor.buffer_.size(), in_row[ind]->GetBuffer(), buffer_size);
  43. CHECK_FAIL_RETURN_UNEXPECTED(ret_code == 0, "Failed to copy data into plugin tensor.");
  44. } else {
  45. auto ret_code = std::memcpy(tensor.buffer_.data(), in_row[ind]->GetBuffer(), buffer_size);
  46. CHECK_FAIL_RETURN_UNEXPECTED(ret_code == tensor.buffer_.data(), "Failed to copy data into plugin tensor.");
  47. }
  48. } else { // string tensor, for now, only tensor with 1 string is supported!
  49. CHECK_FAIL_RETURN_UNEXPECTED(in_row[ind]->shape().NumOfElements() == 1,
  50. "String tensor with more than 1 element is not yet supported.");
  51. // get the first and only string in this tensor
  52. std::string str1(*(in_row[ind]->begin<std::string_view>()));
  53. tensor.buffer_.resize(str1.size());
  54. auto ret_code = memcpy_s(tensor.buffer_.data(), tensor.buffer_.size(), str1.data(), str1.size());
  55. CHECK_FAIL_RETURN_UNEXPECTED(ret_code == 0, "memcpy_s failed when copying string tensor.");
  56. }
  57. tensor.shape_ = in_row[ind]->shape().AsVector();
  58. tensor.type_ = in_row[ind]->type().ToString();
  59. }
  60. return Status::OK();
  61. }
  62. Status PluginOp::Compute(const TensorRow &input, TensorRow *output) {
  63. // Compute should quit if init fails. Error code has already been logged, no need to repeat
  64. RETURN_IF_NOT_OK(init_code_);
  65. std::vector<plugin::Tensor> in_row, out_row;
  66. RETURN_IF_NOT_OK(TensorRowToPlugin(input, &in_row));
  67. plugin::Status rc = plugin_op_->Compute(&in_row, &out_row);
  68. CHECK_FAIL_RETURN_UNEXPECTED(rc.IsOk(), rc.ToString());
  69. RETURN_IF_NOT_OK(PluginToTensorRow(out_row, output));
  70. return Status::OK();
  71. }
  72. PluginOp::PluginOp(const std::string &lib_path, const std::string &func_name, const std::string &user_args)
  73. : plugin_op_(nullptr), lib_path_(lib_path), func_name_(func_name), user_args_(user_args) {
  74. init_code_ = Init();
  75. }
  76. Status PluginOp::Init() {
  77. plugin::PluginManagerBase *plugin = nullptr;
  78. RETURN_IF_NOT_OK(PluginLoader::GetInstance()->LoadPlugin(lib_path_, &plugin));
  79. // casting a void pointer to specific type
  80. plugin_op_ = dynamic_cast<plugin::TensorOp *>(plugin->GetModule(func_name_));
  81. RETURN_UNEXPECTED_IF_NULL(plugin_op_);
  82. plugin::Status rc = plugin_op_->ParseSerializedArgs(user_args_);
  83. CHECK_FAIL_RETURN_UNEXPECTED(rc.IsOk(), rc.ToString());
  84. return Status::OK();
  85. }
  86. } // namespace dataset
  87. } // namespace mindspore