Browse Source

!122 merge gpt3 to master

From: @xu-yfei
Reviewed-by: @zhoufeng54,@linqingke
Signed-off-by: @zhoufeng54
tags/v1.2.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
a939bfafaf
71 changed files with 3532 additions and 1044 deletions
  1. +13
    -0
      mindspore_serving/ccsrc/common/exit_handle.cc
  2. +2
    -0
      mindspore_serving/ccsrc/common/exit_handle.h
  3. +25
    -0
      mindspore_serving/ccsrc/common/grpc_client.cc
  4. +115
    -0
      mindspore_serving/ccsrc/common/grpc_client.h
  5. +50
    -0
      mindspore_serving/ccsrc/common/proto_tensor.cc
  6. +4
    -0
      mindspore_serving/ccsrc/common/proto_tensor.h
  7. +218
    -122
      mindspore_serving/ccsrc/common/servable.cc
  8. +42
    -7
      mindspore_serving/ccsrc/common/servable.h
  9. +1
    -1
      mindspore_serving/ccsrc/master/dispacther.h
  10. +0
    -73
      mindspore_serving/ccsrc/master/grpc/grpc_client.cc
  11. +0
    -68
      mindspore_serving/ccsrc/master/grpc/grpc_client.h
  12. +1
    -2
      mindspore_serving/ccsrc/master/notify_worker/base_notify.h
  13. +2
    -3
      mindspore_serving/ccsrc/master/notify_worker/grpc_notify.cc
  14. +0
    -3
      mindspore_serving/ccsrc/master/server.cc
  15. +63
    -0
      mindspore_serving/ccsrc/python/agent/agent_py.cc
  16. +47
    -0
      mindspore_serving/ccsrc/python/agent/agent_py.h
  17. +97
    -16
      mindspore_serving/ccsrc/python/serving_py.cc
  18. +10
    -1
      mindspore_serving/ccsrc/python/worker/servable_py.cc
  19. +1
    -0
      mindspore_serving/ccsrc/python/worker/servable_py.h
  20. +80
    -6
      mindspore_serving/ccsrc/python/worker/worker_py.cc
  21. +10
    -0
      mindspore_serving/ccsrc/python/worker/worker_py.h
  22. +34
    -0
      mindspore_serving/ccsrc/worker/distributed_worker/agent_executor.cc
  23. +48
    -0
      mindspore_serving/ccsrc/worker/distributed_worker/agent_executor.h
  24. +37
    -0
      mindspore_serving/ccsrc/worker/distributed_worker/agent_process/agent_process.cc
  25. +42
    -0
      mindspore_serving/ccsrc/worker/distributed_worker/agent_process/agent_process.h
  26. +45
    -0
      mindspore_serving/ccsrc/worker/distributed_worker/agent_startup.cc
  27. +48
    -0
      mindspore_serving/ccsrc/worker/distributed_worker/agent_startup.h
  28. +61
    -0
      mindspore_serving/ccsrc/worker/distributed_worker/common.h
  29. +72
    -0
      mindspore_serving/ccsrc/worker/distributed_worker/distributed_process/distributed_process.cc
  30. +54
    -0
      mindspore_serving/ccsrc/worker/distributed_worker/distributed_process/distributed_process.h
  31. +37
    -0
      mindspore_serving/ccsrc/worker/distributed_worker/distributed_process/distributed_server.cc
  32. +178
    -0
      mindspore_serving/ccsrc/worker/distributed_worker/distributed_process/distributed_server.h
  33. +335
    -0
      mindspore_serving/ccsrc/worker/distributed_worker/distributed_servable.cc
  34. +92
    -0
      mindspore_serving/ccsrc/worker/distributed_worker/distributed_servable.h
  35. +42
    -0
      mindspore_serving/ccsrc/worker/distributed_worker/notify_agent/base_notify_agent.h
  36. +66
    -0
      mindspore_serving/ccsrc/worker/distributed_worker/notify_agent/notify_agent.cc
  37. +48
    -0
      mindspore_serving/ccsrc/worker/distributed_worker/notify_agent/notify_agent.h
  38. +107
    -0
      mindspore_serving/ccsrc/worker/distributed_worker/notify_distributed/notify_worker.cc
  39. +55
    -0
      mindspore_serving/ccsrc/worker/distributed_worker/notify_distributed/notify_worker.h
  40. +103
    -0
      mindspore_serving/ccsrc/worker/distributed_worker/worker_agent.cc
  41. +55
    -0
      mindspore_serving/ccsrc/worker/distributed_worker/worker_agent.h
  42. +0
    -1
      mindspore_serving/ccsrc/worker/grpc/worker_process.cc
  43. +1
    -1
      mindspore_serving/ccsrc/worker/grpc/worker_process.h
  44. +15
    -3
      mindspore_serving/ccsrc/worker/grpc/worker_server.cc
  45. +27
    -30
      mindspore_serving/ccsrc/worker/grpc/worker_server.h
  46. +0
    -126
      mindspore_serving/ccsrc/worker/inference/inference.h
  47. +62
    -84
      mindspore_serving/ccsrc/worker/inference/mindspore_model_wrap.cc
  48. +15
    -23
      mindspore_serving/ccsrc/worker/inference/mindspore_model_wrap.h
  49. +254
    -0
      mindspore_serving/ccsrc/worker/local_servable/local_sevable.cc
  50. +69
    -0
      mindspore_serving/ccsrc/worker/local_servable/local_sevable.h
  51. +0
    -33
      mindspore_serving/ccsrc/worker/model.cc
  52. +6
    -20
      mindspore_serving/ccsrc/worker/sevable_base.h
  53. +11
    -11
      mindspore_serving/ccsrc/worker/work_executor.cc
  54. +3
    -5
      mindspore_serving/ccsrc/worker/work_executor.h
  55. +57
    -264
      mindspore_serving/ccsrc/worker/worker.cc
  56. +10
    -19
      mindspore_serving/ccsrc/worker/worker.h
  57. +2
    -0
      mindspore_serving/master/_master.py
  58. +43
    -0
      mindspore_serving/proto/ms_agent.proto
  59. +53
    -0
      mindspore_serving/proto/ms_distributed.proto
  60. +2
    -0
      mindspore_serving/proto/ms_service.proto
  61. +6
    -0
      mindspore_serving/proto/ms_worker.proto
  62. +3
    -1
      mindspore_serving/worker/_worker.py
  63. +250
    -0
      mindspore_serving/worker/distributed/agent_startup.py
  64. +131
    -0
      mindspore_serving/worker/distributed/distributed_worker.py
  65. +43
    -0
      mindspore_serving/worker/distributed/register.py
  66. +66
    -0
      mindspore_serving/worker/distributed/worker_agent.py
  67. +2
    -24
      mindspore_serving/worker/register/method.py
  68. +17
    -16
      mindspore_serving/worker/register/servable.py
  69. +11
    -5
      tests/ut/cpp/common/test_servable_common.h
  70. +17
    -63
      tests/ut/cpp/tests/test_start_worker.cc
  71. +16
    -13
      tests/ut/runtest.sh

+ 13
- 0
mindspore_serving/ccsrc/common/exit_handle.cc View File

@@ -55,6 +55,17 @@ void ExitSignalHandle::WorkerWait() {
exit_future.wait();
}

// waiting ctrl+c or stop message to exit,
// if no server is running or server has exited, there is no need to wait
void ExitSignalHandle::AgentWait() {
if (!is_running_) {
MSI_LOG_INFO << "Exit Handle has not started or has exited";
return;
}
auto exit_future = agent_exit_requested_.get_future();
exit_future.wait();
}

