From 547a49356bcd12bd8e31e81383c19fdec31f4b9d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Cxujincai=E2=80=9D?= <“xujincai@huawei.com”> Date: Thu, 25 Feb 2021 10:22:10 +0800 Subject: [PATCH] ut for agent config acquire --- .../notify_distributed/notify_worker.cc | 58 +++++---- .../notify_distributed/notify_worker.h | 2 + .../ut/cpp/tests/test_agent_config_acquire.cc | 117 ++++++++++++++++++ .../cpp/tests/test_init_config_on_start_up.cc | 2 + 4 files changed, 152 insertions(+), 27 deletions(-) create mode 100644 tests/ut/cpp/tests/test_agent_config_acquire.cc diff --git a/mindspore_serving/ccsrc/worker/distributed_worker/notify_distributed/notify_worker.cc b/mindspore_serving/ccsrc/worker/distributed_worker/notify_distributed/notify_worker.cc index 82d14cd..afe3b38 100644 --- a/mindspore_serving/ccsrc/worker/distributed_worker/notify_distributed/notify_worker.cc +++ b/mindspore_serving/ccsrc/worker/distributed_worker/notify_distributed/notify_worker.cc @@ -115,33 +115,7 @@ Status GrpcNotifyDistributeWorker::GetAgentsConfigsFromWorker(const std::string context.set_deadline(deadline); grpc::Status status = stub->AgentConfigAcquire(&context, request, &reply); if (status.ok()) { - MSI_LOG(INFO) << "Success to get Agents configs from Worker, and begin to parser"; - // parser reply message:AgentConfigAcquireReply, parameter:rank_table_content - config->rank_table_content = reply.rank_table_content(); - // parser reply message:AgentConfigAcquireReply, parameter:rank_list - for (auto &temp_rank : reply.rank_list()) { - OneRankConfig ome_rank_config; - ome_rank_config.ip = temp_rank.ip(); - ome_rank_config.device_id = temp_rank.device_id(); - config->rank_list.push_back(ome_rank_config); - } - // parser reply message:AgentConfigAcquireReply, parameter:common_meta - auto &temp_common_meta = reply.common_meta(); - config->common_meta.servable_name = temp_common_meta.servable_name(); - config->common_meta.with_batch_dim = temp_common_meta.with_batch_dim(); - for (auto &temp_without_batch_dim_inputs : temp_common_meta.without_batch_dim_inputs()) { - config->common_meta.without_batch_dim_inputs.push_back(temp_without_batch_dim_inputs); - } - config->common_meta.inputs_count = temp_common_meta.inputs_count(); - config->common_meta.outputs_count = temp_common_meta.outputs_count(); - - // parser reply message:AgentConfigAcquireReply, parameter:distributed_meta - auto &temp_distributed_meta = reply.distributed_meta(); - config->distributed_meta.rank_size = temp_distributed_meta.rank_size(); - config->distributed_meta.stage_size = temp_distributed_meta.stage_size(); - MSI_LOG(INFO) << "Success to parser reply message and save to DistributedServableConfig"; - - return SUCCESS; + return ParseAgentConfigAcquireReply(reply, config); } MSI_LOG_INFO << "Grpc message: " << status.error_code() << ", " << status.error_message(); std::this_thread::sleep_for(std::chrono::milliseconds(REGISTER_INTERVAL * 1000)); @@ -151,6 +125,36 @@ Status GrpcNotifyDistributeWorker::GetAgentsConfigsFromWorker(const std::string } return INFER_STATUS_LOG_ERROR(SYSTEM_ERROR) << "Failed to get Agents configs from Worker"; } +Status GrpcNotifyDistributeWorker::ParseAgentConfigAcquireReply(const proto::AgentConfigAcquireReply &reply, + DistributedServableConfig *config) { + MSI_LOG(INFO) << "Success to get Agents configs from Worker, and begin to parser"; + // parser reply message:AgentConfigAcquireReply, parameter:rank_table_content + config->rank_table_content = reply.rank_table_content(); + // parser reply message:AgentConfigAcquireReply, parameter:rank_list + for (auto &temp_rank : reply.rank_list()) { + OneRankConfig ome_rank_config; + ome_rank_config.ip = temp_rank.ip(); + ome_rank_config.device_id = temp_rank.device_id(); + config->rank_list.push_back(ome_rank_config); + } + // parser reply message:AgentConfigAcquireReply, parameter:common_meta + auto &temp_common_meta = reply.common_meta(); + config->common_meta.servable_name = temp_common_meta.servable_name(); + config->common_meta.with_batch_dim = temp_common_meta.with_batch_dim(); + for (auto &temp_without_batch_dim_inputs : temp_common_meta.without_batch_dim_inputs()) { + config->common_meta.without_batch_dim_inputs.push_back(temp_without_batch_dim_inputs); + } + config->common_meta.inputs_count = temp_common_meta.inputs_count(); + config->common_meta.outputs_count = temp_common_meta.outputs_count(); + + // parser reply message:AgentConfigAcquireReply, parameter:distributed_meta + auto &temp_distributed_meta = reply.distributed_meta(); + config->distributed_meta.rank_size = temp_distributed_meta.rank_size(); + config->distributed_meta.stage_size = temp_distributed_meta.stage_size(); + MSI_LOG(INFO) << "Success to parser reply message and save to DistributedServableConfig"; + + return SUCCESS; +} } // namespace serving } // namespace mindspore diff --git a/mindspore_serving/ccsrc/worker/distributed_worker/notify_distributed/notify_worker.h b/mindspore_serving/ccsrc/worker/distributed_worker/notify_distributed/notify_worker.h index cf11d55..830fba9 100644 --- a/mindspore_serving/ccsrc/worker/distributed_worker/notify_distributed/notify_worker.h +++ b/mindspore_serving/ccsrc/worker/distributed_worker/notify_distributed/notify_worker.h @@ -41,6 +41,8 @@ class MS_API GrpcNotifyDistributeWorker { DistributedServableConfig *config); private: + static Status ParseAgentConfigAcquireReply(const proto::AgentConfigAcquireReply &reply, + DistributedServableConfig *config); std::string distributed_worker_ip_; uint32_t distributed_worker_port_; std::string host_ip_; diff --git a/tests/ut/cpp/tests/test_agent_config_acquire.cc b/tests/ut/cpp/tests/test_agent_config_acquire.cc new file mode 100644 index 0000000..02b7645 --- /dev/null +++ b/tests/ut/cpp/tests/test_agent_config_acquire.cc @@ -0,0 +1,117 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "common/common_test.h" +#include "common/tensor_base.h" +#define private public +#include "worker/distributed_worker/distributed_process/distributed_process.h" +#include "worker/distributed_worker/notify_distributed/notify_worker.h" +#undef private + +using std::string; +using std::vector; +namespace mindspore { +namespace serving { +class TestAgentConfigAcquire : public UT::Common { + public: + TestAgentConfigAcquire() = default; + virtual void SetUp() {} + virtual void TearDown() {} +}; + +TEST_F(TestAgentConfigAcquire, test_agent_config_acquire_success) { + std::shared_ptr servable = std::make_shared(); + std::string rank_table_content = "rank table content"; + CommonServableMeta commonServableMeta; + commonServableMeta.servable_name = "servable_name"; + commonServableMeta.outputs_count = 1; + commonServableMeta.inputs_count = 1; + commonServableMeta.with_batch_dim = false; + commonServableMeta.without_batch_dim_inputs.push_back(8); + DistributedServableMeta distributedServableMeta; + distributedServableMeta.stage_size = 8; + distributedServableMeta.rank_size = 8; + OneRankConfig oneRankConfig; + oneRankConfig.ip = "1.1.1.1"; + oneRankConfig.device_id = 0; + servable->config_.rank_table_content = rank_table_content; + servable->config_.common_meta = commonServableMeta; + servable->config_.distributed_meta = distributedServableMeta; + servable->config_.rank_list.push_back(oneRankConfig); + servable->config_loaded_ = true; + const std::string server_address = "any_addr"; + MSDistributedImpl mSDistributedImpl(servable, server_address); + grpc::ServerContext context; + const proto::AgentConfigAcquireRequest request; + proto::AgentConfigAcquireReply reply; + grpc::Status status = mSDistributedImpl.AgentConfigAcquire(&context, &request, &reply); + ASSERT_EQ(status.error_code(), 0); + + DistributedServableConfig config; + GrpcNotifyDistributeWorker::ParseAgentConfigAcquireReply(reply, &config); + ASSERT_EQ(config.rank_table_content, rank_table_content); + ASSERT_EQ(config.common_meta.servable_name, "servable_name"); + ASSERT_EQ(config.common_meta.inputs_count, 1); + ASSERT_EQ(config.common_meta.outputs_count, 1); + ASSERT_EQ(config.common_meta.with_batch_dim, false); + ASSERT_EQ(config.common_meta.without_batch_dim_inputs.size(), 1); + ASSERT_EQ(config.common_meta.without_batch_dim_inputs.at(0), 8); + ASSERT_EQ(config.distributed_meta.rank_size, 8); + ASSERT_EQ(config.distributed_meta.stage_size, 8); + ASSERT_EQ(config.rank_list.size(), 1); + OneRankConfig tempRankConfig = config.rank_list.at(0); + ASSERT_EQ(tempRankConfig.device_id, 0); + ASSERT_EQ(tempRankConfig.ip, "1.1.1.1"); +} + +TEST_F(TestAgentConfigAcquire, test_agent_config_acquire_not_load_config_failed) { + std::shared_ptr servable = std::make_shared(); + servable->config_loaded_ = true; + const std::string server_address = "any_addr"; + MSDistributedImpl mSDistributedImpl(servable, server_address); + grpc::ServerContext context; + const proto::AgentConfigAcquireRequest request; + proto::AgentConfigAcquireReply reply; + const grpc::Status status = mSDistributedImpl.AgentConfigAcquire(&context, &request, &reply); + ASSERT_EQ(status.error_code(), 1); +} + +TEST_F(TestAgentConfigAcquire, test_agent_config_acquire_not_init_config_failed) { + std::shared_ptr servable = std::make_shared(); + std::string rank_table_content = "rank table content"; + CommonServableMeta commonServableMeta; + commonServableMeta.servable_name = "servable_name"; + commonServableMeta.outputs_count = 1; + commonServableMeta.inputs_count = 1; + commonServableMeta.with_batch_dim = false; + commonServableMeta.without_batch_dim_inputs.push_back(8); + DistributedServableMeta distributedServableMeta; + distributedServableMeta.stage_size = 8; + distributedServableMeta.rank_size = 8; + servable->config_.rank_table_content = rank_table_content; + servable->config_.common_meta = commonServableMeta; + servable->config_.distributed_meta = distributedServableMeta; + servable->config_loaded_ = true; + const std::string server_address = "any_addr"; + MSDistributedImpl mSDistributedImpl(servable, server_address); + grpc::ServerContext context; + const proto::AgentConfigAcquireRequest request; + proto::AgentConfigAcquireReply reply; + const grpc::Status status = mSDistributedImpl.AgentConfigAcquire(&context, &request, &reply); + ASSERT_EQ(status.error_code(), 1); +} + +} // namespace serving +} // namespace mindspore diff --git a/tests/ut/cpp/tests/test_init_config_on_start_up.cc b/tests/ut/cpp/tests/test_init_config_on_start_up.cc index f235bb9..fcbc7e9 100644 --- a/tests/ut/cpp/tests/test_init_config_on_start_up.cc +++ b/tests/ut/cpp/tests/test_init_config_on_start_up.cc @@ -40,6 +40,7 @@ TEST_F(TestParseRankTableFile, test_init_config_on_startup_empty_file_failed) { std::ofstream fp(empty_rank_table_file); fp << "empty rank table file"; fp.close(); + config_file_list_.emplace(empty_rank_table_file); auto servable = std::make_shared(); auto status = servable->InitConfigOnStartup(empty_rank_table_file); ASSERT_EQ(status.StatusCode(), INVALID_INPUTS); @@ -64,6 +65,7 @@ TEST_F(TestParseRankTableFile, test_init_config_on_startup_success) { std::ofstream fp(rank_table_file); fp << rank_table_server_list; fp.close(); + config_file_list_.emplace(rank_table_file); auto servable = std::make_shared(); auto status = servable->InitConfigOnStartup(rank_table_file); ASSERT_EQ(status.StatusCode(), SUCCESS);