|
|
|
@@ -83,88 +83,87 @@ std::vector<std::string> CLUENode::split(const std::string &s, char delim) { |
|
|
|
return res; |
|
|
|
} |
|
|
|
|
|
|
|
std::map<std::string, std::string> CLUENode::CreateKeyMapForBuild() { |
|
|
|
std::map<std::string, std::string> CLUENode::CreateKeyMapForAFQMCOrCMNLITask() { |
|
|
|
std::map<std::string, std::string> key_map; |
|
|
|
if (task_ == "AFQMC") { |
|
|
|
if (usage_ == "train" || usage_ == "eval") { |
|
|
|
key_map["sentence1"] = "sentence1"; |
|
|
|
key_map["sentence2"] = "sentence2"; |
|
|
|
key_map["label"] = "label"; |
|
|
|
} else { // usage_ == "test" |
|
|
|
key_map["id"] = "id"; |
|
|
|
key_map["sentence1"] = "sentence1"; |
|
|
|
key_map["sentence2"] = "sentence2"; |
|
|
|
} |
|
|
|
if (usage_ == "train" || usage_ == "eval") { |
|
|
|
key_map["label"] = "label"; |
|
|
|
} else { // usage_ == "test" |
|
|
|
key_map["id"] = "id"; |
|
|
|
} |
|
|
|
if (task_ == "CMNLI") { |
|
|
|
if (usage_ == "train" || usage_ == "eval") { |
|
|
|
key_map["sentence1"] = "sentence1"; |
|
|
|
key_map["sentence2"] = "sentence2"; |
|
|
|
key_map["label"] = "label"; |
|
|
|
} else { // usage_ == "test" |
|
|
|
key_map["id"] = "id"; |
|
|
|
key_map["sentence1"] = "sentence1"; |
|
|
|
key_map["sentence2"] = "sentence2"; |
|
|
|
} |
|
|
|
key_map["sentence1"] = "sentence1"; |
|
|
|
key_map["sentence2"] = "sentence2"; |
|
|
|
return key_map; |
|
|
|
} |
|
|
|
|
|
|
|
std::map<std::string, std::string> CLUENode::CreateKeyMapForCSLTask() { |
|
|
|
std::map<std::string, std::string> key_map; |
|
|
|
if (usage_ == "train" || usage_ == "eval") { |
|
|
|
key_map["label"] = "label"; |
|
|
|
} |
|
|
|
if (task_ == "CSL") { |
|
|
|
if (usage_ == "train" || usage_ == "eval") { |
|
|
|
key_map["id"] = "id"; |
|
|
|
key_map["abst"] = "abst"; |
|
|
|
key_map["keyword"] = "keyword"; |
|
|
|
key_map["label"] = "label"; |
|
|
|
} else { // usage_ == "test" |
|
|
|
key_map["id"] = "id"; |
|
|
|
key_map["abst"] = "abst"; |
|
|
|
key_map["keyword"] = "keyword"; |
|
|
|
} |
|
|
|
key_map["id"] = "id"; |
|
|
|
key_map["abst"] = "abst"; |
|
|
|
key_map["keyword"] = "keyword"; |
|
|
|
return key_map; |
|
|
|
} |
|
|
|
|
|
|
|
std::map<std::string, std::string> CLUENode::CreateKeyMapForIFLYTEKTask() { |
|
|
|
std::map<std::string, std::string> key_map; |
|
|
|
if (usage_ == "train" || usage_ == "eval") { |
|
|
|
key_map["label"] = "label"; |
|
|
|
key_map["label_des"] = "label_des"; |
|
|
|
} else { // usage_ == "test" |
|
|
|
key_map["id"] = "id"; |
|
|
|
} |
|
|
|
if (task_ == "IFLYTEK") { |
|
|
|
if (usage_ == "train" || usage_ == "eval") { |
|
|
|
key_map["label"] = "label"; |
|
|
|
key_map["label_des"] = "label_des"; |
|
|
|
key_map["sentence"] = "sentence"; |
|
|
|
} else { // usage_ == "test" |
|
|
|
key_map["id"] = "id"; |
|
|
|
key_map["sentence"] = "sentence"; |
|
|
|
} |
|
|
|
key_map["sentence"] = "sentence"; |
|
|
|
return key_map; |
|
|
|
} |
|
|
|
|
|
|
|
std::map<std::string, std::string> CLUENode::CreateKeyMapForTNEWSTask() { |
|
|
|
std::map<std::string, std::string> key_map; |
|
|
|
if (usage_ == "train" || usage_ == "eval") { |
|
|
|
key_map["label"] = "label"; |
|
|
|
key_map["label_desc"] = "label_desc"; |
|
|
|
} else { // usage_ == "test" |
|
|
|
key_map["id"] = "id"; |
|
|
|
} |
|
|
|
if (task_ == "TNEWS") { |
|
|
|
if (usage_ == "train" || usage_ == "eval") { |
|
|
|
key_map["label"] = "label"; |
|
|
|
key_map["label_desc"] = "label_desc"; |
|
|
|
key_map["sentence"] = "sentence"; |
|
|
|
key_map["keywords"] = "keywords"; |
|
|
|
} else { // usage_ == "test" |
|
|
|
key_map["id"] = "id"; |
|
|
|
key_map["sentence"] = "sentence"; |
|
|
|
key_map["keywords"] = "keywords"; |
|
|
|
} |
|
|
|
key_map["sentence"] = "sentence"; |
|
|
|
key_map["keywords"] = "keywords"; |
|
|
|
return key_map; |
|
|
|
} |
|
|
|
|
|
|
|
std::map<std::string, std::string> CLUENode::CreateKeyMapForWSCTask() { |
|
|
|
std::map<std::string, std::string> key_map; |
|
|
|
if (usage_ == "train" || usage_ == "eval") { |
|
|
|
key_map["label"] = "label"; |
|
|
|
} |
|
|
|
if (task_ == "WSC") { |
|
|
|
if (usage_ == "train" || usage_ == "eval") { |
|
|
|
key_map["span1_index"] = "target/span1_index"; |
|
|
|
key_map["span2_index"] = "target/span2_index"; |
|
|
|
key_map["span1_text"] = "target/span1_text"; |
|
|
|
key_map["span2_text"] = "target/span2_text"; |
|
|
|
key_map["idx"] = "idx"; |
|
|
|
key_map["label"] = "label"; |
|
|
|
key_map["text"] = "text"; |
|
|
|
} else { // usage_ == "test" |
|
|
|
key_map["span1_index"] = "target/span1_index"; |
|
|
|
key_map["span2_index"] = "target/span2_index"; |
|
|
|
key_map["span1_text"] = "target/span1_text"; |
|
|
|
key_map["span2_text"] = "target/span2_text"; |
|
|
|
key_map["idx"] = "idx"; |
|
|
|
key_map["text"] = "text"; |
|
|
|
} |
|
|
|
key_map["span1_index"] = "target/span1_index"; |
|
|
|
key_map["span2_index"] = "target/span2_index"; |
|
|
|
key_map["span1_text"] = "target/span1_text"; |
|
|
|
key_map["span2_text"] = "target/span2_text"; |
|
|
|
key_map["idx"] = "idx"; |
|
|
|
key_map["text"] = "text"; |
|
|
|
return key_map; |
|
|
|
} |
|
|
|
|
|
|
|
std::map<std::string, std::string> CLUENode::CreateKeyMap() { |
|
|
|
std::map<std::string, std::string> key_map; |
|
|
|
if (task_ == "AFQMC" || task_ == "CMNLI") { |
|
|
|
key_map = CreateKeyMapForAFQMCOrCMNLITask(); |
|
|
|
} else if (task_ == "CSL") { |
|
|
|
key_map = CreateKeyMapForCSLTask(); |
|
|
|
} else if (task_ == "IFLYTEK") { |
|
|
|
key_map = CreateKeyMapForIFLYTEKTask(); |
|
|
|
} else if (task_ == "TNEWS") { |
|
|
|
key_map = CreateKeyMapForTNEWSTask(); |
|
|
|
} else if (task_ == "WSC") { |
|
|
|
key_map = CreateKeyMapForWSCTask(); |
|
|
|
} |
|
|
|
return key_map; |
|
|
|
} |
|
|
|
|
|
|
|
// Function to build CLUENode |
|
|
|
Status CLUENode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) { |
|
|
|
auto key_map = CreateKeyMapForBuild(); |
|
|
|
auto key_map = CreateKeyMap(); |
|
|
|
ColKeyMap ck_map; |
|
|
|
for (auto &p : key_map) { |
|
|
|
ck_map.insert({p.first, split(p.second, '/')}); |
|
|
|
@@ -246,11 +245,11 @@ Status CLUENode::to_json(nlohmann::json *out_json) { |
|
|
|
return Status::OK(); |
|
|
|
} |
|
|
|
|
|
|
|
// Note: The following two functions are common among NonMappableSourceNode and should be promoted to its parent class. |
|
|
|
// CLUE by itself is a non-mappable dataset that does not support sampling. |
|
|
|
// However, if a cache operator is injected at some other place higher in the tree, that cache can |
|
|
|
// inherit this sampler from the leaf, providing sampling support from the caching layer. |
|
|
|
// That is why we setup the sampler for a leaf node that does not use sampling. |
|
|
|
// Note: The following two functions are common among NonMappableSourceNode and should be promoted to its parent |
|
|
|
// class. CLUE by itself is a non-mappable dataset that does not support sampling. However, if a cache operator is |
|
|
|
// injected at some other place higher in the tree, that cache can inherit this sampler from the leaf, providing |
|
|
|
// sampling support from the caching layer. That is why we setup the sampler for a leaf node that does not use |
|
|
|
// sampling. |
|
|
|
Status CLUENode::SetupSamplerForCache(std::shared_ptr<SamplerObj> *sampler) { |
|
|
|
bool shuffle_files = (shuffle_ == ShuffleMode::kGlobal || shuffle_ == ShuffleMode::kFiles); |
|
|
|
*sampler = SelectSampler(num_samples_, shuffle_files, num_shards_, shard_id_); |
|
|
|
|