Browse Source

ut for agent config acquire

tags/v1.2.0
“xujincai” 5 years ago
parent
commit
547a49356b
4 changed files with 152 additions and 27 deletions
  1. +31
    -27
      mindspore_serving/ccsrc/worker/distributed_worker/notify_distributed/notify_worker.cc
  2. +2
    -0
      mindspore_serving/ccsrc/worker/distributed_worker/notify_distributed/notify_worker.h
  3. +117
    -0
      tests/ut/cpp/tests/test_agent_config_acquire.cc
  4. +2
    -0
      tests/ut/cpp/tests/test_init_config_on_start_up.cc

+ 31
- 27
mindspore_serving/ccsrc/worker/distributed_worker/notify_distributed/notify_worker.cc View File

@@ -115,33 +115,7 @@ Status GrpcNotifyDistributeWorker::GetAgentsConfigsFromWorker(const std::string
context.set_deadline(deadline);
grpc::Status status = stub->AgentConfigAcquire(&context, request, &reply);
if (status.ok()) {
MSI_LOG(INFO) << "Success to get Agents configs from Worker, and begin to parser";
// parser reply message:AgentConfigAcquireReply, parameter:rank_table_content
config->rank_table_content = reply.rank_table_content();
// parser reply message:AgentConfigAcquireReply, parameter:rank_list
for (auto &temp_rank : reply.rank_list()) {
OneRankConfig ome_rank_config;
ome_rank_config.ip = temp_rank.ip();
ome_rank_config.device_id = temp_rank.device_id();
config->rank_list.push_back(ome_rank_config);
}
// parser reply message:AgentConfigAcquireReply, parameter:common_meta
auto &temp_common_meta = reply.common_meta();
config->common_meta.servable_name = temp_common_meta.servable_name();
config->common_meta.with_batch_dim = temp_common_meta.with_batch_dim();
for (auto &temp_without_batch_dim_inputs : temp_common_meta.without_batch_dim_inputs()) {
config->common_meta.without_batch_dim_inputs.push_back(temp_without_batch_dim_inputs);
}
config->common_meta.inputs_count = temp_common_meta.inputs_count();
config->common_meta.outputs_count = temp_common_meta.outputs_count();

// parser reply message:AgentConfigAcquireReply, parameter:distributed_meta
auto &temp_distributed_meta = reply.distributed_meta();
config->distributed_meta.rank_size = temp_distributed_meta.rank_size();
config->distributed_meta.stage_size = temp_distributed_meta.stage_size();
MSI_LOG(INFO) << "Success to parser reply message and save to DistributedServableConfig";

return SUCCESS;
return ParseAgentConfigAcquireReply(reply, config);
}
MSI_LOG_INFO << "Grpc message: " << status.error_code() << ", " << status.error_message();
std::this_thread::sleep_for(std::chrono::milliseconds(REGISTER_INTERVAL * 1000));
@@ -151,6 +125,36 @@ Status GrpcNotifyDistributeWorker::GetAgentsConfigsFromWorker(const std::string
}
return INFER_STATUS_LOG_ERROR(SYSTEM_ERROR) << "Failed to get Agents configs from Worker";
}
Status GrpcNotifyDistributeWorker::ParseAgentConfigAcquireReply(const proto::AgentConfigAcquireReply &reply,
DistributedServableConfig *config) {
MSI_LOG(INFO) << "Success to get Agents configs from Worker, and begin to parser";
// parser reply message:AgentConfigAcquireReply, parameter:rank_table_content
config->rank_table_content = reply.rank_table_content();
// parser reply message:AgentConfigAcquireReply, parameter:rank_list
for (auto &temp_rank : reply.rank_list()) {
OneRankConfig ome_rank_config;
ome_rank_config.ip = temp_rank.ip();
ome_rank_config.device_id = temp_rank.device_id();
config->rank_list.push_back(ome_rank_config);
}
// parser reply message:AgentConfigAcquireReply, parameter:common_meta
auto &temp_common_meta = reply.common_meta();
config->common_meta.servable_name = temp_common_meta.servable_name();
config->common_meta.with_batch_dim = temp_common_meta.with_batch_dim();
for (auto &temp_without_batch_dim_inputs : temp_common_meta.without_batch_dim_inputs()) {
config->common_meta.without_batch_dim_inputs.push_back(temp_without_batch_dim_inputs);
}
config->common_meta.inputs_count = temp_common_meta.inputs_count();
config->common_meta.outputs_count = temp_common_meta.outputs_count();

// parser reply message:AgentConfigAcquireReply, parameter:distributed_meta
auto &temp_distributed_meta = reply.distributed_meta();
config->distributed_meta.rank_size = temp_distributed_meta.rank_size();
config->distributed_meta.stage_size = temp_distributed_meta.stage_size();
MSI_LOG(INFO) << "Success to parser reply message and save to DistributedServableConfig";

return SUCCESS;
}

} // namespace serving
} // namespace mindspore

+ 2
- 0
mindspore_serving/ccsrc/worker/distributed_worker/notify_distributed/notify_worker.h View File

@@ -41,6 +41,8 @@ class MS_API GrpcNotifyDistributeWorker {
DistributedServableConfig *config);

private:
static Status ParseAgentConfigAcquireReply(const proto::AgentConfigAcquireReply &reply,
DistributedServableConfig *config);
std::string distributed_worker_ip_;
uint32_t distributed_worker_port_;
std::string host_ip_;


+ 117
- 0
tests/ut/cpp/tests/test_agent_config_acquire.cc View File

@@ -0,0 +1,117 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "common/common_test.h"
#include "common/tensor_base.h"
#define private public
#include "worker/distributed_worker/distributed_process/distributed_process.h"
#include "worker/distributed_worker/notify_distributed/notify_worker.h"
#undef private

