You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tdt_plugin.cc 3.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "dataset/engine/tdt/tdt_plugin.h"
  17. #include "common/utils.h"
  18. #include "utils/log_adapter.h"
  19. namespace mindspore {
  20. namespace dataset {
  21. static std::shared_ptr<TdtPlugin> instance_ptr_ = nullptr;
  22. std::shared_ptr<TdtPlugin> TdtPlugin::GetInstance() {
  23. if (instance_ptr_ == nullptr) {
  24. instance_ptr_ = std::shared_ptr<TdtPlugin>(new TdtPlugin);
  25. }
  26. return instance_ptr_;
  27. }
  28. TdtStatus TdtPlugin::hostPush(TensorRow ts_row, bool is_wait, std::string channel_name) {
  29. MS_LOG(INFO) << "TDT channel name is " << channel_name << ".";
  30. std::vector<DataItem> items;
  31. auto ret = translate(ts_row, items);
  32. if (ret != SUCCESS) {
  33. MS_LOG(ERROR) << "TDT converting tensor failed!";
  34. return FAILED;
  35. }
  36. if (tdt::TdtHostPushData(channel_name, items) != 0) {
  37. MS_LOG(ERROR) << "TDT pushing data failed!";
  38. return FAILED;
  39. }
  40. return SUCCESS;
  41. }
  42. TdtStatus TdtPlugin::getTdtType(DataType d_type, std::string &datatype) {
  43. switch (d_type.value()) {
  44. case DataType::DE_BOOL:
  45. datatype = "bool";
  46. break;
  47. case DataType::DE_INT8:
  48. datatype = "int8";
  49. break;
  50. case DataType::DE_UINT8:
  51. datatype = "uint8";
  52. break;
  53. case DataType::DE_INT16:
  54. datatype = "int16";
  55. break;
  56. case DataType::DE_UINT16:
  57. datatype = "uint16";
  58. break;
  59. case DataType::DE_INT32:
  60. datatype = "int32";
  61. break;
  62. case DataType::DE_UINT32:
  63. datatype = "uint32";
  64. break;
  65. case DataType::DE_FLOAT16:
  66. datatype = "float16";
  67. break;
  68. case DataType::DE_FLOAT32:
  69. datatype = "float32";
  70. break;
  71. case DataType::DE_FLOAT64:
  72. datatype = "float64";
  73. break;
  74. case DataType::DE_INT64:
  75. datatype = "int64";
  76. break;
  77. case DataType::DE_UINT64:
  78. datatype = "uint64";
  79. break;
  80. default:
  81. return FAILED;
  82. }
  83. return SUCCESS;
  84. }
  85. TdtStatus TdtPlugin::translate(const TensorRow &ts_row, std::vector<DataItem> &items) {
  86. if (ts_row.size() == 0) {
  87. MS_LOG(ERROR) << "TDT the size of row is zero.";
  88. return SUCCESS;
  89. }
  90. for (auto ts : ts_row) {
  91. std::string datatype;
  92. TdtStatus status = getTdtType(ts->type(), datatype);
  93. if (status != SUCCESS) {
  94. return status;
  95. }
  96. TensorShape tsShape = ts->shape();
  97. std::string dataShapes = "[";
  98. for (auto dim : tsShape.AsVector()) {
  99. (void)dataShapes.append(std::to_string(dim)).append(",");
  100. }
  101. dataShapes.pop_back();
  102. (void)dataShapes.append("]");
  103. DataItem data_item;
  104. data_item.dataType_ = tdt::TDT_TENSOR;
  105. data_item.tensorShape_ = dataShapes;
  106. data_item.tensorType_ = datatype;
  107. data_item.dataLen_ = ts->SizeInBytes();
  108. data_item.dataPtr_ = std::shared_ptr<void>(reinterpret_cast<void *>(ts->StartAddr()), [](void *elem) {});
  109. items.emplace_back(data_item);
  110. MS_LOG(INFO) << "TDT data type is " << datatype << ", data shape is " << dataShapes << ", data length is "
  111. << ts->Size() << ".";
  112. }
  113. return SUCCESS;
  114. }
  115. } // namespace dataset
  116. } // namespace mindspore