You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

dataset_op.cc 17 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "dataset/engine/datasetops/dataset_op.h"
  17. #include <iomanip>
  18. #include <iostream>
  19. #include <memory>
  20. #include <regex>
  21. #include <utility>
  22. #include <string>
  23. #include <algorithm>
  24. #include "dataset/engine/execution_tree.h"
  25. #include "dataset/engine/datasetops/device_queue_op.h"
  26. #include "dataset/engine/datasetops/source/sampler/sampler.h"
  27. #include "dataset/engine/data_buffer.h"
  28. #include "dataset/engine/db_connector.h"
  29. #include "dataset/engine/opt/pass.h"
  30. #include "utils/system/crc32c.h"
  31. #include "utils/log_adapter.h"
  32. namespace mindspore {
  33. namespace dataset {
  34. // Constructor
  35. DatasetOp::DatasetOp(int32_t op_connector_size, std::shared_ptr<Sampler> sampler)
  36. : oc_queue_size_(op_connector_size),
  37. sampler_(sampler),
  38. operator_id_(kInvalidOperatorId),
  39. tree_(nullptr),
  40. state_(OpState::kDeOpIdle),
  41. op_ctrl_flags_(kDeOpNone),
  42. out_connector_(nullptr) {
  43. // The operator starts out with an invalid operator id. The only way to
  44. // get it out of invalid state is to assign the operator to an execution tree.
  45. }
  46. // Adds a operator to become our child.
  47. Status DatasetOp::AddChild(std::shared_ptr<DatasetOp> child) {
  48. if (std::dynamic_pointer_cast<DeviceQueueOp>(child) != nullptr) {
  49. std::string err_msg("DeviceQueueOp cannot be added as a child, DeviceQueueOp must be a root node");
  50. RETURN_STATUS_UNEXPECTED(err_msg);
  51. }
  52. if (operator_id_ == kInvalidOperatorId) {
  53. std::string err_msg(
  54. "Cannot add child node. Tree node connections can only"
  55. "be made if the node belongs to a tree.");
  56. RETURN_STATUS_UNEXPECTED(err_msg);
  57. }
  58. // disallow relationships with other trees
  59. if (tree_ != child->tree_) {
  60. std::string err_msg(
  61. "Cannot add child node. Tree node connections can only be made if both nodes belong to the same tree.");
  62. RETURN_STATUS_UNEXPECTED(err_msg);
  63. }
  64. child_.push_back(child);
  65. child->AddParent(this);
  66. return Status::OK();
  67. }
  68. Status DatasetOp::RemoveChild(std::shared_ptr<DatasetOp> child) {
  69. if (operator_id_ == kInvalidOperatorId) {
  70. std::string err_msg(
  71. "Cannot remove child node. Tree node connections can only"
  72. "be made if the node belongs to a tree.");
  73. RETURN_STATUS_UNEXPECTED(err_msg);
  74. }
  75. // disallow relationships with other trees
  76. if (tree_ != child->tree_) {
  77. std::string err_msg(
  78. "Cannot remove child node. Tree node connections can only be made if both nodes belong to the same tree.");
  79. RETURN_STATUS_UNEXPECTED(err_msg);
  80. }
  81. child_.erase(std::remove(child_.begin(), child_.end(), child), child_.end());
  82. child->RemoveParent(this);
  83. return Status::OK();
  84. }
  85. Status DatasetOp::InsertAsParent(std::shared_ptr<DatasetOp> to_add) {
  86. for (auto &prev_parent : this->parent_) {
  87. RETURN_IF_NOT_OK(prev_parent->RemoveChild(shared_from_this()));
  88. RETURN_IF_NOT_OK(prev_parent->AddChild(to_add));
  89. }
  90. RETURN_IF_NOT_OK(to_add->AddChild(shared_from_this()));
  91. if (tree_->root()->id() == this->id()) {
  92. tree_->AssignRoot(to_add);
  93. }
  94. return Status::OK();
  95. }
  96. // Adds a parent operator to this operator
  97. void DatasetOp::AddParent(DatasetOp *parent) { parent_.push_back(parent); }
  98. // Removes a parent operator from this operator
  99. void DatasetOp::RemoveParent(const DatasetOp *parent) {
  100. parent_.erase(std::remove(parent_.begin(), parent_.end(), parent), parent_.end());
  101. }
  102. // Removes this node from the tree and connects it's parent/child together
  103. Status DatasetOp::Remove() {
  104. if (parent_.size() > 1) {
  105. std::string err_msg("No support for op removal if the operator has more than one parent");
  106. RETURN_STATUS_UNEXPECTED(err_msg);
  107. }
  108. if (child_.size() > 1) {
  109. std::string err_msg("No support for op removal if the operator has more than one child");
  110. RETURN_STATUS_UNEXPECTED(err_msg);
  111. }
  112. // Scenario's when removing node B:
  113. // A -> B -> C
  114. // A -> B
  115. // B -> C
  116. //
  117. // If we remove B, then first take our child A and update it's parent to be C
  118. // It's possible the parent is null if we are the root node being removed.
  119. if (!child_.empty()) {
  120. // If we have a parent, then assign chlid's parent to point to our parent.
  121. if (!parent_.empty()) {
  122. child_[0]->parent_[0] = parent_[0];
  123. } else {
  124. // We don't have a parent, so we are the root node being removed.
  125. // clear the parent list of our child so that it becomes the new root.
  126. child_[0]->parent_.clear();
  127. tree_->AssignRoot(child_[0]);
  128. }
  129. }
  130. // Next, if we had a parent, then set it's child to be our child.
  131. if (!parent_.empty()) {
  132. // if we have a child, then set our parent to point to it
  133. if (!child_.empty()) {
  134. parent_[0]->child_[0] = child_[0];
  135. } else {
  136. // We don't have a child, so clear the child list of the current
  137. // parent because it will be empty once we are removed.
  138. parent_[0]->child_.clear();
  139. }
  140. }
  141. return Status::OK();
  142. }
  143. // Getter function to get a shared pointer to our childAdds a operator to become our child.
  144. std::shared_ptr<DatasetOp> DatasetOp::child(int32_t child_index) const {
  145. MS_ASSERT(child_index < static_cast<int>(child_.size()));
  146. // Return a shared pointer
  147. return child_[child_index];
  148. }
  149. // Creates the connector within this operator
  150. void DatasetOp::CreateConnector(int32_t num_producers, int32_t num_consumers) {
  151. MS_LOG(DEBUG) << "Creating connector in tree operator: " << operator_id_ << ". Producer: " << num_producers
  152. << ". Consumer: " << num_consumers << ".";
  153. if (oc_queue_size_ > 0) {
  154. out_connector_ = std::make_unique<DbConnector>(num_producers, // The number of producers
  155. num_consumers, // Only one consumer (the training App)
  156. oc_queue_size_);
  157. } else {
  158. // Some op's may choose not to have an output connector
  159. MS_LOG(DEBUG) << "Bypassed connector creation for tree operator: " << operator_id_ << ".";
  160. out_connector_ = nullptr;
  161. }
  162. }
  163. // A print method typically used for debugging. showAll of true will recursively descend to child prints
  164. void DatasetOp::Print(std::ostream &out, bool show_all) const {
  165. // When show_all is false, we display a 1 liner piece of text for the op.
  166. // When show_all is true, we display more detailed output for the op.
  167. // Derived printers should show their own header info, then call base class printer, followed by
  168. // derived-specific items.
  169. // For now, the base class doesn't have any summary info to show so it's a no-op in that case.
  170. if (show_all) {
  171. // The detailed display will show common base class info of the op. Allow the derived class to print
  172. // it's own id and name though as the first line.
  173. out << "\nNumber of children : " << child_.size();
  174. for (size_t i = 0; i < child_.size(); i++) {
  175. out << "\n Child[" << i << "] id: " << child_[i]->id();
  176. }
  177. out << "\nNumber of parents : " << parent_.size();
  178. for (size_t i = 0; i < parent_.size(); i++) {
  179. out << "\n Parent[" << i << "] id: " << parent_[i]->id();
  180. }
  181. out << "\nConnector queue size : " << oc_queue_size_ << "\nOperator control flags : 0x" << std::hex
  182. << std::setw(8) << std::setfill('0') << op_ctrl_flags_ << std::dec << std::setfill(' ');
  183. if (sampler_) {
  184. sampler_->Print(out, show_all);
  185. }
  186. }
  187. }
  188. // Gets the next buffer from the given child
  189. Status DatasetOp::GetNextBuffer(std::unique_ptr<DataBuffer> *p_buffer, int32_t worker_id, bool retry_if_eoe) {
  190. #if defined(_WIN32) || defined(_WIN64)
  191. RETURN_IF_NOT_OK(out_connector_->PopWithRetry(static_cast<int>(worker_id), p_buffer, retry_if_eoe));
  192. #else
  193. std::unique_ptr<DataBuffer> next_buff;
  194. // pop is a blocked call and will throw an interruption if the whole group shuts down.
  195. RETURN_IF_NOT_OK(out_connector_->PopWithRetry(static_cast<int>(worker_id), &next_buff, retry_if_eoe));
  196. *p_buffer = std::move(next_buff);
  197. #endif
  198. return Status::OK();
  199. }
  200. // Gets the next buffer from the given child . This function also has built-in eoe and eof
  201. // message handling so that child classes don't have to manually code pass-through logic when
  202. // those messages are received.
  203. Status DatasetOp::GetNextInput(std::unique_ptr<DataBuffer> *p_buffer, int32_t worker_id, int32_t child_index) {
  204. if (child_.size() == 0) {
  205. return this->GetNextBuffer(p_buffer, worker_id);
  206. }
  207. CHECK_FAIL_RETURN_UNEXPECTED(child_index < child_.size(), "Child index too big : " + std::to_string(child_index));
  208. std::shared_ptr<DatasetOp> child = child_[child_index];
  209. std::unique_ptr<DataBuffer> buf;
  210. RETURN_IF_NOT_OK(child->GetNextBuffer(&buf, worker_id));
  211. // Loop until non EOE is received
  212. while (buf->eoe()) {
  213. RETURN_IF_NOT_OK(EoeReceived(worker_id));
  214. if (state_ == OpState::kDeOpIdle) {
  215. *p_buffer = std::move(buf);
  216. return Status::OK();
  217. }
  218. RETURN_IF_NOT_OK(child->GetNextBuffer(&buf, worker_id));
  219. }
  220. // Check if the last buf is next eof
  221. if (buf->eof()) {
  222. RETURN_IF_NOT_OK(EofReceived(worker_id));
  223. }
  224. *p_buffer = std::move(buf);
  225. return Status::OK();
  226. }
  227. // Performs handling for when an eoe message is received.
  228. // The base class implementation simply flows the eoe message to output. Derived classes
  229. // may override if they need to perform special eoe handling.
  230. Status DatasetOp::EoeReceived(int32_t worker_id) {
  231. std::unique_ptr<DataBuffer> eoe_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
  232. return (out_connector_->Add(static_cast<int>(worker_id), std::move(eoe_buffer)));
  233. }
  234. // Performs handling for when an eof message is received.
  235. // The base class implementation simply flows the eof message to output. Derived classes
  236. // may override if they need to perform special eof handling.
  237. Status DatasetOp::EofReceived(int32_t worker_id) {
  238. std::unique_ptr<DataBuffer> eof_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF);
  239. return (out_connector_->Add(static_cast<int>(worker_id), std::move(eof_buffer)));
  240. }
  241. // During tree prepare phase, operators may have specific pre-operations to perform depending on
  242. // their role.
  243. Status DatasetOp::PrepareNodePreAction() {
  244. if (BitTest(tree_->PrepareFlags(), ExecutionTree::kDePrepRepeat)) set_control_flag(kDeOpRepeated);
  245. return Status::OK();
  246. }
  247. // During tree prepare phase, operators may have specific post-operations to perform depending on
  248. // their role.
  249. Status DatasetOp::PrepareNodePostAction() {
  250. // If this op does not have any children and it is in a repeat path of the tree...
  251. if (child_.empty() && BitTest(op_ctrl_flags_, kDeOpRepeated)) {
  252. // push ourselves onto the eoe operator stack. Later, a repeat/epoch ctrl operator
  253. // above us will consume them.
  254. tree_->AddToEOEOpStack(shared_from_this());
  255. }
  256. // Creating Connector object for each op.
  257. // The consumer of the root node is assumed to be one thread.
  258. // If multiple threads are consuming from the root node, they will get the ordered data in round robin fashion.
  259. if (parent_.empty()) {
  260. this->CreateConnector(num_producers(), 1);
  261. } else {
  262. this->CreateConnector(num_producers(), parent_[0]->num_consumers());
  263. }
  264. if (out_connector_) {
  265. RETURN_IF_NOT_OK(out_connector_->Register(tree_->AllTasks()));
  266. }
  267. RETURN_IF_NOT_OK(this->RegisterWorkerConnectors());
  268. // Generate the column name map for the current op.
  269. RETURN_IF_NOT_OK(this->ComputeColMap());
  270. return Status::OK();
  271. }
  272. // Getter function. Base class does not have any special flags setting.
  273. uint32_t DatasetOp::PrepareFlags() const { return ExecutionTree::kDePrepNone; }
  274. // Derived classes may implement the reset function if the operator is stateful and needs
  275. // specific reset handling that is not contained in this common code version of the reset.
  276. Status DatasetOp::Reset() {
  277. state_ = OpState::kDeOpRunning;
  278. return Status::OK();
  279. }
  280. // gives a string output for the column map for handy debug printing
  281. std::string DatasetOp::ColumnNameMapAsString() const {
  282. std::string outStr = "Column name id map: ";
  283. for (auto &it : column_name_id_map_) {
  284. outStr += (" " + it.first + ":" + std::to_string(it.second));
  285. }
  286. return outStr;
  287. }
  288. // Computing the assignment of the column name map.
  289. // This just inherits the column map from its first child, can only be used if the number of children is 1.
  290. // Operations changing the column map must overwrite this function.
  291. Status DatasetOp::ComputeColMap() {
  292. if (child_.size() > 1) {
  293. RETURN_STATUS_UNEXPECTED("Assigning column name map from child only works for single-child operators.");
  294. }
  295. if (column_name_id_map_.empty()) {
  296. column_name_id_map_ = child_[0]->column_name_id_map();
  297. if (column_name_id_map_.empty()) {
  298. RETURN_STATUS_UNEXPECTED("Child column name map cannot be empty!");
  299. }
  300. MS_LOG(DEBUG) << "Setting column map:\n" << DatasetOp::ColumnNameMapAsString();
  301. } else {
  302. MS_LOG(WARNING) << "Column name map is already set!";
  303. }
  304. return Status::OK();
  305. }
  306. Status DatasetOp::PreAccept(NodePass *p, bool *modified) {
  307. // DatasetOp is the base class of visitor target pre-visit.
  308. // This method will only be called if its derived class does not implement one.
  309. return p->PreRunOnNode(shared_from_this(), modified);
  310. }
  311. Status DatasetOp::Accept(NodePass *p, bool *modified) {
  312. // DatasetOp is the base class of visitor target.
  313. // This method will only be called if its derived class does not implement one.
  314. return p->RunOnNode(shared_from_this(), modified);
  315. }
  316. // A helper function with some common code that leaf nodes can use during
  317. // prepare phase for checking if they need to assign a sampler to the cache.
  318. Status DatasetOp::SaveSamplerForCache(bool random_access_op) {
  319. // If we are a descendant under a cache op and we have a sampler, then save this sampler
  320. // to a stack so that the cache can pick it up during it's processing above us.
  321. if (sampler_) {
  322. if (BitTest(tree_->PrepareFlags(), ExecutionTree::kDePrepCache)) {
  323. // use move semantic to set our sampler_ to null after the move. This is okay because a sampler is
  324. // useless to a random data op. It was only being used as a temporary holding until the cache can
  325. // be created
  326. tree_->AddToSamplerStack(sampler_);
  327. MS_LOG(INFO) << "Preparing a leaf op: passing sampler up the tree for Cache handling.";
  328. } else if (!random_access_op) {
  329. // A sampler exists, but we are not in a caching tree and we are not a random access mappable leaf.
  330. // This is an error because that type of leaf does not use sampling unless there's a cache to hook it into.
  331. RETURN_STATUS_UNEXPECTED(
  332. "Non-mappable leaf op has a sampler, but it only supports sampling if there is a cache after it in the tree");
  333. }
  334. }
  335. if (!random_access_op) {
  336. // Since we don't truly need the sampler for this non-mappable dataset and it's been saved for the cache
  337. // we can remove it now from the base.
  338. sampler_.reset();
  339. }
  340. return Status::OK();
  341. }
  342. uint32_t DatasetOp::GenerateCRC(const std::shared_ptr<DatasetOp> &op) {
  343. std::stringstream ss;
  344. op->tree_->Print(ss, op);
  345. std::string ss_str = ss.str();
  346. // Filter out the Operator control flags field when generating the check sum
  347. ss_str = std::regex_replace(ss_str, std::regex("Operator control flags.*\n"), "");
  348. // Filter out the Device id field to allow cache sharing for a distributed run of the same pipeline
  349. ss_str = std::regex_replace(ss_str, std::regex("Device id.*\n"), "");
  350. ss_str = std::regex_replace(ss_str, std::regex("device_id.*\n"), "");
  351. // The Cache crc and Server cache id field is different when creating new cache_client and re-using the same
  352. // cache_client later. So we filter out these two fields to allow cache sharing.
  353. ss_str = std::regex_replace(ss_str, std::regex("Cache crc.*\n"), "");
  354. ss_str = std::regex_replace(ss_str, std::regex("Server cache id.*\n"), "");
  355. uint32_t cache_crc = system::Crc32c::GetMaskCrc32cValue(ss_str.c_str(), ss_str.length());
  356. return cache_crc;
  357. }
  358. } // namespace dataset
  359. } // namespace mindspore