| @@ -115,33 +115,7 @@ Status GrpcNotifyDistributeWorker::GetAgentsConfigsFromWorker(const std::string | |||||
| context.set_deadline(deadline); | context.set_deadline(deadline); | ||||
| grpc::Status status = stub->AgentConfigAcquire(&context, request, &reply); | grpc::Status status = stub->AgentConfigAcquire(&context, request, &reply); | ||||
| if (status.ok()) { | if (status.ok()) { | ||||
| MSI_LOG(INFO) << "Success to get Agents configs from Worker, and begin to parser"; | |||||
| // parser reply message:AgentConfigAcquireReply, parameter:rank_table_content | |||||
| config->rank_table_content = reply.rank_table_content(); | |||||
| // parser reply message:AgentConfigAcquireReply, parameter:rank_list | |||||
| for (auto &temp_rank : reply.rank_list()) { | |||||
| OneRankConfig ome_rank_config; | |||||
| ome_rank_config.ip = temp_rank.ip(); | |||||
| ome_rank_config.device_id = temp_rank.device_id(); | |||||
| config->rank_list.push_back(ome_rank_config); | |||||
| } | |||||
| // parser reply message:AgentConfigAcquireReply, parameter:common_meta | |||||
| auto &temp_common_meta = reply.common_meta(); | |||||
| config->common_meta.servable_name = temp_common_meta.servable_name(); | |||||
| config->common_meta.with_batch_dim = temp_common_meta.with_batch_dim(); | |||||
| for (auto &temp_without_batch_dim_inputs : temp_common_meta.without_batch_dim_inputs()) { | |||||
| config->common_meta.without_batch_dim_inputs.push_back(temp_without_batch_dim_inputs); | |||||
| } | |||||
| config->common_meta.inputs_count = temp_common_meta.inputs_count(); | |||||
| config->common_meta.outputs_count = temp_common_meta.outputs_count(); | |||||
| // parser reply message:AgentConfigAcquireReply, parameter:distributed_meta | |||||
| auto &temp_distributed_meta = reply.distributed_meta(); | |||||
| config->distributed_meta.rank_size = temp_distributed_meta.rank_size(); | |||||
| config->distributed_meta.stage_size = temp_distributed_meta.stage_size(); | |||||
| MSI_LOG(INFO) << "Success to parser reply message and save to DistributedServableConfig"; | |||||
| return SUCCESS; | |||||
| return ParseAgentConfigAcquireReply(reply, config); | |||||
| } | } | ||||
| MSI_LOG_INFO << "Grpc message: " << status.error_code() << ", " << status.error_message(); | MSI_LOG_INFO << "Grpc message: " << status.error_code() << ", " << status.error_message(); | ||||
| std::this_thread::sleep_for(std::chrono::milliseconds(REGISTER_INTERVAL * 1000)); | std::this_thread::sleep_for(std::chrono::milliseconds(REGISTER_INTERVAL * 1000)); | ||||
| @@ -151,6 +125,36 @@ Status GrpcNotifyDistributeWorker::GetAgentsConfigsFromWorker(const std::string | |||||
| } | } | ||||
| return INFER_STATUS_LOG_ERROR(SYSTEM_ERROR) << "Failed to get Agents configs from Worker"; | return INFER_STATUS_LOG_ERROR(SYSTEM_ERROR) << "Failed to get Agents configs from Worker"; | ||||
| } | } | ||||
| Status GrpcNotifyDistributeWorker::ParseAgentConfigAcquireReply(const proto::AgentConfigAcquireReply &reply, | |||||
| DistributedServableConfig *config) { | |||||
| MSI_LOG(INFO) << "Success to get Agents configs from Worker, and begin to parser"; | |||||
| // parser reply message:AgentConfigAcquireReply, parameter:rank_table_content | |||||
| config->rank_table_content = reply.rank_table_content(); | |||||
| // parser reply message:AgentConfigAcquireReply, parameter:rank_list | |||||
| for (auto &temp_rank : reply.rank_list()) { | |||||
| OneRankConfig ome_rank_config; | |||||
| ome_rank_config.ip = temp_rank.ip(); | |||||
| ome_rank_config.device_id = temp_rank.device_id(); | |||||
| config->rank_list.push_back(ome_rank_config); | |||||
| } | |||||
| // parser reply message:AgentConfigAcquireReply, parameter:common_meta | |||||
| auto &temp_common_meta = reply.common_meta(); | |||||
| config->common_meta.servable_name = temp_common_meta.servable_name(); | |||||
| config->common_meta.with_batch_dim = temp_common_meta.with_batch_dim(); | |||||
| for (auto &temp_without_batch_dim_inputs : temp_common_meta.without_batch_dim_inputs()) { | |||||
| config->common_meta.without_batch_dim_inputs.push_back(temp_without_batch_dim_inputs); | |||||
| } | |||||
| config->common_meta.inputs_count = temp_common_meta.inputs_count(); | |||||
| config->common_meta.outputs_count = temp_common_meta.outputs_count(); | |||||
| // parser reply message:AgentConfigAcquireReply, parameter:distributed_meta | |||||
| auto &temp_distributed_meta = reply.distributed_meta(); | |||||
| config->distributed_meta.rank_size = temp_distributed_meta.rank_size(); | |||||
| config->distributed_meta.stage_size = temp_distributed_meta.stage_size(); | |||||
| MSI_LOG(INFO) << "Success to parser reply message and save to DistributedServableConfig"; | |||||
| return SUCCESS; | |||||
| } | |||||
| } // namespace serving | } // namespace serving | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -41,6 +41,8 @@ class MS_API GrpcNotifyDistributeWorker { | |||||
| DistributedServableConfig *config); | DistributedServableConfig *config); | ||||
| private: | private: | ||||
| static Status ParseAgentConfigAcquireReply(const proto::AgentConfigAcquireReply &reply, | |||||
| DistributedServableConfig *config); | |||||
| std::string distributed_worker_ip_; | std::string distributed_worker_ip_; | ||||
| uint32_t distributed_worker_port_; | uint32_t distributed_worker_port_; | ||||
| std::string host_ip_; | std::string host_ip_; | ||||
| @@ -0,0 +1,117 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "common/common_test.h" | |||||
| #include "common/tensor_base.h" | |||||
| #define private public | |||||
| #include "worker/distributed_worker/distributed_process/distributed_process.h" | |||||
| #include "worker/distributed_worker/notify_distributed/notify_worker.h" | |||||
| #undef private | |||||
| using std::string; | |||||
| using std::vector; | |||||
| namespace mindspore { | |||||
| namespace serving { | |||||
| class TestAgentConfigAcquire : public UT::Common { | |||||
| public: | |||||
| TestAgentConfigAcquire() = default; | |||||
| virtual void SetUp() {} | |||||
| virtual void TearDown() {} | |||||
| }; | |||||
| TEST_F(TestAgentConfigAcquire, test_agent_config_acquire_success) { | |||||
| std::shared_ptr<DistributedServable> servable = std::make_shared<DistributedServable>(); | |||||
| std::string rank_table_content = "rank table content"; | |||||
| CommonServableMeta commonServableMeta; | |||||
| commonServableMeta.servable_name = "servable_name"; | |||||
| commonServableMeta.outputs_count = 1; | |||||
| commonServableMeta.inputs_count = 1; | |||||
| commonServableMeta.with_batch_dim = false; | |||||
| commonServableMeta.without_batch_dim_inputs.push_back(8); | |||||
| DistributedServableMeta distributedServableMeta; | |||||
| distributedServableMeta.stage_size = 8; | |||||
| distributedServableMeta.rank_size = 8; | |||||
| OneRankConfig oneRankConfig; | |||||
| oneRankConfig.ip = "1.1.1.1"; | |||||
| oneRankConfig.device_id = 0; | |||||
| servable->config_.rank_table_content = rank_table_content; | |||||
| servable->config_.common_meta = commonServableMeta; | |||||
| servable->config_.distributed_meta = distributedServableMeta; | |||||
| servable->config_.rank_list.push_back(oneRankConfig); | |||||
| servable->config_loaded_ = true; | |||||
| const std::string server_address = "any_addr"; | |||||
| MSDistributedImpl mSDistributedImpl(servable, server_address); | |||||
| grpc::ServerContext context; | |||||
| const proto::AgentConfigAcquireRequest request; | |||||
| proto::AgentConfigAcquireReply reply; | |||||
| grpc::Status status = mSDistributedImpl.AgentConfigAcquire(&context, &request, &reply); | |||||
| ASSERT_EQ(status.error_code(), 0); | |||||
| DistributedServableConfig config; | |||||
| GrpcNotifyDistributeWorker::ParseAgentConfigAcquireReply(reply, &config); | |||||
| ASSERT_EQ(config.rank_table_content, rank_table_content); | |||||
| ASSERT_EQ(config.common_meta.servable_name, "servable_name"); | |||||
| ASSERT_EQ(config.common_meta.inputs_count, 1); | |||||
| ASSERT_EQ(config.common_meta.outputs_count, 1); | |||||
| ASSERT_EQ(config.common_meta.with_batch_dim, false); | |||||
| ASSERT_EQ(config.common_meta.without_batch_dim_inputs.size(), 1); | |||||
| ASSERT_EQ(config.common_meta.without_batch_dim_inputs.at(0), 8); | |||||
| ASSERT_EQ(config.distributed_meta.rank_size, 8); | |||||
| ASSERT_EQ(config.distributed_meta.stage_size, 8); | |||||
| ASSERT_EQ(config.rank_list.size(), 1); | |||||
| OneRankConfig tempRankConfig = config.rank_list.at(0); | |||||
| ASSERT_EQ(tempRankConfig.device_id, 0); | |||||
| ASSERT_EQ(tempRankConfig.ip, "1.1.1.1"); | |||||
| } | |||||
| TEST_F(TestAgentConfigAcquire, test_agent_config_acquire_not_load_config_failed) { | |||||
| std::shared_ptr<DistributedServable> servable = std::make_shared<DistributedServable>(); | |||||
| servable->config_loaded_ = true; | |||||
| const std::string server_address = "any_addr"; | |||||
| MSDistributedImpl mSDistributedImpl(servable, server_address); | |||||
| grpc::ServerContext context; | |||||
| const proto::AgentConfigAcquireRequest request; | |||||
| proto::AgentConfigAcquireReply reply; | |||||
| const grpc::Status status = mSDistributedImpl.AgentConfigAcquire(&context, &request, &reply); | |||||
| ASSERT_EQ(status.error_code(), 1); | |||||
| } | |||||
| TEST_F(TestAgentConfigAcquire, test_agent_config_acquire_not_init_config_failed) { | |||||
| std::shared_ptr<DistributedServable> servable = std::make_shared<DistributedServable>(); | |||||
| std::string rank_table_content = "rank table content"; | |||||
| CommonServableMeta commonServableMeta; | |||||
| commonServableMeta.servable_name = "servable_name"; | |||||
| commonServableMeta.outputs_count = 1; | |||||
| commonServableMeta.inputs_count = 1; | |||||
| commonServableMeta.with_batch_dim = false; | |||||
| commonServableMeta.without_batch_dim_inputs.push_back(8); | |||||
| DistributedServableMeta distributedServableMeta; | |||||
| distributedServableMeta.stage_size = 8; | |||||
| distributedServableMeta.rank_size = 8; | |||||
| servable->config_.rank_table_content = rank_table_content; | |||||
| servable->config_.common_meta = commonServableMeta; | |||||
| servable->config_.distributed_meta = distributedServableMeta; | |||||
| servable->config_loaded_ = true; | |||||
| const std::string server_address = "any_addr"; | |||||
| MSDistributedImpl mSDistributedImpl(servable, server_address); | |||||
| grpc::ServerContext context; | |||||
| const proto::AgentConfigAcquireRequest request; | |||||
| proto::AgentConfigAcquireReply reply; | |||||
| const grpc::Status status = mSDistributedImpl.AgentConfigAcquire(&context, &request, &reply); | |||||
| ASSERT_EQ(status.error_code(), 1); | |||||
| } | |||||
| } // namespace serving | |||||
| } // namespace mindspore | |||||
| @@ -40,6 +40,7 @@ TEST_F(TestParseRankTableFile, test_init_config_on_startup_empty_file_failed) { | |||||
| std::ofstream fp(empty_rank_table_file); | std::ofstream fp(empty_rank_table_file); | ||||
| fp << "empty rank table file"; | fp << "empty rank table file"; | ||||
| fp.close(); | fp.close(); | ||||
| config_file_list_.emplace(empty_rank_table_file); | |||||
| auto servable = std::make_shared<DistributedServable>(); | auto servable = std::make_shared<DistributedServable>(); | ||||
| auto status = servable->InitConfigOnStartup(empty_rank_table_file); | auto status = servable->InitConfigOnStartup(empty_rank_table_file); | ||||
| ASSERT_EQ(status.StatusCode(), INVALID_INPUTS); | ASSERT_EQ(status.StatusCode(), INVALID_INPUTS); | ||||
| @@ -64,6 +65,7 @@ TEST_F(TestParseRankTableFile, test_init_config_on_startup_success) { | |||||
| std::ofstream fp(rank_table_file); | std::ofstream fp(rank_table_file); | ||||
| fp << rank_table_server_list; | fp << rank_table_server_list; | ||||
| fp.close(); | fp.close(); | ||||
| config_file_list_.emplace(rank_table_file); | |||||
| auto servable = std::make_shared<DistributedServable>(); | auto servable = std::make_shared<DistributedServable>(); | ||||
| auto status = servable->InitConfigOnStartup(rank_table_file); | auto status = servable->InitConfigOnStartup(rank_table_file); | ||||
| ASSERT_EQ(status.StatusCode(), SUCCESS); | ASSERT_EQ(status.StatusCode(), SUCCESS); | ||||