You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

serdes.cc 25 kB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475
  1. /**
  2. * Copyright 2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <fstream>
  17. #include <stack>
  18. #include <iomanip>
  19. #include "minddata/dataset/engine/serdes.h"
  20. #include "minddata/dataset/core/pybind_support.h"
  21. #include "utils/file_utils.h"
  22. #include "include/common/utils/utils.h"
  23. namespace mindspore {
  24. namespace dataset {
  25. std::map<std::string, Status (*)(nlohmann::json json_obj, std::shared_ptr<TensorOperation> *operation)>
  26. Serdes::func_ptr_ = Serdes::InitializeFuncPtr();
  27. Status Serdes::SaveToJSON(std::shared_ptr<DatasetNode> node, const std::string &filename, nlohmann::json *out_json) {
  28. RETURN_UNEXPECTED_IF_NULL(node);
  29. RETURN_UNEXPECTED_IF_NULL(out_json);
  30. // If an optimized IR Tree is sent (use-case for MD AutoTune), ignore Top and EpochCtrl nodes
  31. if (node->Name() == "Top" || node->Name() == "EpochCtrl") {
  32. CHECK_FAIL_RETURN_UNEXPECTED(
  33. node->Children().size() == 1,
  34. "Expected " + node->Name() + " to have exactly 1 child but it has " + std::to_string(node->Children().size()));
  35. return SaveToJSON(node->Children()[0], filename, out_json);
  36. }
  37. // Dump attributes of current node to json string
  38. nlohmann::json args;
  39. RETURN_IF_NOT_OK(node->to_json(&args));
  40. args["op_type"] = node->Name();
  41. // If the current node isn't leaf node, visit all its children and get all attributes
  42. std::vector<nlohmann::json> children_pipeline;
  43. if (!node->IsLeaf()) {
  44. for (auto child : node->Children()) {
  45. nlohmann::json child_args;
  46. RETURN_IF_NOT_OK(SaveToJSON(child, "", &child_args));
  47. children_pipeline.push_back(child_args);
  48. }
  49. }
  50. args["children"] = children_pipeline;
  51. // Save json string into file if filename is given.
  52. if (!filename.empty()) {
  53. RETURN_IF_NOT_OK(SaveJSONToFile(args, filename));
  54. }
  55. *out_json = args;
  56. return Status::OK();
  57. }
  58. Status Serdes::SaveJSONToFile(const nlohmann::json &json_string, const std::string &file_name, bool pretty) {
  59. constexpr int field_width = 4;
  60. try {
  61. std::optional<std::string> dir = "";
  62. std::optional<std::string> local_file_name = "";
  63. FileUtils::SplitDirAndFileName(file_name, &dir, &local_file_name);
  64. if (!dir.has_value()) {
  65. dir = ".";
  66. }
  67. auto realpath = FileUtils::GetRealPath(dir.value().c_str());
  68. if (!realpath.has_value()) {
  69. MS_LOG(ERROR) << "Invalid file, get real path failed, path=" << file_name;
  70. RETURN_STATUS_UNEXPECTED("Invalid file, get real path failed, path=" + file_name);
  71. }
  72. std::optional<std::string> whole_path = "";
  73. FileUtils::ConcatDirAndFileName(&realpath, &local_file_name, &whole_path);
  74. std::ofstream file(whole_path.value());
  75. if (pretty) {
  76. file << std::setw(field_width);
  77. }
  78. file << json_string << std::endl;
  79. file.close();
  80. ChangeFileMode(whole_path.value(), S_IRUSR | S_IWUSR);
  81. } catch (const std::exception &err) {
  82. RETURN_STATUS_UNEXPECTED("Invalid data, failed to save json string into file: " + file_name +
  83. ", error message: " + err.what());
  84. }
  85. return Status::OK();
  86. }
  87. Status Serdes::Deserialize(const std::string &json_filepath, std::shared_ptr<DatasetNode> *ds) {
  88. nlohmann::json json_obj;
  89. CHECK_FAIL_RETURN_UNEXPECTED(json_filepath.size() != 0, "Json path is null");
  90. std::ifstream json_in(json_filepath);
  91. CHECK_FAIL_RETURN_UNEXPECTED(json_in, "Invalid file, failed to open json file: " + json_filepath);
  92. try {
  93. json_in >> json_obj;
  94. } catch (const std::exception &e) {
  95. json_in.close();
  96. return Status(StatusCode::kMDSyntaxError,
  97. "Invalid file, failed to parse json file: " + json_filepath + ", error message: " + e.what());
  98. }
  99. json_in.close();
  100. // Handle config generated by dataset autotune
  101. if (json_obj.find("pipeline") != json_obj.end()) {
  102. json_obj = json_obj["pipeline"];
  103. }
  104. RETURN_IF_NOT_OK(ConstructPipeline(json_obj, ds));
  105. return Status::OK();
  106. }
  107. Status Serdes::ConstructPipeline(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) {
  108. CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("children") != json_obj.end(), "Failed to find children");
  109. std::shared_ptr<DatasetNode> child_ds;
  110. if (json_obj["children"].size() == 0) {
  111. // If the JSON object has no child, then this node is a leaf node. Call create node to construct the corresponding
  112. // leaf node
  113. RETURN_IF_NOT_OK(CreateNode(nullptr, json_obj, ds));
  114. } else if (json_obj["children"].size() == 1) {
  115. // This node only has one child, construct the sub-tree under it first, and then call create node to construct the
  116. // corresponding node
  117. RETURN_IF_NOT_OK(ConstructPipeline(json_obj["children"][0], &child_ds));
  118. RETURN_IF_NOT_OK(CreateNode(child_ds, json_obj, ds));
  119. } else {
  120. std::vector<std::shared_ptr<DatasetNode>> datasets;
  121. for (const auto &child_json_obj : json_obj["children"]) {
  122. RETURN_IF_NOT_OK(ConstructPipeline(child_json_obj, &child_ds));
  123. datasets.push_back(child_ds);
  124. }
  125. if (json_obj["op_type"] == "Zip") {
  126. CHECK_FAIL_RETURN_UNEXPECTED(datasets.size() > 1, "Should zip more than 1 dataset");
  127. RETURN_IF_NOT_OK(ZipNode::from_json(datasets, ds));
  128. } else if (json_obj["op_type"] == "Concat") {
  129. CHECK_FAIL_RETURN_UNEXPECTED(datasets.size() > 1, "Should concat more than 1 dataset");
  130. RETURN_IF_NOT_OK(ConcatNode::from_json(json_obj, datasets, ds));
  131. } else {
  132. return Status(StatusCode::kMDUnexpectedError,
  133. "Invalid data, unsupported operation type: " + std::string(json_obj["op_type"]));
  134. }
  135. }
  136. return Status::OK();
  137. }
  138. Status Serdes::CreateNode(const std::shared_ptr<DatasetNode> &child_ds, nlohmann::json json_obj,
  139. std::shared_ptr<DatasetNode> *ds) {
  140. CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("op_type") != json_obj.end(), "Failed to find op_type in json.");
  141. std::string op_type = json_obj["op_type"];
  142. if (child_ds == nullptr) {
  143. // if dataset doesn't have any child, then create a source dataset IR. e.g., ImageFolderNode, CocoNode
  144. RETURN_IF_NOT_OK(CreateDatasetNode(json_obj, op_type, ds));
  145. } else {
  146. // if the dataset has at least one child, then create an operation dataset IR, e.g., BatchNode, MapNode
  147. RETURN_IF_NOT_OK(CreateDatasetOperationNode(child_ds, json_obj, op_type, ds));
  148. }
  149. return Status::OK();
  150. }
  151. Status Serdes::CreateDatasetNode(const nlohmann::json &json_obj, const std::string &op_type,
  152. std::shared_ptr<DatasetNode> *ds) {
  153. if (op_type == kAlbumNode) {
  154. RETURN_IF_NOT_OK(AlbumNode::from_json(json_obj, ds));
  155. } else if (op_type == kCelebANode) {
  156. RETURN_IF_NOT_OK(CelebANode::from_json(json_obj, ds));
  157. } else if (op_type == kCifar10Node) {
  158. RETURN_IF_NOT_OK(Cifar10Node::from_json(json_obj, ds));
  159. } else if (op_type == kCifar100Node) {
  160. RETURN_IF_NOT_OK(Cifar100Node::from_json(json_obj, ds));
  161. } else if (op_type == kCLUENode) {
  162. RETURN_IF_NOT_OK(CLUENode::from_json(json_obj, ds));
  163. } else if (op_type == kCocoNode) {
  164. RETURN_IF_NOT_OK(CocoNode::from_json(json_obj, ds));
  165. } else if (op_type == kCSVNode) {
  166. RETURN_IF_NOT_OK(CSVNode::from_json(json_obj, ds));
  167. } else if (op_type == kFlickrNode) {
  168. RETURN_IF_NOT_OK(FlickrNode::from_json(json_obj, ds));
  169. } else if (op_type == kImageFolderNode) {
  170. RETURN_IF_NOT_OK(ImageFolderNode::from_json(json_obj, ds));
  171. } else if (op_type == kManifestNode) {
  172. RETURN_IF_NOT_OK(ManifestNode::from_json(json_obj, ds));
  173. } else if (op_type == kMnistNode) {
  174. RETURN_IF_NOT_OK(MnistNode::from_json(json_obj, ds));
  175. } else if (op_type == kTextFileNode) {
  176. RETURN_IF_NOT_OK(TextFileNode::from_json(json_obj, ds));
  177. } else if (op_type == kTFRecordNode) {
  178. RETURN_IF_NOT_OK(TFRecordNode::from_json(json_obj, ds));
  179. } else if (op_type == kVOCNode) {
  180. RETURN_IF_NOT_OK(VOCNode::from_json(json_obj, ds));
  181. } else {
  182. return Status(StatusCode::kMDUnexpectedError, "Invalid data, unsupported operation type: " + op_type);
  183. }
  184. return Status::OK();
  185. }
  186. Status Serdes::CreateDatasetOperationNode(const std::shared_ptr<DatasetNode> &ds, const nlohmann::json &json_obj,
  187. const std::string &op_type, std::shared_ptr<DatasetNode> *result) {
  188. if (op_type == kBatchNode) {
  189. RETURN_IF_NOT_OK(BatchNode::from_json(json_obj, ds, result));
  190. } else if (op_type == kMapNode) {
  191. RETURN_IF_NOT_OK(MapNode::from_json(json_obj, ds, result));
  192. } else if (op_type == kProjectNode) {
  193. RETURN_IF_NOT_OK(ProjectNode::from_json(json_obj, ds, result));
  194. } else if (op_type == kRenameNode) {
  195. RETURN_IF_NOT_OK(RenameNode::from_json(json_obj, ds, result));
  196. } else if (op_type == kRepeatNode) {
  197. RETURN_IF_NOT_OK(RepeatNode::from_json(json_obj, ds, result));
  198. } else if (op_type == kShuffleNode) {
  199. RETURN_IF_NOT_OK(ShuffleNode::from_json(json_obj, ds, result));
  200. } else if (op_type == kSkipNode) {
  201. RETURN_IF_NOT_OK(SkipNode::from_json(json_obj, ds, result));
  202. } else if (op_type == kTransferNode) {
  203. RETURN_IF_NOT_OK(TransferNode::from_json(json_obj, ds, result));
  204. } else if (op_type == kTakeNode) {
  205. RETURN_IF_NOT_OK(TakeNode::from_json(json_obj, ds, result));
  206. } else {
  207. return Status(StatusCode::kMDUnexpectedError, "Invalid data, unsupported operation type: " + op_type);
  208. }
  209. return Status::OK();
  210. }
  211. Status Serdes::ConstructSampler(nlohmann::json json_obj, std::shared_ptr<SamplerObj> *sampler) {
  212. if (json_obj["sampler_name"] == "SkipFirstEpochSampler") {
  213. RETURN_IF_NOT_OK(SkipFirstEpochSamplerObj::from_json(json_obj, sampler));
  214. return Status::OK();
  215. }
  216. CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_samples") != json_obj.end(), "Failed to find num_samples");
  217. CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("sampler_name") != json_obj.end(), "Failed to find sampler_name");
  218. int64_t num_samples = json_obj["num_samples"];
  219. std::string sampler_name = json_obj["sampler_name"];
  220. if (sampler_name == "DistributedSampler") {
  221. RETURN_IF_NOT_OK(DistributedSamplerObj::from_json(json_obj, num_samples, sampler));
  222. } else if (sampler_name == "PKSampler") {
  223. RETURN_IF_NOT_OK(PKSamplerObj::from_json(json_obj, num_samples, sampler));
  224. } else if (sampler_name == "RandomSampler") {
  225. RETURN_IF_NOT_OK(RandomSamplerObj::from_json(json_obj, num_samples, sampler));
  226. } else if (sampler_name == "SequentialSampler") {
  227. RETURN_IF_NOT_OK(SequentialSamplerObj::from_json(json_obj, num_samples, sampler));
  228. } else if (sampler_name == "SubsetSampler") {
  229. RETURN_IF_NOT_OK(SubsetSamplerObj::from_json(json_obj, num_samples, sampler));
  230. } else if (sampler_name == "SubsetRandomSampler") {
  231. RETURN_IF_NOT_OK(SubsetRandomSamplerObj::from_json(json_obj, num_samples, sampler));
  232. } else if (sampler_name == "WeightedRandomSampler") {
  233. RETURN_IF_NOT_OK(WeightedRandomSamplerObj::from_json(json_obj, num_samples, sampler));
  234. } else {
  235. return Status(StatusCode::kMDUnexpectedError, "Invalid data, unsupported sampler type: " + sampler_name);
  236. }
  237. return Status::OK();
  238. }
  239. Status Serdes::ConstructTensorOps(nlohmann::json json_obj, std::vector<std::shared_ptr<TensorOperation>> *result) {
  240. std::vector<std::shared_ptr<TensorOperation>> output;
  241. for (nlohmann::json item : json_obj) {
  242. if (item.find("python_module") != item.end()) {
  243. if (Py_IsInitialized() != 0) {
  244. RETURN_IF_NOT_OK(PyFuncOp::from_json(item, result));
  245. } else {
  246. LOG_AND_RETURN_STATUS_SYNTAX_ERROR(
  247. "Python module is not initialized or Pyfunction is not supported on this platform.");
  248. }
  249. } else {
  250. CHECK_FAIL_RETURN_UNEXPECTED(item.find("tensor_op_name") != item.end(), "Failed to find tensor_op_name");
  251. CHECK_FAIL_RETURN_UNEXPECTED(item.find("tensor_op_params") != item.end(), "Failed to find tensor_op_params");
  252. std::string op_name = item["tensor_op_name"];
  253. nlohmann::json op_params = item["tensor_op_params"];
  254. std::shared_ptr<TensorOperation> operation = nullptr;
  255. CHECK_FAIL_RETURN_UNEXPECTED(func_ptr_.find(op_name) != func_ptr_.end(),
  256. "Invalid data, unsupported operation: " + op_name);
  257. RETURN_IF_NOT_OK(func_ptr_[op_name](op_params, &operation));
  258. output.push_back(operation);
  259. *result = output;
  260. }
  261. }
  262. return Status::OK();
  263. }
  264. std::map<std::string, Status (*)(nlohmann::json json_obj, std::shared_ptr<TensorOperation> *operation)>
  265. Serdes::InitializeFuncPtr() {
  266. std::map<std::string, Status (*)(nlohmann::json json_obj, std::shared_ptr<TensorOperation> * operation)> ops_ptr;
  267. ops_ptr[vision::kAdjustGammaOperation] = &(vision::AdjustGammaOperation::from_json);
  268. ops_ptr[vision::kAffineOperation] = &(vision::AffineOperation::from_json);
  269. ops_ptr[vision::kAutoContrastOperation] = &(vision::AutoContrastOperation::from_json);
  270. ops_ptr[vision::kBoundingBoxAugmentOperation] = &(vision::BoundingBoxAugmentOperation::from_json);
  271. ops_ptr[vision::kCenterCropOperation] = &(vision::CenterCropOperation::from_json);
  272. ops_ptr[vision::kCropOperation] = &(vision::CropOperation::from_json);
  273. ops_ptr[vision::kCutMixBatchOperation] = &(vision::CutMixBatchOperation::from_json);
  274. ops_ptr[vision::kCutOutOperation] = &(vision::CutOutOperation::from_json);
  275. ops_ptr[vision::kDecodeOperation] = &(vision::DecodeOperation::from_json);
  276. #ifdef ENABLE_ACL
  277. ops_ptr[vision::kDvppCropJpegOperation] = &(vision::DvppCropJpegOperation::from_json);
  278. ops_ptr[vision::kDvppDecodeResizeOperation] = &(vision::DvppDecodeResizeOperation::from_json);
  279. ops_ptr[vision::kDvppDecodeResizeCropOperation] = &(vision::DvppDecodeResizeCropOperation::from_json);
  280. ops_ptr[vision::kDvppNormalizeOperation] = &(vision::DvppNormalizeOperation::from_json);
  281. ops_ptr[vision::kDvppResizeJpegOperation] = &(vision::DvppResizeJpegOperation::from_json);
  282. #endif
  283. ops_ptr[vision::kEqualizeOperation] = &(vision::EqualizeOperation::from_json);
  284. ops_ptr[vision::kGaussianBlurOperation] = &(vision::GaussianBlurOperation::from_json);
  285. ops_ptr[vision::kHorizontalFlipOperation] = &(vision::HorizontalFlipOperation::from_json);
  286. ops_ptr[vision::kHwcToChwOperation] = &(vision::HwcToChwOperation::from_json);
  287. ops_ptr[vision::kInvertOperation] = &(vision::InvertOperation::from_json);
  288. ops_ptr[vision::kMixUpBatchOperation] = &(vision::MixUpBatchOperation::from_json);
  289. ops_ptr[vision::kNormalizeOperation] = &(vision::NormalizeOperation::from_json);
  290. ops_ptr[vision::kNormalizePadOperation] = &(vision::NormalizePadOperation::from_json);
  291. ops_ptr[vision::kPadOperation] = &(vision::PadOperation::from_json);
  292. ops_ptr[vision::kRandomAffineOperation] = &(vision::RandomAffineOperation::from_json);
  293. ops_ptr[vision::kRandomColorOperation] = &(vision::RandomColorOperation::from_json);
  294. ops_ptr[vision::kRandomColorAdjustOperation] = &(vision::RandomColorAdjustOperation::from_json);
  295. ops_ptr[vision::kRandomCropDecodeResizeOperation] = &(vision::RandomCropDecodeResizeOperation::from_json);
  296. ops_ptr[vision::kRandomCropOperation] = &(vision::RandomCropOperation::from_json);
  297. ops_ptr[vision::kRandomCropWithBBoxOperation] = &(vision::RandomCropWithBBoxOperation::from_json);
  298. ops_ptr[vision::kRandomHorizontalFlipOperation] = &(vision::RandomHorizontalFlipOperation::from_json);
  299. ops_ptr[vision::kRandomHorizontalFlipWithBBoxOperation] = &(vision::RandomHorizontalFlipWithBBoxOperation::from_json);
  300. ops_ptr[vision::kRandomPosterizeOperation] = &(vision::RandomPosterizeOperation::from_json);
  301. ops_ptr[vision::kRandomResizeOperation] = &(vision::RandomResizeOperation::from_json);
  302. ops_ptr[vision::kRandomResizeWithBBoxOperation] = &(vision::RandomResizeWithBBoxOperation::from_json);
  303. ops_ptr[vision::kRandomResizedCropOperation] = &(vision::RandomResizedCropOperation::from_json);
  304. ops_ptr[vision::kRandomResizedCropWithBBoxOperation] = &(vision::RandomResizedCropWithBBoxOperation::from_json);
  305. ops_ptr[vision::kRandomRotationOperation] = &(vision::RandomRotationOperation::from_json);
  306. ops_ptr[vision::kRandomSelectSubpolicyOperation] = &(vision::RandomSelectSubpolicyOperation::from_json);
  307. ops_ptr[vision::kRandomSharpnessOperation] = &(vision::RandomSharpnessOperation::from_json);
  308. ops_ptr[vision::kRandomSolarizeOperation] = &(vision::RandomSolarizeOperation::from_json);
  309. ops_ptr[vision::kRandomVerticalFlipOperation] = &(vision::RandomVerticalFlipOperation::from_json);
  310. ops_ptr[vision::kRandomVerticalFlipWithBBoxOperation] = &(vision::RandomVerticalFlipWithBBoxOperation::from_json);
  311. ops_ptr[vision::kRandomSharpnessOperation] = &(vision::RandomSharpnessOperation::from_json);
  312. ops_ptr[vision::kRandomSolarizeOperation] = &(vision::RandomSolarizeOperation::from_json);
  313. ops_ptr[vision::kRescaleOperation] = &(vision::RescaleOperation::from_json);
  314. ops_ptr[vision::kResizeOperation] = &(vision::ResizeOperation::from_json);
  315. ops_ptr[vision::kResizePreserveAROperation] = &(vision::ResizePreserveAROperation::from_json);
  316. ops_ptr[vision::kResizeWithBBoxOperation] = &(vision::ResizeWithBBoxOperation::from_json);
  317. ops_ptr[vision::kRgbaToBgrOperation] = &(vision::RgbaToBgrOperation::from_json);
  318. ops_ptr[vision::kRgbaToRgbOperation] = &(vision::RgbaToRgbOperation::from_json);
  319. ops_ptr[vision::kRgbToBgrOperation] = &(vision::RgbToBgrOperation::from_json);
  320. ops_ptr[vision::kRgbToGrayOperation] = &(vision::RgbToGrayOperation::from_json);
  321. ops_ptr[vision::kRotateOperation] = &(vision::RotateOperation::from_json);
  322. ops_ptr[vision::kSlicePatchesOperation] = &(vision::SlicePatchesOperation::from_json);
  323. ops_ptr[vision::kSoftDvppDecodeRandomCropResizeJpegOperation] =
  324. &(vision::SoftDvppDecodeRandomCropResizeJpegOperation::from_json);
  325. ops_ptr[vision::kSoftDvppDecodeResizeJpegOperation] = &(vision::SoftDvppDecodeResizeJpegOperation::from_json);
  326. ops_ptr[vision::kSwapRedBlueOperation] = &(vision::SwapRedBlueOperation::from_json);
  327. ops_ptr[vision::kUniformAugOperation] = &(vision::UniformAugOperation::from_json);
  328. ops_ptr[vision::kVerticalFlipOperation] = &(vision::VerticalFlipOperation::from_json);
  329. ops_ptr[transforms::kFillOperation] = &(transforms::FillOperation::from_json);
  330. ops_ptr[transforms::kOneHotOperation] = &(transforms::OneHotOperation::from_json);
  331. ops_ptr[transforms::kTypeCastOperation] = &(transforms::TypeCastOperation::from_json);
  332. ops_ptr[text::kToNumberOperation] = &(text::ToNumberOperation::from_json);
  333. return ops_ptr;
  334. }
  335. Status Serdes::ParseMindIRPreprocess(const std::vector<std::string> &map_json_string,
  336. std::vector<std::shared_ptr<mindspore::dataset::Execute>> *data_graph) {
  337. CHECK_FAIL_RETURN_UNEXPECTED(!map_json_string.empty(), "Invalid data, no json data in map_json_string.");
  338. const std::string process_column = "[\"image\"]";
  339. MS_LOG(WARNING) << "Only supports parse \"image\" column from dataset object.";
  340. nlohmann::json map_json;
  341. try {
  342. for (auto &json : map_json_string) {
  343. map_json = nlohmann::json::parse(json);
  344. if (map_json["input_columns"].dump() == process_column) {
  345. break;
  346. }
  347. }
  348. } catch (const std::exception &err) {
  349. MS_LOG(ERROR) << "Invalid json content, failed to parse JSON data, error message: " << err.what();
  350. RETURN_STATUS_UNEXPECTED("Invalid json content, failed to parse JSON data.");
  351. }
  352. if (map_json.empty()) {
  353. MS_LOG(ERROR) << "Invalid json content, no JSON data found for given input column: " + process_column;
  354. RETURN_STATUS_UNEXPECTED("Invalid json content, no JSON data found for given input column: " + process_column);
  355. }
  356. while (map_json != nullptr) {
  357. CHECK_FAIL_RETURN_UNEXPECTED(map_json["op_type"] == "Map", "Invalid json content, this is not a MapOp.");
  358. std::vector<std::shared_ptr<TensorOperation>> tensor_ops;
  359. RETURN_IF_NOT_OK(ConstructTensorOps(map_json["operations"], &tensor_ops));
  360. if (map_json["input_columns"].dump() == process_column) {
  361. std::vector<std::string> op_names;
  362. std::transform(tensor_ops.begin(), tensor_ops.end(), std::back_inserter(op_names),
  363. [](const auto &op) { return op->Name(); });
  364. MS_LOG(INFO) << "Find valid preprocess operations: " << op_names;
  365. data_graph->push_back(std::make_shared<Execute>(tensor_ops));
  366. }
  367. map_json = map_json["children"];
  368. }
  369. if (!data_graph->size()) {
  370. MS_LOG(WARNING) << "Can not find any valid preprocess operation.";
  371. }
  372. return Status::OK();
  373. }
  374. Status Serdes::UpdateOptimizedIRTreeJSON(nlohmann::json *serialized_json,
  375. const std::map<int32_t, std::shared_ptr<DatasetOp>> &op_map) {
  376. RETURN_UNEXPECTED_IF_NULL(serialized_json);
  377. int32_t op_id = 0;
  378. return RecurseUpdateOptimizedIRTreeJSON(serialized_json, &op_id, op_map);
  379. }
  380. bool IsDatasetOpMatchIRNode(std::string_view ir_node_name, std::string_view dataset_op_name) {
  381. // Helper function to match IR Node name to its dataset op name
  382. if (ir_node_name == kSyncWaitNode) {
  383. return dataset_op_name == kBarrierOp;
  384. } else if (ir_node_name == kCifar10Node || ir_node_name == kCifar100Node) {
  385. return dataset_op_name == "CifarOp";
  386. } else if (ir_node_name == kMindDataNode) {
  387. return dataset_op_name == "MindRecordOp";
  388. } else if (ir_node_name == kRandomNode) {
  389. return dataset_op_name == "RandomDataOp";
  390. } else if (ir_node_name == kTFRecordNode) {
  391. return dataset_op_name == "TFReaderOp";
  392. } else if (ir_node_name == kIWSLT2016Node || ir_node_name == kIWSLT2017Node) {
  393. return dataset_op_name == "IWSLTOp";
  394. } else {
  395. // Generic way of matching, special cases handled above. Special cases will evolve over time.
  396. return ir_node_name.substr(0, ir_node_name.find("Dataset")) ==
  397. dataset_op_name.substr(0, dataset_op_name.find("Op"));
  398. }
  399. }
  400. Status Serdes::RecurseUpdateOptimizedIRTreeJSON(nlohmann::json *serialized_json, int32_t *op_id,
  401. const std::map<int32_t, std::shared_ptr<DatasetOp>> &op_map) {
  402. RETURN_UNEXPECTED_IF_NULL(serialized_json);
  403. RETURN_UNEXPECTED_IF_NULL(op_id);
  404. std::string ir_node_name = (*serialized_json)["op_type"];
  405. MS_LOG(INFO) << "Visiting IR Node: " << ir_node_name;
  406. // Each IR Node should have a corresponding dataset node in the execution tree but the reverse is not necessarily true
  407. while (!IsDatasetOpMatchIRNode(ir_node_name, op_map.find(*op_id)->second->Name())) {
  408. // During the construction of execution tree, extra dataset nodes may have been inserted
  409. // Skip dataset ops unless we get to the expected node
  410. MS_LOG(INFO) << "\tSkipping dataset op: " << op_map.find(*op_id)->second->NameWithID();
  411. ++(*op_id);
  412. CHECK_FAIL_RETURN_UNEXPECTED(*op_id < op_map.size(), "op_id is out of bounds");
  413. }
  414. MS_LOG(INFO) << "\tMatch found for IR Node: " << ir_node_name
  415. << " with dataset op: " << op_map.find(*op_id)->second->NameWithID();
  416. if (!op_map.find(*op_id)->second->inlined() && serialized_json->contains("num_parallel_workers") &&
  417. serialized_json->contains("connector_queue_size")) {
  418. (*serialized_json)["num_parallel_workers"] = op_map.find(*op_id)->second->NumWorkers();
  419. (*serialized_json)["connector_queue_size"] = op_map.find(*op_id)->second->ConnectorCapacity();
  420. }
  421. ++(*op_id);
  422. auto num_children = (*serialized_json)["children"].size();
  423. for (int i = 0; i < num_children; ++i) {
  424. RETURN_IF_NOT_OK(RecurseUpdateOptimizedIRTreeJSON(&(*serialized_json)["children"][i], op_id, op_map));
  425. }
  426. return Status::OK();
  427. }
  428. // In the current stage, there is a cyclic dependency between libmindspore.so and c_dataengine.so,
  429. // we make a C function here and dlopen by libminspore.so to avoid linking explicitly,
  430. // will be fix after decouling libminspore.so into multi submodules
  431. extern "C" {
  432. // ParseMindIRPreprocess_C has C-linkage specified, but returns user-defined type 'mindspore::Status'
  433. // which is incompatible with C
  434. void ParseMindIRPreprocess_C(const std::vector<std::string> &dataset_json,
  435. std::vector<std::shared_ptr<mindspore::dataset::Execute>> *data_graph, Status *s) {
  436. Status ret = Serdes::ParseMindIRPreprocess(dataset_json, data_graph);
  437. *s = Status(ret);
  438. }
  439. }
  440. } // namespace dataset
  441. } // namespace mindspore