From: @zyli2020 Reviewed-by: Signed-off-by:tags/v1.2.0-rc1
| @@ -53,7 +53,7 @@ ConstInputToAttrInfoRegistry::ConstInputToAttrInfoRegistry() { | |||
| Register(prim::kPrimReduceAny->name(), {1}); | |||
| Register(prim::kPrimUnsortedSegmentMin->name(), {2}); | |||
| Register(prim::kPrimUnsortedSegmentMax->name(), {2}); | |||
| Register(kSparseGatherV2, {2}); | |||
| Register(kSparseGatherV2OpName, {2}); | |||
| Register(kUnsortedSegmentProdOpName, {2}); | |||
| Register(kSimpleMeanGradOpName, {1}); | |||
| Register(kMeanGradOpName, {1}); | |||
| @@ -109,7 +109,7 @@ bool ConstInputToAttrInfoRegistry::GetRegisterByOpName(const std::string &op_nam | |||
| ConstInputToAttrInfoRegister *reg) const { | |||
| if (op_input_to_attr_map_.find(op_name) != op_input_to_attr_map_.end()) { | |||
| *reg = op_input_to_attr_map_.at(op_name); | |||
| MS_LOG(DEBUG) << op_name << " const2attr find in registery."; | |||
| MS_LOG(DEBUG) << op_name << " const2attr find in registry."; | |||
| return true; | |||
| } | |||
| return false; | |||
| @@ -31,15 +31,22 @@ std::string GetOpPythonPath(const OperatorName &op_name) { | |||
| // almost all ops are defined in two main paths | |||
| const std::string ops_module = OP_PATH; | |||
| const std::string inner_ops_module = INNER_OP_PATH; | |||
| const std::string functional_op_module = FUNCTIONAL_OP_PATH; | |||
| py::module mod = py::module::import(common::SafeCStr(ops_module)); | |||
| py::module inner_mod = py::module::import(common::SafeCStr(inner_ops_module)); | |||
| if (!py::hasattr(inner_mod, common::SafeCStr(op_name))) { | |||
| if (!py::hasattr(mod, common::SafeCStr(op_name))) { | |||
| MS_LOG(EXCEPTION) << ops_module << " or " << inner_ops_module << " don't have op:" << op_name; | |||
| } | |||
| py::module functional_mod = py::module::import(common::SafeCStr(functional_op_module)); | |||
| if (py::hasattr(inner_mod, common::SafeCStr(op_name))) { | |||
| return inner_ops_module; | |||
| } | |||
| if (py::hasattr(mod, common::SafeCStr(op_name))) { | |||
| return ops_module; | |||
| } | |||
| return inner_ops_module; | |||
| if (!py::hasattr(functional_mod, common::SafeCStr(op_name))) { | |||
| MS_LOG(EXCEPTION) << ops_module << " and " << inner_ops_module << " and " << functional_op_module | |||
| << " don't have op:" << op_name; | |||
| } | |||
| return functional_op_module; | |||
| } | |||
| ValuePtr CreatOpInstance(const OperatorAttrs &attrs, const OperatorName &op_name, const std::string &instance_name) { | |||
| @@ -141,7 +148,7 @@ Status GenerateGraph::Init(const CNodePtr &cnode) { | |||
| } | |||
| AnfNodePtr GenerateGraph::PushBack(const std::vector<AnfNodePtr> &inputs) { | |||
| CNodePtr cnode = func_graph_->NewCNode(inputs); // using NewCNode to creat anfnode | |||
| CNodePtr cnode = func_graph_->NewCNode(inputs); // using NewCNode to create anfnode | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| cnode->set_scope(scope_); | |||
| if (inputs.size() < 2) { | |||
| @@ -24,8 +24,10 @@ | |||
| #include "frontend/parallel/device_matrix.h" | |||
| #include "frontend/parallel/graph_util/generate_graph.h" | |||
| #include "frontend/parallel/context.h" | |||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||
| #include "ps/ps_cache/ps_data/ps_data_prefetch.h" | |||
| #include "ps/ps_cache/ps_cache_manager.h" | |||
| #include "utils/ms_context.h" | |||
| #endif | |||
| namespace mindspore { | |||
| @@ -158,6 +160,15 @@ Status GatherV2PInfo::GetAttrs() { | |||
| if (std::find(inputs_shape_[1].begin(), inputs_shape_[1].end(), -1) != inputs_shape_[1].end()) { | |||
| dynamic_shape_indices_ = true; | |||
| } | |||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||
| MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance()); | |||
| std::string parallel_mode = ParallelContext::GetInstance()->parallel_mode(); | |||
| MS_EXCEPTION_IF_NULL(MsContext::GetInstance()); | |||
| bool enable_sparse = MsContext::GetInstance()->get_param<bool>(MS_CTX_ENABLE_SPARSE); | |||
| if (ps::PsDataPrefetch::GetInstance().cache_enable() && enable_sparse) { | |||
| dynamic_shape_indices_ = true; | |||
| } | |||
| #endif | |||
| return SUCCESS; | |||
| } | |||
| @@ -531,7 +542,7 @@ Status GatherV2PInfo::InferBias() { | |||
| } | |||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||
| if (ps::PsDataPrefetch::GetInstance().cache_enable()) { | |||
| bias_ = 0; | |||
| bias_ = static_cast<int64_t>(ps::PsCacheManager::GetInstance().cache_indices_lower_bound()); | |||
| return SUCCESS; | |||
| } | |||
| #endif | |||
| @@ -68,6 +68,7 @@ constexpr char REDUCE_OP_MAX[] = "max"; | |||
| constexpr char REDUCE_OP_MIN[] = "min"; | |||
| constexpr char OP_PATH[] = "mindspore.ops.operations"; | |||
| constexpr char INNER_OP_PATH[] = "mindspore.ops.operations._inner_ops"; | |||
| constexpr char FUNCTIONAL_OP_PATH[] = "mindspore.ops.functional"; | |||
| constexpr char GET_OP_FUNCTION_PATH[] = "mindspore.parallel._utils"; | |||
| constexpr char GET_OP_FUNCTION[] = "_get_python_op"; | |||
| constexpr char KEEP_DIMS[] = "keep_dims"; | |||
| @@ -23,9 +23,13 @@ | |||
| #include "ir/value.h" | |||
| #include "frontend/parallel/device_matrix.h" | |||
| #include "frontend/parallel/graph_util/generate_graph.h" | |||
| #include "frontend/parallel/strategy.h" | |||
| #include "frontend/parallel/context.h" | |||
| #include "frontend/parallel/tensor_layout/tensor_redistribution.h" | |||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||
| #include "ps/ps_cache/ps_cache_manager.h" | |||
| #endif | |||
| namespace mindspore { | |||
| namespace parallel { | |||
| @@ -186,5 +190,63 @@ Status UniqueInfo::GenerateStrategies(int64_t stage_id) { | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||
| Status UniqueInfo::ComputeReplaceGraph(const CNodePtr &cnode) { | |||
| GenerateGraph gen_g = GenerateGraph(); | |||
| if (gen_g.Init(cnode) != SUCCESS) { | |||
| MS_LOG(ERROR) << "GenerateGraph Init failed"; | |||
| return FAILED; | |||
| } | |||
| auto bias = static_cast<int64_t>(ps::PsCacheManager::GetInstance().cache_indices_lower_bound()); | |||
| auto slice_size = SizeToLong(ps::PsCacheManager::GetInstance().vocab_cache_size()); | |||
| auto sub = gen_g.PushBack({gen_g.NewOpInst(SUB), gen_g.virtual_input_node(), CreateInt32Tensor(bias)}); | |||
| auto relu = gen_g.PushBack({gen_g.NewOpInst(RELU), sub}); | |||
| auto minimum = gen_g.PushBack({gen_g.NewOpInst(MINIMUM), relu, CreateInt32Tensor(slice_size - 1)}); | |||
| auto equal = gen_g.PushBack({gen_g.NewOpInst(EQUAL), sub, minimum}); | |||
| auto unique = gen_g.PushBack({gen_g.NewOpInst(replace_op_name_), gen_g.virtual_input_node()}); | |||
| auto tuple_getitem_0 = gen_g.PushBack({gen_g.NewOpInst(TUPLE_GETITEM), unique, CreatInt64Imm(0)}); | |||
| auto tuple_getitem_1 = gen_g.PushBack({gen_g.NewOpInst(TUPLE_GETITEM), unique, CreatInt64Imm(1)}); | |||
| auto dtype = gen_g.PushBack({gen_g.NewOpInst(DTYPE), tuple_getitem_1}); | |||
| auto cast = gen_g.PushBack({gen_g.NewOpInst(CAST), equal, dtype}); | |||
| auto mul = gen_g.PushBack({gen_g.NewOpInst(MUL), tuple_getitem_1, cast}); | |||
| Attr attr_op = std::make_pair(OP, MakeValue(REDUCE_OP_SUM)); | |||
| OperatorAttrs attrs = {attr_op}; | |||
| AnfNodePtr reduce_op; | |||
| reduce_op = gen_g.PushBack({gen_g.NewOpInst(ALL_REDUCE, attrs), mul}); | |||
| auto make_tuple = gen_g.PushBack({gen_g.NewOpInst(MAKE_TUPLE), tuple_getitem_0, reduce_op}); | |||
| std::vector<std::pair<AnfNodePtr, int64_t>> input_nodes = {std::make_pair(sub, 1), std::make_pair(unique, 1)}; | |||
| replace_graph_ = std::make_shared<std::pair<std::vector<std::pair<AnfNodePtr, int64_t>>, AnfNodePtr>>( | |||
| std::make_pair(input_nodes, make_tuple)); | |||
| return SUCCESS; | |||
| } | |||
| #endif | |||
| ReplaceGraphPtr UniqueInfo::replace_graph(const CNodePtr &cnode) { | |||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||
| if (ps::PsDataPrefetch::GetInstance().cache_enable()) { | |||
| auto inputs = cnode->inputs(); | |||
| if (inputs.empty()) { | |||
| MS_LOG(EXCEPTION) << "Invalid inputs"; | |||
| } | |||
| const auto &primitive = GetValueNode<PrimitivePtr>(inputs[0]); | |||
| const auto &attr = primitive->GetAttr("cache_enable"); | |||
| if (attr == nullptr) { | |||
| return nullptr; | |||
| } | |||
| auto need_mask = GetValue<bool>(attr); | |||
| if (!need_mask) { | |||
| return nullptr; | |||
| } | |||
| if (ComputeReplaceGraph(cnode) != SUCCESS) { | |||
| MS_LOG(EXCEPTION) << name_ << ": ComputeReplaceGraph failed."; | |||
| } | |||
| return replace_graph_; | |||
| } | |||
| #endif | |||
| return nullptr; | |||
| } | |||
| } // namespace parallel | |||
| } // namespace mindspore | |||
| @@ -39,6 +39,7 @@ class UniqueInfo : public OperatorInfo { | |||
| Status SetCostUnderStrategy(const StrategyPtr &strategy) override; | |||
| Status InitForCostModel(const StrategyPtr &strategy) override; | |||
| Status GenerateStrategies(int64_t stage_id) override; | |||
| ReplaceGraphPtr replace_graph(const CNodePtr &cnode) override; | |||
| protected: | |||
| Status CheckStrategy(const StrategyPtr &strategy) override; | |||
| @@ -50,8 +51,12 @@ class UniqueInfo : public OperatorInfo { | |||
| Status InferMirrorOps() override; | |||
| Status InferForwardCommunication() override { return SUCCESS; } | |||
| Status InferAsLossDivisor() override { return SUCCESS; } | |||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||
| Status ComputeReplaceGraph(const CNodePtr &cnode); | |||
| #endif | |||
| private: | |||
| std::string replace_op_name_ = UNIQUE; | |||
| int64_t dev_num_ = 1; | |||
| }; | |||
| } // namespace parallel | |||
| @@ -321,7 +321,8 @@ PYBIND11_MODULE(_c_expression, m) { | |||
| .def("insert_weight_init_info", &PSContext::InsertWeightInitInfo, "Insert embedding table initialization seed.") | |||
| .def("insert_accumu_init_info", &PSContext::InsertAccumuInitInfo, "Insert accumulation initialization value.") | |||
| .def("clone_hash_table", &PSContext::CloneHashTable, "Clone a hash table.") | |||
| .def("set_cache_enable", &PSContext::set_cache_enable, "Set ps mode cache enable or not."); | |||
| .def("set_cache_enable", &PSContext::set_cache_enable, "Set ps mode cache enable or not.") | |||
| .def("set_rank_id", &PSContext::set_rank_id, "Set rank id for worker on ps mode."); | |||
| (void)py::class_<OpInfoLoaderPy, std::shared_ptr<OpInfoLoaderPy>>(m, "OpInfoLoaderPy") | |||
| .def(py::init()) | |||
| @@ -773,12 +773,14 @@ void ParameterServer<T>::GetEmbeddingTableParamPtr() { | |||
| for (auto cnode : cnodes) { | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| std::string cnode_name = AnfAlgo::GetCNodeName(cnode); | |||
| if (cnode_name == kEmbeddingLookupOpName || cnode_name == kGatherV2OpName) { | |||
| if (cnode_name == kEmbeddingLookupOpName || cnode_name == kGatherV2OpName || cnode_name == kSparseGatherV2OpName) { | |||
| auto embedding_table = AnfAlgo::GetInputNode(cnode, 0); | |||
| MS_EXCEPTION_IF_NULL(embedding_table); | |||
| MS_LOG(INFO) << "Embedding table name is " << embedding_table->fullname_with_scope() << ", key is " << count; | |||
| embedding_tables_.insert(std::make_pair(count, embedding_table->cast<ParameterPtr>())); | |||
| count++; | |||
| if (embedding_table->isa<Parameter>()) { | |||
| MS_LOG(INFO) << "Embedding table name is " << embedding_table->fullname_with_scope() << ", key is " << count; | |||
| embedding_tables_.insert(std::make_pair(count, embedding_table->cast<ParameterPtr>())); | |||
| count++; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -35,11 +35,11 @@ void PsCacheManager::InsertHashTableSize(const std::string ¶m_name, size_t c | |||
| if (vocab_size_ == 0) { | |||
| vocab_size_ = vocab_size; | |||
| } | |||
| if (cache_vocab_size_ == 0) { | |||
| cache_vocab_size_ = cache_vocab_size; | |||
| if (vocab_cache_size_ == 0) { | |||
| vocab_cache_size_ = cache_vocab_size; | |||
| } | |||
| if (host_cache_vocab_size_ == 0) { | |||
| host_cache_vocab_size_ = cache_vocab_size * kHostCacheScaleFactor; | |||
| if (host_vocab_cache_size_ == 0) { | |||
| host_vocab_cache_size_ = cache_vocab_size * kHostCacheScaleFactor; | |||
| } | |||
| } | |||
| @@ -148,8 +148,8 @@ void PsCacheManager::Initialize() { | |||
| Util::SetInternalEnvVar(); | |||
| worker.Run(); | |||
| } | |||
| embedding_device_cache_ = std::make_shared<EmbeddingDeviceCache>(batch_elements_, cache_vocab_size_); | |||
| embedding_host_cache_ = std::make_shared<EmbeddingHostCache>(batch_elements_, host_cache_vocab_size_); | |||
| embedding_device_cache_ = std::make_shared<EmbeddingDeviceCache>(batch_elements_, vocab_cache_size_); | |||
| embedding_host_cache_ = std::make_shared<EmbeddingHostCache>(batch_elements_, host_vocab_cache_size_); | |||
| AddEmbeddingTable(); | |||
| AllocMemForHashTable(); | |||
| SetLocalIdRank(); | |||
| @@ -220,13 +220,13 @@ void PsCacheManager::AllocMemForHashTable() { | |||
| for (auto &item : hash_tables_) { | |||
| size_t embedding_size = item.second.embedding_size; | |||
| auto &device_address = item.second.device_address; | |||
| device_address.size = cache_vocab_size_ * embedding_size * sizeof(float); | |||
| device_address.size = vocab_cache_size_ * embedding_size * sizeof(float); | |||
| auto addr = embedding_device_cache_->cache_->MallocMemory(device_address.size); | |||
| MS_EXCEPTION_IF_NULL(addr); | |||
| device_address.addr = addr; | |||
| auto &host_address = item.second.host_address; | |||
| auto host_address_ptr = new float[host_cache_vocab_size_ * embedding_size]; | |||
| auto host_address_ptr = new float[host_vocab_cache_size_ * embedding_size]; | |||
| MS_EXCEPTION_IF_NULL(host_address_ptr); | |||
| host_address = std::shared_ptr<float[]>(host_address_ptr, std::default_delete<float[]>()); | |||
| MS_EXCEPTION_IF_NULL(host_address); | |||
| @@ -239,21 +239,28 @@ void PsCacheManager::AllocMemForHashTable() { | |||
| embedding_device_cache_->hash_swap_value_addr_ = reinterpret_cast<float *>( | |||
| embedding_device_cache_->cache_->MallocMemory(max_embedding_size * batch_elements_ * sizeof(float))); | |||
| MS_EXCEPTION_IF_NULL(embedding_device_cache_->hash_swap_value_addr_); | |||
| if (!(embedding_device_cache_->cache_->MallocConstantMemory(cache_vocab_size_))) { | |||
| if (!(embedding_device_cache_->cache_->MallocConstantMemory(vocab_cache_size_))) { | |||
| MS_LOG(EXCEPTION) << "MallocConstantMemory failed."; | |||
| } | |||
| } | |||
| void PsCacheManager::SetLocalIdRank() { | |||
| auto worker_num = ::ps::NumWorkers(); | |||
| auto worker_id = ::ps::MyRank(); | |||
| auto local_shard_size = FloatToSize(std::ceil(SizeToFloat(vocab_size_) / worker_num)); | |||
| range_bound_.first = local_shard_size * worker_id; | |||
| range_bound_.second = std::min(range_bound_.first + local_shard_size, vocab_size_); | |||
| MS_LOG(INFO) << "Worker num:" << worker_num << ", worker id:" << worker_id << ", rank id begin:" << range_bound_.first | |||
| << ", rank id end:" << range_bound_.second; | |||
| auto local_shard_size = FloatToInt(std::ceil(SizeToFloat(vocab_size_) / worker_num)); | |||
| vocab_cache_size_diff_ = local_shard_size - SizeToInt(vocab_cache_size_); | |||
| emb_table_slice_bounds_.first = local_shard_size * rank_id_; | |||
| emb_table_slice_bounds_.second = std::min(emb_table_slice_bounds_.first + local_shard_size, SizeToInt(vocab_size_)); | |||
| cache_indices_bounds_.first = SizeToInt(vocab_cache_size_) * rank_id_; | |||
| cache_indices_bounds_.second = cache_indices_bounds_.first + SizeToInt(vocab_cache_size_); | |||
| MS_LOG(INFO) << "Worker num:" << worker_num << ", rank id:" << rank_id_ | |||
| << ", id begin:" << emb_table_slice_bounds_.first << ", id end:" << emb_table_slice_bounds_.second | |||
| << ", cache indices begin: " << cache_indices_bounds_.first | |||
| << ", cache indices end: " << cache_indices_bounds_.second | |||
| << ", vocab_cache_size_diff: " << vocab_cache_size_diff_; | |||
| } | |||
| int PsCacheManager::cache_indices_lower_bound() const { return cache_indices_bounds_.first; } | |||
| std::string PsCacheManager::channel_name() { | |||
| std::lock_guard<std::mutex> locker(channel_mutex_); | |||
| return channel_name_; | |||
| @@ -398,8 +405,8 @@ bool PsCacheManager::ProcessData() { | |||
| return true; | |||
| } | |||
| bool PsCacheManager::CheckIDInDeviceTask(const int *batch_ids, const size_t batch_ids_len, int *hash_index, | |||
| bool *in_device, size_t *hash_hit_count) { | |||
| bool PsCacheManager::CheckCacheHitOrOutRangeTask(const int *batch_ids, const size_t batch_ids_len, int *hash_index, | |||
| bool *in_device, bool *out_range, size_t *hash_hit_count) { | |||
| MS_ERROR_IF_NULL(batch_ids); | |||
| MS_ERROR_IF_NULL(hash_index); | |||
| MS_ERROR_IF_NULL(in_device); | |||
| @@ -410,9 +417,19 @@ bool PsCacheManager::CheckIDInDeviceTask(const int *batch_ids, const size_t batc | |||
| const auto &hash_id_to_index = device_hash_map->hash_id_to_index(); | |||
| for (size_t i = 0; i < batch_ids_len; ++i) { | |||
| if (batch_ids[i] < emb_table_slice_bounds_.first) { | |||
| hash_index[i] = batch_ids[i] - vocab_cache_size_diff_; | |||
| out_range[i] = true; | |||
| continue; | |||
| } | |||
| if (batch_ids[i] >= emb_table_slice_bounds_.second) { | |||
| hash_index[i] = batch_ids[i] + cache_indices_bounds_.second; | |||
| out_range[i] = true; | |||
| continue; | |||
| } | |||
| auto iter = hash_id_to_index.find(batch_ids[i]); | |||
| if (iter != hash_id_to_index.end()) { | |||
| hash_index[i] = iter->second; | |||
| hash_index[i] = iter->second + cache_indices_bounds_.first; | |||
| if (device_hash_map->hash_step(iter->second) != data_step_) { | |||
| ++(*hash_hit_count); | |||
| device_hash_map->set_hash_step(iter->second, data_step_); | |||
| @@ -423,11 +440,12 @@ bool PsCacheManager::CheckIDInDeviceTask(const int *batch_ids, const size_t batc | |||
| return true; | |||
| } | |||
| bool PsCacheManager::CheckIDInDevice(const int *batch_ids, const size_t batch_ids_len, int *hash_index, | |||
| bool *in_device) { | |||
| bool PsCacheManager::CheckCacheHitOrOutRange(const int *batch_ids, const size_t batch_ids_len, int *hash_index, | |||
| bool *in_device, bool *out_range) { | |||
| MS_ERROR_IF_NULL(batch_ids); | |||
| MS_ERROR_IF_NULL(hash_index); | |||
| MS_ERROR_IF_NULL(in_device); | |||
| MS_ERROR_IF_NULL(out_range); | |||
| size_t thread_num = batch_ids_len / kMinIdsPerThread + 1; | |||
| thread_num = thread_num > kMaxThreadNum ? kMaxThreadNum : thread_num; | |||
| @@ -441,8 +459,9 @@ bool PsCacheManager::CheckIDInDevice(const int *batch_ids, const size_t batch_id | |||
| break; | |||
| } | |||
| size_t task_proc_lens = batch_ids_len / thread_num + (i < (batch_ids_len % thread_num) ? 1 : 0); | |||
| threads[i] = std::thread(&PsCacheManager::CheckIDInDeviceTask, this, batch_ids + task_offset, task_proc_lens, | |||
| hash_index + task_offset, in_device + task_offset, hash_hit_count + i); | |||
| threads[i] = | |||
| std::thread(&PsCacheManager::CheckCacheHitOrOutRangeTask, this, batch_ids + task_offset, task_proc_lens, | |||
| hash_index + task_offset, in_device + task_offset, out_range + task_offset, hash_hit_count + i); | |||
| task_offset += task_proc_lens; | |||
| } | |||
| if (task_offset != batch_ids_len) { | |||
| @@ -477,27 +496,26 @@ bool PsCacheManager::ParseData(const int *batch_ids, const size_t batch_ids_len, | |||
| MS_ERROR_IF_NULL(hash_index); | |||
| statistics_info_.batch_id_count_ = batch_ids_len; | |||
| std::unique_ptr<bool[]> in_device(new bool[batch_ids_len]); | |||
| std::unique_ptr<bool[]> out_range(new bool[batch_ids_len]); | |||
| if (memset_s(in_device.get(), batch_ids_len * sizeof(bool), 0, batch_ids_len * sizeof(bool))) { | |||
| MS_LOG(EXCEPTION) << "Data in device memset failed."; | |||
| MS_LOG(EXCEPTION) << "Initialize in_device array failed."; | |||
| } | |||
| if (memset_s(out_range.get(), batch_ids_len * sizeof(bool), 0, batch_ids_len * sizeof(bool))) { | |||
| MS_LOG(EXCEPTION) << "Initialize out_range array failed."; | |||
| } | |||
| CheckIDInDevice(batch_ids, batch_ids_len, hash_index, in_device.get()); | |||
| RETURN_IF_FALSE(CheckCacheHitOrOutRange(batch_ids, batch_ids_len, hash_index, in_device.get(), out_range.get())); | |||
| RETURN_IF_FALSE(ResetEmbeddingHashMap()); | |||
| for (size_t i = 0; i < batch_ids_len; i++) { | |||
| if (in_device[i]) { | |||
| if (in_device[i] || out_range[i]) { | |||
| continue; | |||
| } | |||
| bool need_swap_host_to_device = true; | |||
| bool need_swap_device_to_host = true; | |||
| auto id = batch_ids[i]; | |||
| if ((id < SizeToInt(range_bound_.first)) || (id >= SizeToInt(range_bound_.second))) { | |||
| hash_index[i] = -1; | |||
| continue; | |||
| } | |||
| int index = INVALID_INDEX_VALUE; | |||
| RETURN_IF_FALSE(ParseDeviceData(id, &need_swap_device_to_host, &need_swap_host_to_device, &index)); | |||
| hash_index[i] = index; | |||
| RETURN_IF_FALSE(ParseDeviceData(batch_ids[i], &need_swap_device_to_host, &need_swap_host_to_device, &index)); | |||
| hash_index[i] = index + cache_indices_bounds_.first; | |||
| if (need_swap_host_to_device) { | |||
| RETURN_IF_FALSE(ParseHostDataHostToDevice(id)); | |||
| RETURN_IF_FALSE(ParseHostDataHostToDevice(batch_ids[i])); | |||
| } | |||
| if (need_swap_device_to_host) { | |||
| RETURN_IF_FALSE(ParseHostDataDeviceToHost()); | |||
| @@ -667,7 +685,7 @@ void PsCacheManager::LookUpTableTask(size_t indices_lens, size_t outer_dim_size, | |||
| bool PsCacheManager::LookUpHostHashTable(size_t embedding_size, size_t indices_lens, const float *hash_table_addr, | |||
| const int *indices_addr, float *output_addr) { | |||
| size_t first_dim_size = host_cache_vocab_size_; | |||
| size_t first_dim_size = host_vocab_cache_size_; | |||
| size_t outer_dim_size = embedding_size; | |||
| size_t thread_num = indices_lens / 10000 + 1; | |||
| @@ -697,7 +715,7 @@ bool PsCacheManager::LookUpHostHashTable(size_t embedding_size, size_t indices_l | |||
| bool PsCacheManager::InsertHostHashTable(size_t embedding_size, size_t insert_indices_size, int *insert_indices, | |||
| float *insert_data, float *hash_table_addr) { | |||
| size_t first_dim_size = host_cache_vocab_size_; | |||
| size_t first_dim_size = host_vocab_cache_size_; | |||
| size_t thread_num = insert_indices_size / 10000 + 1; | |||
| thread_num = thread_num > kMaxThreadNum ? kMaxThreadNum : thread_num; | |||
| std::thread threads[kMaxThreadNum]; | |||
| @@ -125,7 +125,10 @@ class PsCacheManager { | |||
| const size_t &QueryHashTableSize(const std::string ¶m_name) const; | |||
| bool IsHashTable(const std::string ¶m_name) { return hash_tables_.count(param_name) != 0; } | |||
| void set_batch_elements(size_t batch_elements) { batch_elements_ = batch_elements; } | |||
| void set_rank_id(int rank_id) { rank_id_ = rank_id; } | |||
| bool initialized_ps_cache() const { return initialized_ps_cache_; } | |||
| size_t vocab_cache_size() const { return vocab_cache_size_; } | |||
| int cache_indices_lower_bound() const; | |||
| void DoProcessData(uint32_t device_id, void *context); | |||
| void IncreaseGraphStep(const std::string &channel_name); | |||
| void SyncEmbeddingTable(); | |||
| @@ -170,10 +173,12 @@ class PsCacheManager { | |||
| void DumpStatisticsInfo(size_t each_print_step = 1000); | |||
| bool SyncHostEmbeddingTable(); | |||
| bool SyncDeviceEmbeddingTable(); | |||
| bool CheckIDInDeviceTask(const int *batch_ids, const size_t batch_ids_len, int *hash_index, bool *in_device, | |||
| size_t *hash_hit_count); | |||
| bool CheckIDInDevice(const int *batch_ids, const size_t batch_ids_len, int *hash_index, bool *in_device); | |||
| bool CheckCacheHitOrOutRangeTask(const int *batch_ids, const size_t batch_ids_len, int *hash_index, bool *in_device, | |||
| bool *out_range, size_t *hash_hit_count); | |||
| bool CheckCacheHitOrOutRange(const int *batch_ids, const size_t batch_ids_len, int *hash_index, bool *in_device, | |||
| bool *out_range); | |||
| bool ResetEmbeddingHashMap(); | |||
| bool initialized_ps_cache_{false}; | |||
| std::string channel_name_; | |||
| std::mutex channel_mutex_; | |||
| @@ -190,11 +195,14 @@ class PsCacheManager { | |||
| std::shared_ptr<EmbeddingHostCache> embedding_host_cache_; | |||
| size_t vocab_size_{0}; | |||
| size_t cache_vocab_size_{0}; | |||
| size_t host_cache_vocab_size_{0}; | |||
| size_t vocab_cache_size_{0}; | |||
| size_t host_vocab_cache_size_{0}; | |||
| size_t batch_elements_{0}; | |||
| PsCacheStatisticsInfo statistics_info_; | |||
| std::pair<size_t, size_t> range_bound_; | |||
| std::pair<int, int> emb_table_slice_bounds_; | |||
| std::pair<int, int> cache_indices_bounds_; | |||
| int vocab_cache_size_diff_{0}; | |||
| int rank_id_{0}; | |||
| std::atomic_bool finish_insert_init_info_{false}; | |||
| std::atomic_bool finish_init_parameter_server_{false}; | |||
| std::atomic_bool running_{false}; | |||
| @@ -129,5 +129,11 @@ void PSContext::set_cache_enable(bool cache_enable) const { | |||
| PsDataPrefetch::GetInstance().set_cache_enable(cache_enable); | |||
| #endif | |||
| } | |||
| void PSContext::set_rank_id(int rank_id) const { | |||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||
| ps_cache_instance.set_rank_id(rank_id); | |||
| #endif | |||
| } | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| @@ -52,6 +52,7 @@ class PSContext { | |||
| void InsertAccumuInitInfo(const std::string ¶m_name, float init_val) const; | |||
| void CloneHashTable(const std::string &dest_param_name, const std::string &src_param_name) const; | |||
| void set_cache_enable(bool cache_enable) const; | |||
| void set_rank_id(int rank_id) const; | |||
| private: | |||
| PSContext() : ps_enabled_(false), is_worker_(false), is_pserver_(false), is_sched_(false), rank_id_(-1) {} | |||
| @@ -391,7 +391,7 @@ bool AscendKernelRuntime::GenDynamicKernel(const session::KernelGraph *graph) { | |||
| bool AscendKernelRuntime::GenTask(const session::KernelGraph *graph) { | |||
| InnerSetContext(); | |||
| if (graph->is_dynamic_shape()) { | |||
| if (ConfigManager::GetInstance().dataset_mode() == DS_SINK_MODE) { | |||
| if (ConfigManager::GetInstance().dataset_mode() == DS_SINK_MODE && (ConfigManager::GetInstance().iter_num() > 1)) { | |||
| MS_LOG(EXCEPTION) << "Dynamic shape is not supported with sink mode."; | |||
| } | |||
| if (DumpJsonParser::GetInstance().async_dump_enabled()) { | |||
| @@ -851,7 +851,7 @@ void GPUKernelRuntime::UpdateHostSwapInQueue(const DeviceAddressPtr device_addre | |||
| MS_LOG(WARNING) << "Unexpected device address status: " << status; | |||
| break; | |||
| default: | |||
| MS_LOG(EXCEPTION) << "Invaild device address status: " << status; | |||
| MS_LOG(EXCEPTION) << "Invalid device address status: " << status; | |||
| } | |||
| } | |||
| @@ -1092,6 +1092,7 @@ void GPUKernelRuntime::FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel) | |||
| MS_EXCEPTION_IF_NULL(mem_reuse_util_); | |||
| auto cnode = kernel->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| // Can not free the input addr of communication op when enable multi stream | |||
| if (AnfAlgo::IsCommunicationOp(kernel)) { | |||
| return; | |||
| } | |||
| @@ -1106,7 +1107,9 @@ void GPUKernelRuntime::FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel) | |||
| } | |||
| auto kernel_with_index = GetPrevNodeOutput(kernel, i); | |||
| if (AnfAlgo::IsCommunicationOp(kernel_with_index.first)) { | |||
| // Maintain output addr of fused communication op to improve training performance | |||
| if (AnfAlgo::IsCommunicationOp(kernel_with_index.first) && | |||
| AnfAlgo::GetInputTensorNum(kernel_with_index.first) > 1) { | |||
| continue; | |||
| } | |||
| @@ -1049,7 +1049,8 @@ void KernelRuntime::GetFirstPSEmbeddingCache(const session::KernelGraph *graph, | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| for (const auto &kernel : graph->execution_order()) { | |||
| MS_EXCEPTION_IF_NULL(kernel); | |||
| if (AnfAlgo::GetCNodeName(kernel) != "GatherV2") { | |||
| auto kernel_name = AnfAlgo::GetCNodeName(kernel); | |||
| if (kernel_name != kGatherV2OpName && kernel_name != kSparseGatherV2OpName) { | |||
| continue; | |||
| } | |||
| auto input_param = AnfAlgo::GetPrevNodeOutput(kernel, 0, true); | |||
| @@ -1061,13 +1062,15 @@ void KernelRuntime::GetFirstPSEmbeddingCache(const session::KernelGraph *graph, | |||
| continue; | |||
| } | |||
| auto size = ps::ps_cache_instance.QueryHashTableSize(param_name); | |||
| while (input_index.first->isa<CNode>() && (AnfAlgo::GetCNodeName(input_index.first) == "Cast")) { | |||
| input_index = AnfAlgo::GetPrevNodeOutput(input_index.first, input_index.second, true); | |||
| while (input_index.first->isa<CNode>() && (AnfAlgo::GetCNodeName(input_index.first) == kCastOpName)) { | |||
| input_index = AnfAlgo::GetPrevNodeOutput(input_index.first, 0, true); | |||
| MS_EXCEPTION_IF_NULL(input_index.first); | |||
| } | |||
| if (input_index.first->isa<CNode>() && (AnfAlgo::GetCNodeName(input_index.first) != "GetNext")) { | |||
| auto input_index_node_name = AnfAlgo::GetCNodeName(input_index.first); | |||
| if (input_index.first->isa<CNode>() && (input_index_node_name != kGetNextOpName)) { | |||
| bool full_batch = parallel::ParallelContext::GetInstance()->full_batch(); | |||
| if ((!full_batch) || (AnfAlgo::GetCNodeName(input_index.first) != "Minimum")) { | |||
| if ((!full_batch && (input_index_node_name != kUniqueOpName)) || | |||
| (full_batch && (input_index_node_name != kMinimumOpName))) { | |||
| MS_LOG(ERROR) << "The input index of the embeddingLookup(" << kernel->fullname_with_scope() | |||
| << ") cache is from " << input_index.first->fullname_with_scope(); | |||
| MS_LOG(EXCEPTION) << "The embeddingLookup whose input index isn't from dataset doesn't support cache in " | |||
| @@ -1082,6 +1085,28 @@ void KernelRuntime::GetFirstPSEmbeddingCache(const session::KernelGraph *graph, | |||
| } | |||
| } | |||
| void KernelRuntime::CheckSparsePSEmbeddingCache(const CNodePtr &node) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| auto pre_node = AnfAlgo::GetPrevNodeOutput(node, 1, true); | |||
| while (pre_node.first->isa<CNode>() && (AnfAlgo::GetCNodeName(pre_node.first) != kUniqueOpName)) { | |||
| pre_node = AnfAlgo::GetPrevNodeOutput(pre_node.first, 0, true); | |||
| MS_EXCEPTION_IF_NULL(pre_node.first); | |||
| } | |||
| if (!(pre_node.first->isa<CNode>()) || (AnfAlgo::GetCNodeName(pre_node.first) != kUniqueOpName)) { | |||
| MS_LOG(EXCEPTION) << "The input_indices of kernel[SparseGatherV2] must be unique in parameter server cache mode"; | |||
| } | |||
| pre_node = AnfAlgo::GetPrevNodeOutput(pre_node.first, 0, true); | |||
| while (pre_node.first->isa<CNode>() && (AnfAlgo::GetCNodeName(pre_node.first) == kCastOpName)) { | |||
| pre_node = AnfAlgo::GetPrevNodeOutput(pre_node.first, 0, true); | |||
| MS_EXCEPTION_IF_NULL(pre_node.first); | |||
| } | |||
| if (!(pre_node.first->isa<CNode>()) || (AnfAlgo::GetCNodeName(pre_node.first) != kGetNextOpName)) { | |||
| MS_LOG(EXCEPTION) << "The input indices of kernel[Unique] must be produced from dataset directly and the indices " | |||
| "value can not be changed before delivering to kernel[Unique] in parameter server cache mode."; | |||
| } | |||
| } | |||
| void KernelRuntime::CheckIfSupportPSEmbeddingCache(const session::KernelGraph *graph) { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| AnfNodePtr first_cache_input_index = nullptr; | |||
| @@ -1090,16 +1115,23 @@ void KernelRuntime::CheckIfSupportPSEmbeddingCache(const session::KernelGraph *g | |||
| MS_EXCEPTION_IF_NULL(first_cache_input_index); | |||
| for (const auto &kernel : graph->execution_order()) { | |||
| MS_EXCEPTION_IF_NULL(kernel); | |||
| if (AnfAlgo::GetCNodeName(kernel) != "GatherV2") { | |||
| auto kernel_name = AnfAlgo::GetCNodeName(kernel); | |||
| if (kernel_name != kGatherV2OpName && kernel_name != kSparseGatherV2OpName) { | |||
| continue; | |||
| } | |||
| auto input_param = AnfAlgo::GetPrevNodeOutput(kernel, 0, true); | |||
| auto input_index = AnfAlgo::GetPrevNodeOutput(kernel, 1, true); | |||
| MS_EXCEPTION_IF_NULL(input_param.first); | |||
| MS_EXCEPTION_IF_NULL(input_index.first); | |||
| if (!input_param.first->isa<Parameter>()) { | |||
| continue; | |||
| } | |||
| auto param_name = input_param.first->fullname_with_scope(); | |||
| while (input_index.first->isa<CNode>() && (AnfAlgo::GetCNodeName(input_index.first) == "Cast")) { | |||
| input_index = AnfAlgo::GetPrevNodeOutput(input_index.first, input_index.second, true); | |||
| if (ps::ps_cache_instance.IsHashTable(param_name) && (kernel_name == kSparseGatherV2OpName)) { | |||
| CheckSparsePSEmbeddingCache(kernel); | |||
| } | |||
| while (input_index.first->isa<CNode>() && (AnfAlgo::GetCNodeName(input_index.first) == kCastOpName)) { | |||
| input_index = AnfAlgo::GetPrevNodeOutput(input_index.first, 0, true); | |||
| MS_EXCEPTION_IF_NULL(input_index.first); | |||
| } | |||
| if (input_index.first == first_cache_input_index) { | |||
| @@ -138,6 +138,7 @@ class KernelRuntime { | |||
| void GetFirstPSEmbeddingCache(const session::KernelGraph *graph, AnfNodePtr *first_cache_input_index, | |||
| size_t *first_cache_size); | |||
| void CheckIfSupportPSEmbeddingCache(const session::KernelGraph *graph); | |||
| void CheckSparsePSEmbeddingCache(const CNodePtr &node); | |||
| #endif | |||
| protected: | |||
| @@ -83,7 +83,7 @@ constexpr auto kScatterNdOpName = "ScatterNd"; | |||
| constexpr auto kStridedSliceAssignOpName = "StridedSliceAssign"; | |||
| constexpr auto kStridedSliceOpName = "StridedSlice"; | |||
| constexpr auto kStridedSliceGradOpName = "StridedSliceGrad"; | |||
| constexpr auto kSparseGatherV2 = "SparseGatherV2"; | |||
| constexpr auto kSparseGatherV2OpName = "SparseGatherV2"; | |||
| constexpr auto kUnsortedSegmentProdOpName = "UnsortedSegmentProd"; | |||
| constexpr auto kUnsortedSegmentMinOpName = "UnsortedSegmentMin"; | |||
| constexpr auto kFlattenGradOpName = "FlattenGrad"; | |||
| @@ -73,6 +73,13 @@ inline size_t FloatToSize(float u) { | |||
| } | |||
| inline float IntToFloat(int32_t v) { return static_cast<float>(v); } | |||
| inline int FloatToInt(float u) { | |||
| if (u > static_cast<float>((std::numeric_limits<int>::max)())) { | |||
| MS_LOG(EXCEPTION) << "The float value(" << u << ") exceeds the maximum value of int."; | |||
| } | |||
| return static_cast<int>(u); | |||
| } | |||
| inline float SizeToFloat(size_t v) { return static_cast<float>(v); } | |||
| inline double LongToDouble(int64_t v) { return static_cast<double>(v); } | |||
| @@ -20,10 +20,12 @@ from mindspore.ops import operations as P | |||
| from mindspore.ops import functional as F | |||
| from mindspore.common.parameter import Parameter | |||
| from mindspore.common.initializer import initializer | |||
| from mindspore.communication.management import get_group_size | |||
| from mindspore.context import ParallelMode, get_context | |||
| from mindspore.communication.management import get_group_size, get_rank | |||
| from mindspore.context import ParallelMode | |||
| from mindspore.parallel._utils import _get_parallel_mode, _get_full_batch | |||
| from mindspore.parallel._ps_context import _insert_hash_table_size, _set_cache_enable, _is_role_worker, _get_ps_context | |||
| from mindspore.parallel._ps_context import _is_role_worker, _get_ps_context | |||
| from mindspore.parallel._ps_context import _insert_hash_table_size, _set_cache_enable, _set_rank_id | |||
| from mindspore import context | |||
| from mindspore._checkparam import Rel | |||
| from mindspore._checkparam import Validator as validator | |||
| from mindspore.ops.primitive import constexpr | |||
| @@ -227,8 +229,6 @@ class EmbeddingLookup(Cell): | |||
| self.embedding_size = validator.check_positive_int(embedding_size, 'embedding_size') | |||
| self.embedding_table = Parameter(initializer(param_init, [self.vocab_size, self.embedding_size]), | |||
| name='embedding_table') | |||
| if self.cache_enable and enable_ps: | |||
| self._set_voacb_cache_enable_for_ps(vocab_cache_size, embedding_size, vocab_size) | |||
| parallel_mode = _get_parallel_mode() | |||
| is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL) | |||
| self.gather_revert = P.GatherV2() | |||
| @@ -238,6 +238,10 @@ class EmbeddingLookup(Cell): | |||
| self.shape = P.Shape() | |||
| if is_auto_parallel: | |||
| self.unique = P.Unique().shard(((1,),)) | |||
| if self.cache_enable and enable_ps: | |||
| self._set_voacb_cache_enable_for_ps(vocab_cache_size, embedding_size, vocab_size) | |||
| if is_auto_parallel: | |||
| self.unique.add_prim_attr('cache_enable', True) | |||
| indices_shape_size = 2 | |||
| if slice_mode == "field_slice" and is_auto_parallel: | |||
| if not manual_shapes: | |||
| @@ -252,7 +256,7 @@ class EmbeddingLookup(Cell): | |||
| self.embeddinglookup.shard(((get_group_size(), 1), (1, get_group_size()))) | |||
| elif slice_mode == "table_row_slice" and is_auto_parallel: | |||
| full_batch = _get_full_batch() | |||
| if target == 'DEVICE' and not full_batch: | |||
| if (target == 'DEVICE' and not full_batch) or (self.cache_enable and enable_ps and sparse): | |||
| indices_shape_size = 1 | |||
| self.gather_revert.shard(((1, 1), (get_group_size(),))) | |||
| self.forward_unique = True | |||
| @@ -293,7 +297,7 @@ class EmbeddingLookup(Cell): | |||
| raise ValueError("The configuration of 'vocab_cache_size' is valid only in 'DEVICE' target.") | |||
| if not self.sparse: | |||
| raise ValueError("The configuration of 'vocab_cache_size' is valid only 'sparse' is true.") | |||
| if get_context("device_target") != 'Ascend': | |||
| if context.get_context("device_target") != 'Ascend': | |||
| raise ValueError("The configuration of 'vocab_cache_size' is valid only in 'ascend'.") | |||
| logger.info("EmbeddingLookup cache enable takes effect.") | |||
| @@ -320,21 +324,29 @@ class EmbeddingLookup(Cell): | |||
| parallel_mode = _get_parallel_mode() | |||
| is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL) | |||
| if is_auto_parallel: | |||
| device_num = get_group_size() | |||
| rank_size = get_group_size() | |||
| rank_id = get_rank() | |||
| full_batch = _get_full_batch() | |||
| if device_num > 1 and not (full_batch and slice_mode == "table_row_slice"): | |||
| if rank_size > 1 and not (full_batch and slice_mode == "table_row_slice"): | |||
| raise ValueError("The embeddingLookup cache of parameter server parallel only be used " | |||
| "in 'full_batch' and 'table_row_slice' parallel strategy.") | |||
| self.vocab_cache_size = self.vocab_cache_size * device_num | |||
| self.vocab_cache_size = self.vocab_cache_size * rank_size | |||
| _set_rank_id(rank_id) | |||
| self.cache_enable = True | |||
| if _is_role_worker(): | |||
| self.vocab_size = self.vocab_cache_size | |||
| if context.get_context("enable_sparse") != self.sparse: | |||
| raise ValueError("The value of parameter 'sparse' must be same for all EmbeddingLookup " | |||
| "kernels and equal the value of 'enable_sparse' in context setting in " | |||
| "parameter server cache mode") | |||
| def _set_voacb_cache_enable_for_ps(self, vocab_cache_size, embedding_size, vocab_size): | |||
| """PS embeddingLookup cache enable set.""" | |||
| self.embedding_table.cache_enable = True | |||
| self.embedding_table.is_param_ps = True | |||
| _set_cache_enable(True) | |||
| if self.sparse: | |||
| self.forward_unique = True | |||
| if _is_role_worker(): | |||
| _insert_hash_table_size(self.embedding_table.name, vocab_cache_size, embedding_size, vocab_size) | |||
| @@ -28,14 +28,15 @@ _lazy_adam_opt = C.MultitypeFuncGraph("lazy_adam_opt") | |||
| @_lazy_adam_opt.register("Function", "Function", "Function", "Function", "Bool", "Bool", "Bool", "Tensor", "Tensor", | |||
| "Tensor", "Tensor", "Tensor", "Tensor", "RowTensor", "Tensor", "Tensor", "Tensor", "Bool") | |||
| "Tensor", "Tensor", "Tensor", "Tensor", "RowTensor", "Tensor", "Tensor", "Tensor", "Bool", | |||
| "Bool") | |||
| def _run_opt_with_sparse(opt, sparse_opt, push, pull, use_locking, use_nesterov, target, beta1_power, beta2_power, | |||
| beta1, beta2, eps, lr, gradient, params, m, v, ps_parameter): | |||
| beta1, beta2, eps, lr, gradient, params, m, v, ps_parameter, cache_enable): | |||
| """Apply sparse lazy adam optimizer to the weight parameter when the gradient is sparse.""" | |||
| success = True | |||
| indices = gradient.indices | |||
| values = gradient.values | |||
| if ps_parameter: | |||
| if ps_parameter and not cache_enable: | |||
| op_shape = P.Shape() | |||
| shapes = (op_shape(params), op_shape(m), op_shape(v), | |||
| op_shape(beta1_power), op_shape(beta2_power), op_shape(lr), op_shape(beta1), | |||
| @@ -75,12 +76,12 @@ def _run_opt_with_sparse(opt, sparse_opt, push, pull, use_locking, use_nesterov, | |||
| @_lazy_adam_opt.register("Function", "Function", "Function", "Function", "Bool", "Bool", "Bool", "Tensor", "Tensor", | |||
| "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Bool") | |||
| def _run_opt_with_one_number(opt, sparse_opt, push, pull, use_locking, use_nesterov, target, beta1_power, | |||
| beta2_power, beta1, beta2, eps, lr, gradient, params, moment1, moment2, ps_parameter): | |||
| "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Bool", "Bool") | |||
| def _run_opt_with_one_number(opt, sparse_opt, push, pull, use_locking, use_nesterov, target, beta1_power, beta2_power, | |||
| beta1, beta2, eps, lr, gradient, params, moment1, moment2, ps_parameter, cache_enable): | |||
| """Apply lazy adam optimizer to the weight parameter using Tensor.""" | |||
| success = True | |||
| if ps_parameter: | |||
| if ps_parameter and not cache_enable: | |||
| op_shape = P.Shape() | |||
| success = F.depend(success, pull(push((beta1_power, beta2_power, lr, beta1, beta2, eps, gradient), | |||
| (op_shape(params), op_shape(moment1), op_shape(moment2))), params)) | |||
| @@ -245,12 +246,14 @@ class LazyAdam(Optimizer): | |||
| success = self.map_(F.partial(_lazy_adam_opt, self.opt, self.sparse_opt, self._ps_push, self._ps_pull, | |||
| self.use_locking, self.use_nesterov, self._is_device, | |||
| self.beta1_power, self.beta2_power, self.beta1, self.beta2, self.eps), | |||
| lr, gradients, self.parameters, self.moment1, self.moment2, self.ps_parameters) | |||
| lr, gradients, self.parameters, self.moment1, self.moment2, self.ps_parameters, | |||
| self.cache_enable) | |||
| else: | |||
| success = self.map_(F.partial(_lazy_adam_opt, self.opt, self.sparse_opt, self._ps_push, self._ps_pull, | |||
| self.use_locking, self.use_nesterov, self._is_device, | |||
| self.beta1_power, self.beta2_power, self.beta1, self.beta2, self.eps, lr), | |||
| gradients, self.parameters, self.moment1, self.moment2, self.ps_parameters) | |||
| gradients, self.parameters, self.moment1, self.moment2, self.ps_parameters, | |||
| self.cache_enable) | |||
| return success | |||
| @Optimizer.target.setter | |||
| @@ -142,3 +142,6 @@ def _set_cache_enable(cache_enable): | |||
| os.environ['GOTO_NUM_THREADS'] = '2' | |||
| os.environ['OMP_NUM_THREADS'] = '2' | |||
| ps_context().set_cache_enable(cache_enable) | |||
| def _set_rank_id(rank_id): | |||
| ps_context().set_rank_id(rank_id) | |||
| @@ -190,7 +190,10 @@ def _get_python_op(op_name, op_path, instance_name, arglist): | |||
| """Get python operator.""" | |||
| module = __import__(op_path, fromlist=["None"]) | |||
| cls = getattr(module, op_name) | |||
| op = cls(*arglist) | |||
| if op_path != "mindspore.ops.functional": | |||
| op = cls(*arglist) | |||
| else: | |||
| op = cls | |||
| op.set_prim_instance_name(instance_name) | |||
| return op | |||
| @@ -17,7 +17,8 @@ | |||
| #bash run_parameter_server_train_cluster.sh RANK_SIZE EPOCHS DEVICE_TARGET DATASET | |||
| # LOCAL_WORKER_NUM LOCAL_SERVER_NUM SERVER_NUM | |||
| # SCHED_HOST SCHED_PORT ROLE RANK_TABLE_FILE VOCAB_CACHE_SIZE | |||
| # SCHED_HOST SCHED_PORT ROLE RANK_TABLE_FILE | |||
| # VOCAB_CACHE_SIZE SPARSE | |||
| execute_path=$(pwd) | |||
| script_self=$(readlink -f "$0") | |||
| self_path=$(dirname "${script_self}") | |||
| @@ -37,11 +38,16 @@ export MS_SCHED_PORT=$9 | |||
| export MS_ROLE=${10} | |||
| export RANK_TABLE_FILE=${11} | |||
| export VOCAB_CACHE_SIZE=${12} | |||
| export SPARSE=${13} | |||
| if [[ ! -n "${12}" ]]; then | |||
| export VOCAB_CACHE_SIZE=0 | |||
| fi | |||
| if [[ ! -n "${13}" ]]; then | |||
| export SPARSE=0 | |||
| fi | |||
| echo "=====Role is $MS_ROLE======" | |||
| if [[ "$MS_ROLE" == "MS_SCHED" ]]; then | |||
| @@ -73,7 +79,7 @@ if [[ "$MS_ROLE" == "MS_WORKER" ]]; then | |||
| mpirun --allow-run-as-root -n $LOCAL_WORKER_NUM --output-filename log_output --merge-stderr-to-stdout \ | |||
| python -s ${self_path}/../train_and_eval_parameter_server_distribute.py \ | |||
| --device_target=$DEVICE --data_path=$DATASET --epochs=$EPOCH_SIZE --parameter_server=1 \ | |||
| --vocab_cache_size=$VOCAB_CACHE_SIZE --dropout_flag=1 >worker.log 2>&1 & | |||
| --vocab_cache_size=$VOCAB_CACHE_SIZE --sparse=$SPARSE --dropout_flag=1 >worker.log 2>&1 & | |||
| else | |||
| for((i=0;i<$LOCAL_WORKER_NUM;i++)); | |||
| do | |||
| @@ -84,7 +90,7 @@ if [[ "$MS_ROLE" == "MS_WORKER" ]]; then | |||
| export DEVICE_ID=$i | |||
| python -s ${self_path}/../train_and_eval_parameter_server_distribute.py \ | |||
| --device_target=$DEVICE_TARGET --data_path=$DATASET --epochs=$EPOCH_SIZE --parameter_server=1 \ | |||
| --vocab_cache_size=$VOCAB_CACHE_SIZE --dropout_flag=1 >worker_$i.log 2>&1 & | |||
| --vocab_cache_size=$VOCAB_CACHE_SIZE --sparse=$SPARSE --dropout_flag=1 >worker_$i.log 2>&1 & | |||
| done | |||
| fi | |||
| fi | |||
| @@ -17,7 +17,7 @@ | |||
| #bash run_parameter_server_train_distribute.sh RANK_SIZE EPOCHS DEVICE_TARGET DATASET | |||
| # SERVER_NUM SCHED_HOST SCHED_PORT RANK_TABLE_FILE | |||
| # VOCAB_CACHE_SIZE | |||
| # VOCAB_CACHE_SIZE SPARSE | |||
| execute_path=$(pwd) | |||
| script_self=$(readlink -f "$0") | |||
| self_path=$(dirname "${script_self}") | |||
| @@ -33,11 +33,16 @@ export MS_SCHED_HOST=$6 | |||
| export MS_SCHED_PORT=$7 | |||
| export RANK_TABLE_FILE=$8 | |||
| export VOCAB_CACHE_SIZE=$9 | |||
| export SPARSE=${10} | |||
| if [[ ! -n "$9" ]]; then | |||
| export VOCAB_CACHE_SIZE=0 | |||
| fi | |||
| if [[ ! -n "${10}" ]]; then | |||
| export SPARSE=0 | |||
| fi | |||
| export MS_ROLE=MS_SCHED | |||
| rm -rf ${execute_path}/sched/ | |||
| mkdir ${execute_path}/sched/ | |||
| @@ -65,7 +70,7 @@ if [[ "X$DEVICE_TARGET" == "XGPU" ]]; then | |||
| mpirun --allow-run-as-root -n $RANK_SIZE --output-filename log_output --merge-stderr-to-stdout \ | |||
| python -s ${self_path}/../train_and_eval_parameter_server_distribute.py \ | |||
| --device_target=$DEVICE_TARGET --data_path=$DATASET --epochs=$EPOCH_SIZE --parameter_server=1 \ | |||
| --vocab_cache_size=$VOCAB_CACHE_SIZE --dropout_flag=1 >worker.log 2>&1 & | |||
| --vocab_cache_size=$VOCAB_CACHE_SIZE --sparse=$SPARSE --dropout_flag=1 >worker.log 2>&1 & | |||
| else | |||
| for((i=0;i<$MS_WORKER_NUM;i++)); | |||
| do | |||
| @@ -76,7 +81,7 @@ else | |||
| export DEVICE_ID=$i | |||
| python -s ${self_path}/../train_and_eval_parameter_server_distribute.py \ | |||
| --device_target=$DEVICE_TARGET --data_path=$DATASET --epochs=$EPOCH_SIZE --parameter_server=1 \ | |||
| --vocab_cache_size=$VOCAB_CACHE_SIZE --dropout_flag=1 >worker_$i.log 2>&1 & | |||
| --vocab_cache_size=$VOCAB_CACHE_SIZE --sparse=$SPARSE --dropout_flag=1 >worker_$i.log 2>&1 & | |||
| done | |||
| fi | |||
| @@ -16,7 +16,7 @@ | |||
| #bash run_parameter_server_train_standalone.sh EPOCHS DEVICE_TARGET DATASET SERVER_NUM SCHED_HOST | |||
| # SCHED_PORT DEVICE_ID VOCAB_CACHE_SIZE | |||
| # SCHED_PORT DEVICE_ID VOCAB_CACHE_SIZE SPARSE | |||
| execute_path=$(pwd) | |||
| script_self=$(readlink -f "$0") | |||
| self_path=$(dirname "${script_self}") | |||
| @@ -31,11 +31,16 @@ export MS_SCHED_HOST=$5 | |||
| export MS_SCHED_PORT=$6 | |||
| DEVICE_ID=$7 | |||
| export VOCAB_CACHE_SIZE=$8 | |||
| export SPARSE=$9 | |||
| if [[ ! -n "$8" ]]; then | |||
| export VOCAB_CACHE_SIZE=0 | |||
| fi | |||
| if [[ ! -n "$9" ]]; then | |||
| export SPARSE=0 | |||
| fi | |||
| # Set device id | |||
| if [[ "X$DEVICE_TARGET" == "XGPU" ]]; then | |||
| if [[ ! -n "$DEVICE_ID" ]]; then | |||
| @@ -76,4 +81,4 @@ mkdir ${execute_path}/worker/ | |||
| cd ${execute_path}/worker/ || exit | |||
| python -s ${self_path}/../train_and_eval_parameter_server_standalone.py --device_target=$DEVICE_TARGET \ | |||
| --epochs=$EPOCH_SIZE --data_path=$DATASET --parameter_server=1 \ | |||
| --vocab_cache_size=$VOCAB_CACHE_SIZE --dropout_flag=1 >worker.log 2>&1 & | |||
| --vocab_cache_size=$VOCAB_CACHE_SIZE --sparse=$SPARSE --dropout_flag=1 >worker.log 2>&1 & | |||
| @@ -115,8 +115,11 @@ class EvalCallBack(Callback): | |||
| if parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL, | |||
| ParallelMode.DATA_PARALLEL): | |||
| rank_id = get_rank() | |||
| enable_data_sink = not self.sparse | |||
| if bool(self.config.parameter_server): | |||
| enable_data_sink = True | |||
| start_time = time.time() | |||
| out = self.model.eval(self.eval_dataset, dataset_sink_mode=(not self.sparse)) | |||
| out = self.model.eval(self.eval_dataset, dataset_sink_mode=enable_data_sink) | |||
| end_time = time.time() | |||
| eval_time = int(end_time - start_time) | |||
| @@ -202,7 +202,7 @@ class WideDeepModel(nn.Cell): | |||
| self.unique = P.Unique().shard(((1,),)) | |||
| self.wide_gatherv2 = P.GatherV2() | |||
| self.deep_gatherv2 = P.GatherV2() | |||
| if is_auto_parallel and sparse and not is_field_slice: | |||
| if is_auto_parallel and sparse and not is_field_slice and not parameter_server: | |||
| target = 'DEVICE' | |||
| if host_device_mix: | |||
| target = 'CPU' | |||
| @@ -376,12 +376,12 @@ class TrainStepWrap(nn.Cell): | |||
| self.weights_w = ParameterTuple(weights_w) | |||
| self.weights_d = ParameterTuple(weights_d) | |||
| if (sparse and is_auto_parallel) or (parameter_server and not cache_enable): | |||
| if (sparse and is_auto_parallel) or (sparse and parameter_server): | |||
| self.optimizer_d = LazyAdam( | |||
| self.weights_d, learning_rate=3.5e-4, eps=1e-8, loss_scale=sens) | |||
| self.optimizer_w = FTRL(learning_rate=5e-2, params=self.weights_w, | |||
| l1=1e-8, l2=1e-8, initial_accum=1.0, loss_scale=sens) | |||
| if host_device_mix or parameter_server: | |||
| if host_device_mix or (parameter_server and not cache_enable): | |||
| self.optimizer_w.target = "CPU" | |||
| self.optimizer_d.target = "CPU" | |||
| else: | |||
| @@ -43,7 +43,7 @@ def get_wide_deep_net(config): | |||
| if cache_enable: | |||
| loss_net = VirtualDatasetCellTriple(loss_net) | |||
| train_net = TrainStepWrap(loss_net, parameter_server=bool(config.parameter_server), | |||
| cache_enable=(config.vocab_cache_size > 0)) | |||
| sparse=config.sparse, cache_enable=(config.vocab_cache_size > 0)) | |||
| eval_net = PredictWithSigmoid(wide_deep_net) | |||
| if cache_enable: | |||
| eval_net = VirtualDatasetCellTriple(eval_net) | |||
| @@ -138,7 +138,7 @@ def train_and_eval(config): | |||
| callback_list.append(ckpoint_cb) | |||
| model.train(epochs, ds_train, | |||
| callbacks=callback_list, | |||
| dataset_sink_mode=bool(parameter_server and cache_enable)) | |||
| dataset_sink_mode=(parameter_server and cache_enable)) | |||
| if __name__ == "__main__": | |||
| @@ -148,7 +148,6 @@ if __name__ == "__main__": | |||
| cache_enable = wide_deep_config.vocab_cache_size > 0 | |||
| if cache_enable and wide_deep_config.device_target != "GPU": | |||
| context.set_context(variable_memory_max_size="24GB") | |||
| context.set_context(enable_sparse=True) | |||
| context.set_ps_context(enable_ps=True) | |||
| init() | |||
| context.set_context(save_graphs_path='./graphs_of_device_id_'+str(get_rank())) | |||
| @@ -159,5 +158,8 @@ if __name__ == "__main__": | |||
| else: | |||
| context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True, | |||
| device_num=get_group_size()) | |||
| wide_deep_config.sparse = True | |||
| if wide_deep_config.sparse: | |||
| context.set_context(enable_sparse=True) | |||
| train_and_eval(wide_deep_config) | |||
| @@ -29,7 +29,6 @@ from src.metrics import AUCMetric | |||
| from src.config import WideDeepConfig | |||
| sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |||
| context.set_context(enable_sparse=True) | |||
| def get_wide_deep_net(config): | |||
| @@ -39,7 +38,7 @@ def get_wide_deep_net(config): | |||
| wide_deep_net = WideDeepModel(config) | |||
| loss_net = NetWithLossClass(wide_deep_net, config) | |||
| train_net = TrainStepWrap(loss_net, parameter_server=bool(config.parameter_server), | |||
| cache_enable=(config.vocab_cache_size > 0)) | |||
| sparse=config.sparse, cache_enable=(config.vocab_cache_size > 0)) | |||
| eval_net = PredictWithSigmoid(wide_deep_net) | |||
| return train_net, eval_net | |||
| @@ -81,7 +80,6 @@ def train_and_eval(config): | |||
| else: | |||
| dataset_type = DataType.H5 | |||
| parameter_server = bool(config.parameter_server) | |||
| cache_enable = config.vocab_cache_size > 0 | |||
| print("epochs is {}".format(epochs)) | |||
| ds_train = create_dataset(data_path, train_mode=True, epochs=1, | |||
| batch_size=batch_size, data_type=dataset_type) | |||
| @@ -121,6 +119,11 @@ if __name__ == "__main__": | |||
| wide_deep_config.argparse_init() | |||
| context.set_context(mode=context.GRAPH_MODE, device_target=wide_deep_config.device_target, save_graphs=True) | |||
| cache_enable = wide_deep_config.vocab_cache_size > 0 | |||
| if not cache_enable: | |||
| wide_deep_config.sparse = True | |||
| if wide_deep_config.sparse: | |||
| context.set_context(enable_sparse=True) | |||
| context.set_ps_context(enable_ps=True) | |||
| train_and_eval(wide_deep_config) | |||
| @@ -0,0 +1,128 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| """ | |||
| callbacks | |||
| """ | |||
| import time | |||
| from mindspore.train.callback import Callback | |||
| from mindspore import context | |||
| from mindspore.context import ParallelMode | |||
| from mindspore.communication.management import get_rank | |||
| def add_write(file_path, out_str): | |||
| """ | |||
| add lines to the file | |||
| """ | |||
| with open(file_path, 'a+', encoding="utf-8") as file_out: | |||
| file_out.write(out_str + "\n") | |||
| class LossCallBack(Callback): | |||
| """ | |||
| Monitor the loss in training. | |||
| If the loss is NAN or INF, terminate the training. | |||
| Note: | |||
| If per_print_times is 0, do NOT print loss. | |||
| If this process is MS_PSERVER role, do not run callbacks. | |||
| Args: | |||
| per_print_times (int): Print loss every times. Default: 1. | |||
| """ | |||
| def __init__(self, config=None, per_print_times=1): | |||
| super(LossCallBack, self).__init__() | |||
| if not isinstance(per_print_times, int) or per_print_times < 0: | |||
| raise ValueError("per_print_times must be in and >= 0.") | |||
| self._per_print_times = per_print_times | |||
| self.config = config | |||
| def step_end(self, run_context): | |||
| """Monitor the loss in training.""" | |||
| cb_params = run_context.original_args() | |||
| if cb_params.net_outputs is None: | |||
| return | |||
| wide_loss, deep_loss = cb_params.net_outputs[0].asnumpy(), cb_params.net_outputs[1].asnumpy() | |||
| cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1 | |||
| cur_num = cb_params.cur_step_num | |||
| rank_id = 0 | |||
| parallel_mode = context.get_auto_parallel_context("parallel_mode") | |||
| if parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL, | |||
| ParallelMode.DATA_PARALLEL): | |||
| rank_id = get_rank() | |||
| print("===loss===", rank_id, cb_params.cur_epoch_num, cur_step_in_epoch, | |||
| wide_loss, deep_loss, flush=True) | |||
| # raise ValueError | |||
| if self._per_print_times != 0 and cur_num % self._per_print_times == 0 and self.config is not None: | |||
| loss_file = open(self.config.loss_file_name, "a+") | |||
| loss_file.write("epoch: %s, step: %s, wide_loss: %s, deep_loss: %s" % | |||
| (cb_params.cur_epoch_num, cur_step_in_epoch, wide_loss, deep_loss)) | |||
| loss_file.write("\n") | |||
| loss_file.close() | |||
| print("epoch: %s, step: %s, wide_loss: %s, deep_loss: %s" % | |||
| (cb_params.cur_epoch_num, cur_step_in_epoch, wide_loss, deep_loss)) | |||
| class EvalCallBack(Callback): | |||
| """ | |||
| Monitor the loss in evaluating. | |||
| If the loss is NAN or INF, terminate evaluating. | |||
| Note: | |||
| If per_print_times is 0, do NOT print loss. | |||
| Args: | |||
| print_per_step (int): Print loss every times. Default: 1. | |||
| """ | |||
| def __init__(self, model, eval_dataset, auc_metric, config, print_per_step=1): | |||
| super(EvalCallBack, self).__init__() | |||
| if not isinstance(print_per_step, int) or print_per_step < 0: | |||
| raise ValueError("print_per_step must be int and >= 0.") | |||
| self.print_per_step = print_per_step | |||
| self.model = model | |||
| self.eval_dataset = eval_dataset | |||
| self.aucMetric = auc_metric | |||
| self.aucMetric.clear() | |||
| self.eval_file_name = config.eval_file_name | |||
| self.eval_values = [] | |||
| self.sparse = config.sparse | |||
| self.config = config | |||
| def epoch_end(self, run_context): | |||
| """ | |||
| epoch end | |||
| """ | |||
| self.aucMetric.clear() | |||
| parallel_mode = context.get_auto_parallel_context("parallel_mode") | |||
| if parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL): | |||
| context.set_auto_parallel_context(strategy_ckpt_save_file="", | |||
| strategy_ckpt_load_file=self.config.stra_ckpt) | |||
| rank_id = 0 | |||
| if parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL, | |||
| ParallelMode.DATA_PARALLEL): | |||
| rank_id = get_rank() | |||
| start_time = time.time() | |||
| out = self.model.eval(self.eval_dataset, dataset_sink_mode=(not self.sparse)) | |||
| end_time = time.time() | |||
| eval_time = int(end_time - start_time) | |||
| time_str = time.strftime("%Y-%m-%d %H:%M%S", time.localtime()) | |||
| out_str = "{} == Rank: {} == EvalCallBack model.eval(): {}; eval_time: {}s".\ | |||
| format(time_str, rank_id, out.values(), eval_time) | |||
| print(out_str) | |||
| self.eval_values = out.values() | |||
| add_write(self.eval_file_name, out_str) | |||
| @@ -34,6 +34,7 @@ cp -r ${CODE_DIR} ${BASE_PATH}/wide_and_deep | |||
| cp -f ${BASE_PATH}/python_file_for_ci/train_and_test_multinpu_ci.py ${BASE_PATH}/wide_and_deep/train_and_test_multinpu_ci.py | |||
| cp -f ${BASE_PATH}/python_file_for_ci/__init__.py ${BASE_PATH}/wide_and_deep/__init__.py | |||
| cp -f ${BASE_PATH}/python_file_for_ci/config.py ${BASE_PATH}/wide_and_deep/src/config.py | |||
| cp -f ${BASE_PATH}/python_file_for_ci/callbacks.py ${BASE_PATH}/wide_and_deep/src/callbacks.py | |||
| cp -f ${BASE_PATH}/python_file_for_ci/datasets.py ${BASE_PATH}/wide_and_deep/src/datasets.py | |||
| cp -f ${BASE_PATH}/python_file_for_ci/wide_and_deep.py ${BASE_PATH}/wide_and_deep/src/wide_and_deep.py | |||
| source ${BASE_PATH}/env.sh | |||
| @@ -55,7 +56,7 @@ for((i=0; i<${DEVICE_NUM}; i++)); do | |||
| wait ${process_pid[i]} | |||
| status=`echo $?` | |||
| if [ "${status}" != "0" ]; then | |||
| echo "[ERROR] test wide_and_deep semi auto parallel failed. status: ${status}" | |||
| echo "[ERROR] test wide_and_deep semi auto parallel failed. status: ${status}" | |||
| exit 1 | |||
| else | |||
| echo "[INFO] test wide_and_deep semi auto parallel success." | |||