From: @xu-yfei Reviewed-by: @zhoufeng54,@linqingke Signed-off-by: @zhoufeng54tags/v1.2.0
| @@ -55,6 +55,17 @@ void ExitSignalHandle::WorkerWait() { | |||
| exit_future.wait(); | |||
| } | |||
| // waiting ctrl+c or stop message to exit, | |||
| // if no server is running or server has exited, there is no need to wait | |||
| void ExitSignalHandle::AgentWait() { | |||
| if (!is_running_) { | |||
| MSI_LOG_INFO << "Exit Handle has not started or has exited"; | |||
| return; | |||
| } | |||
| auto exit_future = agent_exit_requested_.get_future(); | |||
| exit_future.wait(); | |||
| } | |||
| void ExitSignalHandle::Start() { | |||
| if (is_running_) { | |||
| return; | |||
| @@ -62,6 +73,7 @@ void ExitSignalHandle::Start() { | |||
| is_running_ = true; | |||
| master_exit_requested_ = std::promise<void>(); | |||
| worker_exit_requested_ = std::promise<void>(); | |||
| agent_exit_requested_ = std::promise<void>(); | |||
| has_exited_.clear(); | |||
| InitSignalHandle(); | |||
| } | |||
| @@ -79,6 +91,7 @@ void ExitSignalHandle::HandleSignalInner() { | |||
| if (!has_exited_.test_and_set()) { | |||
| master_exit_requested_.set_value(); | |||
| worker_exit_requested_.set_value(); | |||
| agent_exit_requested_.set_value(); | |||
| is_running_ = false; | |||
| } | |||
| } | |||
| @@ -32,6 +32,7 @@ class MS_API ExitSignalHandle { | |||
| void InitSignalHandle(); | |||
| void MasterWait(); | |||
| void WorkerWait(); | |||
| void AgentWait(); | |||
| void Start(); | |||
| void Stop(); | |||
| bool HasStopped(); | |||
| @@ -39,6 +40,7 @@ class MS_API ExitSignalHandle { | |||
| private: | |||
| std::promise<void> master_exit_requested_; | |||
| std::promise<void> worker_exit_requested_; | |||
| std::promise<void> agent_exit_requested_; | |||
| std::atomic_flag has_exited_ = true; | |||
| std::atomic_flag has_inited_ = ATOMIC_FLAG_INIT; | |||
| std::atomic_bool is_running_ = false; | |||
| @@ -0,0 +1,25 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "common/grpc_client.h" | |||
| namespace mindspore { | |||
| namespace serving { | |||
| std::unique_ptr<MSPredictClient> client_; | |||
| std::unique_ptr<MSDistributedClient> distributed_client_; | |||
| } // namespace serving | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,115 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_SERVING_MASTER_GRPC_CLIENT_H | |||
| #define MINDSPORE_SERVING_MASTER_GRPC_CLIENT_H | |||
| #include <grpcpp/grpcpp.h> | |||
| #include <grpcpp/health_check_service_interface.h> | |||
| #include <grpcpp/ext/proto_server_reflection_plugin.h> | |||
| #include <memory> | |||
| #include <functional> | |||
| #include <thread> | |||
| #include <string> | |||
| #include <utility> | |||
| #include "common/serving_common.h" | |||
| #include "proto/ms_service.pb.h" | |||
| #include "proto/ms_service.grpc.pb.h" | |||
| #include "proto/ms_master.pb.h" | |||
| #include "proto/ms_master.grpc.pb.h" | |||
| #include "proto/ms_worker.grpc.pb.h" | |||
| #include "proto/ms_agent.pb.h" | |||
| #include "proto/ms_agent.grpc.pb.h" | |||
| namespace mindspore { | |||
| namespace serving { | |||
| using PredictOnFinish = std::function<void()>; | |||
| using DispatchCallback = std::function<void(Status status)>; | |||
| template <typename Request, typename Reply, typename MSStub> | |||
| class MSServiceClient { | |||
| public: | |||
| MSServiceClient() = default; | |||
| ~MSServiceClient() { | |||
| if (in_running_) { | |||
| cq_.Shutdown(); | |||
| if (client_thread_.joinable()) { | |||
| try { | |||
| client_thread_.join(); | |||
| } catch (const std::system_error &) { | |||
| } catch (...) { | |||
| } | |||
| } | |||
| } | |||
| in_running_ = false; | |||
| } | |||
| void Start() { | |||
| client_thread_ = std::thread(&MSServiceClient::AsyncCompleteRpc, this); | |||
| in_running_ = true; | |||
| } | |||
| void AsyncCompleteRpc() { | |||
| void *got_tag; | |||
| bool ok = false; | |||
| while (cq_.Next(&got_tag, &ok)) { | |||
| AsyncClientCall *call = static_cast<AsyncClientCall *>(got_tag); | |||
| if (call->status.ok()) { | |||
| call->callback(SUCCESS); | |||
| } else { | |||
| MSI_LOG_ERROR << "RPC failed: " << call->status.error_code() << ", " << call->status.error_message(); | |||
| call->callback(Status(FAILED, call->status.error_message())); | |||
| } | |||
| delete call; | |||
| } | |||
| } | |||
| void PredictAsync(const Request &request, Reply *reply, MSStub *stub, DispatchCallback callback) { | |||
| AsyncClientCall *call = new AsyncClientCall; | |||
| call->reply = reply; | |||
| call->callback = std::move(callback); | |||
| call->response_reader = stub->PrepareAsyncPredict(&call->context, request, &cq_); | |||
| call->response_reader->StartCall(); | |||
| call->response_reader->Finish(call->reply, &call->status, call); | |||
| MSI_LOG(INFO) << "Finish send Predict"; | |||
| } | |||
| private: | |||
| struct AsyncClientCall { | |||
| grpc::ClientContext context; | |||
| grpc::Status status; | |||
| Reply *reply; | |||
| DispatchCallback callback; | |||
| std::shared_ptr<grpc::ClientAsyncResponseReader<Reply>> response_reader; | |||
| }; | |||
| grpc::CompletionQueue cq_; | |||
| std::thread client_thread_; | |||
| bool in_running_ = false; | |||
| }; | |||
| using MSPredictClient = MSServiceClient<proto::PredictRequest, proto::PredictReply, proto::MSWorker::Stub>; | |||
| using MSDistributedClient = | |||
| MSServiceClient<proto::DistributedPredictRequest, proto::DistributedPredictReply, proto::MSAgent::Stub>; | |||
| extern std::unique_ptr<MSPredictClient> client_; | |||
| extern std::unique_ptr<MSDistributedClient> distributed_client_; | |||
| } // namespace serving | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_SERVING_MASTER_GRPC_CLIENT_H | |||
| @@ -341,6 +341,56 @@ Status GrpcTensorHelper::CreateInstanceFromRequestInstances(const proto::Predict | |||
| return SUCCESS; | |||
| } | |||
| void GrpcTensorHelper::CopyFromAgentSpec(const proto::AgentSpec &specs, WorkerAgentSpec *worker_specs) { | |||
| worker_specs->rank_id = specs.rank_id(); | |||
| worker_specs->batch_size = specs.batch_size(); | |||
| for (auto &in : specs.inputs()) { | |||
| TensorInfo info; | |||
| info.data_type = ProtoTensor::TransDataType2Inference(in.dtype()); | |||
| info.size = in.size(); | |||
| for (auto &dim : in.shape().dims()) { | |||
| info.shape.push_back(dim); | |||
| } | |||
| worker_specs->input_infos.push_back(info); | |||
| } | |||
| for (auto &out : specs.outputs()) { | |||
| TensorInfo info; | |||
| info.data_type = ProtoTensor::TransDataType2Inference(out.dtype()); | |||
| for (auto &dim : out.shape().dims()) { | |||
| info.shape.push_back(dim); | |||
| } | |||
| worker_specs->output_infos.push_back(info); | |||
| } | |||
| } | |||
| void GrpcTensorHelper::CopyFromWorkerAgentSpec(const std::vector<WorkerAgentSpec> &worker_specs, | |||
| proto::AgentRegisterRequest *request) { | |||
| for (size_t i = 0; i < worker_specs.size(); i++) { | |||
| auto &spec = worker_specs[i]; | |||
| auto worker_spec = request->add_agent_spec(); | |||
| worker_spec->set_rank_id(spec.rank_id); | |||
| worker_spec->set_batch_size(spec.batch_size); | |||
| for (auto &method : spec.input_infos) { | |||
| auto proto_method = worker_spec->add_inputs(); | |||
| proto_method->set_dtype(ProtoTensor::TransDataType2Proto(method.data_type)); | |||
| proto_method->set_size(method.size); | |||
| auto proto_shape = proto_method->mutable_shape(); | |||
| for (auto &dim : method.shape) { | |||
| proto_shape->add_dims(dim); | |||
| } | |||
| } | |||
| for (auto &method : spec.output_infos) { | |||
| auto proto_method = worker_spec->add_outputs(); | |||
| proto_method->set_dtype(ProtoTensor::TransDataType2Proto(method.data_type)); | |||
| proto_method->set_size(method.size); | |||
| auto proto_shape = proto_method->mutable_shape(); | |||
| for (auto &dim : method.shape) { | |||
| proto_shape->add_dims(dim); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| Status GrpcTensorHelper::CheckRequestTensor(const proto::Tensor &tensor) { | |||
| Status status; | |||
| ProtoTensor tensor_input(const_cast<proto::Tensor *>(&tensor)); | |||
| @@ -24,6 +24,7 @@ | |||
| #include "common/serving_common.h" | |||
| #include "proto/ms_service.pb.h" | |||
| #include "proto/ms_master.pb.h" | |||
| #include "proto/ms_distributed.pb.h" | |||
| #include "common/instance.h" | |||
| #include "common/servable.h" | |||
| @@ -68,6 +69,9 @@ class MS_API GrpcTensorHelper { | |||
| std::vector<InstanceData> *results); | |||
| static Status CreateReplyFromInstances(const proto::PredictRequest &request, const std::vector<Instance> &inputs, | |||
| proto::PredictReply *reply); | |||
| static void CopyFromAgentSpec(const proto::AgentSpec &request, WorkerAgentSpec *worker_specs); | |||
| static void CopyFromWorkerAgentSpec(const std::vector<WorkerAgentSpec> &worker_specs, | |||
| proto::AgentRegisterRequest *request); | |||
| private: | |||
| static Status CreateInstanceFromRequestInstances(const proto::PredictRequest &request, | |||
| @@ -25,11 +25,23 @@ namespace mindspore::serving { | |||
| std::string ServableMeta::Repr() const { | |||
| std::ostringstream stream; | |||
| stream << "path(" << servable_name << ") file(" << servable_file + ")"; | |||
| switch (servable_type) { | |||
| case kServableTypeUnknown: | |||
| stream << "undeclared servable, servable name: '" << common_meta.servable_name << "'"; | |||
| break; | |||
| case kServableTypeLocal: | |||
| stream << "local servable, servable name: '" << common_meta.servable_name << "', file: '" | |||
| << local_meta.servable_file + "'"; | |||
| break; | |||
| case kServableTypeDistributed: | |||
| stream << "distributed servable, servable name: '" << common_meta.servable_name | |||
| << "', rank size: " << distributed_meta.rank_size << ", stage size " << distributed_meta.stage_size; | |||
| break; | |||
| } | |||
| return stream.str(); | |||
| } | |||
| void ServableMeta::SetModelFormat(const std::string &format) { | |||
| void LocalServableMeta::SetModelFormat(const std::string &format) { | |||
| if (format == "om") { | |||
| model_format = kOM; | |||
| } else if (format == "mindir") { | |||
| @@ -63,142 +75,181 @@ std::string RequestSpec::Repr() const { | |||
| return "servable(" + servable_name + ") " + "method(" + method_name + ") " + version; | |||
| } | |||
| Status ServableSignature::Check() const { | |||
| std::set<std::string> method_set; | |||
| Status ServableSignature::CheckPreprocessInput(const MethodSignature &method, size_t *preprocess_outputs_count) const { | |||
| std::string model_str = servable_meta.Repr(); | |||
| const auto &preprocess_name = method.preprocess_name; | |||
| if (!preprocess_name.empty()) { | |||
| auto preprocess = PreprocessStorage::Instance().GetPreprocess(preprocess_name); | |||
| if (preprocess == nullptr) { | |||
| return INFER_STATUS_LOG_ERROR(FAILED) << "Model " << model_str << " method " << method.method_name | |||
| << " preprocess " << preprocess_name << " not defined"; | |||
| } | |||
| *preprocess_outputs_count = preprocess->GetOutputsCount(preprocess_name); | |||
| for (auto &method : methods) { | |||
| if (method_set.count(method.method_name) > 0) { | |||
| return INFER_STATUS_LOG_ERROR(FAILED) | |||
| << "Model " << model_str << " " << method.method_name << " has been defined repeatly"; | |||
| for (size_t i = 0; i < method.preprocess_inputs.size(); i++) { | |||
| auto &input = method.preprocess_inputs[i]; | |||
| if (input.first != kPredictPhaseTag_Input) { | |||
| return INFER_STATUS_LOG_ERROR(FAILED) | |||
| << "Model " << model_str << " method " << method.method_name << ", the data of preprocess " << i | |||
| << "th input cannot not come from '" << input.first << "'"; | |||
| } | |||
| if (input.second >= method.inputs.size()) { | |||
| return INFER_STATUS_LOG_ERROR(FAILED) | |||
| << "Model " << model_str << " method " << method.method_name << ", the preprocess " << i | |||
| << "th input uses method " << input.second << "th input, that is greater than the method inputs size " | |||
| << method.inputs.size(); | |||
| } | |||
| } | |||
| method_set.emplace(method.method_name); | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| size_t preprocess_outputs_count = 0; | |||
| size_t postprocess_outputs_count = 0; | |||
| Status ServableSignature::CheckPredictInput(const MethodSignature &method, size_t preprocess_outputs_count) const { | |||
| std::string model_str = servable_meta.Repr(); | |||
| const auto &preprocess_name = method.preprocess_name; | |||
| if (!preprocess_name.empty()) { | |||
| auto preprocess = PreprocessStorage::Instance().GetPreprocess(preprocess_name); | |||
| if (preprocess == nullptr) { | |||
| return INFER_STATUS_LOG_ERROR(FAILED) << "Model " << model_str << " method " << method.method_name | |||
| << " preprocess " << preprocess_name << " not defined"; | |||
| for (size_t i = 0; i < method.servable_inputs.size(); i++) { | |||
| auto &input = method.servable_inputs[i]; | |||
| if (input.first == kPredictPhaseTag_Input) { | |||
| if (input.second >= method.inputs.size()) { | |||
| return INFER_STATUS_LOG_ERROR(FAILED) | |||
| << "Model " << model_str << " method " << method.method_name << ", the servable " << i | |||
| << "th input uses method " << input.second << "th input, that is greater than the method inputs size " | |||
| << method.inputs.size(); | |||
| } | |||
| preprocess_outputs_count = preprocess->GetOutputsCount(preprocess_name); | |||
| for (size_t i = 0; i < method.preprocess_inputs.size(); i++) { | |||
| auto &input = method.preprocess_inputs[i]; | |||
| if (input.first != kPredictPhaseTag_Input) { | |||
| return INFER_STATUS_LOG_ERROR(FAILED) | |||
| << "Model " << model_str << " method " << method.method_name << ", the data of preprocess " << i | |||
| << "th input cannot not come from '" << input.first << "'"; | |||
| } | |||
| if (input.second >= method.inputs.size()) { | |||
| return INFER_STATUS_LOG_ERROR(FAILED) | |||
| << "Model " << model_str << " method " << method.method_name << ", the preprocess " << i | |||
| << "th input uses method " << input.second << "th input, that is greater than the method inputs size " | |||
| << method.inputs.size(); | |||
| } | |||
| } else if (input.first == kPredictPhaseTag_Preproces) { | |||
| if (input.second >= preprocess_outputs_count) { | |||
| return INFER_STATUS_LOG_ERROR(FAILED) | |||
| << "Model " << model_str << " method " << method.method_name << ", the servable " << i | |||
| << "th input uses preprocess " << input.second | |||
| << "th output, that is greater than the preprocess outputs size " << preprocess_outputs_count; | |||
| } | |||
| } else { | |||
| return INFER_STATUS_LOG_ERROR(FAILED) | |||
| << "Model " << model_str << " method " << method.method_name << ", the data of servable " << i | |||
| << "th input cannot not come from '" << input.first << "'"; | |||
| } | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| Status ServableSignature::CheckPostprocessInput(const MethodSignature &method, size_t preprocess_outputs_count, | |||
| size_t *postprocess_outputs_count) const { | |||
| std::string model_str = servable_meta.Repr(); | |||
| const auto &common_meta = servable_meta.common_meta; | |||
| const auto &postprocess_name = method.postprocess_name; | |||
| if (!method.postprocess_name.empty()) { | |||
| auto postprocess = PostprocessStorage::Instance().GetPostprocess(postprocess_name); | |||
| if (postprocess == nullptr) { | |||
| return INFER_STATUS_LOG_ERROR(FAILED) << "Model " << model_str << " method " << method.method_name | |||
| << " postprocess " << postprocess_name << " not defined"; | |||
| } | |||
| *postprocess_outputs_count = postprocess->GetOutputsCount(postprocess_name); | |||
| for (size_t i = 0; i < method.servable_inputs.size(); i++) { | |||
| auto &input = method.servable_inputs[i]; | |||
| for (size_t i = 0; i < method.postprocess_inputs.size(); i++) { | |||
| auto &input = method.postprocess_inputs[i]; | |||
| if (input.first == kPredictPhaseTag_Input) { | |||
| if (input.second >= method.inputs.size()) { | |||
| return INFER_STATUS_LOG_ERROR(FAILED) | |||
| << "Model " << model_str << " method " << method.method_name << ", the servable " << i | |||
| << "Model " << model_str << " method " << method.method_name << ", the postprocess " << i | |||
| << "th input uses method " << input.second << "th input, that is greater than the method inputs size " | |||
| << method.inputs.size(); | |||
| } | |||
| } else if (input.first == kPredictPhaseTag_Preproces) { | |||
| if (input.second >= preprocess_outputs_count) { | |||
| return INFER_STATUS_LOG_ERROR(FAILED) | |||
| << "Model " << model_str << " method " << method.method_name << ", the servable " << i | |||
| << "Model " << model_str << " method " << method.method_name << ", the postprocess " << i | |||
| << "th input uses preprocess " << input.second | |||
| << "th output, that is greater than the preprocess outputs size " << preprocess_outputs_count; | |||
| } | |||
| } else if (input.first == kPredictPhaseTag_Predict) { | |||
| if (input.second >= common_meta.outputs_count) { | |||
| return INFER_STATUS_LOG_ERROR(FAILED) | |||
| << "Model " << model_str << " method " << method.method_name << ", the postprocess " << i | |||
| << "th input uses servable " << input.second | |||
| << "th output, that is greater than the servable outputs size " << common_meta.outputs_count; | |||
| } | |||
| } else { | |||
| return INFER_STATUS_LOG_ERROR(FAILED) | |||
| << "Model " << model_str << " method " << method.method_name << ", the data of servable " << i | |||
| << "Model " << model_str << " method " << method.method_name << ", the data of postprocess " << i | |||
| << "th input cannot not come from '" << input.first << "'"; | |||
| } | |||
| } | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| const auto &postprocess_name = method.postprocess_name; | |||
| if (!method.postprocess_name.empty()) { | |||
| auto postprocess = PostprocessStorage::Instance().GetPostprocess(postprocess_name); | |||
| if (postprocess == nullptr) { | |||
| return INFER_STATUS_LOG_ERROR(FAILED) << "Model " << model_str << " method " << method.method_name | |||
| << " postprocess " << postprocess_name << " not defined"; | |||
| } | |||
| postprocess_outputs_count = postprocess->GetOutputsCount(postprocess_name); | |||
| Status ServableSignature::CheckReturn(const MethodSignature &method, size_t preprocess_outputs_count, | |||
| size_t postprocess_outputs_count) const { | |||
| std::string model_str = servable_meta.Repr(); | |||
| const auto &common_meta = servable_meta.common_meta; | |||
| for (size_t i = 0; i < method.postprocess_inputs.size(); i++) { | |||
| auto &input = method.postprocess_inputs[i]; | |||
| if (input.first == kPredictPhaseTag_Input) { | |||
| if (input.second >= method.inputs.size()) { | |||
| return INFER_STATUS_LOG_ERROR(FAILED) | |||
| << "Model " << model_str << " method " << method.method_name << ", the postprocess " << i | |||
| << "th input uses method " << input.second | |||
| << "th input, that is greater than the method inputs size " << method.inputs.size(); | |||
| } | |||
| } else if (input.first == kPredictPhaseTag_Preproces) { | |||
| if (input.second >= preprocess_outputs_count) { | |||
| return INFER_STATUS_LOG_ERROR(FAILED) | |||
| << "Model " << model_str << " method " << method.method_name << ", the postprocess " << i | |||
| << "th input uses preprocess " << input.second | |||
| << "th output, that is greater than the preprocess outputs size " << preprocess_outputs_count; | |||
| } | |||
| } else if (input.first == kPredictPhaseTag_Predict) { | |||
| if (input.second >= servable_meta.outputs_count) { | |||
| return INFER_STATUS_LOG_ERROR(FAILED) | |||
| << "Model " << model_str << " method " << method.method_name << ", the postprocess " << i | |||
| << "th input uses servable " << input.second | |||
| << "th output, that is greater than the servable outputs size " << servable_meta.outputs_count; | |||
| } | |||
| } else { | |||
| return INFER_STATUS_LOG_ERROR(FAILED) | |||
| << "Model " << model_str << " method " << method.method_name << ", the data of postprocess " << i | |||
| << "th input cannot not come from '" << input.first << "'"; | |||
| } | |||
| for (size_t i = 0; i < method.returns.size(); i++) { | |||
| auto &input = method.returns[i]; | |||
| if (input.first == kPredictPhaseTag_Input) { | |||
| if (input.second >= method.inputs.size()) { | |||
| return INFER_STATUS_LOG_ERROR(FAILED) | |||
| << "Model " << model_str << " method " << method.method_name << ", the method " << i | |||
| << "th output uses method " << input.second << "th input, that is greater than the method inputs size " | |||
| << method.inputs.size(); | |||
| } | |||
| } | |||
| for (size_t i = 0; i < method.returns.size(); i++) { | |||
| auto &input = method.returns[i]; | |||
| if (input.first == kPredictPhaseTag_Input) { | |||
| if (input.second >= method.inputs.size()) { | |||
| return INFER_STATUS_LOG_ERROR(FAILED) | |||
| << "Model " << model_str << " method " << method.method_name << ", the method " << i | |||
| << "th output uses method " << input.second << "th input, that is greater than the method inputs size " | |||
| << method.inputs.size(); | |||
| } | |||
| } else if (input.first == kPredictPhaseTag_Preproces) { | |||
| if (input.second >= preprocess_outputs_count) { | |||
| return INFER_STATUS_LOG_ERROR(FAILED) | |||
| << "Model " << model_str << " method " << method.method_name << ", the method " << i | |||
| << "th output uses preprocess " << input.second | |||
| << "th output, that is greater than the preprocess outputs size " << preprocess_outputs_count; | |||
| } | |||
| } else if (input.first == kPredictPhaseTag_Predict) { | |||
| if (input.second >= servable_meta.outputs_count) { | |||
| return INFER_STATUS_LOG_ERROR(FAILED) | |||
| << "Model " << model_str << " method " << method.method_name << ", the method " << i | |||
| << "th output uses servable " << input.second | |||
| << "th output, that is greater than the servable outputs size " << servable_meta.outputs_count; | |||
| } | |||
| } else if (input.first == kPredictPhaseTag_Postprocess) { | |||
| if (input.second >= postprocess_outputs_count) { | |||
| return INFER_STATUS_LOG_ERROR(FAILED) | |||
| << "Model " << model_str << " method " << method.method_name << ", the method " << i | |||
| << "th output uses postprocess " << input.second | |||
| << "th output, that is greater than the postprocess outputs size " << postprocess_outputs_count; | |||
| } | |||
| } else { | |||
| } else if (input.first == kPredictPhaseTag_Preproces) { | |||
| if (input.second >= preprocess_outputs_count) { | |||
| return INFER_STATUS_LOG_ERROR(FAILED) | |||
| << "Model " << model_str << " method " << method.method_name << ", the method " << i | |||
| << "th output uses preprocess " << input.second | |||
| << "th output, that is greater than the preprocess outputs size " << preprocess_outputs_count; | |||
| } | |||
| } else if (input.first == kPredictPhaseTag_Predict) { | |||
| if (input.second >= common_meta.outputs_count) { | |||
| return INFER_STATUS_LOG_ERROR(FAILED) | |||
| << "Model " << model_str << " method " << method.method_name << ", the method " << i | |||
| << "th output uses servable " << input.second | |||
| << "th output, that is greater than the servable outputs size " << common_meta.outputs_count; | |||
| } | |||
| } else if (input.first == kPredictPhaseTag_Postprocess) { | |||
| if (input.second >= postprocess_outputs_count) { | |||
| return INFER_STATUS_LOG_ERROR(FAILED) | |||
| << "Model " << model_str << " method " << method.method_name << ", the data of method " << i | |||
| << "th output cannot not come from '" << input.first << "'"; | |||
| << "Model " << model_str << " method " << method.method_name << ", the method " << i | |||
| << "th output uses postprocess " << input.second | |||
| << "th output, that is greater than the postprocess outputs size " << postprocess_outputs_count; | |||
| } | |||
| } else { | |||
| return INFER_STATUS_LOG_ERROR(FAILED) | |||
| << "Model " << model_str << " method " << method.method_name << ", the data of method " << i | |||
| << "th output cannot not come from '" << input.first << "'"; | |||
| } | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| Status ServableSignature::Check() const { | |||
| std::set<std::string> method_set; | |||
| Status status; | |||
| for (auto &method : methods) { | |||
| if (method_set.count(method.method_name) > 0) { | |||
| return INFER_STATUS_LOG_ERROR(FAILED) | |||
| << "Model " << servable_meta.Repr() << " " << method.method_name << " has been defined repeatedly"; | |||
| } | |||
| method_set.emplace(method.method_name); | |||
| size_t preprocess_outputs_count = 0; | |||
| size_t postprocess_outputs_count = 0; | |||
| status = CheckPreprocessInput(method, &preprocess_outputs_count); | |||
| if (status != SUCCESS) { | |||
| return status; | |||
| } | |||
| status = CheckPredictInput(method, preprocess_outputs_count); | |||
| if (status != SUCCESS) { | |||
| return status; | |||
| } | |||
| status = CheckPostprocessInput(method, preprocess_outputs_count, &postprocess_outputs_count); | |||
| if (status != SUCCESS) { | |||
| return status; | |||
| } | |||
| status = CheckReturn(method, preprocess_outputs_count, postprocess_outputs_count); | |||
| if (status != SUCCESS) { | |||
| return status; | |||
| } | |||
| } | |||
| return SUCCESS; | |||
| @@ -216,7 +267,7 @@ bool ServableSignature::GetMethodDeclare(const std::string &method_name, MethodS | |||
| } | |||
| void ServableStorage::Register(const ServableSignature &def) { | |||
| auto model_name = def.servable_meta.servable_name; | |||
| auto model_name = def.servable_meta.common_meta.servable_name; | |||
| if (servable_signatures_map_.find(model_name) == servable_signatures_map_.end()) { | |||
| MSI_LOG_WARNING << "Servable " << model_name << " has already been defined"; | |||
| } | |||
| @@ -258,16 +309,60 @@ Status ServableStorage::RegisterMethod(const MethodSignature &method) { | |||
| return SUCCESS; | |||
| } | |||
| void ServableStorage::DeclareServable(const mindspore::serving::ServableMeta &servable) { | |||
| MSI_LOG_INFO << "Declare servable " << servable.servable_name; | |||
| auto it = servable_signatures_map_.find(servable.servable_name); | |||
| Status ServableStorage::DeclareServable(ServableMeta servable) { | |||
| auto &common_meta = servable.common_meta; | |||
| MSI_LOG_INFO << "Declare servable " << common_meta.servable_name; | |||
| servable.servable_type = kServableTypeLocal; | |||
| if (servable.local_meta.servable_file.empty()) { | |||
| return INFER_STATUS_LOG_ERROR(FAILED) | |||
| << "Declare servable " << common_meta.servable_name << " failed, servable_file cannot be empty"; | |||
| } | |||
| if (servable.local_meta.model_format == ModelType::kUnknownType) { | |||
| return INFER_STATUS_LOG_ERROR(FAILED) | |||
| << "Declare servable " << common_meta.servable_name << " failed, model_format is not inited"; | |||
| } | |||
| auto it = servable_signatures_map_.find(common_meta.servable_name); | |||
| if (it == servable_signatures_map_.end()) { | |||
| ServableSignature signature; | |||
| signature.servable_meta = servable; | |||
| servable_signatures_map_[servable.servable_name] = signature; | |||
| return; | |||
| servable_signatures_map_[common_meta.servable_name] = signature; | |||
| return SUCCESS; | |||
| } | |||
| it->second.servable_meta = servable; | |||
| auto &org_servable_meta = it->second.servable_meta; | |||
| if (org_servable_meta.servable_type != kServableTypeUnknown) { | |||
| return INFER_STATUS_LOG_ERROR(FAILED) | |||
| << "Servable " << common_meta.servable_name << " has already been declared as: " << servable.Repr(); | |||
| } | |||
| org_servable_meta = servable; | |||
| return SUCCESS; | |||
| } | |||
| Status ServableStorage::DeclareDistributedServable(ServableMeta servable) { | |||
| auto &common_meta = servable.common_meta; | |||
| MSI_LOG_INFO << "Declare servable " << common_meta.servable_name; | |||
| servable.servable_type = kServableTypeDistributed; | |||
| if (servable.distributed_meta.rank_size == 0) { | |||
| return INFER_STATUS_LOG_ERROR(FAILED) | |||
| << "Declare distributed servable " << common_meta.servable_name << " failed, rank_size cannot be 0"; | |||
| } | |||
| if (servable.distributed_meta.stage_size == 0) { | |||
| return INFER_STATUS_LOG_ERROR(FAILED) | |||
| << "Declare distributed servable " << common_meta.servable_name << " failed, stage_size cannot be 0"; | |||
| } | |||
| auto it = servable_signatures_map_.find(common_meta.servable_name); | |||
| if (it == servable_signatures_map_.end()) { | |||
| ServableSignature signature; | |||
| signature.servable_meta = servable; | |||
| servable_signatures_map_[common_meta.servable_name] = signature; | |||
| return SUCCESS; | |||
| } | |||
| auto &org_servable_meta = it->second.servable_meta; | |||
| if (org_servable_meta.servable_type != kServableTypeUnknown) { | |||
| return INFER_STATUS_LOG_ERROR(FAILED) | |||
| << "Servable " << common_meta.servable_name << " has already been declared as: " << servable.Repr(); | |||
| } | |||
| org_servable_meta = servable; | |||
| return SUCCESS; | |||
| } | |||
| Status ServableStorage::RegisterInputOutputInfo(const std::string &servable_name, size_t inputs_count, | |||
| @@ -277,18 +372,19 @@ Status ServableStorage::RegisterInputOutputInfo(const std::string &servable_name | |||
| return INFER_STATUS_LOG_ERROR(FAILED) << "RegisterInputOutputInfo failed, cannot find servable " << servable_name; | |||
| } | |||
| auto &servable_meta = it->second.servable_meta; | |||
| if (servable_meta.inputs_count != 0 && servable_meta.inputs_count != inputs_count) { | |||
| auto &common_meta = servable_meta.common_meta; | |||
| if (common_meta.inputs_count != 0 && common_meta.inputs_count != inputs_count) { | |||
| return INFER_STATUS_LOG_ERROR(FAILED) | |||
| << "RegisterInputOutputInfo failed, inputs count " << inputs_count << " not match old count " | |||
| << servable_meta.inputs_count << ",servable name " << servable_name; | |||
| << common_meta.inputs_count << ",servable name " << servable_name; | |||
| } | |||
| if (servable_meta.outputs_count != 0 && servable_meta.outputs_count != outputs_count) { | |||
| if (common_meta.outputs_count != 0 && common_meta.outputs_count != outputs_count) { | |||
| return INFER_STATUS_LOG_ERROR(FAILED) | |||
| << "RegisterInputOutputInfo failed, outputs count " << outputs_count << " not match old count " | |||
| << servable_meta.outputs_count << ",servable name " << servable_name; | |||
| << common_meta.outputs_count << ",servable name " << servable_name; | |||
| } | |||
| servable_meta.inputs_count = inputs_count; | |||
| servable_meta.outputs_count = outputs_count; | |||
| common_meta.inputs_count = inputs_count; | |||
| common_meta.outputs_count = outputs_count; | |||
| return SUCCESS; | |||
| } | |||
| @@ -298,8 +394,8 @@ std::vector<size_t> ServableStorage::GetInputOutputInfo(const std::string &serva | |||
| if (it == servable_signatures_map_.end()) { | |||
| return result; | |||
| } | |||
| result.push_back(it->second.servable_meta.inputs_count); | |||
| result.push_back(it->second.servable_meta.outputs_count); | |||
| result.push_back(it->second.servable_meta.common_meta.inputs_count); | |||
| result.push_back(it->second.servable_meta.common_meta.outputs_count); | |||
| return result; | |||
| } | |||
| @@ -81,19 +81,39 @@ struct RequestSpec { | |||
| std::string Repr() const; | |||
| }; | |||
| struct MS_API ServableMeta { | |||
| enum ServableType { | |||
| kServableTypeUnknown = 0, | |||
| kServableTypeLocal = 1, | |||
| kServableTypeDistributed = 2, | |||
| }; | |||
| struct CommonServableMeta { | |||
| std::string servable_name; | |||
| std::string servable_file; // file name | |||
| ModelType model_format; // OM, MindIR | |||
| bool with_batch_dim = true; // whether there is batch dim in model's inputs/outputs | |||
| std::vector<int> without_batch_dim_inputs; | |||
| size_t inputs_count = 0; | |||
| size_t outputs_count = 0; | |||
| }; | |||
| std::map<std::string, std::string> load_options; // Acl options | |||
| std::vector<int> without_batch_dim_inputs; | |||
| struct MS_API LocalServableMeta { | |||
| std::string servable_file; // file name | |||
| ModelType model_format = ModelType::kUnknownType; // OM, MindIR | |||
| std::map<std::string, std::string> load_options; // Acl options | |||
| void SetModelFormat(const std::string &format); | |||
| }; | |||
| struct DistributedServableMeta { | |||
| size_t rank_size = 0; | |||
| size_t stage_size = 0; | |||
| }; | |||
| struct MS_API ServableMeta { | |||
| ServableType servable_type = kServableTypeUnknown; | |||
| CommonServableMeta common_meta; | |||
| LocalServableMeta local_meta; | |||
| DistributedServableMeta distributed_meta; | |||
| std::string Repr() const; | |||
| void SetModelFormat(const std::string &format); | |||
| }; | |||
| struct ServableSignature { | |||
| @@ -102,6 +122,12 @@ struct ServableSignature { | |||
| Status Check() const; | |||
| bool GetMethodDeclare(const std::string &method_name, MethodSignature *method); | |||
| private: | |||
| Status CheckPreprocessInput(const MethodSignature &method, size_t *pre) const; | |||
| Status CheckPredictInput(const MethodSignature &method, size_t pre) const; | |||
| Status CheckPostprocessInput(const MethodSignature &method, size_t pre, size_t *post) const; | |||
| Status CheckReturn(const MethodSignature &method, size_t pre, size_t post) const; | |||
| }; | |||
| class MS_API ServableStorage { | |||
| @@ -111,7 +137,8 @@ class MS_API ServableStorage { | |||
| bool GetServableDef(const std::string &model_name, ServableSignature *def) const; | |||
| void DeclareServable(const ServableMeta &servable); | |||
| Status DeclareServable(ServableMeta servable); | |||
| Status DeclareDistributedServable(ServableMeta servable); | |||
| Status RegisterInputOutputInfo(const std::string &servable_name, size_t inputs_count, size_t outputs_count); | |||
| std::vector<size_t> GetInputOutputInfo(const std::string &servable_name) const; | |||
| @@ -144,6 +171,14 @@ static inline LogStream &operator<<(LogStream &stream, PredictPhaseTag data_type | |||
| return stream; | |||
| } | |||
| struct WorkerAgentSpec { | |||
| std::string agent_address; | |||
| uint32_t rank_id = 0; | |||
| std::vector<TensorInfo> input_infos; | |||
| std::vector<TensorInfo> output_infos; | |||
| uint32_t batch_size = 0; | |||
| }; | |||
| } // namespace mindspore::serving | |||
| #endif // MINDSPORE_SERVING_SERVABLE_H | |||
| @@ -27,7 +27,7 @@ | |||
| #include "common/instance.h" | |||
| #include "common/servable.h" | |||
| #include "master/notify_worker/base_notify.h" | |||
| #include "master/grpc/grpc_client.h" | |||
| #include "common/grpc_client.h" | |||
| namespace mindspore::serving { | |||
| @@ -1,73 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "master/grpc/grpc_client.h" | |||
| #include <string> | |||
| #include <utility> | |||
| #include "master/grpc/grpc_server.h" | |||
| namespace mindspore { | |||
| namespace serving { | |||
| std::unique_ptr<MSServiceClient> client_; | |||
| MSServiceClient::~MSServiceClient() { | |||
| if (in_running_) { | |||
| cq_.Shutdown(); | |||
| if (client_thread_.joinable()) { | |||
| try { | |||
| client_thread_.join(); | |||
| } catch (const std::system_error &) { | |||
| } catch (...) { | |||
| } | |||
| } | |||
| } | |||
| in_running_ = false; | |||
| } | |||
| void MSServiceClient::PredictAsync(const proto::PredictRequest &request, proto::PredictReply *reply, | |||
| std::shared_ptr<proto::MSWorker::Stub> stub, DispatchCallback callback) { | |||
| AsyncClientCall *call = new AsyncClientCall; | |||
| call->reply = reply; | |||
| call->callback = std::move(callback); | |||
| call->response_reader = stub->PrepareAsyncPredict(&call->context, request, &cq_); | |||
| call->response_reader->StartCall(); | |||
| call->response_reader->Finish(call->reply, &call->status, call); | |||
| MSI_LOG(INFO) << "Finish send Predict"; | |||
| } | |||
| void MSServiceClient::AsyncCompleteRpc() { | |||
| void *got_tag; | |||
| bool ok = false; | |||
| while (cq_.Next(&got_tag, &ok)) { | |||
| AsyncClientCall *call = static_cast<AsyncClientCall *>(got_tag); | |||
| if (call->status.ok()) { | |||
| call->callback(SUCCESS); | |||
| } else { | |||
| MSI_LOG_ERROR << "RPC failed: " << call->status.error_code() << ", " << call->status.error_message(); | |||
| call->callback(Status(FAILED, call->status.error_message())); | |||
| } | |||
| delete call; | |||
| } | |||
| } | |||
| void MSServiceClient::Start() { | |||
| client_thread_ = std::thread(&MSServiceClient::AsyncCompleteRpc, this); | |||
| in_running_ = true; | |||
| } | |||
| } // namespace serving | |||
| } // namespace mindspore | |||
| @@ -1,68 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_SERVING_MASTER_GRPC_CLIENT_H | |||
| #define MINDSPORE_SERVING_MASTER_GRPC_CLIENT_H | |||
| #include <grpcpp/grpcpp.h> | |||
| #include <grpcpp/health_check_service_interface.h> | |||
| #include <grpcpp/ext/proto_server_reflection_plugin.h> | |||
| #include <memory> | |||
| #include <functional> | |||
| #include <thread> | |||
| #include "common/serving_common.h" | |||
| #include "master/notify_worker/base_notify.h" | |||
| #include "proto/ms_service.pb.h" | |||
| #include "proto/ms_service.grpc.pb.h" | |||
| #include "proto/ms_master.pb.h" | |||
| #include "proto/ms_master.grpc.pb.h" | |||
| #include "proto/ms_worker.grpc.pb.h" | |||
| namespace mindspore { | |||
| namespace serving { | |||
| class MSServiceClient; | |||
| extern std::unique_ptr<MSServiceClient> client_; | |||
| using PredictOnFinish = std::function<void()>; | |||
| class MSServiceClient { | |||
| public: | |||
| MSServiceClient() = default; | |||
| ~MSServiceClient(); | |||
| void AsyncCompleteRpc(); | |||
| void Start(); | |||
| void PredictAsync(const proto::PredictRequest &request, proto::PredictReply *reply, | |||
| std::shared_ptr<proto::MSWorker::Stub> stub, DispatchCallback callback); | |||
| private: | |||
| struct AsyncClientCall { | |||
| grpc::ClientContext context; | |||
| grpc::Status status; | |||
| proto::PredictReply *reply; | |||
| DispatchCallback callback; | |||
| std::shared_ptr<grpc::ClientAsyncResponseReader<proto::PredictReply>> response_reader; | |||
| }; | |||
| grpc::CompletionQueue cq_; | |||
| std::thread client_thread_; | |||
| bool in_running_ = false; | |||
| }; | |||
| } // namespace serving | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_SERVING_MASTER_GRPC_CLIENT_H | |||
| @@ -22,12 +22,11 @@ | |||
| #include "common/serving_common.h" | |||
| #include "common/servable.h" | |||
| #include "proto/ms_service.pb.h" | |||
| #include "common/grpc_client.h" | |||
| namespace mindspore { | |||
| namespace serving { | |||
| using DispatchCallback = std::function<void(Status status)>; | |||
| class MS_API BaseNotifyWorker { | |||
| public: | |||
| BaseNotifyWorker() = default; | |||
| @@ -20,7 +20,6 @@ | |||
| #include <thread> | |||
| #include "common/exit_handle.h" | |||
| #include "common/grpc_server.h" | |||
| #include "master/grpc/grpc_client.h" | |||
| namespace mindspore { | |||
| namespace serving { | |||
| @@ -56,10 +55,10 @@ Status GrpcNotfiyWorker::DispatchAsync(const proto::PredictRequest &request, pro | |||
| << worker_address_; | |||
| } | |||
| if (!client_) { | |||
| client_ = std::make_unique<MSServiceClient>(); | |||
| client_ = std::make_unique<MSPredictClient>(); | |||
| client_->Start(); | |||
| } | |||
| client_->PredictAsync(request, reply, stub_, callback); | |||
| client_->PredictAsync(request, reply, stub_.get(), callback); | |||
| return SUCCESS; | |||
| } | |||
| @@ -39,7 +39,6 @@ Status Server::StartGrpcServer(const std::string &ip, uint32_t grpc_port, int ma | |||
| if (grpc_async_server_ != nullptr) { | |||
| return INFER_STATUS_LOG_ERROR(SYSTEM_ERROR) << "Serving Error: Serving gRPC server is already running"; | |||
| } | |||
| ExitSignalHandle::Instance().Start(); // handle ctrl+c to exit | |||
| if (max_msg_mb_size > gRpcMaxMBMsgSize) { | |||
| MSI_LOG_WARNING << "The maximum Serving gRPC message size is 512MB and will be updated from " << max_msg_mb_size | |||
| << "MB to 512MB"; | |||
| @@ -50,14 +49,12 @@ Status Server::StartGrpcServer(const std::string &ip, uint32_t grpc_port, int ma | |||
| } | |||
| Status Server::StartGrpcMasterServer(const std::string &ip, uint32_t grpc_port) { | |||
| ExitSignalHandle::Instance().Start(); // handle ctrl+c to exit | |||
| return grpc_manager_server_.Start(std::make_shared<MSMasterImpl>(dispatcher_), ip, grpc_port, gRpcMaxMBMsgSize, | |||
| "Master"); | |||
| } | |||
| Status Server::StartRestfulServer(const std::string &ip, uint32_t restful_port, int max_msg_mb_size, | |||
| int time_out_second) { | |||
| ExitSignalHandle::Instance().Start(); // handle ctrl+c to exit | |||
| return restful_server_.Start(ip, restful_port, max_msg_mb_size, time_out_second); | |||
| } | |||
| @@ -0,0 +1,63 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "python/agent/agent_py.h" | |||
| #include "common/exit_handle.h" | |||
| #include "worker/distributed_worker/agent_startup.h" | |||
| #include "worker/distributed_worker/worker_agent.h" | |||
| namespace mindspore::serving { | |||
| DistributedServableConfig PyAgent::GetAgentsConfigsFromWorker(const std::string &worker_ip, uint32_t worker_port) { | |||
| auto status = WorkerAgentStartUp::Instance().GetAgentsConfigsFromWorker(worker_ip, worker_port); | |||
| if (status != SUCCESS) { | |||
| MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); | |||
| } | |||
| DistributedServableConfig config; | |||
| status = WorkerAgentStartUp::Instance().GetDistributedServableConfig(&config); | |||
| if (status != SUCCESS) { | |||
| MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); | |||
| } | |||
| return config; | |||
| } | |||
| void PyAgent::NotifyFailed(const std::string &worker_ip, uint32_t worker_port) { | |||
| WorkerAgentStartUp::Instance().NotifyFailed(worker_ip, worker_port); | |||
| } | |||
| void PyAgent::StartAgent(const AgentStartUpConfig &start_config) { | |||
| auto status = WorkerAgent::Instance().StartAgent(start_config); | |||
| if (status != SUCCESS) { | |||
| MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); | |||
| } | |||
| } | |||
| void PyAgent::WaitAndClear() { | |||
| { | |||
| py::gil_scoped_release release; | |||
| ExitSignalHandle::Instance().AgentWait(); | |||
| } | |||
| WorkerAgent::Instance().Clear(); | |||
| MSI_LOG_INFO << "Python agent end wait and clear"; | |||
| } | |||
| void PyAgent::StopAndClear() { | |||
| ExitSignalHandle::Instance().Stop(); | |||
| WorkerAgent::Instance().Clear(); | |||
| } | |||
| } // namespace mindspore::serving | |||
| @@ -0,0 +1,47 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_SERVER_AGENT_PY_H | |||
| #define MINDSPORE_SERVER_AGENT_PY_H | |||
| #include <pybind11/pybind11.h> | |||
| #include <pybind11/numpy.h> | |||
| #include <pybind11/stl.h> | |||
| #include <string> | |||
| #include <memory> | |||
| #include "common/serving_common.h" | |||
| #include "worker/distributed_worker/common.h" | |||
| namespace py = pybind11; | |||
| namespace mindspore { | |||
| namespace serving { | |||
| class MS_API PyAgent { | |||
| public: | |||
| static void StartAgent(const AgentStartUpConfig &start_config); | |||
| static DistributedServableConfig GetAgentsConfigsFromWorker(const std::string &worker_ip, uint32_t worker_port); | |||
| static void WaitAndClear(); | |||
| static void StopAndClear(); | |||
| // from start up, not agent | |||
| static void NotifyFailed(const std::string &worker_ip, uint32_t worker_port); | |||
| }; | |||
| } // namespace serving | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_SERVER_AGENT_PY_H | |||
| @@ -23,10 +23,14 @@ | |||
| #include "common/servable.h" | |||
| #include "worker/context.h" | |||
| #include "python/master/master_py.h" | |||
| #include "python/agent/agent_py.h" | |||
| #include "common/exit_handle.h" | |||
| #include "worker/distributed_worker/worker_agent.h" | |||
| namespace mindspore::serving { | |||
| PYBIND11_MODULE(_mindspore_serving, m) { | |||
| void PyRegServable(pybind11::module *m_ptr) { | |||
| auto &m = *m_ptr; | |||
| // avoid as numpy object memory copy in PyTensor::AsPythonData | |||
| py::class_<TensorBase, TensorBasePtr>(m, "Tensor_"); | |||
| @@ -68,16 +72,30 @@ PYBIND11_MODULE(_mindspore_serving, m) { | |||
| .def_readwrite("version_number", &RequestSpec::version_number) | |||
| .def_readwrite("method_name", &RequestSpec::method_name); | |||
| py::class_<CommonServableMeta>(m, "CommonServableMeta_") | |||
| .def(py::init<>()) | |||
| .def_readwrite("servable_name", &CommonServableMeta::servable_name) | |||
| .def_readwrite("inputs_count", &CommonServableMeta::inputs_count) | |||
| .def_readwrite("outputs_count", &CommonServableMeta::outputs_count) | |||
| .def_readwrite("with_batch_dim", &CommonServableMeta::with_batch_dim) | |||
| .def_readwrite("without_batch_dim_inputs", &CommonServableMeta::without_batch_dim_inputs); | |||
| py::class_<LocalServableMeta>(m, "LocalServableMeta_") | |||
| .def(py::init<>()) | |||
| .def_readwrite("servable_file", &LocalServableMeta::servable_file) | |||
| .def_readwrite("options", &LocalServableMeta::load_options) | |||
| .def("set_model_format", &LocalServableMeta::SetModelFormat); | |||
| py::class_<DistributedServableMeta>(m, "DistributedServableMeta_") | |||
| .def(py::init<>()) | |||
| .def_readwrite("rank_size", &DistributedServableMeta::rank_size) | |||
| .def_readwrite("stage_size", &DistributedServableMeta::stage_size); | |||
| py::class_<ServableMeta>(m, "ServableMeta_") | |||
| .def(py::init<>()) | |||
| .def_readwrite("servable_name", &ServableMeta::servable_name) | |||
| .def_readwrite("inputs_count", &ServableMeta::inputs_count) | |||
| .def_readwrite("outputs_count", &ServableMeta::outputs_count) | |||
| .def_readwrite("servable_file", &ServableMeta::servable_file) | |||
| .def_readwrite("with_batch_dim", &ServableMeta::with_batch_dim) | |||
| .def_readwrite("options", &ServableMeta::load_options) | |||
| .def_readwrite("without_batch_dim_inputs", &ServableMeta::without_batch_dim_inputs) | |||
| .def("set_model_format", &ServableMeta::SetModelFormat); | |||
| .def_readwrite("common_meta", &ServableMeta::common_meta) | |||
| .def_readwrite("local_meta", &ServableMeta::local_meta) | |||
| .def_readwrite("distributed_meta", &ServableMeta::distributed_meta); | |||
| py::class_<ServableSignature>(m, "ServableSignature_") | |||
| .def(py::init<>()) | |||
| @@ -87,8 +105,34 @@ PYBIND11_MODULE(_mindspore_serving, m) { | |||
| py::class_<PyServableStorage>(m, "ServableStorage_") | |||
| .def_static("register_servable_input_output_info", &PyServableStorage::RegisterInputOutputInfo) | |||
| .def_static("register_method", &PyServableStorage::RegisterMethod) | |||
| .def_static("declare_servable", &PyServableStorage::DeclareServable); | |||
| .def_static("declare_servable", &PyServableStorage::DeclareServable) | |||
| .def_static("declare_distributed_servable", &PyServableStorage::DeclareDistributedServable); | |||
| py::class_<OneRankConfig>(m, "OneRankConfig_") | |||
| .def(py::init<>()) | |||
| .def_readwrite("device_id", &OneRankConfig::device_id) | |||
| .def_readwrite("ip", &OneRankConfig::ip); | |||
| py::class_<DistributedServableConfig>(m, "DistributedServableConfig_") | |||
| .def(py::init<>()) | |||
| .def_readwrite("common_meta", &DistributedServableConfig::common_meta) | |||
| .def_readwrite("distributed_meta", &DistributedServableConfig::distributed_meta) | |||
| .def_readwrite("rank_table_content", &DistributedServableConfig::rank_table_content) | |||
| .def_readwrite("rank_list", &DistributedServableConfig::rank_list); | |||
| } | |||
| void PyRegMaster(pybind11::module *m_ptr) { | |||
| auto &m = *m_ptr; | |||
| py::class_<PyMaster>(m, "Master_") | |||
| .def_static("start_grpc_server", &PyMaster::StartGrpcServer) | |||
| .def_static("start_grpc_master_server", &PyMaster::StartGrpcMasterServer) | |||
| .def_static("start_restful_server", &PyMaster::StartRestfulServer) | |||
| .def_static("wait_and_clear", &PyMaster::WaitAndClear) | |||
| .def_static("stop_and_clear", &PyMaster::StopAndClear); | |||
| } | |||
| void PyRegWorker(pybind11::module *m_ptr) { | |||
| auto &m = *m_ptr; | |||
| py::class_<TaskContext>(m, "TaskContext_").def(py::init<>()); | |||
| py::class_<TaskItem>(m, "TaskItem_") | |||
| @@ -108,6 +152,8 @@ PYBIND11_MODULE(_mindspore_serving, m) { | |||
| py::class_<PyWorker>(m, "Worker_") | |||
| .def_static("start_servable", &PyWorker::StartServable) | |||
| .def_static("start_servable_in_master", &PyWorker::StartServableInMaster) | |||
| .def_static("start_distributed_servable", &PyWorker::StartDistributedServable) | |||
| .def_static("start_distributed_servable_in_master", &PyWorker::StartDistributedServableInMaster) | |||
| .def_static("get_batch_size", &PyWorker::GetBatchSize) | |||
| .def_static("wait_and_clear", &PyWorker::WaitAndClear) | |||
| .def_static("stop_and_clear", PyWorker::StopAndClear) | |||
| @@ -130,17 +176,52 @@ PYBIND11_MODULE(_mindspore_serving, m) { | |||
| } | |||
| }) | |||
| .def("set_device_id", &ServableContext::SetDeviceId); | |||
| } | |||
| py::class_<PyMaster, std::shared_ptr<PyMaster>>(m, "Master_") | |||
| .def_static("start_grpc_server", &PyMaster::StartGrpcServer) | |||
| .def_static("start_grpc_master_server", &PyMaster::StartGrpcMasterServer) | |||
| .def_static("start_restful_server", &PyMaster::StartRestfulServer) | |||
| .def_static("wait_and_clear", &PyMaster::WaitAndClear) | |||
| .def_static("stop_and_clear", &PyMaster::StopAndClear); | |||
| void PyRegWorkerAgent(pybind11::module *m_ptr) { | |||
| auto &m = *m_ptr; | |||
| py::class_<PyAgent>(m, "WorkerAgent_") | |||
| .def_static("get_agents_config_from_worker", &PyAgent::GetAgentsConfigsFromWorker) | |||
| .def_static("wait_and_clear", &PyAgent::WaitAndClear) | |||
| .def_static("stop_and_clear", &PyAgent::StopAndClear) | |||
| .def_static("notify_failed", &PyAgent::NotifyFailed) | |||
| .def_static("start_agent", &PyAgent::StartAgent); | |||
| py::class_<AgentStartUpConfig>(m, "AgentStartUpConfig_") | |||
| .def(py::init<>()) | |||
| .def_readwrite("rank_id", &AgentStartUpConfig::rank_id) | |||
| .def_readwrite("device_id", &AgentStartUpConfig::device_id) | |||
| .def_readwrite("model_file_name", &AgentStartUpConfig::model_file_name) | |||
| .def_readwrite("group_file_name", &AgentStartUpConfig::group_file_name) | |||
| .def_readwrite("rank_table_json_file_name", &AgentStartUpConfig::rank_table_json_file_name) | |||
| .def_readwrite("agent_ip", &AgentStartUpConfig::agent_ip) | |||
| .def_readwrite("agent_port", &AgentStartUpConfig::agent_port) | |||
| .def_readwrite("worker_ip", &AgentStartUpConfig::worker_ip) | |||
| .def_readwrite("worker_port", &AgentStartUpConfig::worker_port) | |||
| .def_readwrite("common_meta", &AgentStartUpConfig::common_meta); | |||
| } | |||
| class PyExitSignalHandle { | |||
| public: | |||
| static void Start() { ExitSignalHandle::Instance().Start(); } | |||
| static bool HasStopped() { return ExitSignalHandle::Instance().HasStopped(); } | |||
| }; | |||
| // cppcheck-suppress syntaxError | |||
| PYBIND11_MODULE(_mindspore_serving, m) { | |||
| PyRegServable(&m); | |||
| PyRegMaster(&m); | |||
| PyRegWorker(&m); | |||
| PyRegWorkerAgent(&m); | |||
| py::class_<PyExitSignalHandle>(m, "ExitSignalHandle_") | |||
| .def_static("start", &PyExitSignalHandle::Start) | |||
| .def_static("has_stopped", &PyExitSignalHandle::HasStopped); | |||
| (void)py::module::import("atexit").attr("register")(py::cpp_function{[&]() -> void { | |||
| Server::Instance().Clear(); | |||
| Worker::GetInstance().Clear(); | |||
| WorkerAgent::Instance().Clear(); | |||
| }}); | |||
| } | |||
| @@ -25,7 +25,16 @@ void PyServableStorage::RegisterMethod(const MethodSignature &method) { | |||
| } | |||
| } | |||
| void PyServableStorage::DeclareServable(const ServableMeta &servable) { | |||
| ServableStorage::Instance().DeclareServable(servable); | |||
| auto status = ServableStorage::Instance().DeclareServable(servable); | |||
| if (status != SUCCESS) { | |||
| MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); | |||
| } | |||
| } | |||
| void PyServableStorage::DeclareDistributedServable(const ServableMeta &servable) { | |||
| auto status = ServableStorage::Instance().DeclareDistributedServable(servable); | |||
| if (status != SUCCESS) { | |||
| MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); | |||
| } | |||
| } | |||
| void PyServableStorage::RegisterInputOutputInfo(const std::string &servable_name, size_t inputs_count, | |||
| size_t outputs_count) { | |||
| @@ -27,6 +27,7 @@ class MS_API PyServableStorage { | |||
| static void RegisterMethod(const MethodSignature &method); | |||
| static void DeclareServable(const ServableMeta &servable); | |||
| static void DeclareDistributedServable(const ServableMeta &servable); | |||
| static void RegisterInputOutputInfo(const std::string &servable_name, size_t inputs_count, size_t outputs_count); | |||
| static void Clear(); | |||
| @@ -21,21 +21,33 @@ | |||
| #include "common/exit_handle.h" | |||
| #include "worker/notfiy_master/grpc_notify.h" | |||
| #include "worker/notfiy_master/local_notify.h" | |||
| #include "worker/local_servable/local_sevable.h" | |||
| #include "worker/distributed_worker/distributed_servable.h" | |||
| #include "worker/grpc/worker_server.h" | |||
| #include "worker/distributed_worker/distributed_process/distributed_server.h" | |||
| namespace mindspore::serving { | |||
| void PyWorker::StartServable(const std::string &model_directory, const std::string &model_name, uint32_t version_number, | |||
| const std::string &master_ip, uint32_t master_port, const std::string &host_ip, | |||
| uint32_t host_port) { | |||
| auto notify_master = std::make_shared<GrpcNotfiyMaster>(master_ip, master_port, host_ip, host_port); | |||
| auto status = Worker::GetInstance().StartServable(model_directory, model_name, version_number, notify_master); | |||
| const std::string &master_ip, uint32_t master_port, const std::string &worker_ip, | |||
| uint32_t worker_port) { | |||
| auto notify_master = std::make_shared<GrpcNotfiyMaster>(master_ip, master_port, worker_ip, worker_port); | |||
| auto servable = std::make_shared<LocalModelServable>(); | |||
| auto status = servable->StartServable(model_directory, model_name, version_number); | |||
| if (status != SUCCESS) { | |||
| MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); | |||
| } | |||
| status = Worker::GetInstance().StartGrpcServer(host_ip, host_port); | |||
| status = Worker::GetInstance().StartServable(servable, notify_master); | |||
| if (status != SUCCESS) { | |||
| MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); | |||
| } | |||
| // start grpc server | |||
| auto grpc_sever = std::make_shared<MSWorkerServer>(); | |||
| status = Worker::GetInstance().StartGrpcServer(grpc_sever, worker_ip, worker_port); | |||
| if (status != SUCCESS) { | |||
| MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); | |||
| } | |||
| status = Worker::GetInstance().StartVersionController(); | |||
| if (status != SUCCESS) { | |||
| MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); | |||
| @@ -45,7 +57,69 @@ void PyWorker::StartServable(const std::string &model_directory, const std::stri | |||
| void PyWorker::StartServableInMaster(const std::string &model_directory, const std::string &model_name, | |||
| uint32_t version_number) { | |||
| auto notify_master = std::make_shared<LocalNotifyMaster>(); | |||
| auto status = Worker::GetInstance().StartServable(model_directory, model_name, version_number, notify_master); | |||
| auto servable = std::make_shared<LocalModelServable>(); | |||
| auto status = servable->StartServable(model_directory, model_name, version_number); | |||
| if (status != SUCCESS) { | |||
| MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); | |||
| } | |||
| status = Worker::GetInstance().StartServable(servable, notify_master); | |||
| if (status != SUCCESS) { | |||
| MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); | |||
| } | |||
| status = Worker::GetInstance().StartVersionController(); | |||
| if (status != SUCCESS) { | |||
| MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); | |||
| } | |||
| } | |||
| void PyWorker::StartDistributedServable(const std::string &servable_directory, const std::string &servable_name, | |||
| const std::string &rank_table_json_file, uint32_t version_number, | |||
| const std::string &worker_ip, uint32_t worker_port, | |||
| const std::string &master_ip, uint32_t master_port, | |||
| uint32_t wait_agents_time_in_seconds) { | |||
| Status status; | |||
| auto servable = std::make_shared<DistributedServable>(); | |||
| auto grpc_sever = std::make_shared<MSDistributedWorkerServer>(servable); | |||
| status = Worker::GetInstance().StartGrpcServer(grpc_sever, worker_ip, worker_port); | |||
| if (status != SUCCESS) { | |||
| MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); | |||
| } | |||
| auto notify_master = std::make_shared<GrpcNotfiyMaster>(master_ip, master_port, worker_ip, worker_port); | |||
| status = servable->StartServable(servable_directory, servable_name, rank_table_json_file, version_number, | |||
| wait_agents_time_in_seconds); | |||
| if (status != SUCCESS) { | |||
| MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); | |||
| } | |||
| status = Worker::GetInstance().StartServable(servable, notify_master); | |||
| if (status != SUCCESS) { | |||
| MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); | |||
| } | |||
| status = Worker::GetInstance().StartVersionController(); | |||
| if (status != SUCCESS) { | |||
| MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); | |||
| } | |||
| } | |||
| void PyWorker::StartDistributedServableInMaster(const std::string &servable_directory, const std::string &servable_name, | |||
| const std::string &rank_table_json_file, uint32_t version_number, | |||
| const std::string &worker_ip, uint32_t worker_port, | |||
| uint32_t wait_agents_time_in_seconds) { | |||
| Status status; | |||
| auto servable = std::make_shared<DistributedServable>(); | |||
| auto grpc_sever = std::make_shared<MSDistributedWorkerServer>(servable); | |||
| status = Worker::GetInstance().StartGrpcServer(grpc_sever, worker_ip, worker_port); | |||
| if (status != SUCCESS) { | |||
| MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); | |||
| } | |||
| auto notify_master = std::make_shared<LocalNotifyMaster>(); | |||
| status = servable->StartServable(servable_directory, servable_name, rank_table_json_file, version_number, | |||
| wait_agents_time_in_seconds); | |||
| if (status != SUCCESS) { | |||
| MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); | |||
| } | |||
| status = Worker::GetInstance().StartServable(servable, notify_master); | |||
| if (status != SUCCESS) { | |||
| MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); | |||
| } | |||
| @@ -34,6 +34,16 @@ class MS_API PyWorker { | |||
| static void StartServableInMaster(const std::string &model_directory, const std::string &model_name, | |||
| uint32_t version_number); | |||
| static void StartDistributedServable(const std::string &servable_directory, const std::string &servable_name, | |||
| const std::string &rank_table_json_file, uint32_t version_number, | |||
| const std::string &worker_ip, uint32_t worker_port, const std::string &master_ip, | |||
| uint32_t master_port, uint32_t wait_agents_time_in_seconds); | |||
| static void StartDistributedServableInMaster(const std::string &servable_directory, const std::string &servable_name, | |||
| const std::string &rank_table_json_file, uint32_t version_number, | |||
| const std::string &worker_ip, uint32_t worker_port, | |||
| uint32_t wait_agents_time_in_seconds); | |||
| static int GetBatchSize(); | |||
| static void WaitAndClear(); | |||
| static void StopAndClear(); | |||
| @@ -0,0 +1,34 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "worker/distributed_worker/agent_executor.h" | |||
| namespace mindspore { | |||
| namespace serving { | |||
| Status WorkerAgentExecutor::LoadModelFromFile(const AgentStartUpConfig &config) { return Status(); } | |||
| Status WorkerAgentExecutor::UnloadModel() { return Status(); } | |||
| Status WorkerAgentExecutor::ExecuteModel(const std::vector<TensorBasePtr> &request, std::vector<TensorBasePtr> *reply) { | |||
| return Status(); | |||
| } | |||
| std::vector<serving::TensorInfo> WorkerAgentExecutor::GetInputInfos() const { | |||
| return std::vector<serving::TensorInfo>(); | |||
| } | |||
| std::vector<serving::TensorInfo> WorkerAgentExecutor::GetOutputInfos() const { | |||
| return std::vector<serving::TensorInfo>(); | |||
| } | |||
| ssize_t WorkerAgentExecutor::GetBatchSize() const { return 0; } | |||
| } // namespace serving | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,48 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_SERVING_WORKER_AGENT_EXECUTOR_H | |||
| #define MINDSPORE_SERVING_WORKER_AGENT_EXECUTOR_H | |||
| #include <vector> | |||
| #include "common/serving_common.h" | |||
| #include "worker/inference/inference.h" | |||
| #include "worker/distributed_worker/common.h" | |||
| namespace mindspore { | |||
| namespace serving { | |||
| class MS_API WorkerAgentExecutor { | |||
| public: | |||
| // from python | |||
| Status LoadModelFromFile(const AgentStartUpConfig &config); | |||
| // ctrl+c, worker exit | |||
| Status UnloadModel(); | |||
| // from worker | |||
| Status ExecuteModel(const std::vector<TensorBasePtr> &request, std::vector<TensorBasePtr> *reply); | |||
| // for register | |||
| std::vector<serving::TensorInfo> GetInputInfos() const; | |||
| std::vector<serving::TensorInfo> GetOutputInfos() const; | |||
| ssize_t GetBatchSize() const; | |||
| }; | |||
| } // namespace serving | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_SERVING_WORKER_AGENT_EXECUTOR_H | |||
| @@ -0,0 +1,37 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "worker/distributed_worker/agent_process/agent_process.h" | |||
| #include "worker/distributed_worker/worker_agent.h" | |||
| namespace mindspore { | |||
| namespace serving { | |||
| grpc::Status MSAgentImpl::Exit(grpc::ServerContext *context, const proto::DistributedExitRequest *request, | |||
| proto::DistributedExitReply *reply) { | |||
| MSI_LOG(INFO) << "Distributed Worker Exit"; | |||
| WorkerAgent::Instance().StopAgent(false); | |||
| return grpc::Status::OK; | |||
| } | |||
| grpc::Status MSAgentImpl::Predict(grpc::ServerContext *context, const proto::DistributedPredictRequest *request, | |||
| proto::DistributedPredictReply *reply) { | |||
| MSI_LOG(INFO) << "Begin call service Eval"; | |||
| WorkerAgent::Instance().Run(*request, reply); | |||
| return grpc::Status::OK; | |||
| } | |||
| } // namespace serving | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,42 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_SERVING_WORKER_AGENT_PROCESS_H | |||
| #define MINDSPORE_SERVING_WORKER_AGENT_PROCESS_H | |||
| #include <grpcpp/grpcpp.h> | |||
| #include <grpcpp/health_check_service_interface.h> | |||
| #include <grpcpp/ext/proto_server_reflection_plugin.h> | |||
| #include "common/serving_common.h" | |||
| #include "proto/ms_agent.pb.h" | |||
| #include "proto/ms_agent.grpc.pb.h" | |||
| namespace mindspore { | |||
| namespace serving { | |||
| // Service Implement | |||
| class MSAgentImpl final : public proto::MSAgent::Service { | |||
| public: | |||
| grpc::Status Predict(grpc::ServerContext *context, const proto::DistributedPredictRequest *request, | |||
| proto::DistributedPredictReply *reply) override; | |||
| grpc::Status Exit(grpc::ServerContext *context, const proto::DistributedExitRequest *request, | |||
| proto::DistributedExitReply *reply) override; | |||
| }; | |||
| } // namespace serving | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_SERVING_WORKER_AGENT_PROCESS_H | |||
| @@ -0,0 +1,45 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "worker/distributed_worker/agent_startup.h" | |||
| #include "worker/distributed_worker/notify_distributed/notify_worker.h" | |||
| namespace mindspore { | |||
| namespace serving { | |||
| WorkerAgentStartUp &WorkerAgentStartUp::Instance() { | |||
| static WorkerAgentStartUp instance; | |||
| return instance; | |||
| } | |||
| Status WorkerAgentStartUp::GetAgentsConfigsFromWorker(const std::string &worker_ip, uint32_t worker_port) { | |||
| return Status(); | |||
| } | |||
| Status WorkerAgentStartUp::GetDistributedServableConfig(DistributedServableConfig *config) { | |||
| MSI_EXCEPTION_IF_NULL(config); | |||
| if (config_.rank_list.empty()) { | |||
| return INFER_STATUS_LOG_ERROR(FAILED) << "Rank table config is not ready"; | |||
| } | |||
| *config = config_; | |||
| return SUCCESS; | |||
| } | |||
| Status WorkerAgentStartUp::NotifyFailed(const std::string &worker_ip, uint32_t worker_port) { | |||
| return GrpcNotifyDistributeWorker::NotifyFailed(worker_ip, worker_port); | |||
| } | |||
| } // namespace serving | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,48 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_SERVING_WORKER_AGENT_STARTUP_H | |||
| #define MINDSPORE_SERVING_WORKER_AGENT_STARTUP_H | |||
| #include <vector> | |||
| #include <string> | |||
| #include "common/serving_common.h" | |||
| #include "worker/distributed_worker/common.h" | |||
| #include "worker/inference/inference.h" | |||
| namespace mindspore { | |||
| namespace serving { | |||
| class MS_API WorkerAgentStartUp { | |||
| public: | |||
| static WorkerAgentStartUp &Instance(); | |||
| // from python, worker_agent.py | |||
| // start_worker_agent | |||
| // step1, get agents config from worker | |||
| Status GetAgentsConfigsFromWorker(const std::string &worker_ip, uint32_t worker_port); | |||
| // step2, invoke from python | |||
| Status GetDistributedServableConfig(DistributedServableConfig *config); | |||
| Status NotifyFailed(const std::string &worker_ip, uint32_t worker_port); | |||
| private: | |||
| DistributedServableConfig config_; | |||
| std::string worker_address_; | |||
| }; | |||
| } // namespace serving | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_SERVING_WORKER_AGENT_STARTUP_H | |||
| @@ -0,0 +1,61 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_SERVING_DISTRIBUTED_WORKER_COMMON_H | |||
| #define MINDSPORE_SERVING_DISTRIBUTED_WORKER_COMMON_H | |||
| #include <vector> | |||
| #include <string> | |||
| #include <map> | |||
| #include "common/serving_common.h" | |||
| #include "worker/inference/inference.h" | |||
| #include "common/servable.h" | |||
| namespace mindspore { | |||
| namespace serving { | |||
| struct OneRankConfig { | |||
| std::string ip; | |||
| uint32_t device_id = 0; | |||
| }; | |||
| struct DistributedServableConfig { | |||
| std::string rank_table_content; | |||
| std::vector<OneRankConfig> rank_list; | |||
| CommonServableMeta common_meta; | |||
| DistributedServableMeta distributed_meta; | |||
| }; | |||
| struct AgentStartUpConfig { | |||
| uint32_t rank_id; | |||
| uint32_t device_id; | |||
| std::string model_file_name; | |||
| std::string group_file_name; | |||
| std::string rank_table_json_file_name; | |||
| std::string agent_ip; | |||
| uint32_t agent_port; | |||
| std::string worker_ip; | |||
| uint32_t worker_port; | |||
| CommonServableMeta common_meta; | |||
| }; | |||
| } // namespace serving | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_SERVING_DISTRIBUTED_WORKER_COMMON_H | |||
| @@ -0,0 +1,72 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "worker/distributed_worker/distributed_process/distributed_process.h" | |||
| #include "worker/worker.h" | |||
| #include "common/proto_tensor.h" | |||
| namespace mindspore { | |||
| namespace serving { | |||
| grpc::Status MSDistributedImpl::AgentRegister(grpc::ServerContext *context, const proto::AgentRegisterRequest *request, | |||
| proto::AgentRegisterReply *reply) { | |||
| MSI_EXCEPTION_IF_NULL(request); | |||
| MSI_EXCEPTION_IF_NULL(reply); | |||
| for (auto &spec : request->agent_spec()) { | |||
| WorkerAgentSpec agent_spec; | |||
| agent_spec.agent_address = request->address(); | |||
| GrpcTensorHelper::CopyFromAgentSpec(spec, &agent_spec); | |||
| Status status(FAILED); | |||
| status = servable_->RegisterAgent(agent_spec); | |||
| if (status != SUCCESS) { | |||
| MSI_LOG(ERROR) << "Agent Register FAILED"; | |||
| } | |||
| } | |||
| return grpc::Status::OK; | |||
| } | |||
| grpc::Status MSDistributedImpl::AgentExit(grpc::ServerContext *context, const proto::AgentExitRequest *request, | |||
| proto::AgentExitReply *reply) { | |||
| MSI_EXCEPTION_IF_NULL(request); | |||
| MSI_EXCEPTION_IF_NULL(reply); | |||
| for (auto &spec : request->agent_spec()) { | |||
| WorkerAgentSpec agent_spec; | |||
| agent_spec.agent_address = request->address(); | |||
| GrpcTensorHelper::CopyFromAgentSpec(spec, &agent_spec); | |||
| Status status(FAILED); | |||
| status = servable_->UnregisterAgent(agent_spec); | |||
| if (status != SUCCESS) { | |||
| MSI_LOG(ERROR) << "Agent Exit FAILED"; | |||
| } | |||
| } | |||
| if (Worker::GetInstance().IsRunning()) { | |||
| Worker::GetInstance().StopServable(); | |||
| } | |||
| return grpc::Status::OK; | |||
| } | |||
| grpc::Status MSDistributedImpl::AgentFailed(grpc::ServerContext *context, const proto::AgentFailedRequest *request, | |||
| proto::AgentFailedReply *reply) { | |||
| if (Worker::GetInstance().IsRunning()) { | |||
| MSI_LOG_ERROR << "Expect worker should not be running"; | |||
| Worker::GetInstance().StopServable(); | |||
| } else { | |||
| servable_->OnAgentFailed(); | |||
| } | |||
| return grpc::Status::OK; | |||
| } | |||
| } // namespace serving | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,54 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_SERVING_DISTRIBUTED_WORKER_WORKER_PROCESS_H | |||
| #define MINDSPORE_SERVING_DISTRIBUTED_WORKER_WORKER_PROCESS_H | |||
| #include <grpcpp/grpcpp.h> | |||
| #include <grpcpp/health_check_service_interface.h> | |||
| #include <grpcpp/ext/proto_server_reflection_plugin.h> | |||
| #include <memory> | |||
| #include "common/serving_common.h" | |||
| #include "proto/ms_service.pb.h" | |||
| #include "proto/ms_service.grpc.pb.h" | |||
| #include "proto/ms_distributed.pb.h" | |||
| #include "proto/ms_distributed.grpc.pb.h" | |||
| #include "worker/distributed_worker/distributed_servable.h" | |||
| #include "worker/grpc/worker_process.h" | |||
| namespace mindspore { | |||
| namespace serving { | |||
| // Service Implement | |||
| class MSDistributedImpl final : public MSWorkerImpl { | |||
| public: | |||
| explicit MSDistributedImpl(std::shared_ptr<DistributedServable> servable) : servable_(servable) {} | |||
| ~MSDistributedImpl() = default; | |||
| grpc::Status AgentRegister(grpc::ServerContext *context, const proto::AgentRegisterRequest *request, | |||
| proto::AgentRegisterReply *reply) override; | |||
| grpc::Status AgentExit(grpc::ServerContext *context, const proto::AgentExitRequest *request, | |||
| proto::AgentExitReply *reply) override; | |||
| grpc::Status AgentFailed(grpc::ServerContext *context, const proto::AgentFailedRequest *request, | |||
| proto::AgentFailedReply *reply) override; | |||
| private: | |||
| std::shared_ptr<DistributedServable> servable_; | |||
| }; | |||
| } // namespace serving | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_SERVING_DISTRIBUTED_WORKER_WORKER_PROCESS_H | |||
| @@ -0,0 +1,37 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "worker/distributed_worker/distributed_process/distributed_server.h" | |||
| #include <string> | |||
| #include <memory> | |||
| #include <utility> | |||
| #include "common/grpc_server.h" | |||
| namespace mindspore { | |||
| namespace serving { | |||
| Status MSDistributedWorkerServer::StartWorkerGrpcServer(const std::string &hostname, int32_t port) { | |||
| if (in_running_) { | |||
| return INFER_STATUS_LOG_ERROR(FAILED) << "Worker grpc server is already running"; | |||
| } | |||
| auto impl = std::make_unique<MSDistributedImpl>(servable_); | |||
| async_server_ = std::make_unique<DistributedWorkerGrpcServer>(hostname, port, impl.get()); | |||
| service_impl_ = std::move(impl); | |||
| return Init(); | |||
| } | |||
| } // namespace serving | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,178 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_SERVING_WORKER_DISTRIBUTED_WORKER_SERVER_H | |||
| #define MINDSPORE_SERVING_WORKER_DISTRIBUTED_WORKER_SERVER_H | |||
| #include <grpcpp/grpcpp.h> | |||
| #include <grpcpp/health_check_service_interface.h> | |||
| #include <grpcpp/ext/proto_server_reflection_plugin.h> | |||
| #include <memory> | |||
| #include <string> | |||
| #include "common/serving_common.h" | |||
| #include "proto/ms_worker.pb.h" | |||
| #include "proto/ms_worker.grpc.pb.h" | |||
| #include "common/grpc_async_server.h" | |||
| #include "worker/grpc/worker_process.h" | |||
| #include "worker/grpc/worker_server.h" | |||
| #include "worker/distributed_worker/distributed_process/distributed_process.h" | |||
| namespace mindspore { | |||
| namespace serving { | |||
| // Service Implement | |||
| class MS_API MSDistributedWorkerServer : public MSWorkerServer { | |||
| public: | |||
| explicit MSDistributedWorkerServer(std::shared_ptr<DistributedServable> servable) : servable_(servable) {} | |||
| ~MSDistributedWorkerServer() = default; | |||
| Status StartWorkerGrpcServer(const std::string &hostname, int32_t port) override; | |||
| private: | |||
| std::shared_ptr<DistributedServable> servable_; | |||
| }; | |||
| class DistributedServiceContext : public WorkerServiceContext { | |||
| public: | |||
| DistributedServiceContext(MSDistributedImpl *service_impl, proto::MSWorker::AsyncService *async_service, | |||
| grpc::ServerCompletionQueue *cq) | |||
| : WorkerServiceContext(service_impl, async_service, cq), dist_service_impl_(service_impl) {} | |||
| protected: | |||
| MSDistributedImpl *dist_service_impl_ = nullptr; | |||
| }; | |||
| // Service Implement | |||
| class WorkerAgentRegisterContext : public DistributedServiceContext { | |||
| public: | |||
| WorkerAgentRegisterContext(MSDistributedImpl *service_impl, proto::MSWorker::AsyncService *async_service, | |||
| grpc::ServerCompletionQueue *cq) | |||
| : DistributedServiceContext(service_impl, async_service, cq), responder_(&ctx_) {} | |||
| ~WorkerAgentRegisterContext() = default; | |||
| static Status EnqueueRequest(MSDistributedImpl *service_impl, proto::MSWorker::AsyncService *async_service, | |||
| grpc::ServerCompletionQueue *cq) { | |||
| auto call = new WorkerAgentRegisterContext(service_impl, async_service, cq); | |||
| call->StartEnqueueRequest(); | |||
| return SUCCESS; | |||
| } | |||
| void StartEnqueueRequest() override { | |||
| state_ = STATE::PROCESS; | |||
| async_service_->RequestAgentRegister(&ctx_, &request_, &responder_, cq_, cq_, this); | |||
| } | |||
| void HandleRequest() override { | |||
| EnqueueRequest(dist_service_impl_, async_service_, cq_); | |||
| state_ = STATE::FINISH; | |||
| grpc::Status status = dist_service_impl_->AgentRegister(&ctx_, &request_, &response_); | |||
| responder_.Finish(response_, status, this); | |||
| } | |||
| private: | |||
| grpc::ServerAsyncResponseWriter<proto::AgentRegisterReply> responder_; | |||
| proto::AgentRegisterRequest request_; | |||
| proto::AgentRegisterReply response_; | |||
| }; | |||
| class WorkerAgentExitContext : public DistributedServiceContext { | |||
| public: | |||
| WorkerAgentExitContext(MSDistributedImpl *service_impl, proto::MSWorker::AsyncService *async_service, | |||
| grpc::ServerCompletionQueue *cq) | |||
| : DistributedServiceContext(service_impl, async_service, cq), responder_(&ctx_) {} | |||
| ~WorkerAgentExitContext() = default; | |||
| static Status EnqueueRequest(MSDistributedImpl *service_impl, proto::MSWorker::AsyncService *async_service, | |||
| grpc::ServerCompletionQueue *cq) { | |||
| auto call = new WorkerAgentExitContext(service_impl, async_service, cq); | |||
| call->StartEnqueueRequest(); | |||
| return SUCCESS; | |||
| } | |||
| void StartEnqueueRequest() override { | |||
| state_ = STATE::PROCESS; | |||
| async_service_->RequestAgentExit(&ctx_, &request_, &responder_, cq_, cq_, this); | |||
| } | |||
| void HandleRequest() override { | |||
| EnqueueRequest(dist_service_impl_, async_service_, cq_); | |||
| state_ = STATE::FINISH; | |||
| grpc::Status status = dist_service_impl_->AgentExit(&ctx_, &request_, &response_); | |||
| responder_.Finish(response_, status, this); | |||
| } | |||
| private: | |||
| grpc::ServerAsyncResponseWriter<proto::AgentExitReply> responder_; | |||
| proto::AgentExitRequest request_; | |||
| proto::AgentExitReply response_; | |||
| }; | |||
| class WorkerAgentFailedContext : public DistributedServiceContext { | |||
| public: | |||
| WorkerAgentFailedContext(MSDistributedImpl *service_impl, proto::MSWorker::AsyncService *async_service, | |||
| grpc::ServerCompletionQueue *cq) | |||
| : DistributedServiceContext(service_impl, async_service, cq), responder_(&ctx_) {} | |||
| ~WorkerAgentFailedContext() = default; | |||
| static Status EnqueueRequest(MSDistributedImpl *service_impl, proto::MSWorker::AsyncService *async_service, | |||
| grpc::ServerCompletionQueue *cq) { | |||
| auto call = new WorkerAgentFailedContext(service_impl, async_service, cq); | |||
| call->StartEnqueueRequest(); | |||
| return SUCCESS; | |||
| } | |||
| void StartEnqueueRequest() override { | |||
| state_ = STATE::PROCESS; | |||
| async_service_->RequestAgentFailed(&ctx_, &request_, &responder_, cq_, cq_, this); | |||
| } | |||
| void HandleRequest() override { | |||
| EnqueueRequest(dist_service_impl_, async_service_, cq_); | |||
| state_ = STATE::FINISH; | |||
| grpc::Status status = dist_service_impl_->AgentFailed(&ctx_, &request_, &response_); | |||
| responder_.Finish(response_, status, this); | |||
| } | |||
| private: | |||
| grpc::ServerAsyncResponseWriter<proto::AgentFailedReply> responder_; | |||
| proto::AgentFailedRequest request_; | |||
| proto::AgentFailedReply response_; | |||
| }; | |||
| class DistributedWorkerGrpcServer : public WorkerGrpcServer { | |||
| public: | |||
| DistributedWorkerGrpcServer(const std::string &host, int32_t port, MSDistributedImpl *service_impl) | |||
| : WorkerGrpcServer(host, port, service_impl), distributed_service_impl_(service_impl) {} | |||
| ~DistributedWorkerGrpcServer() = default; | |||
| Status EnqueueRequest() { | |||
| WorkerGrpcServer::EnqueueRequest(); | |||
| WorkerAgentRegisterContext::EnqueueRequest(distributed_service_impl_, &svc_, cq_.get()); | |||
| WorkerAgentExitContext::EnqueueRequest(distributed_service_impl_, &svc_, cq_.get()); | |||
| WorkerAgentFailedContext::EnqueueRequest(distributed_service_impl_, &svc_, cq_.get()); | |||
| return SUCCESS; | |||
| } | |||
| private: | |||
| MSDistributedImpl *distributed_service_impl_; | |||
| }; | |||
| } // namespace serving | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_SERVING_WORKER_DISTRIBUTED_WORKER_SERVER_H | |||
| @@ -0,0 +1,335 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "worker/distributed_worker/distributed_servable.h" | |||
| #include <vector> | |||
| #include <string> | |||
| #include <set> | |||
| #include "worker/distributed_worker/notify_agent/notify_agent.h" | |||
| #include "common/exit_handle.h" | |||
| namespace mindspore { | |||
| namespace serving { | |||
| DistributedServable::~DistributedServable() { Clear(); } | |||
| std::string DistributedServable::GetServableName() const { return servable_name_; } | |||
| uint64_t DistributedServable::GetServableVersion() const { return version_number_; } | |||
| Status DistributedServable::Predict(const std::vector<TensorBasePtr> &input, std::vector<TensorBasePtr> *output) { | |||
| if (!model_loaded_) { | |||
| MSI_LOG_EXCEPTION << "Model has not been loaded"; | |||
| } | |||
| return Status(); | |||
| } | |||
| std::vector<TensorInfo> DistributedServable::GetInputInfos() const { | |||
| if (!model_loaded_) { | |||
| MSI_LOG_EXCEPTION << "Model has not been loaded"; | |||
| } | |||
| return input_infos_; | |||
| } | |||
| std::vector<TensorInfo> DistributedServable::GetOutputInfos() const { | |||
| if (!model_loaded_) { | |||
| MSI_LOG_EXCEPTION << "Model has not been loaded"; | |||
| } | |||
| return output_infos_; | |||
| } | |||
| uint64_t DistributedServable::GetBatchSize() const { | |||
| if (!model_loaded_) { | |||
| MSI_LOG_EXCEPTION << "Model has not been loaded"; | |||
| } | |||
| return batch_size_; | |||
| } | |||
| Status DistributedServable::GetDistributedServableConfig(DistributedServableConfig *config) const { | |||
| *config = config_; | |||
| return SUCCESS; | |||
| } | |||
| void DistributedServable::SetWaitAgentsPromise(bool flag) { | |||
| if (!promise_set_flag_.test_and_set()) { | |||
| agents_promise_.set_value(flag); | |||
| } | |||
| } | |||
| Status DistributedServable::RegisterAgent(const WorkerAgentSpec &agent_spec) { | |||
| std::unique_lock<std::mutex> lock{mutex_}; | |||
| if (agent_spec.rank_id < config_.distributed_meta.rank_size) { | |||
| return INFER_STATUS_LOG_ERROR(FAILED) | |||
| << "Invalid rank id " << agent_spec.rank_id << ", rank size " << config_.distributed_meta.rank_size; | |||
| } | |||
| DistributedAgentContext context; | |||
| auto it = agent_spec_map_.find(agent_spec.rank_id); | |||
| if (it != agent_spec_map_.end()) { | |||
| MSI_LOG_WARNING << "rank_id " << agent_spec.rank_id << " has been registered"; | |||
| return SUCCESS; | |||
| } | |||
| context.agent_spec_ = agent_spec; | |||
| std::shared_ptr<BaseNotifyAgent> notify_agent = std::make_shared<GrpcNotfiyAgent>(agent_spec.agent_address); | |||
| context.notify_agent_ = notify_agent; | |||
| agent_spec_map_[agent_spec.rank_id] = context; | |||
| if (agent_spec_map_.size() >= config_.distributed_meta.rank_size) { | |||
| SetWaitAgentsPromise(true); | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| void DistributedServable::Clear() { | |||
| std::unique_lock<std::mutex> lock{mutex_}; | |||
| for (auto &agent : agent_spec_map_) { | |||
| agent.second.notify_agent_->Exit(); | |||
| } | |||
| agent_spec_map_.clear(); | |||
| MSI_LOG_INFO << "End Clear servable"; | |||
| } | |||
| Status DistributedServable::UnregisterAgent(const WorkerAgentSpec &agent_spec) { | |||
| std::unique_lock<std::mutex> lock{mutex_}; | |||
| for (auto iter = agent_spec_map_.begin(); iter != agent_spec_map_.end();) { | |||
| if (agent_spec.rank_id == iter->second.agent_spec_.rank_id) { | |||
| iter = agent_spec_map_.erase(iter); | |||
| } else { | |||
| ++iter; | |||
| } | |||
| } | |||
| SetWaitAgentsPromise(false); | |||
| return SUCCESS; | |||
| } | |||
| Status DistributedServable::StartServable(const std::string &servable_directory, const std::string &servable_name, | |||
| const std::string &rank_table_json_file, uint64_t version_number, | |||
| uint64_t wait_agents_time_in_seconds) { | |||
| if (model_loaded_) { | |||
| MSI_LOG_EXCEPTION << "Model has loaded"; | |||
| } | |||
| version_number_ = version_number; | |||
| servable_name_ = servable_name; | |||
| rank_table_json_file_ = rank_table_json_file; | |||
| ServableSignature signature; | |||
| if (!ServableStorage::Instance().GetServableDef(servable_name, &signature)) { | |||
| return INFER_STATUS_LOG_ERROR(FAILED) << "Servable '" << servable_name << "' has not been registered"; | |||
| } | |||
| auto &meta = signature.servable_meta; | |||
| if (meta.servable_type != kServableTypeDistributed) { | |||
| return INFER_STATUS_LOG_ERROR(FAILED) | |||
| << "Servable '" << servable_name << "' is not registered as distributed servable, " << meta.Repr(); | |||
| } | |||
| config_.common_meta = meta.common_meta; | |||
| config_.distributed_meta = meta.distributed_meta; | |||
| auto status = InitConfigOnStartup(rank_table_json_file_); | |||
| if (status != SUCCESS) { | |||
| MSI_LOG_ERROR << "Init with rank table on start up failed"; | |||
| return status; | |||
| } | |||
| status = CheckRankConfig(); | |||
| if (status != SUCCESS) { | |||
| MSI_LOG_ERROR << "Check rank config failed"; | |||
| return status; | |||
| } | |||
| status = WaitAgentsReady(wait_agents_time_in_seconds); | |||
| if (status != SUCCESS) { | |||
| MSI_LOG_ERROR << "Waiting for ready of agents failed"; | |||
| return status; | |||
| } | |||
| status = CheckAgentsInfosAndInitTensorInfos(); | |||
| if (status != SUCCESS) { | |||
| MSI_LOG_ERROR << "Check agents infos failed"; | |||
| return status; | |||
| } | |||
| model_loaded_ = true; | |||
| return SUCCESS; | |||
| } | |||
| Status DistributedServable::InitConfigOnStartup(const std::string &rank_table_json_file) { return FAILED; } | |||
| Status DistributedServable::WaitAgentsReady(uint64_t wait_agents_time_in_seconds) { | |||
| auto future = agents_promise_.get_future(); | |||
| if (wait_agents_time_in_seconds == 0) { | |||
| wait_agents_time_in_seconds = UINT32_MAX; | |||
| } | |||
| const uint64_t kWaitMaxHundredMs = wait_agents_time_in_seconds * 10; | |||
| uint64_t i; | |||
| for (i = 0; i < kWaitMaxHundredMs; i++) { // | |||
| if (ExitSignalHandle::Instance().HasStopped()) { | |||
| return INFER_STATUS_LOG_ERROR(FAILED) << "Agents has stopped"; | |||
| } | |||
| // waiting for 100ms | |||
| if (future.wait_for(std::chrono::milliseconds(100)) == std::future_status::ready) { | |||
| auto flag = future.get(); | |||
| if (!flag) { | |||
| return INFER_STATUS_LOG_ERROR(FAILED) << "Failed to starting all agents, maybe some error reported"; | |||
| } | |||
| break; | |||
| } | |||
| } | |||
| if (i >= kWaitMaxHundredMs) { | |||
| return INFER_STATUS_LOG_ERROR(FAILED) | |||
| << "Failed to wait for ready of all agents, current agents count: " << agent_spec_map_.size() | |||
| << ", rank size: " << config_.distributed_meta.rank_size; | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| Status DistributedServable::CompareTensorInfos(const std::vector<TensorInfo> &lefts, | |||
| const std::vector<TensorInfo> &rights) { | |||
| if (lefts.size() != rights.size()) { | |||
| return INFER_STATUS(FAILED) << "Size not match, left: " << lefts.size() << ", right: " << rights.size(); | |||
| } | |||
| auto tensor_info_as_str = [](const TensorInfo &tensor_info) { | |||
| Status status = INFER_STATUS(SUCCESS) << "size: " << tensor_info.size << ", data type: " << tensor_info.data_type | |||
| << ", shape: " << tensor_info.shape; | |||
| return status.StatusMessage(); | |||
| }; | |||
| for (size_t k = 0; k < lefts.size(); k++) { | |||
| auto &left = lefts[k]; | |||
| auto &right = rights[k]; | |||
| if (left.size != right.size || left.shape != right.shape || left.data_type != right.data_type) { | |||
| return INFER_STATUS(FAILED) << "Index " << k << " tensor not match, left- " << tensor_info_as_str(left) | |||
| << "; right- " << tensor_info_as_str(right); | |||
| } | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| Status DistributedServable::CheckAgentsInfosAndInitTensorInfos() { | |||
| auto rank_size = config_.distributed_meta.rank_size; | |||
| auto stage_size = config_.distributed_meta.stage_size; | |||
| auto parallel_count = rank_size / stage_size; | |||
| MSI_LOG_INFO << "Check agents infos, rank size :" << rank_size << ", stage size: " << stage_size | |||
| << ", parallel count: " << parallel_count; | |||
| if (agent_spec_map_.size() != rank_size) { | |||
| return INFER_STATUS_LOG_ERROR(FAILED) | |||
| << "Registered agents size " << agent_spec_map_.size() << " not match rank size " << rank_size; | |||
| } | |||
| input_infos_ = agent_spec_map_[0].agent_spec_.input_infos; | |||
| output_infos_ = agent_spec_map_[rank_size - 1].agent_spec_.output_infos; | |||
| batch_size_ = agent_spec_map_[0].agent_spec_.batch_size; | |||
| if (input_infos_.empty()) { | |||
| return INFER_STATUS_LOG_ERROR(FAILED) << "Rank " << 0 << " input count cannot be 0"; | |||
| } | |||
| if (output_infos_.empty()) { | |||
| return INFER_STATUS_LOG_ERROR(FAILED) << "Rank " << rank_size - 1 << " output count cannot be 0"; | |||
| } | |||
| Status status; | |||
| for (size_t i = 0; i < parallel_count; i++) { | |||
| auto &agent_spec = agent_spec_map_[i]; | |||
| status = CompareTensorInfos(agent_spec.agent_spec_.input_infos, input_infos_); | |||
| if (status != SUCCESS) { | |||
| status = INFER_STATUS_LOG_ERROR(FAILED) | |||
| << "Rank " << i << " input infos not match rank 0, details: " << status.StatusMessage(); | |||
| return status; | |||
| } | |||
| } | |||
| for (size_t i = parallel_count; i < rank_size; i++) { | |||
| auto &agent_spec = agent_spec_map_[i]; | |||
| if (!agent_spec.agent_spec_.input_infos.empty()) { | |||
| return INFER_STATUS_LOG_ERROR(FAILED) << "Expect rank " << i << " input count equal to 0"; | |||
| } | |||
| } | |||
| for (size_t i = 0; i < rank_size; i++) { | |||
| auto &first_item = agent_spec_map_[i]; | |||
| for (size_t k = 0; k < parallel_count && i + k < rank_size; k++) { | |||
| auto rank_id = i + k; | |||
| auto &agent_spec = agent_spec_map_[i + k]; | |||
| status = CompareTensorInfos(agent_spec.agent_spec_.output_infos, first_item.agent_spec_.output_infos); | |||
| if (status != SUCCESS) { | |||
| status = INFER_STATUS_LOG_ERROR(FAILED) << "Rank " << rank_size << " output infos not match rank " << i | |||
| << ", details: " << status.StatusMessage(); | |||
| return status; | |||
| } | |||
| if (agent_spec.agent_spec_.batch_size != 0 && agent_spec.agent_spec_.batch_size != batch_size_) { | |||
| return INFER_STATUS_LOG_ERROR(FAILED) | |||
| << "Expect rank " << rank_id << " batch size equal to 0 or rank 0 batch size " << batch_size_; | |||
| } | |||
| } | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| Status DistributedServable::CheckRankConfig() { | |||
| auto rank_size = config_.distributed_meta.rank_size; | |||
| auto stage_size = config_.distributed_meta.stage_size; | |||
| if (stage_size == 0 || rank_size == 0) { | |||
| return INFER_STATUS_LOG_ERROR(FAILED) | |||
| << "Rank size or stage size cannot be 0, rank size: " << rank_size << ", stage size: " << stage_size; | |||
| } | |||
| if (rank_size % stage_size != 0) { | |||
| return INFER_STATUS_LOG_ERROR(FAILED) | |||
| << "Rank size must be an integral multiple of stage size, rank size: " << rank_size | |||
| << ", stage size: " << stage_size; | |||
| } | |||
| if (config_.rank_list.size() != rank_size) { | |||
| return INFER_STATUS_LOG_ERROR(FAILED) | |||
| << "Rank size " << config_.rank_list.size() << " declared in rank table file not equal to rank size " | |||
| << rank_size << " declared in servable_config, rank json config file: " << rank_table_json_file_; | |||
| } | |||
| auto parallel_count = rank_size / stage_size; | |||
| constexpr size_t card_count_per_machine = 8; | |||
| if (stage_size == 1) { | |||
| std::map<std::string, std::set<uint32_t>> device_map; | |||
| for (size_t i = 0; i < rank_size; i++) { | |||
| const auto &item = config_.rank_list[i]; | |||
| auto &device_id_list = device_map[item.ip]; | |||
| if (device_id_list.count(item.device_id) > 0) { | |||
| return INFER_STATUS_LOG_ERROR(FAILED) << "Check rank table config failed, device id repeatedly used by rank " | |||
| << i << " in device ip " << item.ip; | |||
| } | |||
| device_id_list.emplace(item.device_id); | |||
| } | |||
| } else { | |||
| if (rank_size < card_count_per_machine) { | |||
| return INFER_STATUS_LOG_ERROR(FAILED) | |||
| << "Rank size " << rank_size << "must >= card count " << card_count_per_machine | |||
| << " of one machine when stage size " << stage_size << " > 1"; | |||
| } | |||
| if (parallel_count % card_count_per_machine != 0) { | |||
| return INFER_STATUS_LOG_ERROR(FAILED) | |||
| << "Parallel count " << parallel_count << " in one stage must be N * " << card_count_per_machine | |||
| << "(card count of one machine), rank size: " << rank_size << ", stage size: " << stage_size; | |||
| } | |||
| for (size_t i = 0; i < rank_size; i += card_count_per_machine) { | |||
| const auto &first_item = config_.rank_list[i]; | |||
| for (size_t k = 0; i + k < rank_size && k < card_count_per_machine; k++) { | |||
| auto rank_id = i + k; | |||
| const auto &item = config_.rank_list[rank_id]; | |||
| if (k != item.device_id) { | |||
| return INFER_STATUS_LOG_ERROR(FAILED) | |||
| << "Check rank table config failed, expected device id of rank " << rank_id << " to be " << k; | |||
| } | |||
| if (first_item.ip != item.ip) { | |||
| return INFER_STATUS_LOG_ERROR(FAILED) | |||
| << "Check rank table config failed, expected device ip " << item.ip << " of rank " << rank_id | |||
| << " to be equal with device ip " << first_item.ip << " of rank " << i; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| MSI_LOG_INFO << "Check rank table success, rank size: " << rank_size << ", stage size: " << stage_size | |||
| << ", parallel count in one stage: " << parallel_count; | |||
| return SUCCESS; | |||
| } | |||
| void DistributedServable::OnAgentFailed() { SetWaitAgentsPromise(false); } | |||
| } // namespace serving | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,92 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_SERVING_WORKER_DISTRIBUTED_SERVABLE_H | |||
| #define MINDSPORE_SERVING_WORKER_DISTRIBUTED_SERVABLE_H | |||
| #include <vector> | |||
| #include <string> | |||
| #include <map> | |||
| #include <memory> | |||
| #include "worker/sevable_base.h" | |||
| #include "worker/distributed_worker/common.h" | |||
| #include "worker/distributed_worker/notify_agent/base_notify_agent.h" | |||
| namespace mindspore { | |||
| namespace serving { | |||
| struct DistributedAgentContext { | |||
| WorkerAgentSpec agent_spec_; | |||
| std::shared_ptr<BaseNotifyAgent> notify_agent_ = nullptr; | |||
| }; | |||
| class MS_API DistributedServable : public ServableBase { | |||
| public: | |||
| DistributedServable() = default; | |||
| ~DistributedServable(); | |||
| // from python, worker.py | |||
| Status StartServable(const std::string &servable_directory, const std::string &servable_name, | |||
| const std::string &rank_table_json_file, uint64_t version_number, | |||
| uint64_t wait_agents_time_in_seconds); | |||
| // invoke from agent | |||
| Status GetDistributedServableConfig(DistributedServableConfig *config) const; | |||
| // send model and group | |||
| // register and unregister agent, agent_spec_list_ | |||
| Status RegisterAgent(const WorkerAgentSpec &agent_spec); | |||
| Status UnregisterAgent(const WorkerAgentSpec &agent_spec); | |||
| // predict, use config_ and agent_spec_list_ | |||
| Status Predict(const std::vector<TensorBasePtr> &input, std::vector<TensorBasePtr> *output) override; | |||
| std::vector<TensorInfo> GetInputInfos() const override; | |||
| std::vector<TensorInfo> GetOutputInfos() const override; | |||
| uint64_t GetBatchSize() const override; | |||
| std::string GetServableName() const override; | |||
| uint64_t GetServableVersion() const override; | |||
| void Clear() override; | |||
| void OnAgentFailed(); | |||
| private: | |||
| DistributedServableConfig config_; | |||
| std::string servable_name_; | |||
| uint64_t version_number_ = 0; | |||
| bool model_loaded_ = false; | |||
| std::mutex mutex_; | |||
| std::map<uint32_t, DistributedAgentContext> agent_spec_map_; | |||
| std::string rank_table_json_file_; | |||
| std::vector<TensorInfo> input_infos_; | |||
| std::vector<TensorInfo> output_infos_; | |||
| uint64_t batch_size_ = 0; | |||
| std::atomic_flag promise_set_flag_ = ATOMIC_FLAG_INIT; | |||
| std::promise<bool> agents_promise_; | |||
| Status InitConfigOnStartup(const std::string &rank_table_json_file); | |||
| Status WaitAgentsReady(uint64_t wait_agents_time_in_seconds); | |||
| Status CheckAgentsInfosAndInitTensorInfos(); | |||
| Status CompareTensorInfos(const std::vector<TensorInfo> &lefts, const std::vector<TensorInfo> &rights); | |||
| Status CheckRankConfig(); | |||
| void SetWaitAgentsPromise(bool flag); | |||
| // agent stubs | |||
| }; | |||
| } // namespace serving | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_SERVING_WORKER_DISTRIBUTED_SERVABLE_H | |||
| @@ -0,0 +1,42 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_SERVING_WORKER_BASE_NOTIFY_AGENT_H | |||
| #define MINDSPORE_SERVING_WORKER_BASE_NOTIFY_AGENT_H | |||
| #include <vector> | |||
| #include <functional> | |||
| #include <future> | |||
| #include "common/serving_common.h" | |||
| #include "common/servable.h" | |||
| #include "proto/ms_agent.pb.h" | |||
| #include "common/grpc_client.h" | |||
| namespace mindspore { | |||
| namespace serving { | |||
| class MS_API BaseNotifyAgent { | |||
| public: | |||
| BaseNotifyAgent() = default; | |||
| virtual ~BaseNotifyAgent() = default; | |||
| virtual Status Exit() = 0; | |||
| virtual Status DispatchAsync(const proto::DistributedPredictRequest &request, proto::DistributedPredictReply *reply, | |||
| DispatchCallback callback) = 0; | |||
| }; | |||
| } // namespace serving | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_SERVING_WORKER_BASE_NOTIFY_AGENT_H | |||
| @@ -0,0 +1,66 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "worker/distributed_worker/notify_agent/notify_agent.h" | |||
| #include <grpcpp/grpcpp.h> | |||
| #include <grpcpp/health_check_service_interface.h> | |||
| #include <grpcpp/ext/proto_server_reflection_plugin.h> | |||
| #include <thread> | |||
| #include "common/exit_handle.h" | |||
| #include "common/grpc_server.h" | |||
| #include "common/grpc_client.h" | |||
| namespace mindspore { | |||
| namespace serving { | |||
| GrpcNotfiyAgent::GrpcNotfiyAgent(const std::string &agent_address) { | |||
| agent_address_ = agent_address; | |||
| std::shared_ptr<grpc::Channel> channel = GrpcServer::CreateChannel(agent_address_); | |||
| stub_ = proto::MSAgent::NewStub(channel); | |||
| } | |||
| GrpcNotfiyAgent::~GrpcNotfiyAgent() = default; | |||
| Status GrpcNotfiyAgent::Exit() { | |||
| if (stub_) { | |||
| proto::DistributedExitRequest request; | |||
| request.set_address(agent_address_); | |||
| proto::DistributedExitReply reply; | |||
| grpc::ClientContext context; | |||
| const int32_t TIME_OUT = 1; | |||
| std::chrono::system_clock::time_point deadline = std::chrono::system_clock::now() + std::chrono::seconds(TIME_OUT); | |||
| context.set_deadline(deadline); | |||
| (void)stub_->Exit(&context, request, &reply); | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| Status GrpcNotfiyAgent::DispatchAsync(const proto::DistributedPredictRequest &request, | |||
| proto::DistributedPredictReply *reply, DispatchCallback callback) { | |||
| if (!stub_) { | |||
| return INFER_STATUS_LOG_ERROR(FAILED) | |||
| << "Predict failed, agent gRPC has not been inited or has already exited, agent address " << agent_address_; | |||
| } | |||
| if (!distributed_client_) { | |||
| distributed_client_ = std::make_unique<MSDistributedClient>(); | |||
| distributed_client_->Start(); | |||
| } | |||
| distributed_client_->PredictAsync(request, reply, stub_.get(), callback); | |||
| return SUCCESS; | |||
| } // namespace serving | |||
| } // namespace serving | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,48 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_SERVING_WORKER_NOTIFY_AGENT_H | |||
| #define MINDSPORE_SERVING_WORKER_NOTIFY_AGENT_H | |||
| #include <vector> | |||
| #include <string> | |||
| #include <memory> | |||
| #include <atomic> | |||
| #include "worker/distributed_worker/notify_agent/base_notify_agent.h" | |||
| #include "proto/ms_agent.pb.h" | |||
| #include "proto/ms_agent.grpc.pb.h" | |||
| namespace mindspore { | |||
| namespace serving { | |||
| class MS_API GrpcNotfiyAgent : public BaseNotifyAgent { | |||
| public: | |||
| explicit GrpcNotfiyAgent(const std::string &worker_address); | |||
| ~GrpcNotfiyAgent() override; | |||
| Status Exit() override; | |||
| Status DispatchAsync(const proto::DistributedPredictRequest &request, proto::DistributedPredictReply *reply, | |||
| DispatchCallback callback) override; | |||
| private: | |||
| std::string agent_address_; | |||
| std::shared_ptr<proto::MSAgent::Stub> stub_ = nullptr; | |||
| }; | |||
| } // namespace serving | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_SERVING_WORKER_NOTIFY_AGENT_H | |||
| @@ -0,0 +1,107 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "worker/distributed_worker/notify_distributed/notify_worker.h" | |||
| #include <grpcpp/grpcpp.h> | |||
| #include <grpcpp/health_check_service_interface.h> | |||
| #include <grpcpp/ext/proto_server_reflection_plugin.h> | |||
| #include <thread> | |||
| #include "common/exit_handle.h" | |||
| #include "common/grpc_server.h" | |||
| #include "common/proto_tensor.h" | |||
| namespace mindspore { | |||
| namespace serving { | |||
| GrpcNotifyDistributeWorker::GrpcNotifyDistributeWorker(const std::string &distributed_worker_ip, | |||
| uint32_t distributed_worker_port, const std::string &host_ip, | |||
| uint32_t host_port) | |||
| : distributed_worker_ip_(distributed_worker_ip), | |||
| distributed_worker_port_(distributed_worker_port), | |||
| host_ip_(host_ip), | |||
| host_port_(host_port) { | |||
| distributed_worker_address_ = distributed_worker_ip + ":" + std::to_string(distributed_worker_port); | |||
| agent_address_ = host_ip_ + ":" + std::to_string(host_port_); | |||
| auto channel = GrpcServer::CreateChannel(distributed_worker_address_); | |||
| stub_ = proto::MSWorker::NewStub(channel); | |||
| } | |||
| GrpcNotifyDistributeWorker::~GrpcNotifyDistributeWorker() = default; | |||
| Status GrpcNotifyDistributeWorker::Register(const std::vector<WorkerAgentSpec> &worker_specs) { | |||
| const int32_t REGISTER_TIME_OUT = 60; | |||
| const int32_t REGISTER_INTERVAL = 1; | |||
| auto loop = REGISTER_TIME_OUT; | |||
| while (loop-- && !ExitSignalHandle::Instance().HasStopped()) { | |||
| MSI_LOG(INFO) << "Register to " << distributed_worker_address_; | |||
| proto::AgentRegisterRequest request; | |||
| GrpcTensorHelper::CopyFromWorkerAgentSpec(worker_specs, &request); | |||
| proto::AgentRegisterReply reply; | |||
| grpc::ClientContext context; | |||
| std::chrono::system_clock::time_point deadline = | |||
| std::chrono::system_clock::now() + std::chrono::seconds(REGISTER_INTERVAL); | |||
| context.set_deadline(deadline); | |||
| grpc::Status status = stub_->AgentRegister(&context, request, &reply); | |||
| if (status.ok()) { | |||
| MSI_LOG(INFO) << "Register SUCCESS "; | |||
| return SUCCESS; | |||
| } | |||
| MSI_LOG_INFO << "Grpc message: " << status.error_code() << ", " << status.error_message(); | |||
| std::this_thread::sleep_for(std::chrono::milliseconds(REGISTER_INTERVAL * 1000)); | |||
| } | |||
| if (ExitSignalHandle::Instance().HasStopped()) { | |||
| return INFER_STATUS_LOG_WARNING(FAILED) << "Agent exit, stop registration"; | |||
| } | |||
| return INFER_STATUS_LOG_ERROR(SYSTEM_ERROR) << "Register TimeOut"; | |||
| } | |||
| Status GrpcNotifyDistributeWorker::Unregister() { | |||
| if (is_stoped_.load()) { | |||
| return SUCCESS; | |||
| } | |||
| is_stoped_ = true; | |||
| proto::AgentExitRequest request; | |||
| request.set_address(agent_address_); | |||
| proto::AgentExitReply reply; | |||
| grpc::ClientContext context; | |||
| const int32_t TIME_OUT = 1; | |||
| std::chrono::system_clock::time_point deadline = std::chrono::system_clock::now() + std::chrono::seconds(TIME_OUT); | |||
| context.set_deadline(deadline); | |||
| grpc::Status status = stub_->AgentExit(&context, request, &reply); | |||
| if (status.ok()) { | |||
| MSI_LOG(INFO) << "Exit SUCCESS "; | |||
| return SUCCESS; | |||
| } | |||
| return INFER_STATUS_LOG_ERROR(SYSTEM_ERROR) << "Exit Failed"; | |||
| } | |||
| Status GrpcNotifyDistributeWorker::NotifyFailed(const std::string &worker_ip, uint32_t worker_port) { | |||
| auto address = worker_ip + ":" + std::to_string(worker_port); | |||
| auto channel = GrpcServer::CreateChannel(address); | |||
| auto stub = proto::MSWorker::NewStub(channel); | |||
| grpc::ClientContext context; | |||
| proto::AgentFailedRequest request; | |||
| proto::AgentFailedReply reply; | |||
| grpc::Status status = stub->AgentFailed(&context, request, &reply); | |||
| if (status.ok()) { | |||
| MSI_LOG(INFO) << "Success to notify failure of agent"; | |||
| return SUCCESS; | |||
| } | |||
| return INFER_STATUS_LOG_ERROR(SYSTEM_ERROR) << "Failed to notify failure of agent"; | |||
| } | |||
| } // namespace serving | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,55 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_SERVING_WORKER_NOTIFY_WORKER_H | |||
| #define MINDSPORE_SERVING_WORKER_NOTIFY_WORKER_H | |||
| #include <vector> | |||
| #include <string> | |||
| #include <memory> | |||
| #include "common/serving_common.h" | |||
| #include "worker/distributed_worker/common.h" | |||
| #include "proto/ms_distributed.pb.h" | |||
| #include "proto/ms_distributed.grpc.pb.h" | |||
| #include "proto/ms_worker.pb.h" | |||
| #include "proto/ms_worker.grpc.pb.h" | |||
| namespace mindspore { | |||
| namespace serving { | |||
| class MS_API GrpcNotifyDistributeWorker { | |||
| public: | |||
| GrpcNotifyDistributeWorker(const std::string &worker_ip, uint32_t worker_port, const std::string &agent_ip, | |||
| uint32_t agent_port); | |||
| ~GrpcNotifyDistributeWorker(); | |||
| Status Register(const std::vector<WorkerAgentSpec> &agent_specs); | |||
| Status Unregister(); | |||
| // from start up, not agent | |||
| static Status NotifyFailed(const std::string &worker_ip, uint32_t worker_port); | |||
| private: | |||
| std::string distributed_worker_ip_; | |||
| uint32_t distributed_worker_port_; | |||
| std::string host_ip_; | |||
| uint32_t host_port_; | |||
| std::string agent_address_; | |||
| std::string distributed_worker_address_; | |||
| std::unique_ptr<proto::MSWorker::Stub> stub_; | |||
| std::atomic<bool> is_stoped_{false}; | |||
| }; | |||
| } // namespace serving | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_SERVING_WORKER_NOTIFY_WORKER_H | |||
| @@ -0,0 +1,103 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "worker/distributed_worker/worker_agent.h" | |||
| #include <memory> | |||
| #include "worker/distributed_worker/agent_process/agent_process.h" | |||
| #include "worker/distributed_worker/notify_distributed/notify_worker.h" | |||
| #include "common/exit_handle.h" | |||
| namespace mindspore { | |||
| namespace serving { | |||
| WorkerAgent &WorkerAgent::Instance() { | |||
| static WorkerAgent instance; | |||
| return instance; | |||
| } | |||
| Status WorkerAgent::Clear() { | |||
| if (notify_worker_) { | |||
| if (exit_notify_worker_) { | |||
| notify_worker_->Unregister(); | |||
| } | |||
| notify_worker_ = nullptr; | |||
| } | |||
| grpc_server_.Stop(); | |||
| executor_.UnloadModel(); | |||
| return SUCCESS; | |||
| } | |||
| Status WorkerAgent::Run(const proto::DistributedPredictRequest &request, proto::DistributedPredictReply *reply) { | |||
| // todo : DistributedPredictRequest->RequestBase | |||
| // todo : DistributedPredictReply->ReplyBase | |||
| return SUCCESS; | |||
| } | |||
| Status WorkerAgent::StartAgent(const AgentStartUpConfig &config) { | |||
| Status status; | |||
| config_ = config; | |||
| status = executor_.LoadModelFromFile(config); | |||
| if (status != SUCCESS) { | |||
| MSI_LOG_ERROR << "LoadModelFromFile failed, servable name: " << config.common_meta.servable_name | |||
| << ", rank_id: " << config.rank_id << ", device id: " << config.device_id | |||
| << ", model file: " << config.model_file_name | |||
| << ", rank table file: " << config.rank_table_json_file_name | |||
| << ", group config file: " << config.group_file_name; | |||
| return status; | |||
| } | |||
| status = StartGrpcServer(); | |||
| if (status != SUCCESS) { | |||
| MSI_LOG_ERROR << "Start agent grpc server failed, agent ip: " << config.agent_ip | |||
| << ", agent port: " << config.agent_port; | |||
| return status; | |||
| } | |||
| status = RegisterAgent(); | |||
| if (status != SUCCESS) { | |||
| MSI_LOG_ERROR << "Register agent failed, agent ip: " << config.agent_ip << ", agent port: " << config.agent_port | |||
| << ", worker ip: " << config.worker_ip << ", worker port: " << config.worker_port; | |||
| return status; | |||
| } | |||
| MSI_LOG_INFO << "Start agent success, servable name: " << config.common_meta.servable_name | |||
| << ", rank_id: " << config.rank_id << ", device id: " << config.device_id | |||
| << ", model file: " << config.model_file_name | |||
| << ", rank table file: " << config.rank_table_json_file_name | |||
| << ", group config file: " << config.group_file_name; | |||
| return SUCCESS; | |||
| } | |||
| Status WorkerAgent::StartGrpcServer() { | |||
| grpc_server_.Start(std::make_shared<MSAgentImpl>(), config_.agent_ip, config_.agent_port, gRpcMaxMBMsgSize, "Agent"); | |||
| return SUCCESS; | |||
| } | |||
| Status WorkerAgent::RegisterAgent() { | |||
| notify_worker_ = std::make_shared<GrpcNotifyDistributeWorker>(config_.worker_ip, config_.agent_port, config_.agent_ip, | |||
| config_.agent_port); | |||
| WorkerAgentSpec spec; | |||
| spec.agent_address = config_.agent_ip + ":" + std::to_string(config_.agent_port); | |||
| spec.rank_id = config_.rank_id; | |||
| spec.batch_size = executor_.GetBatchSize(); | |||
| spec.input_infos = executor_.GetInputInfos(); | |||
| spec.output_infos = executor_.GetOutputInfos(); | |||
| return notify_worker_->Register({spec}); | |||
| } | |||
| void WorkerAgent::StopAgent(bool notify_worker) { | |||
| exit_notify_worker_ = notify_worker; | |||
| ExitSignalHandle::Instance().Stop(); | |||
| } | |||
| } // namespace serving | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,55 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_SERVING_WORKER_AGENT_H | |||
| #define MINDSPORE_SERVING_WORKER_AGENT_H | |||
| #include <vector> | |||
| #include <memory> | |||
| #include "worker/distributed_worker/agent_executor.h" | |||
| #include "proto/ms_agent.pb.h" | |||
| #include "proto/ms_agent.grpc.pb.h" | |||
| #include "common/grpc_server.h" | |||
| #include "worker/distributed_worker/common.h" | |||
| #include "worker/distributed_worker/notify_distributed/notify_worker.h" | |||
| namespace mindspore { | |||
| namespace serving { | |||
| class MS_API WorkerAgent { | |||
| public: | |||
| static WorkerAgent &Instance(); | |||
| Status Clear(); | |||
| Status Run(const proto::DistributedPredictRequest &request, proto::DistributedPredictReply *reply); | |||
| Status StartAgent(const AgentStartUpConfig &config); | |||
| void StopAgent(bool notify_worker = true); | |||
| private: | |||
| AgentStartUpConfig config_; | |||
| WorkerAgentExecutor executor_; | |||
| GrpcServer grpc_server_; | |||
| bool exit_notify_worker_ = true; | |||
| std::shared_ptr<GrpcNotifyDistributeWorker> notify_worker_; | |||
| Status StartGrpcServer(); | |||
| Status RegisterAgent(); | |||
| }; | |||
| } // namespace serving | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_SERVING_WORKER_AGENT_H | |||
| @@ -15,7 +15,6 @@ | |||
| */ | |||
| #include "worker/grpc/worker_process.h" | |||
| #include "master/dispacther.h" | |||
| #include "worker/worker.h" | |||
| namespace mindspore { | |||
| @@ -28,7 +28,7 @@ namespace mindspore { | |||
| namespace serving { | |||
| // Service Implement | |||
| class MSWorkerImpl final : public proto::MSWorker::Service { | |||
| class MSWorkerImpl : public proto::MSWorker::Service { | |||
| public: | |||
| grpc::Status Predict(grpc::ServerContext *context, const proto::PredictRequest *request, | |||
| proto::PredictReply *reply) override; | |||
| @@ -21,12 +21,20 @@ | |||
| namespace mindspore { | |||
| namespace serving { | |||
| MSWorkerServer::~MSWorkerServer() { Stop(); } | |||
| MSWorkerServer::MSWorkerServer(const std::string &hostname, int32_t port) { | |||
| Status MSWorkerServer::StartWorkerGrpcServer(const std::string &hostname, int32_t port) { | |||
| if (in_running_) { | |||
| return INFER_STATUS_LOG_ERROR(FAILED) << "Worker grpc server is already running"; | |||
| } | |||
| service_impl_ = std::make_unique<MSWorkerImpl>(); | |||
| async_server_ = std::make_unique<WorkerGrpcServer>(hostname, port, service_impl_.get()); | |||
| return Init(); | |||
| } | |||
| MSWorkerServer::MSWorkerServer() = default; | |||
| Status MSWorkerServer::Init() { | |||
| Status status = async_server_->Run("Worker gRPC", gRpcMaxMBMsgSize); | |||
| if (status != SUCCESS) return status; | |||
| @@ -40,10 +48,14 @@ Status MSWorkerServer::StartAsyncRpcService() { | |||
| return status; | |||
| } | |||
| Status MSWorkerServer::Stop() { | |||
| if (in_running_) { | |||
| if (in_running_ && async_server_) { | |||
| async_server_->Stop(); | |||
| grpc_thread_.join(); | |||
| if (grpc_thread_.joinable()) { | |||
| grpc_thread_.join(); | |||
| } | |||
| } | |||
| async_server_ = nullptr; | |||
| service_impl_ = nullptr; | |||
| in_running_ = false; | |||
| return SUCCESS; | |||
| } | |||
| @@ -27,40 +27,53 @@ | |||
| #include "proto/ms_worker.grpc.pb.h" | |||
| #include "common/grpc_async_server.h" | |||
| #include "worker/grpc/worker_process.h" | |||
| #include "worker/distributed_worker/distributed_servable.h" | |||
| namespace mindspore { | |||
| namespace serving { | |||
| // Service Implement | |||
| class MSWorkerServer { | |||
| class MS_API MSWorkerServer { | |||
| public: | |||
| enum ServerState { kGdsUninit = 0, kGdsInitializing, kGdsRunning, kGdsStopped }; | |||
| MSWorkerServer(const std::string &hostname, int32_t port); | |||
| ~MSWorkerServer(); | |||
| Status Init(); | |||
| MSWorkerServer(); | |||
| virtual ~MSWorkerServer(); | |||
| virtual Status StartWorkerGrpcServer(const std::string &hostname, int32_t port); | |||
| Status Stop(); | |||
| Status StartAsyncRpcService(); | |||
| protected: | |||
| bool in_running_ = false; | |||
| std::thread grpc_thread_; | |||
| std::unique_ptr<MSWorkerImpl> service_impl_; | |||
| std::unique_ptr<GrpcAsyncServer> async_server_; | |||
| std::unique_ptr<MSWorkerImpl> service_impl_ = nullptr; | |||
| std::unique_ptr<GrpcAsyncServer> async_server_ = nullptr; | |||
| Status Init(); | |||
| Status StartAsyncRpcService(); | |||
| }; | |||
| class WorkerServiceContext { | |||
| public: | |||
| enum class STATE : int8_t { CREATE = 1, PROCESS = 2, FINISH = 3 }; | |||
| WorkerServiceContext(MSWorkerImpl *service_impl, proto::MSWorker::AsyncService *async_service, | |||
| grpc::ServerCompletionQueue *cq) | |||
| : service_impl_(service_impl), async_service_(async_service), cq_(cq) { | |||
| state_ = STATE::CREATE; | |||
| } | |||
| virtual ~WorkerServiceContext() {} | |||
| bool JudgeFinish() { return state_ == STATE::FINISH; } | |||
| virtual void StartEnqueueRequest() = 0; | |||
| virtual void HandleRequest() = 0; | |||
| virtual bool JudgeFinish() = 0; | |||
| protected: | |||
| MSWorkerImpl *service_impl_; | |||
| proto::MSWorker::AsyncService *async_service_; | |||
| grpc::ServerCompletionQueue *cq_; | |||
| grpc::ServerContext ctx_; | |||
| public: | |||
| STATE state_; | |||
| }; | |||
| @@ -68,9 +81,7 @@ class WorkerPredictContext : public WorkerServiceContext { | |||
| public: | |||
| WorkerPredictContext(MSWorkerImpl *service_impl, proto::MSWorker::AsyncService *async_service, | |||
| grpc::ServerCompletionQueue *cq) | |||
| : service_impl_(service_impl), async_service_(async_service), cq_(cq), responder_(&ctx_) { | |||
| state_ = STATE::CREATE; | |||
| } | |||
| : WorkerServiceContext(service_impl, async_service, cq), responder_(&ctx_) {} | |||
| ~WorkerPredictContext() = default; | |||
| @@ -93,13 +104,7 @@ class WorkerPredictContext : public WorkerServiceContext { | |||
| responder_.Finish(response_, status, this); | |||
| } | |||
| bool JudgeFinish() override { return state_ == STATE::FINISH; } | |||
| private: | |||
| MSWorkerImpl *service_impl_; | |||
| proto::MSWorker::AsyncService *async_service_; | |||
| grpc::ServerCompletionQueue *cq_; | |||
| grpc::ServerContext ctx_; | |||
| grpc::ServerAsyncResponseWriter<proto::PredictReply> responder_; | |||
| proto::PredictRequest request_; | |||
| proto::PredictReply response_; | |||
| @@ -109,9 +114,7 @@ class WorkerExitContext : public WorkerServiceContext { | |||
| public: | |||
| WorkerExitContext(MSWorkerImpl *service_impl, proto::MSWorker::AsyncService *async_service, | |||
| grpc::ServerCompletionQueue *cq) | |||
| : service_impl_(service_impl), async_service_(async_service), cq_(cq), responder_(&ctx_) { | |||
| state_ = STATE::CREATE; | |||
| } | |||
| : WorkerServiceContext(service_impl, async_service, cq), responder_(&ctx_) {} | |||
| ~WorkerExitContext() = default; | |||
| @@ -134,13 +137,7 @@ class WorkerExitContext : public WorkerServiceContext { | |||
| responder_.Finish(response_, status, this); | |||
| } | |||
| bool JudgeFinish() override { return state_ == STATE::FINISH; } | |||
| private: | |||
| MSWorkerImpl *service_impl_; | |||
| proto::MSWorker::AsyncService *async_service_; | |||
| grpc::ServerCompletionQueue *cq_; | |||
| grpc::ServerContext ctx_; | |||
| grpc::ServerAsyncResponseWriter<proto::ExitReply> responder_; | |||
| proto::ExitRequest request_; | |||
| proto::ExitReply response_; | |||
| @@ -174,7 +171,7 @@ class WorkerGrpcServer : public GrpcAsyncServer { | |||
| return SUCCESS; | |||
| } | |||
| private: | |||
| protected: | |||
| MSWorkerImpl *service_impl_; | |||
| proto::MSWorker::AsyncService svc_; | |||
| }; | |||
| @@ -52,132 +52,6 @@ enum DeviceType { | |||
| kDeviceTypeCpu, | |||
| }; | |||
| class MS_API InferSession { | |||
| public: | |||
| InferSession() = default; | |||
| virtual ~InferSession() = default; | |||
| virtual Status InitEnv(DeviceType device_type, uint32_t device_id, | |||
| const std::map<std::string, std::string> &other_options) = 0; | |||
| virtual Status FinalizeEnv() = 0; | |||
| virtual Status LoadModelFromFile(serving::DeviceType device_type, uint32_t device_id, const std::string &file_name, | |||
| ModelType model_type, const std::vector<int> &without_batch_dim_inputs, | |||
| const std::map<std::string, std::string> &other_options, uint32_t *model_id) = 0; | |||
| virtual Status UnloadModel(uint32_t model_id) = 0; | |||
| // override this method to avoid request/reply data copy | |||
| virtual Status ExecuteModel(uint32_t model_id, const RequestBase &request, ReplyBase *reply) = 0; | |||
| virtual Status ExecuteModel(uint32_t model_id, const std::vector<TensorBasePtr> &request, | |||
| std::vector<TensorBasePtr> *reply) { | |||
| VectorTensorPtrWrapRequest wrap_request(request); | |||
| VectorTensorPtrWrapReply wrap_reply(reply, []() { return std::make_shared<Tensor>(); }); | |||
| return ExecuteModel(model_id, wrap_request, &wrap_reply); | |||
| } | |||
| virtual std::vector<TensorInfo> GetInputInfos(uint32_t model_id) const = 0; | |||
| virtual std::vector<TensorInfo> GetOutputInfos(uint32_t model_id) const = 0; | |||
| virtual ssize_t GetBatchSize(uint32_t model_id) const = 0; | |||
| virtual bool CheckModelSupport(DeviceType device_type, ModelType model_type) const { return true; } | |||
| }; | |||
| struct InferSessionRegInfo { | |||
| std::shared_ptr<InferSession> session; | |||
| ModelType model_type; | |||
| int priority; | |||
| }; | |||
| class MS_API InferSessionStorage { | |||
| public: | |||
| void Register(DeviceType device_type, ModelType model_type, const std::shared_ptr<InferSession> &session, | |||
| int priority) { | |||
| auto &list = session_map_[device_type]; | |||
| InferSessionRegInfo info{session, model_type, priority}; | |||
| list.push_back(info); | |||
| } | |||
| std::shared_ptr<InferSession> Get(DeviceType device_type, ModelType model_type, DeviceType *specified_device_type) { | |||
| MSI_EXCEPTION_IF_NULL(specified_device_type); | |||
| if (device_type == kDeviceTypeNotSpecified) { | |||
| for (auto &item_device : session_map_) { | |||
| std::shared_ptr<InferSession> ret_session = GetSession(item_device.second, item_device.first, model_type); | |||
| if (ret_session) { | |||
| *specified_device_type = item_device.first; | |||
| return ret_session; | |||
| } | |||
| } | |||
| return nullptr; | |||
| } else if (device_type == kDeviceTypeAscend) { | |||
| auto ascend_list = {kDeviceTypeAscendCL, kDeviceTypeAscendMS}; | |||
| for (auto ascend_type : ascend_list) { | |||
| auto it = session_map_.find(ascend_type); | |||
| if (it == session_map_.end()) { | |||
| continue; | |||
| } | |||
| auto session_ret = GetSession(it->second, ascend_type, model_type); | |||
| if (session_ret != nullptr) { | |||
| *specified_device_type = ascend_type; | |||
| return session_ret; | |||
| } | |||
| } | |||
| return nullptr; | |||
| } | |||
| auto it = session_map_.find(device_type); | |||
| if (it == session_map_.end()) { | |||
| return nullptr; | |||
| } | |||
| std::shared_ptr<InferSession> session_ret; | |||
| session_ret = GetSession(it->second, device_type, model_type); | |||
| *specified_device_type = device_type; | |||
| return session_ret; | |||
| } | |||
| static InferSessionStorage &Instance() { | |||
| static InferSessionStorage instance; | |||
| return instance; | |||
| } | |||
| private: | |||
| std::unordered_map<DeviceType, std::vector<InferSessionRegInfo>> session_map_; | |||
| std::shared_ptr<InferSession> GetSession(const std::vector<InferSessionRegInfo> &session_list, DeviceType device_type, | |||
| ModelType model_type) { | |||
| std::shared_ptr<InferSession> session_ret = nullptr; | |||
| int cur_priority = INT32_MIN; | |||
| for (auto &item : session_list) { | |||
| if (item.model_type != model_type) { | |||
| continue; | |||
| } | |||
| if (session_ret == nullptr || cur_priority < item.priority) { | |||
| if (!item.session->CheckModelSupport(device_type, model_type)) { | |||
| MSI_LOG_INFO << "CheckModelSupport for " << device_type << " " << model_type << " failed, skipped"; | |||
| continue; | |||
| } | |||
| cur_priority = item.priority; | |||
| session_ret = item.session; | |||
| } | |||
| } | |||
| return session_ret; | |||
| } | |||
| }; | |||
| class MS_API InferSessionRegister { | |||
| public: | |||
| InferSessionRegister(DeviceType device_type, ModelType model_type, const std::shared_ptr<InferSession> &session, | |||
| int priority) { | |||
| InferSessionStorage::Instance().Register(device_type, model_type, session, priority); | |||
| } | |||
| }; | |||
| #define REGISTER_INFER_SEESION_UNIQUE(device_type, model_type, cls_name, priority, index) \ | |||
| static mindspore::serving::InferSessionRegister g_register_session_##cls_name##_##index( \ | |||
| device_type, model_type, std::make_shared<cls_name>(), priority); | |||
| #define REGISTER_INFER_SEESION_HELPER(device_type, model_type, cls_name, priority, index) \ | |||
| REGISTER_INFER_SEESION_UNIQUE(device_type, model_type, cls_name, priority, index) | |||
| #define REGISTER_INFER_SEESION(device_type, model_type, cls_name, priority) \ | |||
| REGISTER_INFER_SEESION_HELPER(device_type, model_type, cls_name, priority, __COUNTER__); | |||
| static inline LogStream &operator<<(LogStream &stream, DeviceType device_type) { | |||
| switch (device_type) { | |||
| case kDeviceTypeAscend: | |||
| @@ -26,16 +26,6 @@ | |||
| namespace mindspore { | |||
| namespace serving { | |||
| Status MindSporeModelWrap::InitEnv(serving::DeviceType device_type, uint32_t device_id, | |||
| const std::map<std::string, std::string> &other_options) { | |||
| return SUCCESS; | |||
| } | |||
| Status MindSporeModelWrap::FinalizeEnv() { | |||
| model_map_.clear(); | |||
| return SUCCESS; | |||
| } | |||
| mindspore::DataType TransInferDataType2ApiTypeId(DataType data_type) { | |||
| const std::map<DataType, mindspore::DataType> type2id_map{ | |||
| {serving::kMSI_Unknown, mindspore::DataType::kTypeUnknown}, | |||
| @@ -81,11 +71,9 @@ DataType TransTypeId2InferDataType(mindspore::DataType type_id) { | |||
| } | |||
| Status MindSporeModelWrap::LoadModelFromFile(serving::DeviceType device_type, uint32_t device_id, | |||
| const std::string &file_name, ModelType model_type, | |||
| const std::string &file_name, ModelType model_type, bool with_batch_dim, | |||
| const std::vector<int> &without_batch_dim_inputs, | |||
| const std::map<std::string, std::string> &other_options, | |||
| uint32_t *model_id) { | |||
| MSI_EXCEPTION_IF_NULL(model_id); | |||
| const std::map<std::string, std::string> &other_options) { | |||
| std::string device_type_str; | |||
| if (device_type == kDeviceTypeAscendMS) { | |||
| device_type_str = mindspore::kDeviceTypeAscend910; | |||
| @@ -113,18 +101,18 @@ Status MindSporeModelWrap::LoadModelFromFile(serving::DeviceType device_type, ui | |||
| << "', device_id: " << device_id << ", model type: " << model_type << ", options: " << other_options; | |||
| return Status(FAILED, status.ToString()); | |||
| } | |||
| model_index_++; | |||
| *model_id = model_index_; | |||
| ApiModelInfo api_model_info; | |||
| api_model_info.model = model; | |||
| api_model_info.device_type = device_type_str; | |||
| api_model_info.device_id = device_id; | |||
| api_model_info.with_batch_dim = with_batch_dim; | |||
| api_model_info.without_batch_dim_inputs = without_batch_dim_inputs; | |||
| auto st = GetModelInfos(&api_model_info); | |||
| if (st != SUCCESS) { | |||
| return st; | |||
| } | |||
| model_map_[*model_id] = api_model_info; | |||
| GetModelBatchSize(&api_model_info); | |||
| model_ = api_model_info; | |||
| MSI_LOG_INFO << "Load model from file success, model file: " << file_name << ", device_type: '" << device_type_str | |||
| << "', device_id: " << device_id << ", model type: " << model_type << ", options: " << other_options; | |||
| return SUCCESS; | |||
| @@ -169,20 +157,6 @@ Status MindSporeModelWrap::GetModelInfos(ApiModelInfo *api_model_info) { | |||
| MSI_EXCEPTION_IF_NULL(api_model_info); | |||
| auto model = api_model_info->model; | |||
| bool first_dim_same = true; | |||
| auto find_batch_size = [&first_dim_same, api_model_info](const std::vector<int64_t> &shape) { | |||
| if (first_dim_same) { | |||
| if (shape.empty()) { | |||
| first_dim_same = false; | |||
| } else if (api_model_info->batch_size != 0) { | |||
| if (api_model_info->batch_size != shape[0]) { | |||
| first_dim_same = false; | |||
| } | |||
| } else { | |||
| api_model_info->batch_size = shape[0]; | |||
| } | |||
| } | |||
| }; | |||
| auto get_tensor_info_from_tensor = [](const mindspore::MSTensor &ms_tensor) { | |||
| serving::TensorInfo tensor_info; | |||
| tensor_info.shape = ms_tensor.Shape(); | |||
| @@ -204,10 +178,6 @@ Status MindSporeModelWrap::GetModelInfos(ApiModelInfo *api_model_info) { | |||
| return INFER_STATUS_LOG_ERROR(FAILED) | |||
| << "Unknown input mindspore data type " << static_cast<int>(info.DataType()); | |||
| } | |||
| const auto &list = api_model_info->without_batch_dim_inputs; | |||
| if (std::find(list.begin(), list.end(), i) == list.end()) { | |||
| find_batch_size(tensor_info.shape); | |||
| } | |||
| api_model_info->input_tensor_infos.push_back(tensor_info); | |||
| api_model_info->input_names.push_back(info.Name()); | |||
| } | |||
| @@ -220,27 +190,59 @@ Status MindSporeModelWrap::GetModelInfos(ApiModelInfo *api_model_info) { | |||
| return INFER_STATUS_LOG_ERROR(FAILED) | |||
| << "Unknown output mindspore data type " << static_cast<int>(info.DataType()); | |||
| } | |||
| find_batch_size(tensor_info.shape); | |||
| api_model_info->output_tensor_infos.push_back(tensor_info); | |||
| api_model_info->output_names.push_back(info.Name()); | |||
| } | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| void MindSporeModelWrap::GetModelBatchSize(ApiModelInfo *api_model_info) { | |||
| MSI_EXCEPTION_IF_NULL(api_model_info); | |||
| bool first_dim_same = true; | |||
| auto find_batch_size = [&first_dim_same, api_model_info](const std::vector<int64_t> &shape) { | |||
| if (!api_model_info->with_batch_dim) { | |||
| first_dim_same = false; | |||
| return; | |||
| } | |||
| if (!first_dim_same) { | |||
| return; | |||
| } | |||
| if (shape.empty()) { | |||
| first_dim_same = false; | |||
| return; | |||
| } | |||
| if (api_model_info->batch_size != 0) { | |||
| if (api_model_info->batch_size != shape[0]) { | |||
| first_dim_same = false; | |||
| } | |||
| } else { | |||
| api_model_info->batch_size = shape[0]; | |||
| } | |||
| }; | |||
| auto list = api_model_info->without_batch_dim_inputs; | |||
| auto size = api_model_info->input_tensor_infos.size(); | |||
| for (size_t i = 0; i < size; i++) { | |||
| if (std::find(list.begin(), list.end(), i) == list.end()) { | |||
| auto &info = api_model_info->input_tensor_infos[i]; | |||
| find_batch_size(info.shape); | |||
| } | |||
| } | |||
| for (auto &info : api_model_info->output_tensor_infos) { | |||
| find_batch_size(info.shape); | |||
| } | |||
| if (!first_dim_same) { | |||
| api_model_info->batch_size = 0; | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| Status MindSporeModelWrap::UnloadModel(uint32_t model_id) { | |||
| auto it = model_map_.find(model_id); | |||
| if (it == model_map_.end()) { | |||
| return INFER_STATUS_LOG_ERROR(FAILED) << "Invalid model id " << model_id; | |||
| } | |||
| model_map_.erase(it); | |||
| Status MindSporeModelWrap::UnloadModel() { | |||
| model_.model = nullptr; | |||
| return SUCCESS; | |||
| } | |||
| Status MindSporeModelWrap::ExecuteModel(uint32_t model_id, const RequestBase &request, serving::ReplyBase *reply) { | |||
| Status MindSporeModelWrap::ExecuteModel(const RequestBase &request, serving::ReplyBase *reply) { | |||
| MSI_EXCEPTION_IF_NULL(reply); | |||
| FuncMakeInBuffer func_in = [&request](size_t index, const std::string &name) { | |||
| auto input_tensor = request[index]; | |||
| @@ -260,11 +262,10 @@ Status MindSporeModelWrap::ExecuteModel(uint32_t model_id, const RequestBase &re | |||
| tensor->set_data_type(data_type); | |||
| tensor->set_shape(shape); | |||
| }; | |||
| return ExecuteModelCommon(model_id, request.size(), func_in, func_out); | |||
| return ExecuteModelCommon(request.size(), func_in, func_out); | |||
| } | |||
| Status MindSporeModelWrap::ExecuteModel(uint32_t model_id, const std::vector<TensorBasePtr> &request, | |||
| std::vector<TensorBasePtr> *reply) { | |||
| Status MindSporeModelWrap::ExecuteModel(const std::vector<TensorBasePtr> &request, std::vector<TensorBasePtr> *reply) { | |||
| MSI_EXCEPTION_IF_NULL(reply); | |||
| FuncMakeInBuffer func_in = [&request](size_t index, const std::string &name) { | |||
| auto &input_tensor = request[index]; | |||
| @@ -282,16 +283,15 @@ Status MindSporeModelWrap::ExecuteModel(uint32_t model_id, const std::vector<Ten | |||
| tensor->set_shape(shape); | |||
| reply->push_back(tensor); | |||
| }; | |||
| return ExecuteModelCommon(model_id, request.size(), func_in, func_out); | |||
| return ExecuteModelCommon(request.size(), func_in, func_out); | |||
| } | |||
| Status MindSporeModelWrap::ExecuteModelCommon(uint32_t model_id, size_t request_size, const FuncMakeInBuffer &in_func, | |||
| Status MindSporeModelWrap::ExecuteModelCommon(size_t request_size, const FuncMakeInBuffer &in_func, | |||
| const FuncMakeOutTensor &out_func) { | |||
| auto it = model_map_.find(model_id); | |||
| if (it == model_map_.end()) { | |||
| return INFER_STATUS_LOG_ERROR(FAILED) << "Invalid model id " << model_id; | |||
| if (model_.model == nullptr) { | |||
| return INFER_STATUS_LOG_ERROR(FAILED) << "Model is not loaded"; | |||
| } | |||
| auto &model_info = it->second; | |||
| auto &model_info = model_; | |||
| auto model = model_info.model; | |||
| auto &input_names = model_info.input_names; | |||
| auto &output_names = model_info.output_names; | |||
| @@ -327,43 +327,25 @@ Status MindSporeModelWrap::ExecuteModelCommon(uint32_t model_id, size_t request_ | |||
| return SUCCESS; | |||
| } | |||
| std::vector<serving::TensorInfo> MindSporeModelWrap::GetInputInfos(uint32_t model_id) const { | |||
| auto it = model_map_.find(model_id); | |||
| if (it == model_map_.end()) { | |||
| MSI_LOG_ERROR << "Invalid model id " << model_id; | |||
| return {}; | |||
| } | |||
| auto &model_info = it->second; | |||
| return model_info.input_tensor_infos; | |||
| } | |||
| std::vector<serving::TensorInfo> MindSporeModelWrap::GetInputInfos() const { return model_.input_tensor_infos; } | |||
| std::vector<serving::TensorInfo> MindSporeModelWrap::GetOutputInfos(uint32_t model_id) const { | |||
| auto it = model_map_.find(model_id); | |||
| if (it == model_map_.end()) { | |||
| MSI_LOG_ERROR << "Invalid model id " << model_id; | |||
| return {}; | |||
| } | |||
| auto &model_info = it->second; | |||
| return model_info.output_tensor_infos; | |||
| } | |||
| std::vector<serving::TensorInfo> MindSporeModelWrap::GetOutputInfos() const { return model_.output_tensor_infos; } | |||
| ssize_t MindSporeModelWrap::GetBatchSize(uint32_t model_id) const { | |||
| auto it = model_map_.find(model_id); | |||
| if (it == model_map_.end()) { | |||
| MSI_LOG_ERROR << "Invalid model id " << model_id; | |||
| return {}; | |||
| } | |||
| auto &model_info = it->second; | |||
| return model_info.batch_size; | |||
| } | |||
| ssize_t MindSporeModelWrap::GetBatchSize() const { return model_.batch_size; } | |||
| bool MindSporeModelWrap::CheckModelSupport(DeviceType device_type, ModelType model_type) const { | |||
| std::string device_type_str; | |||
| switch (device_type) { | |||
| case kDeviceTypeAscendMS: | |||
| if (model_type != kMindIR) { | |||
| return false; | |||
| } | |||
| device_type_str = mindspore::kDeviceTypeAscend910; | |||
| break; | |||
| case kDeviceTypeAscendCL: | |||
| if (model_type != kMindIR && model_type != kOM) { | |||
| return false; | |||
| } | |||
| device_type_str = mindspore::kDeviceTypeAscend310; | |||
| break; | |||
| default: | |||
| @@ -378,9 +360,5 @@ ApiBufferTensorWrap::ApiBufferTensorWrap(const mindspore::MSTensor &tensor) : te | |||
| ApiBufferTensorWrap::~ApiBufferTensorWrap() = default; | |||
| REGISTER_INFER_SEESION(serving::kDeviceTypeAscendCL, kOM, MindSporeModelWrap, 1); | |||
| REGISTER_INFER_SEESION(serving::kDeviceTypeAscendCL, kMindIR, MindSporeModelWrap, 1); | |||
| REGISTER_INFER_SEESION(serving::kDeviceTypeAscendMS, kMindIR, MindSporeModelWrap, 1); | |||
| } // namespace serving | |||
| } // namespace mindspore | |||
| @@ -34,54 +34,46 @@ struct ApiModelInfo { | |||
| std::vector<serving::TensorInfo> input_tensor_infos; | |||
| std::vector<std::string> output_names; | |||
| std::vector<serving::TensorInfo> output_tensor_infos; | |||
| std::shared_ptr<mindspore::Model> model; | |||
| std::shared_ptr<mindspore::Model> model = nullptr; | |||
| uint32_t batch_size = 0; | |||
| std::string device_type; | |||
| uint32_t device_id = 0; | |||
| bool with_batch_dim = false; | |||
| std::vector<int> without_batch_dim_inputs; | |||
| }; | |||
| class MindSporeModelWrap : public InferSession { | |||
| class MindSporeModelWrap { | |||
| public: | |||
| MindSporeModelWrap() = default; | |||
| ~MindSporeModelWrap() = default; | |||
| Status InitEnv(serving::DeviceType device_type, uint32_t device_id, | |||
| const std::map<std::string, std::string> &other_options) override; | |||
| Status FinalizeEnv() override; | |||
| Status LoadModelFromFile(serving::DeviceType device_type, uint32_t device_id, const std::string &file_name, | |||
| ModelType model_type, const std::vector<int> &without_batch_dim_inputs, | |||
| const std::map<std::string, std::string> &other_options, uint32_t *model_id) override; | |||
| Status UnloadModel(uint32_t model_id) override; | |||
| ModelType model_type, bool with_batch_dim, const std::vector<int> &without_batch_dim_inputs, | |||
| const std::map<std::string, std::string> &other_options); | |||
| // override this method to avoid request/reply data copy | |||
| Status ExecuteModel(uint32_t model_id, const RequestBase &request, ReplyBase *reply) override; | |||
| Status ExecuteModel(uint32_t model_id, const std::vector<TensorBasePtr> &request, | |||
| std::vector<TensorBasePtr> *reply) override; | |||
| Status UnloadModel(); | |||
| Status ExecuteModel(const RequestBase &request, ReplyBase *reply); | |||
| Status ExecuteModel(const std::vector<TensorBasePtr> &request, std::vector<TensorBasePtr> *reply); | |||
| std::vector<serving::TensorInfo> GetInputInfos(uint32_t model_id) const override; | |||
| std::vector<serving::TensorInfo> GetInputInfos() const; | |||
| std::vector<serving::TensorInfo> GetOutputInfos(uint32_t model_id) const override; | |||
| std::vector<serving::TensorInfo> GetOutputInfos() const; | |||
| ssize_t GetBatchSize(uint32_t model_id) const override; | |||
| ssize_t GetBatchSize() const; | |||
| bool CheckModelSupport(DeviceType device_type, ModelType model_type) const override; | |||
| bool CheckModelSupport(DeviceType device_type, ModelType model_type) const; | |||
| private: | |||
| std::unordered_map<uint32_t, ApiModelInfo> model_map_; | |||
| uint32_t model_index_ = 0; | |||
| ApiModelInfo model_; | |||
| using FuncMakeInBuffer = std::function<mindspore::MSTensor(size_t index, const std::string &name)>; | |||
| using FuncMakeOutTensor = | |||
| std::function<void(const mindspore::MSTensor, DataType data_type, const std::vector<int64_t> &shape)>; | |||
| Status ExecuteModelCommon(uint32_t model_id, size_t request_size, const FuncMakeInBuffer &in_func, | |||
| const FuncMakeOutTensor &out_func); | |||
| Status ExecuteModelCommon(size_t request_size, const FuncMakeInBuffer &in_func, const FuncMakeOutTensor &out_func); | |||
| Status GetModelInfos(ApiModelInfo *model_info); | |||
| std::shared_ptr<Context> TransformModelContext(const std::map<std::string, std::string> &other_options); | |||
| void GetModelBatchSize(ApiModelInfo *model_info); | |||
| }; | |||
| class ApiBufferTensorWrap : public TensorBase { | |||
| @@ -0,0 +1,254 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "worker/local_servable/local_sevable.h" | |||
| #include <algorithm> | |||
| #include <set> | |||
| #include <map> | |||
| #include <vector> | |||
| #include <string> | |||
| #include "common/tensor.h" | |||
| #include "common/file_system_operation.h" | |||
| #include "worker/context.h" | |||
| namespace { | |||
| static const char *kVersionStrategyLatest = "latest"; | |||
| static const char *kVersionStrategySpecific = "specific"; | |||
| } // namespace | |||
| namespace mindspore::serving { | |||
| LocalModelServable::~LocalModelServable() { Clear(); } | |||
| std::string LocalModelServable::GetServableName() const { return servable_name_; } | |||
| uint64_t LocalModelServable::GetServableVersion() const { return version_number_; } | |||
| Status LocalModelServable::Predict(const std::vector<TensorBasePtr> &input, std::vector<TensorBasePtr> *output) { | |||
| if (!model_loaded_) { | |||
| MSI_LOG_EXCEPTION << "Model has not been loaded"; | |||
| } | |||
| return session_.ExecuteModel(input, output); | |||
| } | |||
| std::vector<TensorInfo> LocalModelServable::GetInputInfos() const { | |||
| if (!model_loaded_) { | |||
| MSI_LOG_EXCEPTION << "Model has not been loaded"; | |||
| } | |||
| return session_.GetInputInfos(); | |||
| } | |||
| std::vector<TensorInfo> LocalModelServable::GetOutputInfos() const { | |||
| if (!model_loaded_) { | |||
| MSI_LOG_EXCEPTION << "Model has not been loaded"; | |||
| } | |||
| return session_.GetOutputInfos(); | |||
| } | |||
| uint64_t LocalModelServable::GetBatchSize() const { | |||
| if (!model_loaded_) { | |||
| MSI_LOG_EXCEPTION << "Model has not been loaded"; | |||
| } | |||
| return session_.GetBatchSize(); | |||
| } | |||
| Status LocalModelServable::StartServable(const std::string &servable_directory, const std::string &servable_name, | |||
| uint64_t version_number) { | |||
| if (model_loaded_) { | |||
| MSI_LOG_EXCEPTION << "Model has loaded"; | |||
| } | |||
| base_spec_.servable_directory = servable_directory; | |||
| base_spec_.servable_name = servable_name; | |||
| base_spec_.version_number = version_number; | |||
| std::string version_strategy; | |||
| if (version_number == 0) { | |||
| version_strategy = kVersionStrategyLatest; | |||
| } else { | |||
| version_strategy = kVersionStrategySpecific; | |||
| } | |||
| Status status; | |||
| ServableSignature signature; | |||
| if (!ServableStorage::Instance().GetServableDef(servable_name, &signature)) { | |||
| return INFER_STATUS_LOG_ERROR(FAILED) << "Servable '" << servable_name << "' has not been registered"; | |||
| } | |||
| status = InitDevice(signature.servable_meta.local_meta.model_format, {}); | |||
| if (status != SUCCESS) { | |||
| MSI_LOG_ERROR << "Init env failed"; | |||
| return status; | |||
| } | |||
| std::vector<uint64_t> real_versions; | |||
| status = LoadServableConfig(base_spec_, version_strategy, &real_versions); | |||
| if (status != SUCCESS) { | |||
| return INFER_STATUS_LOG_ERROR(FAILED) | |||
| << "Start servable failed, there is no servable of the specified version number, specified version number: " | |||
| << version_number << ", servable directory: '" << base_spec_.servable_directory << "', servable name: '" | |||
| << base_spec_.servable_name | |||
| << "'. version number is a positive integer(started from 1) and 0 represents the maximum version number."; | |||
| } | |||
| auto real_version_number = real_versions[0]; | |||
| status = LoadModel(real_version_number); | |||
| if (status != SUCCESS) { | |||
| return status; | |||
| } | |||
| servable_name_ = base_spec_.servable_name; | |||
| version_number_ = real_version_number; | |||
| model_loaded_ = true; | |||
| MSI_LOG_INFO << status.StatusMessage(); | |||
| std::cout << status.StatusMessage() << std::endl; | |||
| return SUCCESS; | |||
| } | |||
| void LocalModelServable::GetVersions(const LoadServableSpec &servable_spec, std::vector<uint64_t> *real_versions) { | |||
| MSI_EXCEPTION_IF_NULL(real_versions); | |||
| // define version_strategy:"specific","latest","multi" | |||
| if (version_strategy_ == kVersionStrategySpecific) { | |||
| real_versions->push_back(servable_spec.version_number); | |||
| return; | |||
| } | |||
| auto trans_to_integer = [](const std::string &str) -> uint32_t { | |||
| uint32_t parsed_value = 0; | |||
| for (auto c : str) { | |||
| if (c < '0' || c > '9') { | |||
| return 0; | |||
| } | |||
| parsed_value = parsed_value * 10 + c - '0'; | |||
| } | |||
| if (std::to_string(parsed_value) != str) { | |||
| return 0; | |||
| } | |||
| return parsed_value; | |||
| }; | |||
| uint64_t newest_version = 0; | |||
| std::string model_path = servable_spec.servable_directory + "/" + servable_spec.servable_name; | |||
| auto sub_dir = GetAllSubDirsNotFullPath(model_path); | |||
| static std::set<std::string> ignore_dir; | |||
| for (const auto &dir : sub_dir) { | |||
| if (dir == "__pycache__") continue; | |||
| auto version_parse = trans_to_integer(dir); | |||
| if (version_parse == 0) { | |||
| if (ignore_dir.emplace(servable_spec.servable_directory + dir).second) { | |||
| MSI_LOG_INFO << "Ignore directory " << dir << ", model_directory " << servable_spec.servable_directory | |||
| << ", model_name " << servable_spec.servable_name; | |||
| } | |||
| continue; | |||
| } | |||
| real_versions->push_back(version_parse); | |||
| if (version_parse > newest_version) { | |||
| newest_version = version_parse; | |||
| } | |||
| } | |||
| if (version_strategy_ == kVersionStrategyLatest) { | |||
| real_versions->clear(); | |||
| if (newest_version != 0) { | |||
| real_versions->push_back(newest_version); | |||
| } | |||
| } | |||
| } | |||
| Status LocalModelServable::LoadServableConfig(const LoadServableSpec &servable_spec, | |||
| const std::string &version_strategy, | |||
| std::vector<uint64_t> *real_versions) { | |||
| MSI_EXCEPTION_IF_NULL(real_versions); | |||
| auto model_directory = servable_spec.servable_directory; | |||
| auto model_name = servable_spec.servable_name; | |||
| if (!DirOrFileExist(model_directory + "/" + model_name)) { | |||
| return INFER_STATUS_LOG_ERROR(FAILED) | |||
| << "Model not found, model_directory " << model_directory << ", model_name " << model_name; | |||
| } | |||
| std::string model_path = model_directory + "/" + model_name; | |||
| auto version_directory = [model_path](int64_t version_number) { | |||
| return model_path + "/" + std::to_string(version_number); | |||
| }; | |||
| version_strategy_ = version_strategy; | |||
| // version_strategy:"specific","latest","multi" | |||
| GetVersions(servable_spec, real_versions); | |||
| if (real_versions->size() == 0) { | |||
| return INFER_STATUS_LOG_ERROR(FAILED) | |||
| << "Not found invalid model version , model_directory " << model_directory << ", model_name " << model_name; | |||
| } | |||
| for (auto real_version_number : *real_versions) { | |||
| if (!DirOrFileExist(version_directory(real_version_number))) { | |||
| return INFER_STATUS_LOG_ERROR(FAILED) << "Open failed for version " << real_version_number << ", model_directory " | |||
| << model_directory << ", model_name " << model_name; | |||
| } | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| Status LocalModelServable::InitDevice(ModelType model_type, const std::map<std::string, std::string> &other_options) { | |||
| Status status; | |||
| auto context = ServableContext::Instance(); | |||
| DeviceType device_type = ServableContext::Instance()->GetDeviceType(); | |||
| auto get_support_device_type = [this, device_type, model_type]() { | |||
| std::vector<DeviceType> support_device_list; | |||
| if (device_type == kDeviceTypeNotSpecified || device_type == kDeviceTypeAscend) { | |||
| auto ascend_list = {kDeviceTypeAscendCL, kDeviceTypeAscendMS}; | |||
| for (auto item : ascend_list) { | |||
| if (session_.CheckModelSupport(item, model_type)) { | |||
| return item; | |||
| } | |||
| } | |||
| } else if (device_type == kDeviceTypeAscendCL || device_type == kDeviceTypeAscendMS) { | |||
| if (session_.CheckModelSupport(device_type, model_type)) { | |||
| return device_type; | |||
| } | |||
| } | |||
| return kDeviceTypeNotSpecified; | |||
| }; | |||
| auto support_device_type = get_support_device_type(); | |||
| if (support_device_type == kDeviceTypeNotSpecified) { | |||
| return INFER_STATUS_LOG_ERROR(FAILED) | |||
| << "Not support device type " << device_type << " and model type " << model_type | |||
| << ". Ascend 910 supports MindIR model and Ascend 310 supports OM, MindIR model"; | |||
| } | |||
| context->SetDeviceType(support_device_type); | |||
| return SUCCESS; | |||
| } | |||
| Status LocalModelServable::LoadModel(uint64_t version_number) { | |||
| ServableSignature signature; | |||
| if (!ServableStorage::Instance().GetServableDef(base_spec_.servable_name, &signature)) { | |||
| return INFER_STATUS_LOG_ERROR(FAILED) << "Servable " << base_spec_.servable_name << " has not been registered"; | |||
| } | |||
| const auto &servable_meta = signature.servable_meta; | |||
| const auto &common_meta = servable_meta.common_meta; | |||
| const auto &local_meta = servable_meta.local_meta; | |||
| std::string model_file_name = base_spec_.servable_directory + "/" + base_spec_.servable_name + "/" + | |||
| std::to_string(version_number) + "/" + local_meta.servable_file; | |||
| auto context = ServableContext::Instance(); | |||
| Status status = session_.LoadModelFromFile(context->GetDeviceType(), context->GetDeviceId(), model_file_name, | |||
| local_meta.model_format, common_meta.with_batch_dim, | |||
| common_meta.without_batch_dim_inputs, local_meta.load_options); | |||
| if (status != SUCCESS) { | |||
| return INFER_STATUS_LOG_ERROR(FAILED) | |||
| << "Load model failed, servable directory: '" << base_spec_.servable_directory << "', servable name: '" | |||
| << base_spec_.servable_name << "', servable file: '" << local_meta.servable_file << "', version number " | |||
| << version_number << ", options " << local_meta.load_options; | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| void LocalModelServable::Clear() { | |||
| if (model_loaded_) { | |||
| session_.UnloadModel(); | |||
| } | |||
| model_loaded_ = false; | |||
| } | |||
| } // namespace mindspore::serving | |||
| @@ -0,0 +1,69 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_SERVING_WORKER_ASCEND_SERVABLE_H | |||
| #define MINDSPORE_SERVING_WORKER_ASCEND_SERVABLE_H | |||
| #include <memory> | |||
| #include <vector> | |||
| #include <string> | |||
| #include <map> | |||
| #include "common/serving_common.h" | |||
| #include "common/instance.h" | |||
| #include "common/servable.h" | |||
| #include "worker/sevable_base.h" | |||
| #include "worker/inference/inference.h" | |||
| #include "worker/inference/mindspore_model_wrap.h" | |||
| namespace mindspore::serving { | |||
| class MS_API LocalModelServable : public ServableBase { | |||
| public: | |||
| LocalModelServable() = default; | |||
| ~LocalModelServable() override; | |||
| Status Predict(const std::vector<TensorBasePtr> &input, std::vector<TensorBasePtr> *output) override; | |||
| std::vector<TensorInfo> GetInputInfos() const override; | |||
| std::vector<TensorInfo> GetOutputInfos() const override; | |||
| uint64_t GetBatchSize() const override; | |||
| Status StartServable(const std::string &servable_directory, const std::string &servable_name, | |||
| uint64_t version_number); | |||
| Status InitDevice(ModelType model_type, const std::map<std::string, std::string> &other_options); | |||
| std::string GetServableName() const override; | |||
| uint64_t GetServableVersion() const override; | |||
| void Clear() override; | |||
| private: | |||
| LoadServableSpec base_spec_; | |||
| std::string servable_name_; | |||
| uint64_t version_number_ = 0; | |||
| MindSporeModelWrap session_; | |||
| std::string version_strategy_; | |||
| bool model_loaded_ = false; | |||
| void GetVersions(const LoadServableSpec &servable_spec, std::vector<uint64_t> *real_versions); | |||
| Status LoadServableConfig(const LoadServableSpec &servable_spec, const std::string &version_strategy, | |||
| std::vector<uint64_t> *real_version_number); | |||
| Status LoadModel(uint64_t version); | |||
| }; | |||
| } // namespace mindspore::serving | |||
| #endif // MINDSPORE_SERVING_WORKER_ASCEND_SERVABLE_H | |||
| @@ -1,33 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "worker/model.h" | |||
| #include <algorithm> | |||
| #include "mindspore_serving/ccsrc/common/tensor.h" | |||
| namespace mindspore::serving { | |||
| Status AscendModelServable::Predict(const std::vector<TensorBasePtr> &input, std::vector<TensorBasePtr> *output) { | |||
| return session_->ExecuteModel(model_id_, input, output); | |||
| } | |||
| std::vector<TensorInfo> AscendModelServable::GetInputInfos() const { return session_->GetInputInfos(model_id_); } | |||
| std::vector<TensorInfo> AscendModelServable::GetOutputInfos() const { return session_->GetOutputInfos(model_id_); } | |||
| uint64_t AscendModelServable::GetBatchSize() const { return session_->GetBatchSize(model_id_); } | |||
| } // namespace mindspore::serving | |||
| @@ -14,8 +14,8 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_SERVING_WORKER_MODEL_H | |||
| #define MINDSPORE_SERVING_WORKER_MODEL_H | |||
| #ifndef MINDSPORE_SERVING_WORKER_SERVABLE_BASE_H | |||
| #define MINDSPORE_SERVING_WORKER_SERVABLE_BASE_H | |||
| #include <memory> | |||
| #include <unordered_map> | |||
| @@ -39,25 +39,11 @@ class ServableBase { | |||
| virtual std::vector<TensorInfo> GetInputInfos() const = 0; | |||
| virtual std::vector<TensorInfo> GetOutputInfos() const = 0; | |||
| virtual uint64_t GetBatchSize() const = 0; | |||
| }; | |||
| class AscendModelServable : public ServableBase { | |||
| public: | |||
| AscendModelServable(const std::shared_ptr<serving::InferSession> &session, uint32_t model_id) | |||
| : session_(session), model_id_(model_id) {} | |||
| ~AscendModelServable() = default; | |||
| Status Predict(const std::vector<TensorBasePtr> &input, std::vector<TensorBasePtr> *output) override; | |||
| std::vector<TensorInfo> GetInputInfos() const override; | |||
| std::vector<TensorInfo> GetOutputInfos() const override; | |||
| uint64_t GetBatchSize() const override; | |||
| private: | |||
| std::shared_ptr<serving::InferSession> session_{nullptr}; | |||
| uint32_t model_id_ = 0; | |||
| virtual std::string GetServableName() const = 0; | |||
| virtual uint64_t GetServableVersion() const = 0; | |||
| virtual void Clear() = 0; | |||
| }; | |||
| } // namespace mindspore::serving | |||
| #endif // MINDSPORE_SERVING_WORKER_MODEL_H | |||
| #endif // MINDSPORE_SERVING_WORKER_SERVABLE_BASE_H | |||
| @@ -49,15 +49,15 @@ Status WorkExecutor::CheckSevableSignature() { | |||
| if (servable_declare_.methods.empty()) { | |||
| return INFER_STATUS_LOG_ERROR(FAILED) << "There is no method registered for servable"; | |||
| } | |||
| if (input_infos.size() != servable_declare_.servable_meta.inputs_count) { | |||
| return INFER_STATUS_LOG_ERROR(FAILED) | |||
| << "The inputs count " << servable_declare_.servable_meta.inputs_count << " registered in method " | |||
| << "not equal to the count " << input_infos.size() << " defined in servable"; | |||
| const auto &common_meta = servable_declare_.servable_meta.common_meta; | |||
| if (input_infos.size() != common_meta.inputs_count) { | |||
| return INFER_STATUS_LOG_ERROR(FAILED) << "The inputs count " << common_meta.inputs_count << " registered in method " | |||
| << "not equal to the count " << input_infos.size() << " defined in servable"; | |||
| } | |||
| const auto &output_infos = output_infos_; | |||
| if (output_infos.size() != servable_declare_.servable_meta.outputs_count) { | |||
| if (output_infos.size() != common_meta.outputs_count) { | |||
| return INFER_STATUS_LOG_ERROR(FAILED) | |||
| << "The outputs count " << servable_declare_.servable_meta.outputs_count << " registered in method " | |||
| << "The outputs count " << common_meta.outputs_count << " registered in method " | |||
| << "not equal to the count " << output_infos.size() << " defined in servable"; | |||
| } | |||
| MSI_LOG_INFO << "Model input infos: count " << input_infos.size(); | |||
| @@ -68,7 +68,7 @@ Status WorkExecutor::CheckSevableSignature() { | |||
| for (auto &item : output_infos) { | |||
| MSI_LOG_INFO << item.shape << ", " << item.data_type << ", " << item.size; | |||
| } | |||
| if (servable_declare_.servable_meta.with_batch_dim) { | |||
| if (common_meta.with_batch_dim) { | |||
| if (model_batch_size_ == 0) { | |||
| return INFER_STATUS_LOG_ERROR(FAILED) << "Servable batch size cannot be " << model_batch_size_; | |||
| } | |||
| @@ -104,7 +104,7 @@ Status WorkExecutor::Init(const ServableSignature &servable_declare, const std:: | |||
| servable_ = servable; | |||
| input_infos_ = servable_->GetInputInfos(); | |||
| output_infos_ = servable_->GetOutputInfos(); | |||
| if (servable_declare_.servable_meta.with_batch_dim) { | |||
| if (servable_declare_.servable_meta.common_meta.with_batch_dim) { | |||
| model_batch_size_ = servable_->GetBatchSize(); | |||
| } else { | |||
| model_batch_size_ = 1; | |||
| @@ -389,7 +389,7 @@ Status WorkExecutor::PostPredict(const std::vector<Instance> &inputs, const std: | |||
| MSI_LOG_EXCEPTION << "Output result data size cannot be 0"; | |||
| } | |||
| auto shape = item->shape(); | |||
| if (servable_declare_.servable_meta.with_batch_dim) { | |||
| if (servable_declare_.servable_meta.common_meta.with_batch_dim) { | |||
| if (shape.empty() || shape[0] != model_batch_size) { | |||
| MSI_LOG_EXCEPTION << "Output shape " << shape << " not match batch size " << model_batch_size; | |||
| } | |||
| @@ -429,9 +429,9 @@ Status WorkExecutor::Predict(const std::vector<Instance> &inputs, std::vector<In | |||
| } | |||
| bool WorkExecutor::IsNoBatchDimInput(int input_index) const { | |||
| auto without_batch_dim_inputs = servable_declare_.servable_meta.without_batch_dim_inputs; | |||
| auto without_batch_dim_inputs = servable_declare_.servable_meta.common_meta.without_batch_dim_inputs; | |||
| bool no_batch_dim = true; | |||
| if (servable_declare_.servable_meta.with_batch_dim) { | |||
| if (servable_declare_.servable_meta.common_meta.with_batch_dim) { | |||
| no_batch_dim = std::find(without_batch_dim_inputs.begin(), without_batch_dim_inputs.end(), input_index) != | |||
| without_batch_dim_inputs.end(); | |||
| } | |||
| @@ -28,7 +28,7 @@ | |||
| #include "common/serving_common.h" | |||
| #include "common/instance.h" | |||
| #include "common/servable.h" | |||
| #include "worker/model.h" | |||
| #include "worker/sevable_base.h" | |||
| #include "worker/predict_thread.h" | |||
| #include "worker/task_queue.h" | |||
| @@ -39,10 +39,8 @@ using WorkCallBack = std::function<void(const Instance &output, const Status &er | |||
| class WorkExecutor { | |||
| public: | |||
| WorkExecutor(std::shared_ptr<TaskQueue> py_preprocess_task_queue, | |||
| std::shared_ptr<TaskQueue> py_postprocess_task_queue, | |||
| std::shared_ptr<TaskQueue> cpp_preprocess_task_queue, | |||
| std::shared_ptr<TaskQueue> cpp_postprocess_task_queue); | |||
| WorkExecutor(std::shared_ptr<TaskQueue> py_preprocess, std::shared_ptr<TaskQueue> py_postprocess, | |||
| std::shared_ptr<TaskQueue> cpp_preprocess, std::shared_ptr<TaskQueue> cpp_postprocess); | |||
| ~WorkExecutor(); | |||
| Status Init(const ServableSignature &servable_declare, const std::shared_ptr<ServableBase> &servable); | |||
| @@ -34,46 +34,16 @@ namespace py = pybind11; | |||
| namespace mindspore { | |||
| namespace serving { | |||
| static const char *kVersionStrategyLastest = "lastest"; | |||
| static const char *kVersionStrategySpecific = "specific"; | |||
| static std::unique_ptr<MSWorkerServer> grpc_async_worker_server_; | |||
| Worker &Worker::GetInstance() { | |||
| static Worker instance; | |||
| return instance; | |||
| } | |||
| Status Worker::StartGrpcServer(const std::string &ip, uint32_t grpc_port) { | |||
| if (grpc_async_worker_server_ != nullptr) { | |||
| return INFER_STATUS_LOG_ERROR(SYSTEM_ERROR) << "Serving Error: Worker gRPC server is already running"; | |||
| } | |||
| grpc_async_worker_server_ = std::make_unique<MSWorkerServer>(ip, grpc_port); | |||
| return grpc_async_worker_server_->Init(); | |||
| } | |||
| Status Worker::RegisterWorker() { | |||
| std::vector<LoadServableSpec> specs; | |||
| std::vector<ServableSignature> signatures; | |||
| for (auto &work : work_list_) { | |||
| specs.push_back(work.servable_spec); | |||
| signatures.push_back(work.servable_signature); | |||
| } | |||
| std::vector<WorkerSpec> worker_specs; | |||
| for (size_t i = 0; i < specs.size(); i++) { | |||
| auto &spec = specs[i]; | |||
| auto &servable_signature = signatures[i]; | |||
| WorkerSpec worker_spec; | |||
| worker_spec.servable_name = spec.servable_name; | |||
| worker_spec.version_number = spec.version_number; | |||
| for (auto &method : servable_signature.methods) { | |||
| WorkerMethodInfo worker_method_info; | |||
| worker_method_info.name = method.method_name; | |||
| for (auto &name : method.inputs) { | |||
| worker_method_info.input_names.push_back(name); | |||
| } | |||
| worker_spec.methods.push_back(worker_method_info); | |||
| } | |||
| worker_specs.push_back(worker_spec); | |||
| for (auto &work : work_list_) { | |||
| // cppcheck-suppress useStlAlgorithm | |||
| worker_specs.push_back(work.worker_spec); | |||
| } | |||
| auto status = notify_master_->Register(worker_specs); | |||
| return status; | |||
| @@ -84,34 +54,10 @@ Status Worker::StartVersionController() { | |||
| return SUCCESS; | |||
| } | |||
| Status Worker::AddWorker(const ServableWorkerContext &work) { | |||
| WorkerSpec worker_spec; | |||
| worker_spec.servable_name = work.servable_spec.servable_name; | |||
| worker_spec.version_number = work.servable_spec.version_number; | |||
| for (auto &method : work.servable_signature.methods) { | |||
| WorkerMethodInfo worker_method_info; | |||
| worker_method_info.name = method.method_name; | |||
| for (auto &name : method.inputs) { | |||
| worker_method_info.input_names.push_back(name); | |||
| } | |||
| worker_spec.methods.push_back(worker_method_info); | |||
| } | |||
| return notify_master_->AddWorker(worker_spec); | |||
| } | |||
| Status Worker::AddWorker(const ServableWorkerContext &work) { return notify_master_->AddWorker(work.worker_spec); } | |||
| Status Worker::RemoveWorker(const ServableWorkerContext &work) { | |||
| WorkerSpec worker_spec; | |||
| worker_spec.servable_name = work.servable_spec.servable_name; | |||
| worker_spec.version_number = work.servable_spec.version_number; | |||
| for (auto &method : work.servable_signature.methods) { | |||
| WorkerMethodInfo worker_method_info; | |||
| worker_method_info.name = method.method_name; | |||
| for (auto &name : method.inputs) { | |||
| worker_method_info.input_names.push_back(name); | |||
| } | |||
| worker_spec.methods.push_back(worker_method_info); | |||
| } | |||
| return notify_master_->RemoveWorker(worker_spec); | |||
| return notify_master_->RemoveWorker(work.worker_spec); | |||
| } | |||
| Status Worker::Run(const proto::PredictRequest &request, proto::PredictReply *reply) { | |||
| @@ -189,74 +135,8 @@ std::pair<Status, std::shared_ptr<AsyncResult>> Worker::RunAsync(const RequestSp | |||
| return {SUCCESS, result}; | |||
| } | |||
| Status Worker::InitEnv(ModelType model_type, const std::map<std::string, std::string> &other_options) { | |||
| Status status; | |||
| if (session_) { | |||
| return INFER_STATUS_LOG_ERROR(FAILED) << "Session has been inited"; | |||
| } | |||
| auto context = ServableContext::Instance(); | |||
| DeviceType device_type = kDeviceTypeNotSpecified; | |||
| session_ = InferSessionStorage::Instance().Get(context->GetDeviceType(), model_type, &device_type); | |||
| if (session_ == nullptr) { | |||
| return INFER_STATUS_LOG_ERROR(FAILED) | |||
| << "Cannot find session registered for device type " << context->GetDeviceType() << " and model type " | |||
| << model_type << ". Ascend 910 supports MindIR model and Ascend 310 supports OM, MindIR model"; | |||
| } | |||
| if (device_type != kDeviceTypeNotSpecified) { | |||
| context->SetDeviceType(device_type); | |||
| } | |||
| status = session_->InitEnv(context->GetDeviceType(), context->GetDeviceId(), other_options); | |||
| if (status != SUCCESS) { | |||
| session_ = nullptr; | |||
| return INFER_STATUS_LOG_ERROR(FAILED) | |||
| << "Init env failed, device type " << context->GetDeviceType() << ", device id " << context->GetDeviceId(); | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| Status Worker::FinalizeEnv() { | |||
| if (session_ != nullptr) { | |||
| return session_->FinalizeEnv(); | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| Status Worker::LoadModel(LoadServableSpec *servable_spec, uint64_t version_number, ServableWorkerContext *work) { | |||
| MSI_EXCEPTION_IF_NULL(servable_spec); | |||
| MSI_EXCEPTION_IF_NULL(work); | |||
| servable_spec->version_number = version_number; | |||
| ServableSignature signature; | |||
| if (!ServableStorage::Instance().GetServableDef(servable_spec->servable_name, &signature)) { | |||
| return INFER_STATUS_LOG_ERROR(FAILED) << "Servable " << servable_spec->servable_name << " has not been registerd"; | |||
| } | |||
| const auto &servable_meta = signature.servable_meta; | |||
| std::string model_file_name = servable_spec->servable_directory + "/" + servable_spec->servable_name + "/" + | |||
| std::to_string(version_number) + "/" + servable_meta.servable_file; | |||
| uint32_t model_id; | |||
| auto context = ServableContext::Instance(); | |||
| Status status = session_->LoadModelFromFile(context->GetDeviceType(), context->GetDeviceId(), model_file_name, | |||
| servable_meta.model_format, servable_meta.without_batch_dim_inputs, | |||
| servable_meta.load_options, &model_id); | |||
| if (status != SUCCESS) { | |||
| return INFER_STATUS_LOG_ERROR(FAILED) | |||
| << "Load model failed, servable directory: '" << servable_spec->servable_directory << "', servable name: '" | |||
| << servable_spec->servable_name << "', servable file: '" << servable_meta.servable_file | |||
| << "', version number " << version_number << ", options " << servable_meta.load_options; | |||
| } | |||
| auto service = std::make_shared<WorkExecutor>(GetPyTaskQueuePreprocess(), GetPyTaskQueuePostprocess(), | |||
| GetCppTaskQueuePreprocess(), GetCppTaskQueuePostprocess()); | |||
| status = service->Init(signature, std::make_shared<AscendModelServable>(session_, model_id)); | |||
| if (status != SUCCESS) { | |||
| return status; | |||
| } | |||
| work->servable_spec = *servable_spec; | |||
| work->servable_signature = signature; | |||
| work->worker_service = service; | |||
| work->model_id = model_id; | |||
| work->model_file_name = model_file_name; | |||
| return SUCCESS; | |||
| } | |||
| void Worker::Update() { | |||
| /* | |||
| if (version_strategy_ == kVersionStrategySpecific) { | |||
| return; | |||
| } | |||
| @@ -291,10 +171,19 @@ void Worker::Update() { | |||
| MSI_LOG_INFO << "UnLoad Model version " << iter->servable_spec.version_number << " success"; | |||
| work_list_.erase(iter); | |||
| } | |||
| */ | |||
| } | |||
| Status Worker::StartServable(const std::string &servable_directory, const std::string &servable_name, | |||
| uint32_t version_number, std::shared_ptr<BaseNotifyMaster> notify_master) { | |||
| Status Worker::StartGrpcServer(const std::shared_ptr<MSWorkerServer> &grpc_server, const std::string &worker_ip, | |||
| int32_t port) { | |||
| if (worker_grpc_server_ != nullptr) { | |||
| return INFER_STATUS_LOG_ERROR(FAILED) << "Worker gRPC server is already running"; | |||
| } | |||
| worker_grpc_server_ = grpc_server; | |||
| return worker_grpc_server_->StartWorkerGrpcServer(worker_ip, port); | |||
| } | |||
| Status Worker::StartServable(std::shared_ptr<ServableBase> servable, std::shared_ptr<BaseNotifyMaster> notify_master) { | |||
| ExitSignalHandle::Instance().Start(); // handle ctrl+c to exit | |||
| if (servable_started_) { | |||
| MSI_LOG_EXCEPTION << "A servable has been started, only one servable can run in a process currently."; | |||
| @@ -307,58 +196,42 @@ Status Worker::StartServable(const std::string &servable_directory, const std::s | |||
| cpp_postprocess_.Start(2); | |||
| notify_master_ = std::move(notify_master); | |||
| base_spec_.servable_directory = servable_directory; | |||
| base_spec_.servable_name = servable_name; | |||
| base_spec_.version_number = version_number; | |||
| std::string version_strategy; | |||
| if (version_number == 0) { | |||
| version_strategy = kVersionStrategyLastest; | |||
| } else { | |||
| version_strategy = kVersionStrategySpecific; | |||
| } | |||
| Status status; | |||
| auto servable_name = servable->GetServableName(); | |||
| ServableSignature signature; | |||
| if (!ServableStorage::Instance().GetServableDef(servable_name, &signature)) { | |||
| return INFER_STATUS_LOG_ERROR(FAILED) << "Servable '" << servable_name << "' has not been registered"; | |||
| } | |||
| if (session_ == nullptr) { | |||
| status = InitEnv(signature.servable_meta.model_format, {}); | |||
| if (status != SUCCESS) { | |||
| MSI_LOG_ERROR << "Init env failed"; | |||
| return status; | |||
| } | |||
| return INFER_STATUS_LOG_ERROR(FAILED) << "Servable " << servable_name << " has not been registered"; | |||
| } | |||
| std::vector<uint64_t> real_versions; | |||
| status = LoadServableConfig(base_spec_, version_strategy, &real_versions); | |||
| auto service = std::make_shared<WorkExecutor>(GetPyTaskQueuePreprocess(), GetPyTaskQueuePostprocess(), | |||
| GetCppTaskQueuePreprocess(), GetCppTaskQueuePostprocess()); | |||
| auto status = service->Init(signature, servable); | |||
| if (status != SUCCESS) { | |||
| return INFER_STATUS_LOG_ERROR(FAILED) | |||
| << "Start servable failed, there is no servable of the specified version number, specified version number: " | |||
| << version_number << ", servable directory: '" << base_spec_.servable_directory << "', servable name: '" | |||
| << base_spec_.servable_name | |||
| << "'. version number is a positive integer(started from 1) and 0 represents the maximum version number."; | |||
| return status; | |||
| } | |||
| for (auto real_version_number : real_versions) { | |||
| ServableWorkerContext work; | |||
| status = LoadModel(&base_spec_, real_version_number, &work); | |||
| if (status != SUCCESS) { | |||
| return status; | |||
| ServableWorkerContext work; | |||
| WorkerSpec worker_spec; | |||
| worker_spec.servable_name = servable_name; | |||
| worker_spec.version_number = servable->GetServableVersion(); | |||
| for (auto &method : signature.methods) { | |||
| WorkerMethodInfo worker_method_info; | |||
| worker_method_info.name = method.method_name; | |||
| for (auto &name : method.inputs) { | |||
| worker_method_info.input_names.push_back(name); | |||
| } | |||
| work_list_.push_back(work); | |||
| worker_spec.methods.push_back(worker_method_info); | |||
| } | |||
| work.worker_spec = worker_spec; | |||
| work.servable_signature = signature; | |||
| work.worker_service = service; | |||
| work.servable = servable; | |||
| work_list_.push_back(work); | |||
| status = RegisterWorker(); | |||
| if (status != SUCCESS) { | |||
| MSI_LOG_ERROR << "Register worker failed"; | |||
| return status; | |||
| } | |||
| servable_started_ = true; | |||
| status = INFER_STATUS(SUCCESS) << "Serving: Start servable success, servable directory: '" << servable_directory | |||
| << "', servable name: '" << servable_name | |||
| << "', specified version number: " << version_number | |||
| << ", started version numbers: " << real_versions; | |||
| MSI_LOG_INFO << status.StatusMessage(); | |||
| std::cout << status.StatusMessage() << std::endl; | |||
| return SUCCESS; | |||
| } | |||
| @@ -368,119 +241,39 @@ void Worker::StopServable(bool notify_master) { | |||
| } | |||
| void Worker::Clear() { | |||
| std::unique_lock<std::shared_mutex> lock(worker_shared_lock_); | |||
| ServableStorage::Instance().Clear(); | |||
| worker_grpc_server_ = nullptr; | |||
| if (clear_flag_.test_and_set()) { | |||
| return; | |||
| } | |||
| std::unique_lock<std::shared_mutex> lock(worker_shared_lock_); | |||
| MSI_LOG_INFO << "Start clear worker session"; | |||
| version_controller_.StopPollModelPeriodic(); | |||
| if (exit_notify_master_ && servable_started_) { | |||
| notify_master_->Unregister(); | |||
| } | |||
| if (session_ != nullptr) { | |||
| for (auto &it : work_list_) { | |||
| session_->UnloadModel(it.model_id); | |||
| } | |||
| for (auto &worker_item : work_list_) { | |||
| worker_item.servable->Clear(); | |||
| } | |||
| work_list_.clear(); | |||
| FinalizeEnv(); | |||
| session_ = nullptr; | |||
| py_task_queue_group_.Stop(); | |||
| cpp_preprocess_.Stop(); | |||
| cpp_postprocess_.Stop(); | |||
| ServableStorage::Instance().Clear(); | |||
| grpc_async_worker_server_ = nullptr; | |||
| servable_started_ = false; | |||
| MSI_LOG_INFO << "End clear worker session"; | |||
| } | |||
| bool Worker::HasCleared() { return !servable_started_; } | |||
| bool Worker::IsRunning() { return servable_started_; } | |||
| Worker::~Worker() { Clear(); } | |||
| void Worker::GetVersions(const LoadServableSpec &servable_spec, std::vector<uint64_t> *real_versions) { | |||
| MSI_EXCEPTION_IF_NULL(real_versions); | |||
| // define version_strategy:"specific","lastest","multi" | |||
| if (version_strategy_ == kVersionStrategySpecific) { | |||
| real_versions->push_back(servable_spec.version_number); | |||
| return; | |||
| } | |||
| auto trans_to_integer = [](const std::string &str) -> uint32_t { | |||
| uint32_t parsed_value = 0; | |||
| for (auto c : str) { | |||
| if (c < '0' || c > '9') { | |||
| return 0; | |||
| } | |||
| parsed_value = parsed_value * 10 + c - '0'; | |||
| } | |||
| if (std::to_string(parsed_value) != str) { | |||
| return 0; | |||
| } | |||
| return parsed_value; | |||
| }; | |||
| uint64_t newest_version = 0; | |||
| std::string model_path = servable_spec.servable_directory + "/" + servable_spec.servable_name; | |||
| auto sub_dir = GetAllSubDirsNotFullPath(model_path); | |||
| static std::set<std::string> ignore_dir; | |||
| for (const auto &dir : sub_dir) { | |||
| if (dir == "__pycache__") continue; | |||
| auto version_parse = trans_to_integer(dir); | |||
| if (version_parse == 0) { | |||
| if (ignore_dir.emplace(servable_spec.servable_directory + dir).second) { | |||
| MSI_LOG_INFO << "Ignore directory " << dir << ", model_directory " << servable_spec.servable_directory | |||
| << ", model_name " << servable_spec.servable_name; | |||
| } | |||
| continue; | |||
| } | |||
| real_versions->push_back(version_parse); | |||
| if (version_parse > newest_version) { | |||
| newest_version = version_parse; | |||
| } | |||
| } | |||
| if (version_strategy_ == kVersionStrategyLastest) { | |||
| real_versions->clear(); | |||
| if (newest_version != 0) { | |||
| real_versions->push_back(newest_version); | |||
| } | |||
| } | |||
| } | |||
| Status Worker::LoadServableConfig(const LoadServableSpec &servable_spec, const std::string &version_strategy, | |||
| std::vector<uint64_t> *real_versions) { | |||
| MSI_EXCEPTION_IF_NULL(real_versions); | |||
| auto model_directory = servable_spec.servable_directory; | |||
| auto model_name = servable_spec.servable_name; | |||
| if (!DirOrFileExist(model_directory + "/" + model_name)) { | |||
| return INFER_STATUS_LOG_ERROR(FAILED) | |||
| << "Model not found, model_directory " << model_directory << ", model_name " << model_name; | |||
| } | |||
| std::string model_path = model_directory + "/" + model_name; | |||
| auto version_directory = [model_path](int64_t version_number) { | |||
| return model_path + "/" + std::to_string(version_number); | |||
| }; | |||
| version_strategy_ = version_strategy; | |||
| // version_strategy:"specific","lastest","multi" | |||
| GetVersions(servable_spec, real_versions); | |||
| if (real_versions->size() == 0) { | |||
| return INFER_STATUS_LOG_ERROR(FAILED) | |||
| << "Not found invalid model version , model_directory " << model_directory << ", model_name " << model_name; | |||
| } | |||
| for (auto real_version_number : *real_versions) { | |||
| if (!DirOrFileExist(version_directory(real_version_number))) { | |||
| return INFER_STATUS_LOG_ERROR(FAILED) << "Open failed for version " << real_version_number << ", model_directory " | |||
| << model_directory << ", model_name " << model_name; | |||
| } | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| ServableWorkerContext Worker::GetServableWorker(const RequestSpec &request_spec) { | |||
| ServableWorkerContext context; | |||
| if (request_spec.version_number != 0) { | |||
| auto item = find_if(work_list_.begin(), work_list_.end(), [&](const ServableWorkerContext &v) { | |||
| return v.servable_spec.servable_name == request_spec.servable_name && | |||
| v.servable_spec.version_number == request_spec.version_number; | |||
| return v.worker_spec.servable_name == request_spec.servable_name && | |||
| v.worker_spec.version_number == request_spec.version_number; | |||
| }); | |||
| if (item != work_list_.end()) { | |||
| context = *item; | |||
| @@ -488,10 +281,10 @@ ServableWorkerContext Worker::GetServableWorker(const RequestSpec &request_spec) | |||
| } else { | |||
| uint64_t max_version = 0; | |||
| for (auto &item : work_list_) { | |||
| if (item.servable_spec.servable_name == request_spec.servable_name && | |||
| item.servable_spec.version_number > max_version) { | |||
| if (item.worker_spec.servable_name == request_spec.servable_name && | |||
| item.worker_spec.version_number > max_version) { | |||
| context = item; | |||
| max_version = item.servable_spec.version_number; | |||
| max_version = item.worker_spec.version_number; | |||
| } | |||
| } | |||
| } | |||
| @@ -500,11 +293,11 @@ ServableWorkerContext Worker::GetServableWorker(const RequestSpec &request_spec) | |||
| Worker::Worker() {} | |||
| ssize_t Worker::GetBatchSize() const { | |||
| ssize_t batch_size_ret = -1; | |||
| for (auto service : work_list_) { | |||
| auto batch_size = session_->GetBatchSize(service.model_id); | |||
| if (batch_size != -1) { | |||
| size_t Worker::GetBatchSize() const { | |||
| size_t batch_size_ret = 1; | |||
| for (const auto &service : work_list_) { | |||
| auto batch_size = service.servable->GetBatchSize(); | |||
| if (batch_size != 0) { | |||
| batch_size_ret = batch_size; | |||
| break; | |||
| } | |||
| @@ -532,7 +325,7 @@ Status AsyncResult::GetNext(Instance *instance_result) { | |||
| const int kWaitMaxHundredMs = 100; | |||
| int i; | |||
| for (i = 0; i < kWaitMaxHundredMs; i++) { // | |||
| if (ExitSignalHandle::Instance().HasStopped() || Worker::GetInstance().HasCleared()) { | |||
| if (ExitSignalHandle::Instance().HasStopped() || !Worker::GetInstance().IsRunning()) { | |||
| instance_result->error_msg = Status(SYSTEM_ERROR, "Servable stopped"); | |||
| return SYSTEM_ERROR; | |||
| } | |||
| @@ -32,6 +32,8 @@ | |||
| #include "worker/task_queue.h" | |||
| #include "worker/version_control/version_controller.h" | |||
| #include "common/grpc_async_server.h" | |||
| #include "worker/sevable_base.h" | |||
| #include "worker/grpc/worker_server.h" | |||
| namespace mindspore { | |||
| namespace serving { | |||
| @@ -53,11 +55,10 @@ class AsyncResult { | |||
| }; | |||
| struct ServableWorkerContext { | |||
| LoadServableSpec servable_spec; | |||
| WorkerSpec worker_spec; | |||
| ServableSignature servable_signature; | |||
| std::shared_ptr<WorkExecutor> worker_service = nullptr; | |||
| uint32_t model_id = 0; | |||
| std::string model_file_name; | |||
| std::shared_ptr<ServableBase> servable = nullptr; | |||
| }; | |||
| class MS_API Worker { | |||
| @@ -72,17 +73,14 @@ class MS_API Worker { | |||
| Status Run(const RequestSpec &request_spec, const std::vector<InstanceData> &inputs, std::vector<Instance> *outputs); | |||
| std::pair<Status, std::shared_ptr<AsyncResult>> RunAsync(const RequestSpec &request_spec, | |||
| const std::vector<InstanceData> &inputs); | |||
| Status StartServable(std::shared_ptr<ServableBase> servable, std::shared_ptr<BaseNotifyMaster> notify_master); | |||
| Status InitEnv(ModelType model_type, const std::map<std::string, std::string> &other_options); | |||
| Status FinalizeEnv(); | |||
| Status StartGrpcServer(const std::shared_ptr<MSWorkerServer> &grpc_server, const std::string &worker_ip, | |||
| int32_t port); | |||
| Status StartServable(const std::string &servable_directory, const std::string &servable_name, uint32_t version_number, | |||
| std::shared_ptr<BaseNotifyMaster> notify_master); | |||
| void StopServable(bool notify_master = true); | |||
| bool HasCleared(); | |||
| bool IsRunning(); | |||
| Status RegisterWorker(); | |||
| Status StartGrpcServer(const std::string &ip, uint32_t grpc_port); | |||
| Status LoadModel(LoadServableSpec *servable_spec, uint64_t version, ServableWorkerContext *work); | |||
| void Update(); | |||
| Status StartVersionController(); | |||
| Status AddWorker(const ServableWorkerContext &work); | |||
| @@ -93,31 +91,24 @@ class MS_API Worker { | |||
| std::shared_ptr<TaskQueue> GetPyTaskQueuePostprocess() { return py_task_queue_group_.GetPostprocessTaskQueue(); } | |||
| std::shared_ptr<TaskQueue> GetCppTaskQueuePreprocess() { return cpp_preprocess_.GetTaskQueue(); } | |||
| std::shared_ptr<TaskQueue> GetCppTaskQueuePostprocess() { return cpp_postprocess_.GetTaskQueue(); } | |||
| ssize_t GetBatchSize() const; | |||
| size_t GetBatchSize() const; | |||
| private: | |||
| static std::shared_ptr<Worker> global_worker_; | |||
| std::vector<ServableWorkerContext> work_list_; | |||
| std::shared_ptr<serving::InferSession> session_ = nullptr; | |||
| std::string version_strategy_; | |||
| PyTaskQueueGroup py_task_queue_group_; | |||
| PreprocessThreadPool cpp_preprocess_; | |||
| PostprocessThreadPool cpp_postprocess_; | |||
| VersionController version_controller_; | |||
| LoadServableSpec base_spec_; | |||
| std::atomic_bool exit_notify_master_ = true; | |||
| std::atomic_bool servable_started_ = false; | |||
| std::atomic_flag clear_flag_ = ATOMIC_FLAG_INIT; | |||
| std::shared_ptr<BaseNotifyMaster> notify_master_ = nullptr; | |||
| std::shared_ptr<MSWorkerServer> worker_grpc_server_ = nullptr; | |||
| std::shared_mutex worker_shared_lock_; | |||
| ServableWorkerContext GetServableWorker(const RequestSpec &request_spec); | |||
| Status LoadServableConfig(const LoadServableSpec &servable_spec, const std::string &version_strategy, | |||
| std::vector<uint64_t> *real_version_number); | |||
| void GetVersions(const LoadServableSpec &servable_spec, std::vector<uint64_t> *real_versions); | |||
| }; | |||
| } // namespace serving | |||
| @@ -18,6 +18,7 @@ import threading | |||
| from functools import wraps | |||
| from mindspore_serving.worker import check_type | |||
| from mindspore_serving import log as logger | |||
| from mindspore_serving._mindspore_serving import ExitSignalHandle_ | |||
| from mindspore_serving._mindspore_serving import Master_ | |||
| _wait_and_clear_thread = None | |||
| @@ -59,6 +60,7 @@ def stop_on_except(func): | |||
| @wraps(func) | |||
| def handle_except(*args, **kwargs): | |||
| try: | |||
| ExitSignalHandle_.start() # Set flag to running and receive Ctrl+C message | |||
| func(*args, **kwargs) | |||
| except: | |||
| stop() | |||
| @@ -0,0 +1,43 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| // ms_manager.proto | |||
| syntax = "proto3"; | |||
| package mindspore.serving.proto; | |||
| import "mindspore_serving/proto/ms_service.proto"; | |||
| message DistributedPredictRequest { | |||
| repeated Tensor inputs = 1; | |||
| } | |||
| message DistributedPredictReply { | |||
| repeated Tensor outputs = 1; | |||
| ErrorMsg error_msg = 2; | |||
| } | |||
| message DistributedExitRequest { | |||
| string address = 1; | |||
| } | |||
| message DistributedExitReply { | |||
| ErrorMsg error_msg = 1; | |||
| } | |||
| service MSAgent { | |||
| rpc Predict(DistributedPredictRequest) returns (DistributedPredictReply) {} | |||
| rpc Exit(DistributedExitRequest) returns (DistributedExitReply) {} | |||
| } | |||
| @@ -0,0 +1,53 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| // ms_manager.proto | |||
| syntax = "proto3"; | |||
| package mindspore.serving.proto; | |||
| import "mindspore_serving/proto/ms_service.proto"; | |||
| message AgentSpec { | |||
| int64 rank_id = 1; | |||
| int64 batch_size = 2; | |||
| repeated Tensor inputs = 3; | |||
| repeated Tensor outputs = 4; | |||
| } | |||
| message AgentRegisterRequest { | |||
| repeated AgentSpec agent_spec = 1; | |||
| string address = 2; | |||
| } | |||
| message AgentRegisterReply { | |||
| ErrorMsg error_msg = 1; | |||
| } | |||
| message AgentExitRequest { | |||
| repeated AgentSpec agent_spec = 1; | |||
| string address = 2; | |||
| } | |||
| message AgentExitReply { | |||
| ErrorMsg error_msg = 1; | |||
| } | |||
| message AgentFailedRequest { | |||
| } | |||
| message AgentFailedReply { | |||
| ErrorMsg error_msg = 1; | |||
| } | |||
| @@ -80,6 +80,8 @@ message Tensor { | |||
| // for string type and images, the dtype is MS_BYTES. | |||
| repeated bytes bytes_val = 4; | |||
| int64 size = 5; | |||
| } | |||
| message ServableSpec { | |||
| @@ -20,8 +20,14 @@ syntax = "proto3"; | |||
| package mindspore.serving.proto; | |||
| import "mindspore_serving/proto/ms_service.proto"; | |||
| import "mindspore_serving/proto/ms_master.proto"; | |||
| import "mindspore_serving/proto/ms_distributed.proto"; | |||
| service MSWorker { | |||
| // for master | |||
| rpc Predict(PredictRequest) returns (PredictReply) {} | |||
| rpc Exit(ExitRequest) returns (ExitReply) {} | |||
| // for worker agent | |||
| rpc AgentExit(AgentExitRequest) returns (AgentExitReply) {} | |||
| rpc AgentRegister(AgentRegisterRequest) returns (AgentRegisterReply) {} | |||
| rpc AgentFailed(AgentFailedRequest) returns (AgentFailedReply) {} | |||
| } | |||
| @@ -12,11 +12,12 @@ | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Inferface for start up servable""" | |||
| """Interface for start up servable""" | |||
| import threading | |||
| from functools import wraps | |||
| from mindspore_serving import log as logger | |||
| from mindspore_serving._mindspore_serving import ExitSignalHandle_ | |||
| from mindspore_serving._mindspore_serving import Worker_ | |||
| from .register.preprocess import preprocess_storage | |||
| from .register.postprocess import postprocess_storage | |||
| @@ -77,6 +78,7 @@ def stop_on_except(func): | |||
| @wraps(func) | |||
| def handle_except(*args, **kwargs): | |||
| try: | |||
| ExitSignalHandle_.start() # Set flag to running and receive Ctrl+C message | |||
| func(*args, **kwargs) | |||
| except: | |||
| stop() | |||
| @@ -0,0 +1,250 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Serving, distributed worker agent startup""" | |||
| import os | |||
| import time | |||
| from multiprocessing import Process, Pipe | |||
| from mindspore_serving._mindspore_serving import ExitSignalHandle_ | |||
| from mindspore_serving._mindspore_serving import WorkerAgent_, AgentStartUpConfig_ | |||
| from mindspore_serving import log as logger | |||
| from mindspore_serving.worker import check_type | |||
| from mindspore_serving.worker.distributed import worker_agent | |||
| def _get_local_ip(rank_list, port): | |||
| """Get the local ip from the rank table config""" | |||
| import socket | |||
| ip_list = [] | |||
| for item in rank_list: | |||
| ip_list.append(item.ip) | |||
| with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: | |||
| for ip in ip_list: | |||
| try: | |||
| s.bind((ip, port)) | |||
| logger.info(f"Get local machine ip success, ip {ip}") | |||
| return ip | |||
| # pylint: disable=bare-except | |||
| except: | |||
| pass | |||
| raise RuntimeError(f"Get local machine ip failed, rank table ips: {ip_list}, bind port {port}") | |||
| def _update_model_files_path(model_files, group_config_files): | |||
| """Check and return model files or group config files""" | |||
| script_dir = os.path.dirname(os.path.realpath(__file__)) | |||
| logger.info(f"input model files: {model_files}") | |||
| logger.info(f"input group config files: {group_config_files}") | |||
| model_files_temp = [] | |||
| for item in model_files: | |||
| file_name = os.path.join(script_dir, item) | |||
| if not os.access(file_name, os.R_OK): | |||
| raise RuntimeError(f"Cannot access model file '{file_name}'") | |||
| model_files_temp.append(file_name) | |||
| group_files_temp = [] | |||
| for item in group_config_files: | |||
| file_name = os.path.join(script_dir, item) | |||
| if not os.access(file_name, os.R_OK): | |||
| raise RuntimeError(f"Cannot access group config file '{file_name}'") | |||
| group_files_temp.append(file_name) | |||
| logger.info(f"absolute model files: {model_files_temp}") | |||
| logger.info(f"absolute group config files: {group_files_temp}") | |||
| return model_files_temp, group_files_temp | |||
| def _make_json_table_file(distributed_config): | |||
| """Make rank table json file""" | |||
| rank_size = len(distributed_config.rank_list) | |||
| runtime_dir = os.path.abspath(".") | |||
| time_stamp = str(time.strftime('%Y_%m_%d_%H_%M_%S', time.localtime(time.time()))) | |||
| rank_table_file_name = os.path.join(runtime_dir, f"hccl_rank_table_{time_stamp}_{rank_size}.json") | |||
| with open(rank_table_file_name, "w") as fp: | |||
| fp.write(distributed_config.rank_table_content) | |||
| return rank_table_file_name | |||
| signal_success = "Success" | |||
| signal_exit = "Exit" | |||
| signal_heartbeat = "HeartBeat" | |||
| def _recv_parent(index, recv_pipe): | |||
| """Receive message from Start up process. | |||
| Return False on Ctrl+C(and worker Stop message) Exit Signal, heartbeat failed, and signal_exit. | |||
| Return True on receiving signal_success.""" | |||
| try: | |||
| while True: | |||
| heartbeat_count = 0 | |||
| while not recv_pipe.poll(0.1): | |||
| if ExitSignalHandle_.has_stopped(): | |||
| logger.warning(f"Child {index}: Exit on Ctrl+C or stop message from worker") | |||
| return False | |||
| heartbeat_count += 1 | |||
| if heartbeat_count >= 30: # 3s | |||
| logger.warning(f"Child {index}: Exit on failure of receiving parent message") | |||
| return False | |||
| parent_signal = recv_pipe.recv() | |||
| if parent_signal != signal_heartbeat: | |||
| break | |||
| if parent_signal == signal_success: | |||
| logger.info(f"Child {index}: Receive success") | |||
| return True | |||
| if parent_signal == signal_exit: | |||
| logger.warning(f"Child {index}: Exit on receiving exit message") | |||
| else: | |||
| logger.warning(f"Child {index}: Exit on receiving unknown message {parent_signal}") | |||
| # pylint: disable=broad-except | |||
| except Exception as e: | |||
| logger.warning(f"Child {index}: Exit on exception: {e}") | |||
| return False | |||
| def _agent_process(send_pipe, recv_pipe, index, start_config): | |||
| """Agent process""" | |||
| try: | |||
| # listening success or failed message from parent process | |||
| ExitSignalHandle_.start() # Set flag to running and receive Ctrl+C message | |||
| worker_agent.start_worker_agent(start_config=start_config) | |||
| send_pipe.send((index, signal_success)) | |||
| success_msg = _recv_parent(index, recv_pipe) | |||
| if not success_msg: | |||
| worker_agent.stop() | |||
| send_pipe.close() | |||
| recv_pipe.close() | |||
| # pylint: disable=broad-except | |||
| except Exception as e: | |||
| logger.error(f"Child {index}: Catch exception and notify exit of others") | |||
| send_pipe.send((index, e)) | |||
| worker_agent.stop() | |||
| raise | |||
| def _start_listening_child_processes(p_recv_pipe, send_pipe_list, subprocess_list): | |||
| """Listening child process""" | |||
| def send_pipe_msg(send_pipe, msg): | |||
| try: | |||
| send_pipe.send(msg) | |||
| # pylint: disable=broad-except | |||
| except Exception as e: | |||
| logger.warning(f"Send pipe message exception happen: {e}") | |||
| count = len(send_pipe_list) | |||
| for _ in range(count): | |||
| while True: | |||
| if p_recv_pipe.poll(0.1): | |||
| break | |||
| for send_pipe, process in zip(send_pipe_list, subprocess_list): | |||
| if process.is_alive(): | |||
| continue | |||
| logger.warning("Fail to start agents because of death of one agent") | |||
| for send_pipe_x, process_x in zip(send_pipe_list, subprocess_list): | |||
| if process_x.is_alive(): | |||
| send_pipe_msg(send_pipe_x, signal_exit) | |||
| return False | |||
| for send_pipe in send_pipe_list: | |||
| send_pipe_msg(send_pipe, signal_heartbeat) | |||
| _, msg = p_recv_pipe.recv() | |||
| if isinstance(msg, Exception): | |||
| logger.warning("Fail to start agents because of exception raise by one agent") | |||
| for send_pipe in send_pipe_list: | |||
| send_pipe_msg(send_pipe, signal_exit) | |||
| return False | |||
| for send_pipe in send_pipe_list: | |||
| send_pipe_msg(send_pipe, signal_success) | |||
| logger.info("Success to start agents") | |||
| return True | |||
| def _startup_all_agents(common_meta, worker_ip, worker_port, | |||
| agent_ip, agent_start_port, device_id_list, rank_id_list, | |||
| model_files, group_config_files, rank_table_file): | |||
| """Start up all agents in one machine""" | |||
| servable_name = common_meta.servable_name | |||
| index = 0 | |||
| send_pipe_list = [] | |||
| subprocess_list = [] | |||
| c_send_pipe, p_recv_pipe = Pipe() | |||
| for device_id, rank_id, model_file, group_file in zip(device_id_list, rank_id_list, model_files, | |||
| group_config_files): | |||
| p_send_pipe, c_recv_pipe = Pipe() | |||
| send_pipe_list.append(p_send_pipe) | |||
| agent_port = agent_start_port + index | |||
| start_config = AgentStartUpConfig_() | |||
| start_config.rank_id = rank_id | |||
| start_config.device_id = device_id | |||
| start_config.model_file_name = model_file | |||
| start_config.group_file_name = group_file | |||
| start_config.rank_table_json_file_name = rank_table_file | |||
| start_config.agent_ip = agent_ip | |||
| start_config.agent_port = agent_port | |||
| start_config.worker_ip = worker_ip | |||
| start_config.worker_port = worker_port | |||
| start_config.common_meta = common_meta | |||
| process = Process(target=_agent_process, | |||
| args=(c_send_pipe, c_recv_pipe, index, start_config), | |||
| name=f"{servable_name}_worker_agent_rank{rank_id}_device{device_id}") | |||
| process.start() | |||
| subprocess_list.append(process) | |||
| index += 1 | |||
| ret = _start_listening_child_processes(p_recv_pipe, send_pipe_list, subprocess_list) | |||
| if not ret: | |||
| WorkerAgent_.notify_failed(worker_ip, worker_port) | |||
| def startup_worker_agents(worker_ip, worker_port, model_files, group_config_files, agent_start_port=7000): | |||
| """Start up all needed worker agents on one machine""" | |||
| check_type.check_str("worker_ip", worker_ip) | |||
| check_type.check_ip_port("worker_port", worker_port) | |||
| check_type.check_int("agent_start_port", agent_start_port, 1, 65535 - 7) | |||
| model_files = check_type.check_and_as_int_tuple_list("model_files", model_files) | |||
| group_config_files = check_type.check_and_as_int_tuple_list("group_config_files", group_config_files) | |||
| distributed_config = WorkerAgent_.get_agents_config_from_worker(worker_ip, worker_port) | |||
| # get machine ip | |||
| rank_list = distributed_config.rank_list | |||
| local_ip = _get_local_ip(rank_list, agent_start_port) | |||
| # get all device_id and rank_id | |||
| local_device_id_list = [] | |||
| local_rank_id_list = [] | |||
| for rank_id, item in enumerate(rank_list): | |||
| if item.ip == local_ip: | |||
| local_device_id_list.append(item.device_id) | |||
| local_rank_id_list.append(rank_id) | |||
| # handle model files and group config files | |||
| if len(local_device_id_list) != len(model_files): | |||
| raise RuntimeError(f"Card count {local_device_id_list} described rank table does not equal to model files size " | |||
| f"{len(model_files)}, model files: {model_files}") | |||
| if len(local_device_id_list) != len(group_config_files): | |||
| raise RuntimeError(f"Card count {local_device_id_list} described rank table does not equal to group config " | |||
| f"files size {len(group_config_files)}, group config files: {group_config_files}") | |||
| model_files, group_config_files = _update_model_files_path(model_files, group_config_files) | |||
| # make json table file and export env | |||
| rank_table_file = _make_json_table_file(distributed_config) | |||
| _startup_all_agents(distributed_config.common_meta, worker_ip, worker_port, local_ip, agent_start_port, | |||
| local_device_id_list, local_rank_id_list, | |||
| model_files, group_config_files, rank_table_file) | |||
| @@ -0,0 +1,131 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Serving, distributed worker startup""" | |||
| from mindspore_serving._mindspore_serving import Worker_ | |||
| from mindspore_serving.worker import check_type | |||
| from mindspore_serving.worker._worker import _start_py_task, _start_wait_and_clear | |||
| from mindspore_serving.worker._worker import stop_on_except, _load_servable_config | |||
| @stop_on_except | |||
| def start_distributed_servable(servable_directory, servable_name, rank_table_json_file, version_number=1, | |||
| worker_ip="0.0.0.0", worker_port=6200, master_ip="0.0.0.0", master_port=6100, | |||
| wait_agents_time_in_seconds=300): | |||
| r""" | |||
| Start up the servable named 'servable_name' defined in 'servable_directory', and link the worker to the master | |||
| through gRPC (master_ip, master_port). | |||
| Serving has two running modes. One is running in a single process, providing the Serving service of a single model. | |||
| The other includes a master and multiple workers. This interface is for the second scenario. | |||
| The master is responsible for providing the Serving access interface for clients, | |||
| while the worker is responsible for providing the inference service of the specific model. The communications | |||
| between the master and workers through gPRC are defined as (master_ip, master_port) and (worker_ip, worker_port). | |||
| Args: | |||
| servable_directory (str): The directory where the servable is located in. There expects to has a directory | |||
| named `servable_name`. For more detail: | |||
| `How to config Servable <https://www.mindspore.cn/tutorial/inference/zh-CN/master/serving_model.html>`_ . | |||
| servable_name (str): The servable name. | |||
| version_number (int): Servable version number to be loaded. The version number should be a positive integer, | |||
| starting from 1, and 0 means to load the latest version. Default: 0. | |||
| rank_table_json_file (str): The ranke table json file name. | |||
| master_ip (str): The master ip the worker linked to. | |||
| master_port (int): The master port the worker linked to. | |||
| worker_ip (str): The worker ip the master and agents linked to. | |||
| worker_port (int): The worker port the master and agents linked to. | |||
| wait_agents_time_in_seconds(int): The maximum time in seconds the worker waiting ready of all agents. | |||
| Examples: | |||
| >>> import os | |||
| >>> from mindspore_serving import worker | |||
| >>> | |||
| >>> servable_dir = os.path.abspath(".") | |||
| >>> worker.start_servable(servable_dir, "lenet", device_id=0, \ | |||
| ... master_ip="127.0.0.1", master_port=6500, \ | |||
| ... host_ip="127.0.0.1", host_port=6600) | |||
| """ | |||
| check_type.check_str('servable_directory', servable_directory) | |||
| check_type.check_str('servable_name', servable_name) | |||
| check_type.check_int('version_number', version_number, 0) | |||
| if version_number == 0: | |||
| version_number = 1 | |||
| check_type.check_str('rank_table_json_file', rank_table_json_file) | |||
| check_type.check_str('master_ip', master_ip) | |||
| check_type.check_ip_port('master_port', master_port) | |||
| check_type.check_str('worker_ip', worker_ip) | |||
| check_type.check_ip_port('worker_port', worker_port) | |||
| _load_servable_config(servable_directory, servable_name) | |||
| Worker_.start_distributed_servable(servable_directory, servable_name, rank_table_json_file, version_number, | |||
| master_ip, master_port, worker_ip, worker_port, wait_agents_time_in_seconds) | |||
| _start_py_task(Worker_.get_batch_size()) | |||
| _start_wait_and_clear() | |||
| @stop_on_except | |||
| def start_distributed_servable_in_master(servable_directory, servable_name, rank_table_json_file, version_number=1, | |||
| worker_ip="0.0.0.0", worker_port=6200, wait_agents_time_in_seconds=300): | |||
| r""" | |||
| Start up the servable named 'servable_name' defined in 'svable_directory', and the worker will run in | |||
| the process of the master. | |||
| Serving has two running modes. One is running in a single process, providing the Serving service of a single model. | |||
| The other includes a master and multiple workers. This interface is for the first scenario. | |||
| Args: | |||
| servable_directory (str): The directory where the servable is located in. There expects to has a directory named | |||
| `servable_name`. For more detail: | |||
| `How to config Servable <https://www.mindspore.cn/tutorial/inference/zh-CN/master/serving_model.html>`_ . | |||
| servable_name (str): The servable name. | |||
| version_number (int): Servable version number to be loaded. The version number should be a positive integer, | |||
| starting from 1, and 0 means to load the latest version. Default: 0. | |||
| rank_table_json_file (str): The ranke table json file name. | |||
| worker_ip (str): The worker ip the agents linked to. | |||
| worker_port (int): The worker port the agents linked to. | |||
| wait_agents_time_in_seconds(int): The maximum time in seconds the worker waiting ready of all agents. | |||
| Examples: | |||
| >>> import os | |||
| >>> from mindspore_serving import worker | |||
| >>> from mindspore_serving import master | |||
| >>> | |||
| >>> servable_dir = os.path.abspath(".") | |||
| >>> worker.start_servable_in_master(servable_dir, "lenet", device_id=0) | |||
| >>> | |||
| >>> master.start_grpc_server("0.0.0.0", 5500) | |||
| >>> master.start_restful_server("0.0.0.0", 1500) | |||
| """ | |||
| check_type.check_str('servable_directory', servable_directory) | |||
| check_type.check_str('servable_name', servable_name) | |||
| check_type.check_int('version_number', version_number, 0) | |||
| if version_number == 0: | |||
| version_number = 1 | |||
| check_type.check_str('rank_table_json_file', rank_table_json_file) | |||
| check_type.check_str('worker_ip', worker_ip) | |||
| check_type.check_ip_port('worker_port', worker_port) | |||
| _load_servable_config(servable_directory, servable_name) | |||
| Worker_.start_distributed_servable_in_master(servable_directory, servable_name, rank_table_json_file, | |||
| version_number, worker_ip, worker_port, wait_agents_time_in_seconds) | |||
| _start_py_task(Worker_.get_batch_size()) | |||
| _start_wait_and_clear() | |||
| @@ -0,0 +1,43 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Serving, distributed worker register""" | |||
| from mindspore_serving._mindspore_serving import ServableMeta_, ServableStorage_ | |||
| from mindspore_serving.worker import check_type | |||
| from mindspore_serving.worker.common import get_servable_dir | |||
| from mindspore_serving import log as logger | |||
| def declare_distributed_servable(rank_size, stage_size, with_batch_dim, without_batch_dim_inputs): | |||
| """declare distributed servable in servable_config.py""" | |||
| check_type.check_bool('with_batch_dim', with_batch_dim) | |||
| meta = ServableMeta_() | |||
| meta.common_meta.servable_name = get_servable_dir() | |||
| meta.common_meta.with_batch_dim = with_batch_dim | |||
| if without_batch_dim_inputs: | |||
| without_batch_dim_inputs = check_type.check_and_as_int_tuple_list('without_batch_dim_inputs', | |||
| without_batch_dim_inputs, 0) | |||
| meta.common_meta.without_batch_dim_inputs = without_batch_dim_inputs | |||
| # init distributed servable meta info | |||
| check_type.check_int("rank_size", rank_size, 1) | |||
| check_type.check_int("stage_size", stage_size, 1) | |||
| meta.distributed_meta.rank_size = rank_size | |||
| meta.distributed_meta.stage_size = stage_size | |||
| ServableStorage_.declare_distributed_servable(meta) | |||
| logger.info(f"Declare distributed servable, servable_name: {meta.common_meta.servable_name} " | |||
| f", rank_size: {rank_size} , stage_size: {stage_size}, with_batch_dim: {with_batch_dim} " | |||
| f", without_batch_dim_inputs: {without_batch_dim_inputs}") | |||
| @@ -0,0 +1,66 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Serving, distributed worker agent""" | |||
| import os | |||
| import threading | |||
| from mindspore_serving._mindspore_serving import WorkerAgent_, AgentStartUpConfig_ | |||
| from mindspore_serving import log as logger | |||
| def start_worker_agent(start_config): | |||
| """Start up one worker agent on one device id, invoke by agent_startup.startup_worker_agents | |||
| """ | |||
| if not isinstance(start_config, AgentStartUpConfig_): | |||
| raise RuntimeError("Parameter 'start_config' should be instance of AgentStartUpConfig_") | |||
| os.environ["RANK_ID"] = str(start_config.rank_id) | |||
| os.environ["DEVICE_ID"] = str(start_config.device_id) | |||
| os.environ["MS_ENABLE_HCCL"] = "1" | |||
| os.environ["PARA_GROUP_FILE"] = start_config.group_file_name | |||
| os.environ["RANK_TABLE_FILE"] = start_config.rank_table_json_file_name | |||
| for item in ("RANK_ID", "DEVICE_ID", "MS_ENABLE_HCCL", "PARA_GROUP_FILE", "RANK_TABLE_FILE", | |||
| "LD_LIBRARY_PATH", "PYTHONPATH"): | |||
| logger.info(f"Env {item}: {os.getenv(item, '')}") | |||
| WorkerAgent_.start_agent(start_config) | |||
| start_wait_and_clear() | |||
| _wait_and_clear_thread = None | |||
| def start_wait_and_clear(): | |||
| """Waiting for Ctrl+C, and clear up environment""" | |||
| def thread_func(): | |||
| logger.info("Serving worker: wait for Ctrl+C to exit ------------------------------------") | |||
| print("Serving worker: wait for Ctrl+C to exit ------------------------------------") | |||
| WorkerAgent_.wait_and_clear() | |||
| logger.info("Serving worker: exited ------------------------------------") | |||
| print("Serving worker: exited ------------------------------------") | |||
| global _wait_and_clear_thread | |||
| if not _wait_and_clear_thread: | |||
| _wait_and_clear_thread = threading.Thread(target=thread_func) | |||
| _wait_and_clear_thread.start() | |||
| def stop(): | |||
| r""" | |||
| Stop the running of agent. | |||
| """ | |||
| WorkerAgent_.stop_and_clear() | |||
| @@ -35,28 +35,6 @@ method_tag_predict = PredictPhaseTag_.kPredictPhaseTag_Predict | |||
| method_tag_postprocess = PredictPhaseTag_.kPredictPhaseTag_Postprocess | |||
| class _ServableStorage: | |||
| """Declare servable info""" | |||
| def __init__(self): | |||
| pass | |||
| @staticmethod | |||
| def declare_servable(servable_meta): | |||
| """Declare servable info excluding method, input and output count""" | |||
| ServableStorage_.declare_servable(servable_meta) | |||
| @staticmethod | |||
| def declare_servable_input_output(servable_name, inputs_count, outputs_count): | |||
| """Declare input and output count of servable""" | |||
| ServableStorage_.register_servable_input_output_info(servable_name, inputs_count, outputs_count) | |||
| @staticmethod | |||
| def register_method(method_signature): | |||
| """Declare method of servable""" | |||
| ServableStorage_.register_method(method_signature) | |||
| class _TensorDef: | |||
| """Data flow item, for definitions of data flow in a method""" | |||
| @@ -251,7 +229,7 @@ def call_servable(*args): | |||
| servable_name = get_servable_dir() | |||
| inputs_count, outputs_count = method_def_ast_meta_[_call_servable_name] | |||
| _ServableStorage.declare_servable_input_output(servable_name, inputs_count, outputs_count) | |||
| ServableStorage_.register_servable_input_output_info(servable_name, inputs_count, outputs_count) | |||
| if inputs_count != len(args): | |||
| raise RuntimeError(f"Check failed in method '{method_def_context_.method_name}', given servable input " | |||
| f"size {len(args)} not match '{servable_name}' ast parse size {inputs_count}") | |||
| @@ -467,7 +445,7 @@ def register_method(output_names): | |||
| f", servable_name {method_def_context_.servable_name}, inputs: {input_names}, outputs: " | |||
| f"{output_names}") | |||
| _ServableStorage.register_method(method_def_context_) | |||
| ServableStorage_.register_method(method_def_context_) | |||
| return func | |||
| return register | |||
| @@ -14,11 +14,10 @@ | |||
| # ============================================================================ | |||
| """Servable declaration interface""" | |||
| from mindspore_serving._mindspore_serving import ServableMeta_ | |||
| from mindspore_serving._mindspore_serving import ServableMeta_, ServableStorage_ | |||
| from mindspore_serving.worker import check_type | |||
| from mindspore_serving.worker.common import get_servable_dir | |||
| from mindspore_serving import log as logger | |||
| from .method import _ServableStorage | |||
| def declare_servable(servable_file, model_format, with_batch_dim=True, options=None, without_batch_dim_inputs=None): | |||
| @@ -37,19 +36,25 @@ def declare_servable(servable_file, model_format, with_batch_dim=True, options=N | |||
| RuntimeError: The type or value of the parameters is invalid. | |||
| """ | |||
| check_type.check_str('servable_file', servable_file) | |||
| check_type.check_str('model_format', model_format) | |||
| check_type.check_bool('with_batch_dim', with_batch_dim) | |||
| meta = ServableMeta_() | |||
| meta.common_meta.servable_name = get_servable_dir() | |||
| meta.common_meta.with_batch_dim = with_batch_dim | |||
| if without_batch_dim_inputs: | |||
| without_batch_dim_inputs = check_type.check_and_as_int_tuple_list('without_batch_dim_inputs', | |||
| without_batch_dim_inputs, 0) | |||
| meta.common_meta.without_batch_dim_inputs = without_batch_dim_inputs | |||
| # init local servable meta info | |||
| check_type.check_str('servable_file', servable_file) | |||
| check_type.check_str('model_format', model_format) | |||
| model_format = model_format.lower() | |||
| if model_format not in ("om", "mindir"): | |||
| raise RuntimeError("model format can only be OM or MindIR") | |||
| meta = ServableMeta_() | |||
| meta.servable_name = get_servable_dir() | |||
| meta.servable_file = servable_file | |||
| meta.set_model_format(model_format) | |||
| meta.with_batch_dim = with_batch_dim | |||
| meta.local_meta.servable_file = servable_file | |||
| meta.local_meta.set_model_format(model_format) | |||
| if isinstance(options, dict): | |||
| for k, w in options.items(): | |||
| check_type.check_str("options key", k) | |||
| @@ -61,14 +66,10 @@ def declare_servable(servable_file, model_format, with_batch_dim=True, options=N | |||
| raise RuntimeError(f"Parameter 'options' should be None, dict of <str,str> or AclOptions, but " | |||
| f"gotten {type(options)}") | |||
| if options: | |||
| meta.options = options | |||
| if without_batch_dim_inputs: | |||
| without_batch_dim_inputs = check_type.check_and_as_int_tuple_list('without_batch_dim_inputs', | |||
| without_batch_dim_inputs, 0) | |||
| meta.without_batch_dim_inputs = without_batch_dim_inputs | |||
| meta.local_meta.options = options | |||
| _ServableStorage.declare_servable(meta) | |||
| logger.info(f"Declare servable, servable_name: {meta.servable_name} " | |||
| ServableStorage_.declare_servable(meta) | |||
| logger.info(f"Declare servable, servable_name: {meta.common_meta.servable_name} " | |||
| f", servable_file: {servable_file} , model_format: {model_format}, with_batch_dim: {with_batch_dim} " | |||
| f", options: {options}, without_batch_dim_inputs: {without_batch_dim_inputs}") | |||
| @@ -27,6 +27,7 @@ | |||
| #include "worker/worker.h" | |||
| #include "worker/notfiy_master/local_notify.h" | |||
| #include "worker/context.h" | |||
| #include "worker/local_servable/local_sevable.h" | |||
| #include "master/grpc/grpc_process.h" | |||
| #include "mindspore_serving/proto/ms_service.pb.h" | |||
| @@ -102,16 +103,21 @@ class TestMasterWorker : public UT::Common { | |||
| auto notify_master = std::make_shared<LocalNotifyMaster>(); | |||
| ServableContext::Instance()->SetDeviceId(0); | |||
| ServableContext::Instance()->SetDeviceTypeStr("Ascend"); | |||
| Status status = Worker::GetInstance().StartServable(servable_dir, servable_name, version_number, notify_master); | |||
| auto servable = std::make_shared<LocalModelServable>(); | |||
| auto status = servable->StartServable(servable_dir, servable_name, version_number); | |||
| if (status != SUCCESS) { | |||
| return status; | |||
| } | |||
| status = Worker::GetInstance().StartServable(servable, notify_master); | |||
| return status; | |||
| } | |||
| static void DeclareServable(const std::string &servable_name, const std::string &servable_file, | |||
| const std::string &model_type, bool with_batch_dim = false) { | |||
| ServableMeta servable_meta; | |||
| servable_meta.servable_name = servable_name; | |||
| servable_meta.servable_file = servable_file; | |||
| servable_meta.SetModelFormat(model_type); | |||
| servable_meta.with_batch_dim = with_batch_dim; | |||
| servable_meta.common_meta.servable_name = servable_name; | |||
| servable_meta.common_meta.with_batch_dim = with_batch_dim; | |||
| servable_meta.local_meta.servable_file = servable_file; | |||
| servable_meta.local_meta.SetModelFormat(model_type); | |||
| // declare_servable | |||
| ServableStorage::Instance().DeclareServable(servable_meta); | |||
| } | |||
| @@ -30,10 +30,7 @@ TEST_F(TestStartWorker, test_worker_start_success) { | |||
| DeclareServable("test_servable", "test_add.mindir", "mindir", true); | |||
| RegisterMethod("test_servable", "add_common", {"x1", "x2"}, {"y"}, 2, 1); | |||
| // start_servable | |||
| auto notify_master = std::make_shared<LocalNotifyMaster>(); | |||
| ServableContext::Instance()->SetDeviceId(0); | |||
| ServableContext::Instance()->SetDeviceTypeStr("Ascend"); | |||
| Status status = Worker::GetInstance().StartServable("test_servable_dir", "test_servable", 0, notify_master); | |||
| Status status = StartServable("test_servable_dir", "test_servable", 0); | |||
| EXPECT_TRUE(status.IsSuccess()); | |||
| } | |||
| @@ -43,10 +40,7 @@ TEST_F(TestStartWorker, test_worker_start_error_model_file_name) { | |||
| RegisterMethod("test_servable", "add_common", {"x1", "x2"}, {"y"}, 2, 1); | |||
| // start_servable | |||
| auto notify_master = std::make_shared<LocalNotifyMaster>(); | |||
| ServableContext::Instance()->SetDeviceId(0); | |||
| ServableContext::Instance()->SetDeviceTypeStr("Ascend"); | |||
| Status status = Worker::GetInstance().StartServable("test_servable_dir", "test_servable", 0, notify_master); | |||
| auto status = StartServable("test_servable_dir", "test_servable", 0); | |||
| EXPECT_FALSE(status.IsSuccess()); | |||
| ExpectContainMsg(status.StatusMessage(), "Load model failed, servable directory: "); | |||
| } | |||
| @@ -57,12 +51,8 @@ TEST_F(TestStartWorker, test_worker_start_error_version_number) { | |||
| RegisterMethod("test_servable", "add_common", {"x1", "x2"}, {"y"}, 2, 1); | |||
| // start_servable | |||
| auto notify_master = std::make_shared<LocalNotifyMaster>(); | |||
| ServableContext::Instance()->SetDeviceId(0); | |||
| ServableContext::Instance()->SetDeviceTypeStr("Ascend"); | |||
| int error_version_number = 2; | |||
| Status status = | |||
| Worker::GetInstance().StartServable("test_servable_dir", "test_servable", error_version_number, notify_master); | |||
| auto status = StartServable("test_servable_dir", "test_servable", error_version_number); | |||
| EXPECT_FALSE(status.IsSuccess()); | |||
| ExpectContainMsg(status.StatusMessage(), | |||
| "Start servable failed, there is no servable of" | |||
| @@ -78,11 +68,8 @@ TEST_F(TestStartWorker, test_worker_start_multi_version_number) { | |||
| RegisterMethod("test_servable", "add_common", {"x1", "x2"}, {"y"}, 2, 1); | |||
| // start_servable | |||
| auto notify_master = std::make_shared<LocalNotifyMaster>(); | |||
| ServableContext::Instance()->SetDeviceId(0); | |||
| ServableContext::Instance()->SetDeviceTypeStr("Ascend"); | |||
| int version_number = 0; | |||
| Status status = Worker::GetInstance().StartServable(servable_dir, "test_servable", version_number, notify_master); | |||
| Status status = StartServable(servable_dir, "test_servable", version_number); | |||
| EXPECT_TRUE(status.IsSuccess()); | |||
| } | |||
| @@ -96,10 +83,7 @@ TEST_F(TestStartWorker, test_worker_start_version_number_no_valid) { | |||
| RegisterMethod("test_servable", "add_common", {"x1", "x2"}, {"y"}, 2, 1); | |||
| // start_servable | |||
| auto notify_master = std::make_shared<LocalNotifyMaster>(); | |||
| ServableContext::Instance()->SetDeviceId(0); | |||
| ServableContext::Instance()->SetDeviceTypeStr("Ascend"); | |||
| Status status = Worker::GetInstance().StartServable(servable_dir, "test_servable", 0, notify_master); | |||
| Status status = StartServable(servable_dir, "test_servable", 0); | |||
| EXPECT_FALSE(status.IsSuccess()); | |||
| ExpectContainMsg(status.StatusMessage(), | |||
| "Start servable failed, there is no servable of" | |||
| @@ -112,11 +96,8 @@ TEST_F(TestStartWorker, test_worker_start_error_servable_dir) { | |||
| RegisterMethod("test_servable", "add_common", {"x1", "x2"}, {"y"}, 2, 1); | |||
| // start_servable | |||
| auto notify_master = std::make_shared<LocalNotifyMaster>(); | |||
| ServableContext::Instance()->SetDeviceId(0); | |||
| ServableContext::Instance()->SetDeviceTypeStr("Ascend"); | |||
| std::string error_servable_dir = "test_servable_dir_error"; | |||
| Status status = Worker::GetInstance().StartServable(error_servable_dir, "test_servable", 0, notify_master); | |||
| Status status = StartServable(error_servable_dir, "test_servable", 0); | |||
| EXPECT_FALSE(status.IsSuccess()); | |||
| ExpectContainMsg(status.StatusMessage(), | |||
| "Start servable failed, there is no servable of" | |||
| @@ -129,11 +110,8 @@ TEST_F(TestStartWorker, test_worker_start_error_servable_name) { | |||
| RegisterMethod("test_servable", "add_common", {"x1", "x2"}, {"y"}, 2, 1); | |||
| // start_servable | |||
| auto notify_master = std::make_shared<LocalNotifyMaster>(); | |||
| ServableContext::Instance()->SetDeviceId(0); | |||
| ServableContext::Instance()->SetDeviceTypeStr("Ascend"); | |||
| std::string error_servable_name = "test_servable_error"; | |||
| Status status = Worker::GetInstance().StartServable("test_servable_dir", error_servable_name, 0, notify_master); | |||
| Status status = StartServable("test_servable_dir", error_servable_name, 0); | |||
| EXPECT_FALSE(status.IsSuccess()); | |||
| ExpectContainMsg(status.StatusMessage(), "'test_servable_error' has not been registered"); | |||
| } | |||
| @@ -144,24 +122,18 @@ TEST_F(TestStartWorker, test_worker_start_error_servable_format) { | |||
| RegisterMethod("test_servable", "add_common", {"x1", "x2"}, {"y"}, 2, 1); | |||
| // start_servable | |||
| auto notify_master = std::make_shared<LocalNotifyMaster>(); | |||
| ServableContext::Instance()->SetDeviceId(0); | |||
| ServableContext::Instance()->SetDeviceTypeStr("Ascend"); | |||
| Status status = Worker::GetInstance().StartServable("test_servable_dir", "test_servable", 0, notify_master); | |||
| Status status = StartServable("test_servable_dir", "test_servable", 0); | |||
| EXPECT_FALSE(status.IsSuccess()); | |||
| ExpectContainMsg(status.StatusMessage(), "Cannot find session registered for device type Ascend and model type OM"); | |||
| ExpectContainMsg(status.StatusMessage(), "Not support device type Ascend and model type OM. "); | |||
| } | |||
| TEST_F(TestStartWorker, test_worker_start_no_registered_method) { | |||
| Init("test_servable_dir", "test_servable", 1, "test_add.mindir"); | |||
| Init("test_servable_dir", "test_servable", 2, "test_add.mindir"); | |||
| DeclareServable("test_servable", "test_add.mindir", "mindir", true); | |||
| // no registered method | |||
| // RegisterMethod("test_servable", "add_common", {"x1", "x2"}, {"y"}, 2, 1); | |||
| // start_servable | |||
| auto notify_master = std::make_shared<LocalNotifyMaster>(); | |||
| ServableContext::Instance()->SetDeviceId(0); | |||
| ServableContext::Instance()->SetDeviceTypeStr("Ascend"); | |||
| Status status = Worker::GetInstance().StartServable("test_servable_dir", "test_servable", 0, notify_master); | |||
| Status status = StartServable("test_servable_dir", "test_servable", 2); | |||
| EXPECT_FALSE(status.IsSuccess()); | |||
| ExpectContainMsg(status.StatusMessage(), "There is no method registered for servable"); | |||
| } | |||
| @@ -181,10 +153,7 @@ TEST_F(TestStartWorker, test_worker_start_multi_method) { | |||
| RegisterMethod("test_servable", "add_common", {"x1", "x2"}, {"y"}, 2, 1); | |||
| RegisterMethod("test_servable", "add_common2", {"x1", "x2"}, {"y"}, 2, 1); | |||
| // start_servable | |||
| auto notify_master = std::make_shared<LocalNotifyMaster>(); | |||
| ServableContext::Instance()->SetDeviceId(0); | |||
| ServableContext::Instance()->SetDeviceTypeStr("Ascend"); | |||
| Status status = Worker::GetInstance().StartServable("test_servable_dir", "test_servable", 0, notify_master); | |||
| Status status = StartServable("test_servable_dir", "test_servable", 0); | |||
| EXPECT_TRUE(status.IsSuccess()); | |||
| } | |||
| @@ -194,10 +163,7 @@ TEST_F(TestStartWorker, test_worker_start_method_servable_input_count_not_match) | |||
| size_t servable_input_count = 1; | |||
| RegisterMethod("test_servable", "add_common", {"x1", "x2"}, {"y"}, servable_input_count, 1); | |||
| // start_servable | |||
| auto notify_master = std::make_shared<LocalNotifyMaster>(); | |||
| ServableContext::Instance()->SetDeviceId(0); | |||
| ServableContext::Instance()->SetDeviceTypeStr("Ascend"); | |||
| Status status = Worker::GetInstance().StartServable("test_servable_dir", "test_servable", 0, notify_master); | |||
| Status status = StartServable("test_servable_dir", "test_servable", 0); | |||
| EXPECT_FALSE(status.IsSuccess()); | |||
| ExpectContainMsg(status.StatusMessage(), | |||
| "The inputs count 1 registered in method not equal to " | |||
| @@ -210,10 +176,7 @@ TEST_F(TestStartWorker, test_worker_start_method_servable_output_count_not_match | |||
| size_t servable_output_count = 2; | |||
| RegisterMethod("test_servable", "add_common", {"x1", "x2"}, {"y"}, 2, servable_output_count); | |||
| // start_servable | |||
| auto notify_master = std::make_shared<LocalNotifyMaster>(); | |||
| ServableContext::Instance()->SetDeviceId(0); | |||
| ServableContext::Instance()->SetDeviceTypeStr("Ascend"); | |||
| Status status = Worker::GetInstance().StartServable("test_servable_dir", "test_servable", 0, notify_master); | |||
| Status status = StartServable("test_servable_dir", "test_servable", 0); | |||
| EXPECT_FALSE(status.IsSuccess()); | |||
| ExpectContainMsg(status.StatusMessage(), | |||
| "The outputs count 2 registered in method not equal to " | |||
| @@ -241,10 +204,7 @@ TEST_F(TestStartWorker, test_worker_start_preprocess_not_found) { | |||
| ServableStorage::Instance().RegisterMethod(method_signature); | |||
| // start_servable | |||
| auto notify_master = std::make_shared<LocalNotifyMaster>(); | |||
| ServableContext::Instance()->SetDeviceId(0); | |||
| ServableContext::Instance()->SetDeviceTypeStr("Ascend"); | |||
| Status status = Worker::GetInstance().StartServable("test_servable_dir", "test_servable", 0, notify_master); | |||
| Status status = StartServable("test_servable_dir", "test_servable", 0); | |||
| EXPECT_FALSE(status.IsSuccess()); | |||
| ExpectContainMsg(status.StatusMessage(), " preprocess preprocess_fake_fun not defined") | |||
| } | |||
| @@ -269,10 +229,7 @@ TEST_F(TestStartWorker, test_worker_start_postprocess_not_found) { | |||
| ServableStorage::Instance().RegisterMethod(method_signature); | |||
| // start_servable | |||
| auto notify_master = std::make_shared<LocalNotifyMaster>(); | |||
| ServableContext::Instance()->SetDeviceId(0); | |||
| ServableContext::Instance()->SetDeviceTypeStr("Ascend"); | |||
| Status status = Worker::GetInstance().StartServable("test_servable_dir", "test_servable", 0, notify_master); | |||
| Status status = StartServable("test_servable_dir", "test_servable", 0); | |||
| EXPECT_FALSE(status.IsSuccess()); | |||
| ExpectContainMsg(status.StatusMessage(), " postprocess postprocess_fake_fun not defined") | |||
| } | |||
| @@ -300,10 +257,7 @@ TEST_F(TestStartWorker, test_worker_start_with_preproces_and_postprocess_success | |||
| ServableStorage::Instance().RegisterMethod(method_signature); | |||
| // start_servable | |||
| auto notify_master = std::make_shared<LocalNotifyMaster>(); | |||
| ServableContext::Instance()->SetDeviceId(0); | |||
| ServableContext::Instance()->SetDeviceTypeStr("Ascend"); | |||
| Status status = Worker::GetInstance().StartServable("test_servable_dir", "test_servable", 0, notify_master); | |||
| Status status = StartServable("test_servable_dir", "test_servable", 0); | |||
| EXPECT_TRUE(status.IsSuccess()); | |||
| } | |||
| @@ -16,21 +16,24 @@ | |||
| set -e | |||
| CURRPATH=$(cd "$(dirname $0)" || exit; pwd) | |||
| CURRPATH=$( | |||
| cd "$(dirname $0)" || exit | |||
| pwd | |||
| ) | |||
| if [ $# -gt 0 ]; then | |||
| if [ $1 == "python" ]; then | |||
| echo "run python ut" | |||
| bash ${CURRPATH}/python/runtest.sh $2 | |||
| elif [ $1 == "cpp" ]; then | |||
| echo "run cpp ut" | |||
| bash ${CURRPATH}/cpp/runtest.sh | |||
| fi | |||
| else | |||
| echo "run all ut" | |||
| # 1.run python testcases | |||
| if [ $1 == "python" ]; then | |||
| echo "run python ut" | |||
| bash ${CURRPATH}/python/runtest.sh $2 | |||
| # 2.run c++ ut testcases | |||
| elif [ $1 == "cpp" ]; then | |||
| echo "run cpp ut" | |||
| bash ${CURRPATH}/cpp/runtest.sh | |||
| fi | |||
| else | |||
| echo "run all ut" | |||
| # 1.run python testcases | |||
| bash ${CURRPATH}/python/runtest.sh $2 | |||
| # 2.run c++ ut testcases | |||
| bash ${CURRPATH}/cpp/runtest.sh | |||
| fi | |||