You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

batch_op.cc 22 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "dataset/engine/datasetops/batch_op.h"
  17. #include <utility>
  18. #include <iomanip>
  19. #include "common/utils.h"
  20. #include "dataset/core/pybind_support.h"
  21. #include "dataset/engine/data_buffer.h"
  22. #include "dataset/engine/db_connector.h"
  23. #include "dataset/engine/opt/pass.h"
  24. using float16 = Eigen::half;
  25. namespace mindspore {
  26. namespace dataset {
  27. BatchOp::Builder::Builder(int32_t batch_size) : builder_drop_(false), builder_pad_(false), builder_pad_map_({}) {
  28. builder_batch_size_ = batch_size;
  29. std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
  30. builder_num_workers_ = cfg->num_parallel_workers();
  31. builder_op_connector_size_ = cfg->op_connector_size();
  32. }
  33. Status BatchOp::Builder::Build(std::shared_ptr<BatchOp> *ptr) {
  34. RETURN_IF_NOT_OK(SanityCheck());
  35. *ptr = std::make_shared<BatchOp>(builder_batch_size_, builder_drop_, builder_pad_, builder_op_connector_size_,
  36. builder_num_workers_, builder_cols_to_map_, builder_batch_size_func_,
  37. builder_batch_map_func_, builder_pad_map_);
  38. return Status::OK();
  39. }
  40. Status BatchOp::Builder::SanityCheck() {
  41. std::string err;
  42. err += builder_op_connector_size_ <= 0 ? "connector size <= 0\n" : "";
  43. err += builder_batch_size_ <= 0 ? "batch size <= 0\n" : "";
  44. err += builder_num_workers_ <= 0 ? "batch num_parallel_workers <= 0\n" : "";
  45. return err.empty() ? Status::OK() : Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, common::SafeCStr(err));
  46. }
  47. BatchOp::BatchOp(int32_t batch_size, bool drop, bool pad, int32_t op_queue_size, int32_t num_workers,
  48. const std::vector<std::string> &cols_to_map, py::function batch_size_func, py::function batch_map_func,
  49. std::map<std::string, std::pair<TensorShape, float>> pad_map)
  50. : ParallelOp(num_workers, op_queue_size),
  51. start_batch_size_(batch_size),
  52. drop_(drop),
  53. pad_(pad),
  54. pyfunc_column_names_(cols_to_map),
  55. batch_size_func_(batch_size_func),
  56. batch_map_func_(batch_map_func),
  57. pad_info_(pad_map) {
  58. worker_queues_.Init(num_workers, op_queue_size);
  59. }
  60. Status BatchOp::operator()() {
  61. Status rc = LaunchThreadsAndInitOp();
  62. // Synchronize with TaskManager
  63. TaskManager::FindMe()->Post();
  64. RETURN_IF_NOT_OK(rc);
  65. int64_t epoch_num = 0, batch_num = 0, cnt = 0;
  66. TensorRow new_row;
  67. std::unique_ptr<TensorQTable> table = std::make_unique<TensorQTable>();
  68. child_iterator_ = std::make_unique<ChildIterator>(this, 0, 0);
  69. RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(&new_row));
  70. for (const auto &t : new_row) {
  71. CHECK_FAIL_RETURN_UNEXPECTED(t->type().IsNumeric(),
  72. "[Batch ERROR] Batch does not support Tensor of type string yet.");
  73. }
  74. RETURN_IF_NOT_OK(DatasetOp::AssignColMapFromChild()); // must come after the first fetch above
  75. int32_t cur_batch_size = 0;
  76. RETURN_IF_NOT_OK(GetBatchSize(&cur_batch_size, CBatchInfo(0, 0, 0)));
  77. while (child_iterator_->eof_handled() == false) {
  78. while (new_row.empty() == false) {
  79. table->emplace_back(new_row);
  80. // if # of rows is enough to make 1 batch (1 batch is buffer), send it to worker_queue
  81. if (table->size() == static_cast<size_t>(cur_batch_size)) {
  82. RETURN_IF_NOT_OK(worker_queues_[cnt++ % num_workers_]->EmplaceBack(
  83. std::make_pair(std::move(table), CBatchInfo(epoch_num, batch_num++, cnt - epoch_num))));
  84. table = std::make_unique<TensorQTable>();
  85. RETURN_IF_NOT_OK(GetBatchSize(&cur_batch_size, CBatchInfo(epoch_num, batch_num, cnt - epoch_num)));
  86. }
  87. RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(&new_row));
  88. }
  89. // Reminder logic, execute only when there is a remainder (table is non empty) and don't drop
  90. if (drop_ == false && table->empty() == false) {
  91. RETURN_IF_NOT_OK(worker_queues_[cnt++ % num_workers_]->EmplaceBack(
  92. std::make_pair(std::move(table), CBatchInfo(epoch_num, batch_num++, cnt - epoch_num))));
  93. }
  94. table = std::make_unique<TensorQTable>(); // this drops when drop == true
  95. // end of the current epoch, batch_num should start from 0 again
  96. batch_num = 0;
  97. epoch_num++;
  98. RETURN_IF_NOT_OK(
  99. worker_queues_[cnt++ % num_workers_]->EmplaceBack(std::make_pair(nullptr, CBatchInfo(batchCtrl::kEOE))));
  100. RETURN_IF_NOT_OK(GetBatchSize(&cur_batch_size, CBatchInfo(epoch_num, batch_num, cnt - epoch_num)));
  101. RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(&new_row));
  102. } // end of eof_handled() == false
  103. RETURN_IF_NOT_OK(
  104. worker_queues_[cnt++ % num_workers_]->EmplaceBack(std::make_pair(nullptr, CBatchInfo(batchCtrl::kEOF))));
  105. // EOF received, send quit signal (an empty buffer) to all workers
  106. for (int32_t ind = 0; ind < num_workers_; ind++) {
  107. RETURN_IF_NOT_OK(
  108. worker_queues_[cnt++ % num_workers_]->EmplaceBack(std::make_pair(nullptr, CBatchInfo(batchCtrl::kQuit))));
  109. }
  110. return Status::OK();
  111. }
  112. void BatchOp::Print(std::ostream &out, bool show_all) const {
  113. // Always show the id and name as first line regardless if this summary or detailed print
  114. out << "(" << std::setw(2) << operator_id_ << ") <BatchOp>:";
  115. if (!show_all) {
  116. // Call the super class for displaying any common 1-liner info
  117. ParallelOp::Print(out, show_all);
  118. // Then show any custom derived-internal 1-liner info for this op
  119. out << " [batch size: " << start_batch_size_ << "]\n";
  120. } else {
  121. // Call the super class for displaying any common detailed info
  122. ParallelOp::Print(out, show_all);
  123. // Then show any custom derived-internal stuff
  124. out << "\nStart batch size: " << start_batch_size_ << "\nDrop remainder: " << (drop_ ? "yes" : "no") << "\n\n";
  125. }
  126. }
  127. Status BatchOp::BatchRows(const std::unique_ptr<TensorQTable> *source_table,
  128. const std::unique_ptr<TensorQTable> *dest_table, size_t batch_size) {
  129. if ((*source_table)->size() < batch_size || (*source_table)->size() == 0) {
  130. RETURN_STATUS_UNEXPECTED("[Internal Batch ERROR] Insufficient rows in source_table\n");
  131. }
  132. TensorRow row = std::move((*source_table)->front());
  133. (*source_table)->pop_front();
  134. if (batch_size == 1) {
  135. for (std::shared_ptr<Tensor> tensor : row) {
  136. RETURN_IF_NOT_OK(tensor->ExpandDim(0));
  137. }
  138. (*dest_table)->push_back(row);
  139. } else { // batch_size > 1
  140. std::vector<TensorShape> row_shapes;
  141. TensorRow batched_row;
  142. for (size_t i = 0; i < row.size(); i++) { // Handle the first row popped
  143. row_shapes.push_back(row[i]->shape());
  144. std::shared_ptr<Tensor> ts;
  145. RETURN_IF_NOT_OK(Tensor::CreateTensor(
  146. &ts, TensorImpl::kFlexible, row[i]->shape().PrependDim(static_cast<int64_t>(batch_size)), row[i]->type()));
  147. batched_row.emplace_back(ts);
  148. RETURN_IF_NOT_OK(batched_row[i]->InsertTensor(std::vector<dsize_t>(1, 0), row[i])); // {j} = 0
  149. }
  150. for (size_t j = 1; j < batch_size; j++) { // Handle the rest of the rows
  151. row = std::move((*source_table)->front());
  152. (*source_table)->pop_front();
  153. for (size_t i = 0; i < row.size(); i++) {
  154. if (row[i]->shape() == row_shapes[i]) { // check the newly popped rows have the same dim as the first
  155. RETURN_IF_NOT_OK(batched_row[i]->InsertTensor(std::vector<dsize_t>(1, j), row[i]));
  156. } else {
  157. std::string column_name;
  158. for (auto itr : column_name_id_map_) {
  159. if (static_cast<size_t>(itr.second) == i) {
  160. column_name = itr.first;
  161. break;
  162. }
  163. }
  164. RETURN_STATUS_UNEXPECTED("[Batch ERROR] Inconsistent TensorShapes of Column " + column_name);
  165. }
  166. }
  167. }
  168. (*dest_table)->emplace_back(batched_row);
  169. }
  170. return Status::OK();
  171. }
  172. Status BatchOp::WorkerEntry(int32_t workerId) {
  173. TaskManager::FindMe()->Post();
  174. std::pair<std::unique_ptr<TensorQTable>, CBatchInfo> table_pair;
  175. RETURN_IF_NOT_OK(worker_queues_[workerId]->PopFront(&table_pair));
  176. while (table_pair.second.ctrl_ != batchCtrl::kQuit) {
  177. if (table_pair.second.ctrl_ == batchCtrl::kEOE) {
  178. RETURN_IF_NOT_OK(out_connector_->Add(workerId, std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE)));
  179. } else if (table_pair.second.ctrl_ == batchCtrl::kEOF) {
  180. RETURN_IF_NOT_OK(out_connector_->Add(workerId, std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF)));
  181. } else if (table_pair.second.ctrl_ == batchCtrl::kNoCtrl) {
  182. std::unique_ptr<DataBuffer> db = nullptr;
  183. RETURN_IF_NOT_OK(MakeBatchedBuffer(std::move(table_pair), &db));
  184. RETURN_IF_NOT_OK(out_connector_->Add(workerId, std::move(db)));
  185. }
  186. RETURN_IF_NOT_OK(worker_queues_[workerId]->PopFront(&table_pair));
  187. }
  188. return Status::OK();
  189. }
  190. Status BatchOp::MakeBatchedBuffer(std::pair<std::unique_ptr<TensorQTable>, CBatchInfo> table_pair,
  191. std::unique_ptr<DataBuffer> *db) {
  192. RETURN_UNEXPECTED_IF_NULL(table_pair.first);
  193. if (!pyfunc_column_names_.empty()) RETURN_IF_NOT_OK(MapColumns(&table_pair)); // pass it through pyfunc
  194. if (pad_) RETURN_IF_NOT_OK(PadColumns(&table_pair)); // do padding if needed
  195. (*db) = std::make_unique<DataBuffer>(table_pair.second.batch_num_, DataBuffer::kDeBFlagNone);
  196. std::unique_ptr<TensorQTable> dest_table = std::make_unique<TensorQTable>();
  197. RETURN_IF_NOT_OK(BatchRows(&table_pair.first, &dest_table, table_pair.first->size()));
  198. (*db)->set_tensor_table(std::move(dest_table));
  199. return Status::OK();
  200. }
  201. Status BatchOp::LaunchThreadsAndInitOp() {
  202. RETURN_UNEXPECTED_IF_NULL(tree_);
  203. RETURN_IF_NOT_OK(worker_queues_.Register(tree_->AllTasks()));
  204. RETURN_IF_NOT_OK(tree_->LaunchWorkers(num_workers_, std::bind(&BatchOp::WorkerEntry, this, std::placeholders::_1)));
  205. return Status::OK();
  206. }
  207. Status BatchOp::EofReceived(int32_t) { return Status::OK(); }
  208. Status BatchOp::EoeReceived(int32_t) {
  209. state_ = OpState::kDeOpIdle;
  210. return Status::OK();
  211. }
  212. Status BatchOp::MapColumns(std::pair<std::unique_ptr<TensorQTable>, CBatchInfo> *table_pair) {
  213. TensorBatchTable input_table;
  214. input_table.reserve(pyfunc_column_names_.size());
  215. for (std::string col_name : pyfunc_column_names_) {
  216. if (column_name_id_map_.find(col_name) == column_name_id_map_.end()) {
  217. RETURN_STATUS_UNEXPECTED("column : '" + col_name + "' does not exist\n");
  218. }
  219. TensorBatch tensor_batch;
  220. tensor_batch.reserve(table_pair->first->size());
  221. size_t col_idx = static_cast<size_t>(column_name_id_map_[col_name]);
  222. for (size_t row_idx = 0; row_idx < table_pair->first->size(); row_idx++) {
  223. tensor_batch.push_back(std::move(table_pair->first->at(row_idx)[col_idx]));
  224. }
  225. input_table.push_back(std::move(tensor_batch));
  226. }
  227. // Perform batch map
  228. TensorBatchTable output_table;
  229. RETURN_IF_NOT_OK(InvokeBatchMapFunc(&input_table, &output_table, table_pair->second));
  230. // Write back to TensorQTable
  231. for (size_t input_idx = 0; input_idx < pyfunc_column_names_.size(); input_idx++) {
  232. size_t col_idx = static_cast<size_t>(column_name_id_map_[pyfunc_column_names_[input_idx]]);
  233. size_t row_id = 0;
  234. for (TensorRow &row : *(table_pair->first)) {
  235. row[col_idx] = std::move(output_table[input_idx][row_id++]);
  236. }
  237. }
  238. return Status::OK();
  239. }
  240. Status BatchOp::GetBatchSize(int32_t *batch_size, CBatchInfo info) {
  241. if (batch_size_func_ != nullptr) {
  242. RETURN_IF_NOT_OK(InvokeBatchSizeFunc(batch_size, info));
  243. } else {
  244. (*batch_size) = start_batch_size_;
  245. }
  246. return Status::OK();
  247. }
  248. Status BatchOp::InvokeBatchSizeFunc(int32_t *batch_size, CBatchInfo info) {
  249. {
  250. // Acquire Python GIL
  251. py::gil_scoped_acquire gil_acquire;
  252. if (Py_IsInitialized() == 0) {
  253. return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized");
  254. }
  255. try {
  256. py::object size = batch_size_func_(info);
  257. *batch_size = size.cast<int32_t>();
  258. if (*batch_size <= 0) {
  259. return Status(StatusCode::kPyFuncException, "Batch size function should return an integer > 0");
  260. }
  261. } catch (const py::error_already_set &e) {
  262. return Status(StatusCode::kPyFuncException, e.what());
  263. } catch (const py::cast_error &e) {
  264. return Status(StatusCode::kPyFuncException, "Batch size function should return an integer > 0");
  265. }
  266. }
  267. return Status(StatusCode::kOK, "Batch size func call succeed");
  268. }
  269. Status BatchOp::InvokeBatchMapFunc(TensorBatchTable *input, TensorBatchTable *output, CBatchInfo info) {
  270. {
  271. // Acquire Python GIL
  272. py::gil_scoped_acquire gil_acquire;
  273. if (Py_IsInitialized() == 0) {
  274. return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized");
  275. }
  276. try {
  277. // Prepare batch map call back parameters
  278. py::tuple input_args(input->size() + 1);
  279. for (size_t i = 0; i < input->size(); i++) {
  280. std::vector<py::array> np_batch;
  281. for (std::shared_ptr<Tensor> t : input->at(i)) {
  282. py::array np_array;
  283. RETURN_IF_NOT_OK(t->GetDataAsNumpy(&np_array));
  284. np_batch.push_back(std::move(np_array));
  285. }
  286. input_args[i] = np_batch;
  287. }
  288. input_args[input->size()] = info;
  289. // Invoke batch map func
  290. py::object ret_py_obj = batch_map_func_(*input_args);
  291. // Parse batch map return value
  292. py::tuple ret_tuple = py::cast<py::tuple>(ret_py_obj);
  293. if (ret_tuple.size() != pyfunc_column_names_.size() || !py::isinstance<py::tuple>(ret_tuple)) {
  294. return Status(StatusCode::kPyFuncException, "Batch map function should return a tuple");
  295. }
  296. for (size_t i = 0; i < ret_tuple.size(); i++) {
  297. TensorBatch output_batch;
  298. py::list output_list = py::cast<py::list>(ret_tuple[i]);
  299. for (size_t j = 0; j < output_list.size(); j++) {
  300. std::shared_ptr<Tensor> out;
  301. RETURN_IF_NOT_OK(Tensor::CreateTensor(&out, py::cast<py::array>(output_list[j])));
  302. output_batch.push_back(std::move(out));
  303. }
  304. output->push_back(std::move(output_batch));
  305. }
  306. } catch (const py::error_already_set &e) {
  307. return Status(StatusCode::kPyFuncException, e.what());
  308. } catch (const py::cast_error &e) {
  309. return Status(StatusCode::kPyFuncException, "Batch map function should return an tuple of list of numpy array");
  310. }
  311. }
  312. return Status(StatusCode::kOK);
  313. }
  314. Status BatchOp::PadTensor(std::shared_ptr<Tensor> src, std::shared_ptr<Tensor> *dst,
  315. const std::vector<dsize_t> &pad_shape, float pad_val) {
  316. CHECK_FAIL_RETURN_UNEXPECTED(src != nullptr && dst != nullptr, "tensor can't be nullptr");
  317. if (src->Rank() == 0 || src->shape().AsVector() == pad_shape) {
  318. (*dst) = src; // if no padding, copy the pointer
  319. } else {
  320. CHECK_FAIL_RETURN_UNEXPECTED(src->Rank() == pad_shape.size(), "Pad to diff rank not allowed");
  321. RETURN_IF_NOT_OK(Tensor::CreateTensor(dst, TensorImpl::kFlexible, TensorShape(pad_shape), src->type()));
  322. auto tensor_type = src->type().value();
  323. if (pad_val == 0) { // if pad with zero, don't care what type it is
  324. RETURN_IF_NOT_OK((*dst)->Zero());
  325. } else if (tensor_type == DataType::DE_INT8) {
  326. RETURN_IF_NOT_OK((*dst)->Fill<int8_t>(pad_val));
  327. } else if (tensor_type == DataType::DE_BOOL) {
  328. RETURN_IF_NOT_OK((*dst)->Fill<bool>(pad_val));
  329. } else if (tensor_type == DataType::DE_UINT8) {
  330. RETURN_IF_NOT_OK((*dst)->Fill<uint8_t>(pad_val));
  331. } else if (tensor_type == DataType::DE_INT16) {
  332. RETURN_IF_NOT_OK((*dst)->Fill<int16_t>(pad_val));
  333. } else if (tensor_type == DataType::DE_FLOAT16) {
  334. RETURN_IF_NOT_OK((*dst)->Fill<float16>(static_cast<float16>(pad_val)));
  335. } else if (tensor_type == DataType::DE_UINT16) {
  336. RETURN_IF_NOT_OK((*dst)->Fill<uint16_t>(pad_val));
  337. } else if (tensor_type == DataType::DE_INT32) {
  338. RETURN_IF_NOT_OK((*dst)->Fill<int32_t>(pad_val));
  339. } else if (tensor_type == DataType::DE_UINT32) {
  340. RETURN_IF_NOT_OK((*dst)->Fill<uint32_t>(pad_val));
  341. } else if (tensor_type == DataType::DE_INT64) {
  342. RETURN_IF_NOT_OK((*dst)->Fill<int64_t>(pad_val));
  343. } else if (tensor_type == DataType::DE_UINT64) {
  344. RETURN_IF_NOT_OK((*dst)->Fill<uint64_t>(pad_val));
  345. } else if (tensor_type == DataType::DE_FLOAT32) {
  346. RETURN_IF_NOT_OK((*dst)->Fill<float>(pad_val));
  347. } else if (tensor_type == DataType::DE_FLOAT64) {
  348. RETURN_IF_NOT_OK((*dst)->Fill<double>(pad_val));
  349. } else {
  350. RETURN_STATUS_UNEXPECTED("Incorrect/Unknown tensor type");
  351. }
  352. std::vector<dsize_t> cur_ind(src->Rank(), 0), src_s(src->Rank(), 1), dst_s(src->Rank(), 1);
  353. for (dsize_t i = src->Rank() - 2; i >= 0; i--) {
  354. src_s[i] = src->shape()[i + 1] * src_s[i + 1];
  355. dst_s[i] = pad_shape[i + 1] * dst_s[i + 1];
  356. }
  357. RETURN_IF_NOT_OK(PadHelper(src, *dst, cur_ind, src_s, dst_s, 0));
  358. }
  359. return Status::OK();
  360. } // namespace dataset
  361. Status BatchOp::PadColumns(std::pair<std::unique_ptr<TensorQTable>, CBatchInfo> *table_pair) {
  362. RETURN_UNEXPECTED_IF_NULL(table_pair); // placeholder for now, might need this in the future
  363. CHECK_FAIL_RETURN_UNEXPECTED(table_pair->first->front().size() == column_name_id_map_.size(),
  364. "col_name_map mismatch");
  365. std::vector<float> pad_vals(column_name_id_map_.size(), 0); // value to pad each column's tensor with, default 0
  366. std::set<int32_t> pad_cols;
  367. // padded_shape provided by user, maximum shapes of current batch of tensors
  368. std::vector<std::vector<dsize_t>> pad_shapes(column_name_id_map_.size()), max_shapes(column_name_id_map_.size());
  369. RETURN_IF_NOT_OK(UnpackPadInfo(&pad_cols, &pad_vals, &pad_shapes));
  370. // init each shape in max_shape to {-1,-1...} init each unspecified shape in pad_shape to -1 as well
  371. for (size_t col_id : pad_cols) {
  372. max_shapes[col_id] = std::vector<dsize_t>(table_pair->first->front()[col_id]->Rank(), -1);
  373. if (pad_shapes[col_id].empty()) pad_shapes[col_id] = max_shapes[col_id]; // fill pad shape with -1
  374. CHECK_FAIL_RETURN_UNEXPECTED(pad_shapes[col_id].size() == max_shapes[col_id].size(), "wrong rank in pad_shape");
  375. }
  376. // calculate maximum shape for each column that needs to be padded
  377. for (const TensorRow &row : *(table_pair->first)) { // iterator each row in a batch
  378. for (size_t col_id : pad_cols) { // iterator each tensor in a row
  379. CHECK_FAIL_RETURN_UNEXPECTED(row[col_id]->Rank() == max_shapes[col_id].size(),
  380. "Tensor to be padded together need to have the same rank");
  381. for (size_t dim = 0; dim < row[col_id]->Rank(); dim++) { // pick the largest number in each dimension
  382. max_shapes[col_id][dim] = std::max(max_shapes[col_id][dim], row[col_id]->shape()[dim]);
  383. }
  384. }
  385. }
  386. // if user sets a dimension to -1 (None in python), use the max value for current dimension
  387. for (size_t col_id : pad_cols) {
  388. for (size_t dim = 0; dim < pad_shapes[col_id].size(); dim++) {
  389. if (pad_shapes[col_id][dim] < 0) pad_shapes[col_id][dim] = max_shapes[col_id][dim];
  390. }
  391. }
  392. // call pad on each tensor that needs to be padded
  393. for (TensorRow &row : *(table_pair->first)) {
  394. for (size_t col_id : pad_cols) {
  395. std::shared_ptr<Tensor> pad_tensor;
  396. RETURN_IF_NOT_OK(PadTensor(row[col_id], &pad_tensor, pad_shapes[col_id], pad_vals[col_id]));
  397. row[col_id] = pad_tensor;
  398. }
  399. }
  400. return Status::OK();
  401. }
  402. Status BatchOp::UnpackPadInfo(std::set<int32_t> *pad_cols, std::vector<float> *pad_vals,
  403. std::vector<std::vector<dsize_t>> *pad_shapes) {
  404. if (pad_info_.empty()) { // if pad_info empty, pad every columns automatically
  405. for (dsize_t col_id = 0; col_id < column_name_id_map_.size(); col_id++) {
  406. pad_cols->insert(col_id);
  407. }
  408. } else {
  409. for (auto p : pad_info_) {
  410. CHECK_FAIL_RETURN_UNEXPECTED(column_name_id_map_.find(p.first) != column_name_id_map_.end(),
  411. "no column exists with name:" + p.first);
  412. dsize_t col_id = static_cast<dsize_t>(column_name_id_map_[p.first]);
  413. CHECK_FAIL_RETURN_UNEXPECTED(col_id < pad_vals->size() && col_id < pad_shapes->size(), "col_id out of bound");
  414. pad_cols->insert(col_id);
  415. (*pad_vals)[col_id] = p.second.second; // set pad values
  416. (*pad_shapes)[col_id] = p.second.first.AsVector(); // empty vector if shape is unknown
  417. }
  418. }
  419. return Status::OK();
  420. }
  421. Status BatchOp::PadHelper(std::shared_ptr<Tensor> src, std::shared_ptr<Tensor> dst, std::vector<dsize_t> cur_ind,
  422. const std::vector<dsize_t> &src_s, const std::vector<dsize_t> &dst_s, size_t cur_dim) {
  423. if (cur_dim == src->Rank() - 1) { // if this is the last dimension, copy the data
  424. uint8_t type_size = src->type().SizeInBytes();
  425. size_t len = std::min(src->shape()[cur_dim], dst->shape()[cur_dim]) * type_size;
  426. dsize_t src_flat_ind = 0, dst_flat_ind = 0;
  427. for (size_t i = 0; i < src->Rank(); i++) {
  428. src_flat_ind += src_s[i] * cur_ind[i];
  429. dst_flat_ind += dst_s[i] * cur_ind[i];
  430. }
  431. unsigned char *src_addr = src->GetMutableBuffer() + src_flat_ind * type_size;
  432. unsigned char *dst_addr = dst->GetMutableBuffer() + dst_flat_ind * type_size;
  433. CHECK_FAIL_RETURN_UNEXPECTED(memcpy_s(dst_addr, len, src_addr, len) == 0, "memcpy error");
  434. } else { // not the last dimension, keep doing recursion
  435. dsize_t min_ind = std::min(dst->shape()[cur_dim], src->shape()[cur_dim]);
  436. for (dsize_t i = 0; i < min_ind; i++) {
  437. cur_ind[cur_dim] = i;
  438. RETURN_IF_NOT_OK(PadHelper(src, dst, cur_ind, src_s, dst_s, cur_dim + 1));
  439. }
  440. }
  441. return Status::OK();
  442. }
  443. // Visitor accept method for NodePass
  444. Status BatchOp::Accept(NodePass *p, bool *modified) {
  445. // Downcast shared pointer then call visitor
  446. return p->RunOnNode(std::static_pointer_cast<BatchOp>(shared_from_this()), modified);
  447. }
  448. } // namespace dataset
  449. } // namespace mindspore