You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tdt_plugin.cc 3.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "dataset/engine/tdt/tdt_plugin.h"
  17. #include "common/utils.h"
  18. #include "utils/log_adapter.h"
  19. #include "dataset/engine/perf/profiling.h"
  20. namespace mindspore {
  21. namespace dataset {
  22. static std::shared_ptr<TdtPlugin> instance_ptr_ = nullptr;
  23. std::shared_ptr<TdtPlugin> TdtPlugin::GetInstance() {
  24. if (instance_ptr_ == nullptr) {
  25. instance_ptr_ = std::shared_ptr<TdtPlugin>(new TdtPlugin);
  26. }
  27. return instance_ptr_;
  28. }
  29. TdtStatus TdtPlugin::hostPush(TensorRow ts_row, bool is_wait, std::string channel_name, bool profiling, int32_t &time) {
  30. MS_LOG(DEBUG) << "TDT channel name is " << channel_name << ".";
  31. std::vector<DataItem> items;
  32. double start_time;
  33. auto ret = translate(ts_row, items);
  34. if (ret != SUCCESS) {
  35. MS_LOG(ERROR) << "TDT converting tensor failed!";
  36. return FAILED;
  37. }
  38. if (profiling) {
  39. start_time = ProfilingTime::GetCurMilliSecond();
  40. }
  41. if (tdt::TdtHostPushData(channel_name, items) != 0) {
  42. MS_LOG(ERROR) << "TDT pushing data failed!";
  43. return FAILED;
  44. }
  45. if (profiling) {
  46. double end_time = ProfilingTime::GetCurMilliSecond();
  47. time = (int32_t)(end_time - start_time);
  48. }
  49. return SUCCESS;
  50. }
  51. TdtStatus TdtPlugin::getTdtType(DataType d_type, std::string &datatype) {
  52. switch (d_type.value()) {
  53. case DataType::DE_BOOL:
  54. datatype = "bool";
  55. break;
  56. case DataType::DE_INT8:
  57. datatype = "int8";
  58. break;
  59. case DataType::DE_UINT8:
  60. datatype = "uint8";
  61. break;
  62. case DataType::DE_INT16:
  63. datatype = "int16";
  64. break;
  65. case DataType::DE_UINT16:
  66. datatype = "uint16";
  67. break;
  68. case DataType::DE_INT32:
  69. datatype = "int32";
  70. break;
  71. case DataType::DE_UINT32:
  72. datatype = "uint32";
  73. break;
  74. case DataType::DE_FLOAT16:
  75. datatype = "float16";
  76. break;
  77. case DataType::DE_FLOAT32:
  78. datatype = "float32";
  79. break;
  80. case DataType::DE_FLOAT64:
  81. datatype = "float64";
  82. break;
  83. case DataType::DE_INT64:
  84. datatype = "int64";
  85. break;
  86. case DataType::DE_UINT64:
  87. datatype = "uint64";
  88. break;
  89. default:
  90. return FAILED;
  91. }
  92. return SUCCESS;
  93. }
  94. TdtStatus TdtPlugin::translate(const TensorRow &ts_row, std::vector<DataItem> &items) {
  95. if (ts_row.size() == 0) {
  96. MS_LOG(ERROR) << "TDT the size of row is zero.";
  97. return SUCCESS;
  98. }
  99. for (auto ts : ts_row) {
  100. std::string datatype;
  101. TdtStatus status = getTdtType(ts->type(), datatype);
  102. if (status != SUCCESS) {
  103. return status;
  104. }
  105. TensorShape tsShape = ts->shape();
  106. std::string dataShapes = "[";
  107. for (auto dim : tsShape.AsVector()) {
  108. (void)dataShapes.append(std::to_string(dim)).append(",");
  109. }
  110. dataShapes.pop_back();
  111. (void)dataShapes.append("]");
  112. DataItem data_item;
  113. data_item.dataType_ = tdt::TDT_TENSOR;
  114. data_item.tensorShape_ = dataShapes;
  115. data_item.tensorType_ = datatype;
  116. data_item.dataLen_ = ts->SizeInBytes();
  117. data_item.dataPtr_ =
  118. std::shared_ptr<void>(reinterpret_cast<uchar *>(&(*ts->begin<uint8_t>())), [](const void *elem) {});
  119. items.emplace_back(data_item);
  120. MS_LOG(DEBUG) << "TDT data type is " << datatype << ", data shape is " << dataShapes << ", data length is "
  121. << ts->Size() << ".";
  122. }
  123. return SUCCESS;
  124. }
  125. } // namespace dataset
  126. } // namespace mindspore