diff --git a/mindspore_serving/ccsrc/common/exit_handle.cc b/mindspore_serving/ccsrc/common/exit_handle.cc index 3b97c21..88d9644 100644 --- a/mindspore_serving/ccsrc/common/exit_handle.cc +++ b/mindspore_serving/ccsrc/common/exit_handle.cc @@ -55,6 +55,17 @@ void ExitSignalHandle::WorkerWait() { exit_future.wait(); } +// waiting ctrl+c or stop message to exit, +// if no server is running or server has exited, there is no need to wait +void ExitSignalHandle::AgentWait() { + if (!is_running_) { + MSI_LOG_INFO << "Exit Handle has not started or has exited"; + return; + } + auto exit_future = agent_exit_requested_.get_future(); + exit_future.wait(); +} + void ExitSignalHandle::Start() { if (is_running_) { return; @@ -62,6 +73,7 @@ void ExitSignalHandle::Start() { is_running_ = true; master_exit_requested_ = std::promise(); worker_exit_requested_ = std::promise(); + agent_exit_requested_ = std::promise(); has_exited_.clear(); InitSignalHandle(); } @@ -79,6 +91,7 @@ void ExitSignalHandle::HandleSignalInner() { if (!has_exited_.test_and_set()) { master_exit_requested_.set_value(); worker_exit_requested_.set_value(); + agent_exit_requested_.set_value(); is_running_ = false; } } diff --git a/mindspore_serving/ccsrc/common/exit_handle.h b/mindspore_serving/ccsrc/common/exit_handle.h index 66a2fd2..42654c6 100644 --- a/mindspore_serving/ccsrc/common/exit_handle.h +++ b/mindspore_serving/ccsrc/common/exit_handle.h @@ -32,6 +32,7 @@ class MS_API ExitSignalHandle { void InitSignalHandle(); void MasterWait(); void WorkerWait(); + void AgentWait(); void Start(); void Stop(); bool HasStopped(); @@ -39,6 +40,7 @@ class MS_API ExitSignalHandle { private: std::promise master_exit_requested_; std::promise worker_exit_requested_; + std::promise agent_exit_requested_; std::atomic_flag has_exited_ = true; std::atomic_flag has_inited_ = ATOMIC_FLAG_INIT; std::atomic_bool is_running_ = false; diff --git a/mindspore_serving/ccsrc/common/grpc_client.cc b/mindspore_serving/ccsrc/common/grpc_client.cc new file mode 100644 index 0000000..d4ccb8c --- /dev/null +++ b/mindspore_serving/ccsrc/common/grpc_client.cc @@ -0,0 +1,25 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "common/grpc_client.h" + +namespace mindspore { +namespace serving { +std::unique_ptr client_; +std::unique_ptr distributed_client_; + +} // namespace serving +} // namespace mindspore diff --git a/mindspore_serving/ccsrc/common/grpc_client.h b/mindspore_serving/ccsrc/common/grpc_client.h new file mode 100644 index 0000000..c784713 --- /dev/null +++ b/mindspore_serving/ccsrc/common/grpc_client.h @@ -0,0 +1,115 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_SERVING_MASTER_GRPC_CLIENT_H +#define MINDSPORE_SERVING_MASTER_GRPC_CLIENT_H + +#include +#include +#include +#include +#include +#include +#include +#include +#include "common/serving_common.h" +#include "proto/ms_service.pb.h" +#include "proto/ms_service.grpc.pb.h" +#include "proto/ms_master.pb.h" +#include "proto/ms_master.grpc.pb.h" +#include "proto/ms_worker.grpc.pb.h" +#include "proto/ms_agent.pb.h" +#include "proto/ms_agent.grpc.pb.h" + +namespace mindspore { +namespace serving { + +using PredictOnFinish = std::function; + +using DispatchCallback = std::function; + +template +class MSServiceClient { + public: + MSServiceClient() = default; + ~MSServiceClient() { + if (in_running_) { + cq_.Shutdown(); + if (client_thread_.joinable()) { + try { + client_thread_.join(); + } catch (const std::system_error &) { + } catch (...) { + } + } + } + in_running_ = false; + } + + void Start() { + client_thread_ = std::thread(&MSServiceClient::AsyncCompleteRpc, this); + in_running_ = true; + } + + void AsyncCompleteRpc() { + void *got_tag; + bool ok = false; + + while (cq_.Next(&got_tag, &ok)) { + AsyncClientCall *call = static_cast(got_tag); + if (call->status.ok()) { + call->callback(SUCCESS); + } else { + MSI_LOG_ERROR << "RPC failed: " << call->status.error_code() << ", " << call->status.error_message(); + call->callback(Status(FAILED, call->status.error_message())); + } + delete call; + } + } + + void PredictAsync(const Request &request, Reply *reply, MSStub *stub, DispatchCallback callback) { + AsyncClientCall *call = new AsyncClientCall; + call->reply = reply; + call->callback = std::move(callback); + call->response_reader = stub->PrepareAsyncPredict(&call->context, request, &cq_); + call->response_reader->StartCall(); + call->response_reader->Finish(call->reply, &call->status, call); + MSI_LOG(INFO) << "Finish send Predict"; + } + + private: + struct AsyncClientCall { + grpc::ClientContext context; + grpc::Status status; + Reply *reply; + DispatchCallback callback; + std::shared_ptr> response_reader; + }; + + grpc::CompletionQueue cq_; + std::thread client_thread_; + bool in_running_ = false; +}; + +using MSPredictClient = MSServiceClient; +using MSDistributedClient = + MSServiceClient; +extern std::unique_ptr client_; +extern std::unique_ptr distributed_client_; +} // namespace serving +} // namespace mindspore + +#endif // MINDSPORE_SERVING_MASTER_GRPC_CLIENT_H diff --git a/mindspore_serving/ccsrc/common/proto_tensor.cc b/mindspore_serving/ccsrc/common/proto_tensor.cc index 2b8fd9e..c4f8d16 100644 --- a/mindspore_serving/ccsrc/common/proto_tensor.cc +++ b/mindspore_serving/ccsrc/common/proto_tensor.cc @@ -341,6 +341,56 @@ Status GrpcTensorHelper::CreateInstanceFromRequestInstances(const proto::Predict return SUCCESS; } +void GrpcTensorHelper::CopyFromAgentSpec(const proto::AgentSpec &specs, WorkerAgentSpec *worker_specs) { + worker_specs->rank_id = specs.rank_id(); + worker_specs->batch_size = specs.batch_size(); + for (auto &in : specs.inputs()) { + TensorInfo info; + info.data_type = ProtoTensor::TransDataType2Inference(in.dtype()); + info.size = in.size(); + for (auto &dim : in.shape().dims()) { + info.shape.push_back(dim); + } + worker_specs->input_infos.push_back(info); + } + for (auto &out : specs.outputs()) { + TensorInfo info; + info.data_type = ProtoTensor::TransDataType2Inference(out.dtype()); + for (auto &dim : out.shape().dims()) { + info.shape.push_back(dim); + } + worker_specs->output_infos.push_back(info); + } +} + +void GrpcTensorHelper::CopyFromWorkerAgentSpec(const std::vector &worker_specs, + proto::AgentRegisterRequest *request) { + for (size_t i = 0; i < worker_specs.size(); i++) { + auto &spec = worker_specs[i]; + auto worker_spec = request->add_agent_spec(); + worker_spec->set_rank_id(spec.rank_id); + worker_spec->set_batch_size(spec.batch_size); + for (auto &method : spec.input_infos) { + auto proto_method = worker_spec->add_inputs(); + proto_method->set_dtype(ProtoTensor::TransDataType2Proto(method.data_type)); + proto_method->set_size(method.size); + auto proto_shape = proto_method->mutable_shape(); + for (auto &dim : method.shape) { + proto_shape->add_dims(dim); + } + } + for (auto &method : spec.output_infos) { + auto proto_method = worker_spec->add_outputs(); + proto_method->set_dtype(ProtoTensor::TransDataType2Proto(method.data_type)); + proto_method->set_size(method.size); + auto proto_shape = proto_method->mutable_shape(); + for (auto &dim : method.shape) { + proto_shape->add_dims(dim); + } + } + } +} + Status GrpcTensorHelper::CheckRequestTensor(const proto::Tensor &tensor) { Status status; ProtoTensor tensor_input(const_cast(&tensor)); diff --git a/mindspore_serving/ccsrc/common/proto_tensor.h b/mindspore_serving/ccsrc/common/proto_tensor.h index 554da02..d80b982 100644 --- a/mindspore_serving/ccsrc/common/proto_tensor.h +++ b/mindspore_serving/ccsrc/common/proto_tensor.h @@ -24,6 +24,7 @@ #include "common/serving_common.h" #include "proto/ms_service.pb.h" #include "proto/ms_master.pb.h" +#include "proto/ms_distributed.pb.h" #include "common/instance.h" #include "common/servable.h" @@ -68,6 +69,9 @@ class MS_API GrpcTensorHelper { std::vector *results); static Status CreateReplyFromInstances(const proto::PredictRequest &request, const std::vector &inputs, proto::PredictReply *reply); + static void CopyFromAgentSpec(const proto::AgentSpec &request, WorkerAgentSpec *worker_specs); + static void CopyFromWorkerAgentSpec(const std::vector &worker_specs, + proto::AgentRegisterRequest *request); private: static Status CreateInstanceFromRequestInstances(const proto::PredictRequest &request, diff --git a/mindspore_serving/ccsrc/common/servable.cc b/mindspore_serving/ccsrc/common/servable.cc index d36bacf..26f5957 100644 --- a/mindspore_serving/ccsrc/common/servable.cc +++ b/mindspore_serving/ccsrc/common/servable.cc @@ -25,11 +25,23 @@ namespace mindspore::serving { std::string ServableMeta::Repr() const { std::ostringstream stream; - stream << "path(" << servable_name << ") file(" << servable_file + ")"; + switch (servable_type) { + case kServableTypeUnknown: + stream << "undeclared servable, servable name: '" << common_meta.servable_name << "'"; + break; + case kServableTypeLocal: + stream << "local servable, servable name: '" << common_meta.servable_name << "', file: '" + << local_meta.servable_file + "'"; + break; + case kServableTypeDistributed: + stream << "distributed servable, servable name: '" << common_meta.servable_name + << "', rank size: " << distributed_meta.rank_size << ", stage size " << distributed_meta.stage_size; + break; + } return stream.str(); } -void ServableMeta::SetModelFormat(const std::string &format) { +void LocalServableMeta::SetModelFormat(const std::string &format) { if (format == "om") { model_format = kOM; } else if (format == "mindir") { @@ -63,142 +75,181 @@ std::string RequestSpec::Repr() const { return "servable(" + servable_name + ") " + "method(" + method_name + ") " + version; } -Status ServableSignature::Check() const { - std::set method_set; +Status ServableSignature::CheckPreprocessInput(const MethodSignature &method, size_t *preprocess_outputs_count) const { std::string model_str = servable_meta.Repr(); + const auto &preprocess_name = method.preprocess_name; + if (!preprocess_name.empty()) { + auto preprocess = PreprocessStorage::Instance().GetPreprocess(preprocess_name); + if (preprocess == nullptr) { + return INFER_STATUS_LOG_ERROR(FAILED) << "Model " << model_str << " method " << method.method_name + << " preprocess " << preprocess_name << " not defined"; + } + *preprocess_outputs_count = preprocess->GetOutputsCount(preprocess_name); - for (auto &method : methods) { - if (method_set.count(method.method_name) > 0) { - return INFER_STATUS_LOG_ERROR(FAILED) - << "Model " << model_str << " " << method.method_name << " has been defined repeatly"; + for (size_t i = 0; i < method.preprocess_inputs.size(); i++) { + auto &input = method.preprocess_inputs[i]; + if (input.first != kPredictPhaseTag_Input) { + return INFER_STATUS_LOG_ERROR(FAILED) + << "Model " << model_str << " method " << method.method_name << ", the data of preprocess " << i + << "th input cannot not come from '" << input.first << "'"; + } + if (input.second >= method.inputs.size()) { + return INFER_STATUS_LOG_ERROR(FAILED) + << "Model " << model_str << " method " << method.method_name << ", the preprocess " << i + << "th input uses method " << input.second << "th input, that is greater than the method inputs size " + << method.inputs.size(); + } } - method_set.emplace(method.method_name); + } + return SUCCESS; +} - size_t preprocess_outputs_count = 0; - size_t postprocess_outputs_count = 0; +Status ServableSignature::CheckPredictInput(const MethodSignature &method, size_t preprocess_outputs_count) const { + std::string model_str = servable_meta.Repr(); - const auto &preprocess_name = method.preprocess_name; - if (!preprocess_name.empty()) { - auto preprocess = PreprocessStorage::Instance().GetPreprocess(preprocess_name); - if (preprocess == nullptr) { - return INFER_STATUS_LOG_ERROR(FAILED) << "Model " << model_str << " method " << method.method_name - << " preprocess " << preprocess_name << " not defined"; + for (size_t i = 0; i < method.servable_inputs.size(); i++) { + auto &input = method.servable_inputs[i]; + if (input.first == kPredictPhaseTag_Input) { + if (input.second >= method.inputs.size()) { + return INFER_STATUS_LOG_ERROR(FAILED) + << "Model " << model_str << " method " << method.method_name << ", the servable " << i + << "th input uses method " << input.second << "th input, that is greater than the method inputs size " + << method.inputs.size(); } - preprocess_outputs_count = preprocess->GetOutputsCount(preprocess_name); - - for (size_t i = 0; i < method.preprocess_inputs.size(); i++) { - auto &input = method.preprocess_inputs[i]; - if (input.first != kPredictPhaseTag_Input) { - return INFER_STATUS_LOG_ERROR(FAILED) - << "Model " << model_str << " method " << method.method_name << ", the data of preprocess " << i - << "th input cannot not come from '" << input.first << "'"; - } - if (input.second >= method.inputs.size()) { - return INFER_STATUS_LOG_ERROR(FAILED) - << "Model " << model_str << " method " << method.method_name << ", the preprocess " << i - << "th input uses method " << input.second << "th input, that is greater than the method inputs size " - << method.inputs.size(); - } + } else if (input.first == kPredictPhaseTag_Preproces) { + if (input.second >= preprocess_outputs_count) { + return INFER_STATUS_LOG_ERROR(FAILED) + << "Model " << model_str << " method " << method.method_name << ", the servable " << i + << "th input uses preprocess " << input.second + << "th output, that is greater than the preprocess outputs size " << preprocess_outputs_count; } + } else { + return INFER_STATUS_LOG_ERROR(FAILED) + << "Model " << model_str << " method " << method.method_name << ", the data of servable " << i + << "th input cannot not come from '" << input.first << "'"; + } + } + return SUCCESS; +} + +Status ServableSignature::CheckPostprocessInput(const MethodSignature &method, size_t preprocess_outputs_count, + size_t *postprocess_outputs_count) const { + std::string model_str = servable_meta.Repr(); + const auto &common_meta = servable_meta.common_meta; + + const auto &postprocess_name = method.postprocess_name; + if (!method.postprocess_name.empty()) { + auto postprocess = PostprocessStorage::Instance().GetPostprocess(postprocess_name); + if (postprocess == nullptr) { + return INFER_STATUS_LOG_ERROR(FAILED) << "Model " << model_str << " method " << method.method_name + << " postprocess " << postprocess_name << " not defined"; } + *postprocess_outputs_count = postprocess->GetOutputsCount(postprocess_name); - for (size_t i = 0; i < method.servable_inputs.size(); i++) { - auto &input = method.servable_inputs[i]; + for (size_t i = 0; i < method.postprocess_inputs.size(); i++) { + auto &input = method.postprocess_inputs[i]; if (input.first == kPredictPhaseTag_Input) { if (input.second >= method.inputs.size()) { return INFER_STATUS_LOG_ERROR(FAILED) - << "Model " << model_str << " method " << method.method_name << ", the servable " << i + << "Model " << model_str << " method " << method.method_name << ", the postprocess " << i << "th input uses method " << input.second << "th input, that is greater than the method inputs size " << method.inputs.size(); } } else if (input.first == kPredictPhaseTag_Preproces) { if (input.second >= preprocess_outputs_count) { return INFER_STATUS_LOG_ERROR(FAILED) - << "Model " << model_str << " method " << method.method_name << ", the servable " << i + << "Model " << model_str << " method " << method.method_name << ", the postprocess " << i << "th input uses preprocess " << input.second << "th output, that is greater than the preprocess outputs size " << preprocess_outputs_count; } + } else if (input.first == kPredictPhaseTag_Predict) { + if (input.second >= common_meta.outputs_count) { + return INFER_STATUS_LOG_ERROR(FAILED) + << "Model " << model_str << " method " << method.method_name << ", the postprocess " << i + << "th input uses servable " << input.second + << "th output, that is greater than the servable outputs size " << common_meta.outputs_count; + } } else { return INFER_STATUS_LOG_ERROR(FAILED) - << "Model " << model_str << " method " << method.method_name << ", the data of servable " << i + << "Model " << model_str << " method " << method.method_name << ", the data of postprocess " << i << "th input cannot not come from '" << input.first << "'"; } } + } + return SUCCESS; +} - const auto &postprocess_name = method.postprocess_name; - if (!method.postprocess_name.empty()) { - auto postprocess = PostprocessStorage::Instance().GetPostprocess(postprocess_name); - if (postprocess == nullptr) { - return INFER_STATUS_LOG_ERROR(FAILED) << "Model " << model_str << " method " << method.method_name - << " postprocess " << postprocess_name << " not defined"; - } - postprocess_outputs_count = postprocess->GetOutputsCount(postprocess_name); +Status ServableSignature::CheckReturn(const MethodSignature &method, size_t preprocess_outputs_count, + size_t postprocess_outputs_count) const { + std::string model_str = servable_meta.Repr(); + const auto &common_meta = servable_meta.common_meta; - for (size_t i = 0; i < method.postprocess_inputs.size(); i++) { - auto &input = method.postprocess_inputs[i]; - if (input.first == kPredictPhaseTag_Input) { - if (input.second >= method.inputs.size()) { - return INFER_STATUS_LOG_ERROR(FAILED) - << "Model " << model_str << " method " << method.method_name << ", the postprocess " << i - << "th input uses method " << input.second - << "th input, that is greater than the method inputs size " << method.inputs.size(); - } - } else if (input.first == kPredictPhaseTag_Preproces) { - if (input.second >= preprocess_outputs_count) { - return INFER_STATUS_LOG_ERROR(FAILED) - << "Model " << model_str << " method " << method.method_name << ", the postprocess " << i - << "th input uses preprocess " << input.second - << "th output, that is greater than the preprocess outputs size " << preprocess_outputs_count; - } - } else if (input.first == kPredictPhaseTag_Predict) { - if (input.second >= servable_meta.outputs_count) { - return INFER_STATUS_LOG_ERROR(FAILED) - << "Model " << model_str << " method " << method.method_name << ", the postprocess " << i - << "th input uses servable " << input.second - << "th output, that is greater than the servable outputs size " << servable_meta.outputs_count; - } - } else { - return INFER_STATUS_LOG_ERROR(FAILED) - << "Model " << model_str << " method " << method.method_name << ", the data of postprocess " << i - << "th input cannot not come from '" << input.first << "'"; - } + for (size_t i = 0; i < method.returns.size(); i++) { + auto &input = method.returns[i]; + if (input.first == kPredictPhaseTag_Input) { + if (input.second >= method.inputs.size()) { + return INFER_STATUS_LOG_ERROR(FAILED) + << "Model " << model_str << " method " << method.method_name << ", the method " << i + << "th output uses method " << input.second << "th input, that is greater than the method inputs size " + << method.inputs.size(); } - } - for (size_t i = 0; i < method.returns.size(); i++) { - auto &input = method.returns[i]; - if (input.first == kPredictPhaseTag_Input) { - if (input.second >= method.inputs.size()) { - return INFER_STATUS_LOG_ERROR(FAILED) - << "Model " << model_str << " method " << method.method_name << ", the method " << i - << "th output uses method " << input.second << "th input, that is greater than the method inputs size " - << method.inputs.size(); - } - } else if (input.first == kPredictPhaseTag_Preproces) { - if (input.second >= preprocess_outputs_count) { - return INFER_STATUS_LOG_ERROR(FAILED) - << "Model " << model_str << " method " << method.method_name << ", the method " << i - << "th output uses preprocess " << input.second - << "th output, that is greater than the preprocess outputs size " << preprocess_outputs_count; - } - } else if (input.first == kPredictPhaseTag_Predict) { - if (input.second >= servable_meta.outputs_count) { - return INFER_STATUS_LOG_ERROR(FAILED) - << "Model " << model_str << " method " << method.method_name << ", the method " << i - << "th output uses servable " << input.second - << "th output, that is greater than the servable outputs size " << servable_meta.outputs_count; - } - } else if (input.first == kPredictPhaseTag_Postprocess) { - if (input.second >= postprocess_outputs_count) { - return INFER_STATUS_LOG_ERROR(FAILED) - << "Model " << model_str << " method " << method.method_name << ", the method " << i - << "th output uses postprocess " << input.second - << "th output, that is greater than the postprocess outputs size " << postprocess_outputs_count; - } - } else { + } else if (input.first == kPredictPhaseTag_Preproces) { + if (input.second >= preprocess_outputs_count) { + return INFER_STATUS_LOG_ERROR(FAILED) + << "Model " << model_str << " method " << method.method_name << ", the method " << i + << "th output uses preprocess " << input.second + << "th output, that is greater than the preprocess outputs size " << preprocess_outputs_count; + } + } else if (input.first == kPredictPhaseTag_Predict) { + if (input.second >= common_meta.outputs_count) { + return INFER_STATUS_LOG_ERROR(FAILED) + << "Model " << model_str << " method " << method.method_name << ", the method " << i + << "th output uses servable " << input.second + << "th output, that is greater than the servable outputs size " << common_meta.outputs_count; + } + } else if (input.first == kPredictPhaseTag_Postprocess) { + if (input.second >= postprocess_outputs_count) { return INFER_STATUS_LOG_ERROR(FAILED) - << "Model " << model_str << " method " << method.method_name << ", the data of method " << i - << "th output cannot not come from '" << input.first << "'"; + << "Model " << model_str << " method " << method.method_name << ", the method " << i + << "th output uses postprocess " << input.second + << "th output, that is greater than the postprocess outputs size " << postprocess_outputs_count; } + } else { + return INFER_STATUS_LOG_ERROR(FAILED) + << "Model " << model_str << " method " << method.method_name << ", the data of method " << i + << "th output cannot not come from '" << input.first << "'"; + } + } + return SUCCESS; +} + +Status ServableSignature::Check() const { + std::set method_set; + Status status; + for (auto &method : methods) { + if (method_set.count(method.method_name) > 0) { + return INFER_STATUS_LOG_ERROR(FAILED) + << "Model " << servable_meta.Repr() << " " << method.method_name << " has been defined repeatedly"; + } + method_set.emplace(method.method_name); + + size_t preprocess_outputs_count = 0; + size_t postprocess_outputs_count = 0; + status = CheckPreprocessInput(method, &preprocess_outputs_count); + if (status != SUCCESS) { + return status; + } + status = CheckPredictInput(method, preprocess_outputs_count); + if (status != SUCCESS) { + return status; + } + status = CheckPostprocessInput(method, preprocess_outputs_count, &postprocess_outputs_count); + if (status != SUCCESS) { + return status; + } + status = CheckReturn(method, preprocess_outputs_count, postprocess_outputs_count); + if (status != SUCCESS) { + return status; } } return SUCCESS; @@ -216,7 +267,7 @@ bool ServableSignature::GetMethodDeclare(const std::string &method_name, MethodS } void ServableStorage::Register(const ServableSignature &def) { - auto model_name = def.servable_meta.servable_name; + auto model_name = def.servable_meta.common_meta.servable_name; if (servable_signatures_map_.find(model_name) == servable_signatures_map_.end()) { MSI_LOG_WARNING << "Servable " << model_name << " has already been defined"; } @@ -258,16 +309,60 @@ Status ServableStorage::RegisterMethod(const MethodSignature &method) { return SUCCESS; } -void ServableStorage::DeclareServable(const mindspore::serving::ServableMeta &servable) { - MSI_LOG_INFO << "Declare servable " << servable.servable_name; - auto it = servable_signatures_map_.find(servable.servable_name); +Status ServableStorage::DeclareServable(ServableMeta servable) { + auto &common_meta = servable.common_meta; + MSI_LOG_INFO << "Declare servable " << common_meta.servable_name; + servable.servable_type = kServableTypeLocal; + if (servable.local_meta.servable_file.empty()) { + return INFER_STATUS_LOG_ERROR(FAILED) + << "Declare servable " << common_meta.servable_name << " failed, servable_file cannot be empty"; + } + if (servable.local_meta.model_format == ModelType::kUnknownType) { + return INFER_STATUS_LOG_ERROR(FAILED) + << "Declare servable " << common_meta.servable_name << " failed, model_format is not inited"; + } + auto it = servable_signatures_map_.find(common_meta.servable_name); if (it == servable_signatures_map_.end()) { ServableSignature signature; signature.servable_meta = servable; - servable_signatures_map_[servable.servable_name] = signature; - return; + servable_signatures_map_[common_meta.servable_name] = signature; + return SUCCESS; } - it->second.servable_meta = servable; + auto &org_servable_meta = it->second.servable_meta; + if (org_servable_meta.servable_type != kServableTypeUnknown) { + return INFER_STATUS_LOG_ERROR(FAILED) + << "Servable " << common_meta.servable_name << " has already been declared as: " << servable.Repr(); + } + org_servable_meta = servable; + return SUCCESS; +} + +Status ServableStorage::DeclareDistributedServable(ServableMeta servable) { + auto &common_meta = servable.common_meta; + MSI_LOG_INFO << "Declare servable " << common_meta.servable_name; + servable.servable_type = kServableTypeDistributed; + if (servable.distributed_meta.rank_size == 0) { + return INFER_STATUS_LOG_ERROR(FAILED) + << "Declare distributed servable " << common_meta.servable_name << " failed, rank_size cannot be 0"; + } + if (servable.distributed_meta.stage_size == 0) { + return INFER_STATUS_LOG_ERROR(FAILED) + << "Declare distributed servable " << common_meta.servable_name << " failed, stage_size cannot be 0"; + } + auto it = servable_signatures_map_.find(common_meta.servable_name); + if (it == servable_signatures_map_.end()) { + ServableSignature signature; + signature.servable_meta = servable; + servable_signatures_map_[common_meta.servable_name] = signature; + return SUCCESS; + } + auto &org_servable_meta = it->second.servable_meta; + if (org_servable_meta.servable_type != kServableTypeUnknown) { + return INFER_STATUS_LOG_ERROR(FAILED) + << "Servable " << common_meta.servable_name << " has already been declared as: " << servable.Repr(); + } + org_servable_meta = servable; + return SUCCESS; } Status ServableStorage::RegisterInputOutputInfo(const std::string &servable_name, size_t inputs_count, @@ -277,18 +372,19 @@ Status ServableStorage::RegisterInputOutputInfo(const std::string &servable_name return INFER_STATUS_LOG_ERROR(FAILED) << "RegisterInputOutputInfo failed, cannot find servable " << servable_name; } auto &servable_meta = it->second.servable_meta; - if (servable_meta.inputs_count != 0 && servable_meta.inputs_count != inputs_count) { + auto &common_meta = servable_meta.common_meta; + if (common_meta.inputs_count != 0 && common_meta.inputs_count != inputs_count) { return INFER_STATUS_LOG_ERROR(FAILED) << "RegisterInputOutputInfo failed, inputs count " << inputs_count << " not match old count " - << servable_meta.inputs_count << ",servable name " << servable_name; + << common_meta.inputs_count << ",servable name " << servable_name; } - if (servable_meta.outputs_count != 0 && servable_meta.outputs_count != outputs_count) { + if (common_meta.outputs_count != 0 && common_meta.outputs_count != outputs_count) { return INFER_STATUS_LOG_ERROR(FAILED) << "RegisterInputOutputInfo failed, outputs count " << outputs_count << " not match old count " - << servable_meta.outputs_count << ",servable name " << servable_name; + << common_meta.outputs_count << ",servable name " << servable_name; } - servable_meta.inputs_count = inputs_count; - servable_meta.outputs_count = outputs_count; + common_meta.inputs_count = inputs_count; + common_meta.outputs_count = outputs_count; return SUCCESS; } @@ -298,8 +394,8 @@ std::vector ServableStorage::GetInputOutputInfo(const std::string &serva if (it == servable_signatures_map_.end()) { return result; } - result.push_back(it->second.servable_meta.inputs_count); - result.push_back(it->second.servable_meta.outputs_count); + result.push_back(it->second.servable_meta.common_meta.inputs_count); + result.push_back(it->second.servable_meta.common_meta.outputs_count); return result; } diff --git a/mindspore_serving/ccsrc/common/servable.h b/mindspore_serving/ccsrc/common/servable.h index 5402bc0..b64d99b 100644 --- a/mindspore_serving/ccsrc/common/servable.h +++ b/mindspore_serving/ccsrc/common/servable.h @@ -81,19 +81,39 @@ struct RequestSpec { std::string Repr() const; }; -struct MS_API ServableMeta { +enum ServableType { + kServableTypeUnknown = 0, + kServableTypeLocal = 1, + kServableTypeDistributed = 2, +}; + +struct CommonServableMeta { std::string servable_name; - std::string servable_file; // file name - ModelType model_format; // OM, MindIR bool with_batch_dim = true; // whether there is batch dim in model's inputs/outputs + std::vector without_batch_dim_inputs; size_t inputs_count = 0; size_t outputs_count = 0; +}; - std::map load_options; // Acl options - std::vector without_batch_dim_inputs; +struct MS_API LocalServableMeta { + std::string servable_file; // file name + ModelType model_format = ModelType::kUnknownType; // OM, MindIR + std::map load_options; // Acl options + void SetModelFormat(const std::string &format); +}; + +struct DistributedServableMeta { + size_t rank_size = 0; + size_t stage_size = 0; +}; + +struct MS_API ServableMeta { + ServableType servable_type = kServableTypeUnknown; + CommonServableMeta common_meta; + LocalServableMeta local_meta; + DistributedServableMeta distributed_meta; std::string Repr() const; - void SetModelFormat(const std::string &format); }; struct ServableSignature { @@ -102,6 +122,12 @@ struct ServableSignature { Status Check() const; bool GetMethodDeclare(const std::string &method_name, MethodSignature *method); + + private: + Status CheckPreprocessInput(const MethodSignature &method, size_t *pre) const; + Status CheckPredictInput(const MethodSignature &method, size_t pre) const; + Status CheckPostprocessInput(const MethodSignature &method, size_t pre, size_t *post) const; + Status CheckReturn(const MethodSignature &method, size_t pre, size_t post) const; }; class MS_API ServableStorage { @@ -111,7 +137,8 @@ class MS_API ServableStorage { bool GetServableDef(const std::string &model_name, ServableSignature *def) const; - void DeclareServable(const ServableMeta &servable); + Status DeclareServable(ServableMeta servable); + Status DeclareDistributedServable(ServableMeta servable); Status RegisterInputOutputInfo(const std::string &servable_name, size_t inputs_count, size_t outputs_count); std::vector GetInputOutputInfo(const std::string &servable_name) const; @@ -144,6 +171,14 @@ static inline LogStream &operator<<(LogStream &stream, PredictPhaseTag data_type return stream; } +struct WorkerAgentSpec { + std::string agent_address; + uint32_t rank_id = 0; + std::vector input_infos; + std::vector output_infos; + uint32_t batch_size = 0; +}; + } // namespace mindspore::serving #endif // MINDSPORE_SERVING_SERVABLE_H diff --git a/mindspore_serving/ccsrc/master/dispacther.h b/mindspore_serving/ccsrc/master/dispacther.h index 95f632b..4e5e406 100644 --- a/mindspore_serving/ccsrc/master/dispacther.h +++ b/mindspore_serving/ccsrc/master/dispacther.h @@ -27,7 +27,7 @@ #include "common/instance.h" #include "common/servable.h" #include "master/notify_worker/base_notify.h" -#include "master/grpc/grpc_client.h" +#include "common/grpc_client.h" namespace mindspore::serving { diff --git a/mindspore_serving/ccsrc/master/grpc/grpc_client.cc b/mindspore_serving/ccsrc/master/grpc/grpc_client.cc deleted file mode 100644 index 85ad129..0000000 --- a/mindspore_serving/ccsrc/master/grpc/grpc_client.cc +++ /dev/null @@ -1,73 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "master/grpc/grpc_client.h" -#include -#include -#include "master/grpc/grpc_server.h" - -namespace mindspore { -namespace serving { -std::unique_ptr client_; - -MSServiceClient::~MSServiceClient() { - if (in_running_) { - cq_.Shutdown(); - if (client_thread_.joinable()) { - try { - client_thread_.join(); - } catch (const std::system_error &) { - } catch (...) { - } - } - } - in_running_ = false; -} - -void MSServiceClient::PredictAsync(const proto::PredictRequest &request, proto::PredictReply *reply, - std::shared_ptr stub, DispatchCallback callback) { - AsyncClientCall *call = new AsyncClientCall; - call->reply = reply; - call->callback = std::move(callback); - call->response_reader = stub->PrepareAsyncPredict(&call->context, request, &cq_); - call->response_reader->StartCall(); - call->response_reader->Finish(call->reply, &call->status, call); - MSI_LOG(INFO) << "Finish send Predict"; -} - -void MSServiceClient::AsyncCompleteRpc() { - void *got_tag; - bool ok = false; - - while (cq_.Next(&got_tag, &ok)) { - AsyncClientCall *call = static_cast(got_tag); - if (call->status.ok()) { - call->callback(SUCCESS); - } else { - MSI_LOG_ERROR << "RPC failed: " << call->status.error_code() << ", " << call->status.error_message(); - call->callback(Status(FAILED, call->status.error_message())); - } - delete call; - } -} - -void MSServiceClient::Start() { - client_thread_ = std::thread(&MSServiceClient::AsyncCompleteRpc, this); - in_running_ = true; -} - -} // namespace serving -} // namespace mindspore diff --git a/mindspore_serving/ccsrc/master/grpc/grpc_client.h b/mindspore_serving/ccsrc/master/grpc/grpc_client.h deleted file mode 100644 index ca39a48..0000000 --- a/mindspore_serving/ccsrc/master/grpc/grpc_client.h +++ /dev/null @@ -1,68 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_SERVING_MASTER_GRPC_CLIENT_H -#define MINDSPORE_SERVING_MASTER_GRPC_CLIENT_H - -#include -#include -#include -#include -#include -#include -#include "common/serving_common.h" -#include "master/notify_worker/base_notify.h" -#include "proto/ms_service.pb.h" -#include "proto/ms_service.grpc.pb.h" -#include "proto/ms_master.pb.h" -#include "proto/ms_master.grpc.pb.h" -#include "proto/ms_worker.grpc.pb.h" - -namespace mindspore { -namespace serving { -class MSServiceClient; -extern std::unique_ptr client_; - -using PredictOnFinish = std::function; - -class MSServiceClient { - public: - MSServiceClient() = default; - ~MSServiceClient(); - void AsyncCompleteRpc(); - void Start(); - - void PredictAsync(const proto::PredictRequest &request, proto::PredictReply *reply, - std::shared_ptr stub, DispatchCallback callback); - - private: - struct AsyncClientCall { - grpc::ClientContext context; - grpc::Status status; - proto::PredictReply *reply; - DispatchCallback callback; - std::shared_ptr> response_reader; - }; - - grpc::CompletionQueue cq_; - std::thread client_thread_; - bool in_running_ = false; -}; - -} // namespace serving -} // namespace mindspore - -#endif // MINDSPORE_SERVING_MASTER_GRPC_CLIENT_H diff --git a/mindspore_serving/ccsrc/master/notify_worker/base_notify.h b/mindspore_serving/ccsrc/master/notify_worker/base_notify.h index 5d8cb11..5ccb0c3 100644 --- a/mindspore_serving/ccsrc/master/notify_worker/base_notify.h +++ b/mindspore_serving/ccsrc/master/notify_worker/base_notify.h @@ -22,12 +22,11 @@ #include "common/serving_common.h" #include "common/servable.h" #include "proto/ms_service.pb.h" +#include "common/grpc_client.h" namespace mindspore { namespace serving { -using DispatchCallback = std::function; - class MS_API BaseNotifyWorker { public: BaseNotifyWorker() = default; diff --git a/mindspore_serving/ccsrc/master/notify_worker/grpc_notify.cc b/mindspore_serving/ccsrc/master/notify_worker/grpc_notify.cc index 4420d44..b60a86d 100644 --- a/mindspore_serving/ccsrc/master/notify_worker/grpc_notify.cc +++ b/mindspore_serving/ccsrc/master/notify_worker/grpc_notify.cc @@ -20,7 +20,6 @@ #include #include "common/exit_handle.h" #include "common/grpc_server.h" -#include "master/grpc/grpc_client.h" namespace mindspore { namespace serving { @@ -56,10 +55,10 @@ Status GrpcNotfiyWorker::DispatchAsync(const proto::PredictRequest &request, pro << worker_address_; } if (!client_) { - client_ = std::make_unique(); + client_ = std::make_unique(); client_->Start(); } - client_->PredictAsync(request, reply, stub_, callback); + client_->PredictAsync(request, reply, stub_.get(), callback); return SUCCESS; } diff --git a/mindspore_serving/ccsrc/master/server.cc b/mindspore_serving/ccsrc/master/server.cc index 5117bad..980daac 100644 --- a/mindspore_serving/ccsrc/master/server.cc +++ b/mindspore_serving/ccsrc/master/server.cc @@ -39,7 +39,6 @@ Status Server::StartGrpcServer(const std::string &ip, uint32_t grpc_port, int ma if (grpc_async_server_ != nullptr) { return INFER_STATUS_LOG_ERROR(SYSTEM_ERROR) << "Serving Error: Serving gRPC server is already running"; } - ExitSignalHandle::Instance().Start(); // handle ctrl+c to exit if (max_msg_mb_size > gRpcMaxMBMsgSize) { MSI_LOG_WARNING << "The maximum Serving gRPC message size is 512MB and will be updated from " << max_msg_mb_size << "MB to 512MB"; @@ -50,14 +49,12 @@ Status Server::StartGrpcServer(const std::string &ip, uint32_t grpc_port, int ma } Status Server::StartGrpcMasterServer(const std::string &ip, uint32_t grpc_port) { - ExitSignalHandle::Instance().Start(); // handle ctrl+c to exit return grpc_manager_server_.Start(std::make_shared(dispatcher_), ip, grpc_port, gRpcMaxMBMsgSize, "Master"); } Status Server::StartRestfulServer(const std::string &ip, uint32_t restful_port, int max_msg_mb_size, int time_out_second) { - ExitSignalHandle::Instance().Start(); // handle ctrl+c to exit return restful_server_.Start(ip, restful_port, max_msg_mb_size, time_out_second); } diff --git a/mindspore_serving/ccsrc/python/agent/agent_py.cc b/mindspore_serving/ccsrc/python/agent/agent_py.cc new file mode 100644 index 0000000..c2c1465 --- /dev/null +++ b/mindspore_serving/ccsrc/python/agent/agent_py.cc @@ -0,0 +1,63 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "python/agent/agent_py.h" +#include "common/exit_handle.h" +#include "worker/distributed_worker/agent_startup.h" +#include "worker/distributed_worker/worker_agent.h" + +namespace mindspore::serving { + +DistributedServableConfig PyAgent::GetAgentsConfigsFromWorker(const std::string &worker_ip, uint32_t worker_port) { + auto status = WorkerAgentStartUp::Instance().GetAgentsConfigsFromWorker(worker_ip, worker_port); + if (status != SUCCESS) { + MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); + } + + DistributedServableConfig config; + status = WorkerAgentStartUp::Instance().GetDistributedServableConfig(&config); + if (status != SUCCESS) { + MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); + } + return config; +} + +void PyAgent::NotifyFailed(const std::string &worker_ip, uint32_t worker_port) { + WorkerAgentStartUp::Instance().NotifyFailed(worker_ip, worker_port); +} + +void PyAgent::StartAgent(const AgentStartUpConfig &start_config) { + auto status = WorkerAgent::Instance().StartAgent(start_config); + if (status != SUCCESS) { + MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); + } +} + +void PyAgent::WaitAndClear() { + { + py::gil_scoped_release release; + ExitSignalHandle::Instance().AgentWait(); + } + WorkerAgent::Instance().Clear(); + MSI_LOG_INFO << "Python agent end wait and clear"; +} + +void PyAgent::StopAndClear() { + ExitSignalHandle::Instance().Stop(); + WorkerAgent::Instance().Clear(); +} + +} // namespace mindspore::serving diff --git a/mindspore_serving/ccsrc/python/agent/agent_py.h b/mindspore_serving/ccsrc/python/agent/agent_py.h new file mode 100644 index 0000000..708b673 --- /dev/null +++ b/mindspore_serving/ccsrc/python/agent/agent_py.h @@ -0,0 +1,47 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_SERVER_AGENT_PY_H +#define MINDSPORE_SERVER_AGENT_PY_H + +#include +#include +#include +#include +#include +#include "common/serving_common.h" +#include "worker/distributed_worker/common.h" + +namespace py = pybind11; + +namespace mindspore { +namespace serving { + +class MS_API PyAgent { + public: + static void StartAgent(const AgentStartUpConfig &start_config); + + static DistributedServableConfig GetAgentsConfigsFromWorker(const std::string &worker_ip, uint32_t worker_port); + static void WaitAndClear(); + static void StopAndClear(); + // from start up, not agent + static void NotifyFailed(const std::string &worker_ip, uint32_t worker_port); +}; + +} // namespace serving +} // namespace mindspore + +#endif // MINDSPORE_SERVER_AGENT_PY_H diff --git a/mindspore_serving/ccsrc/python/serving_py.cc b/mindspore_serving/ccsrc/python/serving_py.cc index 7de691c..adf29d3 100644 --- a/mindspore_serving/ccsrc/python/serving_py.cc +++ b/mindspore_serving/ccsrc/python/serving_py.cc @@ -23,10 +23,14 @@ #include "common/servable.h" #include "worker/context.h" #include "python/master/master_py.h" +#include "python/agent/agent_py.h" +#include "common/exit_handle.h" +#include "worker/distributed_worker/worker_agent.h" namespace mindspore::serving { -PYBIND11_MODULE(_mindspore_serving, m) { +void PyRegServable(pybind11::module *m_ptr) { + auto &m = *m_ptr; // avoid as numpy object memory copy in PyTensor::AsPythonData py::class_(m, "Tensor_"); @@ -68,16 +72,30 @@ PYBIND11_MODULE(_mindspore_serving, m) { .def_readwrite("version_number", &RequestSpec::version_number) .def_readwrite("method_name", &RequestSpec::method_name); + py::class_(m, "CommonServableMeta_") + .def(py::init<>()) + .def_readwrite("servable_name", &CommonServableMeta::servable_name) + .def_readwrite("inputs_count", &CommonServableMeta::inputs_count) + .def_readwrite("outputs_count", &CommonServableMeta::outputs_count) + .def_readwrite("with_batch_dim", &CommonServableMeta::with_batch_dim) + .def_readwrite("without_batch_dim_inputs", &CommonServableMeta::without_batch_dim_inputs); + + py::class_(m, "LocalServableMeta_") + .def(py::init<>()) + .def_readwrite("servable_file", &LocalServableMeta::servable_file) + .def_readwrite("options", &LocalServableMeta::load_options) + .def("set_model_format", &LocalServableMeta::SetModelFormat); + + py::class_(m, "DistributedServableMeta_") + .def(py::init<>()) + .def_readwrite("rank_size", &DistributedServableMeta::rank_size) + .def_readwrite("stage_size", &DistributedServableMeta::stage_size); + py::class_(m, "ServableMeta_") .def(py::init<>()) - .def_readwrite("servable_name", &ServableMeta::servable_name) - .def_readwrite("inputs_count", &ServableMeta::inputs_count) - .def_readwrite("outputs_count", &ServableMeta::outputs_count) - .def_readwrite("servable_file", &ServableMeta::servable_file) - .def_readwrite("with_batch_dim", &ServableMeta::with_batch_dim) - .def_readwrite("options", &ServableMeta::load_options) - .def_readwrite("without_batch_dim_inputs", &ServableMeta::without_batch_dim_inputs) - .def("set_model_format", &ServableMeta::SetModelFormat); + .def_readwrite("common_meta", &ServableMeta::common_meta) + .def_readwrite("local_meta", &ServableMeta::local_meta) + .def_readwrite("distributed_meta", &ServableMeta::distributed_meta); py::class_(m, "ServableSignature_") .def(py::init<>()) @@ -87,8 +105,34 @@ PYBIND11_MODULE(_mindspore_serving, m) { py::class_(m, "ServableStorage_") .def_static("register_servable_input_output_info", &PyServableStorage::RegisterInputOutputInfo) .def_static("register_method", &PyServableStorage::RegisterMethod) - .def_static("declare_servable", &PyServableStorage::DeclareServable); + .def_static("declare_servable", &PyServableStorage::DeclareServable) + .def_static("declare_distributed_servable", &PyServableStorage::DeclareDistributedServable); + + py::class_(m, "OneRankConfig_") + .def(py::init<>()) + .def_readwrite("device_id", &OneRankConfig::device_id) + .def_readwrite("ip", &OneRankConfig::ip); + + py::class_(m, "DistributedServableConfig_") + .def(py::init<>()) + .def_readwrite("common_meta", &DistributedServableConfig::common_meta) + .def_readwrite("distributed_meta", &DistributedServableConfig::distributed_meta) + .def_readwrite("rank_table_content", &DistributedServableConfig::rank_table_content) + .def_readwrite("rank_list", &DistributedServableConfig::rank_list); +} + +void PyRegMaster(pybind11::module *m_ptr) { + auto &m = *m_ptr; + py::class_(m, "Master_") + .def_static("start_grpc_server", &PyMaster::StartGrpcServer) + .def_static("start_grpc_master_server", &PyMaster::StartGrpcMasterServer) + .def_static("start_restful_server", &PyMaster::StartRestfulServer) + .def_static("wait_and_clear", &PyMaster::WaitAndClear) + .def_static("stop_and_clear", &PyMaster::StopAndClear); +} +void PyRegWorker(pybind11::module *m_ptr) { + auto &m = *m_ptr; py::class_(m, "TaskContext_").def(py::init<>()); py::class_(m, "TaskItem_") @@ -108,6 +152,8 @@ PYBIND11_MODULE(_mindspore_serving, m) { py::class_(m, "Worker_") .def_static("start_servable", &PyWorker::StartServable) .def_static("start_servable_in_master", &PyWorker::StartServableInMaster) + .def_static("start_distributed_servable", &PyWorker::StartDistributedServable) + .def_static("start_distributed_servable_in_master", &PyWorker::StartDistributedServableInMaster) .def_static("get_batch_size", &PyWorker::GetBatchSize) .def_static("wait_and_clear", &PyWorker::WaitAndClear) .def_static("stop_and_clear", PyWorker::StopAndClear) @@ -130,17 +176,52 @@ PYBIND11_MODULE(_mindspore_serving, m) { } }) .def("set_device_id", &ServableContext::SetDeviceId); +} - py::class_>(m, "Master_") - .def_static("start_grpc_server", &PyMaster::StartGrpcServer) - .def_static("start_grpc_master_server", &PyMaster::StartGrpcMasterServer) - .def_static("start_restful_server", &PyMaster::StartRestfulServer) - .def_static("wait_and_clear", &PyMaster::WaitAndClear) - .def_static("stop_and_clear", &PyMaster::StopAndClear); +void PyRegWorkerAgent(pybind11::module *m_ptr) { + auto &m = *m_ptr; + py::class_(m, "WorkerAgent_") + .def_static("get_agents_config_from_worker", &PyAgent::GetAgentsConfigsFromWorker) + .def_static("wait_and_clear", &PyAgent::WaitAndClear) + .def_static("stop_and_clear", &PyAgent::StopAndClear) + .def_static("notify_failed", &PyAgent::NotifyFailed) + .def_static("start_agent", &PyAgent::StartAgent); + + py::class_(m, "AgentStartUpConfig_") + .def(py::init<>()) + .def_readwrite("rank_id", &AgentStartUpConfig::rank_id) + .def_readwrite("device_id", &AgentStartUpConfig::device_id) + .def_readwrite("model_file_name", &AgentStartUpConfig::model_file_name) + .def_readwrite("group_file_name", &AgentStartUpConfig::group_file_name) + .def_readwrite("rank_table_json_file_name", &AgentStartUpConfig::rank_table_json_file_name) + .def_readwrite("agent_ip", &AgentStartUpConfig::agent_ip) + .def_readwrite("agent_port", &AgentStartUpConfig::agent_port) + .def_readwrite("worker_ip", &AgentStartUpConfig::worker_ip) + .def_readwrite("worker_port", &AgentStartUpConfig::worker_port) + .def_readwrite("common_meta", &AgentStartUpConfig::common_meta); +} + +class PyExitSignalHandle { + public: + static void Start() { ExitSignalHandle::Instance().Start(); } + static bool HasStopped() { return ExitSignalHandle::Instance().HasStopped(); } +}; + +// cppcheck-suppress syntaxError +PYBIND11_MODULE(_mindspore_serving, m) { + PyRegServable(&m); + PyRegMaster(&m); + PyRegWorker(&m); + PyRegWorkerAgent(&m); + + py::class_(m, "ExitSignalHandle_") + .def_static("start", &PyExitSignalHandle::Start) + .def_static("has_stopped", &PyExitSignalHandle::HasStopped); (void)py::module::import("atexit").attr("register")(py::cpp_function{[&]() -> void { Server::Instance().Clear(); Worker::GetInstance().Clear(); + WorkerAgent::Instance().Clear(); }}); } diff --git a/mindspore_serving/ccsrc/python/worker/servable_py.cc b/mindspore_serving/ccsrc/python/worker/servable_py.cc index 8fa565e..b320722 100644 --- a/mindspore_serving/ccsrc/python/worker/servable_py.cc +++ b/mindspore_serving/ccsrc/python/worker/servable_py.cc @@ -25,7 +25,16 @@ void PyServableStorage::RegisterMethod(const MethodSignature &method) { } } void PyServableStorage::DeclareServable(const ServableMeta &servable) { - ServableStorage::Instance().DeclareServable(servable); + auto status = ServableStorage::Instance().DeclareServable(servable); + if (status != SUCCESS) { + MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); + } +} +void PyServableStorage::DeclareDistributedServable(const ServableMeta &servable) { + auto status = ServableStorage::Instance().DeclareDistributedServable(servable); + if (status != SUCCESS) { + MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); + } } void PyServableStorage::RegisterInputOutputInfo(const std::string &servable_name, size_t inputs_count, size_t outputs_count) { diff --git a/mindspore_serving/ccsrc/python/worker/servable_py.h b/mindspore_serving/ccsrc/python/worker/servable_py.h index af9b26f..759289e 100644 --- a/mindspore_serving/ccsrc/python/worker/servable_py.h +++ b/mindspore_serving/ccsrc/python/worker/servable_py.h @@ -27,6 +27,7 @@ class MS_API PyServableStorage { static void RegisterMethod(const MethodSignature &method); static void DeclareServable(const ServableMeta &servable); + static void DeclareDistributedServable(const ServableMeta &servable); static void RegisterInputOutputInfo(const std::string &servable_name, size_t inputs_count, size_t outputs_count); static void Clear(); diff --git a/mindspore_serving/ccsrc/python/worker/worker_py.cc b/mindspore_serving/ccsrc/python/worker/worker_py.cc index 2980aae..c1b03a5 100644 --- a/mindspore_serving/ccsrc/python/worker/worker_py.cc +++ b/mindspore_serving/ccsrc/python/worker/worker_py.cc @@ -21,21 +21,33 @@ #include "common/exit_handle.h" #include "worker/notfiy_master/grpc_notify.h" #include "worker/notfiy_master/local_notify.h" +#include "worker/local_servable/local_sevable.h" +#include "worker/distributed_worker/distributed_servable.h" +#include "worker/grpc/worker_server.h" +#include "worker/distributed_worker/distributed_process/distributed_server.h" namespace mindspore::serving { void PyWorker::StartServable(const std::string &model_directory, const std::string &model_name, uint32_t version_number, - const std::string &master_ip, uint32_t master_port, const std::string &host_ip, - uint32_t host_port) { - auto notify_master = std::make_shared(master_ip, master_port, host_ip, host_port); - auto status = Worker::GetInstance().StartServable(model_directory, model_name, version_number, notify_master); + const std::string &master_ip, uint32_t master_port, const std::string &worker_ip, + uint32_t worker_port) { + auto notify_master = std::make_shared(master_ip, master_port, worker_ip, worker_port); + auto servable = std::make_shared(); + auto status = servable->StartServable(model_directory, model_name, version_number); if (status != SUCCESS) { MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); } - status = Worker::GetInstance().StartGrpcServer(host_ip, host_port); + status = Worker::GetInstance().StartServable(servable, notify_master); if (status != SUCCESS) { MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); } + // start grpc server + auto grpc_sever = std::make_shared(); + status = Worker::GetInstance().StartGrpcServer(grpc_sever, worker_ip, worker_port); + if (status != SUCCESS) { + MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); + } + status = Worker::GetInstance().StartVersionController(); if (status != SUCCESS) { MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); @@ -45,7 +57,69 @@ void PyWorker::StartServable(const std::string &model_directory, const std::stri void PyWorker::StartServableInMaster(const std::string &model_directory, const std::string &model_name, uint32_t version_number) { auto notify_master = std::make_shared(); - auto status = Worker::GetInstance().StartServable(model_directory, model_name, version_number, notify_master); + auto servable = std::make_shared(); + auto status = servable->StartServable(model_directory, model_name, version_number); + if (status != SUCCESS) { + MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); + } + status = Worker::GetInstance().StartServable(servable, notify_master); + if (status != SUCCESS) { + MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); + } + status = Worker::GetInstance().StartVersionController(); + if (status != SUCCESS) { + MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); + } +} + +void PyWorker::StartDistributedServable(const std::string &servable_directory, const std::string &servable_name, + const std::string &rank_table_json_file, uint32_t version_number, + const std::string &worker_ip, uint32_t worker_port, + const std::string &master_ip, uint32_t master_port, + uint32_t wait_agents_time_in_seconds) { + Status status; + auto servable = std::make_shared(); + auto grpc_sever = std::make_shared(servable); + status = Worker::GetInstance().StartGrpcServer(grpc_sever, worker_ip, worker_port); + if (status != SUCCESS) { + MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); + } + + auto notify_master = std::make_shared(master_ip, master_port, worker_ip, worker_port); + status = servable->StartServable(servable_directory, servable_name, rank_table_json_file, version_number, + wait_agents_time_in_seconds); + if (status != SUCCESS) { + MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); + } + status = Worker::GetInstance().StartServable(servable, notify_master); + if (status != SUCCESS) { + MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); + } + status = Worker::GetInstance().StartVersionController(); + if (status != SUCCESS) { + MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); + } +} + +void PyWorker::StartDistributedServableInMaster(const std::string &servable_directory, const std::string &servable_name, + const std::string &rank_table_json_file, uint32_t version_number, + const std::string &worker_ip, uint32_t worker_port, + uint32_t wait_agents_time_in_seconds) { + Status status; + auto servable = std::make_shared(); + auto grpc_sever = std::make_shared(servable); + status = Worker::GetInstance().StartGrpcServer(grpc_sever, worker_ip, worker_port); + if (status != SUCCESS) { + MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); + } + + auto notify_master = std::make_shared(); + status = servable->StartServable(servable_directory, servable_name, rank_table_json_file, version_number, + wait_agents_time_in_seconds); + if (status != SUCCESS) { + MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); + } + status = Worker::GetInstance().StartServable(servable, notify_master); if (status != SUCCESS) { MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); } diff --git a/mindspore_serving/ccsrc/python/worker/worker_py.h b/mindspore_serving/ccsrc/python/worker/worker_py.h index cf595e3..e6b2c6d 100644 --- a/mindspore_serving/ccsrc/python/worker/worker_py.h +++ b/mindspore_serving/ccsrc/python/worker/worker_py.h @@ -34,6 +34,16 @@ class MS_API PyWorker { static void StartServableInMaster(const std::string &model_directory, const std::string &model_name, uint32_t version_number); + static void StartDistributedServable(const std::string &servable_directory, const std::string &servable_name, + const std::string &rank_table_json_file, uint32_t version_number, + const std::string &worker_ip, uint32_t worker_port, const std::string &master_ip, + uint32_t master_port, uint32_t wait_agents_time_in_seconds); + + static void StartDistributedServableInMaster(const std::string &servable_directory, const std::string &servable_name, + const std::string &rank_table_json_file, uint32_t version_number, + const std::string &worker_ip, uint32_t worker_port, + uint32_t wait_agents_time_in_seconds); + static int GetBatchSize(); static void WaitAndClear(); static void StopAndClear(); diff --git a/mindspore_serving/ccsrc/worker/distributed_worker/agent_executor.cc b/mindspore_serving/ccsrc/worker/distributed_worker/agent_executor.cc new file mode 100644 index 0000000..e55509b --- /dev/null +++ b/mindspore_serving/ccsrc/worker/distributed_worker/agent_executor.cc @@ -0,0 +1,34 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "worker/distributed_worker/agent_executor.h" + +namespace mindspore { +namespace serving { + +Status WorkerAgentExecutor::LoadModelFromFile(const AgentStartUpConfig &config) { return Status(); } +Status WorkerAgentExecutor::UnloadModel() { return Status(); } +Status WorkerAgentExecutor::ExecuteModel(const std::vector &request, std::vector *reply) { + return Status(); +} +std::vector WorkerAgentExecutor::GetInputInfos() const { + return std::vector(); +} +std::vector WorkerAgentExecutor::GetOutputInfos() const { + return std::vector(); +} +ssize_t WorkerAgentExecutor::GetBatchSize() const { return 0; } +} // namespace serving +} // namespace mindspore diff --git a/mindspore_serving/ccsrc/worker/distributed_worker/agent_executor.h b/mindspore_serving/ccsrc/worker/distributed_worker/agent_executor.h new file mode 100644 index 0000000..7a00fb0 --- /dev/null +++ b/mindspore_serving/ccsrc/worker/distributed_worker/agent_executor.h @@ -0,0 +1,48 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_SERVING_WORKER_AGENT_EXECUTOR_H +#define MINDSPORE_SERVING_WORKER_AGENT_EXECUTOR_H + +#include +#include "common/serving_common.h" +#include "worker/inference/inference.h" +#include "worker/distributed_worker/common.h" + +namespace mindspore { +namespace serving { +class MS_API WorkerAgentExecutor { + public: + // from python + Status LoadModelFromFile(const AgentStartUpConfig &config); + // ctrl+c, worker exit + Status UnloadModel(); + + // from worker + Status ExecuteModel(const std::vector &request, std::vector *reply); + + // for register + std::vector GetInputInfos() const; + + std::vector GetOutputInfos() const; + + ssize_t GetBatchSize() const; +}; + +} // namespace serving +} // namespace mindspore + +#endif // MINDSPORE_SERVING_WORKER_AGENT_EXECUTOR_H diff --git a/mindspore_serving/ccsrc/worker/distributed_worker/agent_process/agent_process.cc b/mindspore_serving/ccsrc/worker/distributed_worker/agent_process/agent_process.cc new file mode 100644 index 0000000..6e1750a --- /dev/null +++ b/mindspore_serving/ccsrc/worker/distributed_worker/agent_process/agent_process.cc @@ -0,0 +1,37 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "worker/distributed_worker/agent_process/agent_process.h" +#include "worker/distributed_worker/worker_agent.h" + +namespace mindspore { +namespace serving { +grpc::Status MSAgentImpl::Exit(grpc::ServerContext *context, const proto::DistributedExitRequest *request, + proto::DistributedExitReply *reply) { + MSI_LOG(INFO) << "Distributed Worker Exit"; + WorkerAgent::Instance().StopAgent(false); + return grpc::Status::OK; +} + +grpc::Status MSAgentImpl::Predict(grpc::ServerContext *context, const proto::DistributedPredictRequest *request, + proto::DistributedPredictReply *reply) { + MSI_LOG(INFO) << "Begin call service Eval"; + WorkerAgent::Instance().Run(*request, reply); + return grpc::Status::OK; +} + +} // namespace serving +} // namespace mindspore diff --git a/mindspore_serving/ccsrc/worker/distributed_worker/agent_process/agent_process.h b/mindspore_serving/ccsrc/worker/distributed_worker/agent_process/agent_process.h new file mode 100644 index 0000000..d0ea12c --- /dev/null +++ b/mindspore_serving/ccsrc/worker/distributed_worker/agent_process/agent_process.h @@ -0,0 +1,42 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_SERVING_WORKER_AGENT_PROCESS_H +#define MINDSPORE_SERVING_WORKER_AGENT_PROCESS_H + +#include +#include +#include +#include "common/serving_common.h" +#include "proto/ms_agent.pb.h" +#include "proto/ms_agent.grpc.pb.h" + +namespace mindspore { +namespace serving { + +// Service Implement +class MSAgentImpl final : public proto::MSAgent::Service { + public: + grpc::Status Predict(grpc::ServerContext *context, const proto::DistributedPredictRequest *request, + proto::DistributedPredictReply *reply) override; + grpc::Status Exit(grpc::ServerContext *context, const proto::DistributedExitRequest *request, + proto::DistributedExitReply *reply) override; +}; + +} // namespace serving +} // namespace mindspore + +#endif // MINDSPORE_SERVING_WORKER_AGENT_PROCESS_H diff --git a/mindspore_serving/ccsrc/worker/distributed_worker/agent_startup.cc b/mindspore_serving/ccsrc/worker/distributed_worker/agent_startup.cc new file mode 100644 index 0000000..8ec9a39 --- /dev/null +++ b/mindspore_serving/ccsrc/worker/distributed_worker/agent_startup.cc @@ -0,0 +1,45 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "worker/distributed_worker/agent_startup.h" +#include "worker/distributed_worker/notify_distributed/notify_worker.h" + +namespace mindspore { +namespace serving { + +WorkerAgentStartUp &WorkerAgentStartUp::Instance() { + static WorkerAgentStartUp instance; + return instance; +} + +Status WorkerAgentStartUp::GetAgentsConfigsFromWorker(const std::string &worker_ip, uint32_t worker_port) { + return Status(); +} + +Status WorkerAgentStartUp::GetDistributedServableConfig(DistributedServableConfig *config) { + MSI_EXCEPTION_IF_NULL(config); + if (config_.rank_list.empty()) { + return INFER_STATUS_LOG_ERROR(FAILED) << "Rank table config is not ready"; + } + *config = config_; + return SUCCESS; +} + +Status WorkerAgentStartUp::NotifyFailed(const std::string &worker_ip, uint32_t worker_port) { + return GrpcNotifyDistributeWorker::NotifyFailed(worker_ip, worker_port); +} + +} // namespace serving +} // namespace mindspore diff --git a/mindspore_serving/ccsrc/worker/distributed_worker/agent_startup.h b/mindspore_serving/ccsrc/worker/distributed_worker/agent_startup.h new file mode 100644 index 0000000..ad28e5c --- /dev/null +++ b/mindspore_serving/ccsrc/worker/distributed_worker/agent_startup.h @@ -0,0 +1,48 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_SERVING_WORKER_AGENT_STARTUP_H +#define MINDSPORE_SERVING_WORKER_AGENT_STARTUP_H +#include +#include +#include "common/serving_common.h" +#include "worker/distributed_worker/common.h" +#include "worker/inference/inference.h" + +namespace mindspore { +namespace serving { + +class MS_API WorkerAgentStartUp { + public: + static WorkerAgentStartUp &Instance(); + // from python, worker_agent.py + // start_worker_agent + // step1, get agents config from worker + Status GetAgentsConfigsFromWorker(const std::string &worker_ip, uint32_t worker_port); + // step2, invoke from python + Status GetDistributedServableConfig(DistributedServableConfig *config); + + Status NotifyFailed(const std::string &worker_ip, uint32_t worker_port); + + private: + DistributedServableConfig config_; + std::string worker_address_; +}; + +} // namespace serving +} // namespace mindspore + +#endif // MINDSPORE_SERVING_WORKER_AGENT_STARTUP_H diff --git a/mindspore_serving/ccsrc/worker/distributed_worker/common.h b/mindspore_serving/ccsrc/worker/distributed_worker/common.h new file mode 100644 index 0000000..c145bcd --- /dev/null +++ b/mindspore_serving/ccsrc/worker/distributed_worker/common.h @@ -0,0 +1,61 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_SERVING_DISTRIBUTED_WORKER_COMMON_H +#define MINDSPORE_SERVING_DISTRIBUTED_WORKER_COMMON_H + +#include +#include +#include +#include "common/serving_common.h" +#include "worker/inference/inference.h" +#include "common/servable.h" + +namespace mindspore { +namespace serving { + +struct OneRankConfig { + std::string ip; + uint32_t device_id = 0; +}; + +struct DistributedServableConfig { + std::string rank_table_content; + std::vector rank_list; + + CommonServableMeta common_meta; + DistributedServableMeta distributed_meta; +}; + +struct AgentStartUpConfig { + uint32_t rank_id; + uint32_t device_id; + std::string model_file_name; + std::string group_file_name; + std::string rank_table_json_file_name; + + std::string agent_ip; + uint32_t agent_port; + std::string worker_ip; + uint32_t worker_port; + + CommonServableMeta common_meta; +}; + +} // namespace serving +} // namespace mindspore + +#endif // MINDSPORE_SERVING_DISTRIBUTED_WORKER_COMMON_H diff --git a/mindspore_serving/ccsrc/worker/distributed_worker/distributed_process/distributed_process.cc b/mindspore_serving/ccsrc/worker/distributed_worker/distributed_process/distributed_process.cc new file mode 100644 index 0000000..48d1042 --- /dev/null +++ b/mindspore_serving/ccsrc/worker/distributed_worker/distributed_process/distributed_process.cc @@ -0,0 +1,72 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "worker/distributed_worker/distributed_process/distributed_process.h" +#include "worker/worker.h" +#include "common/proto_tensor.h" + +namespace mindspore { +namespace serving { + +grpc::Status MSDistributedImpl::AgentRegister(grpc::ServerContext *context, const proto::AgentRegisterRequest *request, + proto::AgentRegisterReply *reply) { + MSI_EXCEPTION_IF_NULL(request); + MSI_EXCEPTION_IF_NULL(reply); + for (auto &spec : request->agent_spec()) { + WorkerAgentSpec agent_spec; + agent_spec.agent_address = request->address(); + GrpcTensorHelper::CopyFromAgentSpec(spec, &agent_spec); + Status status(FAILED); + status = servable_->RegisterAgent(agent_spec); + if (status != SUCCESS) { + MSI_LOG(ERROR) << "Agent Register FAILED"; + } + } + return grpc::Status::OK; +} + +grpc::Status MSDistributedImpl::AgentExit(grpc::ServerContext *context, const proto::AgentExitRequest *request, + proto::AgentExitReply *reply) { + MSI_EXCEPTION_IF_NULL(request); + MSI_EXCEPTION_IF_NULL(reply); + for (auto &spec : request->agent_spec()) { + WorkerAgentSpec agent_spec; + agent_spec.agent_address = request->address(); + GrpcTensorHelper::CopyFromAgentSpec(spec, &agent_spec); + Status status(FAILED); + status = servable_->UnregisterAgent(agent_spec); + if (status != SUCCESS) { + MSI_LOG(ERROR) << "Agent Exit FAILED"; + } + } + if (Worker::GetInstance().IsRunning()) { + Worker::GetInstance().StopServable(); + } + return grpc::Status::OK; +} + +grpc::Status MSDistributedImpl::AgentFailed(grpc::ServerContext *context, const proto::AgentFailedRequest *request, + proto::AgentFailedReply *reply) { + if (Worker::GetInstance().IsRunning()) { + MSI_LOG_ERROR << "Expect worker should not be running"; + Worker::GetInstance().StopServable(); + } else { + servable_->OnAgentFailed(); + } + return grpc::Status::OK; +} +} // namespace serving +} // namespace mindspore diff --git a/mindspore_serving/ccsrc/worker/distributed_worker/distributed_process/distributed_process.h b/mindspore_serving/ccsrc/worker/distributed_worker/distributed_process/distributed_process.h new file mode 100644 index 0000000..147e7c5 --- /dev/null +++ b/mindspore_serving/ccsrc/worker/distributed_worker/distributed_process/distributed_process.h @@ -0,0 +1,54 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_SERVING_DISTRIBUTED_WORKER_WORKER_PROCESS_H +#define MINDSPORE_SERVING_DISTRIBUTED_WORKER_WORKER_PROCESS_H + +#include +#include +#include +#include +#include "common/serving_common.h" +#include "proto/ms_service.pb.h" +#include "proto/ms_service.grpc.pb.h" +#include "proto/ms_distributed.pb.h" +#include "proto/ms_distributed.grpc.pb.h" +#include "worker/distributed_worker/distributed_servable.h" +#include "worker/grpc/worker_process.h" + +namespace mindspore { +namespace serving { + +// Service Implement +class MSDistributedImpl final : public MSWorkerImpl { + public: + explicit MSDistributedImpl(std::shared_ptr servable) : servable_(servable) {} + ~MSDistributedImpl() = default; + grpc::Status AgentRegister(grpc::ServerContext *context, const proto::AgentRegisterRequest *request, + proto::AgentRegisterReply *reply) override; + grpc::Status AgentExit(grpc::ServerContext *context, const proto::AgentExitRequest *request, + proto::AgentExitReply *reply) override; + grpc::Status AgentFailed(grpc::ServerContext *context, const proto::AgentFailedRequest *request, + proto::AgentFailedReply *reply) override; + + private: + std::shared_ptr servable_; +}; + +} // namespace serving +} // namespace mindspore + +#endif // MINDSPORE_SERVING_DISTRIBUTED_WORKER_WORKER_PROCESS_H diff --git a/mindspore_serving/ccsrc/worker/distributed_worker/distributed_process/distributed_server.cc b/mindspore_serving/ccsrc/worker/distributed_worker/distributed_process/distributed_server.cc new file mode 100644 index 0000000..d9de7cd --- /dev/null +++ b/mindspore_serving/ccsrc/worker/distributed_worker/distributed_process/distributed_server.cc @@ -0,0 +1,37 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "worker/distributed_worker/distributed_process/distributed_server.h" +#include +#include +#include +#include "common/grpc_server.h" + +namespace mindspore { +namespace serving { + +Status MSDistributedWorkerServer::StartWorkerGrpcServer(const std::string &hostname, int32_t port) { + if (in_running_) { + return INFER_STATUS_LOG_ERROR(FAILED) << "Worker grpc server is already running"; + } + auto impl = std::make_unique(servable_); + async_server_ = std::make_unique(hostname, port, impl.get()); + service_impl_ = std::move(impl); + return Init(); +} + +} // namespace serving +} // namespace mindspore diff --git a/mindspore_serving/ccsrc/worker/distributed_worker/distributed_process/distributed_server.h b/mindspore_serving/ccsrc/worker/distributed_worker/distributed_process/distributed_server.h new file mode 100644 index 0000000..ca6b967 --- /dev/null +++ b/mindspore_serving/ccsrc/worker/distributed_worker/distributed_process/distributed_server.h @@ -0,0 +1,178 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_SERVING_WORKER_DISTRIBUTED_WORKER_SERVER_H +#define MINDSPORE_SERVING_WORKER_DISTRIBUTED_WORKER_SERVER_H + +#include +#include +#include +#include +#include +#include "common/serving_common.h" +#include "proto/ms_worker.pb.h" +#include "proto/ms_worker.grpc.pb.h" +#include "common/grpc_async_server.h" +#include "worker/grpc/worker_process.h" +#include "worker/grpc/worker_server.h" +#include "worker/distributed_worker/distributed_process/distributed_process.h" + +namespace mindspore { +namespace serving { + +// Service Implement +class MS_API MSDistributedWorkerServer : public MSWorkerServer { + public: + explicit MSDistributedWorkerServer(std::shared_ptr servable) : servable_(servable) {} + ~MSDistributedWorkerServer() = default; + Status StartWorkerGrpcServer(const std::string &hostname, int32_t port) override; + + private: + std::shared_ptr servable_; +}; + +class DistributedServiceContext : public WorkerServiceContext { + public: + DistributedServiceContext(MSDistributedImpl *service_impl, proto::MSWorker::AsyncService *async_service, + grpc::ServerCompletionQueue *cq) + : WorkerServiceContext(service_impl, async_service, cq), dist_service_impl_(service_impl) {} + + protected: + MSDistributedImpl *dist_service_impl_ = nullptr; +}; + +// Service Implement +class WorkerAgentRegisterContext : public DistributedServiceContext { + public: + WorkerAgentRegisterContext(MSDistributedImpl *service_impl, proto::MSWorker::AsyncService *async_service, + grpc::ServerCompletionQueue *cq) + : DistributedServiceContext(service_impl, async_service, cq), responder_(&ctx_) {} + + ~WorkerAgentRegisterContext() = default; + + static Status EnqueueRequest(MSDistributedImpl *service_impl, proto::MSWorker::AsyncService *async_service, + grpc::ServerCompletionQueue *cq) { + auto call = new WorkerAgentRegisterContext(service_impl, async_service, cq); + call->StartEnqueueRequest(); + return SUCCESS; + } + + void StartEnqueueRequest() override { + state_ = STATE::PROCESS; + async_service_->RequestAgentRegister(&ctx_, &request_, &responder_, cq_, cq_, this); + } + + void HandleRequest() override { + EnqueueRequest(dist_service_impl_, async_service_, cq_); + state_ = STATE::FINISH; + grpc::Status status = dist_service_impl_->AgentRegister(&ctx_, &request_, &response_); + responder_.Finish(response_, status, this); + } + + private: + grpc::ServerAsyncResponseWriter responder_; + proto::AgentRegisterRequest request_; + proto::AgentRegisterReply response_; +}; + +class WorkerAgentExitContext : public DistributedServiceContext { + public: + WorkerAgentExitContext(MSDistributedImpl *service_impl, proto::MSWorker::AsyncService *async_service, + grpc::ServerCompletionQueue *cq) + : DistributedServiceContext(service_impl, async_service, cq), responder_(&ctx_) {} + + ~WorkerAgentExitContext() = default; + + static Status EnqueueRequest(MSDistributedImpl *service_impl, proto::MSWorker::AsyncService *async_service, + grpc::ServerCompletionQueue *cq) { + auto call = new WorkerAgentExitContext(service_impl, async_service, cq); + call->StartEnqueueRequest(); + return SUCCESS; + } + + void StartEnqueueRequest() override { + state_ = STATE::PROCESS; + async_service_->RequestAgentExit(&ctx_, &request_, &responder_, cq_, cq_, this); + } + + void HandleRequest() override { + EnqueueRequest(dist_service_impl_, async_service_, cq_); + state_ = STATE::FINISH; + grpc::Status status = dist_service_impl_->AgentExit(&ctx_, &request_, &response_); + responder_.Finish(response_, status, this); + } + + private: + grpc::ServerAsyncResponseWriter responder_; + proto::AgentExitRequest request_; + proto::AgentExitReply response_; +}; + +class WorkerAgentFailedContext : public DistributedServiceContext { + public: + WorkerAgentFailedContext(MSDistributedImpl *service_impl, proto::MSWorker::AsyncService *async_service, + grpc::ServerCompletionQueue *cq) + : DistributedServiceContext(service_impl, async_service, cq), responder_(&ctx_) {} + + ~WorkerAgentFailedContext() = default; + static Status EnqueueRequest(MSDistributedImpl *service_impl, proto::MSWorker::AsyncService *async_service, + grpc::ServerCompletionQueue *cq) { + auto call = new WorkerAgentFailedContext(service_impl, async_service, cq); + call->StartEnqueueRequest(); + return SUCCESS; + } + + void StartEnqueueRequest() override { + state_ = STATE::PROCESS; + async_service_->RequestAgentFailed(&ctx_, &request_, &responder_, cq_, cq_, this); + } + + void HandleRequest() override { + EnqueueRequest(dist_service_impl_, async_service_, cq_); + state_ = STATE::FINISH; + grpc::Status status = dist_service_impl_->AgentFailed(&ctx_, &request_, &response_); + responder_.Finish(response_, status, this); + } + + private: + grpc::ServerAsyncResponseWriter responder_; + proto::AgentFailedRequest request_; + proto::AgentFailedReply response_; +}; + +class DistributedWorkerGrpcServer : public WorkerGrpcServer { + public: + DistributedWorkerGrpcServer(const std::string &host, int32_t port, MSDistributedImpl *service_impl) + : WorkerGrpcServer(host, port, service_impl), distributed_service_impl_(service_impl) {} + + ~DistributedWorkerGrpcServer() = default; + + Status EnqueueRequest() { + WorkerGrpcServer::EnqueueRequest(); + WorkerAgentRegisterContext::EnqueueRequest(distributed_service_impl_, &svc_, cq_.get()); + WorkerAgentExitContext::EnqueueRequest(distributed_service_impl_, &svc_, cq_.get()); + WorkerAgentFailedContext::EnqueueRequest(distributed_service_impl_, &svc_, cq_.get()); + return SUCCESS; + } + + private: + MSDistributedImpl *distributed_service_impl_; +}; + +} // namespace serving +} // namespace mindspore + +#endif // MINDSPORE_SERVING_WORKER_DISTRIBUTED_WORKER_SERVER_H diff --git a/mindspore_serving/ccsrc/worker/distributed_worker/distributed_servable.cc b/mindspore_serving/ccsrc/worker/distributed_worker/distributed_servable.cc new file mode 100644 index 0000000..ea83d1c --- /dev/null +++ b/mindspore_serving/ccsrc/worker/distributed_worker/distributed_servable.cc @@ -0,0 +1,335 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "worker/distributed_worker/distributed_servable.h" +#include +#include +#include +#include "worker/distributed_worker/notify_agent/notify_agent.h" +#include "common/exit_handle.h" + +namespace mindspore { +namespace serving { + +DistributedServable::~DistributedServable() { Clear(); } + +std::string DistributedServable::GetServableName() const { return servable_name_; } + +uint64_t DistributedServable::GetServableVersion() const { return version_number_; } + +Status DistributedServable::Predict(const std::vector &input, std::vector *output) { + if (!model_loaded_) { + MSI_LOG_EXCEPTION << "Model has not been loaded"; + } + return Status(); +} +std::vector DistributedServable::GetInputInfos() const { + if (!model_loaded_) { + MSI_LOG_EXCEPTION << "Model has not been loaded"; + } + return input_infos_; +} + +std::vector DistributedServable::GetOutputInfos() const { + if (!model_loaded_) { + MSI_LOG_EXCEPTION << "Model has not been loaded"; + } + return output_infos_; +} + +uint64_t DistributedServable::GetBatchSize() const { + if (!model_loaded_) { + MSI_LOG_EXCEPTION << "Model has not been loaded"; + } + return batch_size_; +} + +Status DistributedServable::GetDistributedServableConfig(DistributedServableConfig *config) const { + *config = config_; + return SUCCESS; +} + +void DistributedServable::SetWaitAgentsPromise(bool flag) { + if (!promise_set_flag_.test_and_set()) { + agents_promise_.set_value(flag); + } +} + +Status DistributedServable::RegisterAgent(const WorkerAgentSpec &agent_spec) { + std::unique_lock lock{mutex_}; + + if (agent_spec.rank_id < config_.distributed_meta.rank_size) { + return INFER_STATUS_LOG_ERROR(FAILED) + << "Invalid rank id " << agent_spec.rank_id << ", rank size " << config_.distributed_meta.rank_size; + } + DistributedAgentContext context; + auto it = agent_spec_map_.find(agent_spec.rank_id); + if (it != agent_spec_map_.end()) { + MSI_LOG_WARNING << "rank_id " << agent_spec.rank_id << " has been registered"; + return SUCCESS; + } + context.agent_spec_ = agent_spec; + std::shared_ptr notify_agent = std::make_shared(agent_spec.agent_address); + context.notify_agent_ = notify_agent; + agent_spec_map_[agent_spec.rank_id] = context; + + if (agent_spec_map_.size() >= config_.distributed_meta.rank_size) { + SetWaitAgentsPromise(true); + } + return SUCCESS; +} + +void DistributedServable::Clear() { + std::unique_lock lock{mutex_}; + for (auto &agent : agent_spec_map_) { + agent.second.notify_agent_->Exit(); + } + agent_spec_map_.clear(); + MSI_LOG_INFO << "End Clear servable"; +} + +Status DistributedServable::UnregisterAgent(const WorkerAgentSpec &agent_spec) { + std::unique_lock lock{mutex_}; + for (auto iter = agent_spec_map_.begin(); iter != agent_spec_map_.end();) { + if (agent_spec.rank_id == iter->second.agent_spec_.rank_id) { + iter = agent_spec_map_.erase(iter); + } else { + ++iter; + } + } + SetWaitAgentsPromise(false); + return SUCCESS; +} + +Status DistributedServable::StartServable(const std::string &servable_directory, const std::string &servable_name, + const std::string &rank_table_json_file, uint64_t version_number, + uint64_t wait_agents_time_in_seconds) { + if (model_loaded_) { + MSI_LOG_EXCEPTION << "Model has loaded"; + } + version_number_ = version_number; + servable_name_ = servable_name; + rank_table_json_file_ = rank_table_json_file; + ServableSignature signature; + if (!ServableStorage::Instance().GetServableDef(servable_name, &signature)) { + return INFER_STATUS_LOG_ERROR(FAILED) << "Servable '" << servable_name << "' has not been registered"; + } + auto &meta = signature.servable_meta; + if (meta.servable_type != kServableTypeDistributed) { + return INFER_STATUS_LOG_ERROR(FAILED) + << "Servable '" << servable_name << "' is not registered as distributed servable, " << meta.Repr(); + } + config_.common_meta = meta.common_meta; + config_.distributed_meta = meta.distributed_meta; + + auto status = InitConfigOnStartup(rank_table_json_file_); + if (status != SUCCESS) { + MSI_LOG_ERROR << "Init with rank table on start up failed"; + return status; + } + status = CheckRankConfig(); + if (status != SUCCESS) { + MSI_LOG_ERROR << "Check rank config failed"; + return status; + } + status = WaitAgentsReady(wait_agents_time_in_seconds); + if (status != SUCCESS) { + MSI_LOG_ERROR << "Waiting for ready of agents failed"; + return status; + } + status = CheckAgentsInfosAndInitTensorInfos(); + if (status != SUCCESS) { + MSI_LOG_ERROR << "Check agents infos failed"; + return status; + } + model_loaded_ = true; + return SUCCESS; +} + +Status DistributedServable::InitConfigOnStartup(const std::string &rank_table_json_file) { return FAILED; } + +Status DistributedServable::WaitAgentsReady(uint64_t wait_agents_time_in_seconds) { + auto future = agents_promise_.get_future(); + if (wait_agents_time_in_seconds == 0) { + wait_agents_time_in_seconds = UINT32_MAX; + } + const uint64_t kWaitMaxHundredMs = wait_agents_time_in_seconds * 10; + uint64_t i; + for (i = 0; i < kWaitMaxHundredMs; i++) { // + if (ExitSignalHandle::Instance().HasStopped()) { + return INFER_STATUS_LOG_ERROR(FAILED) << "Agents has stopped"; + } + // waiting for 100ms + if (future.wait_for(std::chrono::milliseconds(100)) == std::future_status::ready) { + auto flag = future.get(); + if (!flag) { + return INFER_STATUS_LOG_ERROR(FAILED) << "Failed to starting all agents, maybe some error reported"; + } + break; + } + } + if (i >= kWaitMaxHundredMs) { + return INFER_STATUS_LOG_ERROR(FAILED) + << "Failed to wait for ready of all agents, current agents count: " << agent_spec_map_.size() + << ", rank size: " << config_.distributed_meta.rank_size; + } + return SUCCESS; +} + +Status DistributedServable::CompareTensorInfos(const std::vector &lefts, + const std::vector &rights) { + if (lefts.size() != rights.size()) { + return INFER_STATUS(FAILED) << "Size not match, left: " << lefts.size() << ", right: " << rights.size(); + } + auto tensor_info_as_str = [](const TensorInfo &tensor_info) { + Status status = INFER_STATUS(SUCCESS) << "size: " << tensor_info.size << ", data type: " << tensor_info.data_type + << ", shape: " << tensor_info.shape; + return status.StatusMessage(); + }; + for (size_t k = 0; k < lefts.size(); k++) { + auto &left = lefts[k]; + auto &right = rights[k]; + if (left.size != right.size || left.shape != right.shape || left.data_type != right.data_type) { + return INFER_STATUS(FAILED) << "Index " << k << " tensor not match, left- " << tensor_info_as_str(left) + << "; right- " << tensor_info_as_str(right); + } + } + return SUCCESS; +} + +Status DistributedServable::CheckAgentsInfosAndInitTensorInfos() { + auto rank_size = config_.distributed_meta.rank_size; + auto stage_size = config_.distributed_meta.stage_size; + auto parallel_count = rank_size / stage_size; + MSI_LOG_INFO << "Check agents infos, rank size :" << rank_size << ", stage size: " << stage_size + << ", parallel count: " << parallel_count; + if (agent_spec_map_.size() != rank_size) { + return INFER_STATUS_LOG_ERROR(FAILED) + << "Registered agents size " << agent_spec_map_.size() << " not match rank size " << rank_size; + } + + input_infos_ = agent_spec_map_[0].agent_spec_.input_infos; + output_infos_ = agent_spec_map_[rank_size - 1].agent_spec_.output_infos; + batch_size_ = agent_spec_map_[0].agent_spec_.batch_size; + if (input_infos_.empty()) { + return INFER_STATUS_LOG_ERROR(FAILED) << "Rank " << 0 << " input count cannot be 0"; + } + if (output_infos_.empty()) { + return INFER_STATUS_LOG_ERROR(FAILED) << "Rank " << rank_size - 1 << " output count cannot be 0"; + } + Status status; + for (size_t i = 0; i < parallel_count; i++) { + auto &agent_spec = agent_spec_map_[i]; + status = CompareTensorInfos(agent_spec.agent_spec_.input_infos, input_infos_); + if (status != SUCCESS) { + status = INFER_STATUS_LOG_ERROR(FAILED) + << "Rank " << i << " input infos not match rank 0, details: " << status.StatusMessage(); + return status; + } + } + for (size_t i = parallel_count; i < rank_size; i++) { + auto &agent_spec = agent_spec_map_[i]; + if (!agent_spec.agent_spec_.input_infos.empty()) { + return INFER_STATUS_LOG_ERROR(FAILED) << "Expect rank " << i << " input count equal to 0"; + } + } + for (size_t i = 0; i < rank_size; i++) { + auto &first_item = agent_spec_map_[i]; + for (size_t k = 0; k < parallel_count && i + k < rank_size; k++) { + auto rank_id = i + k; + auto &agent_spec = agent_spec_map_[i + k]; + status = CompareTensorInfos(agent_spec.agent_spec_.output_infos, first_item.agent_spec_.output_infos); + if (status != SUCCESS) { + status = INFER_STATUS_LOG_ERROR(FAILED) << "Rank " << rank_size << " output infos not match rank " << i + << ", details: " << status.StatusMessage(); + return status; + } + if (agent_spec.agent_spec_.batch_size != 0 && agent_spec.agent_spec_.batch_size != batch_size_) { + return INFER_STATUS_LOG_ERROR(FAILED) + << "Expect rank " << rank_id << " batch size equal to 0 or rank 0 batch size " << batch_size_; + } + } + } + return SUCCESS; +} + +Status DistributedServable::CheckRankConfig() { + auto rank_size = config_.distributed_meta.rank_size; + auto stage_size = config_.distributed_meta.stage_size; + if (stage_size == 0 || rank_size == 0) { + return INFER_STATUS_LOG_ERROR(FAILED) + << "Rank size or stage size cannot be 0, rank size: " << rank_size << ", stage size: " << stage_size; + } + if (rank_size % stage_size != 0) { + return INFER_STATUS_LOG_ERROR(FAILED) + << "Rank size must be an integral multiple of stage size, rank size: " << rank_size + << ", stage size: " << stage_size; + } + if (config_.rank_list.size() != rank_size) { + return INFER_STATUS_LOG_ERROR(FAILED) + << "Rank size " << config_.rank_list.size() << " declared in rank table file not equal to rank size " + << rank_size << " declared in servable_config, rank json config file: " << rank_table_json_file_; + } + auto parallel_count = rank_size / stage_size; + constexpr size_t card_count_per_machine = 8; + if (stage_size == 1) { + std::map> device_map; + for (size_t i = 0; i < rank_size; i++) { + const auto &item = config_.rank_list[i]; + auto &device_id_list = device_map[item.ip]; + if (device_id_list.count(item.device_id) > 0) { + return INFER_STATUS_LOG_ERROR(FAILED) << "Check rank table config failed, device id repeatedly used by rank " + << i << " in device ip " << item.ip; + } + device_id_list.emplace(item.device_id); + } + } else { + if (rank_size < card_count_per_machine) { + return INFER_STATUS_LOG_ERROR(FAILED) + << "Rank size " << rank_size << "must >= card count " << card_count_per_machine + << " of one machine when stage size " << stage_size << " > 1"; + } + if (parallel_count % card_count_per_machine != 0) { + return INFER_STATUS_LOG_ERROR(FAILED) + << "Parallel count " << parallel_count << " in one stage must be N * " << card_count_per_machine + << "(card count of one machine), rank size: " << rank_size << ", stage size: " << stage_size; + } + for (size_t i = 0; i < rank_size; i += card_count_per_machine) { + const auto &first_item = config_.rank_list[i]; + for (size_t k = 0; i + k < rank_size && k < card_count_per_machine; k++) { + auto rank_id = i + k; + const auto &item = config_.rank_list[rank_id]; + if (k != item.device_id) { + return INFER_STATUS_LOG_ERROR(FAILED) + << "Check rank table config failed, expected device id of rank " << rank_id << " to be " << k; + } + if (first_item.ip != item.ip) { + return INFER_STATUS_LOG_ERROR(FAILED) + << "Check rank table config failed, expected device ip " << item.ip << " of rank " << rank_id + << " to be equal with device ip " << first_item.ip << " of rank " << i; + } + } + } + } + MSI_LOG_INFO << "Check rank table success, rank size: " << rank_size << ", stage size: " << stage_size + << ", parallel count in one stage: " << parallel_count; + return SUCCESS; +} + +void DistributedServable::OnAgentFailed() { SetWaitAgentsPromise(false); } + +} // namespace serving +} // namespace mindspore diff --git a/mindspore_serving/ccsrc/worker/distributed_worker/distributed_servable.h b/mindspore_serving/ccsrc/worker/distributed_worker/distributed_servable.h new file mode 100644 index 0000000..d810209 --- /dev/null +++ b/mindspore_serving/ccsrc/worker/distributed_worker/distributed_servable.h @@ -0,0 +1,92 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_SERVING_WORKER_DISTRIBUTED_SERVABLE_H +#define MINDSPORE_SERVING_WORKER_DISTRIBUTED_SERVABLE_H + +#include +#include +#include +#include +#include "worker/sevable_base.h" +#include "worker/distributed_worker/common.h" +#include "worker/distributed_worker/notify_agent/base_notify_agent.h" + +namespace mindspore { +namespace serving { + +struct DistributedAgentContext { + WorkerAgentSpec agent_spec_; + std::shared_ptr notify_agent_ = nullptr; +}; + +class MS_API DistributedServable : public ServableBase { + public: + DistributedServable() = default; + ~DistributedServable(); + // from python, worker.py + Status StartServable(const std::string &servable_directory, const std::string &servable_name, + const std::string &rank_table_json_file, uint64_t version_number, + uint64_t wait_agents_time_in_seconds); + + // invoke from agent + Status GetDistributedServableConfig(DistributedServableConfig *config) const; + // send model and group + + // register and unregister agent, agent_spec_list_ + Status RegisterAgent(const WorkerAgentSpec &agent_spec); + Status UnregisterAgent(const WorkerAgentSpec &agent_spec); + + // predict, use config_ and agent_spec_list_ + Status Predict(const std::vector &input, std::vector *output) override; + + std::vector GetInputInfos() const override; + std::vector GetOutputInfos() const override; + uint64_t GetBatchSize() const override; + std::string GetServableName() const override; + uint64_t GetServableVersion() const override; + void Clear() override; + void OnAgentFailed(); + + private: + DistributedServableConfig config_; + std::string servable_name_; + uint64_t version_number_ = 0; + bool model_loaded_ = false; + + std::mutex mutex_; + std::map agent_spec_map_; + std::string rank_table_json_file_; + + std::vector input_infos_; + std::vector output_infos_; + uint64_t batch_size_ = 0; + std::atomic_flag promise_set_flag_ = ATOMIC_FLAG_INIT; + std::promise agents_promise_; + + Status InitConfigOnStartup(const std::string &rank_table_json_file); + Status WaitAgentsReady(uint64_t wait_agents_time_in_seconds); + Status CheckAgentsInfosAndInitTensorInfos(); + Status CompareTensorInfos(const std::vector &lefts, const std::vector &rights); + Status CheckRankConfig(); + void SetWaitAgentsPromise(bool flag); + // agent stubs +}; + +} // namespace serving +} // namespace mindspore + +#endif // MINDSPORE_SERVING_WORKER_DISTRIBUTED_SERVABLE_H diff --git a/mindspore_serving/ccsrc/worker/distributed_worker/notify_agent/base_notify_agent.h b/mindspore_serving/ccsrc/worker/distributed_worker/notify_agent/base_notify_agent.h new file mode 100644 index 0000000..ac4d5c7 --- /dev/null +++ b/mindspore_serving/ccsrc/worker/distributed_worker/notify_agent/base_notify_agent.h @@ -0,0 +1,42 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_SERVING_WORKER_BASE_NOTIFY_AGENT_H +#define MINDSPORE_SERVING_WORKER_BASE_NOTIFY_AGENT_H +#include +#include +#include +#include "common/serving_common.h" +#include "common/servable.h" +#include "proto/ms_agent.pb.h" +#include "common/grpc_client.h" + +namespace mindspore { +namespace serving { + +class MS_API BaseNotifyAgent { + public: + BaseNotifyAgent() = default; + virtual ~BaseNotifyAgent() = default; + virtual Status Exit() = 0; + virtual Status DispatchAsync(const proto::DistributedPredictRequest &request, proto::DistributedPredictReply *reply, + DispatchCallback callback) = 0; +}; + +} // namespace serving +} // namespace mindspore + +#endif // MINDSPORE_SERVING_WORKER_BASE_NOTIFY_AGENT_H diff --git a/mindspore_serving/ccsrc/worker/distributed_worker/notify_agent/notify_agent.cc b/mindspore_serving/ccsrc/worker/distributed_worker/notify_agent/notify_agent.cc new file mode 100644 index 0000000..3220a6c --- /dev/null +++ b/mindspore_serving/ccsrc/worker/distributed_worker/notify_agent/notify_agent.cc @@ -0,0 +1,66 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "worker/distributed_worker/notify_agent/notify_agent.h" +#include +#include +#include +#include +#include "common/exit_handle.h" +#include "common/grpc_server.h" +#include "common/grpc_client.h" + +namespace mindspore { +namespace serving { + +GrpcNotfiyAgent::GrpcNotfiyAgent(const std::string &agent_address) { + agent_address_ = agent_address; + std::shared_ptr channel = GrpcServer::CreateChannel(agent_address_); + stub_ = proto::MSAgent::NewStub(channel); +} + +GrpcNotfiyAgent::~GrpcNotfiyAgent() = default; + +Status GrpcNotfiyAgent::Exit() { + if (stub_) { + proto::DistributedExitRequest request; + request.set_address(agent_address_); + proto::DistributedExitReply reply; + grpc::ClientContext context; + const int32_t TIME_OUT = 1; + std::chrono::system_clock::time_point deadline = std::chrono::system_clock::now() + std::chrono::seconds(TIME_OUT); + context.set_deadline(deadline); + + (void)stub_->Exit(&context, request, &reply); + } + return SUCCESS; +} + +Status GrpcNotfiyAgent::DispatchAsync(const proto::DistributedPredictRequest &request, + proto::DistributedPredictReply *reply, DispatchCallback callback) { + if (!stub_) { + return INFER_STATUS_LOG_ERROR(FAILED) + << "Predict failed, agent gRPC has not been inited or has already exited, agent address " << agent_address_; + } + if (!distributed_client_) { + distributed_client_ = std::make_unique(); + distributed_client_->Start(); + } + distributed_client_->PredictAsync(request, reply, stub_.get(), callback); + return SUCCESS; +} // namespace serving + +} // namespace serving +} // namespace mindspore diff --git a/mindspore_serving/ccsrc/worker/distributed_worker/notify_agent/notify_agent.h b/mindspore_serving/ccsrc/worker/distributed_worker/notify_agent/notify_agent.h new file mode 100644 index 0000000..53fd39f --- /dev/null +++ b/mindspore_serving/ccsrc/worker/distributed_worker/notify_agent/notify_agent.h @@ -0,0 +1,48 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_SERVING_WORKER_NOTIFY_AGENT_H +#define MINDSPORE_SERVING_WORKER_NOTIFY_AGENT_H +#include +#include +#include +#include +#include "worker/distributed_worker/notify_agent/base_notify_agent.h" +#include "proto/ms_agent.pb.h" +#include "proto/ms_agent.grpc.pb.h" + +namespace mindspore { +namespace serving { + +class MS_API GrpcNotfiyAgent : public BaseNotifyAgent { + public: + explicit GrpcNotfiyAgent(const std::string &worker_address); + ~GrpcNotfiyAgent() override; + + Status Exit() override; + + Status DispatchAsync(const proto::DistributedPredictRequest &request, proto::DistributedPredictReply *reply, + DispatchCallback callback) override; + + private: + std::string agent_address_; + std::shared_ptr stub_ = nullptr; +}; + +} // namespace serving +} // namespace mindspore + +#endif // MINDSPORE_SERVING_WORKER_NOTIFY_AGENT_H diff --git a/mindspore_serving/ccsrc/worker/distributed_worker/notify_distributed/notify_worker.cc b/mindspore_serving/ccsrc/worker/distributed_worker/notify_distributed/notify_worker.cc new file mode 100644 index 0000000..379eeff --- /dev/null +++ b/mindspore_serving/ccsrc/worker/distributed_worker/notify_distributed/notify_worker.cc @@ -0,0 +1,107 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "worker/distributed_worker/notify_distributed/notify_worker.h" +#include +#include +#include +#include +#include "common/exit_handle.h" +#include "common/grpc_server.h" +#include "common/proto_tensor.h" + +namespace mindspore { +namespace serving { + +GrpcNotifyDistributeWorker::GrpcNotifyDistributeWorker(const std::string &distributed_worker_ip, + uint32_t distributed_worker_port, const std::string &host_ip, + uint32_t host_port) + : distributed_worker_ip_(distributed_worker_ip), + distributed_worker_port_(distributed_worker_port), + host_ip_(host_ip), + host_port_(host_port) { + distributed_worker_address_ = distributed_worker_ip + ":" + std::to_string(distributed_worker_port); + agent_address_ = host_ip_ + ":" + std::to_string(host_port_); + auto channel = GrpcServer::CreateChannel(distributed_worker_address_); + stub_ = proto::MSWorker::NewStub(channel); +} + +GrpcNotifyDistributeWorker::~GrpcNotifyDistributeWorker() = default; + +Status GrpcNotifyDistributeWorker::Register(const std::vector &worker_specs) { + const int32_t REGISTER_TIME_OUT = 60; + const int32_t REGISTER_INTERVAL = 1; + auto loop = REGISTER_TIME_OUT; + while (loop-- && !ExitSignalHandle::Instance().HasStopped()) { + MSI_LOG(INFO) << "Register to " << distributed_worker_address_; + proto::AgentRegisterRequest request; + GrpcTensorHelper::CopyFromWorkerAgentSpec(worker_specs, &request); + proto::AgentRegisterReply reply; + grpc::ClientContext context; + std::chrono::system_clock::time_point deadline = + std::chrono::system_clock::now() + std::chrono::seconds(REGISTER_INTERVAL); + context.set_deadline(deadline); + grpc::Status status = stub_->AgentRegister(&context, request, &reply); + if (status.ok()) { + MSI_LOG(INFO) << "Register SUCCESS "; + return SUCCESS; + } + MSI_LOG_INFO << "Grpc message: " << status.error_code() << ", " << status.error_message(); + std::this_thread::sleep_for(std::chrono::milliseconds(REGISTER_INTERVAL * 1000)); + } + if (ExitSignalHandle::Instance().HasStopped()) { + return INFER_STATUS_LOG_WARNING(FAILED) << "Agent exit, stop registration"; + } + return INFER_STATUS_LOG_ERROR(SYSTEM_ERROR) << "Register TimeOut"; +} + +Status GrpcNotifyDistributeWorker::Unregister() { + if (is_stoped_.load()) { + return SUCCESS; + } + is_stoped_ = true; + proto::AgentExitRequest request; + request.set_address(agent_address_); + proto::AgentExitReply reply; + grpc::ClientContext context; + const int32_t TIME_OUT = 1; + std::chrono::system_clock::time_point deadline = std::chrono::system_clock::now() + std::chrono::seconds(TIME_OUT); + context.set_deadline(deadline); + grpc::Status status = stub_->AgentExit(&context, request, &reply); + if (status.ok()) { + MSI_LOG(INFO) << "Exit SUCCESS "; + return SUCCESS; + } + return INFER_STATUS_LOG_ERROR(SYSTEM_ERROR) << "Exit Failed"; +} + +Status GrpcNotifyDistributeWorker::NotifyFailed(const std::string &worker_ip, uint32_t worker_port) { + auto address = worker_ip + ":" + std::to_string(worker_port); + auto channel = GrpcServer::CreateChannel(address); + auto stub = proto::MSWorker::NewStub(channel); + + grpc::ClientContext context; + proto::AgentFailedRequest request; + proto::AgentFailedReply reply; + grpc::Status status = stub->AgentFailed(&context, request, &reply); + if (status.ok()) { + MSI_LOG(INFO) << "Success to notify failure of agent"; + return SUCCESS; + } + return INFER_STATUS_LOG_ERROR(SYSTEM_ERROR) << "Failed to notify failure of agent"; +} + +} // namespace serving +} // namespace mindspore diff --git a/mindspore_serving/ccsrc/worker/distributed_worker/notify_distributed/notify_worker.h b/mindspore_serving/ccsrc/worker/distributed_worker/notify_distributed/notify_worker.h new file mode 100644 index 0000000..da509ff --- /dev/null +++ b/mindspore_serving/ccsrc/worker/distributed_worker/notify_distributed/notify_worker.h @@ -0,0 +1,55 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_SERVING_WORKER_NOTIFY_WORKER_H +#define MINDSPORE_SERVING_WORKER_NOTIFY_WORKER_H +#include +#include +#include +#include "common/serving_common.h" +#include "worker/distributed_worker/common.h" +#include "proto/ms_distributed.pb.h" +#include "proto/ms_distributed.grpc.pb.h" +#include "proto/ms_worker.pb.h" +#include "proto/ms_worker.grpc.pb.h" +namespace mindspore { +namespace serving { + +class MS_API GrpcNotifyDistributeWorker { + public: + GrpcNotifyDistributeWorker(const std::string &worker_ip, uint32_t worker_port, const std::string &agent_ip, + uint32_t agent_port); + ~GrpcNotifyDistributeWorker(); + Status Register(const std::vector &agent_specs); + Status Unregister(); + // from start up, not agent + static Status NotifyFailed(const std::string &worker_ip, uint32_t worker_port); + + private: + std::string distributed_worker_ip_; + uint32_t distributed_worker_port_; + std::string host_ip_; + uint32_t host_port_; + std::string agent_address_; + std::string distributed_worker_address_; + std::unique_ptr stub_; + std::atomic is_stoped_{false}; +}; + +} // namespace serving +} // namespace mindspore + +#endif // MINDSPORE_SERVING_WORKER_NOTIFY_WORKER_H diff --git a/mindspore_serving/ccsrc/worker/distributed_worker/worker_agent.cc b/mindspore_serving/ccsrc/worker/distributed_worker/worker_agent.cc new file mode 100644 index 0000000..a819b95 --- /dev/null +++ b/mindspore_serving/ccsrc/worker/distributed_worker/worker_agent.cc @@ -0,0 +1,103 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "worker/distributed_worker/worker_agent.h" +#include +#include "worker/distributed_worker/agent_process/agent_process.h" +#include "worker/distributed_worker/notify_distributed/notify_worker.h" +#include "common/exit_handle.h" + +namespace mindspore { +namespace serving { + +WorkerAgent &WorkerAgent::Instance() { + static WorkerAgent instance; + return instance; +} + +Status WorkerAgent::Clear() { + if (notify_worker_) { + if (exit_notify_worker_) { + notify_worker_->Unregister(); + } + notify_worker_ = nullptr; + } + grpc_server_.Stop(); + executor_.UnloadModel(); + return SUCCESS; +} + +Status WorkerAgent::Run(const proto::DistributedPredictRequest &request, proto::DistributedPredictReply *reply) { + // todo : DistributedPredictRequest->RequestBase + // todo : DistributedPredictReply->ReplyBase + return SUCCESS; +} + +Status WorkerAgent::StartAgent(const AgentStartUpConfig &config) { + Status status; + config_ = config; + status = executor_.LoadModelFromFile(config); + if (status != SUCCESS) { + MSI_LOG_ERROR << "LoadModelFromFile failed, servable name: " << config.common_meta.servable_name + << ", rank_id: " << config.rank_id << ", device id: " << config.device_id + << ", model file: " << config.model_file_name + << ", rank table file: " << config.rank_table_json_file_name + << ", group config file: " << config.group_file_name; + return status; + } + status = StartGrpcServer(); + if (status != SUCCESS) { + MSI_LOG_ERROR << "Start agent grpc server failed, agent ip: " << config.agent_ip + << ", agent port: " << config.agent_port; + return status; + } + status = RegisterAgent(); + if (status != SUCCESS) { + MSI_LOG_ERROR << "Register agent failed, agent ip: " << config.agent_ip << ", agent port: " << config.agent_port + << ", worker ip: " << config.worker_ip << ", worker port: " << config.worker_port; + return status; + } + MSI_LOG_INFO << "Start agent success, servable name: " << config.common_meta.servable_name + << ", rank_id: " << config.rank_id << ", device id: " << config.device_id + << ", model file: " << config.model_file_name + << ", rank table file: " << config.rank_table_json_file_name + << ", group config file: " << config.group_file_name; + return SUCCESS; +} + +Status WorkerAgent::StartGrpcServer() { + grpc_server_.Start(std::make_shared(), config_.agent_ip, config_.agent_port, gRpcMaxMBMsgSize, "Agent"); + return SUCCESS; +} + +Status WorkerAgent::RegisterAgent() { + notify_worker_ = std::make_shared(config_.worker_ip, config_.agent_port, config_.agent_ip, + config_.agent_port); + WorkerAgentSpec spec; + spec.agent_address = config_.agent_ip + ":" + std::to_string(config_.agent_port); + spec.rank_id = config_.rank_id; + spec.batch_size = executor_.GetBatchSize(); + spec.input_infos = executor_.GetInputInfos(); + spec.output_infos = executor_.GetOutputInfos(); + return notify_worker_->Register({spec}); +} + +void WorkerAgent::StopAgent(bool notify_worker) { + exit_notify_worker_ = notify_worker; + ExitSignalHandle::Instance().Stop(); +} + +} // namespace serving +} // namespace mindspore diff --git a/mindspore_serving/ccsrc/worker/distributed_worker/worker_agent.h b/mindspore_serving/ccsrc/worker/distributed_worker/worker_agent.h new file mode 100644 index 0000000..702e791 --- /dev/null +++ b/mindspore_serving/ccsrc/worker/distributed_worker/worker_agent.h @@ -0,0 +1,55 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_SERVING_WORKER_AGENT_H +#define MINDSPORE_SERVING_WORKER_AGENT_H +#include +#include +#include "worker/distributed_worker/agent_executor.h" +#include "proto/ms_agent.pb.h" +#include "proto/ms_agent.grpc.pb.h" +#include "common/grpc_server.h" +#include "worker/distributed_worker/common.h" +#include "worker/distributed_worker/notify_distributed/notify_worker.h" + +namespace mindspore { +namespace serving { +class MS_API WorkerAgent { + public: + static WorkerAgent &Instance(); + Status Clear(); + + Status Run(const proto::DistributedPredictRequest &request, proto::DistributedPredictReply *reply); + + Status StartAgent(const AgentStartUpConfig &config); + + void StopAgent(bool notify_worker = true); + + private: + AgentStartUpConfig config_; + WorkerAgentExecutor executor_; + GrpcServer grpc_server_; + bool exit_notify_worker_ = true; + std::shared_ptr notify_worker_; + + Status StartGrpcServer(); + Status RegisterAgent(); +}; + +} // namespace serving +} // namespace mindspore + +#endif // MINDSPORE_SERVING_WORKER_AGENT_H diff --git a/mindspore_serving/ccsrc/worker/grpc/worker_process.cc b/mindspore_serving/ccsrc/worker/grpc/worker_process.cc index 73c38d1..2d41b03 100644 --- a/mindspore_serving/ccsrc/worker/grpc/worker_process.cc +++ b/mindspore_serving/ccsrc/worker/grpc/worker_process.cc @@ -15,7 +15,6 @@ */ #include "worker/grpc/worker_process.h" -#include "master/dispacther.h" #include "worker/worker.h" namespace mindspore { diff --git a/mindspore_serving/ccsrc/worker/grpc/worker_process.h b/mindspore_serving/ccsrc/worker/grpc/worker_process.h index 450158e..ebdb3c5 100644 --- a/mindspore_serving/ccsrc/worker/grpc/worker_process.h +++ b/mindspore_serving/ccsrc/worker/grpc/worker_process.h @@ -28,7 +28,7 @@ namespace mindspore { namespace serving { // Service Implement -class MSWorkerImpl final : public proto::MSWorker::Service { +class MSWorkerImpl : public proto::MSWorker::Service { public: grpc::Status Predict(grpc::ServerContext *context, const proto::PredictRequest *request, proto::PredictReply *reply) override; diff --git a/mindspore_serving/ccsrc/worker/grpc/worker_server.cc b/mindspore_serving/ccsrc/worker/grpc/worker_server.cc index cc603ad..58880df 100644 --- a/mindspore_serving/ccsrc/worker/grpc/worker_server.cc +++ b/mindspore_serving/ccsrc/worker/grpc/worker_server.cc @@ -21,12 +21,20 @@ namespace mindspore { namespace serving { + MSWorkerServer::~MSWorkerServer() { Stop(); } -MSWorkerServer::MSWorkerServer(const std::string &hostname, int32_t port) { +Status MSWorkerServer::StartWorkerGrpcServer(const std::string &hostname, int32_t port) { + if (in_running_) { + return INFER_STATUS_LOG_ERROR(FAILED) << "Worker grpc server is already running"; + } service_impl_ = std::make_unique(); async_server_ = std::make_unique(hostname, port, service_impl_.get()); + return Init(); } + +MSWorkerServer::MSWorkerServer() = default; + Status MSWorkerServer::Init() { Status status = async_server_->Run("Worker gRPC", gRpcMaxMBMsgSize); if (status != SUCCESS) return status; @@ -40,10 +48,14 @@ Status MSWorkerServer::StartAsyncRpcService() { return status; } Status MSWorkerServer::Stop() { - if (in_running_) { + if (in_running_ && async_server_) { async_server_->Stop(); - grpc_thread_.join(); + if (grpc_thread_.joinable()) { + grpc_thread_.join(); + } } + async_server_ = nullptr; + service_impl_ = nullptr; in_running_ = false; return SUCCESS; } diff --git a/mindspore_serving/ccsrc/worker/grpc/worker_server.h b/mindspore_serving/ccsrc/worker/grpc/worker_server.h index 8bcc057..d02d014 100644 --- a/mindspore_serving/ccsrc/worker/grpc/worker_server.h +++ b/mindspore_serving/ccsrc/worker/grpc/worker_server.h @@ -27,40 +27,53 @@ #include "proto/ms_worker.grpc.pb.h" #include "common/grpc_async_server.h" #include "worker/grpc/worker_process.h" +#include "worker/distributed_worker/distributed_servable.h" namespace mindspore { namespace serving { // Service Implement -class MSWorkerServer { +class MS_API MSWorkerServer { public: enum ServerState { kGdsUninit = 0, kGdsInitializing, kGdsRunning, kGdsStopped }; - MSWorkerServer(const std::string &hostname, int32_t port); - ~MSWorkerServer(); - - Status Init(); + MSWorkerServer(); + virtual ~MSWorkerServer(); + virtual Status StartWorkerGrpcServer(const std::string &hostname, int32_t port); Status Stop(); - Status StartAsyncRpcService(); - + protected: bool in_running_ = false; std::thread grpc_thread_; - std::unique_ptr service_impl_; - std::unique_ptr async_server_; + std::unique_ptr service_impl_ = nullptr; + std::unique_ptr async_server_ = nullptr; + + Status Init(); + Status StartAsyncRpcService(); }; class WorkerServiceContext { public: enum class STATE : int8_t { CREATE = 1, PROCESS = 2, FINISH = 3 }; + + WorkerServiceContext(MSWorkerImpl *service_impl, proto::MSWorker::AsyncService *async_service, + grpc::ServerCompletionQueue *cq) + : service_impl_(service_impl), async_service_(async_service), cq_(cq) { + state_ = STATE::CREATE; + } virtual ~WorkerServiceContext() {} + bool JudgeFinish() { return state_ == STATE::FINISH; } + virtual void StartEnqueueRequest() = 0; virtual void HandleRequest() = 0; - virtual bool JudgeFinish() = 0; + protected: + MSWorkerImpl *service_impl_; + proto::MSWorker::AsyncService *async_service_; + grpc::ServerCompletionQueue *cq_; + grpc::ServerContext ctx_; - public: STATE state_; }; @@ -68,9 +81,7 @@ class WorkerPredictContext : public WorkerServiceContext { public: WorkerPredictContext(MSWorkerImpl *service_impl, proto::MSWorker::AsyncService *async_service, grpc::ServerCompletionQueue *cq) - : service_impl_(service_impl), async_service_(async_service), cq_(cq), responder_(&ctx_) { - state_ = STATE::CREATE; - } + : WorkerServiceContext(service_impl, async_service, cq), responder_(&ctx_) {} ~WorkerPredictContext() = default; @@ -93,13 +104,7 @@ class WorkerPredictContext : public WorkerServiceContext { responder_.Finish(response_, status, this); } - bool JudgeFinish() override { return state_ == STATE::FINISH; } - private: - MSWorkerImpl *service_impl_; - proto::MSWorker::AsyncService *async_service_; - grpc::ServerCompletionQueue *cq_; - grpc::ServerContext ctx_; grpc::ServerAsyncResponseWriter responder_; proto::PredictRequest request_; proto::PredictReply response_; @@ -109,9 +114,7 @@ class WorkerExitContext : public WorkerServiceContext { public: WorkerExitContext(MSWorkerImpl *service_impl, proto::MSWorker::AsyncService *async_service, grpc::ServerCompletionQueue *cq) - : service_impl_(service_impl), async_service_(async_service), cq_(cq), responder_(&ctx_) { - state_ = STATE::CREATE; - } + : WorkerServiceContext(service_impl, async_service, cq), responder_(&ctx_) {} ~WorkerExitContext() = default; @@ -134,13 +137,7 @@ class WorkerExitContext : public WorkerServiceContext { responder_.Finish(response_, status, this); } - bool JudgeFinish() override { return state_ == STATE::FINISH; } - private: - MSWorkerImpl *service_impl_; - proto::MSWorker::AsyncService *async_service_; - grpc::ServerCompletionQueue *cq_; - grpc::ServerContext ctx_; grpc::ServerAsyncResponseWriter responder_; proto::ExitRequest request_; proto::ExitReply response_; @@ -174,7 +171,7 @@ class WorkerGrpcServer : public GrpcAsyncServer { return SUCCESS; } - private: + protected: MSWorkerImpl *service_impl_; proto::MSWorker::AsyncService svc_; }; diff --git a/mindspore_serving/ccsrc/worker/inference/inference.h b/mindspore_serving/ccsrc/worker/inference/inference.h index 72df989..337155c 100644 --- a/mindspore_serving/ccsrc/worker/inference/inference.h +++ b/mindspore_serving/ccsrc/worker/inference/inference.h @@ -52,132 +52,6 @@ enum DeviceType { kDeviceTypeCpu, }; -class MS_API InferSession { - public: - InferSession() = default; - virtual ~InferSession() = default; - virtual Status InitEnv(DeviceType device_type, uint32_t device_id, - const std::map &other_options) = 0; - virtual Status FinalizeEnv() = 0; - - virtual Status LoadModelFromFile(serving::DeviceType device_type, uint32_t device_id, const std::string &file_name, - ModelType model_type, const std::vector &without_batch_dim_inputs, - const std::map &other_options, uint32_t *model_id) = 0; - - virtual Status UnloadModel(uint32_t model_id) = 0; - // override this method to avoid request/reply data copy - virtual Status ExecuteModel(uint32_t model_id, const RequestBase &request, ReplyBase *reply) = 0; - virtual Status ExecuteModel(uint32_t model_id, const std::vector &request, - std::vector *reply) { - VectorTensorPtrWrapRequest wrap_request(request); - VectorTensorPtrWrapReply wrap_reply(reply, []() { return std::make_shared(); }); - return ExecuteModel(model_id, wrap_request, &wrap_reply); - } - - virtual std::vector GetInputInfos(uint32_t model_id) const = 0; - virtual std::vector GetOutputInfos(uint32_t model_id) const = 0; - virtual ssize_t GetBatchSize(uint32_t model_id) const = 0; - virtual bool CheckModelSupport(DeviceType device_type, ModelType model_type) const { return true; } -}; - -struct InferSessionRegInfo { - std::shared_ptr session; - ModelType model_type; - int priority; -}; - -class MS_API InferSessionStorage { - public: - void Register(DeviceType device_type, ModelType model_type, const std::shared_ptr &session, - int priority) { - auto &list = session_map_[device_type]; - InferSessionRegInfo info{session, model_type, priority}; - list.push_back(info); - } - - std::shared_ptr Get(DeviceType device_type, ModelType model_type, DeviceType *specified_device_type) { - MSI_EXCEPTION_IF_NULL(specified_device_type); - if (device_type == kDeviceTypeNotSpecified) { - for (auto &item_device : session_map_) { - std::shared_ptr ret_session = GetSession(item_device.second, item_device.first, model_type); - if (ret_session) { - *specified_device_type = item_device.first; - return ret_session; - } - } - return nullptr; - } else if (device_type == kDeviceTypeAscend) { - auto ascend_list = {kDeviceTypeAscendCL, kDeviceTypeAscendMS}; - for (auto ascend_type : ascend_list) { - auto it = session_map_.find(ascend_type); - if (it == session_map_.end()) { - continue; - } - auto session_ret = GetSession(it->second, ascend_type, model_type); - if (session_ret != nullptr) { - *specified_device_type = ascend_type; - return session_ret; - } - } - return nullptr; - } - auto it = session_map_.find(device_type); - if (it == session_map_.end()) { - return nullptr; - } - std::shared_ptr session_ret; - session_ret = GetSession(it->second, device_type, model_type); - *specified_device_type = device_type; - return session_ret; - } - - static InferSessionStorage &Instance() { - static InferSessionStorage instance; - return instance; - } - - private: - std::unordered_map> session_map_; - - std::shared_ptr GetSession(const std::vector &session_list, DeviceType device_type, - ModelType model_type) { - std::shared_ptr session_ret = nullptr; - int cur_priority = INT32_MIN; - for (auto &item : session_list) { - if (item.model_type != model_type) { - continue; - } - if (session_ret == nullptr || cur_priority < item.priority) { - if (!item.session->CheckModelSupport(device_type, model_type)) { - MSI_LOG_INFO << "CheckModelSupport for " << device_type << " " << model_type << " failed, skipped"; - continue; - } - cur_priority = item.priority; - session_ret = item.session; - } - } - return session_ret; - } -}; - -class MS_API InferSessionRegister { - public: - InferSessionRegister(DeviceType device_type, ModelType model_type, const std::shared_ptr &session, - int priority) { - InferSessionStorage::Instance().Register(device_type, model_type, session, priority); - } -}; - -#define REGISTER_INFER_SEESION_UNIQUE(device_type, model_type, cls_name, priority, index) \ - static mindspore::serving::InferSessionRegister g_register_session_##cls_name##_##index( \ - device_type, model_type, std::make_shared(), priority); - -#define REGISTER_INFER_SEESION_HELPER(device_type, model_type, cls_name, priority, index) \ - REGISTER_INFER_SEESION_UNIQUE(device_type, model_type, cls_name, priority, index) - -#define REGISTER_INFER_SEESION(device_type, model_type, cls_name, priority) \ - REGISTER_INFER_SEESION_HELPER(device_type, model_type, cls_name, priority, __COUNTER__); - static inline LogStream &operator<<(LogStream &stream, DeviceType device_type) { switch (device_type) { case kDeviceTypeAscend: diff --git a/mindspore_serving/ccsrc/worker/inference/mindspore_model_wrap.cc b/mindspore_serving/ccsrc/worker/inference/mindspore_model_wrap.cc index 40f8c81..1affc0e 100644 --- a/mindspore_serving/ccsrc/worker/inference/mindspore_model_wrap.cc +++ b/mindspore_serving/ccsrc/worker/inference/mindspore_model_wrap.cc @@ -26,16 +26,6 @@ namespace mindspore { namespace serving { -Status MindSporeModelWrap::InitEnv(serving::DeviceType device_type, uint32_t device_id, - const std::map &other_options) { - return SUCCESS; -} - -Status MindSporeModelWrap::FinalizeEnv() { - model_map_.clear(); - return SUCCESS; -} - mindspore::DataType TransInferDataType2ApiTypeId(DataType data_type) { const std::map type2id_map{ {serving::kMSI_Unknown, mindspore::DataType::kTypeUnknown}, @@ -81,11 +71,9 @@ DataType TransTypeId2InferDataType(mindspore::DataType type_id) { } Status MindSporeModelWrap::LoadModelFromFile(serving::DeviceType device_type, uint32_t device_id, - const std::string &file_name, ModelType model_type, + const std::string &file_name, ModelType model_type, bool with_batch_dim, const std::vector &without_batch_dim_inputs, - const std::map &other_options, - uint32_t *model_id) { - MSI_EXCEPTION_IF_NULL(model_id); + const std::map &other_options) { std::string device_type_str; if (device_type == kDeviceTypeAscendMS) { device_type_str = mindspore::kDeviceTypeAscend910; @@ -113,18 +101,18 @@ Status MindSporeModelWrap::LoadModelFromFile(serving::DeviceType device_type, ui << "', device_id: " << device_id << ", model type: " << model_type << ", options: " << other_options; return Status(FAILED, status.ToString()); } - model_index_++; - *model_id = model_index_; ApiModelInfo api_model_info; api_model_info.model = model; api_model_info.device_type = device_type_str; api_model_info.device_id = device_id; + api_model_info.with_batch_dim = with_batch_dim; api_model_info.without_batch_dim_inputs = without_batch_dim_inputs; auto st = GetModelInfos(&api_model_info); if (st != SUCCESS) { return st; } - model_map_[*model_id] = api_model_info; + GetModelBatchSize(&api_model_info); + model_ = api_model_info; MSI_LOG_INFO << "Load model from file success, model file: " << file_name << ", device_type: '" << device_type_str << "', device_id: " << device_id << ", model type: " << model_type << ", options: " << other_options; return SUCCESS; @@ -169,20 +157,6 @@ Status MindSporeModelWrap::GetModelInfos(ApiModelInfo *api_model_info) { MSI_EXCEPTION_IF_NULL(api_model_info); auto model = api_model_info->model; - bool first_dim_same = true; - auto find_batch_size = [&first_dim_same, api_model_info](const std::vector &shape) { - if (first_dim_same) { - if (shape.empty()) { - first_dim_same = false; - } else if (api_model_info->batch_size != 0) { - if (api_model_info->batch_size != shape[0]) { - first_dim_same = false; - } - } else { - api_model_info->batch_size = shape[0]; - } - } - }; auto get_tensor_info_from_tensor = [](const mindspore::MSTensor &ms_tensor) { serving::TensorInfo tensor_info; tensor_info.shape = ms_tensor.Shape(); @@ -204,10 +178,6 @@ Status MindSporeModelWrap::GetModelInfos(ApiModelInfo *api_model_info) { return INFER_STATUS_LOG_ERROR(FAILED) << "Unknown input mindspore data type " << static_cast(info.DataType()); } - const auto &list = api_model_info->without_batch_dim_inputs; - if (std::find(list.begin(), list.end(), i) == list.end()) { - find_batch_size(tensor_info.shape); - } api_model_info->input_tensor_infos.push_back(tensor_info); api_model_info->input_names.push_back(info.Name()); } @@ -220,27 +190,59 @@ Status MindSporeModelWrap::GetModelInfos(ApiModelInfo *api_model_info) { return INFER_STATUS_LOG_ERROR(FAILED) << "Unknown output mindspore data type " << static_cast(info.DataType()); } - find_batch_size(tensor_info.shape); api_model_info->output_tensor_infos.push_back(tensor_info); api_model_info->output_names.push_back(info.Name()); } } + return SUCCESS; +} + +void MindSporeModelWrap::GetModelBatchSize(ApiModelInfo *api_model_info) { + MSI_EXCEPTION_IF_NULL(api_model_info); + bool first_dim_same = true; + auto find_batch_size = [&first_dim_same, api_model_info](const std::vector &shape) { + if (!api_model_info->with_batch_dim) { + first_dim_same = false; + return; + } + if (!first_dim_same) { + return; + } + if (shape.empty()) { + first_dim_same = false; + return; + } + if (api_model_info->batch_size != 0) { + if (api_model_info->batch_size != shape[0]) { + first_dim_same = false; + } + } else { + api_model_info->batch_size = shape[0]; + } + }; + + auto list = api_model_info->without_batch_dim_inputs; + auto size = api_model_info->input_tensor_infos.size(); + for (size_t i = 0; i < size; i++) { + if (std::find(list.begin(), list.end(), i) == list.end()) { + auto &info = api_model_info->input_tensor_infos[i]; + find_batch_size(info.shape); + } + } + for (auto &info : api_model_info->output_tensor_infos) { + find_batch_size(info.shape); + } if (!first_dim_same) { api_model_info->batch_size = 0; } - return SUCCESS; } -Status MindSporeModelWrap::UnloadModel(uint32_t model_id) { - auto it = model_map_.find(model_id); - if (it == model_map_.end()) { - return INFER_STATUS_LOG_ERROR(FAILED) << "Invalid model id " << model_id; - } - model_map_.erase(it); +Status MindSporeModelWrap::UnloadModel() { + model_.model = nullptr; return SUCCESS; } -Status MindSporeModelWrap::ExecuteModel(uint32_t model_id, const RequestBase &request, serving::ReplyBase *reply) { +Status MindSporeModelWrap::ExecuteModel(const RequestBase &request, serving::ReplyBase *reply) { MSI_EXCEPTION_IF_NULL(reply); FuncMakeInBuffer func_in = [&request](size_t index, const std::string &name) { auto input_tensor = request[index]; @@ -260,11 +262,10 @@ Status MindSporeModelWrap::ExecuteModel(uint32_t model_id, const RequestBase &re tensor->set_data_type(data_type); tensor->set_shape(shape); }; - return ExecuteModelCommon(model_id, request.size(), func_in, func_out); + return ExecuteModelCommon(request.size(), func_in, func_out); } -Status MindSporeModelWrap::ExecuteModel(uint32_t model_id, const std::vector &request, - std::vector *reply) { +Status MindSporeModelWrap::ExecuteModel(const std::vector &request, std::vector *reply) { MSI_EXCEPTION_IF_NULL(reply); FuncMakeInBuffer func_in = [&request](size_t index, const std::string &name) { auto &input_tensor = request[index]; @@ -282,16 +283,15 @@ Status MindSporeModelWrap::ExecuteModel(uint32_t model_id, const std::vectorset_shape(shape); reply->push_back(tensor); }; - return ExecuteModelCommon(model_id, request.size(), func_in, func_out); + return ExecuteModelCommon(request.size(), func_in, func_out); } -Status MindSporeModelWrap::ExecuteModelCommon(uint32_t model_id, size_t request_size, const FuncMakeInBuffer &in_func, +Status MindSporeModelWrap::ExecuteModelCommon(size_t request_size, const FuncMakeInBuffer &in_func, const FuncMakeOutTensor &out_func) { - auto it = model_map_.find(model_id); - if (it == model_map_.end()) { - return INFER_STATUS_LOG_ERROR(FAILED) << "Invalid model id " << model_id; + if (model_.model == nullptr) { + return INFER_STATUS_LOG_ERROR(FAILED) << "Model is not loaded"; } - auto &model_info = it->second; + auto &model_info = model_; auto model = model_info.model; auto &input_names = model_info.input_names; auto &output_names = model_info.output_names; @@ -327,43 +327,25 @@ Status MindSporeModelWrap::ExecuteModelCommon(uint32_t model_id, size_t request_ return SUCCESS; } -std::vector MindSporeModelWrap::GetInputInfos(uint32_t model_id) const { - auto it = model_map_.find(model_id); - if (it == model_map_.end()) { - MSI_LOG_ERROR << "Invalid model id " << model_id; - return {}; - } - auto &model_info = it->second; - return model_info.input_tensor_infos; -} +std::vector MindSporeModelWrap::GetInputInfos() const { return model_.input_tensor_infos; } -std::vector MindSporeModelWrap::GetOutputInfos(uint32_t model_id) const { - auto it = model_map_.find(model_id); - if (it == model_map_.end()) { - MSI_LOG_ERROR << "Invalid model id " << model_id; - return {}; - } - auto &model_info = it->second; - return model_info.output_tensor_infos; -} +std::vector MindSporeModelWrap::GetOutputInfos() const { return model_.output_tensor_infos; } -ssize_t MindSporeModelWrap::GetBatchSize(uint32_t model_id) const { - auto it = model_map_.find(model_id); - if (it == model_map_.end()) { - MSI_LOG_ERROR << "Invalid model id " << model_id; - return {}; - } - auto &model_info = it->second; - return model_info.batch_size; -} +ssize_t MindSporeModelWrap::GetBatchSize() const { return model_.batch_size; } bool MindSporeModelWrap::CheckModelSupport(DeviceType device_type, ModelType model_type) const { std::string device_type_str; switch (device_type) { case kDeviceTypeAscendMS: + if (model_type != kMindIR) { + return false; + } device_type_str = mindspore::kDeviceTypeAscend910; break; case kDeviceTypeAscendCL: + if (model_type != kMindIR && model_type != kOM) { + return false; + } device_type_str = mindspore::kDeviceTypeAscend310; break; default: @@ -378,9 +360,5 @@ ApiBufferTensorWrap::ApiBufferTensorWrap(const mindspore::MSTensor &tensor) : te ApiBufferTensorWrap::~ApiBufferTensorWrap() = default; -REGISTER_INFER_SEESION(serving::kDeviceTypeAscendCL, kOM, MindSporeModelWrap, 1); -REGISTER_INFER_SEESION(serving::kDeviceTypeAscendCL, kMindIR, MindSporeModelWrap, 1); -REGISTER_INFER_SEESION(serving::kDeviceTypeAscendMS, kMindIR, MindSporeModelWrap, 1); - } // namespace serving } // namespace mindspore diff --git a/mindspore_serving/ccsrc/worker/inference/mindspore_model_wrap.h b/mindspore_serving/ccsrc/worker/inference/mindspore_model_wrap.h index 14432ec..02f6f18 100644 --- a/mindspore_serving/ccsrc/worker/inference/mindspore_model_wrap.h +++ b/mindspore_serving/ccsrc/worker/inference/mindspore_model_wrap.h @@ -34,54 +34,46 @@ struct ApiModelInfo { std::vector input_tensor_infos; std::vector output_names; std::vector output_tensor_infos; - std::shared_ptr model; + std::shared_ptr model = nullptr; uint32_t batch_size = 0; std::string device_type; uint32_t device_id = 0; + bool with_batch_dim = false; std::vector without_batch_dim_inputs; }; -class MindSporeModelWrap : public InferSession { +class MindSporeModelWrap { public: MindSporeModelWrap() = default; ~MindSporeModelWrap() = default; - Status InitEnv(serving::DeviceType device_type, uint32_t device_id, - const std::map &other_options) override; - - Status FinalizeEnv() override; - Status LoadModelFromFile(serving::DeviceType device_type, uint32_t device_id, const std::string &file_name, - ModelType model_type, const std::vector &without_batch_dim_inputs, - const std::map &other_options, uint32_t *model_id) override; - - Status UnloadModel(uint32_t model_id) override; + ModelType model_type, bool with_batch_dim, const std::vector &without_batch_dim_inputs, + const std::map &other_options); - // override this method to avoid request/reply data copy - Status ExecuteModel(uint32_t model_id, const RequestBase &request, ReplyBase *reply) override; - Status ExecuteModel(uint32_t model_id, const std::vector &request, - std::vector *reply) override; + Status UnloadModel(); + Status ExecuteModel(const RequestBase &request, ReplyBase *reply); + Status ExecuteModel(const std::vector &request, std::vector *reply); - std::vector GetInputInfos(uint32_t model_id) const override; + std::vector GetInputInfos() const; - std::vector GetOutputInfos(uint32_t model_id) const override; + std::vector GetOutputInfos() const; - ssize_t GetBatchSize(uint32_t model_id) const override; + ssize_t GetBatchSize() const; - bool CheckModelSupport(DeviceType device_type, ModelType model_type) const override; + bool CheckModelSupport(DeviceType device_type, ModelType model_type) const; private: - std::unordered_map model_map_; - uint32_t model_index_ = 0; + ApiModelInfo model_; using FuncMakeInBuffer = std::function; using FuncMakeOutTensor = std::function &shape)>; - Status ExecuteModelCommon(uint32_t model_id, size_t request_size, const FuncMakeInBuffer &in_func, - const FuncMakeOutTensor &out_func); + Status ExecuteModelCommon(size_t request_size, const FuncMakeInBuffer &in_func, const FuncMakeOutTensor &out_func); Status GetModelInfos(ApiModelInfo *model_info); std::shared_ptr TransformModelContext(const std::map &other_options); + void GetModelBatchSize(ApiModelInfo *model_info); }; class ApiBufferTensorWrap : public TensorBase { diff --git a/mindspore_serving/ccsrc/worker/local_servable/local_sevable.cc b/mindspore_serving/ccsrc/worker/local_servable/local_sevable.cc new file mode 100644 index 0000000..6f73444 --- /dev/null +++ b/mindspore_serving/ccsrc/worker/local_servable/local_sevable.cc @@ -0,0 +1,254 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "worker/local_servable/local_sevable.h" +#include +#include +#include +#include +#include +#include "common/tensor.h" +#include "common/file_system_operation.h" +#include "worker/context.h" + +namespace { +static const char *kVersionStrategyLatest = "latest"; +static const char *kVersionStrategySpecific = "specific"; +} // namespace + +namespace mindspore::serving { + +LocalModelServable::~LocalModelServable() { Clear(); } + +std::string LocalModelServable::GetServableName() const { return servable_name_; } + +uint64_t LocalModelServable::GetServableVersion() const { return version_number_; } + +Status LocalModelServable::Predict(const std::vector &input, std::vector *output) { + if (!model_loaded_) { + MSI_LOG_EXCEPTION << "Model has not been loaded"; + } + return session_.ExecuteModel(input, output); +} + +std::vector LocalModelServable::GetInputInfos() const { + if (!model_loaded_) { + MSI_LOG_EXCEPTION << "Model has not been loaded"; + } + return session_.GetInputInfos(); +} + +std::vector LocalModelServable::GetOutputInfos() const { + if (!model_loaded_) { + MSI_LOG_EXCEPTION << "Model has not been loaded"; + } + return session_.GetOutputInfos(); +} + +uint64_t LocalModelServable::GetBatchSize() const { + if (!model_loaded_) { + MSI_LOG_EXCEPTION << "Model has not been loaded"; + } + return session_.GetBatchSize(); +} + +Status LocalModelServable::StartServable(const std::string &servable_directory, const std::string &servable_name, + uint64_t version_number) { + if (model_loaded_) { + MSI_LOG_EXCEPTION << "Model has loaded"; + } + base_spec_.servable_directory = servable_directory; + base_spec_.servable_name = servable_name; + base_spec_.version_number = version_number; + + std::string version_strategy; + if (version_number == 0) { + version_strategy = kVersionStrategyLatest; + } else { + version_strategy = kVersionStrategySpecific; + } + Status status; + ServableSignature signature; + if (!ServableStorage::Instance().GetServableDef(servable_name, &signature)) { + return INFER_STATUS_LOG_ERROR(FAILED) << "Servable '" << servable_name << "' has not been registered"; + } + status = InitDevice(signature.servable_meta.local_meta.model_format, {}); + if (status != SUCCESS) { + MSI_LOG_ERROR << "Init env failed"; + return status; + } + + std::vector real_versions; + status = LoadServableConfig(base_spec_, version_strategy, &real_versions); + if (status != SUCCESS) { + return INFER_STATUS_LOG_ERROR(FAILED) + << "Start servable failed, there is no servable of the specified version number, specified version number: " + << version_number << ", servable directory: '" << base_spec_.servable_directory << "', servable name: '" + << base_spec_.servable_name + << "'. version number is a positive integer(started from 1) and 0 represents the maximum version number."; + } + auto real_version_number = real_versions[0]; + status = LoadModel(real_version_number); + if (status != SUCCESS) { + return status; + } + servable_name_ = base_spec_.servable_name; + version_number_ = real_version_number; + model_loaded_ = true; + MSI_LOG_INFO << status.StatusMessage(); + std::cout << status.StatusMessage() << std::endl; + return SUCCESS; +} + +void LocalModelServable::GetVersions(const LoadServableSpec &servable_spec, std::vector *real_versions) { + MSI_EXCEPTION_IF_NULL(real_versions); + // define version_strategy:"specific","latest","multi" + if (version_strategy_ == kVersionStrategySpecific) { + real_versions->push_back(servable_spec.version_number); + return; + } + auto trans_to_integer = [](const std::string &str) -> uint32_t { + uint32_t parsed_value = 0; + for (auto c : str) { + if (c < '0' || c > '9') { + return 0; + } + parsed_value = parsed_value * 10 + c - '0'; + } + if (std::to_string(parsed_value) != str) { + return 0; + } + return parsed_value; + }; + uint64_t newest_version = 0; + std::string model_path = servable_spec.servable_directory + "/" + servable_spec.servable_name; + auto sub_dir = GetAllSubDirsNotFullPath(model_path); + static std::set ignore_dir; + for (const auto &dir : sub_dir) { + if (dir == "__pycache__") continue; + auto version_parse = trans_to_integer(dir); + if (version_parse == 0) { + if (ignore_dir.emplace(servable_spec.servable_directory + dir).second) { + MSI_LOG_INFO << "Ignore directory " << dir << ", model_directory " << servable_spec.servable_directory + << ", model_name " << servable_spec.servable_name; + } + continue; + } + real_versions->push_back(version_parse); + if (version_parse > newest_version) { + newest_version = version_parse; + } + } + if (version_strategy_ == kVersionStrategyLatest) { + real_versions->clear(); + if (newest_version != 0) { + real_versions->push_back(newest_version); + } + } +} + +Status LocalModelServable::LoadServableConfig(const LoadServableSpec &servable_spec, + const std::string &version_strategy, + std::vector *real_versions) { + MSI_EXCEPTION_IF_NULL(real_versions); + auto model_directory = servable_spec.servable_directory; + auto model_name = servable_spec.servable_name; + + if (!DirOrFileExist(model_directory + "/" + model_name)) { + return INFER_STATUS_LOG_ERROR(FAILED) + << "Model not found, model_directory " << model_directory << ", model_name " << model_name; + } + std::string model_path = model_directory + "/" + model_name; + auto version_directory = [model_path](int64_t version_number) { + return model_path + "/" + std::to_string(version_number); + }; + version_strategy_ = version_strategy; + // version_strategy:"specific","latest","multi" + GetVersions(servable_spec, real_versions); + if (real_versions->size() == 0) { + return INFER_STATUS_LOG_ERROR(FAILED) + << "Not found invalid model version , model_directory " << model_directory << ", model_name " << model_name; + } + for (auto real_version_number : *real_versions) { + if (!DirOrFileExist(version_directory(real_version_number))) { + return INFER_STATUS_LOG_ERROR(FAILED) << "Open failed for version " << real_version_number << ", model_directory " + << model_directory << ", model_name " << model_name; + } + } + return SUCCESS; +} + +Status LocalModelServable::InitDevice(ModelType model_type, const std::map &other_options) { + Status status; + auto context = ServableContext::Instance(); + DeviceType device_type = ServableContext::Instance()->GetDeviceType(); + auto get_support_device_type = [this, device_type, model_type]() { + std::vector support_device_list; + if (device_type == kDeviceTypeNotSpecified || device_type == kDeviceTypeAscend) { + auto ascend_list = {kDeviceTypeAscendCL, kDeviceTypeAscendMS}; + for (auto item : ascend_list) { + if (session_.CheckModelSupport(item, model_type)) { + return item; + } + } + } else if (device_type == kDeviceTypeAscendCL || device_type == kDeviceTypeAscendMS) { + if (session_.CheckModelSupport(device_type, model_type)) { + return device_type; + } + } + return kDeviceTypeNotSpecified; + }; + auto support_device_type = get_support_device_type(); + if (support_device_type == kDeviceTypeNotSpecified) { + return INFER_STATUS_LOG_ERROR(FAILED) + << "Not support device type " << device_type << " and model type " << model_type + << ". Ascend 910 supports MindIR model and Ascend 310 supports OM, MindIR model"; + } + context->SetDeviceType(support_device_type); + return SUCCESS; +} + +Status LocalModelServable::LoadModel(uint64_t version_number) { + ServableSignature signature; + if (!ServableStorage::Instance().GetServableDef(base_spec_.servable_name, &signature)) { + return INFER_STATUS_LOG_ERROR(FAILED) << "Servable " << base_spec_.servable_name << " has not been registered"; + } + const auto &servable_meta = signature.servable_meta; + const auto &common_meta = servable_meta.common_meta; + const auto &local_meta = servable_meta.local_meta; + std::string model_file_name = base_spec_.servable_directory + "/" + base_spec_.servable_name + "/" + + std::to_string(version_number) + "/" + local_meta.servable_file; + auto context = ServableContext::Instance(); + Status status = session_.LoadModelFromFile(context->GetDeviceType(), context->GetDeviceId(), model_file_name, + local_meta.model_format, common_meta.with_batch_dim, + common_meta.without_batch_dim_inputs, local_meta.load_options); + if (status != SUCCESS) { + return INFER_STATUS_LOG_ERROR(FAILED) + << "Load model failed, servable directory: '" << base_spec_.servable_directory << "', servable name: '" + << base_spec_.servable_name << "', servable file: '" << local_meta.servable_file << "', version number " + << version_number << ", options " << local_meta.load_options; + } + return SUCCESS; +} + +void LocalModelServable::Clear() { + if (model_loaded_) { + session_.UnloadModel(); + } + model_loaded_ = false; +} + +} // namespace mindspore::serving diff --git a/mindspore_serving/ccsrc/worker/local_servable/local_sevable.h b/mindspore_serving/ccsrc/worker/local_servable/local_sevable.h new file mode 100644 index 0000000..eb43356 --- /dev/null +++ b/mindspore_serving/ccsrc/worker/local_servable/local_sevable.h @@ -0,0 +1,69 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_SERVING_WORKER_ASCEND_SERVABLE_H +#define MINDSPORE_SERVING_WORKER_ASCEND_SERVABLE_H + +#include +#include +#include +#include + +#include "common/serving_common.h" +#include "common/instance.h" +#include "common/servable.h" +#include "worker/sevable_base.h" +#include "worker/inference/inference.h" +#include "worker/inference/mindspore_model_wrap.h" + +namespace mindspore::serving { + +class MS_API LocalModelServable : public ServableBase { + public: + LocalModelServable() = default; + ~LocalModelServable() override; + + Status Predict(const std::vector &input, std::vector *output) override; + + std::vector GetInputInfos() const override; + std::vector GetOutputInfos() const override; + uint64_t GetBatchSize() const override; + + Status StartServable(const std::string &servable_directory, const std::string &servable_name, + uint64_t version_number); + Status InitDevice(ModelType model_type, const std::map &other_options); + std::string GetServableName() const override; + uint64_t GetServableVersion() const override; + void Clear() override; + + private: + LoadServableSpec base_spec_; + std::string servable_name_; + uint64_t version_number_ = 0; + + MindSporeModelWrap session_; + std::string version_strategy_; + bool model_loaded_ = false; + + void GetVersions(const LoadServableSpec &servable_spec, std::vector *real_versions); + Status LoadServableConfig(const LoadServableSpec &servable_spec, const std::string &version_strategy, + std::vector *real_version_number); + Status LoadModel(uint64_t version); +}; + +} // namespace mindspore::serving + +#endif // MINDSPORE_SERVING_WORKER_ASCEND_SERVABLE_H diff --git a/mindspore_serving/ccsrc/worker/model.cc b/mindspore_serving/ccsrc/worker/model.cc deleted file mode 100644 index 814bd02..0000000 --- a/mindspore_serving/ccsrc/worker/model.cc +++ /dev/null @@ -1,33 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "worker/model.h" -#include -#include "mindspore_serving/ccsrc/common/tensor.h" - -namespace mindspore::serving { - -Status AscendModelServable::Predict(const std::vector &input, std::vector *output) { - return session_->ExecuteModel(model_id_, input, output); -} - -std::vector AscendModelServable::GetInputInfos() const { return session_->GetInputInfos(model_id_); } - -std::vector AscendModelServable::GetOutputInfos() const { return session_->GetOutputInfos(model_id_); } - -uint64_t AscendModelServable::GetBatchSize() const { return session_->GetBatchSize(model_id_); } - -} // namespace mindspore::serving diff --git a/mindspore_serving/ccsrc/worker/model.h b/mindspore_serving/ccsrc/worker/sevable_base.h similarity index 63% rename from mindspore_serving/ccsrc/worker/model.h rename to mindspore_serving/ccsrc/worker/sevable_base.h index dcb23e0..c3acd00 100644 --- a/mindspore_serving/ccsrc/worker/model.h +++ b/mindspore_serving/ccsrc/worker/sevable_base.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_SERVING_WORKER_MODEL_H -#define MINDSPORE_SERVING_WORKER_MODEL_H +#ifndef MINDSPORE_SERVING_WORKER_SERVABLE_BASE_H +#define MINDSPORE_SERVING_WORKER_SERVABLE_BASE_H #include #include @@ -39,25 +39,11 @@ class ServableBase { virtual std::vector GetInputInfos() const = 0; virtual std::vector GetOutputInfos() const = 0; virtual uint64_t GetBatchSize() const = 0; -}; - -class AscendModelServable : public ServableBase { - public: - AscendModelServable(const std::shared_ptr &session, uint32_t model_id) - : session_(session), model_id_(model_id) {} - ~AscendModelServable() = default; - - Status Predict(const std::vector &input, std::vector *output) override; - - std::vector GetInputInfos() const override; - std::vector GetOutputInfos() const override; - uint64_t GetBatchSize() const override; - - private: - std::shared_ptr session_{nullptr}; - uint32_t model_id_ = 0; + virtual std::string GetServableName() const = 0; + virtual uint64_t GetServableVersion() const = 0; + virtual void Clear() = 0; }; } // namespace mindspore::serving -#endif // MINDSPORE_SERVING_WORKER_MODEL_H +#endif // MINDSPORE_SERVING_WORKER_SERVABLE_BASE_H diff --git a/mindspore_serving/ccsrc/worker/work_executor.cc b/mindspore_serving/ccsrc/worker/work_executor.cc index 0453fb6..2ec760e 100644 --- a/mindspore_serving/ccsrc/worker/work_executor.cc +++ b/mindspore_serving/ccsrc/worker/work_executor.cc @@ -49,15 +49,15 @@ Status WorkExecutor::CheckSevableSignature() { if (servable_declare_.methods.empty()) { return INFER_STATUS_LOG_ERROR(FAILED) << "There is no method registered for servable"; } - if (input_infos.size() != servable_declare_.servable_meta.inputs_count) { - return INFER_STATUS_LOG_ERROR(FAILED) - << "The inputs count " << servable_declare_.servable_meta.inputs_count << " registered in method " - << "not equal to the count " << input_infos.size() << " defined in servable"; + const auto &common_meta = servable_declare_.servable_meta.common_meta; + if (input_infos.size() != common_meta.inputs_count) { + return INFER_STATUS_LOG_ERROR(FAILED) << "The inputs count " << common_meta.inputs_count << " registered in method " + << "not equal to the count " << input_infos.size() << " defined in servable"; } const auto &output_infos = output_infos_; - if (output_infos.size() != servable_declare_.servable_meta.outputs_count) { + if (output_infos.size() != common_meta.outputs_count) { return INFER_STATUS_LOG_ERROR(FAILED) - << "The outputs count " << servable_declare_.servable_meta.outputs_count << " registered in method " + << "The outputs count " << common_meta.outputs_count << " registered in method " << "not equal to the count " << output_infos.size() << " defined in servable"; } MSI_LOG_INFO << "Model input infos: count " << input_infos.size(); @@ -68,7 +68,7 @@ Status WorkExecutor::CheckSevableSignature() { for (auto &item : output_infos) { MSI_LOG_INFO << item.shape << ", " << item.data_type << ", " << item.size; } - if (servable_declare_.servable_meta.with_batch_dim) { + if (common_meta.with_batch_dim) { if (model_batch_size_ == 0) { return INFER_STATUS_LOG_ERROR(FAILED) << "Servable batch size cannot be " << model_batch_size_; } @@ -104,7 +104,7 @@ Status WorkExecutor::Init(const ServableSignature &servable_declare, const std:: servable_ = servable; input_infos_ = servable_->GetInputInfos(); output_infos_ = servable_->GetOutputInfos(); - if (servable_declare_.servable_meta.with_batch_dim) { + if (servable_declare_.servable_meta.common_meta.with_batch_dim) { model_batch_size_ = servable_->GetBatchSize(); } else { model_batch_size_ = 1; @@ -389,7 +389,7 @@ Status WorkExecutor::PostPredict(const std::vector &inputs, const std: MSI_LOG_EXCEPTION << "Output result data size cannot be 0"; } auto shape = item->shape(); - if (servable_declare_.servable_meta.with_batch_dim) { + if (servable_declare_.servable_meta.common_meta.with_batch_dim) { if (shape.empty() || shape[0] != model_batch_size) { MSI_LOG_EXCEPTION << "Output shape " << shape << " not match batch size " << model_batch_size; } @@ -429,9 +429,9 @@ Status WorkExecutor::Predict(const std::vector &inputs, std::vector py_preprocess_task_queue, - std::shared_ptr py_postprocess_task_queue, - std::shared_ptr cpp_preprocess_task_queue, - std::shared_ptr cpp_postprocess_task_queue); + WorkExecutor(std::shared_ptr py_preprocess, std::shared_ptr py_postprocess, + std::shared_ptr cpp_preprocess, std::shared_ptr cpp_postprocess); ~WorkExecutor(); Status Init(const ServableSignature &servable_declare, const std::shared_ptr &servable); diff --git a/mindspore_serving/ccsrc/worker/worker.cc b/mindspore_serving/ccsrc/worker/worker.cc index aad5339..10e25e1 100644 --- a/mindspore_serving/ccsrc/worker/worker.cc +++ b/mindspore_serving/ccsrc/worker/worker.cc @@ -34,46 +34,16 @@ namespace py = pybind11; namespace mindspore { namespace serving { -static const char *kVersionStrategyLastest = "lastest"; -static const char *kVersionStrategySpecific = "specific"; -static std::unique_ptr grpc_async_worker_server_; - Worker &Worker::GetInstance() { static Worker instance; return instance; } -Status Worker::StartGrpcServer(const std::string &ip, uint32_t grpc_port) { - if (grpc_async_worker_server_ != nullptr) { - return INFER_STATUS_LOG_ERROR(SYSTEM_ERROR) << "Serving Error: Worker gRPC server is already running"; - } - grpc_async_worker_server_ = std::make_unique(ip, grpc_port); - return grpc_async_worker_server_->Init(); -} - Status Worker::RegisterWorker() { - std::vector specs; - std::vector signatures; - for (auto &work : work_list_) { - specs.push_back(work.servable_spec); - signatures.push_back(work.servable_signature); - } std::vector worker_specs; - for (size_t i = 0; i < specs.size(); i++) { - auto &spec = specs[i]; - auto &servable_signature = signatures[i]; - WorkerSpec worker_spec; - worker_spec.servable_name = spec.servable_name; - worker_spec.version_number = spec.version_number; - for (auto &method : servable_signature.methods) { - WorkerMethodInfo worker_method_info; - worker_method_info.name = method.method_name; - for (auto &name : method.inputs) { - worker_method_info.input_names.push_back(name); - } - worker_spec.methods.push_back(worker_method_info); - } - worker_specs.push_back(worker_spec); + for (auto &work : work_list_) { + // cppcheck-suppress useStlAlgorithm + worker_specs.push_back(work.worker_spec); } auto status = notify_master_->Register(worker_specs); return status; @@ -84,34 +54,10 @@ Status Worker::StartVersionController() { return SUCCESS; } -Status Worker::AddWorker(const ServableWorkerContext &work) { - WorkerSpec worker_spec; - worker_spec.servable_name = work.servable_spec.servable_name; - worker_spec.version_number = work.servable_spec.version_number; - for (auto &method : work.servable_signature.methods) { - WorkerMethodInfo worker_method_info; - worker_method_info.name = method.method_name; - for (auto &name : method.inputs) { - worker_method_info.input_names.push_back(name); - } - worker_spec.methods.push_back(worker_method_info); - } - return notify_master_->AddWorker(worker_spec); -} +Status Worker::AddWorker(const ServableWorkerContext &work) { return notify_master_->AddWorker(work.worker_spec); } Status Worker::RemoveWorker(const ServableWorkerContext &work) { - WorkerSpec worker_spec; - worker_spec.servable_name = work.servable_spec.servable_name; - worker_spec.version_number = work.servable_spec.version_number; - for (auto &method : work.servable_signature.methods) { - WorkerMethodInfo worker_method_info; - worker_method_info.name = method.method_name; - for (auto &name : method.inputs) { - worker_method_info.input_names.push_back(name); - } - worker_spec.methods.push_back(worker_method_info); - } - return notify_master_->RemoveWorker(worker_spec); + return notify_master_->RemoveWorker(work.worker_spec); } Status Worker::Run(const proto::PredictRequest &request, proto::PredictReply *reply) { @@ -189,74 +135,8 @@ std::pair> Worker::RunAsync(const RequestSp return {SUCCESS, result}; } -Status Worker::InitEnv(ModelType model_type, const std::map &other_options) { - Status status; - if (session_) { - return INFER_STATUS_LOG_ERROR(FAILED) << "Session has been inited"; - } - auto context = ServableContext::Instance(); - DeviceType device_type = kDeviceTypeNotSpecified; - session_ = InferSessionStorage::Instance().Get(context->GetDeviceType(), model_type, &device_type); - if (session_ == nullptr) { - return INFER_STATUS_LOG_ERROR(FAILED) - << "Cannot find session registered for device type " << context->GetDeviceType() << " and model type " - << model_type << ". Ascend 910 supports MindIR model and Ascend 310 supports OM, MindIR model"; - } - if (device_type != kDeviceTypeNotSpecified) { - context->SetDeviceType(device_type); - } - status = session_->InitEnv(context->GetDeviceType(), context->GetDeviceId(), other_options); - if (status != SUCCESS) { - session_ = nullptr; - return INFER_STATUS_LOG_ERROR(FAILED) - << "Init env failed, device type " << context->GetDeviceType() << ", device id " << context->GetDeviceId(); - } - return SUCCESS; -} - -Status Worker::FinalizeEnv() { - if (session_ != nullptr) { - return session_->FinalizeEnv(); - } - return SUCCESS; -} -Status Worker::LoadModel(LoadServableSpec *servable_spec, uint64_t version_number, ServableWorkerContext *work) { - MSI_EXCEPTION_IF_NULL(servable_spec); - MSI_EXCEPTION_IF_NULL(work); - servable_spec->version_number = version_number; - ServableSignature signature; - if (!ServableStorage::Instance().GetServableDef(servable_spec->servable_name, &signature)) { - return INFER_STATUS_LOG_ERROR(FAILED) << "Servable " << servable_spec->servable_name << " has not been registerd"; - } - const auto &servable_meta = signature.servable_meta; - std::string model_file_name = servable_spec->servable_directory + "/" + servable_spec->servable_name + "/" + - std::to_string(version_number) + "/" + servable_meta.servable_file; - uint32_t model_id; - auto context = ServableContext::Instance(); - Status status = session_->LoadModelFromFile(context->GetDeviceType(), context->GetDeviceId(), model_file_name, - servable_meta.model_format, servable_meta.without_batch_dim_inputs, - servable_meta.load_options, &model_id); - if (status != SUCCESS) { - return INFER_STATUS_LOG_ERROR(FAILED) - << "Load model failed, servable directory: '" << servable_spec->servable_directory << "', servable name: '" - << servable_spec->servable_name << "', servable file: '" << servable_meta.servable_file - << "', version number " << version_number << ", options " << servable_meta.load_options; - } - auto service = std::make_shared(GetPyTaskQueuePreprocess(), GetPyTaskQueuePostprocess(), - GetCppTaskQueuePreprocess(), GetCppTaskQueuePostprocess()); - status = service->Init(signature, std::make_shared(session_, model_id)); - if (status != SUCCESS) { - return status; - } - work->servable_spec = *servable_spec; - work->servable_signature = signature; - work->worker_service = service; - work->model_id = model_id; - work->model_file_name = model_file_name; - return SUCCESS; -} - void Worker::Update() { + /* if (version_strategy_ == kVersionStrategySpecific) { return; } @@ -291,10 +171,19 @@ void Worker::Update() { MSI_LOG_INFO << "UnLoad Model version " << iter->servable_spec.version_number << " success"; work_list_.erase(iter); } + */ } -Status Worker::StartServable(const std::string &servable_directory, const std::string &servable_name, - uint32_t version_number, std::shared_ptr notify_master) { +Status Worker::StartGrpcServer(const std::shared_ptr &grpc_server, const std::string &worker_ip, + int32_t port) { + if (worker_grpc_server_ != nullptr) { + return INFER_STATUS_LOG_ERROR(FAILED) << "Worker gRPC server is already running"; + } + worker_grpc_server_ = grpc_server; + return worker_grpc_server_->StartWorkerGrpcServer(worker_ip, port); +} + +Status Worker::StartServable(std::shared_ptr servable, std::shared_ptr notify_master) { ExitSignalHandle::Instance().Start(); // handle ctrl+c to exit if (servable_started_) { MSI_LOG_EXCEPTION << "A servable has been started, only one servable can run in a process currently."; @@ -307,58 +196,42 @@ Status Worker::StartServable(const std::string &servable_directory, const std::s cpp_postprocess_.Start(2); notify_master_ = std::move(notify_master); - base_spec_.servable_directory = servable_directory; - base_spec_.servable_name = servable_name; - base_spec_.version_number = version_number; - - std::string version_strategy; - if (version_number == 0) { - version_strategy = kVersionStrategyLastest; - } else { - version_strategy = kVersionStrategySpecific; - } - Status status; + auto servable_name = servable->GetServableName(); ServableSignature signature; if (!ServableStorage::Instance().GetServableDef(servable_name, &signature)) { - return INFER_STATUS_LOG_ERROR(FAILED) << "Servable '" << servable_name << "' has not been registered"; - } - if (session_ == nullptr) { - status = InitEnv(signature.servable_meta.model_format, {}); - if (status != SUCCESS) { - MSI_LOG_ERROR << "Init env failed"; - return status; - } + return INFER_STATUS_LOG_ERROR(FAILED) << "Servable " << servable_name << " has not been registered"; } - - std::vector real_versions; - status = LoadServableConfig(base_spec_, version_strategy, &real_versions); + auto service = std::make_shared(GetPyTaskQueuePreprocess(), GetPyTaskQueuePostprocess(), + GetCppTaskQueuePreprocess(), GetCppTaskQueuePostprocess()); + auto status = service->Init(signature, servable); if (status != SUCCESS) { - return INFER_STATUS_LOG_ERROR(FAILED) - << "Start servable failed, there is no servable of the specified version number, specified version number: " - << version_number << ", servable directory: '" << base_spec_.servable_directory << "', servable name: '" - << base_spec_.servable_name - << "'. version number is a positive integer(started from 1) and 0 represents the maximum version number."; + return status; } - for (auto real_version_number : real_versions) { - ServableWorkerContext work; - status = LoadModel(&base_spec_, real_version_number, &work); - if (status != SUCCESS) { - return status; + ServableWorkerContext work; + WorkerSpec worker_spec; + worker_spec.servable_name = servable_name; + worker_spec.version_number = servable->GetServableVersion(); + for (auto &method : signature.methods) { + WorkerMethodInfo worker_method_info; + worker_method_info.name = method.method_name; + for (auto &name : method.inputs) { + worker_method_info.input_names.push_back(name); } - work_list_.push_back(work); + worker_spec.methods.push_back(worker_method_info); } + work.worker_spec = worker_spec; + work.servable_signature = signature; + work.worker_service = service; + work.servable = servable; + + work_list_.push_back(work); + status = RegisterWorker(); if (status != SUCCESS) { MSI_LOG_ERROR << "Register worker failed"; return status; } servable_started_ = true; - status = INFER_STATUS(SUCCESS) << "Serving: Start servable success, servable directory: '" << servable_directory - << "', servable name: '" << servable_name - << "', specified version number: " << version_number - << ", started version numbers: " << real_versions; - MSI_LOG_INFO << status.StatusMessage(); - std::cout << status.StatusMessage() << std::endl; return SUCCESS; } @@ -368,119 +241,39 @@ void Worker::StopServable(bool notify_master) { } void Worker::Clear() { + std::unique_lock lock(worker_shared_lock_); + ServableStorage::Instance().Clear(); + worker_grpc_server_ = nullptr; if (clear_flag_.test_and_set()) { return; } - std::unique_lock lock(worker_shared_lock_); MSI_LOG_INFO << "Start clear worker session"; version_controller_.StopPollModelPeriodic(); if (exit_notify_master_ && servable_started_) { notify_master_->Unregister(); } - if (session_ != nullptr) { - for (auto &it : work_list_) { - session_->UnloadModel(it.model_id); - } + for (auto &worker_item : work_list_) { + worker_item.servable->Clear(); } work_list_.clear(); - FinalizeEnv(); - session_ = nullptr; py_task_queue_group_.Stop(); cpp_preprocess_.Stop(); cpp_postprocess_.Stop(); - ServableStorage::Instance().Clear(); - grpc_async_worker_server_ = nullptr; servable_started_ = false; MSI_LOG_INFO << "End clear worker session"; } -bool Worker::HasCleared() { return !servable_started_; } +bool Worker::IsRunning() { return servable_started_; } Worker::~Worker() { Clear(); } -void Worker::GetVersions(const LoadServableSpec &servable_spec, std::vector *real_versions) { - MSI_EXCEPTION_IF_NULL(real_versions); - // define version_strategy:"specific","lastest","multi" - if (version_strategy_ == kVersionStrategySpecific) { - real_versions->push_back(servable_spec.version_number); - return; - } - auto trans_to_integer = [](const std::string &str) -> uint32_t { - uint32_t parsed_value = 0; - for (auto c : str) { - if (c < '0' || c > '9') { - return 0; - } - parsed_value = parsed_value * 10 + c - '0'; - } - if (std::to_string(parsed_value) != str) { - return 0; - } - return parsed_value; - }; - uint64_t newest_version = 0; - std::string model_path = servable_spec.servable_directory + "/" + servable_spec.servable_name; - auto sub_dir = GetAllSubDirsNotFullPath(model_path); - static std::set ignore_dir; - for (const auto &dir : sub_dir) { - if (dir == "__pycache__") continue; - auto version_parse = trans_to_integer(dir); - if (version_parse == 0) { - if (ignore_dir.emplace(servable_spec.servable_directory + dir).second) { - MSI_LOG_INFO << "Ignore directory " << dir << ", model_directory " << servable_spec.servable_directory - << ", model_name " << servable_spec.servable_name; - } - continue; - } - real_versions->push_back(version_parse); - if (version_parse > newest_version) { - newest_version = version_parse; - } - } - if (version_strategy_ == kVersionStrategyLastest) { - real_versions->clear(); - if (newest_version != 0) { - real_versions->push_back(newest_version); - } - } -} -Status Worker::LoadServableConfig(const LoadServableSpec &servable_spec, const std::string &version_strategy, - std::vector *real_versions) { - MSI_EXCEPTION_IF_NULL(real_versions); - auto model_directory = servable_spec.servable_directory; - auto model_name = servable_spec.servable_name; - - if (!DirOrFileExist(model_directory + "/" + model_name)) { - return INFER_STATUS_LOG_ERROR(FAILED) - << "Model not found, model_directory " << model_directory << ", model_name " << model_name; - } - std::string model_path = model_directory + "/" + model_name; - auto version_directory = [model_path](int64_t version_number) { - return model_path + "/" + std::to_string(version_number); - }; - version_strategy_ = version_strategy; - // version_strategy:"specific","lastest","multi" - GetVersions(servable_spec, real_versions); - if (real_versions->size() == 0) { - return INFER_STATUS_LOG_ERROR(FAILED) - << "Not found invalid model version , model_directory " << model_directory << ", model_name " << model_name; - } - for (auto real_version_number : *real_versions) { - if (!DirOrFileExist(version_directory(real_version_number))) { - return INFER_STATUS_LOG_ERROR(FAILED) << "Open failed for version " << real_version_number << ", model_directory " - << model_directory << ", model_name " << model_name; - } - } - return SUCCESS; -} - ServableWorkerContext Worker::GetServableWorker(const RequestSpec &request_spec) { ServableWorkerContext context; if (request_spec.version_number != 0) { auto item = find_if(work_list_.begin(), work_list_.end(), [&](const ServableWorkerContext &v) { - return v.servable_spec.servable_name == request_spec.servable_name && - v.servable_spec.version_number == request_spec.version_number; + return v.worker_spec.servable_name == request_spec.servable_name && + v.worker_spec.version_number == request_spec.version_number; }); if (item != work_list_.end()) { context = *item; @@ -488,10 +281,10 @@ ServableWorkerContext Worker::GetServableWorker(const RequestSpec &request_spec) } else { uint64_t max_version = 0; for (auto &item : work_list_) { - if (item.servable_spec.servable_name == request_spec.servable_name && - item.servable_spec.version_number > max_version) { + if (item.worker_spec.servable_name == request_spec.servable_name && + item.worker_spec.version_number > max_version) { context = item; - max_version = item.servable_spec.version_number; + max_version = item.worker_spec.version_number; } } } @@ -500,11 +293,11 @@ ServableWorkerContext Worker::GetServableWorker(const RequestSpec &request_spec) Worker::Worker() {} -ssize_t Worker::GetBatchSize() const { - ssize_t batch_size_ret = -1; - for (auto service : work_list_) { - auto batch_size = session_->GetBatchSize(service.model_id); - if (batch_size != -1) { +size_t Worker::GetBatchSize() const { + size_t batch_size_ret = 1; + for (const auto &service : work_list_) { + auto batch_size = service.servable->GetBatchSize(); + if (batch_size != 0) { batch_size_ret = batch_size; break; } @@ -532,7 +325,7 @@ Status AsyncResult::GetNext(Instance *instance_result) { const int kWaitMaxHundredMs = 100; int i; for (i = 0; i < kWaitMaxHundredMs; i++) { // - if (ExitSignalHandle::Instance().HasStopped() || Worker::GetInstance().HasCleared()) { + if (ExitSignalHandle::Instance().HasStopped() || !Worker::GetInstance().IsRunning()) { instance_result->error_msg = Status(SYSTEM_ERROR, "Servable stopped"); return SYSTEM_ERROR; } diff --git a/mindspore_serving/ccsrc/worker/worker.h b/mindspore_serving/ccsrc/worker/worker.h index 007a5cd..ef66043 100644 --- a/mindspore_serving/ccsrc/worker/worker.h +++ b/mindspore_serving/ccsrc/worker/worker.h @@ -32,6 +32,8 @@ #include "worker/task_queue.h" #include "worker/version_control/version_controller.h" #include "common/grpc_async_server.h" +#include "worker/sevable_base.h" +#include "worker/grpc/worker_server.h" namespace mindspore { namespace serving { @@ -53,11 +55,10 @@ class AsyncResult { }; struct ServableWorkerContext { - LoadServableSpec servable_spec; + WorkerSpec worker_spec; ServableSignature servable_signature; std::shared_ptr worker_service = nullptr; - uint32_t model_id = 0; - std::string model_file_name; + std::shared_ptr servable = nullptr; }; class MS_API Worker { @@ -72,17 +73,14 @@ class MS_API Worker { Status Run(const RequestSpec &request_spec, const std::vector &inputs, std::vector *outputs); std::pair> RunAsync(const RequestSpec &request_spec, const std::vector &inputs); + Status StartServable(std::shared_ptr servable, std::shared_ptr notify_master); - Status InitEnv(ModelType model_type, const std::map &other_options); - Status FinalizeEnv(); + Status StartGrpcServer(const std::shared_ptr &grpc_server, const std::string &worker_ip, + int32_t port); - Status StartServable(const std::string &servable_directory, const std::string &servable_name, uint32_t version_number, - std::shared_ptr notify_master); void StopServable(bool notify_master = true); - bool HasCleared(); + bool IsRunning(); Status RegisterWorker(); - Status StartGrpcServer(const std::string &ip, uint32_t grpc_port); - Status LoadModel(LoadServableSpec *servable_spec, uint64_t version, ServableWorkerContext *work); void Update(); Status StartVersionController(); Status AddWorker(const ServableWorkerContext &work); @@ -93,31 +91,24 @@ class MS_API Worker { std::shared_ptr GetPyTaskQueuePostprocess() { return py_task_queue_group_.GetPostprocessTaskQueue(); } std::shared_ptr GetCppTaskQueuePreprocess() { return cpp_preprocess_.GetTaskQueue(); } std::shared_ptr GetCppTaskQueuePostprocess() { return cpp_postprocess_.GetTaskQueue(); } - ssize_t GetBatchSize() const; + size_t GetBatchSize() const; private: - static std::shared_ptr global_worker_; - std::vector work_list_; - std::shared_ptr session_ = nullptr; - std::string version_strategy_; PyTaskQueueGroup py_task_queue_group_; PreprocessThreadPool cpp_preprocess_; PostprocessThreadPool cpp_postprocess_; VersionController version_controller_; - LoadServableSpec base_spec_; std::atomic_bool exit_notify_master_ = true; std::atomic_bool servable_started_ = false; std::atomic_flag clear_flag_ = ATOMIC_FLAG_INIT; std::shared_ptr notify_master_ = nullptr; + std::shared_ptr worker_grpc_server_ = nullptr; std::shared_mutex worker_shared_lock_; ServableWorkerContext GetServableWorker(const RequestSpec &request_spec); - Status LoadServableConfig(const LoadServableSpec &servable_spec, const std::string &version_strategy, - std::vector *real_version_number); - void GetVersions(const LoadServableSpec &servable_spec, std::vector *real_versions); }; } // namespace serving diff --git a/mindspore_serving/master/_master.py b/mindspore_serving/master/_master.py index 78abb52..0d61459 100644 --- a/mindspore_serving/master/_master.py +++ b/mindspore_serving/master/_master.py @@ -18,6 +18,7 @@ import threading from functools import wraps from mindspore_serving.worker import check_type from mindspore_serving import log as logger +from mindspore_serving._mindspore_serving import ExitSignalHandle_ from mindspore_serving._mindspore_serving import Master_ _wait_and_clear_thread = None @@ -59,6 +60,7 @@ def stop_on_except(func): @wraps(func) def handle_except(*args, **kwargs): try: + ExitSignalHandle_.start() # Set flag to running and receive Ctrl+C message func(*args, **kwargs) except: stop() diff --git a/mindspore_serving/proto/ms_agent.proto b/mindspore_serving/proto/ms_agent.proto new file mode 100644 index 0000000..1642993 --- /dev/null +++ b/mindspore_serving/proto/ms_agent.proto @@ -0,0 +1,43 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// ms_manager.proto +syntax = "proto3"; + +package mindspore.serving.proto; +import "mindspore_serving/proto/ms_service.proto"; + +message DistributedPredictRequest { + repeated Tensor inputs = 1; +} + +message DistributedPredictReply { + repeated Tensor outputs = 1; + ErrorMsg error_msg = 2; +} + +message DistributedExitRequest { + string address = 1; +} + +message DistributedExitReply { + ErrorMsg error_msg = 1; +} + +service MSAgent { + rpc Predict(DistributedPredictRequest) returns (DistributedPredictReply) {} + rpc Exit(DistributedExitRequest) returns (DistributedExitReply) {} +} diff --git a/mindspore_serving/proto/ms_distributed.proto b/mindspore_serving/proto/ms_distributed.proto new file mode 100644 index 0000000..27fa6c4 --- /dev/null +++ b/mindspore_serving/proto/ms_distributed.proto @@ -0,0 +1,53 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// ms_manager.proto +syntax = "proto3"; + +package mindspore.serving.proto; +import "mindspore_serving/proto/ms_service.proto"; + +message AgentSpec { + int64 rank_id = 1; + int64 batch_size = 2; + repeated Tensor inputs = 3; + repeated Tensor outputs = 4; +} + +message AgentRegisterRequest { + repeated AgentSpec agent_spec = 1; + string address = 2; +} + +message AgentRegisterReply { + ErrorMsg error_msg = 1; +} + +message AgentExitRequest { + repeated AgentSpec agent_spec = 1; + string address = 2; +} + +message AgentExitReply { + ErrorMsg error_msg = 1; +} + +message AgentFailedRequest { +} + +message AgentFailedReply { + ErrorMsg error_msg = 1; +} diff --git a/mindspore_serving/proto/ms_service.proto b/mindspore_serving/proto/ms_service.proto index 908c1dd..ddcccd4 100644 --- a/mindspore_serving/proto/ms_service.proto +++ b/mindspore_serving/proto/ms_service.proto @@ -80,6 +80,8 @@ message Tensor { // for string type and images, the dtype is MS_BYTES. repeated bytes bytes_val = 4; + + int64 size = 5; } message ServableSpec { diff --git a/mindspore_serving/proto/ms_worker.proto b/mindspore_serving/proto/ms_worker.proto index c9ed051..7b2dbe0 100644 --- a/mindspore_serving/proto/ms_worker.proto +++ b/mindspore_serving/proto/ms_worker.proto @@ -20,8 +20,14 @@ syntax = "proto3"; package mindspore.serving.proto; import "mindspore_serving/proto/ms_service.proto"; import "mindspore_serving/proto/ms_master.proto"; +import "mindspore_serving/proto/ms_distributed.proto"; service MSWorker { + // for master rpc Predict(PredictRequest) returns (PredictReply) {} rpc Exit(ExitRequest) returns (ExitReply) {} + // for worker agent + rpc AgentExit(AgentExitRequest) returns (AgentExitReply) {} + rpc AgentRegister(AgentRegisterRequest) returns (AgentRegisterReply) {} + rpc AgentFailed(AgentFailedRequest) returns (AgentFailedReply) {} } diff --git a/mindspore_serving/worker/_worker.py b/mindspore_serving/worker/_worker.py index 33ba9fb..b11de95 100644 --- a/mindspore_serving/worker/_worker.py +++ b/mindspore_serving/worker/_worker.py @@ -12,11 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ -"""Inferface for start up servable""" +"""Interface for start up servable""" import threading from functools import wraps from mindspore_serving import log as logger +from mindspore_serving._mindspore_serving import ExitSignalHandle_ from mindspore_serving._mindspore_serving import Worker_ from .register.preprocess import preprocess_storage from .register.postprocess import postprocess_storage @@ -77,6 +78,7 @@ def stop_on_except(func): @wraps(func) def handle_except(*args, **kwargs): try: + ExitSignalHandle_.start() # Set flag to running and receive Ctrl+C message func(*args, **kwargs) except: stop() diff --git a/mindspore_serving/worker/distributed/agent_startup.py b/mindspore_serving/worker/distributed/agent_startup.py new file mode 100644 index 0000000..8bf27d1 --- /dev/null +++ b/mindspore_serving/worker/distributed/agent_startup.py @@ -0,0 +1,250 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Serving, distributed worker agent startup""" + +import os +import time +from multiprocessing import Process, Pipe + +from mindspore_serving._mindspore_serving import ExitSignalHandle_ +from mindspore_serving._mindspore_serving import WorkerAgent_, AgentStartUpConfig_ + +from mindspore_serving import log as logger +from mindspore_serving.worker import check_type +from mindspore_serving.worker.distributed import worker_agent + + +def _get_local_ip(rank_list, port): + """Get the local ip from the rank table config""" + import socket + ip_list = [] + for item in rank_list: + ip_list.append(item.ip) + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + for ip in ip_list: + try: + s.bind((ip, port)) + logger.info(f"Get local machine ip success, ip {ip}") + return ip + # pylint: disable=bare-except + except: + pass + raise RuntimeError(f"Get local machine ip failed, rank table ips: {ip_list}, bind port {port}") + + +def _update_model_files_path(model_files, group_config_files): + """Check and return model files or group config files""" + script_dir = os.path.dirname(os.path.realpath(__file__)) + logger.info(f"input model files: {model_files}") + logger.info(f"input group config files: {group_config_files}") + model_files_temp = [] + for item in model_files: + file_name = os.path.join(script_dir, item) + if not os.access(file_name, os.R_OK): + raise RuntimeError(f"Cannot access model file '{file_name}'") + model_files_temp.append(file_name) + + group_files_temp = [] + for item in group_config_files: + file_name = os.path.join(script_dir, item) + if not os.access(file_name, os.R_OK): + raise RuntimeError(f"Cannot access group config file '{file_name}'") + group_files_temp.append(file_name) + + logger.info(f"absolute model files: {model_files_temp}") + logger.info(f"absolute group config files: {group_files_temp}") + return model_files_temp, group_files_temp + + +def _make_json_table_file(distributed_config): + """Make rank table json file""" + rank_size = len(distributed_config.rank_list) + runtime_dir = os.path.abspath(".") + time_stamp = str(time.strftime('%Y_%m_%d_%H_%M_%S', time.localtime(time.time()))) + rank_table_file_name = os.path.join(runtime_dir, f"hccl_rank_table_{time_stamp}_{rank_size}.json") + with open(rank_table_file_name, "w") as fp: + fp.write(distributed_config.rank_table_content) + return rank_table_file_name + + +signal_success = "Success" +signal_exit = "Exit" +signal_heartbeat = "HeartBeat" + + +def _recv_parent(index, recv_pipe): + """Receive message from Start up process. + Return False on Ctrl+C(and worker Stop message) Exit Signal, heartbeat failed, and signal_exit. + Return True on receiving signal_success.""" + try: + while True: + heartbeat_count = 0 + while not recv_pipe.poll(0.1): + if ExitSignalHandle_.has_stopped(): + logger.warning(f"Child {index}: Exit on Ctrl+C or stop message from worker") + return False + heartbeat_count += 1 + if heartbeat_count >= 30: # 3s + logger.warning(f"Child {index}: Exit on failure of receiving parent message") + return False + parent_signal = recv_pipe.recv() + if parent_signal != signal_heartbeat: + break + if parent_signal == signal_success: + logger.info(f"Child {index}: Receive success") + return True + if parent_signal == signal_exit: + logger.warning(f"Child {index}: Exit on receiving exit message") + else: + logger.warning(f"Child {index}: Exit on receiving unknown message {parent_signal}") + # pylint: disable=broad-except + except Exception as e: + logger.warning(f"Child {index}: Exit on exception: {e}") + return False + + +def _agent_process(send_pipe, recv_pipe, index, start_config): + """Agent process""" + try: + # listening success or failed message from parent process + ExitSignalHandle_.start() # Set flag to running and receive Ctrl+C message + worker_agent.start_worker_agent(start_config=start_config) + send_pipe.send((index, signal_success)) + success_msg = _recv_parent(index, recv_pipe) + if not success_msg: + worker_agent.stop() + send_pipe.close() + recv_pipe.close() + # pylint: disable=broad-except + except Exception as e: + logger.error(f"Child {index}: Catch exception and notify exit of others") + send_pipe.send((index, e)) + worker_agent.stop() + raise + + +def _start_listening_child_processes(p_recv_pipe, send_pipe_list, subprocess_list): + """Listening child process""" + def send_pipe_msg(send_pipe, msg): + try: + send_pipe.send(msg) + # pylint: disable=broad-except + except Exception as e: + logger.warning(f"Send pipe message exception happen: {e}") + + count = len(send_pipe_list) + for _ in range(count): + while True: + if p_recv_pipe.poll(0.1): + break + for send_pipe, process in zip(send_pipe_list, subprocess_list): + if process.is_alive(): + continue + logger.warning("Fail to start agents because of death of one agent") + for send_pipe_x, process_x in zip(send_pipe_list, subprocess_list): + if process_x.is_alive(): + send_pipe_msg(send_pipe_x, signal_exit) + return False + for send_pipe in send_pipe_list: + send_pipe_msg(send_pipe, signal_heartbeat) + + _, msg = p_recv_pipe.recv() + if isinstance(msg, Exception): + logger.warning("Fail to start agents because of exception raise by one agent") + for send_pipe in send_pipe_list: + send_pipe_msg(send_pipe, signal_exit) + return False + + for send_pipe in send_pipe_list: + send_pipe_msg(send_pipe, signal_success) + logger.info("Success to start agents") + return True + + +def _startup_all_agents(common_meta, worker_ip, worker_port, + agent_ip, agent_start_port, device_id_list, rank_id_list, + model_files, group_config_files, rank_table_file): + """Start up all agents in one machine""" + servable_name = common_meta.servable_name + index = 0 + send_pipe_list = [] + subprocess_list = [] + c_send_pipe, p_recv_pipe = Pipe() + for device_id, rank_id, model_file, group_file in zip(device_id_list, rank_id_list, model_files, + group_config_files): + p_send_pipe, c_recv_pipe = Pipe() + send_pipe_list.append(p_send_pipe) + + agent_port = agent_start_port + index + + start_config = AgentStartUpConfig_() + start_config.rank_id = rank_id + start_config.device_id = device_id + start_config.model_file_name = model_file + start_config.group_file_name = group_file + start_config.rank_table_json_file_name = rank_table_file + start_config.agent_ip = agent_ip + start_config.agent_port = agent_port + start_config.worker_ip = worker_ip + start_config.worker_port = worker_port + start_config.common_meta = common_meta + + process = Process(target=_agent_process, + args=(c_send_pipe, c_recv_pipe, index, start_config), + name=f"{servable_name}_worker_agent_rank{rank_id}_device{device_id}") + process.start() + subprocess_list.append(process) + index += 1 + ret = _start_listening_child_processes(p_recv_pipe, send_pipe_list, subprocess_list) + if not ret: + WorkerAgent_.notify_failed(worker_ip, worker_port) + + +def startup_worker_agents(worker_ip, worker_port, model_files, group_config_files, agent_start_port=7000): + """Start up all needed worker agents on one machine""" + check_type.check_str("worker_ip", worker_ip) + check_type.check_ip_port("worker_port", worker_port) + check_type.check_int("agent_start_port", agent_start_port, 1, 65535 - 7) + model_files = check_type.check_and_as_int_tuple_list("model_files", model_files) + group_config_files = check_type.check_and_as_int_tuple_list("group_config_files", group_config_files) + distributed_config = WorkerAgent_.get_agents_config_from_worker(worker_ip, worker_port) + + # get machine ip + rank_list = distributed_config.rank_list + local_ip = _get_local_ip(rank_list, agent_start_port) + # get all device_id and rank_id + local_device_id_list = [] + local_rank_id_list = [] + for rank_id, item in enumerate(rank_list): + if item.ip == local_ip: + local_device_id_list.append(item.device_id) + local_rank_id_list.append(rank_id) + + # handle model files and group config files + if len(local_device_id_list) != len(model_files): + raise RuntimeError(f"Card count {local_device_id_list} described rank table does not equal to model files size " + f"{len(model_files)}, model files: {model_files}") + + if len(local_device_id_list) != len(group_config_files): + raise RuntimeError(f"Card count {local_device_id_list} described rank table does not equal to group config " + f"files size {len(group_config_files)}, group config files: {group_config_files}") + + model_files, group_config_files = _update_model_files_path(model_files, group_config_files) + + # make json table file and export env + rank_table_file = _make_json_table_file(distributed_config) + _startup_all_agents(distributed_config.common_meta, worker_ip, worker_port, local_ip, agent_start_port, + local_device_id_list, local_rank_id_list, + model_files, group_config_files, rank_table_file) diff --git a/mindspore_serving/worker/distributed/distributed_worker.py b/mindspore_serving/worker/distributed/distributed_worker.py new file mode 100644 index 0000000..4235ee6 --- /dev/null +++ b/mindspore_serving/worker/distributed/distributed_worker.py @@ -0,0 +1,131 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Serving, distributed worker startup""" +from mindspore_serving._mindspore_serving import Worker_ + +from mindspore_serving.worker import check_type +from mindspore_serving.worker._worker import _start_py_task, _start_wait_and_clear +from mindspore_serving.worker._worker import stop_on_except, _load_servable_config + + +@stop_on_except +def start_distributed_servable(servable_directory, servable_name, rank_table_json_file, version_number=1, + worker_ip="0.0.0.0", worker_port=6200, master_ip="0.0.0.0", master_port=6100, + wait_agents_time_in_seconds=300): + r""" + Start up the servable named 'servable_name' defined in 'servable_directory', and link the worker to the master + through gRPC (master_ip, master_port). + + Serving has two running modes. One is running in a single process, providing the Serving service of a single model. + The other includes a master and multiple workers. This interface is for the second scenario. + + The master is responsible for providing the Serving access interface for clients, + while the worker is responsible for providing the inference service of the specific model. The communications + between the master and workers through gPRC are defined as (master_ip, master_port) and (worker_ip, worker_port). + + Args: + servable_directory (str): The directory where the servable is located in. There expects to has a directory + named `servable_name`. For more detail: + `How to config Servable `_ . + + servable_name (str): The servable name. + version_number (int): Servable version number to be loaded. The version number should be a positive integer, + starting from 1, and 0 means to load the latest version. Default: 0. + rank_table_json_file (str): The ranke table json file name. + master_ip (str): The master ip the worker linked to. + master_port (int): The master port the worker linked to. + worker_ip (str): The worker ip the master and agents linked to. + worker_port (int): The worker port the master and agents linked to. + wait_agents_time_in_seconds(int): The maximum time in seconds the worker waiting ready of all agents. + + Examples: + >>> import os + >>> from mindspore_serving import worker + >>> + >>> servable_dir = os.path.abspath(".") + >>> worker.start_servable(servable_dir, "lenet", device_id=0, \ + ... master_ip="127.0.0.1", master_port=6500, \ + ... host_ip="127.0.0.1", host_port=6600) + """ + check_type.check_str('servable_directory', servable_directory) + check_type.check_str('servable_name', servable_name) + check_type.check_int('version_number', version_number, 0) + if version_number == 0: + version_number = 1 + check_type.check_str('rank_table_json_file', rank_table_json_file) + + check_type.check_str('master_ip', master_ip) + check_type.check_ip_port('master_port', master_port) + + check_type.check_str('worker_ip', worker_ip) + check_type.check_ip_port('worker_port', worker_port) + + _load_servable_config(servable_directory, servable_name) + Worker_.start_distributed_servable(servable_directory, servable_name, rank_table_json_file, version_number, + master_ip, master_port, worker_ip, worker_port, wait_agents_time_in_seconds) + _start_py_task(Worker_.get_batch_size()) + _start_wait_and_clear() + + +@stop_on_except +def start_distributed_servable_in_master(servable_directory, servable_name, rank_table_json_file, version_number=1, + worker_ip="0.0.0.0", worker_port=6200, wait_agents_time_in_seconds=300): + r""" + Start up the servable named 'servable_name' defined in 'svable_directory', and the worker will run in + the process of the master. + + Serving has two running modes. One is running in a single process, providing the Serving service of a single model. + The other includes a master and multiple workers. This interface is for the first scenario. + + Args: + servable_directory (str): The directory where the servable is located in. There expects to has a directory named + `servable_name`. For more detail: + `How to config Servable `_ . + + servable_name (str): The servable name. + version_number (int): Servable version number to be loaded. The version number should be a positive integer, + starting from 1, and 0 means to load the latest version. Default: 0. + rank_table_json_file (str): The ranke table json file name. + worker_ip (str): The worker ip the agents linked to. + worker_port (int): The worker port the agents linked to. + wait_agents_time_in_seconds(int): The maximum time in seconds the worker waiting ready of all agents. + + Examples: + >>> import os + >>> from mindspore_serving import worker + >>> from mindspore_serving import master + >>> + >>> servable_dir = os.path.abspath(".") + >>> worker.start_servable_in_master(servable_dir, "lenet", device_id=0) + >>> + >>> master.start_grpc_server("0.0.0.0", 5500) + >>> master.start_restful_server("0.0.0.0", 1500) + """ + check_type.check_str('servable_directory', servable_directory) + check_type.check_str('servable_name', servable_name) + check_type.check_int('version_number', version_number, 0) + if version_number == 0: + version_number = 1 + + check_type.check_str('rank_table_json_file', rank_table_json_file) + + check_type.check_str('worker_ip', worker_ip) + check_type.check_ip_port('worker_port', worker_port) + + _load_servable_config(servable_directory, servable_name) + Worker_.start_distributed_servable_in_master(servable_directory, servable_name, rank_table_json_file, + version_number, worker_ip, worker_port, wait_agents_time_in_seconds) + _start_py_task(Worker_.get_batch_size()) + _start_wait_and_clear() diff --git a/mindspore_serving/worker/distributed/register.py b/mindspore_serving/worker/distributed/register.py new file mode 100644 index 0000000..bac0f35 --- /dev/null +++ b/mindspore_serving/worker/distributed/register.py @@ -0,0 +1,43 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Serving, distributed worker register""" + +from mindspore_serving._mindspore_serving import ServableMeta_, ServableStorage_ +from mindspore_serving.worker import check_type +from mindspore_serving.worker.common import get_servable_dir +from mindspore_serving import log as logger + + +def declare_distributed_servable(rank_size, stage_size, with_batch_dim, without_batch_dim_inputs): + """declare distributed servable in servable_config.py""" + check_type.check_bool('with_batch_dim', with_batch_dim) + + meta = ServableMeta_() + meta.common_meta.servable_name = get_servable_dir() + meta.common_meta.with_batch_dim = with_batch_dim + if without_batch_dim_inputs: + without_batch_dim_inputs = check_type.check_and_as_int_tuple_list('without_batch_dim_inputs', + without_batch_dim_inputs, 0) + meta.common_meta.without_batch_dim_inputs = without_batch_dim_inputs + + # init distributed servable meta info + check_type.check_int("rank_size", rank_size, 1) + check_type.check_int("stage_size", stage_size, 1) + meta.distributed_meta.rank_size = rank_size + meta.distributed_meta.stage_size = stage_size + ServableStorage_.declare_distributed_servable(meta) + logger.info(f"Declare distributed servable, servable_name: {meta.common_meta.servable_name} " + f", rank_size: {rank_size} , stage_size: {stage_size}, with_batch_dim: {with_batch_dim} " + f", without_batch_dim_inputs: {without_batch_dim_inputs}") diff --git a/mindspore_serving/worker/distributed/worker_agent.py b/mindspore_serving/worker/distributed/worker_agent.py new file mode 100644 index 0000000..ad32a53 --- /dev/null +++ b/mindspore_serving/worker/distributed/worker_agent.py @@ -0,0 +1,66 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Serving, distributed worker agent""" + +import os +import threading +from mindspore_serving._mindspore_serving import WorkerAgent_, AgentStartUpConfig_ +from mindspore_serving import log as logger + + +def start_worker_agent(start_config): + """Start up one worker agent on one device id, invoke by agent_startup.startup_worker_agents + """ + if not isinstance(start_config, AgentStartUpConfig_): + raise RuntimeError("Parameter 'start_config' should be instance of AgentStartUpConfig_") + + os.environ["RANK_ID"] = str(start_config.rank_id) + os.environ["DEVICE_ID"] = str(start_config.device_id) + os.environ["MS_ENABLE_HCCL"] = "1" + os.environ["PARA_GROUP_FILE"] = start_config.group_file_name + os.environ["RANK_TABLE_FILE"] = start_config.rank_table_json_file_name + + for item in ("RANK_ID", "DEVICE_ID", "MS_ENABLE_HCCL", "PARA_GROUP_FILE", "RANK_TABLE_FILE", + "LD_LIBRARY_PATH", "PYTHONPATH"): + logger.info(f"Env {item}: {os.getenv(item, '')}") + WorkerAgent_.start_agent(start_config) + + start_wait_and_clear() + + +_wait_and_clear_thread = None + + +def start_wait_and_clear(): + """Waiting for Ctrl+C, and clear up environment""" + + def thread_func(): + logger.info("Serving worker: wait for Ctrl+C to exit ------------------------------------") + print("Serving worker: wait for Ctrl+C to exit ------------------------------------") + WorkerAgent_.wait_and_clear() + logger.info("Serving worker: exited ------------------------------------") + print("Serving worker: exited ------------------------------------") + + global _wait_and_clear_thread + if not _wait_and_clear_thread: + _wait_and_clear_thread = threading.Thread(target=thread_func) + _wait_and_clear_thread.start() + + +def stop(): + r""" + Stop the running of agent. + """ + WorkerAgent_.stop_and_clear() diff --git a/mindspore_serving/worker/register/method.py b/mindspore_serving/worker/register/method.py index 224796c..ccde72c 100644 --- a/mindspore_serving/worker/register/method.py +++ b/mindspore_serving/worker/register/method.py @@ -35,28 +35,6 @@ method_tag_predict = PredictPhaseTag_.kPredictPhaseTag_Predict method_tag_postprocess = PredictPhaseTag_.kPredictPhaseTag_Postprocess -class _ServableStorage: - """Declare servable info""" - - def __init__(self): - pass - - @staticmethod - def declare_servable(servable_meta): - """Declare servable info excluding method, input and output count""" - ServableStorage_.declare_servable(servable_meta) - - @staticmethod - def declare_servable_input_output(servable_name, inputs_count, outputs_count): - """Declare input and output count of servable""" - ServableStorage_.register_servable_input_output_info(servable_name, inputs_count, outputs_count) - - @staticmethod - def register_method(method_signature): - """Declare method of servable""" - ServableStorage_.register_method(method_signature) - - class _TensorDef: """Data flow item, for definitions of data flow in a method""" @@ -251,7 +229,7 @@ def call_servable(*args): servable_name = get_servable_dir() inputs_count, outputs_count = method_def_ast_meta_[_call_servable_name] - _ServableStorage.declare_servable_input_output(servable_name, inputs_count, outputs_count) + ServableStorage_.register_servable_input_output_info(servable_name, inputs_count, outputs_count) if inputs_count != len(args): raise RuntimeError(f"Check failed in method '{method_def_context_.method_name}', given servable input " f"size {len(args)} not match '{servable_name}' ast parse size {inputs_count}") @@ -467,7 +445,7 @@ def register_method(output_names): f", servable_name {method_def_context_.servable_name}, inputs: {input_names}, outputs: " f"{output_names}") - _ServableStorage.register_method(method_def_context_) + ServableStorage_.register_method(method_def_context_) return func return register diff --git a/mindspore_serving/worker/register/servable.py b/mindspore_serving/worker/register/servable.py index 97b4829..97eb23d 100644 --- a/mindspore_serving/worker/register/servable.py +++ b/mindspore_serving/worker/register/servable.py @@ -14,11 +14,10 @@ # ============================================================================ """Servable declaration interface""" -from mindspore_serving._mindspore_serving import ServableMeta_ +from mindspore_serving._mindspore_serving import ServableMeta_, ServableStorage_ from mindspore_serving.worker import check_type from mindspore_serving.worker.common import get_servable_dir from mindspore_serving import log as logger -from .method import _ServableStorage def declare_servable(servable_file, model_format, with_batch_dim=True, options=None, without_batch_dim_inputs=None): @@ -37,19 +36,25 @@ def declare_servable(servable_file, model_format, with_batch_dim=True, options=N RuntimeError: The type or value of the parameters is invalid. """ - check_type.check_str('servable_file', servable_file) - check_type.check_str('model_format', model_format) check_type.check_bool('with_batch_dim', with_batch_dim) + meta = ServableMeta_() + meta.common_meta.servable_name = get_servable_dir() + meta.common_meta.with_batch_dim = with_batch_dim + if without_batch_dim_inputs: + without_batch_dim_inputs = check_type.check_and_as_int_tuple_list('without_batch_dim_inputs', + without_batch_dim_inputs, 0) + meta.common_meta.without_batch_dim_inputs = without_batch_dim_inputs + + # init local servable meta info + check_type.check_str('servable_file', servable_file) + check_type.check_str('model_format', model_format) model_format = model_format.lower() if model_format not in ("om", "mindir"): raise RuntimeError("model format can only be OM or MindIR") - meta = ServableMeta_() - meta.servable_name = get_servable_dir() - meta.servable_file = servable_file - meta.set_model_format(model_format) - meta.with_batch_dim = with_batch_dim + meta.local_meta.servable_file = servable_file + meta.local_meta.set_model_format(model_format) if isinstance(options, dict): for k, w in options.items(): check_type.check_str("options key", k) @@ -61,14 +66,10 @@ def declare_servable(servable_file, model_format, with_batch_dim=True, options=N raise RuntimeError(f"Parameter 'options' should be None, dict of or AclOptions, but " f"gotten {type(options)}") if options: - meta.options = options - if without_batch_dim_inputs: - without_batch_dim_inputs = check_type.check_and_as_int_tuple_list('without_batch_dim_inputs', - without_batch_dim_inputs, 0) - meta.without_batch_dim_inputs = without_batch_dim_inputs + meta.local_meta.options = options - _ServableStorage.declare_servable(meta) - logger.info(f"Declare servable, servable_name: {meta.servable_name} " + ServableStorage_.declare_servable(meta) + logger.info(f"Declare servable, servable_name: {meta.common_meta.servable_name} " f", servable_file: {servable_file} , model_format: {model_format}, with_batch_dim: {with_batch_dim} " f", options: {options}, without_batch_dim_inputs: {without_batch_dim_inputs}") diff --git a/tests/ut/cpp/common/test_servable_common.h b/tests/ut/cpp/common/test_servable_common.h index 91df994..5f565ac 100644 --- a/tests/ut/cpp/common/test_servable_common.h +++ b/tests/ut/cpp/common/test_servable_common.h @@ -27,6 +27,7 @@ #include "worker/worker.h" #include "worker/notfiy_master/local_notify.h" #include "worker/context.h" +#include "worker/local_servable/local_sevable.h" #include "master/grpc/grpc_process.h" #include "mindspore_serving/proto/ms_service.pb.h" @@ -102,16 +103,21 @@ class TestMasterWorker : public UT::Common { auto notify_master = std::make_shared(); ServableContext::Instance()->SetDeviceId(0); ServableContext::Instance()->SetDeviceTypeStr("Ascend"); - Status status = Worker::GetInstance().StartServable(servable_dir, servable_name, version_number, notify_master); + auto servable = std::make_shared(); + auto status = servable->StartServable(servable_dir, servable_name, version_number); + if (status != SUCCESS) { + return status; + } + status = Worker::GetInstance().StartServable(servable, notify_master); return status; } static void DeclareServable(const std::string &servable_name, const std::string &servable_file, const std::string &model_type, bool with_batch_dim = false) { ServableMeta servable_meta; - servable_meta.servable_name = servable_name; - servable_meta.servable_file = servable_file; - servable_meta.SetModelFormat(model_type); - servable_meta.with_batch_dim = with_batch_dim; + servable_meta.common_meta.servable_name = servable_name; + servable_meta.common_meta.with_batch_dim = with_batch_dim; + servable_meta.local_meta.servable_file = servable_file; + servable_meta.local_meta.SetModelFormat(model_type); // declare_servable ServableStorage::Instance().DeclareServable(servable_meta); } diff --git a/tests/ut/cpp/tests/test_start_worker.cc b/tests/ut/cpp/tests/test_start_worker.cc index 63e0838..8ab955b 100644 --- a/tests/ut/cpp/tests/test_start_worker.cc +++ b/tests/ut/cpp/tests/test_start_worker.cc @@ -30,10 +30,7 @@ TEST_F(TestStartWorker, test_worker_start_success) { DeclareServable("test_servable", "test_add.mindir", "mindir", true); RegisterMethod("test_servable", "add_common", {"x1", "x2"}, {"y"}, 2, 1); // start_servable - auto notify_master = std::make_shared(); - ServableContext::Instance()->SetDeviceId(0); - ServableContext::Instance()->SetDeviceTypeStr("Ascend"); - Status status = Worker::GetInstance().StartServable("test_servable_dir", "test_servable", 0, notify_master); + Status status = StartServable("test_servable_dir", "test_servable", 0); EXPECT_TRUE(status.IsSuccess()); } @@ -43,10 +40,7 @@ TEST_F(TestStartWorker, test_worker_start_error_model_file_name) { RegisterMethod("test_servable", "add_common", {"x1", "x2"}, {"y"}, 2, 1); // start_servable - auto notify_master = std::make_shared(); - ServableContext::Instance()->SetDeviceId(0); - ServableContext::Instance()->SetDeviceTypeStr("Ascend"); - Status status = Worker::GetInstance().StartServable("test_servable_dir", "test_servable", 0, notify_master); + auto status = StartServable("test_servable_dir", "test_servable", 0); EXPECT_FALSE(status.IsSuccess()); ExpectContainMsg(status.StatusMessage(), "Load model failed, servable directory: "); } @@ -57,12 +51,8 @@ TEST_F(TestStartWorker, test_worker_start_error_version_number) { RegisterMethod("test_servable", "add_common", {"x1", "x2"}, {"y"}, 2, 1); // start_servable - auto notify_master = std::make_shared(); - ServableContext::Instance()->SetDeviceId(0); - ServableContext::Instance()->SetDeviceTypeStr("Ascend"); int error_version_number = 2; - Status status = - Worker::GetInstance().StartServable("test_servable_dir", "test_servable", error_version_number, notify_master); + auto status = StartServable("test_servable_dir", "test_servable", error_version_number); EXPECT_FALSE(status.IsSuccess()); ExpectContainMsg(status.StatusMessage(), "Start servable failed, there is no servable of" @@ -78,11 +68,8 @@ TEST_F(TestStartWorker, test_worker_start_multi_version_number) { RegisterMethod("test_servable", "add_common", {"x1", "x2"}, {"y"}, 2, 1); // start_servable - auto notify_master = std::make_shared(); - ServableContext::Instance()->SetDeviceId(0); - ServableContext::Instance()->SetDeviceTypeStr("Ascend"); int version_number = 0; - Status status = Worker::GetInstance().StartServable(servable_dir, "test_servable", version_number, notify_master); + Status status = StartServable(servable_dir, "test_servable", version_number); EXPECT_TRUE(status.IsSuccess()); } @@ -96,10 +83,7 @@ TEST_F(TestStartWorker, test_worker_start_version_number_no_valid) { RegisterMethod("test_servable", "add_common", {"x1", "x2"}, {"y"}, 2, 1); // start_servable - auto notify_master = std::make_shared(); - ServableContext::Instance()->SetDeviceId(0); - ServableContext::Instance()->SetDeviceTypeStr("Ascend"); - Status status = Worker::GetInstance().StartServable(servable_dir, "test_servable", 0, notify_master); + Status status = StartServable(servable_dir, "test_servable", 0); EXPECT_FALSE(status.IsSuccess()); ExpectContainMsg(status.StatusMessage(), "Start servable failed, there is no servable of" @@ -112,11 +96,8 @@ TEST_F(TestStartWorker, test_worker_start_error_servable_dir) { RegisterMethod("test_servable", "add_common", {"x1", "x2"}, {"y"}, 2, 1); // start_servable - auto notify_master = std::make_shared(); - ServableContext::Instance()->SetDeviceId(0); - ServableContext::Instance()->SetDeviceTypeStr("Ascend"); std::string error_servable_dir = "test_servable_dir_error"; - Status status = Worker::GetInstance().StartServable(error_servable_dir, "test_servable", 0, notify_master); + Status status = StartServable(error_servable_dir, "test_servable", 0); EXPECT_FALSE(status.IsSuccess()); ExpectContainMsg(status.StatusMessage(), "Start servable failed, there is no servable of" @@ -129,11 +110,8 @@ TEST_F(TestStartWorker, test_worker_start_error_servable_name) { RegisterMethod("test_servable", "add_common", {"x1", "x2"}, {"y"}, 2, 1); // start_servable - auto notify_master = std::make_shared(); - ServableContext::Instance()->SetDeviceId(0); - ServableContext::Instance()->SetDeviceTypeStr("Ascend"); std::string error_servable_name = "test_servable_error"; - Status status = Worker::GetInstance().StartServable("test_servable_dir", error_servable_name, 0, notify_master); + Status status = StartServable("test_servable_dir", error_servable_name, 0); EXPECT_FALSE(status.IsSuccess()); ExpectContainMsg(status.StatusMessage(), "'test_servable_error' has not been registered"); } @@ -144,24 +122,18 @@ TEST_F(TestStartWorker, test_worker_start_error_servable_format) { RegisterMethod("test_servable", "add_common", {"x1", "x2"}, {"y"}, 2, 1); // start_servable - auto notify_master = std::make_shared(); - ServableContext::Instance()->SetDeviceId(0); - ServableContext::Instance()->SetDeviceTypeStr("Ascend"); - Status status = Worker::GetInstance().StartServable("test_servable_dir", "test_servable", 0, notify_master); + Status status = StartServable("test_servable_dir", "test_servable", 0); EXPECT_FALSE(status.IsSuccess()); - ExpectContainMsg(status.StatusMessage(), "Cannot find session registered for device type Ascend and model type OM"); + ExpectContainMsg(status.StatusMessage(), "Not support device type Ascend and model type OM. "); } TEST_F(TestStartWorker, test_worker_start_no_registered_method) { - Init("test_servable_dir", "test_servable", 1, "test_add.mindir"); + Init("test_servable_dir", "test_servable", 2, "test_add.mindir"); DeclareServable("test_servable", "test_add.mindir", "mindir", true); // no registered method // RegisterMethod("test_servable", "add_common", {"x1", "x2"}, {"y"}, 2, 1); // start_servable - auto notify_master = std::make_shared(); - ServableContext::Instance()->SetDeviceId(0); - ServableContext::Instance()->SetDeviceTypeStr("Ascend"); - Status status = Worker::GetInstance().StartServable("test_servable_dir", "test_servable", 0, notify_master); + Status status = StartServable("test_servable_dir", "test_servable", 2); EXPECT_FALSE(status.IsSuccess()); ExpectContainMsg(status.StatusMessage(), "There is no method registered for servable"); } @@ -181,10 +153,7 @@ TEST_F(TestStartWorker, test_worker_start_multi_method) { RegisterMethod("test_servable", "add_common", {"x1", "x2"}, {"y"}, 2, 1); RegisterMethod("test_servable", "add_common2", {"x1", "x2"}, {"y"}, 2, 1); // start_servable - auto notify_master = std::make_shared(); - ServableContext::Instance()->SetDeviceId(0); - ServableContext::Instance()->SetDeviceTypeStr("Ascend"); - Status status = Worker::GetInstance().StartServable("test_servable_dir", "test_servable", 0, notify_master); + Status status = StartServable("test_servable_dir", "test_servable", 0); EXPECT_TRUE(status.IsSuccess()); } @@ -194,10 +163,7 @@ TEST_F(TestStartWorker, test_worker_start_method_servable_input_count_not_match) size_t servable_input_count = 1; RegisterMethod("test_servable", "add_common", {"x1", "x2"}, {"y"}, servable_input_count, 1); // start_servable - auto notify_master = std::make_shared(); - ServableContext::Instance()->SetDeviceId(0); - ServableContext::Instance()->SetDeviceTypeStr("Ascend"); - Status status = Worker::GetInstance().StartServable("test_servable_dir", "test_servable", 0, notify_master); + Status status = StartServable("test_servable_dir", "test_servable", 0); EXPECT_FALSE(status.IsSuccess()); ExpectContainMsg(status.StatusMessage(), "The inputs count 1 registered in method not equal to " @@ -210,10 +176,7 @@ TEST_F(TestStartWorker, test_worker_start_method_servable_output_count_not_match size_t servable_output_count = 2; RegisterMethod("test_servable", "add_common", {"x1", "x2"}, {"y"}, 2, servable_output_count); // start_servable - auto notify_master = std::make_shared(); - ServableContext::Instance()->SetDeviceId(0); - ServableContext::Instance()->SetDeviceTypeStr("Ascend"); - Status status = Worker::GetInstance().StartServable("test_servable_dir", "test_servable", 0, notify_master); + Status status = StartServable("test_servable_dir", "test_servable", 0); EXPECT_FALSE(status.IsSuccess()); ExpectContainMsg(status.StatusMessage(), "The outputs count 2 registered in method not equal to " @@ -241,10 +204,7 @@ TEST_F(TestStartWorker, test_worker_start_preprocess_not_found) { ServableStorage::Instance().RegisterMethod(method_signature); // start_servable - auto notify_master = std::make_shared(); - ServableContext::Instance()->SetDeviceId(0); - ServableContext::Instance()->SetDeviceTypeStr("Ascend"); - Status status = Worker::GetInstance().StartServable("test_servable_dir", "test_servable", 0, notify_master); + Status status = StartServable("test_servable_dir", "test_servable", 0); EXPECT_FALSE(status.IsSuccess()); ExpectContainMsg(status.StatusMessage(), " preprocess preprocess_fake_fun not defined") } @@ -269,10 +229,7 @@ TEST_F(TestStartWorker, test_worker_start_postprocess_not_found) { ServableStorage::Instance().RegisterMethod(method_signature); // start_servable - auto notify_master = std::make_shared(); - ServableContext::Instance()->SetDeviceId(0); - ServableContext::Instance()->SetDeviceTypeStr("Ascend"); - Status status = Worker::GetInstance().StartServable("test_servable_dir", "test_servable", 0, notify_master); + Status status = StartServable("test_servable_dir", "test_servable", 0); EXPECT_FALSE(status.IsSuccess()); ExpectContainMsg(status.StatusMessage(), " postprocess postprocess_fake_fun not defined") } @@ -300,10 +257,7 @@ TEST_F(TestStartWorker, test_worker_start_with_preproces_and_postprocess_success ServableStorage::Instance().RegisterMethod(method_signature); // start_servable - auto notify_master = std::make_shared(); - ServableContext::Instance()->SetDeviceId(0); - ServableContext::Instance()->SetDeviceTypeStr("Ascend"); - Status status = Worker::GetInstance().StartServable("test_servable_dir", "test_servable", 0, notify_master); + Status status = StartServable("test_servable_dir", "test_servable", 0); EXPECT_TRUE(status.IsSuccess()); } diff --git a/tests/ut/runtest.sh b/tests/ut/runtest.sh index 9a3dfa7..da92b72 100755 --- a/tests/ut/runtest.sh +++ b/tests/ut/runtest.sh @@ -16,21 +16,24 @@ set -e -CURRPATH=$(cd "$(dirname $0)" || exit; pwd) +CURRPATH=$( + cd "$(dirname $0)" || exit + pwd +) if [ $# -gt 0 ]; then - if [ $1 == "python" ]; then - echo "run python ut" - bash ${CURRPATH}/python/runtest.sh $2 - elif [ $1 == "cpp" ]; then - echo "run cpp ut" - bash ${CURRPATH}/cpp/runtest.sh - fi -else - echo "run all ut" - # 1.run python testcases + if [ $1 == "python" ]; then + echo "run python ut" bash ${CURRPATH}/python/runtest.sh $2 - - # 2.run c++ ut testcases + elif [ $1 == "cpp" ]; then + echo "run cpp ut" bash ${CURRPATH}/cpp/runtest.sh + fi +else + echo "run all ut" + # 1.run python testcases + bash ${CURRPATH}/python/runtest.sh $2 + + # 2.run c++ ut testcases + bash ${CURRPATH}/cpp/runtest.sh fi