Browse Source

distributed prediction ut

tags/v1.2.0
qinzheng 5 years ago
parent
commit
56a2c7e08e
1 changed files with 256 additions and 0 deletions
  1. +256
    -0
      tests/ut/cpp/tests/test_distributed_inference.cc

+ 256
- 0
tests/ut/cpp/tests/test_distributed_inference.cc View File

@@ -0,0 +1,256 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <thread>
#include <chrono>
#include <vector>
#include "gtest/gtest.h"
#include "common/status.h"
#include "proto/ms_agent.pb.h"
#include "tests/ut/cpp/common/common_test.h"
#include "common/grpc_client.h"
#include "worker/distributed_worker/notify_agent/base_notify_agent.h"
#define private public
#include "common/exit_handle.h"
#include "worker/distributed_worker/distributed_servable.h"
#undef private

namespace mindspore {
namespace serving {

struct AgentInferResult {
int64_t prediction_time = 0; // milliseconds
Status status = SUCCESS;
int64_t error_code = 0;
std::string error_msg = "";
};

class FakeNotifyAgent : public BaseNotifyAgent {
public:
explicit FakeNotifyAgent(int64_t prediction_time = 0, Status status = SUCCESS, int64_t error_code = 0,
std::string error_msg = "")
: prediction_time_(prediction_time), status_(status), error_code_(error_code), error_msg_(error_msg) {}
~FakeNotifyAgent() = default;
Status Exit() override { return SUCCESS; }
Status DispatchAsync(const proto::DistributedPredictRequest &request, proto::DistributedPredictReply *reply,
AsyncPredictCallback callback) override {
auto error_msg = reply->mutable_error_msg();
error_msg->set_error_code(error_code_);
if (!error_msg_.empty()) {
error_msg->set_error_msg(error_msg_);
}

auto predict = [=]() {
std::chrono::milliseconds dura(prediction_time_);
std::this_thread::sleep_for(dura);
callback(status_);
};
std::thread t1(predict);
t1.detach();
return SUCCESS;
}

private:
int64_t prediction_time_; // milliseconds
Status status_;
int64_t error_code_;
std::string error_msg_;
};

class TestDistributedInference : public UT::Common {
public:
TestDistributedInference() = default;
~TestDistributedInference() = default;

void InitDistributedServable(std::shared_ptr<DistributedServable> servable, size_t rank_size, size_t stage_size,
bool is_running, bool is_loaded) {
ExitSignalHandle::Instance().is_running_ = is_running;
servable->model_loaded_ = is_loaded;
servable->config_.distributed_meta.rank_size = rank_size;
servable->config_.distributed_meta.stage_size = stage_size;
}

void InitAgentSpecMap(std::shared_ptr<DistributedServable> servable,
const std::vector<AgentInferResult> &result_list) {
for (size_t rank_id = 0; rank_id < result_list.size(); ++rank_id) {
const auto &result = result_list[rank_id];
DistributedAgentContext agent_context;
agent_context.notify_agent_ =
std::make_shared<FakeNotifyAgent>(result.prediction_time, result.status, result.error_code, result.error_msg);
servable->agent_spec_map_.insert({rank_id, agent_context});
}
}
};

TEST_F(TestDistributedInference, test_agent_8_stage_1) {
auto servable = std::make_shared<DistributedServable>();
InitDistributedServable(servable, 8, 1, true, true);

std::vector<AgentInferResult> result_list(8);
InitAgentSpecMap(servable, result_list);

std::vector<TensorBasePtr> input, output;
auto status = servable->Predict(input, &output);

ASSERT_EQ(status.StatusCode(), SUCCESS);
}

TEST_F(TestDistributedInference, test_agent_4) {
auto servable = std::make_shared<DistributedServable>();
InitDistributedServable(servable, 4, 1, true, true);

std::vector<AgentInferResult> result_list(4);
InitAgentSpecMap(servable, result_list);

std::vector<TensorBasePtr> input, output;
auto status = servable->Predict(input, &output);

ASSERT_EQ(status.StatusCode(), SUCCESS);
}

TEST_F(TestDistributedInference, test_agent_32_stage_1) {
auto servable = std::make_shared<DistributedServable>();
InitDistributedServable(servable, 32, 1, true, true);

std::vector<AgentInferResult> result_list(32);
InitAgentSpecMap(servable, result_list);

std::vector<TensorBasePtr> input, output;
auto status = servable->Predict(input, &output);

ASSERT_EQ(status.StatusCode(), SUCCESS);
}

TEST_F(TestDistributedInference, test_agent_32_stage_2) {
auto servable = std::make_shared<DistributedServable>();
InitDistributedServable(servable, 32, 2, true, true);

std::vector<AgentInferResult> result_list(32);
InitAgentSpecMap(servable, result_list);

std::vector<TensorBasePtr> input, output;
auto status = servable->Predict(input, &output);

ASSERT_EQ(status.StatusCode(), SUCCESS);
}

TEST_F(TestDistributedInference, test_agent_32_stage_4) {
auto servable = std::make_shared<DistributedServable>();
InitDistributedServable(servable, 32, 4, true, true);

std::vector<AgentInferResult> result_list(32);
InitAgentSpecMap(servable, result_list);

std::vector<TensorBasePtr> input, output;
auto status = servable->Predict(input, &output);

ASSERT_EQ(status.StatusCode(), SUCCESS);
}

TEST_F(TestDistributedInference, test_agent_64_stage_8) {
auto servable = std::make_shared<DistributedServable>();
InitDistributedServable(servable, 64, 8, true, true);

std::vector<AgentInferResult> result_list(64);
InitAgentSpecMap(servable, result_list);

std::vector<TensorBasePtr> input, output;
auto status = servable->Predict(input, &output);

ASSERT_EQ(status.StatusCode(), SUCCESS);
}

TEST_F(TestDistributedInference, test_output_nullptr) {
auto servable = std::make_shared<DistributedServable>();
InitDistributedServable(servable, 32, 4, true, true);

std::vector<AgentInferResult> result_list(32);
InitAgentSpecMap(servable, result_list);

Status status;
std::vector<TensorBasePtr> input, output;
ASSERT_ANY_THROW({ status = servable->Predict(input, nullptr); });
ASSERT_EQ(status.StatusCode(), FAILED);
}

TEST_F(TestDistributedInference, test_agent_infer_more_than_10s) {
auto servable = std::make_shared<DistributedServable>();
InitDistributedServable(servable, 32, 4, true, true);

std::vector<AgentInferResult> result_list(32);
result_list[20].prediction_time = 11000;
InitAgentSpecMap(servable, result_list);

std::vector<TensorBasePtr> input, output;
auto status = servable->Predict(input, &output);

ASSERT_EQ(status.StatusCode(), FAILED);
}

TEST_F(TestDistributedInference, test_agent_exit) {
auto servable = std::make_shared<DistributedServable>();
InitDistributedServable(servable, 32, 4, false, true);

std::vector<AgentInferResult> result_list(32);
InitAgentSpecMap(servable, result_list);

std::vector<TensorBasePtr> input, output;
auto status = servable->Predict(input, &output);

ASSERT_EQ(status.StatusCode(), FAILED);
}

TEST_F(TestDistributedInference, test_rank_size_not_equal_agent_num) {
auto servable = std::make_shared<DistributedServable>();
InitDistributedServable(servable, 32, 4, true, true);

std::vector<AgentInferResult> result_list(12);
InitAgentSpecMap(servable, result_list);

Status status;
std::vector<TensorBasePtr> input, output;
ASSERT_ANY_THROW({ status = servable->Predict(input, &output); });
ASSERT_EQ(status.StatusCode(), FAILED);
}

TEST_F(TestDistributedInference, test_agent_reply_with_error_msg) {
auto servable = std::make_shared<DistributedServable>();
InitDistributedServable(servable, 32, 4, true, true);

std::vector<AgentInferResult> result_list(32);
result_list[10].error_msg = "failed";
result_list[10].error_code = 1;
InitAgentSpecMap(servable, result_list);

std::vector<TensorBasePtr> input, output;
auto status = servable->Predict(input, &output);
ASSERT_EQ(status.StatusCode(), FAILED);
}

TEST_F(TestDistributedInference, test_model_not_loaded) {
auto servable = std::make_shared<DistributedServable>();
InitDistributedServable(servable, 32, 4, true, false);

std::vector<AgentInferResult> result_list(32);
InitAgentSpecMap(servable, result_list);

Status status;
std::vector<TensorBasePtr> input, output;
ASSERT_ANY_THROW({ status = servable->Predict(input, &output); });
ASSERT_EQ(status.StatusCode(), FAILED);
}

} // namespace serving
} // namespace mindspore

Loading…
Cancel
Save