void ExitSignalHandle::Start() {
if (is_running_) {
return;
@@ -62,6 +73,7 @@ void ExitSignalHandle::Start() {
is_running_ = true;
master_exit_requested_ = std::promise<void>();
worker_exit_requested_ = std::promise<void>();
agent_exit_requested_ = std::promise<void>();
has_exited_.clear();
InitSignalHandle();
}
@@ -79,6 +91,7 @@ void ExitSignalHandle::HandleSignalInner() {
if (!has_exited_.test_and_set()) {
master_exit_requested_.set_value();
worker_exit_requested_.set_value();
agent_exit_requested_.set_value();
is_running_ = false;
}
}


+ 2
- 0
mindspore_serving/ccsrc/common/exit_handle.h View File

@@ -32,6 +32,7 @@ class MS_API ExitSignalHandle {
void InitSignalHandle();
void MasterWait();
void WorkerWait();
void AgentWait();
void Start();
void Stop();
bool HasStopped();
@@ -39,6 +40,7 @@ class MS_API ExitSignalHandle {
private:
std::promise<void> master_exit_requested_;
std::promise<void> worker_exit_requested_;
std::promise<void> agent_exit_requested_;
std::atomic_flag has_exited_ = true;
std::atomic_flag has_inited_ = ATOMIC_FLAG_INIT;
std::atomic_bool is_running_ = false;


+ 25
- 0
mindspore_serving/ccsrc/common/grpc_client.cc View File

@@ -0,0 +1,25 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "common/grpc_client.h"
namespace mindspore {
namespace serving {
std::unique_ptr<MSPredictClient> client_;
std::unique_ptr<MSDistributedClient> distributed_client_;
} // namespace serving
} // namespace mindspore

+ 115
- 0
mindspore_serving/ccsrc/common/grpc_client.h View File

@@ -0,0 +1,115 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_SERVING_MASTER_GRPC_CLIENT_H
#define MINDSPORE_SERVING_MASTER_GRPC_CLIENT_H
#include <grpcpp/grpcpp.h>
#include <grpcpp/health_check_service_interface.h>
#include <grpcpp/ext/proto_server_reflection_plugin.h>
#include <memory>
#include <functional>
#include <thread>
#include <string>
#include <utility>
#include "common/serving_common.h"
#include "proto/ms_service.pb.h"
#include "proto/ms_service.grpc.pb.h"
#include "proto/ms_master.pb.h"
#include "proto/ms_master.grpc.pb.h"
#include "proto/ms_worker.grpc.pb.h"
#include "proto/ms_agent.pb.h"
#include "proto/ms_agent.grpc.pb.h"
namespace mindspore {
namespace serving {
using PredictOnFinish = std::function<void()>;
using DispatchCallback = std::function<void(Status status)>;
template <typename Request, typename Reply, typename MSStub>
class MSServiceClient {
public:
MSServiceClient() = default;
~MSServiceClient() {
if (in_running_) {
cq_.Shutdown();
if (client_thread_.joinable()) {
try {
client_thread_.join();
} catch (const std::system_error &) {
} catch (...) {
}
}
}
in_running_ = false;
}
void Start() {
client_thread_ = std::thread(&MSServiceClient::AsyncCompleteRpc, this);
in_running_ = true;
}
void AsyncCompleteRpc() {
void *got_tag;
bool ok = false;
while (cq_.Next(&got_tag, &ok)) {
AsyncClientCall *call = static_cast<AsyncClientCall *>(got_tag);
if (call->status.ok()) {
call->callback(SUCCESS);
} else {
MSI_LOG_ERROR << "RPC failed: " << call->status.error_code() << ", " << call->status.error_message();
call->callback(Status(FAILED, call->status.error_message()));
}
delete call;
}
}
void PredictAsync(const Request &request, Reply *reply, MSStub *stub, DispatchCallback callback) {
AsyncClientCall *call = new AsyncClientCall;
call->reply = reply;
call->callback = std::move(callback);
call->response_reader = stub->PrepareAsyncPredict(&call->context, request, &cq_);
call->response_reader->StartCall();
call->response_reader->Finish(call->reply, &call->status, call);
MSI_LOG(INFO) << "Finish send Predict";
}
private:
struct AsyncClientCall {
grpc::ClientContext context;
grpc::Status status;
Reply *reply;
DispatchCallback callback;
std::shared_ptr<grpc::ClientAsyncResponseReader<Reply>> response_reader;
};
grpc::CompletionQueue cq_;
std::thread client_thread_;
bool in_running_ = false;
};
using MSPredictClient = MSServiceClient<proto::PredictRequest, proto::PredictReply, proto::MSWorker::Stub>;
using MSDistributedClient =
MSServiceClient<proto::DistributedPredictRequest, proto::DistributedPredictReply, proto::MSAgent::Stub>;
extern std::unique_ptr<MSPredictClient> client_;
extern std::unique_ptr<MSDistributedClient> distributed_client_;
} // namespace serving
} // namespace mindspore
#endif // MINDSPORE_SERVING_MASTER_GRPC_CLIENT_H

+ 50
- 0
mindspore_serving/ccsrc/common/proto_tensor.cc View File

@@ -341,6 +341,56 @@ Status GrpcTensorHelper::CreateInstanceFromRequestInstances(const proto::Predict
return SUCCESS;
}

void GrpcTensorHelper::CopyFromAgentSpec(const proto::AgentSpec &specs, WorkerAgentSpec *worker_specs) {
worker_specs->rank_id = specs.rank_id();
worker_specs->batch_size = specs.batch_size();
for (auto &in : specs.inputs()) {
TensorInfo info;
info.data_type = ProtoTensor::TransDataType2Inference(in.dtype());
info.size = in.size();
for (auto &dim : in.shape().dims()) {
info.shape.push_back(dim);
}
worker_specs->input_infos.push_back(info);
}
for (auto &out : specs.outputs()) {
TensorInfo info;
info.data_type = ProtoTensor::TransDataType2Inference(out.dtype());
for (auto &dim : out.shape().dims()) {
info.shape.push_back(dim);
}
worker_specs->output_infos.push_back(info);
}
}

void GrpcTensorHelper::CopyFromWorkerAgentSpec(const std::vector<WorkerAgentSpec> &worker_specs,
proto::AgentRegisterRequest *request) {
for (size_t i = 0; i < worker_specs.size(); i++) {
auto &spec = worker_specs[i];
auto worker_spec = request->add_agent_spec();
worker_spec->set_rank_id(spec.rank_id);
worker_spec->set_batch_size(spec.batch_size);
for (auto &method : spec.input_infos) {
auto proto_method = worker_spec->add_inputs();
proto_method->set_dtype(ProtoTensor::TransDataType2Proto(method.data_type));
proto_method->set_size(method.size);
auto proto_shape = proto_method->mutable_shape();
for (auto &dim : method.shape) {
proto_shape->add_dims(dim);
}
}
for (auto &method : spec.output_infos) {
auto proto_method = worker_spec->add_outputs();
proto_method->set_dtype(ProtoTensor::TransDataType2Proto(method.data_type));
proto_method->set_size(method.size);
auto proto_shape = proto_method->mutable_shape();
for (auto &dim : method.shape) {
proto_shape->add_dims(dim);
}
}
}
}

Status GrpcTensorHelper::CheckRequestTensor(const proto::Tensor &tensor) {
Status status;
ProtoTensor tensor_input(const_cast<proto::Tensor *>(&tensor));


+ 4
- 0
mindspore_serving/ccsrc/common/proto_tensor.h View File

@@ -24,6 +24,7 @@
#include "common/serving_common.h"
#include "proto/ms_service.pb.h"
#include "proto/ms_master.pb.h"
#include "proto/ms_distributed.pb.h"
#include "common/instance.h"
#include "common/servable.h"

@@ -68,6 +69,9 @@ class MS_API GrpcTensorHelper {
std::vector<InstanceData> *results);
static Status CreateReplyFromInstances(const proto::PredictRequest &request, const std::vector<Instance> &inputs,
proto::PredictReply *reply);
static void CopyFromAgentSpec(const proto::AgentSpec &request, WorkerAgentSpec *worker_specs);
static void CopyFromWorkerAgentSpec(const std::vector<WorkerAgentSpec> &worker_specs,
proto::AgentRegisterRequest *request);

private:
static Status CreateInstanceFromRequestInstances(const proto::PredictRequest &request,


+ 218
- 122
mindspore_serving/ccsrc/common/servable.cc View File

@@ -25,11 +25,23 @@ namespace mindspore::serving {

std::string ServableMeta::Repr() const {
std::ostringstream stream;
stream << "path(" << servable_name << ") file(" << servable_file + ")";
switch (servable_type) {
case kServableTypeUnknown:
stream << "undeclared servable, servable name: '" << common_meta.servable_name << "'";
break;
case kServableTypeLocal:
stream << "local servable, servable name: '" << common_meta.servable_name << "', file: '"
<< local_meta.servable_file + "'";
break;
case kServableTypeDistributed:
stream << "distributed servable, servable name: '" << common_meta.servable_name
<< "', rank size: " << distributed_meta.rank_size << ", stage size " << distributed_meta.stage_size;
break;
}
return stream.str();
}

void ServableMeta::SetModelFormat(const std::string &format) {
void LocalServableMeta::SetModelFormat(const std::string &format) {
if (format == "om") {
model_format = kOM;
} else if (format == "mindir") {
@@ -63,142 +75,181 @@ std::string RequestSpec::Repr() const {
return "servable(" + servable_name + ") " + "method(" + method_name + ") " + version;
}

Status ServableSignature::Check() const {
std::set<std::string> method_set;
Status ServableSignature::CheckPreprocessInput(const MethodSignature &method, size_t *preprocess_outputs_count) const {
std::string model_str = servable_meta.Repr();
const auto &preprocess_name = method.preprocess_name;
if (!preprocess_name.empty()) {
auto preprocess = PreprocessStorage::Instance().GetPreprocess(preprocess_name);
if (preprocess == nullptr) {
return INFER_STATUS_LOG_ERROR(FAILED) << "Model " << model_str << " method " << method.method_name
<< " preprocess " << preprocess_name << " not defined";
}
*preprocess_outputs_count = preprocess->GetOutputsCount(preprocess_name);

for (auto &method : methods) {
if (method_set.count(method.method_name) > 0) {
return INFER_STATUS_LOG_ERROR(FAILED)
<< "Model " << model_str << " " << method.method_name << " has been defined repeatly";
for (size_t i = 0; i < method.preprocess_inputs.size(); i++) {
auto &input = method.preprocess_inputs[i];
if (input.first != kPredictPhaseTag_Input) {
return INFER_STATUS_LOG_ERROR(FAILED)
<< "Model " << model_str << " method " << method.method_name << ", the data of preprocess " << i
<< "th input cannot not come from '" << input.first << "'";
}
if (input.second >= method.inputs.size()) {
return INFER_STATUS_LOG_ERROR(FAILED)
<< "Model " << model_str << " method " << method.method_name << ", the preprocess " << i
<< "th input uses method " << input.second << "th input, that is greater than the method inputs size "
<< method.inputs.size();
}
}
method_set.emplace(method.method_name);
}
return SUCCESS;
}

size_t preprocess_outputs_count = 0;
size_t postprocess_outputs_count = 0;
Status ServableSignature::CheckPredictInput(const MethodSignature &method, size_t preprocess_outputs_count) const {
std::string model_str = servable_meta.Repr();

const auto &preprocess_name = method.preprocess_name;
if (!preprocess_name.empty()) {
auto preprocess = PreprocessStorage::Instance().GetPreprocess(preprocess_name);
if (preprocess == nullptr) {
return INFER_STATUS_LOG_ERROR(FAILED) << "Model " << model_str << " method " << method.method_name
<< " preprocess " << preprocess_name << " not defined";
for (size_t i = 0; i < method.servable_inputs.size(); i++) {
auto &input = method.servable_inputs[i];
if (input.first == kPredictPhaseTag_Input) {
if (input.second >= method.inputs.size()) {
return INFER_STATUS_LOG_ERROR(FAILED)
<< "Model " << model_str << " method " << method.method_name << ", the servable " << i
<< "th input uses method " << input.second << "th input, that is greater than the method inputs size "
<< method.inputs.size();
}
preprocess_outputs_count = preprocess->GetOutputsCount(preprocess_name);

for (size_t i = 0; i < method.preprocess_inputs.size(); i++) {
auto &input = method.preprocess_inputs[i];
if (input.first != kPredictPhaseTag_Input) {
return INFER_STATUS_LOG_ERROR(FAILED)
<< "Model " << model_str << " method " << method.method_name << ", the data of preprocess " << i
<< "th input cannot not come from '" << input.first << "'";
}
if (input.second >= method.inputs.size()) {
return INFER_STATUS_LOG_ERROR(FAILED)
<< "Model " << model_str << " method " << method.method_name << ", the preprocess " << i
<< "th input uses method " << input.second << "th input, that is greater than the method inputs size "
<< method.inputs.size();
}
} else if (input.first == kPredictPhaseTag_Preproces) {
if (input.second >= preprocess_outputs_count) {
return INFER_STATUS_LOG_ERROR(FAILED)
<< "Model " << model_str << " method " << method.method_name << ", the servable " << i
<< "th input uses preprocess " << input.second
<< "th output, that is greater than the preprocess outputs size " << preprocess_outputs_count;
}
} else {
return INFER_STATUS_LOG_ERROR(FAILED)
<< "Model " << model_str << " method " << method.method_name << ", the data of servable " << i
<< "th input cannot not come from '" << input.first << "'";
}
}
return SUCCESS;
}

Status ServableSignature::CheckPostprocessInput(const MethodSignature &method, size_t preprocess_outputs_count,
size_t *postprocess_outputs_count) const {
std::string model_str = servable_meta.Repr();
const auto &common_meta = servable_meta.common_meta;

const auto &postprocess_name = method.postprocess_name;
if (!method.postprocess_name.empty()) {
auto postprocess = PostprocessStorage::Instance().GetPostprocess(postprocess_name);
if (postprocess == nullptr) {
return INFER_STATUS_LOG_ERROR(FAILED) << "Model " << model_str << " method " << method.method_name
<< " postprocess " << postprocess_name << " not defined";
}
*postprocess_outputs_count = postprocess->GetOutputsCount(postprocess_name);

for (size_t i = 0; i < method.servable_inputs.size(); i++) {
auto &input = method.servable_inputs[i];
for (size_t i = 0; i < method.postprocess_inputs.size(); i++) {
auto &input = method.postprocess_inputs[i];
if (input.first == kPredictPhaseTag_Input) {
if (input.second >= method.inputs.size()) {
return INFER_STATUS_LOG_ERROR(FAILED)
<< "Model " << model_str << " method " << method.method_name << ", the servable " << i
<< "Model " << model_str << " method " << method.method_name << ", the postprocess " << i
<< "th input uses method " << input.second << "th input, that is greater than the method inputs size "
<< method.inputs.size();
}
} else if (input.first == kPredictPhaseTag_Preproces) {
if (input.second >= preprocess_outputs_count) {
return INFER_STATUS_LOG_ERROR(FAILED)
<< "Model " << model_str << " method " << method.method_name << ", the servable " << i
<< "Model " << model_str << " method " << method.method_name << ", the postprocess " << i
<< "th input uses preprocess " << input.second
<< "th output, that is greater than the preprocess outputs size " << preprocess_outputs_count;
}
} else if (input.first == kPredictPhaseTag_Predict) {
if (input.second >= common_meta.outputs_count) {
return INFER_STATUS_LOG_ERROR(FAILED)
<< "Model " << model_str << " method " << method.method_name << ", the postprocess " << i
<< "th input uses servable " << input.second
<< "th output, that is greater than the servable outputs size " << common_meta.outputs_count;
}
} else {
return INFER_STATUS_LOG_ERROR(FAILED)
<< "Model " << model_str << " method " << method.method_name << ", the data of servable " << i
<< "Model " << model_str << " method " << method.method_name << ", the data of postprocess " << i
<< "th input cannot not come from '" << input.first << "'";
}
}
}
return SUCCESS;
}

const auto &postprocess_name = method.postprocess_name;
if (!method.postprocess_name.empty()) {
auto postprocess = PostprocessStorage::Instance().GetPostprocess(postprocess_name);
if (postprocess == nullptr) {
return INFER_STATUS_LOG_ERROR(FAILED) << "Model " << model_str << " method " << method.method_name
<< " postprocess " << postprocess_name << " not defined";
}
postprocess_outputs_count = postprocess->GetOutputsCount(postprocess_name);
Status ServableSignature::CheckReturn(const MethodSignature &method, size_t preprocess_outputs_count,
size_t postprocess_outputs_count) const {
std::string model_str = servable_meta.Repr();
const auto &common_meta = servable_meta.common_meta;

for (size_t i = 0; i < method.postprocess_inputs.size(); i++) {
auto &input = method.postprocess_inputs[i];
if (input.first == kPredictPhaseTag_Input) {
if (input.second >= method.inputs.size()) {
return INFER_STATUS_LOG_ERROR(FAILED)
<< "Model " << model_str << " method " << method.method_name << ", the postprocess " << i
<< "th input uses method " << input.second
<< "th input, that is greater than the method inputs size " << method.inputs.size();
}
} else if (input.first == kPredictPhaseTag_Preproces) {
if (input.second >= preprocess_outputs_count) {
return INFER_STATUS_LOG_ERROR(FAILED)
<< "Model " << model_str << " method " << method.method_name << ", the postprocess " << i
<< "th input uses preprocess " << input.second
<< "th output, that is greater than the preprocess outputs size " << preprocess_outputs_count;
}
} else if (input.first == kPredictPhaseTag_Predict) {
if (input.second >= servable_meta.outputs_count) {
return INFER_STATUS_LOG_ERROR(FAILED)
<< "Model " << model_str << " method " << method.method_name << ", the postprocess " << i
<< "th input uses servable " << input.second
<< "th output, that is greater than the servable outputs size " << servable_meta.outputs_count;
}
} else {
return INFER_STATUS_LOG_ERROR(FAILED)
<< "Model " << model_str << " method " << method.method_name << ", the data of postprocess " << i
<< "th input cannot not come from '" << input.first << "'";
}
for (size_t i = 0; i < method.returns.size(); i++) {
auto &input = method.returns[i];
if (input.first == kPredictPhaseTag_Input) {
if (input.second >= method.inputs.size()) {
return INFER_STATUS_LOG_ERROR(FAILED)
<< "Model " << model_str << " method " << method.method_name << ", the method " << i
<< "th output uses method " << input.second << "th input, that is greater than the method inputs size "
<< method.inputs.size();
}
}
for (size_t i = 0; i < method.returns.size(); i++) {
auto &input = method.returns[i];
if (input.first == kPredictPhaseTag_Input) {
if (input.second >= method.inputs.size()) {
return INFER_STATUS_LOG_ERROR(FAILED)
<< "Model " << model_str << " method " << method.method_name << ", the method " << i
<< "th output uses method " << input.second << "th input, that is greater than the method inputs size "
<< method.inputs.size();
}
} else if (input.first == kPredictPhaseTag_Preproces) {
if (input.second >= preprocess_outputs_count) {
return INFER_STATUS_LOG_ERROR(FAILED)
<< "Model " << model_str << " method " << method.method_name << ", the method " << i
<< "th output uses preprocess " << input.second
<< "th output, that is greater than the preprocess outputs size " << preprocess_outputs_count;
}
} else if (input.first == kPredictPhaseTag_Predict) {
if (input.second >= servable_meta.outputs_count) {
return INFER_STATUS_LOG_ERROR(FAILED)
<< "Model " << model_str << " method " << method.method_name << ", the method " << i
<< "th output uses servable " << input.second
<< "th output, that is greater than the servable outputs size " << servable_meta.outputs_count;
}
} else if (input.first == kPredictPhaseTag_Postprocess) {
if (input.second >= postprocess_outputs_count) {
return INFER_STATUS_LOG_ERROR(FAILED)
<< "Model " << model_str << " method " << method.method_name << ", the method " << i
<< "th output uses postprocess " << input.second
<< "th output, that is greater than the postprocess outputs size " << postprocess_outputs_count;
}
} else {
} else if (input.first == kPredictPhaseTag_Preproces) {
if (input.second >= preprocess_outputs_count) {
return INFER_STATUS_LOG_ERROR(FAILED)
<< "Model " << model_str << " method " << method.method_name << ", the method " << i
<< "th output uses preprocess " << input.second
<< "th output, that is greater than the preprocess outputs size " << preprocess_outputs_count;
}
} else if (input.first == kPredictPhaseTag_Predict) {
if (input.second >= common_meta.outputs_count) {
return INFER_STATUS_LOG_ERROR(FAILED)
<< "Model " << model_str << " method " << method.method_name << ", the method " << i
<< "th output uses servable " << input.second
<< "th output, that is greater than the servable outputs size " << common_meta.outputs_count;
}
} else if (input.first == kPredictPhaseTag_Postprocess) {
if (input.second >= postprocess_outputs_count) {
return INFER_STATUS_LOG_ERROR(FAILED)
<< "Model " << model_str << " method " << method.method_name << ", the data of method " << i
<< "th output cannot not come from '" << input.first << "'";
<< "Model " << model_str << " method " << method.method_name << ", the method " << i
<< "th output uses postprocess " << input.second
<< "th output, that is greater than the postprocess outputs size " << postprocess_outputs_count;
}
} else {
return INFER_STATUS_LOG_ERROR(FAILED)
<< "Model " << model_str << " method " << method.method_name << ", the data of method " << i
<< "th output cannot not come from '" << input.first << "'";
}
}
return SUCCESS;
}

Status ServableSignature::Check() const {
std::set<std::string> method_set;
Status status;
for (auto &method : methods) {
if (method_set.count(method.method_name) > 0) {
return INFER_STATUS_LOG_ERROR(FAILED)
<< "Model " << servable_meta.Repr() << " " << method.method_name << " has been defined repeatedly";
}
method_set.emplace(method.method_name);

size_t preprocess_outputs_count = 0;
size_t postprocess_outputs_count = 0;
status = CheckPreprocessInput(method, &preprocess_outputs_count);
if (status != SUCCESS) {
return status;
}
status = CheckPredictInput(method, preprocess_outputs_count);
if (status != SUCCESS) {
return status;
}
status = CheckPostprocessInput(method, preprocess_outputs_count, &postprocess_outputs_count);
if (status != SUCCESS) {
return status;
}
status = CheckReturn(method, preprocess_outputs_count, postprocess_outputs_count);
if (status != SUCCESS) {
return status;
}
}
return SUCCESS;
@@ -216,7 +267,7 @@ bool ServableSignature::GetMethodDeclare(const std::string &method_name, MethodS
}

void ServableStorage::Register(const ServableSignature &def) {
auto model_name = def.servable_meta.servable_name;
auto model_name = def.servable_meta.common_meta.servable_name;
if (servable_signatures_map_.find(model_name) == servable_signatures_map_.end()) {
MSI_LOG_WARNING << "Servable " << model_name << " has already been defined";
}
@@ -258,16 +309,60 @@ Status ServableStorage::RegisterMethod(const MethodSignature &method) {
return SUCCESS;
}

void ServableStorage::DeclareServable(const mindspore::serving::ServableMeta &servable) {
MSI_LOG_INFO << "Declare servable " << servable.servable_name;
auto it = servable_signatures_map_.find(servable.servable_name);
Status ServableStorage::DeclareServable(ServableMeta servable) {
auto &common_meta = servable.common_meta;
MSI_LOG_INFO << "Declare servable " << common_meta.servable_name;
servable.servable_type = kServableTypeLocal;
if (servable.local_meta.servable_file.empty()) {
return INFER_STATUS_LOG_ERROR(FAILED)
<< "Declare servable " << common_meta.servable_name << " failed, servable_file cannot be empty";
}
if (servable.local_meta.model_format == ModelType::kUnknownType) {
return INFER_STATUS_LOG_ERROR(FAILED)
<< "Declare servable " << common_meta.servable_name << " failed, model_format is not inited";
}
auto it = servable_signatures_map_.find(common_meta.servable_name);
if (it == servable_signatures_map_.end()) {
ServableSignature signature;
signature.servable_meta = servable;
servable_signatures_map_[servable.servable_name] = signature;
return;
servable_signatures_map_[common_meta.servable_name] = signature;
return SUCCESS;
}
it->second.servable_meta = servable;
auto &org_servable_meta = it->second.servable_meta;
if (org_servable_meta.servable_type != kServableTypeUnknown) {
return INFER_STATUS_LOG_ERROR(FAILED)
<< "Servable " << common_meta.servable_name << " has already been declared as: " << servable.Repr();
}
org_servable_meta = servable;
return SUCCESS;
}

Status ServableStorage::DeclareDistributedServable(ServableMeta servable) {
auto &common_meta = servable.common_meta;
MSI_LOG_INFO << "Declare servable " << common_meta.servable_name;
servable.servable_type = kServableTypeDistributed;
if (servable.distributed_meta.rank_size == 0) {
return INFER_STATUS_LOG_ERROR(FAILED)
<< "Declare distributed servable " << common_meta.servable_name << " failed, rank_size cannot be 0";
}
if (servable.distributed_meta.stage_size == 0) {
return INFER_STATUS_LOG_ERROR(FAILED)
<< "Declare distributed servable " << common_meta.servable_name << " failed, stage_size cannot be 0";
}
auto it = servable_signatures_map_.find(common_meta.servable_name);
if (it == servable_signatures_map_.end()) {
ServableSignature signature;
signature.servable_meta = servable;
servable_signatures_map_[common_meta.servable_name] = signature;
return SUCCESS;
}
auto &org_servable_meta = it->second.servable_meta;
if (org_servable_meta.servable_type != kServableTypeUnknown) {
return INFER_STATUS_LOG_ERROR(FAILED)
<< "Servable " << common_meta.servable_name << " has already been declared as: " << servable.Repr();
}
org_servable_meta = servable;
return SUCCESS;
}

Status ServableStorage::RegisterInputOutputInfo(const std::string &servable_name, size_t inputs_count,
@@ -277,18 +372,19 @@ Status ServableStorage::RegisterInputOutputInfo(const std::string &servable_name
return INFER_STATUS_LOG_ERROR(FAILED) << "RegisterInputOutputInfo failed, cannot find servable " << servable_name;
}
auto &servable_meta = it->second.servable_meta;
if (servable_meta.inputs_count != 0 && servable_meta.inputs_count != inputs_count) {
auto &common_meta = servable_meta.common_meta;
if (common_meta.inputs_count != 0 && common_meta.inputs_count != inputs_count) {
return INFER_STATUS_LOG_ERROR(FAILED)
<< "RegisterInputOutputInfo failed, inputs count " << inputs_count << " not match old count "
<< servable_meta.inputs_count << ",servable name " << servable_name;
<< common_meta.inputs_count << ",servable name " << servable_name;
}
if (servable_meta.outputs_count != 0 && servable_meta.outputs_count != outputs_count) {
if (common_meta.outputs_count != 0 && common_meta.outputs_count != outputs_count) {
return INFER_STATUS_LOG_ERROR(FAILED)
<< "RegisterInputOutputInfo failed, outputs count " << outputs_count << " not match old count "
<< servable_meta.outputs_count << ",servable name " << servable_name;
<< common_meta.outputs_count << ",servable name " << servable_name;
}
servable_meta.inputs_count = inputs_count;
servable_meta.outputs_count = outputs_count;
common_meta.inputs_count = inputs_count;
common_meta.outputs_count = outputs_count;
return SUCCESS;
}

@@ -298,8 +394,8 @@ std::vector<size_t> ServableStorage::GetInputOutputInfo(const std::string &serva
if (it == servable_signatures_map_.end()) {
return result;
}
result.push_back(it->second.servable_meta.inputs_count);
result.push_back(it->second.servable_meta.outputs_count);
result.push_back(it->second.servable_meta.common_meta.inputs_count);
result.push_back(it->second.servable_meta.common_meta.outputs_count);
return result;
}



+ 42
- 7
mindspore_serving/ccsrc/common/servable.h View File

@@ -81,19 +81,39 @@ struct RequestSpec {
std::string Repr() const;
};

struct MS_API ServableMeta {
enum ServableType {
kServableTypeUnknown = 0,
kServableTypeLocal = 1,
kServableTypeDistributed = 2,
};

struct CommonServableMeta {
std::string servable_name;
std::string servable_file; // file name
ModelType model_format; // OM, MindIR
bool with_batch_dim = true; // whether there is batch dim in model's inputs/outputs
std::vector<int> without_batch_dim_inputs;
size_t inputs_count = 0;
size_t outputs_count = 0;
};

std::map<std::string, std::string> load_options; // Acl options
std::vector<int> without_batch_dim_inputs;
struct MS_API LocalServableMeta {
std::string servable_file; // file name
ModelType model_format = ModelType::kUnknownType; // OM, MindIR
std::map<std::string, std::string> load_options; // Acl options
void SetModelFormat(const std::string &format);
};

struct DistributedServableMeta {
size_t rank_size = 0;
size_t stage_size = 0;
};

struct MS_API ServableMeta {
ServableType servable_type = kServableTypeUnknown;
CommonServableMeta common_meta;
LocalServableMeta local_meta;
DistributedServableMeta distributed_meta;

std::string Repr() const;
void SetModelFormat(const std::string &format);
};

struct ServableSignature {
@@ -102,6 +122,12 @@ struct ServableSignature {

Status Check() const;
bool GetMethodDeclare(const std::string &method_name, MethodSignature *method);

private:
Status CheckPreprocessInput(const MethodSignature &method, size_t *pre) const;
Status CheckPredictInput(const MethodSignature &method, size_t pre) const;
Status CheckPostprocessInput(const MethodSignature &method, size_t pre, size_t *post) const;
Status CheckReturn(const MethodSignature &method, size_t pre, size_t post) const;
};

class MS_API ServableStorage {
@@ -111,7 +137,8 @@ class MS_API ServableStorage {

bool GetServableDef(const std::string &model_name, ServableSignature *def) const;

void DeclareServable(const ServableMeta &servable);
Status DeclareServable(ServableMeta servable);
Status DeclareDistributedServable(ServableMeta servable);

Status RegisterInputOutputInfo(const std::string &servable_name, size_t inputs_count, size_t outputs_count);
std::vector<size_t> GetInputOutputInfo(const std::string &servable_name) const;
@@ -144,6 +171,14 @@ static inline LogStream &operator<<(LogStream &stream, PredictPhaseTag data_type
return stream;
}

struct WorkerAgentSpec {
std::string agent_address;
uint32_t rank_id = 0;
std::vector<TensorInfo> input_infos;
std::vector<TensorInfo> output_infos;
uint32_t batch_size = 0;
};

} // namespace mindspore::serving

#endif // MINDSPORE_SERVING_SERVABLE_H

+ 1
- 1
mindspore_serving/ccsrc/master/dispacther.h View File

@@ -27,7 +27,7 @@
#include "common/instance.h"
#include "common/servable.h"
#include "master/notify_worker/base_notify.h"
#include "master/grpc/grpc_client.h"
#include "common/grpc_client.h"

namespace mindspore::serving {



+ 0
- 73
mindspore_serving/ccsrc/master/grpc/grpc_client.cc View File

@@ -1,73 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "master/grpc/grpc_client.h"
#include <string>
#include <utility>
#include "master/grpc/grpc_server.h"
namespace mindspore {
namespace serving {
std::unique_ptr<MSServiceClient> client_;
MSServiceClient::~MSServiceClient() {
if (in_running_) {
cq_.Shutdown();
if (client_thread_.joinable()) {
try {
client_thread_.join();
} catch (const std::system_error &) {
} catch (...) {
}
}
}
in_running_ = false;
}
void MSServiceClient::PredictAsync(const proto::PredictRequest &request, proto::PredictReply *reply,
std::shared_ptr<proto::MSWorker::Stub> stub, DispatchCallback callback) {
AsyncClientCall *call = new AsyncClientCall;
call->reply = reply;
call->callback = std::move(callback);
call->response_reader = stub->PrepareAsyncPredict(&call->context, request, &cq_);
call->response_reader->StartCall();
call->response_reader->Finish(call->reply, &call->status, call);
MSI_LOG(INFO) << "Finish send Predict";
}
void MSServiceClient::AsyncCompleteRpc() {
void *got_tag;
bool ok = false;
while (cq_.Next(&got_tag, &ok)) {
AsyncClientCall *call = static_cast<AsyncClientCall *>(got_tag);
if (call->status.ok()) {
call->callback(SUCCESS);
} else {
MSI_LOG_ERROR << "RPC failed: " << call->status.error_code() << ", " << call->status.error_message();
call->callback(Status(FAILED, call->status.error_message()));
}
delete call;
}
}
void MSServiceClient::Start() {
client_thread_ = std::thread(&MSServiceClient::AsyncCompleteRpc, this);
in_running_ = true;
}
} // namespace serving
} // namespace mindspore

+ 0
- 68
mindspore_serving/ccsrc/master/grpc/grpc_client.h View File

@@ -1,68 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_SERVING_MASTER_GRPC_CLIENT_H
#define MINDSPORE_SERVING_MASTER_GRPC_CLIENT_H
#include <grpcpp/grpcpp.h>
#include <grpcpp/health_check_service_interface.h>
#include <grpcpp/ext/proto_server_reflection_plugin.h>
#include <memory>
#include <functional>
#include <thread>
#include "common/serving_common.h"
#include "master/notify_worker/base_notify.h"
#include "proto/ms_service.pb.h"
#include "proto/ms_service.grpc.pb.h"
#include "proto/ms_master.pb.h"
#include "proto/ms_master.grpc.pb.h"
#include "proto/ms_worker.grpc.pb.h"
namespace mindspore {
namespace serving {
class MSServiceClient;
extern std::unique_ptr<MSServiceClient> client_;
using PredictOnFinish = std::function<void()>;
class MSServiceClient {
public:
MSServiceClient() = default;
~MSServiceClient();
void AsyncCompleteRpc();
void Start();
void PredictAsync(const proto::PredictRequest &request, proto::PredictReply *reply,
std::shared_ptr<proto::MSWorker::Stub> stub, DispatchCallback callback);
private:
struct AsyncClientCall {
grpc::ClientContext context;
grpc::Status status;
proto::PredictReply *reply;
DispatchCallback callback;
std::shared_ptr<grpc::ClientAsyncResponseReader<proto::PredictReply>> response_reader;
};
grpc::CompletionQueue cq_;
std::thread client_thread_;
bool in_running_ = false;
};
} // namespace serving
} // namespace mindspore
#endif // MINDSPORE_SERVING_MASTER_GRPC_CLIENT_H

+ 1
- 2
mindspore_serving/ccsrc/master/notify_worker/base_notify.h View File

@@ -22,12 +22,11 @@
#include "common/serving_common.h"
#include "common/servable.h"
#include "proto/ms_service.pb.h"
#include "common/grpc_client.h"

namespace mindspore {
namespace serving {

using DispatchCallback = std::function<void(Status status)>;

class MS_API BaseNotifyWorker {
public:
BaseNotifyWorker() = default;


+ 2
- 3
mindspore_serving/ccsrc/master/notify_worker/grpc_notify.cc View File

@@ -20,7 +20,6 @@
#include <thread>
#include "common/exit_handle.h"
#include "common/grpc_server.h"
#include "master/grpc/grpc_client.h"

namespace mindspore {
namespace serving {
@@ -56,10 +55,10 @@ Status GrpcNotfiyWorker::DispatchAsync(const proto::PredictRequest &request, pro
<< worker_address_;
}
if (!client_) {
client_ = std::make_unique<MSServiceClient>();
client_ = std::make_unique<MSPredictClient>();
client_->Start();
}
client_->PredictAsync(request, reply, stub_, callback);
client_->PredictAsync(request, reply, stub_.get(), callback);
return SUCCESS;
}



+ 0
- 3
mindspore_serving/ccsrc/master/server.cc View File

@@ -39,7 +39,6 @@ Status Server::StartGrpcServer(const std::string &ip, uint32_t grpc_port, int ma
if (grpc_async_server_ != nullptr) {
return INFER_STATUS_LOG_ERROR(SYSTEM_ERROR) << "Serving Error: Serving gRPC server is already running";
}
ExitSignalHandle::Instance().Start(); // handle ctrl+c to exit
if (max_msg_mb_size > gRpcMaxMBMsgSize) {
MSI_LOG_WARNING << "The maximum Serving gRPC message size is 512MB and will be updated from " << max_msg_mb_size
<< "MB to 512MB";
@@ -50,14 +49,12 @@ Status Server::StartGrpcServer(const std::string &ip, uint32_t grpc_port, int ma
}

Status Server::StartGrpcMasterServer(const std::string &ip, uint32_t grpc_port) {
ExitSignalHandle::Instance().Start(); // handle ctrl+c to exit
return grpc_manager_server_.Start(std::make_shared<MSMasterImpl>(dispatcher_), ip, grpc_port, gRpcMaxMBMsgSize,
"Master");
}

Status Server::StartRestfulServer(const std::string &ip, uint32_t restful_port, int max_msg_mb_size,
int time_out_second) {
ExitSignalHandle::Instance().Start(); // handle ctrl+c to exit
return restful_server_.Start(ip, restful_port, max_msg_mb_size, time_out_second);
}



+ 63
- 0
mindspore_serving/ccsrc/python/agent/agent_py.cc View File

@@ -0,0 +1,63 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "python/agent/agent_py.h"
#include "common/exit_handle.h"
#include "worker/distributed_worker/agent_startup.h"
#include "worker/distributed_worker/worker_agent.h"

namespace mindspore::serving {

DistributedServableConfig PyAgent::GetAgentsConfigsFromWorker(const std::string &worker_ip, uint32_t worker_port) {
auto status = WorkerAgentStartUp::Instance().GetAgentsConfigsFromWorker(worker_ip, worker_port);
if (status != SUCCESS) {
MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage();
}

DistributedServableConfig config;
status = WorkerAgentStartUp::Instance().GetDistributedServableConfig(&config);
if (status != SUCCESS) {
MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage();
}
return config;
}

void PyAgent::NotifyFailed(const std::string &worker_ip, uint32_t worker_port) {
WorkerAgentStartUp::Instance().NotifyFailed(worker_ip, worker_port);
}

void PyAgent::StartAgent(const AgentStartUpConfig &start_config) {
auto status = WorkerAgent::Instance().StartAgent(start_config);
if (status != SUCCESS) {
MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage();
}
}

void PyAgent::WaitAndClear() {
{
py::gil_scoped_release release;
ExitSignalHandle::Instance().AgentWait();
}
WorkerAgent::Instance().Clear();
MSI_LOG_INFO << "Python agent end wait and clear";
}

void PyAgent::StopAndClear() {
ExitSignalHandle::Instance().Stop();
WorkerAgent::Instance().Clear();
}

} // namespace mindspore::serving

+ 47
- 0
mindspore_serving/ccsrc/python/agent/agent_py.h View File

@@ -0,0 +1,47 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_SERVER_AGENT_PY_H
#define MINDSPORE_SERVER_AGENT_PY_H

#include <pybind11/pybind11.h>
#include <pybind11/numpy.h>
#include <pybind11/stl.h>
#include <string>
#include <memory>
#include "common/serving_common.h"
#include "worker/distributed_worker/common.h"

namespace py = pybind11;

namespace mindspore {
namespace serving {

class MS_API PyAgent {
public:
static void StartAgent(const AgentStartUpConfig &start_config);

static DistributedServableConfig GetAgentsConfigsFromWorker(const std::string &worker_ip, uint32_t worker_port);
static void WaitAndClear();
static void StopAndClear();
// from start up, not agent
static void NotifyFailed(const std::string &worker_ip, uint32_t worker_port);
};

} // namespace serving
} // namespace mindspore

#endif // MINDSPORE_SERVER_AGENT_PY_H

+ 97
- 16
mindspore_serving/ccsrc/python/serving_py.cc View File

@@ -23,10 +23,14 @@
#include "common/servable.h"
#include "worker/context.h"
#include "python/master/master_py.h"
#include "python/agent/agent_py.h"
#include "common/exit_handle.h"
#include "worker/distributed_worker/worker_agent.h"

namespace mindspore::serving {

PYBIND11_MODULE(_mindspore_serving, m) {
void PyRegServable(pybind11::module *m_ptr) {
auto &m = *m_ptr;
// avoid as numpy object memory copy in PyTensor::AsPythonData
py::class_<TensorBase, TensorBasePtr>(m, "Tensor_");

@@ -68,16 +72,30 @@ PYBIND11_MODULE(_mindspore_serving, m) {
.def_readwrite("version_number", &RequestSpec::version_number)
.def_readwrite("method_name", &RequestSpec::method_name);

py::class_<CommonServableMeta>(m, "CommonServableMeta_")
.def(py::init<>())
.def_readwrite("servable_name", &CommonServableMeta::servable_name)
.def_readwrite("inputs_count", &CommonServableMeta::inputs_count)
.def_readwrite("outputs_count", &CommonServableMeta::outputs_count)
.def_readwrite("with_batch_dim", &CommonServableMeta::with_batch_dim)
.def_readwrite("without_batch_dim_inputs", &CommonServableMeta::without_batch_dim_inputs);

py::class_<LocalServableMeta>(m, "LocalServableMeta_")
.def(py::init<>())
.def_readwrite("servable_file", &LocalServableMeta::servable_file)
.def_readwrite("options", &LocalServableMeta::load_options)
.def("set_model_format", &LocalServableMeta::SetModelFormat);

py::class_<DistributedServableMeta>(m, "DistributedServableMeta_")
.def(py::init<>())
.def_readwrite("rank_size", &DistributedServableMeta::rank_size)
.def_readwrite("stage_size", &DistributedServableMeta::stage_size);

py::class_<ServableMeta>(m, "ServableMeta_")
.def(py::init<>())
.def_readwrite("servable_name", &ServableMeta::servable_name)
.def_readwrite("inputs_count", &ServableMeta::inputs_count)
.def_readwrite("outputs_count", &ServableMeta::outputs_count)
.def_readwrite("servable_file", &ServableMeta::servable_file)
.def_readwrite("with_batch_dim", &ServableMeta::with_batch_dim)
.def_readwrite("options", &ServableMeta::load_options)
.def_readwrite("without_batch_dim_inputs", &ServableMeta::without_batch_dim_inputs)
.def("set_model_format", &ServableMeta::SetModelFormat);
.def_readwrite("common_meta", &ServableMeta::common_meta)
.def_readwrite("local_meta", &ServableMeta::local_meta)
.def_readwrite("distributed_meta", &ServableMeta::distributed_meta);

py::class_<ServableSignature>(m, "ServableSignature_")
.def(py::init<>())
@@ -87,8 +105,34 @@ PYBIND11_MODULE(_mindspore_serving, m) {
py::class_<PyServableStorage>(m, "ServableStorage_")
.def_static("register_servable_input_output_info", &PyServableStorage::RegisterInputOutputInfo)
.def_static("register_method", &PyServableStorage::RegisterMethod)
.def_static("declare_servable", &PyServableStorage::DeclareServable);
.def_static("declare_servable", &PyServableStorage::DeclareServable)
.def_static("declare_distributed_servable", &PyServableStorage::DeclareDistributedServable);

py::class_<OneRankConfig>(m, "OneRankConfig_")
.def(py::init<>())
.def_readwrite("device_id", &OneRankConfig::device_id)
.def_readwrite("ip", &OneRankConfig::ip);

py::class_<DistributedServableConfig>(m, "DistributedServableConfig_")
.def(py::init<>())
.def_readwrite("common_meta", &DistributedServableConfig::common_meta)
.def_readwrite("distributed_meta", &DistributedServableConfig::distributed_meta)
.def_readwrite("rank_table_content", &DistributedServableConfig::rank_table_content)
.def_readwrite("rank_list", &DistributedServableConfig::rank_list);
}

void PyRegMaster(pybind11::module *m_ptr) {
auto &m = *m_ptr;
py::class_<PyMaster>(m, "Master_")
.def_static("start_grpc_server", &PyMaster::StartGrpcServer)
.def_static("start_grpc_master_server", &PyMaster::StartGrpcMasterServer)
.def_static("start_restful_server", &PyMaster::StartRestfulServer)
.def_static("wait_and_clear", &PyMaster::WaitAndClear)
.def_static("stop_and_clear", &PyMaster::StopAndClear);
}

void PyRegWorker(pybind11::module *m_ptr) {
auto &m = *m_ptr;
py::class_<TaskContext>(m, "TaskContext_").def(py::init<>());

py::class_<TaskItem>(m, "TaskItem_")
@@ -108,6 +152,8 @@ PYBIND11_MODULE(_mindspore_serving, m) {
py::class_<PyWorker>(m, "Worker_")
.def_static("start_servable", &PyWorker::StartServable)
.def_static("start_servable_in_master", &PyWorker::StartServableInMaster)
.def_static("start_distributed_servable", &PyWorker::StartDistributedServable)
.def_static("start_distributed_servable_in_master", &PyWorker::StartDistributedServableInMaster)
.def_static("get_batch_size", &PyWorker::GetBatchSize)
.def_static("wait_and_clear", &PyWorker::WaitAndClear)
.def_static("stop_and_clear", PyWorker::StopAndClear)
@@ -130,17 +176,52 @@ PYBIND11_MODULE(_mindspore_serving, m) {
}
})
.def("set_device_id", &ServableContext::SetDeviceId);
}

py::class_<PyMaster, std::shared_ptr<PyMaster>>(m, "Master_")
.def_static("start_grpc_server", &PyMaster::StartGrpcServer)
.def_static("start_grpc_master_server", &PyMaster::StartGrpcMasterServer)
.def_static("start_restful_server", &PyMaster::StartRestfulServer)
.def_static("wait_and_clear", &PyMaster::WaitAndClear)
.def_static("stop_and_clear", &PyMaster::StopAndClear);
void PyRegWorkerAgent(pybind11::module *m_ptr) {
auto &m = *m_ptr;
py::class_<PyAgent>(m, "WorkerAgent_")
.def_static("get_agents_config_from_worker", &PyAgent::GetAgentsConfigsFromWorker)
.def_static("wait_and_clear", &PyAgent::WaitAndClear)
.def_static("stop_and_clear", &PyAgent::StopAndClear)
.def_static("notify_failed", &PyAgent::NotifyFailed)
.def_static("start_agent", &PyAgent::StartAgent);

py::class_<AgentStartUpConfig>(m, "AgentStartUpConfig_")
.def(py::init<>())
.def_readwrite("rank_id", &AgentStartUpConfig::rank_id)
.def_readwrite("device_id", &AgentStartUpConfig::device_id)
.def_readwrite("model_file_name", &AgentStartUpConfig::model_file_name)
.def_readwrite("group_file_name", &AgentStartUpConfig::group_file_name)
.def_readwrite("rank_table_json_file_name", &AgentStartUpConfig::rank_table_json_file_name)
.def_readwrite("agent_ip", &AgentStartUpConfig::agent_ip)
.def_readwrite("agent_port", &AgentStartUpConfig::agent_port)
.def_readwrite("worker_ip", &AgentStartUpConfig::worker_ip)
.def_readwrite("worker_port", &AgentStartUpConfig::worker_port)
.def_readwrite("common_meta", &AgentStartUpConfig::common_meta);
}

class PyExitSignalHandle {
public:
static void Start() { ExitSignalHandle::Instance().Start(); }
static bool HasStopped() { return ExitSignalHandle::Instance().HasStopped(); }
};

// cppcheck-suppress syntaxError
PYBIND11_MODULE(_mindspore_serving, m) {
PyRegServable(&m);
PyRegMaster(&m);
PyRegWorker(&m);
PyRegWorkerAgent(&m);

py::class_<PyExitSignalHandle>(m, "ExitSignalHandle_")
.def_static("start", &PyExitSignalHandle::Start)
.def_static("has_stopped", &PyExitSignalHandle::HasStopped);

(void)py::module::import("atexit").attr("register")(py::cpp_function{[&]() -> void {
Server::Instance().Clear();
Worker::GetInstance().Clear();
WorkerAgent::Instance().Clear();
}});
}



+ 10
- 1
mindspore_serving/ccsrc/python/worker/servable_py.cc View File

@@ -25,7 +25,16 @@ void PyServableStorage::RegisterMethod(const MethodSignature &method) {
}
}
void PyServableStorage::DeclareServable(const ServableMeta &servable) {
ServableStorage::Instance().DeclareServable(servable);
auto status = ServableStorage::Instance().DeclareServable(servable);
if (status != SUCCESS) {
MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage();
}
}
void PyServableStorage::DeclareDistributedServable(const ServableMeta &servable) {
auto status = ServableStorage::Instance().DeclareDistributedServable(servable);
if (status != SUCCESS) {
MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage();
}
}
void PyServableStorage::RegisterInputOutputInfo(const std::string &servable_name, size_t inputs_count,
size_t outputs_count) {


+ 1
- 0
mindspore_serving/ccsrc/python/worker/servable_py.h View File

@@ -27,6 +27,7 @@ class MS_API PyServableStorage {
static void RegisterMethod(const MethodSignature &method);

static void DeclareServable(const ServableMeta &servable);
static void DeclareDistributedServable(const ServableMeta &servable);

static void RegisterInputOutputInfo(const std::string &servable_name, size_t inputs_count, size_t outputs_count);
static void Clear();


+ 80
- 6
mindspore_serving/ccsrc/python/worker/worker_py.cc View File

@@ -21,21 +21,33 @@
#include "common/exit_handle.h"
#include "worker/notfiy_master/grpc_notify.h"
#include "worker/notfiy_master/local_notify.h"
#include "worker/local_servable/local_sevable.h"
#include "worker/distributed_worker/distributed_servable.h"
#include "worker/grpc/worker_server.h"
#include "worker/distributed_worker/distributed_process/distributed_server.h"

namespace mindspore::serving {

void PyWorker::StartServable(const std::string &model_directory, const std::string &model_name, uint32_t version_number,
const std::string &master_ip, uint32_t master_port, const std::string &host_ip,
uint32_t host_port) {
auto notify_master = std::make_shared<GrpcNotfiyMaster>(master_ip, master_port, host_ip, host_port);
auto status = Worker::GetInstance().StartServable(model_directory, model_name, version_number, notify_master);
const std::string &master_ip, uint32_t master_port, const std::string &worker_ip,
uint32_t worker_port) {
auto notify_master = std::make_shared<GrpcNotfiyMaster>(master_ip, master_port, worker_ip, worker_port);
auto servable = std::make_shared<LocalModelServable>();
auto status = servable->StartServable(model_directory, model_name, version_number);
if (status != SUCCESS) {
MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage();
}
status = Worker::GetInstance().StartGrpcServer(host_ip, host_port);
status = Worker::GetInstance().StartServable(servable, notify_master);
if (status != SUCCESS) {
MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage();
}
// start grpc server
auto grpc_sever = std::make_shared<MSWorkerServer>();
status = Worker::GetInstance().StartGrpcServer(grpc_sever, worker_ip, worker_port);
if (status != SUCCESS) {
MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage();
}

status = Worker::GetInstance().StartVersionController();
if (status != SUCCESS) {
MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage();
@@ -45,7 +57,69 @@ void PyWorker::StartServable(const std::string &model_directory, const std::stri
void PyWorker::StartServableInMaster(const std::string &model_directory, const std::string &model_name,
uint32_t version_number) {
auto notify_master = std::make_shared<LocalNotifyMaster>();
auto status = Worker::GetInstance().StartServable(model_directory, model_name, version_number, notify_master);
auto servable = std::make_shared<LocalModelServable>();
auto status = servable->StartServable(model_directory, model_name, version_number);
if (status != SUCCESS) {
MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage();
}
status = Worker::GetInstance().StartServable(servable, notify_master);
if (status != SUCCESS) {
MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage();
}
status = Worker::GetInstance().StartVersionController();
if (status != SUCCESS) {
MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage();
}
}

void PyWorker::StartDistributedServable(const std::string &servable_directory, const std::string &servable_name,
const std::string &rank_table_json_file, uint32_t version_number,
const std::string &worker_ip, uint32_t worker_port,
const std::string &master_ip, uint32_t master_port,
uint32_t wait_agents_time_in_seconds) {
Status status;
auto servable = std::make_shared<DistributedServable>();
auto grpc_sever = std::make_shared<MSDistributedWorkerServer>(servable);
status = Worker::GetInstance().StartGrpcServer(grpc_sever, worker_ip, worker_port);
if (status != SUCCESS) {
MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage();
}

auto notify_master = std::make_shared<GrpcNotfiyMaster>(master_ip, master_port, worker_ip, worker_port);
status = servable->StartServable(servable_directory, servable_name, rank_table_json_file, version_number,
wait_agents_time_in_seconds);
if (status != SUCCESS) {
MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage();
}
status = Worker::GetInstance().StartServable(servable, notify_master);
if (status != SUCCESS) {
MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage();
}
status = Worker::GetInstance().StartVersionController();
if (status != SUCCESS) {
MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage();
}
}

void PyWorker::StartDistributedServableInMaster(const std::string &servable_directory, const std::string &servable_name,
const std::string &rank_table_json_file, uint32_t version_number,
const std::string &worker_ip, uint32_t worker_port,
uint32_t wait_agents_time_in_seconds) {
Status status;
auto servable = std::make_shared<DistributedServable>();
auto grpc_sever = std::make_shared<MSDistributedWorkerServer>(servable);
status = Worker::GetInstance().StartGrpcServer(grpc_sever, worker_ip, worker_port);
if (status != SUCCESS) {
MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage();
}

auto notify_master = std::make_shared<LocalNotifyMaster>();
status = servable->StartServable(servable_directory, servable_name, rank_table_json_file, version_number,
wait_agents_time_in_seconds);
if (status != SUCCESS) {
MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage();
}
status = Worker::GetInstance().StartServable(servable, notify_master);
if (status != SUCCESS) {
MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage();
}


+ 10
- 0
mindspore_serving/ccsrc/python/worker/worker_py.h View File

@@ -34,6 +34,16 @@ class MS_API PyWorker {
static void StartServableInMaster(const std::string &model_directory, const std::string &model_name,
uint32_t version_number);

static void StartDistributedServable(const std::string &servable_directory, const std::string &servable_name,
const std::string &rank_table_json_file, uint32_t version_number,
const std::string &worker_ip, uint32_t worker_port, const std::string &master_ip,
uint32_t master_port, uint32_t wait_agents_time_in_seconds);

static void StartDistributedServableInMaster(const std::string &servable_directory, const std::string &servable_name,
const std::string &rank_table_json_file, uint32_t version_number,
const std::string &worker_ip, uint32_t worker_port,
uint32_t wait_agents_time_in_seconds);

static int GetBatchSize();
static void WaitAndClear();
static void StopAndClear();


+ 34
- 0
mindspore_serving/ccsrc/worker/distributed_worker/agent_executor.cc View File

@@ -0,0 +1,34 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "worker/distributed_worker/agent_executor.h"

namespace mindspore {
namespace serving {

Status WorkerAgentExecutor::LoadModelFromFile(const AgentStartUpConfig &config) { return Status(); }
Status WorkerAgentExecutor::UnloadModel() { return Status(); }
Status WorkerAgentExecutor::ExecuteModel(const std::vector<TensorBasePtr> &request, std::vector<TensorBasePtr> *reply) {
return Status();
}
std::vector<serving::TensorInfo> WorkerAgentExecutor::GetInputInfos() const {
return std::vector<serving::TensorInfo>();
}
std::vector<serving::TensorInfo> WorkerAgentExecutor::GetOutputInfos() const {
return std::vector<serving::TensorInfo>();
}
ssize_t WorkerAgentExecutor::GetBatchSize() const { return 0; }
} // namespace serving
} // namespace mindspore

+ 48
- 0
mindspore_serving/ccsrc/worker/distributed_worker/agent_executor.h View File

@@ -0,0 +1,48 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_SERVING_WORKER_AGENT_EXECUTOR_H
#define MINDSPORE_SERVING_WORKER_AGENT_EXECUTOR_H

#include <vector>
#include "common/serving_common.h"
#include "worker/inference/inference.h"
#include "worker/distributed_worker/common.h"

namespace mindspore {
namespace serving {
class MS_API WorkerAgentExecutor {
public:
// from python
Status LoadModelFromFile(const AgentStartUpConfig &config);
// ctrl+c, worker exit
Status UnloadModel();

// from worker
Status ExecuteModel(const std::vector<TensorBasePtr> &request, std::vector<TensorBasePtr> *reply);

// for register
std::vector<serving::TensorInfo> GetInputInfos() const;

std::vector<serving::TensorInfo> GetOutputInfos() const;

ssize_t GetBatchSize() const;
};

} // namespace serving
} // namespace mindspore

#endif // MINDSPORE_SERVING_WORKER_AGENT_EXECUTOR_H

+ 37
- 0
mindspore_serving/ccsrc/worker/distributed_worker/agent_process/agent_process.cc View File

@@ -0,0 +1,37 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "worker/distributed_worker/agent_process/agent_process.h"
#include "worker/distributed_worker/worker_agent.h"
namespace mindspore {
namespace serving {
grpc::Status MSAgentImpl::Exit(grpc::ServerContext *context, const proto::DistributedExitRequest *request,
proto::DistributedExitReply *reply) {
MSI_LOG(INFO) << "Distributed Worker Exit";
WorkerAgent::Instance().StopAgent(false);
return grpc::Status::OK;
}
grpc::Status MSAgentImpl::Predict(grpc::ServerContext *context, const proto::DistributedPredictRequest *request,
proto::DistributedPredictReply *reply) {
MSI_LOG(INFO) << "Begin call service Eval";
WorkerAgent::Instance().Run(*request, reply);
return grpc::Status::OK;
}
} // namespace serving
} // namespace mindspore

+ 42
- 0
mindspore_serving/ccsrc/worker/distributed_worker/agent_process/agent_process.h View File

@@ -0,0 +1,42 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_SERVING_WORKER_AGENT_PROCESS_H
#define MINDSPORE_SERVING_WORKER_AGENT_PROCESS_H
#include <grpcpp/grpcpp.h>
#include <grpcpp/health_check_service_interface.h>
#include <grpcpp/ext/proto_server_reflection_plugin.h>
#include "common/serving_common.h"
#include "proto/ms_agent.pb.h"
#include "proto/ms_agent.grpc.pb.h"
namespace mindspore {
namespace serving {
// Service Implement
class MSAgentImpl final : public proto::MSAgent::Service {
public:
grpc::Status Predict(grpc::ServerContext *context, const proto::DistributedPredictRequest *request,
proto::DistributedPredictReply *reply) override;
grpc::Status Exit(grpc::ServerContext *context, const proto::DistributedExitRequest *request,
proto::DistributedExitReply *reply) override;
};
} // namespace serving
} // namespace mindspore
#endif // MINDSPORE_SERVING_WORKER_AGENT_PROCESS_H

+ 45
- 0
mindspore_serving/ccsrc/worker/distributed_worker/agent_startup.cc View File

@@ -0,0 +1,45 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "worker/distributed_worker/agent_startup.h"
#include "worker/distributed_worker/notify_distributed/notify_worker.h"

namespace mindspore {
namespace serving {

WorkerAgentStartUp &WorkerAgentStartUp::Instance() {
static WorkerAgentStartUp instance;
return instance;
}

Status WorkerAgentStartUp::GetAgentsConfigsFromWorker(const std::string &worker_ip, uint32_t worker_port) {
return Status();
}

Status WorkerAgentStartUp::GetDistributedServableConfig(DistributedServableConfig *config) {
MSI_EXCEPTION_IF_NULL(config);
if (config_.rank_list.empty()) {
return INFER_STATUS_LOG_ERROR(FAILED) << "Rank table config is not ready";
}
*config = config_;
return SUCCESS;
}

Status WorkerAgentStartUp::NotifyFailed(const std::string &worker_ip, uint32_t worker_port) {
return GrpcNotifyDistributeWorker::NotifyFailed(worker_ip, worker_port);
}

} // namespace serving
} // namespace mindspore

+ 48
- 0
mindspore_serving/ccsrc/worker/distributed_worker/agent_startup.h View File

@@ -0,0 +1,48 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_SERVING_WORKER_AGENT_STARTUP_H
#define MINDSPORE_SERVING_WORKER_AGENT_STARTUP_H
#include <vector>
#include <string>
#include "common/serving_common.h"
#include "worker/distributed_worker/common.h"
#include "worker/inference/inference.h"

namespace mindspore {
namespace serving {

class MS_API WorkerAgentStartUp {
public:
static WorkerAgentStartUp &Instance();
// from python, worker_agent.py
// start_worker_agent
// step1, get agents config from worker
Status GetAgentsConfigsFromWorker(const std::string &worker_ip, uint32_t worker_port);
// step2, invoke from python
Status GetDistributedServableConfig(DistributedServableConfig *config);

Status NotifyFailed(const std::string &worker_ip, uint32_t worker_port);

private:
DistributedServableConfig config_;
std::string worker_address_;
};

} // namespace serving
} // namespace mindspore

#endif // MINDSPORE_SERVING_WORKER_AGENT_STARTUP_H

+ 61
- 0
mindspore_serving/ccsrc/worker/distributed_worker/common.h View File

@@ -0,0 +1,61 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_SERVING_DISTRIBUTED_WORKER_COMMON_H
#define MINDSPORE_SERVING_DISTRIBUTED_WORKER_COMMON_H

#include <vector>
#include <string>
#include <map>
#include "common/serving_common.h"
#include "worker/inference/inference.h"
#include "common/servable.h"

namespace mindspore {
namespace serving {

struct OneRankConfig {
std::string ip;
uint32_t device_id = 0;
};

struct DistributedServableConfig {
std::string rank_table_content;
std::vector<OneRankConfig> rank_list;

CommonServableMeta common_meta;
DistributedServableMeta distributed_meta;
};

struct AgentStartUpConfig {
uint32_t rank_id;
uint32_t device_id;
std::string model_file_name;
std::string group_file_name;
std::string rank_table_json_file_name;

std::string agent_ip;
uint32_t agent_port;
std::string worker_ip;
uint32_t worker_port;

CommonServableMeta common_meta;
};

} // namespace serving
} // namespace mindspore

#endif // MINDSPORE_SERVING_DISTRIBUTED_WORKER_COMMON_H

+ 72
- 0
mindspore_serving/ccsrc/worker/distributed_worker/distributed_process/distributed_process.cc View File

@@ -0,0 +1,72 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "worker/distributed_worker/distributed_process/distributed_process.h"
#include "worker/worker.h"
#include "common/proto_tensor.h"

namespace mindspore {
namespace serving {

grpc::Status MSDistributedImpl::AgentRegister(grpc::ServerContext *context, const proto::AgentRegisterRequest *request,
proto::AgentRegisterReply *reply) {
MSI_EXCEPTION_IF_NULL(request);
MSI_EXCEPTION_IF_NULL(reply);
for (auto &spec : request->agent_spec()) {
WorkerAgentSpec agent_spec;
agent_spec.agent_address = request->address();
GrpcTensorHelper::CopyFromAgentSpec(spec, &agent_spec);
Status status(FAILED);
status = servable_->RegisterAgent(agent_spec);
if (status != SUCCESS) {
MSI_LOG(ERROR) << "Agent Register FAILED";
}
}
return grpc::Status::OK;
}

grpc::Status MSDistributedImpl::AgentExit(grpc::ServerContext *context, const proto::AgentExitRequest *request,
proto::AgentExitReply *reply) {
MSI_EXCEPTION_IF_NULL(request);
MSI_EXCEPTION_IF_NULL(reply);
for (auto &spec : request->agent_spec()) {
WorkerAgentSpec agent_spec;
agent_spec.agent_address = request->address();
GrpcTensorHelper::CopyFromAgentSpec(spec, &agent_spec);
Status status(FAILED);
status = servable_->UnregisterAgent(agent_spec);
if (status != SUCCESS) {
MSI_LOG(ERROR) << "Agent Exit FAILED";
}
}
if (Worker::GetInstance().IsRunning()) {
Worker::GetInstance().StopServable();
}
return grpc::Status::OK;
}

grpc::Status MSDistributedImpl::AgentFailed(grpc::ServerContext *context, const proto::AgentFailedRequest *request,
proto::AgentFailedReply *reply) {
if (Worker::GetInstance().IsRunning()) {
MSI_LOG_ERROR << "Expect worker should not be running";
Worker::GetInstance().StopServable();
} else {
servable_->OnAgentFailed();
}
return grpc::Status::OK;
}
} // namespace serving
} // namespace mindspore

+ 54
- 0
mindspore_serving/ccsrc/worker/distributed_worker/distributed_process/distributed_process.h View File

@@ -0,0 +1,54 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_SERVING_DISTRIBUTED_WORKER_WORKER_PROCESS_H
#define MINDSPORE_SERVING_DISTRIBUTED_WORKER_WORKER_PROCESS_H
#include <grpcpp/grpcpp.h>
#include <grpcpp/health_check_service_interface.h>
#include <grpcpp/ext/proto_server_reflection_plugin.h>
#include <memory>
#include "common/serving_common.h"
#include "proto/ms_service.pb.h"
#include "proto/ms_service.grpc.pb.h"
#include "proto/ms_distributed.pb.h"
#include "proto/ms_distributed.grpc.pb.h"
#include "worker/distributed_worker/distributed_servable.h"
#include "worker/grpc/worker_process.h"
namespace mindspore {
namespace serving {
// Service Implement
class MSDistributedImpl final : public MSWorkerImpl {
public:
explicit MSDistributedImpl(std::shared_ptr<DistributedServable> servable) : servable_(servable) {}
~MSDistributedImpl() = default;
grpc::Status AgentRegister(grpc::ServerContext *context, const proto::AgentRegisterRequest *request,
proto::AgentRegisterReply *reply) override;
grpc::Status AgentExit(grpc::ServerContext *context, const proto::AgentExitRequest *request,
proto::AgentExitReply *reply) override;
grpc::Status AgentFailed(grpc::ServerContext *context, const proto::AgentFailedRequest *request,
proto::AgentFailedReply *reply) override;
private:
std::shared_ptr<DistributedServable> servable_;
};
} // namespace serving
} // namespace mindspore
#endif // MINDSPORE_SERVING_DISTRIBUTED_WORKER_WORKER_PROCESS_H

+ 37
- 0
mindspore_serving/ccsrc/worker/distributed_worker/distributed_process/distributed_server.cc View File

@@ -0,0 +1,37 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "worker/distributed_worker/distributed_process/distributed_server.h"
#include <string>
#include <memory>
#include <utility>
#include "common/grpc_server.h"
namespace mindspore {
namespace serving {
Status MSDistributedWorkerServer::StartWorkerGrpcServer(const std::string &hostname, int32_t port) {
if (in_running_) {
return INFER_STATUS_LOG_ERROR(FAILED) << "Worker grpc server is already running";
}
auto impl = std::make_unique<MSDistributedImpl>(servable_);
async_server_ = std::make_unique<DistributedWorkerGrpcServer>(hostname, port, impl.get());
service_impl_ = std::move(impl);
return Init();
}
} // namespace serving
} // namespace mindspore

+ 178
- 0
mindspore_serving/ccsrc/worker/distributed_worker/distributed_process/distributed_server.h View File

@@ -0,0 +1,178 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_SERVING_WORKER_DISTRIBUTED_WORKER_SERVER_H
#define MINDSPORE_SERVING_WORKER_DISTRIBUTED_WORKER_SERVER_H
#include <grpcpp/grpcpp.h>
#include <grpcpp/health_check_service_interface.h>
#include <grpcpp/ext/proto_server_reflection_plugin.h>
#include <memory>
#include <string>
#include "common/serving_common.h"
#include "proto/ms_worker.pb.h"
#include "proto/ms_worker.grpc.pb.h"
#include "common/grpc_async_server.h"
#include "worker/grpc/worker_process.h"
#include "worker/grpc/worker_server.h"
#include "worker/distributed_worker/distributed_process/distributed_process.h"
namespace mindspore {
namespace serving {
// Service Implement
class MS_API MSDistributedWorkerServer : public MSWorkerServer {
public:
explicit MSDistributedWorkerServer(std::shared_ptr<DistributedServable> servable) : servable_(servable) {}
~MSDistributedWorkerServer() = default;
Status StartWorkerGrpcServer(const std::string &hostname, int32_t port) override;
private:
std::shared_ptr<DistributedServable> servable_;
};
class DistributedServiceContext : public WorkerServiceContext {
public:
DistributedServiceContext(MSDistributedImpl *service_impl, proto::MSWorker::AsyncService *async_service,
grpc::ServerCompletionQueue *cq)
: WorkerServiceContext(service_impl, async_service, cq), dist_service_impl_(service_impl) {}
protected:
MSDistributedImpl *dist_service_impl_ = nullptr;
};
// Service Implement
class WorkerAgentRegisterContext : public DistributedServiceContext {
public:
WorkerAgentRegisterContext(MSDistributedImpl *service_impl, proto::MSWorker::AsyncService *async_service,
grpc::ServerCompletionQueue *cq)
: DistributedServiceContext(service_impl, async_service, cq), responder_(&ctx_) {}
~WorkerAgentRegisterContext() = default;
static Status EnqueueRequest(MSDistributedImpl *service_impl, proto::MSWorker::AsyncService *async_service,
grpc::ServerCompletionQueue *cq) {
auto call = new WorkerAgentRegisterContext(service_impl, async_service, cq);
call->StartEnqueueRequest();
return SUCCESS;
}
void StartEnqueueRequest() override {
state_ = STATE::PROCESS;
async_service_->RequestAgentRegister(&ctx_, &request_, &responder_, cq_, cq_, this);
}
void HandleRequest() override {
EnqueueRequest(dist_service_impl_, async_service_, cq_);
state_ = STATE::FINISH;
grpc::Status status = dist_service_impl_->AgentRegister(&ctx_, &request_, &response_);
responder_.Finish(response_, status, this);
}
private:
grpc::ServerAsyncResponseWriter<proto::AgentRegisterReply> responder_;
proto::AgentRegisterRequest request_;
proto::AgentRegisterReply response_;
};
class WorkerAgentExitContext : public DistributedServiceContext {
public:
WorkerAgentExitContext(MSDistributedImpl *service_impl, proto::MSWorker::AsyncService *async_service,
grpc::ServerCompletionQueue *cq)
: DistributedServiceContext(service_impl, async_service, cq), responder_(&ctx_) {}
~WorkerAgentExitContext() = default;
static Status EnqueueRequest(MSDistributedImpl *service_impl, proto::MSWorker::AsyncService *async_service,
grpc::ServerCompletionQueue *cq) {
auto call = new WorkerAgentExitContext(service_impl, async_service, cq);
call->StartEnqueueRequest();
return SUCCESS;
}
void StartEnqueueRequest() override {
state_ = STATE::PROCESS;
async_service_->RequestAgentExit(&ctx_, &request_, &responder_, cq_, cq_, this);
}
void HandleRequest() override {
EnqueueRequest(dist_service_impl_, async_service_, cq_);
state_ = STATE::FINISH;
grpc::Status status = dist_service_impl_->AgentExit(&ctx_, &request_, &response_);
responder_.Finish(response_, status, this);
}
private:
grpc::ServerAsyncResponseWriter<proto::AgentExitReply> responder_;
proto::AgentExitRequest request_;
proto::AgentExitReply response_;
};
class WorkerAgentFailedContext : public DistributedServiceContext {
public:
WorkerAgentFailedContext(MSDistributedImpl *service_impl, proto::MSWorker::AsyncService *async_service,
grpc::ServerCompletionQueue *cq)
: DistributedServiceContext(service_impl, async_service, cq), responder_(&ctx_) {}
~WorkerAgentFailedContext() = default;
static Status EnqueueRequest(MSDistributedImpl *service_impl, proto::MSWorker::AsyncService *async_service,
grpc::ServerCompletionQueue *cq) {
auto call = new WorkerAgentFailedContext(service_impl, async_service, cq);
call->StartEnqueueRequest();
return SUCCESS;
}
void StartEnqueueRequest() override {
state_ = STATE::PROCESS;
async_service_->RequestAgentFailed(&ctx_, &request_, &responder_, cq_, cq_, this);
}
void HandleRequest() override {
EnqueueRequest(dist_service_impl_, async_service_, cq_);
state_ = STATE::FINISH;
grpc::Status status = dist_service_impl_->AgentFailed(&ctx_, &request_, &response_);
responder_.Finish(response_, status, this);
}
private:
grpc::ServerAsyncResponseWriter<proto::AgentFailedReply> responder_;
proto::AgentFailedRequest request_;
proto::AgentFailedReply response_;
};
class DistributedWorkerGrpcServer : public WorkerGrpcServer {
public:
DistributedWorkerGrpcServer(const std::string &host, int32_t port, MSDistributedImpl *service_impl)
: WorkerGrpcServer(host, port, service_impl), distributed_service_impl_(service_impl) {}
~DistributedWorkerGrpcServer() = default;
Status EnqueueRequest() {
WorkerGrpcServer::EnqueueRequest();
WorkerAgentRegisterContext::EnqueueRequest(distributed_service_impl_, &svc_, cq_.get());
WorkerAgentExitContext::EnqueueRequest(distributed_service_impl_, &svc_, cq_.get());
WorkerAgentFailedContext::EnqueueRequest(distributed_service_impl_, &svc_, cq_.get());
return SUCCESS;
}
private:
MSDistributedImpl *distributed_service_impl_;
};
} // namespace serving
} // namespace mindspore
#endif // MINDSPORE_SERVING_WORKER_DISTRIBUTED_WORKER_SERVER_H

+ 335
- 0
mindspore_serving/ccsrc/worker/distributed_worker/distributed_servable.cc View File

@@ -0,0 +1,335 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "worker/distributed_worker/distributed_servable.h"
#include <vector>
#include <string>
#include <set>
#include "worker/distributed_worker/notify_agent/notify_agent.h"
#include "common/exit_handle.h"

namespace mindspore {
namespace serving {

DistributedServable::~DistributedServable() { Clear(); }

std::string DistributedServable::GetServableName() const { return servable_name_; }

uint64_t DistributedServable::GetServableVersion() const { return version_number_; }

Status DistributedServable::Predict(const std::vector<TensorBasePtr> &input, std::vector<TensorBasePtr> *output) {
if (!model_loaded_) {
MSI_LOG_EXCEPTION << "Model has not been loaded";
}
return Status();
}
std::vector<TensorInfo> DistributedServable::GetInputInfos() const {
if (!model_loaded_) {
MSI_LOG_EXCEPTION << "Model has not been loaded";
}
return input_infos_;
}

std::vector<TensorInfo> DistributedServable::GetOutputInfos() const {
if (!model_loaded_) {
MSI_LOG_EXCEPTION << "Model has not been loaded";
}
return output_infos_;
}

uint64_t DistributedServable::GetBatchSize() const {
if (!model_loaded_) {
MSI_LOG_EXCEPTION << "Model has not been loaded";
}
return batch_size_;
}

Status DistributedServable::GetDistributedServableConfig(DistributedServableConfig *config) const {
*config = config_;
return SUCCESS;
}

void DistributedServable::SetWaitAgentsPromise(bool flag) {
if (!promise_set_flag_.test_and_set()) {
agents_promise_.set_value(flag);
}
}

Status DistributedServable::RegisterAgent(const WorkerAgentSpec &agent_spec) {
std::unique_lock<std::mutex> lock{mutex_};

if (agent_spec.rank_id < config_.distributed_meta.rank_size) {
return INFER_STATUS_LOG_ERROR(FAILED)
<< "Invalid rank id " << agent_spec.rank_id << ", rank size " << config_.distributed_meta.rank_size;
}
DistributedAgentContext context;
auto it = agent_spec_map_.find(agent_spec.rank_id);
if (it != agent_spec_map_.end()) {
MSI_LOG_WARNING << "rank_id " << agent_spec.rank_id << " has been registered";
return SUCCESS;
}
context.agent_spec_ = agent_spec;
std::shared_ptr<BaseNotifyAgent> notify_agent = std::make_shared<GrpcNotfiyAgent>(agent_spec.agent_address);
context.notify_agent_ = notify_agent;
agent_spec_map_[agent_spec.rank_id] = context;

if (agent_spec_map_.size() >= config_.distributed_meta.rank_size) {
SetWaitAgentsPromise(true);
}
return SUCCESS;
}

void DistributedServable::Clear() {
std::unique_lock<std::mutex> lock{mutex_};
for (auto &agent : agent_spec_map_) {
agent.second.notify_agent_->Exit();
}
agent_spec_map_.clear();
MSI_LOG_INFO << "End Clear servable";
}

Status DistributedServable::UnregisterAgent(const WorkerAgentSpec &agent_spec) {
std::unique_lock<std::mutex> lock{mutex_};
for (auto iter = agent_spec_map_.begin(); iter != agent_spec_map_.end();) {
if (agent_spec.rank_id == iter->second.agent_spec_.rank_id) {
iter = agent_spec_map_.erase(iter);
} else {
++iter;
}
}
SetWaitAgentsPromise(false);
return SUCCESS;
}

Status DistributedServable::StartServable(const std::string &servable_directory, const std::string &servable_name,
const std::string &rank_table_json_file, uint64_t version_number,
uint64_t wait_agents_time_in_seconds) {
if (model_loaded_) {
MSI_LOG_EXCEPTION << "Model has loaded";
}
version_number_ = version_number;
servable_name_ = servable_name;
rank_table_json_file_ = rank_table_json_file;
ServableSignature signature;
if (!ServableStorage::Instance().GetServableDef(servable_name, &signature)) {
return INFER_STATUS_LOG_ERROR(FAILED) << "Servable '" << servable_name << "' has not been registered";
}
auto &meta = signature.servable_meta;
if (meta.servable_type != kServableTypeDistributed) {
return INFER_STATUS_LOG_ERROR(FAILED)
<< "Servable '" << servable_name << "' is not registered as distributed servable, " << meta.Repr();
}
config_.common_meta = meta.common_meta;
config_.distributed_meta = meta.distributed_meta;

auto status = InitConfigOnStartup(rank_table_json_file_);
if (status != SUCCESS) {
MSI_LOG_ERROR << "Init with rank table on start up failed";
return status;
}
status = CheckRankConfig();
if (status != SUCCESS) {
MSI_LOG_ERROR << "Check rank config failed";
return status;
}
status = WaitAgentsReady(wait_agents_time_in_seconds);
if (status != SUCCESS) {
MSI_LOG_ERROR << "Waiting for ready of agents failed";
return status;
}
status = CheckAgentsInfosAndInitTensorInfos();
if (status != SUCCESS) {
MSI_LOG_ERROR << "Check agents infos failed";
return status;
}
model_loaded_ = true;
return SUCCESS;
}

Status DistributedServable::InitConfigOnStartup(const std::string &rank_table_json_file) { return FAILED; }

Status DistributedServable::WaitAgentsReady(uint64_t wait_agents_time_in_seconds) {
auto future = agents_promise_.get_future();
if (wait_agents_time_in_seconds == 0) {
wait_agents_time_in_seconds = UINT32_MAX;
}
const uint64_t kWaitMaxHundredMs = wait_agents_time_in_seconds * 10;
uint64_t i;
for (i = 0; i < kWaitMaxHundredMs; i++) { //
if (ExitSignalHandle::Instance().HasStopped()) {
return INFER_STATUS_LOG_ERROR(FAILED) << "Agents has stopped";
}
// waiting for 100ms
if (future.wait_for(std::chrono::milliseconds(100)) == std::future_status::ready) {
auto flag = future.get();
if (!flag) {
return INFER_STATUS_LOG_ERROR(FAILED) << "Failed to starting all agents, maybe some error reported";
}
break;
}
}
if (i >= kWaitMaxHundredMs) {
return INFER_STATUS_LOG_ERROR(FAILED)
<< "Failed to wait for ready of all agents, current agents count: " << agent_spec_map_.size()
<< ", rank size: " << config_.distributed_meta.rank_size;
}
return SUCCESS;
}

Status DistributedServable::CompareTensorInfos(const std::vector<TensorInfo> &lefts,
const std::vector<TensorInfo> &rights) {
if (lefts.size() != rights.size()) {
return INFER_STATUS(FAILED) << "Size not match, left: " << lefts.size() << ", right: " << rights.size();
}
auto tensor_info_as_str = [](const TensorInfo &tensor_info) {
Status status = INFER_STATUS(SUCCESS) << "size: " << tensor_info.size << ", data type: " << tensor_info.data_type
<< ", shape: " << tensor_info.shape;
return status.StatusMessage();
};
for (size_t k = 0; k < lefts.size(); k++) {
auto &left = lefts[k];
auto &right = rights[k];
if (left.size != right.size || left.shape != right.shape || left.data_type != right.data_type) {
return INFER_STATUS(FAILED) << "Index " << k << " tensor not match, left- " << tensor_info_as_str(left)
<< "; right- " << tensor_info_as_str(right);
}
}
return SUCCESS;
}

Status DistributedServable::CheckAgentsInfosAndInitTensorInfos() {
auto rank_size = config_.distributed_meta.rank_size;
auto stage_size = config_.distributed_meta.stage_size;
auto parallel_count = rank_size / stage_size;
MSI_LOG_INFO << "Check agents infos, rank size :" << rank_size << ", stage size: " << stage_size
<< ", parallel count: " << parallel_count;
if (agent_spec_map_.size() != rank_size) {
return INFER_STATUS_LOG_ERROR(FAILED)
<< "Registered agents size " << agent_spec_map_.size() << " not match rank size " << rank_size;
}

input_infos_ = agent_spec_map_[0].agent_spec_.input_infos;
output_infos_ = agent_spec_map_[rank_size - 1].agent_spec_.output_infos;
batch_size_ = agent_spec_map_[0].agent_spec_.batch_size;
if (input_infos_.empty()) {
return INFER_STATUS_LOG_ERROR(FAILED) << "Rank " << 0 << " input count cannot be 0";
}
if (output_infos_.empty()) {
return INFER_STATUS_LOG_ERROR(FAILED) << "Rank " << rank_size - 1 << " output count cannot be 0";
}
Status status;
for (size_t i = 0; i < parallel_count; i++) {
auto &agent_spec = agent_spec_map_[i];
status = CompareTensorInfos(agent_spec.agent_spec_.input_infos, input_infos_);
if (status != SUCCESS) {
status = INFER_STATUS_LOG_ERROR(FAILED)
<< "Rank " << i << " input infos not match rank 0, details: " << status.StatusMessage();
return status;
}
}
for (size_t i = parallel_count; i < rank_size; i++) {
auto &agent_spec = agent_spec_map_[i];
if (!agent_spec.agent_spec_.input_infos.empty()) {
return INFER_STATUS_LOG_ERROR(FAILED) << "Expect rank " << i << " input count equal to 0";
}
}
for (size_t i = 0; i < rank_size; i++) {
auto &first_item = agent_spec_map_[i];
for (size_t k = 0; k < parallel_count && i + k < rank_size; k++) {
auto rank_id = i + k;
auto &agent_spec = agent_spec_map_[i + k];
status = CompareTensorInfos(agent_spec.agent_spec_.output_infos, first_item.agent_spec_.output_infos);
if (status != SUCCESS) {
status = INFER_STATUS_LOG_ERROR(FAILED) << "Rank " << rank_size << " output infos not match rank " << i
<< ", details: " << status.StatusMessage();
return status;
}
if (agent_spec.agent_spec_.batch_size != 0 && agent_spec.agent_spec_.batch_size != batch_size_) {
return INFER_STATUS_LOG_ERROR(FAILED)
<< "Expect rank " << rank_id << " batch size equal to 0 or rank 0 batch size " << batch_size_;
}
}
}
return SUCCESS;
}

Status DistributedServable::CheckRankConfig() {
auto rank_size = config_.distributed_meta.rank_size;
auto stage_size = config_.distributed_meta.stage_size;
if (stage_size == 0 || rank_size == 0) {
return INFER_STATUS_LOG_ERROR(FAILED)
<< "Rank size or stage size cannot be 0, rank size: " << rank_size << ", stage size: " << stage_size;
}
if (rank_size % stage_size != 0) {
return INFER_STATUS_LOG_ERROR(FAILED)
<< "Rank size must be an integral multiple of stage size, rank size: " << rank_size
<< ", stage size: " << stage_size;
}
if (config_.rank_list.size() != rank_size) {
return INFER_STATUS_LOG_ERROR(FAILED)
<< "Rank size " << config_.rank_list.size() << " declared in rank table file not equal to rank size "
<< rank_size << " declared in servable_config, rank json config file: " << rank_table_json_file_;
}
auto parallel_count = rank_size / stage_size;
constexpr size_t card_count_per_machine = 8;
if (stage_size == 1) {
std::map<std::string, std::set<uint32_t>> device_map;
for (size_t i = 0; i < rank_size; i++) {
const auto &item = config_.rank_list[i];
auto &device_id_list = device_map[item.ip];
if (device_id_list.count(item.device_id) > 0) {
return INFER_STATUS_LOG_ERROR(FAILED) << "Check rank table config failed, device id repeatedly used by rank "
<< i << " in device ip " << item.ip;
}
device_id_list.emplace(item.device_id);
}
} else {
if (rank_size < card_count_per_machine) {
return INFER_STATUS_LOG_ERROR(FAILED)
<< "Rank size " << rank_size << "must >= card count " << card_count_per_machine
<< " of one machine when stage size " << stage_size << " > 1";
}
if (parallel_count % card_count_per_machine != 0) {
return INFER_STATUS_LOG_ERROR(FAILED)
<< "Parallel count " << parallel_count << " in one stage must be N * " << card_count_per_machine
<< "(card count of one machine), rank size: " << rank_size << ", stage size: " << stage_size;
}
for (size_t i = 0; i < rank_size; i += card_count_per_machine) {
const auto &first_item = config_.rank_list[i];
for (size_t k = 0; i + k < rank_size && k < card_count_per_machine; k++) {
auto rank_id = i + k;
const auto &item = config_.rank_list[rank_id];
if (k != item.device_id) {
return INFER_STATUS_LOG_ERROR(FAILED)
<< "Check rank table config failed, expected device id of rank " << rank_id << " to be " << k;
}
if (first_item.ip != item.ip) {
return INFER_STATUS_LOG_ERROR(FAILED)
<< "Check rank table config failed, expected device ip " << item.ip << " of rank " << rank_id
<< " to be equal with device ip " << first_item.ip << " of rank " << i;
}
}
}
}
MSI_LOG_INFO << "Check rank table success, rank size: " << rank_size << ", stage size: " << stage_size
<< ", parallel count in one stage: " << parallel_count;
return SUCCESS;
}

void DistributedServable::OnAgentFailed() { SetWaitAgentsPromise(false); }

} // namespace serving
} // namespace mindspore

+ 92
- 0
mindspore_serving/ccsrc/worker/distributed_worker/distributed_servable.h View File

@@ -0,0 +1,92 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_SERVING_WORKER_DISTRIBUTED_SERVABLE_H
#define MINDSPORE_SERVING_WORKER_DISTRIBUTED_SERVABLE_H

#include <vector>
#include <string>
#include <map>
#include <memory>
#include "worker/sevable_base.h"
#include "worker/distributed_worker/common.h"
#include "worker/distributed_worker/notify_agent/base_notify_agent.h"

namespace mindspore {
namespace serving {

struct DistributedAgentContext {
WorkerAgentSpec agent_spec_;
std::shared_ptr<BaseNotifyAgent> notify_agent_ = nullptr;
};

class MS_API DistributedServable : public ServableBase {
public:
DistributedServable() = default;
~DistributedServable();
// from python, worker.py
Status StartServable(const std::string &servable_directory, const std::string &servable_name,
const std::string &rank_table_json_file, uint64_t version_number,
uint64_t wait_agents_time_in_seconds);

// invoke from agent
Status GetDistributedServableConfig(DistributedServableConfig *config) const;
// send model and group

// register and unregister agent, agent_spec_list_
Status RegisterAgent(const WorkerAgentSpec &agent_spec);
Status UnregisterAgent(const WorkerAgentSpec &agent_spec);

// predict, use config_ and agent_spec_list_
Status Predict(const std::vector<TensorBasePtr> &input, std::vector<TensorBasePtr> *output) override;

std::vector<TensorInfo> GetInputInfos() const override;
std::vector<TensorInfo> GetOutputInfos() const override;
uint64_t GetBatchSize() const override;
std::string GetServableName() const override;
uint64_t GetServableVersion() const override;
void Clear() override;
void OnAgentFailed();

private:
DistributedServableConfig config_;
std::string servable_name_;
uint64_t version_number_ = 0;
bool model_loaded_ = false;

std::mutex mutex_;
std::map<uint32_t, DistributedAgentContext> agent_spec_map_;
std::string rank_table_json_file_;

std::vector<TensorInfo> input_infos_;
std::vector<TensorInfo> output_infos_;
uint64_t batch_size_ = 0;
std::atomic_flag promise_set_flag_ = ATOMIC_FLAG_INIT;
std::promise<bool> agents_promise_;

Status InitConfigOnStartup(const std::string &rank_table_json_file);
Status WaitAgentsReady(uint64_t wait_agents_time_in_seconds);
Status CheckAgentsInfosAndInitTensorInfos();
Status CompareTensorInfos(const std::vector<TensorInfo> &lefts, const std::vector<TensorInfo> &rights);
Status CheckRankConfig();
void SetWaitAgentsPromise(bool flag);
// agent stubs
};

} // namespace serving
} // namespace mindspore

#endif // MINDSPORE_SERVING_WORKER_DISTRIBUTED_SERVABLE_H

+ 42
- 0
mindspore_serving/ccsrc/worker/distributed_worker/notify_agent/base_notify_agent.h View File

@@ -0,0 +1,42 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_SERVING_WORKER_BASE_NOTIFY_AGENT_H
#define MINDSPORE_SERVING_WORKER_BASE_NOTIFY_AGENT_H
#include <vector>
#include <functional>
#include <future>
#include "common/serving_common.h"
#include "common/servable.h"
#include "proto/ms_agent.pb.h"
#include "common/grpc_client.h"

namespace mindspore {
namespace serving {

class MS_API BaseNotifyAgent {
public:
BaseNotifyAgent() = default;
virtual ~BaseNotifyAgent() = default;
virtual Status Exit() = 0;
virtual Status DispatchAsync(const proto::DistributedPredictRequest &request, proto::DistributedPredictReply *reply,
DispatchCallback callback) = 0;
};

} // namespace serving
} // namespace mindspore

#endif // MINDSPORE_SERVING_WORKER_BASE_NOTIFY_AGENT_H

+ 66
- 0
mindspore_serving/ccsrc/worker/distributed_worker/notify_agent/notify_agent.cc View File

@@ -0,0 +1,66 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "worker/distributed_worker/notify_agent/notify_agent.h"
#include <grpcpp/grpcpp.h>
#include <grpcpp/health_check_service_interface.h>
#include <grpcpp/ext/proto_server_reflection_plugin.h>
#include <thread>
#include "common/exit_handle.h"
#include "common/grpc_server.h"
#include "common/grpc_client.h"

namespace mindspore {
namespace serving {

GrpcNotfiyAgent::GrpcNotfiyAgent(const std::string &agent_address) {
agent_address_ = agent_address;
std::shared_ptr<grpc::Channel> channel = GrpcServer::CreateChannel(agent_address_);
stub_ = proto::MSAgent::NewStub(channel);
}

GrpcNotfiyAgent::~GrpcNotfiyAgent() = default;

Status GrpcNotfiyAgent::Exit() {
if (stub_) {
proto::DistributedExitRequest request;
request.set_address(agent_address_);
proto::DistributedExitReply reply;
grpc::ClientContext context;
const int32_t TIME_OUT = 1;
std::chrono::system_clock::time_point deadline = std::chrono::system_clock::now() + std::chrono::seconds(TIME_OUT);
context.set_deadline(deadline);

(void)stub_->Exit(&context, request, &reply);
}
return SUCCESS;
}

Status GrpcNotfiyAgent::DispatchAsync(const proto::DistributedPredictRequest &request,
proto::DistributedPredictReply *reply, DispatchCallback callback) {
if (!stub_) {
return INFER_STATUS_LOG_ERROR(FAILED)
<< "Predict failed, agent gRPC has not been inited or has already exited, agent address " << agent_address_;
}
if (!distributed_client_) {
distributed_client_ = std::make_unique<MSDistributedClient>();
distributed_client_->Start();
}
distributed_client_->PredictAsync(request, reply, stub_.get(), callback);
return SUCCESS;
} // namespace serving

} // namespace serving
} // namespace mindspore

+ 48
- 0
mindspore_serving/ccsrc/worker/distributed_worker/notify_agent/notify_agent.h View File

@@ -0,0 +1,48 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_SERVING_WORKER_NOTIFY_AGENT_H
#define MINDSPORE_SERVING_WORKER_NOTIFY_AGENT_H
#include <vector>
#include <string>
#include <memory>
#include <atomic>
#include "worker/distributed_worker/notify_agent/base_notify_agent.h"
#include "proto/ms_agent.pb.h"
#include "proto/ms_agent.grpc.pb.h"

namespace mindspore {
namespace serving {

class MS_API GrpcNotfiyAgent : public BaseNotifyAgent {
public:
explicit GrpcNotfiyAgent(const std::string &worker_address);
~GrpcNotfiyAgent() override;

Status Exit() override;

Status DispatchAsync(const proto::DistributedPredictRequest &request, proto::DistributedPredictReply *reply,
DispatchCallback callback) override;

private:
std::string agent_address_;
std::shared_ptr<proto::MSAgent::Stub> stub_ = nullptr;
};

} // namespace serving
} // namespace mindspore

#endif // MINDSPORE_SERVING_WORKER_NOTIFY_AGENT_H

+ 107
- 0
mindspore_serving/ccsrc/worker/distributed_worker/notify_distributed/notify_worker.cc View File

@@ -0,0 +1,107 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "worker/distributed_worker/notify_distributed/notify_worker.h"
#include <grpcpp/grpcpp.h>
#include <grpcpp/health_check_service_interface.h>
#include <grpcpp/ext/proto_server_reflection_plugin.h>
#include <thread>
#include "common/exit_handle.h"
#include "common/grpc_server.h"
#include "common/proto_tensor.h"

namespace mindspore {
namespace serving {

GrpcNotifyDistributeWorker::GrpcNotifyDistributeWorker(const std::string &distributed_worker_ip,
uint32_t distributed_worker_port, const std::string &host_ip,
uint32_t host_port)
: distributed_worker_ip_(distributed_worker_ip),
distributed_worker_port_(distributed_worker_port),
host_ip_(host_ip),
host_port_(host_port) {
distributed_worker_address_ = distributed_worker_ip + ":" + std::to_string(distributed_worker_port);
agent_address_ = host_ip_ + ":" + std::to_string(host_port_);
auto channel = GrpcServer::CreateChannel(distributed_worker_address_);
stub_ = proto::MSWorker::NewStub(channel);
}

GrpcNotifyDistributeWorker::~GrpcNotifyDistributeWorker() = default;

Status GrpcNotifyDistributeWorker::Register(const std::vector<WorkerAgentSpec> &worker_specs) {
const int32_t REGISTER_TIME_OUT = 60;
const int32_t REGISTER_INTERVAL = 1;
auto loop = REGISTER_TIME_OUT;
while (loop-- && !ExitSignalHandle::Instance().HasStopped()) {
MSI_LOG(INFO) << "Register to " << distributed_worker_address_;
proto::AgentRegisterRequest request;
GrpcTensorHelper::CopyFromWorkerAgentSpec(worker_specs, &request);
proto::AgentRegisterReply reply;
grpc::ClientContext context;
std::chrono::system_clock::time_point deadline =
std::chrono::system_clock::now() + std::chrono::seconds(REGISTER_INTERVAL);
context.set_deadline(deadline);
grpc::Status status = stub_->AgentRegister(&context, request, &reply);
if (status.ok()) {
MSI_LOG(INFO) << "Register SUCCESS ";
return SUCCESS;
}
MSI_LOG_INFO << "Grpc message: " << status.error_code() << ", " << status.error_message();
std::this_thread::sleep_for(std::chrono::milliseconds(REGISTER_INTERVAL * 1000));
}
if (ExitSignalHandle::Instance().HasStopped()) {
return INFER_STATUS_LOG_WARNING(FAILED) << "Agent exit, stop registration";
}
return INFER_STATUS_LOG_ERROR(SYSTEM_ERROR) << "Register TimeOut";
}

Status GrpcNotifyDistributeWorker::Unregister() {
if (is_stoped_.load()) {
return SUCCESS;
}
is_stoped_ = true;
proto::AgentExitRequest request;
request.set_address(agent_address_);
proto::AgentExitReply reply;
grpc::ClientContext context;
const int32_t TIME_OUT = 1;
std::chrono::system_clock::time_point deadline = std::chrono::system_clock::now() + std::chrono::seconds(TIME_OUT);
context.set_deadline(deadline);
grpc::Status status = stub_->AgentExit(&context, request, &reply);
if (status.ok()) {
MSI_LOG(INFO) << "Exit SUCCESS ";
return SUCCESS;
}
return INFER_STATUS_LOG_ERROR(SYSTEM_ERROR) << "Exit Failed";
}

Status GrpcNotifyDistributeWorker::NotifyFailed(const std::string &worker_ip, uint32_t worker_port) {
auto address = worker_ip + ":" + std::to_string(worker_port);
auto channel = GrpcServer::CreateChannel(address);
auto stub = proto::MSWorker::NewStub(channel);

grpc::ClientContext context;
proto::AgentFailedRequest request;
proto::AgentFailedReply reply;
grpc::Status status = stub->AgentFailed(&context, request, &reply);
if (status.ok()) {
MSI_LOG(INFO) << "Success to notify failure of agent";
return SUCCESS;
}
return INFER_STATUS_LOG_ERROR(SYSTEM_ERROR) << "Failed to notify failure of agent";
}

} // namespace serving
} // namespace mindspore

+ 55
- 0
mindspore_serving/ccsrc/worker/distributed_worker/notify_distributed/notify_worker.h View File

@@ -0,0 +1,55 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_SERVING_WORKER_NOTIFY_WORKER_H
#define MINDSPORE_SERVING_WORKER_NOTIFY_WORKER_H
#include <vector>
#include <string>
#include <memory>
#include "common/serving_common.h"
#include "worker/distributed_worker/common.h"
#include "proto/ms_distributed.pb.h"
#include "proto/ms_distributed.grpc.pb.h"
#include "proto/ms_worker.pb.h"
#include "proto/ms_worker.grpc.pb.h"
namespace mindspore {
namespace serving {

class MS_API GrpcNotifyDistributeWorker {
public:
GrpcNotifyDistributeWorker(const std::string &worker_ip, uint32_t worker_port, const std::string &agent_ip,
uint32_t agent_port);
~GrpcNotifyDistributeWorker();
Status Register(const std::vector<WorkerAgentSpec> &agent_specs);
Status Unregister();
// from start up, not agent
static Status NotifyFailed(const std::string &worker_ip, uint32_t worker_port);

private:
std::string distributed_worker_ip_;
uint32_t distributed_worker_port_;
std::string host_ip_;
uint32_t host_port_;
std::string agent_address_;
std::string distributed_worker_address_;
std::unique_ptr<proto::MSWorker::Stub> stub_;
std::atomic<bool> is_stoped_{false};
};

} // namespace serving
} // namespace mindspore

#endif // MINDSPORE_SERVING_WORKER_NOTIFY_WORKER_H

+ 103
- 0
mindspore_serving/ccsrc/worker/distributed_worker/worker_agent.cc View File

@@ -0,0 +1,103 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "worker/distributed_worker/worker_agent.h"
#include <memory>
#include "worker/distributed_worker/agent_process/agent_process.h"
#include "worker/distributed_worker/notify_distributed/notify_worker.h"
#include "common/exit_handle.h"

namespace mindspore {
namespace serving {

WorkerAgent &WorkerAgent::Instance() {
static WorkerAgent instance;
return instance;
}

Status WorkerAgent::Clear() {
if (notify_worker_) {
if (exit_notify_worker_) {
notify_worker_->Unregister();
}
notify_worker_ = nullptr;
}
grpc_server_.Stop();
executor_.UnloadModel();
return SUCCESS;
}

Status WorkerAgent::Run(const proto::DistributedPredictRequest &request, proto::DistributedPredictReply *reply) {
// todo : DistributedPredictRequest->RequestBase
// todo : DistributedPredictReply->ReplyBase
return SUCCESS;
}

Status WorkerAgent::StartAgent(const AgentStartUpConfig &config) {
Status status;
config_ = config;
status = executor_.LoadModelFromFile(config);
if (status != SUCCESS) {
MSI_LOG_ERROR << "LoadModelFromFile failed, servable name: " << config.common_meta.servable_name
<< ", rank_id: " << config.rank_id << ", device id: " << config.device_id
<< ", model file: " << config.model_file_name
<< ", rank table file: " << config.rank_table_json_file_name
<< ", group config file: " << config.group_file_name;
return status;
}
status = StartGrpcServer();
if (status != SUCCESS) {
MSI_LOG_ERROR << "Start agent grpc server failed, agent ip: " << config.agent_ip
<< ", agent port: " << config.agent_port;
return status;
}
status = RegisterAgent();
if (status != SUCCESS) {
MSI_LOG_ERROR << "Register agent failed, agent ip: " << config.agent_ip << ", agent port: " << config.agent_port
<< ", worker ip: " << config.worker_ip << ", worker port: " << config.worker_port;
return status;
}
MSI_LOG_INFO << "Start agent success, servable name: " << config.common_meta.servable_name
<< ", rank_id: " << config.rank_id << ", device id: " << config.device_id
<< ", model file: " << config.model_file_name
<< ", rank table file: " << config.rank_table_json_file_name
<< ", group config file: " << config.group_file_name;
return SUCCESS;
}

Status WorkerAgent::StartGrpcServer() {
grpc_server_.Start(std::make_shared<MSAgentImpl>(), config_.agent_ip, config_.agent_port, gRpcMaxMBMsgSize, "Agent");
return SUCCESS;
}

Status WorkerAgent::RegisterAgent() {
notify_worker_ = std::make_shared<GrpcNotifyDistributeWorker>(config_.worker_ip, config_.agent_port, config_.agent_ip,
config_.agent_port);
WorkerAgentSpec spec;
spec.agent_address = config_.agent_ip + ":" + std::to_string(config_.agent_port);
spec.rank_id = config_.rank_id;
spec.batch_size = executor_.GetBatchSize();
spec.input_infos = executor_.GetInputInfos();
spec.output_infos = executor_.GetOutputInfos();
return notify_worker_->Register({spec});
}

void WorkerAgent::StopAgent(bool notify_worker) {
exit_notify_worker_ = notify_worker;
ExitSignalHandle::Instance().Stop();
}

} // namespace serving
} // namespace mindspore

+ 55
- 0
mindspore_serving/ccsrc/worker/distributed_worker/worker_agent.h View File

@@ -0,0 +1,55 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_SERVING_WORKER_AGENT_H
#define MINDSPORE_SERVING_WORKER_AGENT_H
#include <vector>
#include <memory>
#include "worker/distributed_worker/agent_executor.h"
#include "proto/ms_agent.pb.h"
#include "proto/ms_agent.grpc.pb.h"
#include "common/grpc_server.h"
#include "worker/distributed_worker/common.h"
#include "worker/distributed_worker/notify_distributed/notify_worker.h"

namespace mindspore {
namespace serving {
class MS_API WorkerAgent {
public:
static WorkerAgent &Instance();
Status Clear();

Status Run(const proto::DistributedPredictRequest &request, proto::DistributedPredictReply *reply);

Status StartAgent(const AgentStartUpConfig &config);

void StopAgent(bool notify_worker = true);

private:
AgentStartUpConfig config_;
WorkerAgentExecutor executor_;
GrpcServer grpc_server_;
bool exit_notify_worker_ = true;
std::shared_ptr<GrpcNotifyDistributeWorker> notify_worker_;

Status StartGrpcServer();
Status RegisterAgent();
};

} // namespace serving
} // namespace mindspore

#endif // MINDSPORE_SERVING_WORKER_AGENT_H

+ 0
- 1
mindspore_serving/ccsrc/worker/grpc/worker_process.cc View File

@@ -15,7 +15,6 @@
*/
#include "worker/grpc/worker_process.h"
#include "master/dispacther.h"
#include "worker/worker.h"
namespace mindspore {


+ 1
- 1
mindspore_serving/ccsrc/worker/grpc/worker_process.h View File

@@ -28,7 +28,7 @@ namespace mindspore {
namespace serving {
// Service Implement
class MSWorkerImpl final : public proto::MSWorker::Service {
class MSWorkerImpl : public proto::MSWorker::Service {
public:
grpc::Status Predict(grpc::ServerContext *context, const proto::PredictRequest *request,
proto::PredictReply *reply) override;


+ 15
- 3
mindspore_serving/ccsrc/worker/grpc/worker_server.cc View File

@@ -21,12 +21,20 @@
namespace mindspore {
namespace serving {
MSWorkerServer::~MSWorkerServer() { Stop(); }
MSWorkerServer::MSWorkerServer(const std::string &hostname, int32_t port) {
Status MSWorkerServer::StartWorkerGrpcServer(const std::string &hostname, int32_t port) {
if (in_running_) {
return INFER_STATUS_LOG_ERROR(FAILED) << "Worker grpc server is already running";
}
service_impl_ = std::make_unique<MSWorkerImpl>();
async_server_ = std::make_unique<WorkerGrpcServer>(hostname, port, service_impl_.get());
return Init();
}
MSWorkerServer::MSWorkerServer() = default;
Status MSWorkerServer::Init() {
Status status = async_server_->Run("Worker gRPC", gRpcMaxMBMsgSize);
if (status != SUCCESS) return status;
@@ -40,10 +48,14 @@ Status MSWorkerServer::StartAsyncRpcService() {
return status;
}
Status MSWorkerServer::Stop() {
if (in_running_) {
if (in_running_ && async_server_) {
async_server_->Stop();
grpc_thread_.join();
if (grpc_thread_.joinable()) {
grpc_thread_.join();
}
}
async_server_ = nullptr;
service_impl_ = nullptr;
in_running_ = false;
return SUCCESS;
}


+ 27
- 30
mindspore_serving/ccsrc/worker/grpc/worker_server.h View File

@@ -27,40 +27,53 @@
#include "proto/ms_worker.grpc.pb.h"
#include "common/grpc_async_server.h"
#include "worker/grpc/worker_process.h"
#include "worker/distributed_worker/distributed_servable.h"
namespace mindspore {
namespace serving {
// Service Implement
class MSWorkerServer {
class MS_API MSWorkerServer {
public:
enum ServerState { kGdsUninit = 0, kGdsInitializing, kGdsRunning, kGdsStopped };
MSWorkerServer(const std::string &hostname, int32_t port);
~MSWorkerServer();
Status Init();
MSWorkerServer();
virtual ~MSWorkerServer();
virtual Status StartWorkerGrpcServer(const std::string &hostname, int32_t port);
Status Stop();
Status StartAsyncRpcService();
protected:
bool in_running_ = false;
std::thread grpc_thread_;
std::unique_ptr<MSWorkerImpl> service_impl_;
std::unique_ptr<GrpcAsyncServer> async_server_;
std::unique_ptr<MSWorkerImpl> service_impl_ = nullptr;
std::unique_ptr<GrpcAsyncServer> async_server_ = nullptr;
Status Init();
Status StartAsyncRpcService();
};
class WorkerServiceContext {
public:
enum class STATE : int8_t { CREATE = 1, PROCESS = 2, FINISH = 3 };
WorkerServiceContext(MSWorkerImpl *service_impl, proto::MSWorker::AsyncService *async_service,
grpc::ServerCompletionQueue *cq)
: service_impl_(service_impl), async_service_(async_service), cq_(cq) {
state_ = STATE::CREATE;
}
virtual ~WorkerServiceContext() {}
bool JudgeFinish() { return state_ == STATE::FINISH; }
virtual void StartEnqueueRequest() = 0;
virtual void HandleRequest() = 0;
virtual bool JudgeFinish() = 0;
protected:
MSWorkerImpl *service_impl_;
proto::MSWorker::AsyncService *async_service_;
grpc::ServerCompletionQueue *cq_;
grpc::ServerContext ctx_;
public:
STATE state_;
};
@@ -68,9 +81,7 @@ class WorkerPredictContext : public WorkerServiceContext {
public:
WorkerPredictContext(MSWorkerImpl *service_impl, proto::MSWorker::AsyncService *async_service,
grpc::ServerCompletionQueue *cq)
: service_impl_(service_impl), async_service_(async_service), cq_(cq), responder_(&ctx_) {
state_ = STATE::CREATE;
}
: WorkerServiceContext(service_impl, async_service, cq), responder_(&ctx_) {}
~WorkerPredictContext() = default;
@@ -93,13 +104,7 @@ class WorkerPredictContext : public WorkerServiceContext {
responder_.Finish(response_, status, this);
}
bool JudgeFinish() override { return state_ == STATE::FINISH; }
private:
MSWorkerImpl *service_impl_;
proto::MSWorker::AsyncService *async_service_;
grpc::ServerCompletionQueue *cq_;
grpc::ServerContext ctx_;
grpc::ServerAsyncResponseWriter<proto::PredictReply> responder_;
proto::PredictRequest request_;
proto::PredictReply response_;
@@ -109,9 +114,7 @@ class WorkerExitContext : public WorkerServiceContext {
public:
WorkerExitContext(MSWorkerImpl *service_impl, proto::MSWorker::AsyncService *async_service,
grpc::ServerCompletionQueue *cq)
: service_impl_(service_impl), async_service_(async_service), cq_(cq), responder_(&ctx_) {
state_ = STATE::CREATE;
}
: WorkerServiceContext(service_impl, async_service, cq), responder_(&ctx_) {}
~WorkerExitContext() = default;
@@ -134,13 +137,7 @@ class WorkerExitContext : public WorkerServiceContext {
responder_.Finish(response_, status, this);
}
bool JudgeFinish() override { return state_ == STATE::FINISH; }
private:
MSWorkerImpl *service_impl_;
proto::MSWorker::AsyncService *async_service_;
grpc::ServerCompletionQueue *cq_;
grpc::ServerContext ctx_;
grpc::ServerAsyncResponseWriter<proto::ExitReply> responder_;
proto::ExitRequest request_;
proto::ExitReply response_;
@@ -174,7 +171,7 @@ class WorkerGrpcServer : public GrpcAsyncServer {
return SUCCESS;
}
private:
protected:
MSWorkerImpl *service_impl_;
proto::MSWorker::AsyncService svc_;
};


+ 0
- 126
mindspore_serving/ccsrc/worker/inference/inference.h View File

@@ -52,132 +52,6 @@ enum DeviceType {
kDeviceTypeCpu,
};

class MS_API InferSession {
public:
InferSession() = default;
virtual ~InferSession() = default;
virtual Status InitEnv(DeviceType device_type, uint32_t device_id,
const std::map<std::string, std::string> &other_options) = 0;
virtual Status FinalizeEnv() = 0;

virtual Status LoadModelFromFile(serving::DeviceType device_type, uint32_t device_id, const std::string &file_name,
ModelType model_type, const std::vector<int> &without_batch_dim_inputs,
const std::map<std::string, std::string> &other_options, uint32_t *model_id) = 0;

virtual Status UnloadModel(uint32_t model_id) = 0;
// override this method to avoid request/reply data copy
virtual Status ExecuteModel(uint32_t model_id, const RequestBase &request, ReplyBase *reply) = 0;
virtual Status ExecuteModel(uint32_t model_id, const std::vector<TensorBasePtr> &request,
std::vector<TensorBasePtr> *reply) {
VectorTensorPtrWrapRequest wrap_request(request);
VectorTensorPtrWrapReply wrap_reply(reply, []() { return std::make_shared<Tensor>(); });
return ExecuteModel(model_id, wrap_request, &wrap_reply);
}

virtual std::vector<TensorInfo> GetInputInfos(uint32_t model_id) const = 0;
virtual std::vector<TensorInfo> GetOutputInfos(uint32_t model_id) const = 0;
virtual ssize_t GetBatchSize(uint32_t model_id) const = 0;
virtual bool CheckModelSupport(DeviceType device_type, ModelType model_type) const { return true; }
};

struct InferSessionRegInfo {
std::shared_ptr<InferSession> session;
ModelType model_type;
int priority;
};

class MS_API InferSessionStorage {
public:
void Register(DeviceType device_type, ModelType model_type, const std::shared_ptr<InferSession> &session,
int priority) {
auto &list = session_map_[device_type];
InferSessionRegInfo info{session, model_type, priority};
list.push_back(info);
}

std::shared_ptr<InferSession> Get(DeviceType device_type, ModelType model_type, DeviceType *specified_device_type) {
MSI_EXCEPTION_IF_NULL(specified_device_type);
if (device_type == kDeviceTypeNotSpecified) {
for (auto &item_device : session_map_) {
std::shared_ptr<InferSession> ret_session = GetSession(item_device.second, item_device.first, model_type);
if (ret_session) {
*specified_device_type = item_device.first;
return ret_session;
}
}
return nullptr;
} else if (device_type == kDeviceTypeAscend) {
auto ascend_list = {kDeviceTypeAscendCL, kDeviceTypeAscendMS};
for (auto ascend_type : ascend_list) {
auto it = session_map_.find(ascend_type);
if (it == session_map_.end()) {
continue;
}
auto session_ret = GetSession(it->second, ascend_type, model_type);
if (session_ret != nullptr) {
*specified_device_type = ascend_type;
return session_ret;
}
}
return nullptr;
}
auto it = session_map_.find(device_type);
if (it == session_map_.end()) {
return nullptr;
}
std::shared_ptr<InferSession> session_ret;
session_ret = GetSession(it->second, device_type, model_type);
*specified_device_type = device_type;
return session_ret;
}

static InferSessionStorage &Instance() {
static InferSessionStorage instance;
return instance;
}

private:
std::unordered_map<DeviceType, std::vector<InferSessionRegInfo>> session_map_;

std::shared_ptr<InferSession> GetSession(const std::vector<InferSessionRegInfo> &session_list, DeviceType device_type,
ModelType model_type) {
std::shared_ptr<InferSession> session_ret = nullptr;
int cur_priority = INT32_MIN;
for (auto &item : session_list) {
if (item.model_type != model_type) {
continue;
}
if (session_ret == nullptr || cur_priority < item.priority) {
if (!item.session->CheckModelSupport(device_type, model_type)) {
MSI_LOG_INFO << "CheckModelSupport for " << device_type << " " << model_type << " failed, skipped";
continue;
}
cur_priority = item.priority;
session_ret = item.session;
}
}
return session_ret;
}
};

class MS_API InferSessionRegister {
public:
InferSessionRegister(DeviceType device_type, ModelType model_type, const std::shared_ptr<InferSession> &session,
int priority) {
InferSessionStorage::Instance().Register(device_type, model_type, session, priority);
}
};

#define REGISTER_INFER_SEESION_UNIQUE(device_type, model_type, cls_name, priority, index) \
static mindspore::serving::InferSessionRegister g_register_session_##cls_name##_##index( \
device_type, model_type, std::make_shared<cls_name>(), priority);

#define REGISTER_INFER_SEESION_HELPER(device_type, model_type, cls_name, priority, index) \
REGISTER_INFER_SEESION_UNIQUE(device_type, model_type, cls_name, priority, index)

#define REGISTER_INFER_SEESION(device_type, model_type, cls_name, priority) \
REGISTER_INFER_SEESION_HELPER(device_type, model_type, cls_name, priority, __COUNTER__);

static inline LogStream &operator<<(LogStream &stream, DeviceType device_type) {
switch (device_type) {
case kDeviceTypeAscend:


+ 62
- 84
mindspore_serving/ccsrc/worker/inference/mindspore_model_wrap.cc View File

@@ -26,16 +26,6 @@
namespace mindspore {
namespace serving {

Status MindSporeModelWrap::InitEnv(serving::DeviceType device_type, uint32_t device_id,
const std::map<std::string, std::string> &other_options) {
return SUCCESS;
}

Status MindSporeModelWrap::FinalizeEnv() {
model_map_.clear();
return SUCCESS;
}

mindspore::DataType TransInferDataType2ApiTypeId(DataType data_type) {
const std::map<DataType, mindspore::DataType> type2id_map{
{serving::kMSI_Unknown, mindspore::DataType::kTypeUnknown},
@@ -81,11 +71,9 @@ DataType TransTypeId2InferDataType(mindspore::DataType type_id) {
}

Status MindSporeModelWrap::LoadModelFromFile(serving::DeviceType device_type, uint32_t device_id,
const std::string &file_name, ModelType model_type,
const std::string &file_name, ModelType model_type, bool with_batch_dim,
const std::vector<int> &without_batch_dim_inputs,
const std::map<std::string, std::string> &other_options,
uint32_t *model_id) {
MSI_EXCEPTION_IF_NULL(model_id);
const std::map<std::string, std::string> &other_options) {
std::string device_type_str;
if (device_type == kDeviceTypeAscendMS) {
device_type_str = mindspore::kDeviceTypeAscend910;
@@ -113,18 +101,18 @@ Status MindSporeModelWrap::LoadModelFromFile(serving::DeviceType device_type, ui
<< "', device_id: " << device_id << ", model type: " << model_type << ", options: " << other_options;
return Status(FAILED, status.ToString());
}
model_index_++;
*model_id = model_index_;
ApiModelInfo api_model_info;
api_model_info.model = model;
api_model_info.device_type = device_type_str;
api_model_info.device_id = device_id;
api_model_info.with_batch_dim = with_batch_dim;
api_model_info.without_batch_dim_inputs = without_batch_dim_inputs;
auto st = GetModelInfos(&api_model_info);
if (st != SUCCESS) {
return st;
}
model_map_[*model_id] = api_model_info;
GetModelBatchSize(&api_model_info);
model_ = api_model_info;
MSI_LOG_INFO << "Load model from file success, model file: " << file_name << ", device_type: '" << device_type_str
<< "', device_id: " << device_id << ", model type: " << model_type << ", options: " << other_options;
return SUCCESS;
@@ -169,20 +157,6 @@ Status MindSporeModelWrap::GetModelInfos(ApiModelInfo *api_model_info) {
MSI_EXCEPTION_IF_NULL(api_model_info);
auto model = api_model_info->model;

bool first_dim_same = true;
auto find_batch_size = [&first_dim_same, api_model_info](const std::vector<int64_t> &shape) {
if (first_dim_same) {
if (shape.empty()) {
first_dim_same = false;
} else if (api_model_info->batch_size != 0) {
if (api_model_info->batch_size != shape[0]) {
first_dim_same = false;
}
} else {
api_model_info->batch_size = shape[0];
}
}
};
auto get_tensor_info_from_tensor = [](const mindspore::MSTensor &ms_tensor) {
serving::TensorInfo tensor_info;
tensor_info.shape = ms_tensor.Shape();
@@ -204,10 +178,6 @@ Status MindSporeModelWrap::GetModelInfos(ApiModelInfo *api_model_info) {
return INFER_STATUS_LOG_ERROR(FAILED)
<< "Unknown input mindspore data type " << static_cast<int>(info.DataType());
}
const auto &list = api_model_info->without_batch_dim_inputs;
if (std::find(list.begin(), list.end(), i) == list.end()) {
find_batch_size(tensor_info.shape);
}
api_model_info->input_tensor_infos.push_back(tensor_info);
api_model_info->input_names.push_back(info.Name());
}
@@ -220,27 +190,59 @@ Status MindSporeModelWrap::GetModelInfos(ApiModelInfo *api_model_info) {
return INFER_STATUS_LOG_ERROR(FAILED)
<< "Unknown output mindspore data type " << static_cast<int>(info.DataType());
}
find_batch_size(tensor_info.shape);
api_model_info->output_tensor_infos.push_back(tensor_info);
api_model_info->output_names.push_back(info.Name());
}
}
return SUCCESS;
}

void MindSporeModelWrap::GetModelBatchSize(ApiModelInfo *api_model_info) {
MSI_EXCEPTION_IF_NULL(api_model_info);
bool first_dim_same = true;
auto find_batch_size = [&first_dim_same, api_model_info](const std::vector<int64_t> &shape) {
if (!api_model_info->with_batch_dim) {
first_dim_same = false;
return;
}
if (!first_dim_same) {
return;
}
if (shape.empty()) {
first_dim_same = false;
return;
}
if (api_model_info->batch_size != 0) {
if (api_model_info->batch_size != shape[0]) {
first_dim_same = false;
}
} else {
api_model_info->batch_size = shape[0];
}
};

auto list = api_model_info->without_batch_dim_inputs;
auto size = api_model_info->input_tensor_infos.size();
for (size_t i = 0; i < size; i++) {
if (std::find(list.begin(), list.end(), i) == list.end()) {
auto &info = api_model_info->input_tensor_infos[i];
find_batch_size(info.shape);
}
}
for (auto &info : api_model_info->output_tensor_infos) {
find_batch_size(info.shape);
}
if (!first_dim_same) {
api_model_info->batch_size = 0;
}
return SUCCESS;
}

Status MindSporeModelWrap::UnloadModel(uint32_t model_id) {
auto it = model_map_.find(model_id);
if (it == model_map_.end()) {
return INFER_STATUS_LOG_ERROR(FAILED) << "Invalid model id " << model_id;
}
model_map_.erase(it);
Status MindSporeModelWrap::UnloadModel() {
model_.model = nullptr;
return SUCCESS;
}

Status MindSporeModelWrap::ExecuteModel(uint32_t model_id, const RequestBase &request, serving::ReplyBase *reply) {
Status MindSporeModelWrap::ExecuteModel(const RequestBase &request, serving::ReplyBase *reply) {
MSI_EXCEPTION_IF_NULL(reply);
FuncMakeInBuffer func_in = [&request](size_t index, const std::string &name) {
auto input_tensor = request[index];
@@ -260,11 +262,10 @@ Status MindSporeModelWrap::ExecuteModel(uint32_t model_id, const RequestBase &re
tensor->set_data_type(data_type);
tensor->set_shape(shape);
};
return ExecuteModelCommon(model_id, request.size(), func_in, func_out);
return ExecuteModelCommon(request.size(), func_in, func_out);
}

Status MindSporeModelWrap::ExecuteModel(uint32_t model_id, const std::vector<TensorBasePtr> &request,
std::vector<TensorBasePtr> *reply) {
Status MindSporeModelWrap::ExecuteModel(const std::vector<TensorBasePtr> &request, std::vector<TensorBasePtr> *reply) {
MSI_EXCEPTION_IF_NULL(reply);
FuncMakeInBuffer func_in = [&request](size_t index, const std::string &name) {
auto &input_tensor = request[index];
@@ -282,16 +283,15 @@ Status MindSporeModelWrap::ExecuteModel(uint32_t model_id, const std::vector<Ten
tensor->set_shape(shape);
reply->push_back(tensor);
};
return ExecuteModelCommon(model_id, request.size(), func_in, func_out);
return ExecuteModelCommon(request.size(), func_in, func_out);
}

Status MindSporeModelWrap::ExecuteModelCommon(uint32_t model_id, size_t request_size, const FuncMakeInBuffer &in_func,
Status MindSporeModelWrap::ExecuteModelCommon(size_t request_size, const FuncMakeInBuffer &in_func,
const FuncMakeOutTensor &out_func) {
auto it = model_map_.find(model_id);
if (it == model_map_.end()) {
return INFER_STATUS_LOG_ERROR(FAILED) << "Invalid model id " << model_id;
if (model_.model == nullptr) {
return INFER_STATUS_LOG_ERROR(FAILED) << "Model is not loaded";
}
auto &model_info = it->second;
auto &model_info = model_;
auto model = model_info.model;
auto &input_names = model_info.input_names;
auto &output_names = model_info.output_names;
@@ -327,43 +327,25 @@ Status MindSporeModelWrap::ExecuteModelCommon(uint32_t model_id, size_t request_
return SUCCESS;
}

std::vector<serving::TensorInfo> MindSporeModelWrap::GetInputInfos(uint32_t model_id) const {
auto it = model_map_.find(model_id);
if (it == model_map_.end()) {
MSI_LOG_ERROR << "Invalid model id " << model_id;
return {};
}
auto &model_info = it->second;
return model_info.input_tensor_infos;
}
std::vector<serving::TensorInfo> MindSporeModelWrap::GetInputInfos() const { return model_.input_tensor_infos; }

std::vector<serving::TensorInfo> MindSporeModelWrap::GetOutputInfos(uint32_t model_id) const {
auto it = model_map_.find(model_id);
if (it == model_map_.end()) {
MSI_LOG_ERROR << "Invalid model id " << model_id;
return {};
}
auto &model_info = it->second;
return model_info.output_tensor_infos;
}
std::vector<serving::TensorInfo> MindSporeModelWrap::GetOutputInfos() const { return model_.output_tensor_infos; }

ssize_t MindSporeModelWrap::GetBatchSize(uint32_t model_id) const {
auto it = model_map_.find(model_id);
if (it == model_map_.end()) {
MSI_LOG_ERROR << "Invalid model id " << model_id;
return {};
}
auto &model_info = it->second;
return model_info.batch_size;
}
ssize_t MindSporeModelWrap::GetBatchSize() const { return model_.batch_size; }

bool MindSporeModelWrap::CheckModelSupport(DeviceType device_type, ModelType model_type) const {
std::string device_type_str;
switch (device_type) {
case kDeviceTypeAscendMS:
if (model_type != kMindIR) {
return false;
}
device_type_str = mindspore::kDeviceTypeAscend910;
break;
case kDeviceTypeAscendCL:
if (model_type != kMindIR && model_type != kOM) {
return false;
}
device_type_str = mindspore::kDeviceTypeAscend310;
break;
default:
@@ -378,9 +360,5 @@ ApiBufferTensorWrap::ApiBufferTensorWrap(const mindspore::MSTensor &tensor) : te

ApiBufferTensorWrap::~ApiBufferTensorWrap() = default;

REGISTER_INFER_SEESION(serving::kDeviceTypeAscendCL, kOM, MindSporeModelWrap, 1);
REGISTER_INFER_SEESION(serving::kDeviceTypeAscendCL, kMindIR, MindSporeModelWrap, 1);
REGISTER_INFER_SEESION(serving::kDeviceTypeAscendMS, kMindIR, MindSporeModelWrap, 1);

} // namespace serving
} // namespace mindspore

+ 15
- 23
mindspore_serving/ccsrc/worker/inference/mindspore_model_wrap.h View File

@@ -34,54 +34,46 @@ struct ApiModelInfo {
std::vector<serving::TensorInfo> input_tensor_infos;
std::vector<std::string> output_names;
std::vector<serving::TensorInfo> output_tensor_infos;
std::shared_ptr<mindspore::Model> model;
std::shared_ptr<mindspore::Model> model = nullptr;
uint32_t batch_size = 0;
std::string device_type;
uint32_t device_id = 0;
bool with_batch_dim = false;
std::vector<int> without_batch_dim_inputs;
};

class MindSporeModelWrap : public InferSession {
class MindSporeModelWrap {
public:
MindSporeModelWrap() = default;

~MindSporeModelWrap() = default;

Status InitEnv(serving::DeviceType device_type, uint32_t device_id,
const std::map<std::string, std::string> &other_options) override;

Status FinalizeEnv() override;

Status LoadModelFromFile(serving::DeviceType device_type, uint32_t device_id, const std::string &file_name,
ModelType model_type, const std::vector<int> &without_batch_dim_inputs,
const std::map<std::string, std::string> &other_options, uint32_t *model_id) override;

Status UnloadModel(uint32_t model_id) override;
ModelType model_type, bool with_batch_dim, const std::vector<int> &without_batch_dim_inputs,
const std::map<std::string, std::string> &other_options);

// override this method to avoid request/reply data copy
Status ExecuteModel(uint32_t model_id, const RequestBase &request, ReplyBase *reply) override;
Status ExecuteModel(uint32_t model_id, const std::vector<TensorBasePtr> &request,
std::vector<TensorBasePtr> *reply) override;
Status UnloadModel();
Status ExecuteModel(const RequestBase &request, ReplyBase *reply);
Status ExecuteModel(const std::vector<TensorBasePtr> &request, std::vector<TensorBasePtr> *reply);

std::vector<serving::TensorInfo> GetInputInfos(uint32_t model_id) const override;
std::vector<serving::TensorInfo> GetInputInfos() const;

std::vector<serving::TensorInfo> GetOutputInfos(uint32_t model_id) const override;
std::vector<serving::TensorInfo> GetOutputInfos() const;

ssize_t GetBatchSize(uint32_t model_id) const override;
ssize_t GetBatchSize() const;

bool CheckModelSupport(DeviceType device_type, ModelType model_type) const override;
bool CheckModelSupport(DeviceType device_type, ModelType model_type) const;

private:
std::unordered_map<uint32_t, ApiModelInfo> model_map_;
uint32_t model_index_ = 0;
ApiModelInfo model_;

using FuncMakeInBuffer = std::function<mindspore::MSTensor(size_t index, const std::string &name)>;
using FuncMakeOutTensor =
std::function<void(const mindspore::MSTensor, DataType data_type, const std::vector<int64_t> &shape)>;
Status ExecuteModelCommon(uint32_t model_id, size_t request_size, const FuncMakeInBuffer &in_func,
const FuncMakeOutTensor &out_func);
Status ExecuteModelCommon(size_t request_size, const FuncMakeInBuffer &in_func, const FuncMakeOutTensor &out_func);
Status GetModelInfos(ApiModelInfo *model_info);
std::shared_ptr<Context> TransformModelContext(const std::map<std::string, std::string> &other_options);
void GetModelBatchSize(ApiModelInfo *model_info);
};

class ApiBufferTensorWrap : public TensorBase {


+ 254
- 0
mindspore_serving/ccsrc/worker/local_servable/local_sevable.cc View File

@@ -0,0 +1,254 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "worker/local_servable/local_sevable.h"
#include <algorithm>
#include <set>
#include <map>
#include <vector>
#include <string>
#include "common/tensor.h"
#include "common/file_system_operation.h"
#include "worker/context.h"

namespace {
static const char *kVersionStrategyLatest = "latest";
static const char *kVersionStrategySpecific = "specific";
} // namespace

namespace mindspore::serving {

LocalModelServable::~LocalModelServable() { Clear(); }

std::string LocalModelServable::GetServableName() const { return servable_name_; }

uint64_t LocalModelServable::GetServableVersion() const { return version_number_; }

Status LocalModelServable::Predict(const std::vector<TensorBasePtr> &input, std::vector<TensorBasePtr> *output) {
if (!model_loaded_) {
MSI_LOG_EXCEPTION << "Model has not been loaded";
}
return session_.ExecuteModel(input, output);
}

std::vector<TensorInfo> LocalModelServable::GetInputInfos() const {
if (!model_loaded_) {
MSI_LOG_EXCEPTION << "Model has not been loaded";
}
return session_.GetInputInfos();
}

std::vector<TensorInfo> LocalModelServable::GetOutputInfos() const {
if (!model_loaded_) {
MSI_LOG_EXCEPTION << "Model has not been loaded";
}
return session_.GetOutputInfos();
}

uint64_t LocalModelServable::GetBatchSize() const {
if (!model_loaded_) {
MSI_LOG_EXCEPTION << "Model has not been loaded";
}
return session_.GetBatchSize();
}

Status LocalModelServable::StartServable(const std::string &servable_directory, const std::string &servable_name,
uint64_t version_number) {
if (model_loaded_) {
MSI_LOG_EXCEPTION << "Model has loaded";
}
base_spec_.servable_directory = servable_directory;
base_spec_.servable_name = servable_name;
base_spec_.version_number = version_number;

std::string version_strategy;
if (version_number == 0) {
version_strategy = kVersionStrategyLatest;
} else {
version_strategy = kVersionStrategySpecific;
}
Status status;
ServableSignature signature;
if (!ServableStorage::Instance().GetServableDef(servable_name, &signature)) {
return INFER_STATUS_LOG_ERROR(FAILED) << "Servable '" << servable_name << "' has not been registered";
}
status = InitDevice(signature.servable_meta.local_meta.model_format, {});
if (status != SUCCESS) {
MSI_LOG_ERROR << "Init env failed";
return status;
}

std::vector<uint64_t> real_versions;
status = LoadServableConfig(base_spec_, version_strategy, &real_versions);
if (status != SUCCESS) {
return INFER_STATUS_LOG_ERROR(FAILED)
<< "Start servable failed, there is no servable of the specified version number, specified version number: "
<< version_number << ", servable directory: '" << base_spec_.servable_directory << "', servable name: '"
<< base_spec_.servable_name
<< "'. version number is a positive integer(started from 1) and 0 represents the maximum version number.";
}
auto real_version_number = real_versions[0];
status = LoadModel(real_version_number);
if (status != SUCCESS) {
return status;
}
servable_name_ = base_spec_.servable_name;
version_number_ = real_version_number;
model_loaded_ = true;
MSI_LOG_INFO << status.StatusMessage();
std::cout << status.StatusMessage() << std::endl;
return SUCCESS;
}

void LocalModelServable::GetVersions(const LoadServableSpec &servable_spec, std::vector<uint64_t> *real_versions) {
MSI_EXCEPTION_IF_NULL(real_versions);
// define version_strategy:"specific","latest","multi"
if (version_strategy_ == kVersionStrategySpecific) {
real_versions->push_back(servable_spec.version_number);
return;
}
auto trans_to_integer = [](const std::string &str) -> uint32_t {
uint32_t parsed_value = 0;
for (auto c : str) {
if (c < '0' || c > '9') {
return 0;
}
parsed_value = parsed_value * 10 + c - '0';
}
if (std::to_string(parsed_value) != str) {
return 0;
}
return parsed_value;
};
uint64_t newest_version = 0;
std::string model_path = servable_spec.servable_directory + "/" + servable_spec.servable_name;
auto sub_dir = GetAllSubDirsNotFullPath(model_path);
static std::set<std::string> ignore_dir;
for (const auto &dir : sub_dir) {
if (dir == "__pycache__") continue;
auto version_parse = trans_to_integer(dir);
if (version_parse == 0) {
if (ignore_dir.emplace(servable_spec.servable_directory + dir).second) {
MSI_LOG_INFO << "Ignore directory " << dir << ", model_directory " << servable_spec.servable_directory
<< ", model_name " << servable_spec.servable_name;
}
continue;
}
real_versions->push_back(version_parse);
if (version_parse > newest_version) {
newest_version = version_parse;
}
}
if (version_strategy_ == kVersionStrategyLatest) {
real_versions->clear();
if (newest_version != 0) {
real_versions->push_back(newest_version);
}
}
}

Status LocalModelServable::LoadServableConfig(const LoadServableSpec &servable_spec,
const std::string &version_strategy,
std::vector<uint64_t> *real_versions) {
MSI_EXCEPTION_IF_NULL(real_versions);
auto model_directory = servable_spec.servable_directory;
auto model_name = servable_spec.servable_name;

if (!DirOrFileExist(model_directory + "/" + model_name)) {
return INFER_STATUS_LOG_ERROR(FAILED)
<< "Model not found, model_directory " << model_directory << ", model_name " << model_name;
}
std::string model_path = model_directory + "/" + model_name;
auto version_directory = [model_path](int64_t version_number) {
return model_path + "/" + std::to_string(version_number);
};
version_strategy_ = version_strategy;
// version_strategy:"specific","latest","multi"
GetVersions(servable_spec, real_versions);
if (real_versions->size() == 0) {
return INFER_STATUS_LOG_ERROR(FAILED)
<< "Not found invalid model version , model_directory " << model_directory << ", model_name " << model_name;
}
for (auto real_version_number : *real_versions) {
if (!DirOrFileExist(version_directory(real_version_number))) {
return INFER_STATUS_LOG_ERROR(FAILED) << "Open failed for version " << real_version_number << ", model_directory "
<< model_directory << ", model_name " << model_name;
}
}
return SUCCESS;
}

Status LocalModelServable::InitDevice(ModelType model_type, const std::map<std::string, std::string> &other_options) {
Status status;
auto context = ServableContext::Instance();
DeviceType device_type = ServableContext::Instance()->GetDeviceType();
auto get_support_device_type = [this, device_type, model_type]() {
std::vector<DeviceType> support_device_list;
if (device_type == kDeviceTypeNotSpecified || device_type == kDeviceTypeAscend) {
auto ascend_list = {kDeviceTypeAscendCL, kDeviceTypeAscendMS};
for (auto item : ascend_list) {
if (session_.CheckModelSupport(item, model_type)) {
return item;
}
}
} else if (device_type == kDeviceTypeAscendCL || device_type == kDeviceTypeAscendMS) {
if (session_.CheckModelSupport(device_type, model_type)) {
return device_type;
}
}
return kDeviceTypeNotSpecified;
};
auto support_device_type = get_support_device_type();
if (support_device_type == kDeviceTypeNotSpecified) {
return INFER_STATUS_LOG_ERROR(FAILED)
<< "Not support device type " << device_type << " and model type " << model_type
<< ". Ascend 910 supports MindIR model and Ascend 310 supports OM, MindIR model";
}
context->SetDeviceType(support_device_type);
return SUCCESS;
}

Status LocalModelServable::LoadModel(uint64_t version_number) {
ServableSignature signature;
if (!ServableStorage::Instance().GetServableDef(base_spec_.servable_name, &signature)) {
return INFER_STATUS_LOG_ERROR(FAILED) << "Servable " << base_spec_.servable_name << " has not been registered";
}
const auto &servable_meta = signature.servable_meta;
const auto &common_meta = servable_meta.common_meta;
const auto &local_meta = servable_meta.local_meta;
std::string model_file_name = base_spec_.servable_directory + "/" + base_spec_.servable_name + "/" +
std::to_string(version_number) + "/" + local_meta.servable_file;
auto context = ServableContext::Instance();
Status status = session_.LoadModelFromFile(context->GetDeviceType(), context->GetDeviceId(), model_file_name,
local_meta.model_format, common_meta.with_batch_dim,
common_meta.without_batch_dim_inputs, local_meta.load_options);
if (status != SUCCESS) {
return INFER_STATUS_LOG_ERROR(FAILED)
<< "Load model failed, servable directory: '" << base_spec_.servable_directory << "', servable name: '"
<< base_spec_.servable_name << "', servable file: '" << local_meta.servable_file << "', version number "
<< version_number << ", options " << local_meta.load_options;
}
return SUCCESS;
}

void LocalModelServable::Clear() {
if (model_loaded_) {
session_.UnloadModel();
}
model_loaded_ = false;
}

} // namespace mindspore::serving

+ 69
- 0
mindspore_serving/ccsrc/worker/local_servable/local_sevable.h View File

@@ -0,0 +1,69 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_SERVING_WORKER_ASCEND_SERVABLE_H
#define MINDSPORE_SERVING_WORKER_ASCEND_SERVABLE_H

#include <memory>
#include <vector>
#include <string>
#include <map>

#include "common/serving_common.h"
#include "common/instance.h"
#include "common/servable.h"
#include "worker/sevable_base.h"
#include "worker/inference/inference.h"
#include "worker/inference/mindspore_model_wrap.h"

namespace mindspore::serving {

class MS_API LocalModelServable : public ServableBase {
public:
LocalModelServable() = default;
~LocalModelServable() override;

Status Predict(const std::vector<TensorBasePtr> &input, std::vector<TensorBasePtr> *output) override;

std::vector<TensorInfo> GetInputInfos() const override;
std::vector<TensorInfo> GetOutputInfos() const override;
uint64_t GetBatchSize() const override;

Status StartServable(const std::string &servable_directory, const std::string &servable_name,
uint64_t version_number);
Status InitDevice(ModelType model_type, const std::map<std::string, std::string> &other_options);
std::string GetServableName() const override;
uint64_t GetServableVersion() const override;
void Clear() override;

private:
LoadServableSpec base_spec_;
std::string servable_name_;
uint64_t version_number_ = 0;

MindSporeModelWrap session_;
std::string version_strategy_;
bool model_loaded_ = false;

void GetVersions(const LoadServableSpec &servable_spec, std::vector<uint64_t> *real_versions);
Status LoadServableConfig(const LoadServableSpec &servable_spec, const std::string &version_strategy,
std::vector<uint64_t> *real_version_number);
Status LoadModel(uint64_t version);
};

} // namespace mindspore::serving

#endif // MINDSPORE_SERVING_WORKER_ASCEND_SERVABLE_H

+ 0
- 33
mindspore_serving/ccsrc/worker/model.cc View File

@@ -1,33 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "worker/model.h"
#include <algorithm>
#include "mindspore_serving/ccsrc/common/tensor.h"

namespace mindspore::serving {

Status AscendModelServable::Predict(const std::vector<TensorBasePtr> &input, std::vector<TensorBasePtr> *output) {
return session_->ExecuteModel(model_id_, input, output);
}

std::vector<TensorInfo> AscendModelServable::GetInputInfos() const { return session_->GetInputInfos(model_id_); }

std::vector<TensorInfo> AscendModelServable::GetOutputInfos() const { return session_->GetOutputInfos(model_id_); }

uint64_t AscendModelServable::GetBatchSize() const { return session_->GetBatchSize(model_id_); }

} // namespace mindspore::serving

mindspore_serving/ccsrc/worker/model.h → mindspore_serving/ccsrc/worker/sevable_base.h View File

@@ -14,8 +14,8 @@
* limitations under the License.
*/

#ifndef MINDSPORE_SERVING_WORKER_MODEL_H
#define MINDSPORE_SERVING_WORKER_MODEL_H
#ifndef MINDSPORE_SERVING_WORKER_SERVABLE_BASE_H
#define MINDSPORE_SERVING_WORKER_SERVABLE_BASE_H

#include <memory>
#include <unordered_map>
@@ -39,25 +39,11 @@ class ServableBase {
virtual std::vector<TensorInfo> GetInputInfos() const = 0;
virtual std::vector<TensorInfo> GetOutputInfos() const = 0;
virtual uint64_t GetBatchSize() const = 0;
};

class AscendModelServable : public ServableBase {
public:
AscendModelServable(const std::shared_ptr<serving::InferSession> &session, uint32_t model_id)
: session_(session), model_id_(model_id) {}
~AscendModelServable() = default;

Status Predict(const std::vector<TensorBasePtr> &input, std::vector<TensorBasePtr> *output) override;

std::vector<TensorInfo> GetInputInfos() const override;
std::vector<TensorInfo> GetOutputInfos() const override;
uint64_t GetBatchSize() const override;

private:
std::shared_ptr<serving::InferSession> session_{nullptr};
uint32_t model_id_ = 0;
virtual std::string GetServableName() const = 0;
virtual uint64_t GetServableVersion() const = 0;
virtual void Clear() = 0;
};

} // namespace mindspore::serving

#endif // MINDSPORE_SERVING_WORKER_MODEL_H
#endif // MINDSPORE_SERVING_WORKER_SERVABLE_BASE_H

+ 11
- 11
mindspore_serving/ccsrc/worker/work_executor.cc View File

@@ -49,15 +49,15 @@ Status WorkExecutor::CheckSevableSignature() {
if (servable_declare_.methods.empty()) {
return INFER_STATUS_LOG_ERROR(FAILED) << "There is no method registered for servable";
}
if (input_infos.size() != servable_declare_.servable_meta.inputs_count) {
return INFER_STATUS_LOG_ERROR(FAILED)
<< "The inputs count " << servable_declare_.servable_meta.inputs_count << " registered in method "
<< "not equal to the count " << input_infos.size() << " defined in servable";
const auto &common_meta = servable_declare_.servable_meta.common_meta;
if (input_infos.size() != common_meta.inputs_count) {
return INFER_STATUS_LOG_ERROR(FAILED) << "The inputs count " << common_meta.inputs_count << " registered in method "
<< "not equal to the count " << input_infos.size() << " defined in servable";
}
const auto &output_infos = output_infos_;
if (output_infos.size() != servable_declare_.servable_meta.outputs_count) {
if (output_infos.size() != common_meta.outputs_count) {
return INFER_STATUS_LOG_ERROR(FAILED)
<< "The outputs count " << servable_declare_.servable_meta.outputs_count << " registered in method "
<< "The outputs count " << common_meta.outputs_count << " registered in method "
<< "not equal to the count " << output_infos.size() << " defined in servable";
}
MSI_LOG_INFO << "Model input infos: count " << input_infos.size();
@@ -68,7 +68,7 @@ Status WorkExecutor::CheckSevableSignature() {
for (auto &item : output_infos) {
MSI_LOG_INFO << item.shape << ", " << item.data_type << ", " << item.size;
}
if (servable_declare_.servable_meta.with_batch_dim) {
if (common_meta.with_batch_dim) {
if (model_batch_size_ == 0) {
return INFER_STATUS_LOG_ERROR(FAILED) << "Servable batch size cannot be " << model_batch_size_;
}
@@ -104,7 +104,7 @@ Status WorkExecutor::Init(const ServableSignature &servable_declare, const std::
servable_ = servable;
input_infos_ = servable_->GetInputInfos();
output_infos_ = servable_->GetOutputInfos();
if (servable_declare_.servable_meta.with_batch_dim) {
if (servable_declare_.servable_meta.common_meta.with_batch_dim) {
model_batch_size_ = servable_->GetBatchSize();
} else {
model_batch_size_ = 1;
@@ -389,7 +389,7 @@ Status WorkExecutor::PostPredict(const std::vector<Instance> &inputs, const std:
MSI_LOG_EXCEPTION << "Output result data size cannot be 0";
}
auto shape = item->shape();
if (servable_declare_.servable_meta.with_batch_dim) {
if (servable_declare_.servable_meta.common_meta.with_batch_dim) {
if (shape.empty() || shape[0] != model_batch_size) {
MSI_LOG_EXCEPTION << "Output shape " << shape << " not match batch size " << model_batch_size;
}
@@ -429,9 +429,9 @@ Status WorkExecutor::Predict(const std::vector<Instance> &inputs, std::vector<In
}

bool WorkExecutor::IsNoBatchDimInput(int input_index) const {
auto without_batch_dim_inputs = servable_declare_.servable_meta.without_batch_dim_inputs;
auto without_batch_dim_inputs = servable_declare_.servable_meta.common_meta.without_batch_dim_inputs;
bool no_batch_dim = true;
if (servable_declare_.servable_meta.with_batch_dim) {
if (servable_declare_.servable_meta.common_meta.with_batch_dim) {
no_batch_dim = std::find(without_batch_dim_inputs.begin(), without_batch_dim_inputs.end(), input_index) !=
without_batch_dim_inputs.end();
}


+ 3
- 5
mindspore_serving/ccsrc/worker/work_executor.h View File

@@ -28,7 +28,7 @@
#include "common/serving_common.h"
#include "common/instance.h"
#include "common/servable.h"
#include "worker/model.h"
#include "worker/sevable_base.h"
#include "worker/predict_thread.h"
#include "worker/task_queue.h"

@@ -39,10 +39,8 @@ using WorkCallBack = std::function<void(const Instance &output, const Status &er

class WorkExecutor {
public:
WorkExecutor(std::shared_ptr<TaskQueue> py_preprocess_task_queue,
std::shared_ptr<TaskQueue> py_postprocess_task_queue,
std::shared_ptr<TaskQueue> cpp_preprocess_task_queue,
std::shared_ptr<TaskQueue> cpp_postprocess_task_queue);
WorkExecutor(std::shared_ptr<TaskQueue> py_preprocess, std::shared_ptr<TaskQueue> py_postprocess,
std::shared_ptr<TaskQueue> cpp_preprocess, std::shared_ptr<TaskQueue> cpp_postprocess);
~WorkExecutor();

Status Init(const ServableSignature &servable_declare, const std::shared_ptr<ServableBase> &servable);


+ 57
- 264
mindspore_serving/ccsrc/worker/worker.cc View File

@@ -34,46 +34,16 @@ namespace py = pybind11;
namespace mindspore {
namespace serving {

static const char *kVersionStrategyLastest = "lastest";
static const char *kVersionStrategySpecific = "specific";
static std::unique_ptr<MSWorkerServer> grpc_async_worker_server_;

Worker &Worker::GetInstance() {
static Worker instance;
return instance;
}

Status Worker::StartGrpcServer(const std::string &ip, uint32_t grpc_port) {
if (grpc_async_worker_server_ != nullptr) {
return INFER_STATUS_LOG_ERROR(SYSTEM_ERROR) << "Serving Error: Worker gRPC server is already running";
}
grpc_async_worker_server_ = std::make_unique<MSWorkerServer>(ip, grpc_port);
return grpc_async_worker_server_->Init();
}

Status Worker::RegisterWorker() {
std::vector<LoadServableSpec> specs;
std::vector<ServableSignature> signatures;
for (auto &work : work_list_) {
specs.push_back(work.servable_spec);
signatures.push_back(work.servable_signature);
}
std::vector<WorkerSpec> worker_specs;
for (size_t i = 0; i < specs.size(); i++) {
auto &spec = specs[i];
auto &servable_signature = signatures[i];
WorkerSpec worker_spec;
worker_spec.servable_name = spec.servable_name;
worker_spec.version_number = spec.version_number;
for (auto &method : servable_signature.methods) {
WorkerMethodInfo worker_method_info;
worker_method_info.name = method.method_name;
for (auto &name : method.inputs) {
worker_method_info.input_names.push_back(name);
}
worker_spec.methods.push_back(worker_method_info);
}
worker_specs.push_back(worker_spec);
for (auto &work : work_list_) {
// cppcheck-suppress useStlAlgorithm
worker_specs.push_back(work.worker_spec);
}
auto status = notify_master_->Register(worker_specs);
return status;
@@ -84,34 +54,10 @@ Status Worker::StartVersionController() {
return SUCCESS;
}

Status Worker::AddWorker(const ServableWorkerContext &work) {
WorkerSpec worker_spec;
worker_spec.servable_name = work.servable_spec.servable_name;
worker_spec.version_number = work.servable_spec.version_number;
for (auto &method : work.servable_signature.methods) {
WorkerMethodInfo worker_method_info;
worker_method_info.name = method.method_name;
for (auto &name : method.inputs) {
worker_method_info.input_names.push_back(name);
}
worker_spec.methods.push_back(worker_method_info);
}
return notify_master_->AddWorker(worker_spec);
}
Status Worker::AddWorker(const ServableWorkerContext &work) { return notify_master_->AddWorker(work.worker_spec); }

Status Worker::RemoveWorker(const ServableWorkerContext &work) {
WorkerSpec worker_spec;
worker_spec.servable_name = work.servable_spec.servable_name;
worker_spec.version_number = work.servable_spec.version_number;
for (auto &method : work.servable_signature.methods) {
WorkerMethodInfo worker_method_info;
worker_method_info.name = method.method_name;
for (auto &name : method.inputs) {
worker_method_info.input_names.push_back(name);
}
worker_spec.methods.push_back(worker_method_info);
}
return notify_master_->RemoveWorker(worker_spec);
return notify_master_->RemoveWorker(work.worker_spec);
}

Status Worker::Run(const proto::PredictRequest &request, proto::PredictReply *reply) {
@@ -189,74 +135,8 @@ std::pair<Status, std::shared_ptr<AsyncResult>> Worker::RunAsync(const RequestSp
return {SUCCESS, result};
}

Status Worker::InitEnv(ModelType model_type, const std::map<std::string, std::string> &other_options) {
Status status;
if (session_) {
return INFER_STATUS_LOG_ERROR(FAILED) << "Session has been inited";
}
auto context = ServableContext::Instance();
DeviceType device_type = kDeviceTypeNotSpecified;
session_ = InferSessionStorage::Instance().Get(context->GetDeviceType(), model_type, &device_type);
if (session_ == nullptr) {
return INFER_STATUS_LOG_ERROR(FAILED)
<< "Cannot find session registered for device type " << context->GetDeviceType() << " and model type "
<< model_type << ". Ascend 910 supports MindIR model and Ascend 310 supports OM, MindIR model";
}
if (device_type != kDeviceTypeNotSpecified) {
context->SetDeviceType(device_type);
}
status = session_->InitEnv(context->GetDeviceType(), context->GetDeviceId(), other_options);
if (status != SUCCESS) {
session_ = nullptr;
return INFER_STATUS_LOG_ERROR(FAILED)
<< "Init env failed, device type " << context->GetDeviceType() << ", device id " << context->GetDeviceId();
}
return SUCCESS;
}

Status Worker::FinalizeEnv() {
if (session_ != nullptr) {
return session_->FinalizeEnv();
}
return SUCCESS;
}
Status Worker::LoadModel(LoadServableSpec *servable_spec, uint64_t version_number, ServableWorkerContext *work) {
MSI_EXCEPTION_IF_NULL(servable_spec);
MSI_EXCEPTION_IF_NULL(work);
servable_spec->version_number = version_number;
ServableSignature signature;
if (!ServableStorage::Instance().GetServableDef(servable_spec->servable_name, &signature)) {
return INFER_STATUS_LOG_ERROR(FAILED) << "Servable " << servable_spec->servable_name << " has not been registerd";
}
const auto &servable_meta = signature.servable_meta;
std::string model_file_name = servable_spec->servable_directory + "/" + servable_spec->servable_name + "/" +
std::to_string(version_number) + "/" + servable_meta.servable_file;
uint32_t model_id;
auto context = ServableContext::Instance();
Status status = session_->LoadModelFromFile(context->GetDeviceType(), context->GetDeviceId(), model_file_name,
servable_meta.model_format, servable_meta.without_batch_dim_inputs,
servable_meta.load_options, &model_id);
if (status != SUCCESS) {
return INFER_STATUS_LOG_ERROR(FAILED)
<< "Load model failed, servable directory: '" << servable_spec->servable_directory << "', servable name: '"
<< servable_spec->servable_name << "', servable file: '" << servable_meta.servable_file
<< "', version number " << version_number << ", options " << servable_meta.load_options;
}
auto service = std::make_shared<WorkExecutor>(GetPyTaskQueuePreprocess(), GetPyTaskQueuePostprocess(),
GetCppTaskQueuePreprocess(), GetCppTaskQueuePostprocess());
status = service->Init(signature, std::make_shared<AscendModelServable>(session_, model_id));
if (status != SUCCESS) {
return status;
}
work->servable_spec = *servable_spec;
work->servable_signature = signature;
work->worker_service = service;
work->model_id = model_id;
work->model_file_name = model_file_name;
return SUCCESS;
}

void Worker::Update() {
/*
if (version_strategy_ == kVersionStrategySpecific) {
return;
}
@@ -291,10 +171,19 @@ void Worker::Update() {
MSI_LOG_INFO << "UnLoad Model version " << iter->servable_spec.version_number << " success";
work_list_.erase(iter);
}
*/
}

Status Worker::StartServable(const std::string &servable_directory, const std::string &servable_name,
uint32_t version_number, std::shared_ptr<BaseNotifyMaster> notify_master) {
Status Worker::StartGrpcServer(const std::shared_ptr<MSWorkerServer> &grpc_server, const std::string &worker_ip,
int32_t port) {
if (worker_grpc_server_ != nullptr) {
return INFER_STATUS_LOG_ERROR(FAILED) << "Worker gRPC server is already running";
}
worker_grpc_server_ = grpc_server;
return worker_grpc_server_->StartWorkerGrpcServer(worker_ip, port);
}

Status Worker::StartServable(std::shared_ptr<ServableBase> servable, std::shared_ptr<BaseNotifyMaster> notify_master) {
ExitSignalHandle::Instance().Start(); // handle ctrl+c to exit
if (servable_started_) {
MSI_LOG_EXCEPTION << "A servable has been started, only one servable can run in a process currently.";
@@ -307,58 +196,42 @@ Status Worker::StartServable(const std::string &servable_directory, const std::s
cpp_postprocess_.Start(2);

notify_master_ = std::move(notify_master);
base_spec_.servable_directory = servable_directory;
base_spec_.servable_name = servable_name;
base_spec_.version_number = version_number;

std::string version_strategy;
if (version_number == 0) {
version_strategy = kVersionStrategyLastest;
} else {
version_strategy = kVersionStrategySpecific;
}
Status status;
auto servable_name = servable->GetServableName();
ServableSignature signature;
if (!ServableStorage::Instance().GetServableDef(servable_name, &signature)) {
return INFER_STATUS_LOG_ERROR(FAILED) << "Servable '" << servable_name << "' has not been registered";
}
if (session_ == nullptr) {
status = InitEnv(signature.servable_meta.model_format, {});
if (status != SUCCESS) {
MSI_LOG_ERROR << "Init env failed";
return status;
}
return INFER_STATUS_LOG_ERROR(FAILED) << "Servable " << servable_name << " has not been registered";
}
std::vector<uint64_t> real_versions;
status = LoadServableConfig(base_spec_, version_strategy, &real_versions);
auto service = std::make_shared<WorkExecutor>(GetPyTaskQueuePreprocess(), GetPyTaskQueuePostprocess(),
GetCppTaskQueuePreprocess(), GetCppTaskQueuePostprocess());
auto status = service->Init(signature, servable);
if (status != SUCCESS) {
return INFER_STATUS_LOG_ERROR(FAILED)
<< "Start servable failed, there is no servable of the specified version number, specified version number: "
<< version_number << ", servable directory: '" << base_spec_.servable_directory << "', servable name: '"
<< base_spec_.servable_name
<< "'. version number is a positive integer(started from 1) and 0 represents the maximum version number.";
return status;
}
for (auto real_version_number : real_versions) {
ServableWorkerContext work;
status = LoadModel(&base_spec_, real_version_number, &work);
if (status != SUCCESS) {
return status;
ServableWorkerContext work;
WorkerSpec worker_spec;
worker_spec.servable_name = servable_name;
worker_spec.version_number = servable->GetServableVersion();
for (auto &method : signature.methods) {
WorkerMethodInfo worker_method_info;
worker_method_info.name = method.method_name;
for (auto &name : method.inputs) {
worker_method_info.input_names.push_back(name);
}
work_list_.push_back(work);
worker_spec.methods.push_back(worker_method_info);
}
work.worker_spec = worker_spec;
work.servable_signature = signature;
work.worker_service = service;
work.servable = servable;

work_list_.push_back(work);

status = RegisterWorker();
if (status != SUCCESS) {
MSI_LOG_ERROR << "Register worker failed";
return status;
}
servable_started_ = true;
status = INFER_STATUS(SUCCESS) << "Serving: Start servable success, servable directory: '" << servable_directory
<< "', servable name: '" << servable_name
<< "', specified version number: " << version_number
<< ", started version numbers: " << real_versions;
MSI_LOG_INFO << status.StatusMessage();
std::cout << status.StatusMessage() << std::endl;
return SUCCESS;
}

@@ -368,119 +241,39 @@ void Worker::StopServable(bool notify_master) {
}

void Worker::Clear() {
std::unique_lock<std::shared_mutex> lock(worker_shared_lock_);
ServableStorage::Instance().Clear();
worker_grpc_server_ = nullptr;
if (clear_flag_.test_and_set()) {
return;
}
std::unique_lock<std::shared_mutex> lock(worker_shared_lock_);
MSI_LOG_INFO << "Start clear worker session";
version_controller_.StopPollModelPeriodic();
if (exit_notify_master_ && servable_started_) {
notify_master_->Unregister();
}
if (session_ != nullptr) {
for (auto &it : work_list_) {
session_->UnloadModel(it.model_id);
}
for (auto &worker_item : work_list_) {
worker_item.servable->Clear();
}
work_list_.clear();
FinalizeEnv();

session_ = nullptr;
py_task_queue_group_.Stop();
cpp_preprocess_.Stop();
cpp_postprocess_.Stop();
ServableStorage::Instance().Clear();
grpc_async_worker_server_ = nullptr;
servable_started_ = false;
MSI_LOG_INFO << "End clear worker session";
}

bool Worker::HasCleared() { return !servable_started_; }
bool Worker::IsRunning() { return servable_started_; }

Worker::~Worker() { Clear(); }

void Worker::GetVersions(const LoadServableSpec &servable_spec, std::vector<uint64_t> *real_versions) {
MSI_EXCEPTION_IF_NULL(real_versions);
// define version_strategy:"specific","lastest","multi"
if (version_strategy_ == kVersionStrategySpecific) {
real_versions->push_back(servable_spec.version_number);
return;
}
auto trans_to_integer = [](const std::string &str) -> uint32_t {
uint32_t parsed_value = 0;
for (auto c : str) {
if (c < '0' || c > '9') {
return 0;
}
parsed_value = parsed_value * 10 + c - '0';
}
if (std::to_string(parsed_value) != str) {
return 0;
}
return parsed_value;
};
uint64_t newest_version = 0;
std::string model_path = servable_spec.servable_directory + "/" + servable_spec.servable_name;
auto sub_dir = GetAllSubDirsNotFullPath(model_path);
static std::set<std::string> ignore_dir;
for (const auto &dir : sub_dir) {
if (dir == "__pycache__") continue;
auto version_parse = trans_to_integer(dir);
if (version_parse == 0) {
if (ignore_dir.emplace(servable_spec.servable_directory + dir).second) {
MSI_LOG_INFO << "Ignore directory " << dir << ", model_directory " << servable_spec.servable_directory
<< ", model_name " << servable_spec.servable_name;
}
continue;
}
real_versions->push_back(version_parse);
if (version_parse > newest_version) {
newest_version = version_parse;
}
}
if (version_strategy_ == kVersionStrategyLastest) {
real_versions->clear();
if (newest_version != 0) {
real_versions->push_back(newest_version);
}
}
}
Status Worker::LoadServableConfig(const LoadServableSpec &servable_spec, const std::string &version_strategy,
std::vector<uint64_t> *real_versions) {
MSI_EXCEPTION_IF_NULL(real_versions);
auto model_directory = servable_spec.servable_directory;
auto model_name = servable_spec.servable_name;

if (!DirOrFileExist(model_directory + "/" + model_name)) {
return INFER_STATUS_LOG_ERROR(FAILED)
<< "Model not found, model_directory " << model_directory << ", model_name " << model_name;
}
std::string model_path = model_directory + "/" + model_name;
auto version_directory = [model_path](int64_t version_number) {
return model_path + "/" + std::to_string(version_number);
};
version_strategy_ = version_strategy;
// version_strategy:"specific","lastest","multi"
GetVersions(servable_spec, real_versions);
if (real_versions->size() == 0) {
return INFER_STATUS_LOG_ERROR(FAILED)
<< "Not found invalid model version , model_directory " << model_directory << ", model_name " << model_name;
}
for (auto real_version_number : *real_versions) {
if (!DirOrFileExist(version_directory(real_version_number))) {
return INFER_STATUS_LOG_ERROR(FAILED) << "Open failed for version " << real_version_number << ", model_directory "
<< model_directory << ", model_name " << model_name;
}
}
return SUCCESS;
}

ServableWorkerContext Worker::GetServableWorker(const RequestSpec &request_spec) {
ServableWorkerContext context;
if (request_spec.version_number != 0) {
auto item = find_if(work_list_.begin(), work_list_.end(), [&](const ServableWorkerContext &v) {
return v.servable_spec.servable_name == request_spec.servable_name &&
v.servable_spec.version_number == request_spec.version_number;
return v.worker_spec.servable_name == request_spec.servable_name &&
v.worker_spec.version_number == request_spec.version_number;
});
if (item != work_list_.end()) {
context = *item;
@@ -488,10 +281,10 @@ ServableWorkerContext Worker::GetServableWorker(const RequestSpec &request_spec)
} else {
uint64_t max_version = 0;
for (auto &item : work_list_) {
if (item.servable_spec.servable_name == request_spec.servable_name &&
item.servable_spec.version_number > max_version) {
if (item.worker_spec.servable_name == request_spec.servable_name &&
item.worker_spec.version_number > max_version) {
context = item;
max_version = item.servable_spec.version_number;
max_version = item.worker_spec.version_number;
}
}
}
@@ -500,11 +293,11 @@ ServableWorkerContext Worker::GetServableWorker(const RequestSpec &request_spec)

Worker::Worker() {}

ssize_t Worker::GetBatchSize() const {
ssize_t batch_size_ret = -1;
for (auto service : work_list_) {
auto batch_size = session_->GetBatchSize(service.model_id);
if (batch_size != -1) {
size_t Worker::GetBatchSize() const {
size_t batch_size_ret = 1;
for (const auto &service : work_list_) {
auto batch_size = service.servable->GetBatchSize();
if (batch_size != 0) {
batch_size_ret = batch_size;
break;
}
@@ -532,7 +325,7 @@ Status AsyncResult::GetNext(Instance *instance_result) {
const int kWaitMaxHundredMs = 100;
int i;
for (i = 0; i < kWaitMaxHundredMs; i++) { //
if (ExitSignalHandle::Instance().HasStopped() || Worker::GetInstance().HasCleared()) {
if (ExitSignalHandle::Instance().HasStopped() || !Worker::GetInstance().IsRunning()) {
instance_result->error_msg = Status(SYSTEM_ERROR, "Servable stopped");
return SYSTEM_ERROR;
}


+ 10
- 19
mindspore_serving/ccsrc/worker/worker.h View File

@@ -32,6 +32,8 @@
#include "worker/task_queue.h"
#include "worker/version_control/version_controller.h"
#include "common/grpc_async_server.h"
#include "worker/sevable_base.h"
#include "worker/grpc/worker_server.h"

namespace mindspore {
namespace serving {
@@ -53,11 +55,10 @@ class AsyncResult {
};

struct ServableWorkerContext {
LoadServableSpec servable_spec;
WorkerSpec worker_spec;
ServableSignature servable_signature;
std::shared_ptr<WorkExecutor> worker_service = nullptr;
uint32_t model_id = 0;
std::string model_file_name;
std::shared_ptr<ServableBase> servable = nullptr;
};

class MS_API Worker {
@@ -72,17 +73,14 @@ class MS_API Worker {
Status Run(const RequestSpec &request_spec, const std::vector<InstanceData> &inputs, std::vector<Instance> *outputs);
std::pair<Status, std::shared_ptr<AsyncResult>> RunAsync(const RequestSpec &request_spec,
const std::vector<InstanceData> &inputs);
Status StartServable(std::shared_ptr<ServableBase> servable, std::shared_ptr<BaseNotifyMaster> notify_master);

Status InitEnv(ModelType model_type, const std::map<std::string, std::string> &other_options);
Status FinalizeEnv();
Status StartGrpcServer(const std::shared_ptr<MSWorkerServer> &grpc_server, const std::string &worker_ip,
int32_t port);

Status StartServable(const std::string &servable_directory, const std::string &servable_name, uint32_t version_number,
std::shared_ptr<BaseNotifyMaster> notify_master);
void StopServable(bool notify_master = true);
bool HasCleared();
bool IsRunning();
Status RegisterWorker();
Status StartGrpcServer(const std::string &ip, uint32_t grpc_port);
Status LoadModel(LoadServableSpec *servable_spec, uint64_t version, ServableWorkerContext *work);
void Update();
Status StartVersionController();
Status AddWorker(const ServableWorkerContext &work);
@@ -93,31 +91,24 @@ class MS_API Worker {
std::shared_ptr<TaskQueue> GetPyTaskQueuePostprocess() { return py_task_queue_group_.GetPostprocessTaskQueue(); }
std::shared_ptr<TaskQueue> GetCppTaskQueuePreprocess() { return cpp_preprocess_.GetTaskQueue(); }
std::shared_ptr<TaskQueue> GetCppTaskQueuePostprocess() { return cpp_postprocess_.GetTaskQueue(); }
ssize_t GetBatchSize() const;
size_t GetBatchSize() const;

private:
static std::shared_ptr<Worker> global_worker_;

std::vector<ServableWorkerContext> work_list_;
std::shared_ptr<serving::InferSession> session_ = nullptr;
std::string version_strategy_;
PyTaskQueueGroup py_task_queue_group_;
PreprocessThreadPool cpp_preprocess_;
PostprocessThreadPool cpp_postprocess_;

VersionController version_controller_;
LoadServableSpec base_spec_;
std::atomic_bool exit_notify_master_ = true;
std::atomic_bool servable_started_ = false;
std::atomic_flag clear_flag_ = ATOMIC_FLAG_INIT;
std::shared_ptr<BaseNotifyMaster> notify_master_ = nullptr;
std::shared_ptr<MSWorkerServer> worker_grpc_server_ = nullptr;

std::shared_mutex worker_shared_lock_;

ServableWorkerContext GetServableWorker(const RequestSpec &request_spec);
Status LoadServableConfig(const LoadServableSpec &servable_spec, const std::string &version_strategy,
std::vector<uint64_t> *real_version_number);
void GetVersions(const LoadServableSpec &servable_spec, std::vector<uint64_t> *real_versions);
};

} // namespace serving


+ 2
- 0
mindspore_serving/master/_master.py View File

@@ -18,6 +18,7 @@ import threading
from functools import wraps
from mindspore_serving.worker import check_type
from mindspore_serving import log as logger
from mindspore_serving._mindspore_serving import ExitSignalHandle_
from mindspore_serving._mindspore_serving import Master_

_wait_and_clear_thread = None
@@ -59,6 +60,7 @@ def stop_on_except(func):
@wraps(func)
def handle_except(*args, **kwargs):
try:
ExitSignalHandle_.start() # Set flag to running and receive Ctrl+C message
func(*args, **kwargs)
except:
stop()


+ 43
- 0
mindspore_serving/proto/ms_agent.proto View File

@@ -0,0 +1,43 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

// ms_manager.proto
syntax = "proto3";

package mindspore.serving.proto;
import "mindspore_serving/proto/ms_service.proto";

message DistributedPredictRequest {
repeated Tensor inputs = 1;
}

message DistributedPredictReply {
repeated Tensor outputs = 1;
ErrorMsg error_msg = 2;
}

message DistributedExitRequest {
string address = 1;
}

message DistributedExitReply {
ErrorMsg error_msg = 1;
}

service MSAgent {
rpc Predict(DistributedPredictRequest) returns (DistributedPredictReply) {}
rpc Exit(DistributedExitRequest) returns (DistributedExitReply) {}
}

+ 53
- 0
mindspore_serving/proto/ms_distributed.proto View File

@@ -0,0 +1,53 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

// ms_manager.proto
syntax = "proto3";

package mindspore.serving.proto;
import "mindspore_serving/proto/ms_service.proto";

message AgentSpec {
int64 rank_id = 1;
int64 batch_size = 2;
repeated Tensor inputs = 3;
repeated Tensor outputs = 4;
}

message AgentRegisterRequest {
repeated AgentSpec agent_spec = 1;
string address = 2;
}

message AgentRegisterReply {
ErrorMsg error_msg = 1;
}

message AgentExitRequest {
repeated AgentSpec agent_spec = 1;
string address = 2;
}

message AgentExitReply {
ErrorMsg error_msg = 1;
}

message AgentFailedRequest {
}

message AgentFailedReply {
ErrorMsg error_msg = 1;
}

+ 2
- 0
mindspore_serving/proto/ms_service.proto View File

@@ -80,6 +80,8 @@ message Tensor {

// for string type and images, the dtype is MS_BYTES.
repeated bytes bytes_val = 4;

int64 size = 5;
}

message ServableSpec {


+ 6
- 0
mindspore_serving/proto/ms_worker.proto View File

@@ -20,8 +20,14 @@ syntax = "proto3";
package mindspore.serving.proto;
import "mindspore_serving/proto/ms_service.proto";
import "mindspore_serving/proto/ms_master.proto";
import "mindspore_serving/proto/ms_distributed.proto";

service MSWorker {
// for master
rpc Predict(PredictRequest) returns (PredictReply) {}
rpc Exit(ExitRequest) returns (ExitReply) {}
// for worker agent
rpc AgentExit(AgentExitRequest) returns (AgentExitReply) {}
rpc AgentRegister(AgentRegisterRequest) returns (AgentRegisterReply) {}
rpc AgentFailed(AgentFailedRequest) returns (AgentFailedReply) {}
}

+ 3
- 1
mindspore_serving/worker/_worker.py View File

@@ -12,11 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Inferface for start up servable"""
"""Interface for start up servable"""

import threading
from functools import wraps
from mindspore_serving import log as logger
from mindspore_serving._mindspore_serving import ExitSignalHandle_
from mindspore_serving._mindspore_serving import Worker_
from .register.preprocess import preprocess_storage
from .register.postprocess import postprocess_storage
@@ -77,6 +78,7 @@ def stop_on_except(func):
@wraps(func)
def handle_except(*args, **kwargs):
try:
ExitSignalHandle_.start() # Set flag to running and receive Ctrl+C message
func(*args, **kwargs)
except:
stop()


+ 250
- 0
mindspore_serving/worker/distributed/agent_startup.py View File

@@ -0,0 +1,250 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Serving, distributed worker agent startup"""

import os
import time
from multiprocessing import Process, Pipe

from mindspore_serving._mindspore_serving import ExitSignalHandle_
from mindspore_serving._mindspore_serving import WorkerAgent_, AgentStartUpConfig_

from mindspore_serving import log as logger
from mindspore_serving.worker import check_type
from mindspore_serving.worker.distributed import worker_agent


def _get_local_ip(rank_list, port):
"""Get the local ip from the rank table config"""
import socket
ip_list = []
for item in rank_list:
ip_list.append(item.ip)
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
for ip in ip_list:
try:
s.bind((ip, port))
logger.info(f"Get local machine ip success, ip {ip}")
return ip
# pylint: disable=bare-except
except:
pass
raise RuntimeError(f"Get local machine ip failed, rank table ips: {ip_list}, bind port {port}")


def _update_model_files_path(model_files, group_config_files):
"""Check and return model files or group config files"""
script_dir = os.path.dirname(os.path.realpath(__file__))
logger.info(f"input model files: {model_files}")
logger.info(f"input group config files: {group_config_files}")
model_files_temp = []
for item in model_files:
file_name = os.path.join(script_dir, item)
if not os.access(file_name, os.R_OK):
raise RuntimeError(f"Cannot access model file '{file_name}'")
model_files_temp.append(file_name)

group_files_temp = []
for item in group_config_files:
file_name = os.path.join(script_dir, item)
if not os.access(file_name, os.R_OK):
raise RuntimeError(f"Cannot access group config file '{file_name}'")
group_files_temp.append(file_name)

logger.info(f"absolute model files: {model_files_temp}")
logger.info(f"absolute group config files: {group_files_temp}")
return model_files_temp, group_files_temp


def _make_json_table_file(distributed_config):
"""Make rank table json file"""
rank_size = len(distributed_config.rank_list)
runtime_dir = os.path.abspath(".")
time_stamp = str(time.strftime('%Y_%m_%d_%H_%M_%S', time.localtime(time.time())))
rank_table_file_name = os.path.join(runtime_dir, f"hccl_rank_table_{time_stamp}_{rank_size}.json")
with open(rank_table_file_name, "w") as fp:
fp.write(distributed_config.rank_table_content)
return rank_table_file_name


signal_success = "Success"
signal_exit = "Exit"
signal_heartbeat = "HeartBeat"


def _recv_parent(index, recv_pipe):
"""Receive message from Start up process.
Return False on Ctrl+C(and worker Stop message) Exit Signal, heartbeat failed, and signal_exit.
Return True on receiving signal_success."""
try:
while True:
heartbeat_count = 0
while not recv_pipe.poll(0.1):
if ExitSignalHandle_.has_stopped():
logger.warning(f"Child {index}: Exit on Ctrl+C or stop message from worker")
return False
heartbeat_count += 1
if heartbeat_count >= 30: # 3s
logger.warning(f"Child {index}: Exit on failure of receiving parent message")
return False
parent_signal = recv_pipe.recv()
if parent_signal != signal_heartbeat:
break
if parent_signal == signal_success:
logger.info(f"Child {index}: Receive success")
return True
if parent_signal == signal_exit:
logger.warning(f"Child {index}: Exit on receiving exit message")
else:
logger.warning(f"Child {index}: Exit on receiving unknown message {parent_signal}")
# pylint: disable=broad-except
except Exception as e:
logger.warning(f"Child {index}: Exit on exception: {e}")
return False


def _agent_process(send_pipe, recv_pipe, index, start_config):
"""Agent process"""
try:
# listening success or failed message from parent process
ExitSignalHandle_.start() # Set flag to running and receive Ctrl+C message
worker_agent.start_worker_agent(start_config=start_config)
send_pipe.send((index, signal_success))
success_msg = _recv_parent(index, recv_pipe)
if not success_msg:
worker_agent.stop()
send_pipe.close()
recv_pipe.close()
# pylint: disable=broad-except
except Exception as e:
logger.error(f"Child {index}: Catch exception and notify exit of others")
send_pipe.send((index, e))
worker_agent.stop()
raise


def _start_listening_child_processes(p_recv_pipe, send_pipe_list, subprocess_list):
"""Listening child process"""
def send_pipe_msg(send_pipe, msg):
try:
send_pipe.send(msg)
# pylint: disable=broad-except
except Exception as e:
logger.warning(f"Send pipe message exception happen: {e}")

count = len(send_pipe_list)
for _ in range(count):
while True:
if p_recv_pipe.poll(0.1):
break
for send_pipe, process in zip(send_pipe_list, subprocess_list):
if process.is_alive():
continue
logger.warning("Fail to start agents because of death of one agent")
for send_pipe_x, process_x in zip(send_pipe_list, subprocess_list):
if process_x.is_alive():
send_pipe_msg(send_pipe_x, signal_exit)
return False
for send_pipe in send_pipe_list:
send_pipe_msg(send_pipe, signal_heartbeat)

_, msg = p_recv_pipe.recv()
if isinstance(msg, Exception):
logger.warning("Fail to start agents because of exception raise by one agent")
for send_pipe in send_pipe_list:
send_pipe_msg(send_pipe, signal_exit)
return False

for send_pipe in send_pipe_list:
send_pipe_msg(send_pipe, signal_success)
logger.info("Success to start agents")
return True


def _startup_all_agents(common_meta, worker_ip, worker_port,
agent_ip, agent_start_port, device_id_list, rank_id_list,
model_files, group_config_files, rank_table_file):
"""Start up all agents in one machine"""
servable_name = common_meta.servable_name
index = 0
send_pipe_list = []
subprocess_list = []
c_send_pipe, p_recv_pipe = Pipe()
for device_id, rank_id, model_file, group_file in zip(device_id_list, rank_id_list, model_files,
group_config_files):
p_send_pipe, c_recv_pipe = Pipe()
send_pipe_list.append(p_send_pipe)

agent_port = agent_start_port + index

start_config = AgentStartUpConfig_()
start_config.rank_id = rank_id
start_config.device_id = device_id
start_config.model_file_name = model_file
start_config.group_file_name = group_file
start_config.rank_table_json_file_name = rank_table_file
start_config.agent_ip = agent_ip
start_config.agent_port = agent_port
start_config.worker_ip = worker_ip
start_config.worker_port = worker_port
start_config.common_meta = common_meta

process = Process(target=_agent_process,
args=(c_send_pipe, c_recv_pipe, index, start_config),
name=f"{servable_name}_worker_agent_rank{rank_id}_device{device_id}")
process.start()
subprocess_list.append(process)
index += 1
ret = _start_listening_child_processes(p_recv_pipe, send_pipe_list, subprocess_list)
if not ret:
WorkerAgent_.notify_failed(worker_ip, worker_port)


def startup_worker_agents(worker_ip, worker_port, model_files, group_config_files, agent_start_port=7000):
"""Start up all needed worker agents on one machine"""
check_type.check_str("worker_ip", worker_ip)
check_type.check_ip_port("worker_port", worker_port)
check_type.check_int("agent_start_port", agent_start_port, 1, 65535 - 7)
model_files = check_type.check_and_as_int_tuple_list("model_files", model_files)
group_config_files = check_type.check_and_as_int_tuple_list("group_config_files", group_config_files)
distributed_config = WorkerAgent_.get_agents_config_from_worker(worker_ip, worker_port)

# get machine ip
rank_list = distributed_config.rank_list
local_ip = _get_local_ip(rank_list, agent_start_port)
# get all device_id and rank_id
local_device_id_list = []
local_rank_id_list = []
for rank_id, item in enumerate(rank_list):
if item.ip == local_ip:
local_device_id_list.append(item.device_id)
local_rank_id_list.append(rank_id)

# handle model files and group config files
if len(local_device_id_list) != len(model_files):
raise RuntimeError(f"Card count {local_device_id_list} described rank table does not equal to model files size "
f"{len(model_files)}, model files: {model_files}")

if len(local_device_id_list) != len(group_config_files):
raise RuntimeError(f"Card count {local_device_id_list} described rank table does not equal to group config "
f"files size {len(group_config_files)}, group config files: {group_config_files}")

model_files, group_config_files = _update_model_files_path(model_files, group_config_files)

# make json table file and export env
rank_table_file = _make_json_table_file(distributed_config)
_startup_all_agents(distributed_config.common_meta, worker_ip, worker_port, local_ip, agent_start_port,
local_device_id_list, local_rank_id_list,
model_files, group_config_files, rank_table_file)

+ 131
- 0
mindspore_serving/worker/distributed/distributed_worker.py View File

@@ -0,0 +1,131 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Serving, distributed worker startup"""
from mindspore_serving._mindspore_serving import Worker_

from mindspore_serving.worker import check_type
from mindspore_serving.worker._worker import _start_py_task, _start_wait_and_clear
from mindspore_serving.worker._worker import stop_on_except, _load_servable_config


@stop_on_except
def start_distributed_servable(servable_directory, servable_name, rank_table_json_file, version_number=1,
worker_ip="0.0.0.0", worker_port=6200, master_ip="0.0.0.0", master_port=6100,
wait_agents_time_in_seconds=300):
r"""
Start up the servable named 'servable_name' defined in 'servable_directory', and link the worker to the master
through gRPC (master_ip, master_port).

Serving has two running modes. One is running in a single process, providing the Serving service of a single model.
The other includes a master and multiple workers. This interface is for the second scenario.

The master is responsible for providing the Serving access interface for clients,
while the worker is responsible for providing the inference service of the specific model. The communications
between the master and workers through gPRC are defined as (master_ip, master_port) and (worker_ip, worker_port).

Args:
servable_directory (str): The directory where the servable is located in. There expects to has a directory
named `servable_name`. For more detail:
`How to config Servable <https://www.mindspore.cn/tutorial/inference/zh-CN/master/serving_model.html>`_ .

servable_name (str): The servable name.
version_number (int): Servable version number to be loaded. The version number should be a positive integer,
starting from 1, and 0 means to load the latest version. Default: 0.
rank_table_json_file (str): The ranke table json file name.
master_ip (str): The master ip the worker linked to.
master_port (int): The master port the worker linked to.
worker_ip (str): The worker ip the master and agents linked to.
worker_port (int): The worker port the master and agents linked to.
wait_agents_time_in_seconds(int): The maximum time in seconds the worker waiting ready of all agents.

Examples:
>>> import os
>>> from mindspore_serving import worker
>>>
>>> servable_dir = os.path.abspath(".")
>>> worker.start_servable(servable_dir, "lenet", device_id=0, \
... master_ip="127.0.0.1", master_port=6500, \
... host_ip="127.0.0.1", host_port=6600)
"""
check_type.check_str('servable_directory', servable_directory)
check_type.check_str('servable_name', servable_name)
check_type.check_int('version_number', version_number, 0)
if version_number == 0:
version_number = 1
check_type.check_str('rank_table_json_file', rank_table_json_file)

check_type.check_str('master_ip', master_ip)
check_type.check_ip_port('master_port', master_port)

check_type.check_str('worker_ip', worker_ip)
check_type.check_ip_port('worker_port', worker_port)

_load_servable_config(servable_directory, servable_name)
Worker_.start_distributed_servable(servable_directory, servable_name, rank_table_json_file, version_number,
master_ip, master_port, worker_ip, worker_port, wait_agents_time_in_seconds)
_start_py_task(Worker_.get_batch_size())
_start_wait_and_clear()


@stop_on_except
def start_distributed_servable_in_master(servable_directory, servable_name, rank_table_json_file, version_number=1,
worker_ip="0.0.0.0", worker_port=6200, wait_agents_time_in_seconds=300):
r"""
Start up the servable named 'servable_name' defined in 'svable_directory', and the worker will run in
the process of the master.

Serving has two running modes. One is running in a single process, providing the Serving service of a single model.
The other includes a master and multiple workers. This interface is for the first scenario.

Args:
servable_directory (str): The directory where the servable is located in. There expects to has a directory named
`servable_name`. For more detail:
`How to config Servable <https://www.mindspore.cn/tutorial/inference/zh-CN/master/serving_model.html>`_ .

servable_name (str): The servable name.
version_number (int): Servable version number to be loaded. The version number should be a positive integer,
starting from 1, and 0 means to load the latest version. Default: 0.
rank_table_json_file (str): The ranke table json file name.
worker_ip (str): The worker ip the agents linked to.
worker_port (int): The worker port the agents linked to.
wait_agents_time_in_seconds(int): The maximum time in seconds the worker waiting ready of all agents.

Examples:
>>> import os
>>> from mindspore_serving import worker
>>> from mindspore_serving import master
>>>
>>> servable_dir = os.path.abspath(".")
>>> worker.start_servable_in_master(servable_dir, "lenet", device_id=0)
>>>
>>> master.start_grpc_server("0.0.0.0", 5500)
>>> master.start_restful_server("0.0.0.0", 1500)
"""
check_type.check_str('servable_directory', servable_directory)
check_type.check_str('servable_name', servable_name)
check_type.check_int('version_number', version_number, 0)
if version_number == 0:
version_number = 1

check_type.check_str('rank_table_json_file', rank_table_json_file)

check_type.check_str('worker_ip', worker_ip)
check_type.check_ip_port('worker_port', worker_port)

_load_servable_config(servable_directory, servable_name)
Worker_.start_distributed_servable_in_master(servable_directory, servable_name, rank_table_json_file,
version_number, worker_ip, worker_port, wait_agents_time_in_seconds)
_start_py_task(Worker_.get_batch_size())
_start_wait_and_clear()

+ 43
- 0
mindspore_serving/worker/distributed/register.py View File

@@ -0,0 +1,43 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Serving, distributed worker register"""

from mindspore_serving._mindspore_serving import ServableMeta_, ServableStorage_
from mindspore_serving.worker import check_type
from mindspore_serving.worker.common import get_servable_dir
from mindspore_serving import log as logger


def declare_distributed_servable(rank_size, stage_size, with_batch_dim, without_batch_dim_inputs):
"""declare distributed servable in servable_config.py"""
check_type.check_bool('with_batch_dim', with_batch_dim)

meta = ServableMeta_()
meta.common_meta.servable_name = get_servable_dir()
meta.common_meta.with_batch_dim = with_batch_dim
if without_batch_dim_inputs:
without_batch_dim_inputs = check_type.check_and_as_int_tuple_list('without_batch_dim_inputs',
without_batch_dim_inputs, 0)
meta.common_meta.without_batch_dim_inputs = without_batch_dim_inputs

# init distributed servable meta info
check_type.check_int("rank_size", rank_size, 1)
check_type.check_int("stage_size", stage_size, 1)
meta.distributed_meta.rank_size = rank_size
meta.distributed_meta.stage_size = stage_size
ServableStorage_.declare_distributed_servable(meta)
logger.info(f"Declare distributed servable, servable_name: {meta.common_meta.servable_name} "
f", rank_size: {rank_size} , stage_size: {stage_size}, with_batch_dim: {with_batch_dim} "
f", without_batch_dim_inputs: {without_batch_dim_inputs}")

+ 66
- 0
mindspore_serving/worker/distributed/worker_agent.py View File

@@ -0,0 +1,66 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Serving, distributed worker agent"""

import os
import threading
from mindspore_serving._mindspore_serving import WorkerAgent_, AgentStartUpConfig_
from mindspore_serving import log as logger


def start_worker_agent(start_config):
"""Start up one worker agent on one device id, invoke by agent_startup.startup_worker_agents
"""
if not isinstance(start_config, AgentStartUpConfig_):
raise RuntimeError("Parameter 'start_config' should be instance of AgentStartUpConfig_")

os.environ["RANK_ID"] = str(start_config.rank_id)
os.environ["DEVICE_ID"] = str(start_config.device_id)
os.environ["MS_ENABLE_HCCL"] = "1"
os.environ["PARA_GROUP_FILE"] = start_config.group_file_name
os.environ["RANK_TABLE_FILE"] = start_config.rank_table_json_file_name

for item in ("RANK_ID", "DEVICE_ID", "MS_ENABLE_HCCL", "PARA_GROUP_FILE", "RANK_TABLE_FILE",
"LD_LIBRARY_PATH", "PYTHONPATH"):
logger.info(f"Env {item}: {os.getenv(item, '')}")
WorkerAgent_.start_agent(start_config)

start_wait_and_clear()


_wait_and_clear_thread = None


def start_wait_and_clear():
"""Waiting for Ctrl+C, and clear up environment"""

def thread_func():
logger.info("Serving worker: wait for Ctrl+C to exit ------------------------------------")
print("Serving worker: wait for Ctrl+C to exit ------------------------------------")
WorkerAgent_.wait_and_clear()
logger.info("Serving worker: exited ------------------------------------")
print("Serving worker: exited ------------------------------------")

global _wait_and_clear_thread
if not _wait_and_clear_thread:
_wait_and_clear_thread = threading.Thread(target=thread_func)
_wait_and_clear_thread.start()


def stop():
r"""
Stop the running of agent.
"""
WorkerAgent_.stop_and_clear()

+ 2
- 24
mindspore_serving/worker/register/method.py View File

@@ -35,28 +35,6 @@ method_tag_predict = PredictPhaseTag_.kPredictPhaseTag_Predict
method_tag_postprocess = PredictPhaseTag_.kPredictPhaseTag_Postprocess


class _ServableStorage:
"""Declare servable info"""

def __init__(self):
pass

@staticmethod
def declare_servable(servable_meta):
"""Declare servable info excluding method, input and output count"""
ServableStorage_.declare_servable(servable_meta)

@staticmethod
def declare_servable_input_output(servable_name, inputs_count, outputs_count):
"""Declare input and output count of servable"""
ServableStorage_.register_servable_input_output_info(servable_name, inputs_count, outputs_count)

@staticmethod
def register_method(method_signature):
"""Declare method of servable"""
ServableStorage_.register_method(method_signature)


class _TensorDef:
"""Data flow item, for definitions of data flow in a method"""

@@ -251,7 +229,7 @@ def call_servable(*args):

servable_name = get_servable_dir()
inputs_count, outputs_count = method_def_ast_meta_[_call_servable_name]
_ServableStorage.declare_servable_input_output(servable_name, inputs_count, outputs_count)
ServableStorage_.register_servable_input_output_info(servable_name, inputs_count, outputs_count)
if inputs_count != len(args):
raise RuntimeError(f"Check failed in method '{method_def_context_.method_name}', given servable input "
f"size {len(args)} not match '{servable_name}' ast parse size {inputs_count}")
@@ -467,7 +445,7 @@ def register_method(output_names):
f", servable_name {method_def_context_.servable_name}, inputs: {input_names}, outputs: "
f"{output_names}")

_ServableStorage.register_method(method_def_context_)
ServableStorage_.register_method(method_def_context_)
return func

return register

+ 17
- 16
mindspore_serving/worker/register/servable.py View File

@@ -14,11 +14,10 @@
# ============================================================================
"""Servable declaration interface"""

from mindspore_serving._mindspore_serving import ServableMeta_
from mindspore_serving._mindspore_serving import ServableMeta_, ServableStorage_
from mindspore_serving.worker import check_type
from mindspore_serving.worker.common import get_servable_dir
from mindspore_serving import log as logger
from .method import _ServableStorage


def declare_servable(servable_file, model_format, with_batch_dim=True, options=None, without_batch_dim_inputs=None):
@@ -37,19 +36,25 @@ def declare_servable(servable_file, model_format, with_batch_dim=True, options=N
RuntimeError: The type or value of the parameters is invalid.
"""

check_type.check_str('servable_file', servable_file)
check_type.check_str('model_format', model_format)
check_type.check_bool('with_batch_dim', with_batch_dim)

meta = ServableMeta_()
meta.common_meta.servable_name = get_servable_dir()
meta.common_meta.with_batch_dim = with_batch_dim
if without_batch_dim_inputs:
without_batch_dim_inputs = check_type.check_and_as_int_tuple_list('without_batch_dim_inputs',
without_batch_dim_inputs, 0)
meta.common_meta.without_batch_dim_inputs = without_batch_dim_inputs

# init local servable meta info
check_type.check_str('servable_file', servable_file)
check_type.check_str('model_format', model_format)
model_format = model_format.lower()
if model_format not in ("om", "mindir"):
raise RuntimeError("model format can only be OM or MindIR")

meta = ServableMeta_()
meta.servable_name = get_servable_dir()
meta.servable_file = servable_file
meta.set_model_format(model_format)
meta.with_batch_dim = with_batch_dim
meta.local_meta.servable_file = servable_file
meta.local_meta.set_model_format(model_format)
if isinstance(options, dict):
for k, w in options.items():
check_type.check_str("options key", k)
@@ -61,14 +66,10 @@ def declare_servable(servable_file, model_format, with_batch_dim=True, options=N
raise RuntimeError(f"Parameter 'options' should be None, dict of <str,str> or AclOptions, but "
f"gotten {type(options)}")
if options:
meta.options = options
if without_batch_dim_inputs:
without_batch_dim_inputs = check_type.check_and_as_int_tuple_list('without_batch_dim_inputs',
without_batch_dim_inputs, 0)
meta.without_batch_dim_inputs = without_batch_dim_inputs
meta.local_meta.options = options

_ServableStorage.declare_servable(meta)
logger.info(f"Declare servable, servable_name: {meta.servable_name} "
ServableStorage_.declare_servable(meta)
logger.info(f"Declare servable, servable_name: {meta.common_meta.servable_name} "
f", servable_file: {servable_file} , model_format: {model_format}, with_batch_dim: {with_batch_dim} "
f", options: {options}, without_batch_dim_inputs: {without_batch_dim_inputs}")



+ 11
- 5
tests/ut/cpp/common/test_servable_common.h View File

@@ -27,6 +27,7 @@
#include "worker/worker.h"
#include "worker/notfiy_master/local_notify.h"
#include "worker/context.h"
#include "worker/local_servable/local_sevable.h"
#include "master/grpc/grpc_process.h"
#include "mindspore_serving/proto/ms_service.pb.h"

@@ -102,16 +103,21 @@ class TestMasterWorker : public UT::Common {
auto notify_master = std::make_shared<LocalNotifyMaster>();
ServableContext::Instance()->SetDeviceId(0);
ServableContext::Instance()->SetDeviceTypeStr("Ascend");
Status status = Worker::GetInstance().StartServable(servable_dir, servable_name, version_number, notify_master);
auto servable = std::make_shared<LocalModelServable>();
auto status = servable->StartServable(servable_dir, servable_name, version_number);
if (status != SUCCESS) {
return status;
}
status = Worker::GetInstance().StartServable(servable, notify_master);
return status;
}
static void DeclareServable(const std::string &servable_name, const std::string &servable_file,
const std::string &model_type, bool with_batch_dim = false) {
ServableMeta servable_meta;
servable_meta.servable_name = servable_name;
servable_meta.servable_file = servable_file;
servable_meta.SetModelFormat(model_type);
servable_meta.with_batch_dim = with_batch_dim;
servable_meta.common_meta.servable_name = servable_name;
servable_meta.common_meta.with_batch_dim = with_batch_dim;
servable_meta.local_meta.servable_file = servable_file;
servable_meta.local_meta.SetModelFormat(model_type);
// declare_servable
ServableStorage::Instance().DeclareServable(servable_meta);
}


+ 17
- 63
tests/ut/cpp/tests/test_start_worker.cc View File

@@ -30,10 +30,7 @@ TEST_F(TestStartWorker, test_worker_start_success) {
DeclareServable("test_servable", "test_add.mindir", "mindir", true);
RegisterMethod("test_servable", "add_common", {"x1", "x2"}, {"y"}, 2, 1);
// start_servable
auto notify_master = std::make_shared<LocalNotifyMaster>();
ServableContext::Instance()->SetDeviceId(0);
ServableContext::Instance()->SetDeviceTypeStr("Ascend");
Status status = Worker::GetInstance().StartServable("test_servable_dir", "test_servable", 0, notify_master);
Status status = StartServable("test_servable_dir", "test_servable", 0);
EXPECT_TRUE(status.IsSuccess());
}

@@ -43,10 +40,7 @@ TEST_F(TestStartWorker, test_worker_start_error_model_file_name) {
RegisterMethod("test_servable", "add_common", {"x1", "x2"}, {"y"}, 2, 1);

// start_servable
auto notify_master = std::make_shared<LocalNotifyMaster>();
ServableContext::Instance()->SetDeviceId(0);
ServableContext::Instance()->SetDeviceTypeStr("Ascend");
Status status = Worker::GetInstance().StartServable("test_servable_dir", "test_servable", 0, notify_master);
auto status = StartServable("test_servable_dir", "test_servable", 0);
EXPECT_FALSE(status.IsSuccess());
ExpectContainMsg(status.StatusMessage(), "Load model failed, servable directory: ");
}
@@ -57,12 +51,8 @@ TEST_F(TestStartWorker, test_worker_start_error_version_number) {
RegisterMethod("test_servable", "add_common", {"x1", "x2"}, {"y"}, 2, 1);

// start_servable
auto notify_master = std::make_shared<LocalNotifyMaster>();
ServableContext::Instance()->SetDeviceId(0);
ServableContext::Instance()->SetDeviceTypeStr("Ascend");
int error_version_number = 2;
Status status =
Worker::GetInstance().StartServable("test_servable_dir", "test_servable", error_version_number, notify_master);
auto status = StartServable("test_servable_dir", "test_servable", error_version_number);
EXPECT_FALSE(status.IsSuccess());
ExpectContainMsg(status.StatusMessage(),
"Start servable failed, there is no servable of"
@@ -78,11 +68,8 @@ TEST_F(TestStartWorker, test_worker_start_multi_version_number) {
RegisterMethod("test_servable", "add_common", {"x1", "x2"}, {"y"}, 2, 1);

// start_servable
auto notify_master = std::make_shared<LocalNotifyMaster>();
ServableContext::Instance()->SetDeviceId(0);
ServableContext::Instance()->SetDeviceTypeStr("Ascend");
int version_number = 0;
Status status = Worker::GetInstance().StartServable(servable_dir, "test_servable", version_number, notify_master);
Status status = StartServable(servable_dir, "test_servable", version_number);
EXPECT_TRUE(status.IsSuccess());
}

@@ -96,10 +83,7 @@ TEST_F(TestStartWorker, test_worker_start_version_number_no_valid) {
RegisterMethod("test_servable", "add_common", {"x1", "x2"}, {"y"}, 2, 1);

// start_servable
auto notify_master = std::make_shared<LocalNotifyMaster>();
ServableContext::Instance()->SetDeviceId(0);
ServableContext::Instance()->SetDeviceTypeStr("Ascend");
Status status = Worker::GetInstance().StartServable(servable_dir, "test_servable", 0, notify_master);
Status status = StartServable(servable_dir, "test_servable", 0);
EXPECT_FALSE(status.IsSuccess());
ExpectContainMsg(status.StatusMessage(),
"Start servable failed, there is no servable of"
@@ -112,11 +96,8 @@ TEST_F(TestStartWorker, test_worker_start_error_servable_dir) {
RegisterMethod("test_servable", "add_common", {"x1", "x2"}, {"y"}, 2, 1);

// start_servable
auto notify_master = std::make_shared<LocalNotifyMaster>();
ServableContext::Instance()->SetDeviceId(0);
ServableContext::Instance()->SetDeviceTypeStr("Ascend");
std::string error_servable_dir = "test_servable_dir_error";
Status status = Worker::GetInstance().StartServable(error_servable_dir, "test_servable", 0, notify_master);
Status status = StartServable(error_servable_dir, "test_servable", 0);
EXPECT_FALSE(status.IsSuccess());
ExpectContainMsg(status.StatusMessage(),
"Start servable failed, there is no servable of"
@@ -129,11 +110,8 @@ TEST_F(TestStartWorker, test_worker_start_error_servable_name) {
RegisterMethod("test_servable", "add_common", {"x1", "x2"}, {"y"}, 2, 1);

// start_servable
auto notify_master = std::make_shared<LocalNotifyMaster>();
ServableContext::Instance()->SetDeviceId(0);
ServableContext::Instance()->SetDeviceTypeStr("Ascend");
std::string error_servable_name = "test_servable_error";
Status status = Worker::GetInstance().StartServable("test_servable_dir", error_servable_name, 0, notify_master);
Status status = StartServable("test_servable_dir", error_servable_name, 0);
EXPECT_FALSE(status.IsSuccess());
ExpectContainMsg(status.StatusMessage(), "'test_servable_error' has not been registered");
}
@@ -144,24 +122,18 @@ TEST_F(TestStartWorker, test_worker_start_error_servable_format) {
RegisterMethod("test_servable", "add_common", {"x1", "x2"}, {"y"}, 2, 1);

// start_servable
auto notify_master = std::make_shared<LocalNotifyMaster>();
ServableContext::Instance()->SetDeviceId(0);
ServableContext::Instance()->SetDeviceTypeStr("Ascend");
Status status = Worker::GetInstance().StartServable("test_servable_dir", "test_servable", 0, notify_master);
Status status = StartServable("test_servable_dir", "test_servable", 0);
EXPECT_FALSE(status.IsSuccess());
ExpectContainMsg(status.StatusMessage(), "Cannot find session registered for device type Ascend and model type OM");
ExpectContainMsg(status.StatusMessage(), "Not support device type Ascend and model type OM. ");
}

TEST_F(TestStartWorker, test_worker_start_no_registered_method) {
Init("test_servable_dir", "test_servable", 1, "test_add.mindir");
Init("test_servable_dir", "test_servable", 2, "test_add.mindir");
DeclareServable("test_servable", "test_add.mindir", "mindir", true);
// no registered method
// RegisterMethod("test_servable", "add_common", {"x1", "x2"}, {"y"}, 2, 1);
// start_servable
auto notify_master = std::make_shared<LocalNotifyMaster>();
ServableContext::Instance()->SetDeviceId(0);
ServableContext::Instance()->SetDeviceTypeStr("Ascend");
Status status = Worker::GetInstance().StartServable("test_servable_dir", "test_servable", 0, notify_master);
Status status = StartServable("test_servable_dir", "test_servable", 2);
EXPECT_FALSE(status.IsSuccess());
ExpectContainMsg(status.StatusMessage(), "There is no method registered for servable");
}
@@ -181,10 +153,7 @@ TEST_F(TestStartWorker, test_worker_start_multi_method) {
RegisterMethod("test_servable", "add_common", {"x1", "x2"}, {"y"}, 2, 1);
RegisterMethod("test_servable", "add_common2", {"x1", "x2"}, {"y"}, 2, 1);
// start_servable
auto notify_master = std::make_shared<LocalNotifyMaster>();
ServableContext::Instance()->SetDeviceId(0);
ServableContext::Instance()->SetDeviceTypeStr("Ascend");
Status status = Worker::GetInstance().StartServable("test_servable_dir", "test_servable", 0, notify_master);
Status status = StartServable("test_servable_dir", "test_servable", 0);
EXPECT_TRUE(status.IsSuccess());
}

@@ -194,10 +163,7 @@ TEST_F(TestStartWorker, test_worker_start_method_servable_input_count_not_match)
size_t servable_input_count = 1;
RegisterMethod("test_servable", "add_common", {"x1", "x2"}, {"y"}, servable_input_count, 1);
// start_servable
auto notify_master = std::make_shared<LocalNotifyMaster>();
ServableContext::Instance()->SetDeviceId(0);
ServableContext::Instance()->SetDeviceTypeStr("Ascend");
Status status = Worker::GetInstance().StartServable("test_servable_dir", "test_servable", 0, notify_master);
Status status = StartServable("test_servable_dir", "test_servable", 0);
EXPECT_FALSE(status.IsSuccess());
ExpectContainMsg(status.StatusMessage(),
"The inputs count 1 registered in method not equal to "
@@ -210,10 +176,7 @@ TEST_F(TestStartWorker, test_worker_start_method_servable_output_count_not_match
size_t servable_output_count = 2;
RegisterMethod("test_servable", "add_common", {"x1", "x2"}, {"y"}, 2, servable_output_count);
// start_servable
auto notify_master = std::make_shared<LocalNotifyMaster>();
ServableContext::Instance()->SetDeviceId(0);
ServableContext::Instance()->SetDeviceTypeStr("Ascend");
Status status = Worker::GetInstance().StartServable("test_servable_dir", "test_servable", 0, notify_master);
Status status = StartServable("test_servable_dir", "test_servable", 0);
EXPECT_FALSE(status.IsSuccess());
ExpectContainMsg(status.StatusMessage(),
"The outputs count 2 registered in method not equal to "
@@ -241,10 +204,7 @@ TEST_F(TestStartWorker, test_worker_start_preprocess_not_found) {
ServableStorage::Instance().RegisterMethod(method_signature);

// start_servable
auto notify_master = std::make_shared<LocalNotifyMaster>();
ServableContext::Instance()->SetDeviceId(0);
ServableContext::Instance()->SetDeviceTypeStr("Ascend");
Status status = Worker::GetInstance().StartServable("test_servable_dir", "test_servable", 0, notify_master);
Status status = StartServable("test_servable_dir", "test_servable", 0);
EXPECT_FALSE(status.IsSuccess());
ExpectContainMsg(status.StatusMessage(), " preprocess preprocess_fake_fun not defined")
}
@@ -269,10 +229,7 @@ TEST_F(TestStartWorker, test_worker_start_postprocess_not_found) {
ServableStorage::Instance().RegisterMethod(method_signature);

// start_servable
auto notify_master = std::make_shared<LocalNotifyMaster>();
ServableContext::Instance()->SetDeviceId(0);
ServableContext::Instance()->SetDeviceTypeStr("Ascend");
Status status = Worker::GetInstance().StartServable("test_servable_dir", "test_servable", 0, notify_master);
Status status = StartServable("test_servable_dir", "test_servable", 0);
EXPECT_FALSE(status.IsSuccess());
ExpectContainMsg(status.StatusMessage(), " postprocess postprocess_fake_fun not defined")
}
@@ -300,10 +257,7 @@ TEST_F(TestStartWorker, test_worker_start_with_preproces_and_postprocess_success
ServableStorage::Instance().RegisterMethod(method_signature);

// start_servable
auto notify_master = std::make_shared<LocalNotifyMaster>();
ServableContext::Instance()->SetDeviceId(0);
ServableContext::Instance()->SetDeviceTypeStr("Ascend");
Status status = Worker::GetInstance().StartServable("test_servable_dir", "test_servable", 0, notify_master);
Status status = StartServable("test_servable_dir", "test_servable", 0);
EXPECT_TRUE(status.IsSuccess());
}



+ 16
- 13
tests/ut/runtest.sh View File

@@ -16,21 +16,24 @@

set -e

CURRPATH=$(cd "$(dirname $0)" || exit; pwd)
CURRPATH=$(
cd "$(dirname $0)" || exit
pwd
)

if [ $# -gt 0 ]; then
if [ $1 == "python" ]; then
echo "run python ut"
bash ${CURRPATH}/python/runtest.sh $2
elif [ $1 == "cpp" ]; then
echo "run cpp ut"
bash ${CURRPATH}/cpp/runtest.sh
fi
else
echo "run all ut"
# 1.run python testcases
if [ $1 == "python" ]; then
echo "run python ut"
bash ${CURRPATH}/python/runtest.sh $2
# 2.run c++ ut testcases
elif [ $1 == "cpp" ]; then
echo "run cpp ut"
bash ${CURRPATH}/cpp/runtest.sh
fi
else
echo "run all ut"
# 1.run python testcases
bash ${CURRPATH}/python/runtest.sh $2

# 2.run c++ ut testcases
bash ${CURRPATH}/cpp/runtest.sh
fi

Loading…
Cancel
Save