using std::string;
using std::vector;
namespace mindspore {
namespace serving {
class TestAgentConfigAcquire : public UT::Common {
public:
TestAgentConfigAcquire() = default;
virtual void SetUp() {}
virtual void TearDown() {}
};

TEST_F(TestAgentConfigAcquire, test_agent_config_acquire_success) {
std::shared_ptr<DistributedServable> servable = std::make_shared<DistributedServable>();
std::string rank_table_content = "rank table content";
CommonServableMeta commonServableMeta;
commonServableMeta.servable_name = "servable_name";
commonServableMeta.outputs_count = 1;
commonServableMeta.inputs_count = 1;
commonServableMeta.with_batch_dim = false;
commonServableMeta.without_batch_dim_inputs.push_back(8);
DistributedServableMeta distributedServableMeta;
distributedServableMeta.stage_size = 8;
distributedServableMeta.rank_size = 8;
OneRankConfig oneRankConfig;
oneRankConfig.ip = "1.1.1.1";
oneRankConfig.device_id = 0;
servable->config_.rank_table_content = rank_table_content;
servable->config_.common_meta = commonServableMeta;
servable->config_.distributed_meta = distributedServableMeta;
servable->config_.rank_list.push_back(oneRankConfig);
servable->config_loaded_ = true;
const std::string server_address = "any_addr";
MSDistributedImpl mSDistributedImpl(servable, server_address);
grpc::ServerContext context;
const proto::AgentConfigAcquireRequest request;
proto::AgentConfigAcquireReply reply;
grpc::Status status = mSDistributedImpl.AgentConfigAcquire(&context, &request, &reply);
ASSERT_EQ(status.error_code(), 0);

DistributedServableConfig config;
GrpcNotifyDistributeWorker::ParseAgentConfigAcquireReply(reply, &config);
ASSERT_EQ(config.rank_table_content, rank_table_content);
ASSERT_EQ(config.common_meta.servable_name, "servable_name");
ASSERT_EQ(config.common_meta.inputs_count, 1);
ASSERT_EQ(config.common_meta.outputs_count, 1);
ASSERT_EQ(config.common_meta.with_batch_dim, false);
ASSERT_EQ(config.common_meta.without_batch_dim_inputs.size(), 1);
ASSERT_EQ(config.common_meta.without_batch_dim_inputs.at(0), 8);
ASSERT_EQ(config.distributed_meta.rank_size, 8);
ASSERT_EQ(config.distributed_meta.stage_size, 8);
ASSERT_EQ(config.rank_list.size(), 1);
OneRankConfig tempRankConfig = config.rank_list.at(0);
ASSERT_EQ(tempRankConfig.device_id, 0);
ASSERT_EQ(tempRankConfig.ip, "1.1.1.1");
}

TEST_F(TestAgentConfigAcquire, test_agent_config_acquire_not_load_config_failed) {
std::shared_ptr<DistributedServable> servable = std::make_shared<DistributedServable>();
servable->config_loaded_ = true;
const std::string server_address = "any_addr";
MSDistributedImpl mSDistributedImpl(servable, server_address);
grpc::ServerContext context;
const proto::AgentConfigAcquireRequest request;
proto::AgentConfigAcquireReply reply;
const grpc::Status status = mSDistributedImpl.AgentConfigAcquire(&context, &request, &reply);
ASSERT_EQ(status.error_code(), 1);
}

TEST_F(TestAgentConfigAcquire, test_agent_config_acquire_not_init_config_failed) {
std::shared_ptr<DistributedServable> servable = std::make_shared<DistributedServable>();
std::string rank_table_content = "rank table content";
CommonServableMeta commonServableMeta;
commonServableMeta.servable_name = "servable_name";
commonServableMeta.outputs_count = 1;
commonServableMeta.inputs_count = 1;
commonServableMeta.with_batch_dim = false;
commonServableMeta.without_batch_dim_inputs.push_back(8);
DistributedServableMeta distributedServableMeta;
distributedServableMeta.stage_size = 8;
distributedServableMeta.rank_size = 8;
servable->config_.rank_table_content = rank_table_content;
servable->config_.common_meta = commonServableMeta;
servable->config_.distributed_meta = distributedServableMeta;
servable->config_loaded_ = true;
const std::string server_address = "any_addr";
MSDistributedImpl mSDistributedImpl(servable, server_address);
grpc::ServerContext context;
const proto::AgentConfigAcquireRequest request;
proto::AgentConfigAcquireReply reply;
const grpc::Status status = mSDistributedImpl.AgentConfigAcquire(&context, &request, &reply);
ASSERT_EQ(status.error_code(), 1);
}

} // namespace serving
} // namespace mindspore

+ 2
- 0
tests/ut/cpp/tests/test_init_config_on_start_up.cc View File

@@ -40,6 +40,7 @@ TEST_F(TestParseRankTableFile, test_init_config_on_startup_empty_file_failed) {
std::ofstream fp(empty_rank_table_file);
fp << "empty rank table file";
fp.close();
config_file_list_.emplace(empty_rank_table_file);
auto servable = std::make_shared<DistributedServable>();
auto status = servable->InitConfigOnStartup(empty_rank_table_file);
ASSERT_EQ(status.StatusCode(), INVALID_INPUTS);
@@ -64,6 +65,7 @@ TEST_F(TestParseRankTableFile, test_init_config_on_startup_success) {
std::ofstream fp(rank_table_file);
fp << rank_table_server_list;
fp.close();
config_file_list_.emplace(rank_table_file);
auto servable = std::make_shared<DistributedServable>();
auto status = servable->InitConfigOnStartup(rank_table_file);
ASSERT_EQ(status.StatusCode(), SUCCESS);


Loading…
Cancel
Save