From 5bbf4fe2b1cf7455d97064d81ffa15e1083c8c87 Mon Sep 17 00:00:00 2001 From: Lixia Chen Date: Fri, 9 Apr 2021 10:47:14 -0400 Subject: [PATCH] Make CLUENode::CreateKepMap() less than 50 lines --- .../engine/ir/datasetops/source/clue_node.cc | 149 +++++++++--------- .../engine/ir/datasetops/source/clue_node.h | 28 +++- 2 files changed, 98 insertions(+), 79 deletions(-) diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/clue_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/clue_node.cc index e60e6f94b3..4e58ade020 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/clue_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/clue_node.cc @@ -83,88 +83,87 @@ std::vector CLUENode::split(const std::string &s, char delim) { return res; } -std::map CLUENode::CreateKeyMapForBuild() { +std::map CLUENode::CreateKeyMapForAFQMCOrCMNLITask() { std::map key_map; - if (task_ == "AFQMC") { - if (usage_ == "train" || usage_ == "eval") { - key_map["sentence1"] = "sentence1"; - key_map["sentence2"] = "sentence2"; - key_map["label"] = "label"; - } else { // usage_ == "test" - key_map["id"] = "id"; - key_map["sentence1"] = "sentence1"; - key_map["sentence2"] = "sentence2"; - } + if (usage_ == "train" || usage_ == "eval") { + key_map["label"] = "label"; + } else { // usage_ == "test" + key_map["id"] = "id"; } - if (task_ == "CMNLI") { - if (usage_ == "train" || usage_ == "eval") { - key_map["sentence1"] = "sentence1"; - key_map["sentence2"] = "sentence2"; - key_map["label"] = "label"; - } else { // usage_ == "test" - key_map["id"] = "id"; - key_map["sentence1"] = "sentence1"; - key_map["sentence2"] = "sentence2"; - } + key_map["sentence1"] = "sentence1"; + key_map["sentence2"] = "sentence2"; + return key_map; +} + +std::map CLUENode::CreateKeyMapForCSLTask() { + std::map key_map; + if (usage_ == "train" || usage_ == "eval") { + key_map["label"] = "label"; } - if (task_ == "CSL") { - if (usage_ == "train" || usage_ == "eval") { - key_map["id"] = "id"; - key_map["abst"] = "abst"; - key_map["keyword"] = "keyword"; - key_map["label"] = "label"; - } else { // usage_ == "test" - key_map["id"] = "id"; - key_map["abst"] = "abst"; - key_map["keyword"] = "keyword"; - } + key_map["id"] = "id"; + key_map["abst"] = "abst"; + key_map["keyword"] = "keyword"; + return key_map; +} + +std::map CLUENode::CreateKeyMapForIFLYTEKTask() { + std::map key_map; + if (usage_ == "train" || usage_ == "eval") { + key_map["label"] = "label"; + key_map["label_des"] = "label_des"; + } else { // usage_ == "test" + key_map["id"] = "id"; } - if (task_ == "IFLYTEK") { - if (usage_ == "train" || usage_ == "eval") { - key_map["label"] = "label"; - key_map["label_des"] = "label_des"; - key_map["sentence"] = "sentence"; - } else { // usage_ == "test" - key_map["id"] = "id"; - key_map["sentence"] = "sentence"; - } + key_map["sentence"] = "sentence"; + return key_map; +} + +std::map CLUENode::CreateKeyMapForTNEWSTask() { + std::map key_map; + if (usage_ == "train" || usage_ == "eval") { + key_map["label"] = "label"; + key_map["label_desc"] = "label_desc"; + } else { // usage_ == "test" + key_map["id"] = "id"; } - if (task_ == "TNEWS") { - if (usage_ == "train" || usage_ == "eval") { - key_map["label"] = "label"; - key_map["label_desc"] = "label_desc"; - key_map["sentence"] = "sentence"; - key_map["keywords"] = "keywords"; - } else { // usage_ == "test" - key_map["id"] = "id"; - key_map["sentence"] = "sentence"; - key_map["keywords"] = "keywords"; - } + key_map["sentence"] = "sentence"; + key_map["keywords"] = "keywords"; + return key_map; +} + +std::map CLUENode::CreateKeyMapForWSCTask() { + std::map key_map; + if (usage_ == "train" || usage_ == "eval") { + key_map["label"] = "label"; } - if (task_ == "WSC") { - if (usage_ == "train" || usage_ == "eval") { - key_map["span1_index"] = "target/span1_index"; - key_map["span2_index"] = "target/span2_index"; - key_map["span1_text"] = "target/span1_text"; - key_map["span2_text"] = "target/span2_text"; - key_map["idx"] = "idx"; - key_map["label"] = "label"; - key_map["text"] = "text"; - } else { // usage_ == "test" - key_map["span1_index"] = "target/span1_index"; - key_map["span2_index"] = "target/span2_index"; - key_map["span1_text"] = "target/span1_text"; - key_map["span2_text"] = "target/span2_text"; - key_map["idx"] = "idx"; - key_map["text"] = "text"; - } + key_map["span1_index"] = "target/span1_index"; + key_map["span2_index"] = "target/span2_index"; + key_map["span1_text"] = "target/span1_text"; + key_map["span2_text"] = "target/span2_text"; + key_map["idx"] = "idx"; + key_map["text"] = "text"; + return key_map; +} + +std::map CLUENode::CreateKeyMap() { + std::map key_map; + if (task_ == "AFQMC" || task_ == "CMNLI") { + key_map = CreateKeyMapForAFQMCOrCMNLITask(); + } else if (task_ == "CSL") { + key_map = CreateKeyMapForCSLTask(); + } else if (task_ == "IFLYTEK") { + key_map = CreateKeyMapForIFLYTEKTask(); + } else if (task_ == "TNEWS") { + key_map = CreateKeyMapForTNEWSTask(); + } else if (task_ == "WSC") { + key_map = CreateKeyMapForWSCTask(); } return key_map; } // Function to build CLUENode Status CLUENode::Build(std::vector> *const node_ops) { - auto key_map = CreateKeyMapForBuild(); + auto key_map = CreateKeyMap(); ColKeyMap ck_map; for (auto &p : key_map) { ck_map.insert({p.first, split(p.second, '/')}); @@ -246,11 +245,11 @@ Status CLUENode::to_json(nlohmann::json *out_json) { return Status::OK(); } -// Note: The following two functions are common among NonMappableSourceNode and should be promoted to its parent class. -// CLUE by itself is a non-mappable dataset that does not support sampling. -// However, if a cache operator is injected at some other place higher in the tree, that cache can -// inherit this sampler from the leaf, providing sampling support from the caching layer. -// That is why we setup the sampler for a leaf node that does not use sampling. +// Note: The following two functions are common among NonMappableSourceNode and should be promoted to its parent +// class. CLUE by itself is a non-mappable dataset that does not support sampling. However, if a cache operator is +// injected at some other place higher in the tree, that cache can inherit this sampler from the leaf, providing +// sampling support from the caching layer. That is why we setup the sampler for a leaf node that does not use +// sampling. Status CLUENode::SetupSamplerForCache(std::shared_ptr *sampler) { bool shuffle_files = (shuffle_ == ShuffleMode::kGlobal || shuffle_ == ShuffleMode::kFiles); *sampler = SelectSampler(num_samples_, shuffle_files, num_shards_, shard_id_); diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/clue_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/clue_node.h index 033e251813..b255462b44 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/clue_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/clue_node.h @@ -50,10 +50,6 @@ class CLUENode : public NonMappableSourceNode { /// \return A shared pointer to the new copy std::shared_ptr Copy() override; - /// \brief Generate a key map to be used in Build() according to usage and task - /// \return The generated key map - std::map CreateKeyMapForBuild(); - /// \brief a base class override function to create the required runtime dataset op objects for this class /// \param node_ops - A vector containing shared pointer to the Dataset Ops that this object will create /// \return Status Status::OK() if build successfully @@ -111,6 +107,30 @@ class CLUENode : public NonMappableSourceNode { /// \return A string vector std::vector split(const std::string &s, char delim); + /// \brief Generate a key map for AFQMC or CMNLI task according to usage + /// \return The generated key map + std::map CreateKeyMapForAFQMCOrCMNLITask(); + + /// \brief Generate a key map for CSL task according to usage + /// \return The generated key map + std::map CreateKeyMapForCSLTask(); + + /// \brief Generate a key map for IFLYTEK task according to usage + /// \return The generated key map + std::map CreateKeyMapForIFLYTEKTask(); + + /// \brief Generate a key map for TNEWS task according to usage + /// \return The generated key map + std::map CreateKeyMapForTNEWSTask(); + + /// \brief Generate a key map for WSC task according to usage + /// \return The generated key map + std::map CreateKeyMapForWSCTask(); + + /// \brief Generate a key map to be used in Build() according to usage and task + /// \return The generated key map + std::map CreateKeyMap(); + std::vector dataset_files_; std::string task_; std::string usage_;