| @@ -356,6 +356,7 @@ void GrpcTensorHelper::CopyFromAgentSpec(const proto::AgentSpec &specs, WorkerAg | |||
| for (auto &out : specs.outputs()) { | |||
| TensorInfo info; | |||
| info.data_type = ProtoTensor::TransDataType2Inference(out.dtype()); | |||
| info.size = out.size(); | |||
| for (auto &dim : out.shape().dims()) { | |||
| info.shape.push_back(dim); | |||
| } | |||
| @@ -28,6 +28,16 @@ | |||
| namespace mindspore::serving { | |||
| void PyWorker::OnEndStartServable(const std::string &servable_directory, const std::string &servable_name, | |||
| uint32_t spec_version_number, uint32_t started_version_number) { | |||
| auto status = INFER_STATUS(SUCCESS) << "Serving: Start servable success, servable directory: '" << servable_directory | |||
| << "', servable name: '" << servable_name | |||
| << "', specified version number: " << spec_version_number | |||
| << ", started version numbers: " << started_version_number; | |||
| MSI_LOG_INFO << status.StatusMessage(); | |||
| std::cout << status.StatusMessage() << std::endl; | |||
| } | |||
| void PyWorker::StartServable(const std::string &model_directory, const std::string &model_name, uint32_t version_number, | |||
| const std::string &master_ip, uint32_t master_port, const std::string &worker_ip, | |||
| uint32_t worker_port) { | |||
| @@ -52,6 +62,7 @@ void PyWorker::StartServable(const std::string &model_directory, const std::stri | |||
| if (status != SUCCESS) { | |||
| MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); | |||
| } | |||
| OnEndStartServable(model_directory, model_name, version_number, servable->GetServableVersion()); | |||
| } | |||
| void PyWorker::StartServableInMaster(const std::string &model_directory, const std::string &model_name, | |||
| @@ -70,6 +81,7 @@ void PyWorker::StartServableInMaster(const std::string &model_directory, const s | |||
| if (status != SUCCESS) { | |||
| MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); | |||
| } | |||
| OnEndStartServable(model_directory, model_name, version_number, servable->GetServableVersion()); | |||
| } | |||
| void PyWorker::StartDistributedServable(const std::string &servable_directory, const std::string &servable_name, | |||
| @@ -99,6 +111,7 @@ void PyWorker::StartDistributedServable(const std::string &servable_directory, c | |||
| if (status != SUCCESS) { | |||
| MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); | |||
| } | |||
| OnEndStartServable(servable_directory, servable_name, version_number, servable->GetServableVersion()); | |||
| } | |||
| void PyWorker::StartDistributedServableInMaster(const std::string &servable_directory, const std::string &servable_name, | |||
| @@ -127,6 +140,7 @@ void PyWorker::StartDistributedServableInMaster(const std::string &servable_dire | |||
| if (status != SUCCESS) { | |||
| MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); | |||
| } | |||
| OnEndStartServable(servable_directory, servable_name, version_number, servable->GetServableVersion()); | |||
| } | |||
| TaskItem PyWorker::GetPyTask() { | |||
| @@ -55,6 +55,10 @@ class MS_API PyWorker { | |||
| static void PushPostprocessPyResult(const py::tuple &output_batch); | |||
| static void PushPostprocessPyFailed(int count); | |||
| private: | |||
| static void OnEndStartServable(const std::string &servable_directory, const std::string &servable_name, | |||
| uint32_t spec_version_number, uint32_t started_version_number); | |||
| }; | |||
| } // namespace mindspore::serving | |||
| @@ -30,6 +30,7 @@ grpc::Status MSAgentImpl::Predict(grpc::ServerContext *context, const proto::Dis | |||
| proto::DistributedPredictReply *reply) { | |||
| MSI_LOG(INFO) << "Begin call service Eval"; | |||
| WorkerAgent::Instance().Run(*request, reply); | |||
| MSI_LOG(INFO) << "End call service Eval"; | |||
| return grpc::Status::OK; | |||
| } | |||
| @@ -14,6 +14,7 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #include "worker/distributed_worker/agent_startup.h" | |||
| #include <fstream> | |||
| #include "worker/distributed_worker/notify_distributed/notify_worker.h" | |||
| namespace mindspore { | |||
| @@ -25,7 +26,7 @@ WorkerAgentStartUp &WorkerAgentStartUp::Instance() { | |||
| } | |||
| Status WorkerAgentStartUp::GetAgentsConfigsFromWorker(const std::string &worker_ip, uint32_t worker_port) { | |||
| return Status(); | |||
| return FAILED; | |||
| } | |||
| Status WorkerAgentStartUp::GetDistributedServableConfig(DistributedServableConfig *config) { | |||
| @@ -42,16 +42,7 @@ grpc::Status MSDistributedImpl::AgentExit(grpc::ServerContext *context, const pr | |||
| proto::AgentExitReply *reply) { | |||
| MSI_EXCEPTION_IF_NULL(request); | |||
| MSI_EXCEPTION_IF_NULL(reply); | |||
| for (auto &spec : request->agent_spec()) { | |||
| WorkerAgentSpec agent_spec; | |||
| agent_spec.agent_address = request->address(); | |||
| GrpcTensorHelper::CopyFromAgentSpec(spec, &agent_spec); | |||
| Status status(FAILED); | |||
| status = servable_->UnregisterAgent(agent_spec); | |||
| if (status != SUCCESS) { | |||
| MSI_LOG(ERROR) << "Agent Exit FAILED"; | |||
| } | |||
| } | |||
| servable_->OnAgentExit(); | |||
| if (Worker::GetInstance().IsRunning()) { | |||
| Worker::GetInstance().StopServable(); | |||
| } | |||
| @@ -18,6 +18,7 @@ | |||
| #include <vector> | |||
| #include <string> | |||
| #include <set> | |||
| #include <fstream> | |||
| #include "worker/distributed_worker/notify_agent/notify_agent.h" | |||
| #include "common/exit_handle.h" | |||
| @@ -65,13 +66,17 @@ Status DistributedServable::GetDistributedServableConfig(DistributedServableConf | |||
| void DistributedServable::SetWaitAgentsPromise(bool flag) { | |||
| if (!promise_set_flag_.test_and_set()) { | |||
| agents_promise_.set_value(flag); | |||
| registered_end_flag_ = true; | |||
| } | |||
| } | |||
| Status DistributedServable::RegisterAgent(const WorkerAgentSpec &agent_spec) { | |||
| std::unique_lock<std::mutex> lock{mutex_}; | |||
| if (registered_end_flag_) { | |||
| return INFER_STATUS_LOG_ERROR(FAILED) << "Distributed servable has ended up registration"; | |||
| } | |||
| if (agent_spec.rank_id < config_.distributed_meta.rank_size) { | |||
| if (agent_spec.rank_id >= config_.distributed_meta.rank_size) { | |||
| return INFER_STATUS_LOG_ERROR(FAILED) | |||
| << "Invalid rank id " << agent_spec.rank_id << ", rank size " << config_.distributed_meta.rank_size; | |||
| } | |||
| @@ -82,7 +87,7 @@ Status DistributedServable::RegisterAgent(const WorkerAgentSpec &agent_spec) { | |||
| return SUCCESS; | |||
| } | |||
| context.agent_spec_ = agent_spec; | |||
| std::shared_ptr<BaseNotifyAgent> notify_agent = std::make_shared<GrpcNotfiyAgent>(agent_spec.agent_address); | |||
| std::shared_ptr<BaseNotifyAgent> notify_agent = std::make_shared<GrpcNotifyAgent>(agent_spec.agent_address); | |||
| context.notify_agent_ = notify_agent; | |||
| agent_spec_map_[agent_spec.rank_id] = context; | |||
| @@ -98,18 +103,11 @@ void DistributedServable::Clear() { | |||
| agent.second.notify_agent_->Exit(); | |||
| } | |||
| agent_spec_map_.clear(); | |||
| MSI_LOG_INFO << "End Clear servable"; | |||
| model_loaded_ = false; | |||
| MSI_LOG_INFO << "End clear distributed servable"; | |||
| } | |||
| Status DistributedServable::UnregisterAgent(const WorkerAgentSpec &agent_spec) { | |||
| std::unique_lock<std::mutex> lock{mutex_}; | |||
| for (auto iter = agent_spec_map_.begin(); iter != agent_spec_map_.end();) { | |||
| if (agent_spec.rank_id == iter->second.agent_spec_.rank_id) { | |||
| iter = agent_spec_map_.erase(iter); | |||
| } else { | |||
| ++iter; | |||
| } | |||
| } | |||
| Status DistributedServable::OnAgentExit() { | |||
| SetWaitAgentsPromise(false); | |||
| return SUCCESS; | |||
| } | |||
| @@ -162,6 +160,7 @@ Status DistributedServable::StartServable(const std::string &servable_directory, | |||
| Status DistributedServable::InitConfigOnStartup(const std::string &rank_table_json_file) { return FAILED; } | |||
| Status DistributedServable::WaitAgentsReady(uint64_t wait_agents_time_in_seconds) { | |||
| MSI_LOG_INFO << "Begin waiting ready of all agents"; | |||
| auto future = agents_promise_.get_future(); | |||
| if (wait_agents_time_in_seconds == 0) { | |||
| wait_agents_time_in_seconds = UINT32_MAX; | |||
| @@ -186,6 +185,7 @@ Status DistributedServable::WaitAgentsReady(uint64_t wait_agents_time_in_seconds | |||
| << "Failed to wait for ready of all agents, current agents count: " << agent_spec_map_.size() | |||
| << ", rank size: " << config_.distributed_meta.rank_size; | |||
| } | |||
| MSI_LOG_INFO << "Success waiting ready of all agents"; | |||
| return SUCCESS; | |||
| } | |||
| @@ -48,7 +48,7 @@ class MS_API DistributedServable : public ServableBase { | |||
| // register and unregister agent, agent_spec_list_ | |||
| Status RegisterAgent(const WorkerAgentSpec &agent_spec); | |||
| Status UnregisterAgent(const WorkerAgentSpec &agent_spec); | |||
| Status OnAgentExit(); | |||
| // predict, use config_ and agent_spec_list_ | |||
| Status Predict(const std::vector<TensorBasePtr> &input, std::vector<TensorBasePtr> *output) override; | |||
| @@ -75,6 +75,7 @@ class MS_API DistributedServable : public ServableBase { | |||
| std::vector<TensorInfo> output_infos_; | |||
| uint64_t batch_size_ = 0; | |||
| std::atomic_flag promise_set_flag_ = ATOMIC_FLAG_INIT; | |||
| std::atomic_bool registered_end_flag_ = false; | |||
| std::promise<bool> agents_promise_; | |||
| Status InitConfigOnStartup(const std::string &rank_table_json_file); | |||
| @@ -25,15 +25,16 @@ | |||
| namespace mindspore { | |||
| namespace serving { | |||
| GrpcNotfiyAgent::GrpcNotfiyAgent(const std::string &agent_address) { | |||
| GrpcNotifyAgent::GrpcNotifyAgent(const std::string &agent_address) { | |||
| agent_address_ = agent_address; | |||
| std::shared_ptr<grpc::Channel> channel = GrpcServer::CreateChannel(agent_address_); | |||
| stub_ = proto::MSAgent::NewStub(channel); | |||
| } | |||
| GrpcNotfiyAgent::~GrpcNotfiyAgent() = default; | |||
| GrpcNotifyAgent::~GrpcNotifyAgent() = default; | |||
| Status GrpcNotfiyAgent::Exit() { | |||
| Status GrpcNotifyAgent::Exit() { | |||
| MSI_LOG_INFO << "Notify one agent exit begin"; | |||
| if (stub_) { | |||
| proto::DistributedExitRequest request; | |||
| request.set_address(agent_address_); | |||
| @@ -43,12 +44,18 @@ Status GrpcNotfiyAgent::Exit() { | |||
| std::chrono::system_clock::time_point deadline = std::chrono::system_clock::now() + std::chrono::seconds(TIME_OUT); | |||
| context.set_deadline(deadline); | |||
| (void)stub_->Exit(&context, request, &reply); | |||
| auto status = stub_->Exit(&context, request, &reply); | |||
| if (status.ok()) { | |||
| MSI_LOG_INFO << "Notify one agent exit success, agent address: " << agent_address_; | |||
| } else { | |||
| MSI_LOG_INFO << "Notify one agent exit failed, agent address: " << agent_address_ | |||
| << ", error: " << status.error_code() << ", " << status.error_message(); | |||
| } | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| Status GrpcNotfiyAgent::DispatchAsync(const proto::DistributedPredictRequest &request, | |||
| Status GrpcNotifyAgent::DispatchAsync(const proto::DistributedPredictRequest &request, | |||
| proto::DistributedPredictReply *reply, DispatchCallback callback) { | |||
| if (!stub_) { | |||
| return INFER_STATUS_LOG_ERROR(FAILED) | |||
| @@ -27,10 +27,10 @@ | |||
| namespace mindspore { | |||
| namespace serving { | |||
| class MS_API GrpcNotfiyAgent : public BaseNotifyAgent { | |||
| class MS_API GrpcNotifyAgent : public BaseNotifyAgent { | |||
| public: | |||
| explicit GrpcNotfiyAgent(const std::string &worker_address); | |||
| ~GrpcNotfiyAgent() override; | |||
| explicit GrpcNotifyAgent(const std::string &worker_address); | |||
| ~GrpcNotifyAgent() override; | |||
| Status Exit() override; | |||
| @@ -45,9 +45,10 @@ Status GrpcNotifyDistributeWorker::Register(const std::vector<WorkerAgentSpec> & | |||
| const int32_t REGISTER_INTERVAL = 1; | |||
| auto loop = REGISTER_TIME_OUT; | |||
| while (loop-- && !ExitSignalHandle::Instance().HasStopped()) { | |||
| MSI_LOG(INFO) << "Register to " << distributed_worker_address_; | |||
| MSI_LOG(INFO) << "Register to " << distributed_worker_address_ << ", agent address: " << agent_address_; | |||
| proto::AgentRegisterRequest request; | |||
| GrpcTensorHelper::CopyFromWorkerAgentSpec(worker_specs, &request); | |||
| request.set_address(agent_address_); | |||
| proto::AgentRegisterReply reply; | |||
| grpc::ClientContext context; | |||
| std::chrono::system_clock::time_point deadline = | |||
| @@ -18,6 +18,7 @@ | |||
| #include "worker/distributed_worker/agent_process/agent_process.h" | |||
| #include "worker/distributed_worker/notify_distributed/notify_worker.h" | |||
| #include "common/exit_handle.h" | |||
| #include "common/proto_tensor.h" | |||
| namespace mindspore { | |||
| namespace serving { | |||
| @@ -31,26 +32,49 @@ Status WorkerAgent::Clear() { | |||
| if (notify_worker_) { | |||
| if (exit_notify_worker_) { | |||
| notify_worker_->Unregister(); | |||
| MSI_LOG_INFO << "End unregister to worker"; | |||
| } | |||
| notify_worker_ = nullptr; | |||
| } | |||
| grpc_server_.Stop(); | |||
| executor_.UnloadModel(); | |||
| session_.UnloadModel(); | |||
| return SUCCESS; | |||
| } | |||
| Status WorkerAgent::Run(const proto::DistributedPredictRequest &request, proto::DistributedPredictReply *reply) { | |||
| // todo : DistributedPredictRequest->RequestBase | |||
| // todo : DistributedPredictReply->ReplyBase | |||
| return SUCCESS; | |||
| Status status; | |||
| try { | |||
| MSI_TIME_STAMP_START(ExecuteModel) | |||
| // status = session_.ExecuteModel(request_wrap, &reply_wrap); | |||
| MSI_TIME_STAMP_END(ExecuteModel) | |||
| } catch (const std::bad_alloc &ex) { | |||
| status = INFER_STATUS_LOG_ERROR(FAILED) << "Serving Error: malloc memory failed"; | |||
| } catch (const std::runtime_error &ex) { | |||
| status = INFER_STATUS_LOG_ERROR(FAILED) << "Serving Error: runtime error occurred: " << ex.what(); | |||
| } catch (const std::exception &ex) { | |||
| status = INFER_STATUS_LOG_ERROR(FAILED) << "Serving Error: exception occurred: " << ex.what(); | |||
| } catch (...) { | |||
| status = INFER_STATUS_LOG_ERROR(FAILED) << "Serving Error: exception occurred"; | |||
| } | |||
| if (status != SUCCESS) { | |||
| reply->Clear(); | |||
| auto error_msg = reply->mutable_error_msg(); | |||
| error_msg->set_error_code(status.StatusCode()); | |||
| error_msg->set_error_msg(status.StatusMessage()); | |||
| } | |||
| return status; | |||
| } | |||
| Status WorkerAgent::StartAgent(const AgentStartUpConfig &config) { | |||
| Status status; | |||
| config_ = config; | |||
| status = executor_.LoadModelFromFile(config); | |||
| const auto &common_meta = config.common_meta; | |||
| status = session_.LoadModelFromFile(kDeviceTypeAscendMS, config.device_id, config.model_file_name, kMindIR, | |||
| common_meta.with_batch_dim, common_meta.without_batch_dim_inputs, {}); | |||
| if (status != SUCCESS) { | |||
| MSI_LOG_ERROR << "LoadModelFromFile failed, servable name: " << config.common_meta.servable_name | |||
| MSI_LOG_ERROR << "LoadModelFromFile failed, servable name: " << common_meta.servable_name | |||
| << ", rank_id: " << config.rank_id << ", device id: " << config.device_id | |||
| << ", model file: " << config.model_file_name | |||
| << ", rank table file: " << config.rank_table_json_file_name | |||
| @@ -69,9 +93,8 @@ Status WorkerAgent::StartAgent(const AgentStartUpConfig &config) { | |||
| << ", worker ip: " << config.worker_ip << ", worker port: " << config.worker_port; | |||
| return status; | |||
| } | |||
| MSI_LOG_INFO << "Start agent success, servable name: " << config.common_meta.servable_name | |||
| << ", rank_id: " << config.rank_id << ", device id: " << config.device_id | |||
| << ", model file: " << config.model_file_name | |||
| MSI_LOG_INFO << "Start agent success, servable name: " << common_meta.servable_name << ", rank_id: " << config.rank_id | |||
| << ", device id: " << config.device_id << ", model file: " << config.model_file_name | |||
| << ", rank table file: " << config.rank_table_json_file_name | |||
| << ", group config file: " << config.group_file_name; | |||
| return SUCCESS; | |||
| @@ -83,14 +106,14 @@ Status WorkerAgent::StartGrpcServer() { | |||
| } | |||
| Status WorkerAgent::RegisterAgent() { | |||
| notify_worker_ = std::make_shared<GrpcNotifyDistributeWorker>(config_.worker_ip, config_.agent_port, config_.agent_ip, | |||
| config_.agent_port); | |||
| notify_worker_ = std::make_shared<GrpcNotifyDistributeWorker>(config_.worker_ip, config_.worker_port, | |||
| config_.agent_ip, config_.agent_port); | |||
| WorkerAgentSpec spec; | |||
| spec.agent_address = config_.agent_ip + ":" + std::to_string(config_.agent_port); | |||
| spec.rank_id = config_.rank_id; | |||
| spec.batch_size = executor_.GetBatchSize(); | |||
| spec.input_infos = executor_.GetInputInfos(); | |||
| spec.output_infos = executor_.GetOutputInfos(); | |||
| spec.batch_size = session_.GetBatchSize(); | |||
| spec.input_infos = session_.GetInputInfos(); | |||
| spec.output_infos = session_.GetOutputInfos(); | |||
| return notify_worker_->Register({spec}); | |||
| } | |||
| @@ -24,6 +24,7 @@ | |||
| #include "common/grpc_server.h" | |||
| #include "worker/distributed_worker/common.h" | |||
| #include "worker/distributed_worker/notify_distributed/notify_worker.h" | |||
| #include "worker/inference/mindspore_model_wrap.h" | |||
| namespace mindspore { | |||
| namespace serving { | |||
| @@ -40,7 +41,8 @@ class MS_API WorkerAgent { | |||
| private: | |||
| AgentStartUpConfig config_; | |||
| WorkerAgentExecutor executor_; | |||
| // WorkerAgentExecutor executor_; | |||
| MindSporeModelWrap session_; | |||
| GrpcServer grpc_server_; | |||
| bool exit_notify_worker_ = true; | |||
| std::shared_ptr<GrpcNotifyDistributeWorker> notify_worker_; | |||
| @@ -38,11 +38,17 @@ using mindspore::ModelType::kMindIR; | |||
| using mindspore::ModelType::kOM; | |||
| struct TensorInfo { | |||
| size_t size; // -1: unspecified | |||
| DataType data_type; | |||
| size_t size = 0; // -1: unspecified | |||
| DataType data_type = kMSI_Unknown; | |||
| std::vector<int64_t> shape; | |||
| }; | |||
| struct TensorInfoWithBatch { | |||
| TensorInfo tensor_info; | |||
| size_t size_one_batch = 0; | |||
| std::vector<int64_t> shape_one_batch; | |||
| }; | |||
| enum DeviceType { | |||
| kDeviceTypeNotSpecified, | |||
| kDeviceTypeAscendMS, | |||
| @@ -97,9 +97,10 @@ Status MindSporeModelWrap::LoadModelFromFile(serving::DeviceType device_type, ui | |||
| } | |||
| mindspore::Status status = model->Build(); | |||
| if (!status.IsOk()) { | |||
| MSI_LOG_ERROR << "Load model from file failed, model file: " << file_name << ", device_type: '" << device_type_str | |||
| << "', device_id: " << device_id << ", model type: " << model_type << ", options: " << other_options; | |||
| return Status(FAILED, status.ToString()); | |||
| return INFER_STATUS_LOG_ERROR(FAILED) | |||
| << "Load model from file failed, model file: " << file_name << ", device_type: '" << device_type_str | |||
| << "', device_id: " << device_id << ", model type: " << model_type << ", options: " << other_options | |||
| << ", build error detail: " << status.ToString(); | |||
| } | |||
| ApiModelInfo api_model_info; | |||
| api_model_info.model = model; | |||
| @@ -246,6 +247,9 @@ Status MindSporeModelWrap::ExecuteModel(const RequestBase &request, serving::Rep | |||
| MSI_EXCEPTION_IF_NULL(reply); | |||
| FuncMakeInBuffer func_in = [&request](size_t index, const std::string &name) { | |||
| auto input_tensor = request[index]; | |||
| if (input_tensor == nullptr || input_tensor->data() == nullptr) { | |||
| MSI_LOG_EXCEPTION << "Input tensor data cannot be nullptr, index " << index; | |||
| } | |||
| return mindspore::MSTensor::CreateRefTensor(name, TransInferDataType2ApiTypeId(input_tensor->data_type()), | |||
| input_tensor->shape(), const_cast<uint8_t *>(input_tensor->data()), | |||
| input_tensor->data_size()); | |||
| @@ -54,19 +54,18 @@ Status WorkExecutor::CheckSevableSignature() { | |||
| return INFER_STATUS_LOG_ERROR(FAILED) << "The inputs count " << common_meta.inputs_count << " registered in method " | |||
| << "not equal to the count " << input_infos.size() << " defined in servable"; | |||
| } | |||
| const auto &output_infos = output_infos_; | |||
| if (output_infos.size() != common_meta.outputs_count) { | |||
| if (output_infos_.size() != common_meta.outputs_count) { | |||
| return INFER_STATUS_LOG_ERROR(FAILED) | |||
| << "The outputs count " << common_meta.outputs_count << " registered in method " | |||
| << "not equal to the count " << output_infos.size() << " defined in servable"; | |||
| << "not equal to the count " << output_infos_.size() << " defined in servable"; | |||
| } | |||
| MSI_LOG_INFO << "Model input infos: count " << input_infos.size(); | |||
| for (auto &item : input_infos) { | |||
| MSI_LOG_INFO << item.shape << ", " << item.data_type << ", " << item.size; | |||
| } | |||
| MSI_LOG_INFO << "Model output infos: count " << output_infos.size(); | |||
| for (auto &item : output_infos) { | |||
| MSI_LOG_INFO << item.shape << ", " << item.data_type << ", " << item.size; | |||
| MSI_LOG_INFO << "Model output infos: count " << output_infos_.size(); | |||
| for (auto &item : output_infos_) { | |||
| MSI_LOG_INFO << item.tensor_info.shape << ", " << item.tensor_info.data_type << ", " << item.tensor_info.size; | |||
| } | |||
| if (common_meta.with_batch_dim) { | |||
| if (model_batch_size_ == 0) { | |||
| @@ -82,11 +81,21 @@ Status WorkExecutor::CheckSevableSignature() { | |||
| << "Servable batch size " << model_batch_size_ << " not match model input shape " << item.shape; | |||
| } | |||
| } | |||
| for (const auto &item : output_infos) { | |||
| if (item.shape.empty() || static_cast<uint32_t>(item.shape[0]) != model_batch_size_) { | |||
| for (auto &item : output_infos_) { | |||
| auto &tensor_info = item.tensor_info; | |||
| if (tensor_info.shape.empty() || static_cast<uint32_t>(tensor_info.shape[0]) != model_batch_size_) { | |||
| return INFER_STATUS_LOG_ERROR(FAILED) | |||
| << "Servable batch size " << model_batch_size_ << " not match model output shape " << item.shape; | |||
| << "Servable batch size " << model_batch_size_ << " not match model output shape " << tensor_info.shape; | |||
| } | |||
| item.shape_one_batch = tensor_info.shape; | |||
| item.shape_one_batch.erase(item.shape_one_batch.begin()); | |||
| item.size_one_batch = tensor_info.size / model_batch_size_; | |||
| } | |||
| } else { | |||
| for (auto &item : output_infos_) { | |||
| auto &tensor_info = item.tensor_info; | |||
| item.shape_one_batch = tensor_info.shape; | |||
| item.size_one_batch = tensor_info.size; | |||
| } | |||
| } | |||
| return SUCCESS; | |||
| @@ -103,7 +112,12 @@ Status WorkExecutor::Init(const ServableSignature &servable_declare, const std:: | |||
| servable_declare_ = servable_declare; | |||
| servable_ = servable; | |||
| input_infos_ = servable_->GetInputInfos(); | |||
| output_infos_ = servable_->GetOutputInfos(); | |||
| auto output_infos = servable_->GetOutputInfos(); | |||
| for (auto &item : output_infos) { | |||
| TensorInfoWithBatch info; | |||
| info.tensor_info = item; | |||
| output_infos_.push_back(info); | |||
| } | |||
| if (servable_declare_.servable_meta.common_meta.with_batch_dim) { | |||
| model_batch_size_ = servable_->GetBatchSize(); | |||
| } else { | |||
| @@ -351,6 +365,10 @@ Status WorkExecutor::PrePredict(const std::vector<Instance> &inputs) { | |||
| auto data_size = tensor->data_size(); | |||
| auto dst_buffer = reinterpret_cast<uint8_t *>(tensor->mutable_data()); | |||
| if (IsNoBatchDimInput(i)) { | |||
| if (data_size != inputs[0].data[i]->data_size()) { | |||
| return INFER_STATUS_LOG_ERROR(SYSTEM_ERROR) << "Input " << i << " data size " << inputs[0].data[i]->data_size() | |||
| << "does not match size " << data_size << " defined in model"; | |||
| } | |||
| memcpy_s(dst_buffer, data_size, inputs[0].data[i]->data(), data_size); | |||
| continue; | |||
| } | |||
| @@ -361,7 +379,7 @@ Status WorkExecutor::PrePredict(const std::vector<Instance> &inputs) { | |||
| } | |||
| if (item_size != inputs[k].data[i]->data_size()) { | |||
| return INFER_STATUS_LOG_ERROR(SYSTEM_ERROR) | |||
| << " Batch index " << k << " input data size " << inputs[k].data[i]->data_size() | |||
| << "Input " << i << " Batch index " << k << " input data size " << inputs[k].data[i]->data_size() | |||
| << "does not match size " << item_size << " defined in model"; | |||
| } | |||
| memcpy_s(dst_buffer + k * item_size, data_size - k * item_size, inputs[k].data[i]->data(), item_size); | |||
| @@ -382,23 +400,27 @@ Status WorkExecutor::PostPredict(const std::vector<Instance> &inputs, const std: | |||
| MSI_LOG_ERROR << "Input batch size " << input_batch_size << " invalid, model batch size " << model_batch_size; | |||
| return SYSTEM_ERROR; | |||
| } | |||
| if (predict_result.size() != output_infos_.size()) { | |||
| MSI_LOG_ERROR << "Output result count " << predict_result.size() << " not equal to output_infos_ count " | |||
| << output_infos_.size(); | |||
| return SYSTEM_ERROR; | |||
| } | |||
| std::vector<ResultInstance> results_data(input_batch_size); | |||
| for (auto &item : predict_result) { | |||
| size_t item_size = item->data_size() / model_batch_size; | |||
| if (item_size == 0) { | |||
| MSI_LOG_EXCEPTION << "Output result data size cannot be 0"; | |||
| } | |||
| auto shape = item->shape(); | |||
| if (servable_declare_.servable_meta.common_meta.with_batch_dim) { | |||
| if (shape.empty() || shape[0] != model_batch_size) { | |||
| MSI_LOG_EXCEPTION << "Output shape " << shape << " not match batch size " << model_batch_size; | |||
| } | |||
| shape.erase(shape.begin()); | |||
| for (size_t i = 0; i < predict_result.size(); i++) { | |||
| auto &item = predict_result[i]; | |||
| auto &output_info = output_infos_[i]; | |||
| if (item->data_size() != output_info.tensor_info.size) { | |||
| MSI_LOG_ERROR << "Output result " << i << " data size " << item->data_size() << " not equal to size " | |||
| << output_info.tensor_info.size << " in output_infos_ "; | |||
| return SYSTEM_ERROR; | |||
| } | |||
| auto item_size = output_info.size_one_batch; | |||
| auto shape = output_info.shape_one_batch; | |||
| auto data_type = output_info.tensor_info.data_type; | |||
| auto src_buffer = const_cast<uint8_t *>(item->data()); | |||
| for (uint32_t k = 0; k < input_batch_size; k++) { | |||
| auto tensor = std::make_shared<BufferTensorWithOwner>(item, item->data_type(), shape, src_buffer + item_size * k, | |||
| item_size, true); | |||
| auto tensor = | |||
| std::make_shared<BufferTensorWithOwner>(item, data_type, shape, src_buffer + item_size * k, item_size, true); | |||
| results_data[k].data.push_back(tensor); | |||
| } | |||
| } | |||
| @@ -58,7 +58,7 @@ class WorkExecutor { | |||
| ServableSignature servable_declare_; | |||
| std::shared_ptr<ServableBase> servable_; | |||
| std::vector<TensorInfo> input_infos_; | |||
| std::vector<TensorInfo> output_infos_; | |||
| std::vector<TensorInfoWithBatch> output_infos_; | |||
| uint32_t model_batch_size_ = 0; | |||
| uint64_t worker_id_ = 0; | |||
| bool init_flag_ = false; | |||
| @@ -16,6 +16,8 @@ | |||
| import os | |||
| import time | |||
| import sys | |||
| import traceback | |||
| from multiprocessing import Process, Pipe | |||
| from mindspore_serving._mindspore_serving import ExitSignalHandle_ | |||
| @@ -46,7 +48,7 @@ def _get_local_ip(rank_list, port): | |||
| def _update_model_files_path(model_files, group_config_files): | |||
| """Check and return model files or group config files""" | |||
| script_dir = os.path.dirname(os.path.realpath(__file__)) | |||
| script_dir = os.path.dirname(os.path.realpath(sys.argv[0])) | |||
| logger.info(f"input model files: {model_files}") | |||
| logger.info(f"input group config files: {group_config_files}") | |||
| model_files_temp = [] | |||
| @@ -72,8 +74,8 @@ def _make_json_table_file(distributed_config): | |||
| """Make rank table json file""" | |||
| rank_size = len(distributed_config.rank_list) | |||
| runtime_dir = os.path.abspath(".") | |||
| time_stamp = str(time.strftime('%Y_%m_%d_%H_%M_%S', time.localtime(time.time()))) | |||
| rank_table_file_name = os.path.join(runtime_dir, f"hccl_rank_table_{time_stamp}_{rank_size}.json") | |||
| time_stamp = str(time.strftime('%Y%m%d_%H%M%S', time.localtime(time.time()))) | |||
| rank_table_file_name = os.path.join(runtime_dir, f"hccl_rank_table_{time_stamp}_{rank_size}p.json") | |||
| with open(rank_table_file_name, "w") as fp: | |||
| fp.write(distributed_config.rank_table_content) | |||
| return rank_table_file_name | |||
| @@ -125,18 +127,20 @@ def _agent_process(send_pipe, recv_pipe, index, start_config): | |||
| success_msg = _recv_parent(index, recv_pipe) | |||
| if not success_msg: | |||
| worker_agent.stop() | |||
| send_pipe.close() | |||
| recv_pipe.close() | |||
| # pylint: disable=broad-except | |||
| except Exception as e: | |||
| traceback.print_exc() | |||
| logger.error(f"Child {index}: Catch exception and notify exit of others") | |||
| send_pipe.send((index, e)) | |||
| _recv_parent(index, recv_pipe) # receive exit message from parent process | |||
| worker_agent.stop() | |||
| raise | |||
| send_pipe.close() | |||
| recv_pipe.close() | |||
| def _start_listening_child_processes(p_recv_pipe, send_pipe_list, subprocess_list): | |||
| """Listening child process""" | |||
| def send_pipe_msg(send_pipe, msg): | |||
| try: | |||
| send_pipe.send(msg) | |||
| @@ -217,8 +221,8 @@ def startup_worker_agents(worker_ip, worker_port, model_files, group_config_file | |||
| check_type.check_str("worker_ip", worker_ip) | |||
| check_type.check_ip_port("worker_port", worker_port) | |||
| check_type.check_int("agent_start_port", agent_start_port, 1, 65535 - 7) | |||
| model_files = check_type.check_and_as_int_tuple_list("model_files", model_files) | |||
| group_config_files = check_type.check_and_as_int_tuple_list("group_config_files", group_config_files) | |||
| model_files = check_type.check_and_as_str_tuple_list("model_files", model_files) | |||
| group_config_files = check_type.check_and_as_str_tuple_list("group_config_files", group_config_files) | |||
| distributed_config = WorkerAgent_.get_agents_config_from_worker(worker_ip, worker_port) | |||
| # get machine ip | |||
| @@ -23,7 +23,7 @@ from mindspore_serving.worker._worker import stop_on_except, _load_servable_conf | |||
| @stop_on_except | |||
| def start_distributed_servable(servable_directory, servable_name, rank_table_json_file, version_number=1, | |||
| worker_ip="0.0.0.0", worker_port=6200, master_ip="0.0.0.0", master_port=6100, | |||
| wait_agents_time_in_seconds=300): | |||
| wait_agents_time_in_seconds=0): | |||
| r""" | |||
| Start up the servable named 'servable_name' defined in 'servable_directory', and link the worker to the master | |||
| through gRPC (master_ip, master_port). | |||
| @@ -48,7 +48,8 @@ def start_distributed_servable(servable_directory, servable_name, rank_table_jso | |||
| master_port (int): The master port the worker linked to. | |||
| worker_ip (str): The worker ip the master and agents linked to. | |||
| worker_port (int): The worker port the master and agents linked to. | |||
| wait_agents_time_in_seconds(int): The maximum time in seconds the worker waiting ready of all agents. | |||
| wait_agents_time_in_seconds(int): The maximum time in seconds the worker waiting ready of all agents, | |||
| 0 means unlimited time, default 0 | |||
| Examples: | |||
| >>> import os | |||
| @@ -81,7 +82,7 @@ def start_distributed_servable(servable_directory, servable_name, rank_table_jso | |||
| @stop_on_except | |||
| def start_distributed_servable_in_master(servable_directory, servable_name, rank_table_json_file, version_number=1, | |||
| worker_ip="0.0.0.0", worker_port=6200, wait_agents_time_in_seconds=300): | |||
| worker_ip="0.0.0.0", worker_port=6200, wait_agents_time_in_seconds=0): | |||
| r""" | |||
| Start up the servable named 'servable_name' defined in 'svable_directory', and the worker will run in | |||
| the process of the master. | |||
| @@ -100,7 +101,8 @@ def start_distributed_servable_in_master(servable_directory, servable_name, rank | |||
| rank_table_json_file (str): The ranke table json file name. | |||
| worker_ip (str): The worker ip the agents linked to. | |||
| worker_port (int): The worker port the agents linked to. | |||
| wait_agents_time_in_seconds(int): The maximum time in seconds the worker waiting ready of all agents. | |||
| wait_agents_time_in_seconds(int): The maximum time in seconds the worker waiting ready of all agents, | |||
| 0 means unlimited time, default 0. | |||
| Examples: | |||
| >>> import os | |||
| @@ -20,7 +20,7 @@ from mindspore_serving.worker.common import get_servable_dir | |||
| from mindspore_serving import log as logger | |||
| def declare_distributed_servable(rank_size, stage_size, with_batch_dim, without_batch_dim_inputs): | |||
| def declare_distributed_servable(rank_size, stage_size, with_batch_dim=True, without_batch_dim_inputs=None): | |||
| """declare distributed servable in servable_config.py""" | |||
| check_type.check_bool('with_batch_dim', with_batch_dim) | |||
| @@ -47,11 +47,11 @@ def start_wait_and_clear(): | |||
| """Waiting for Ctrl+C, and clear up environment""" | |||
| def thread_func(): | |||
| logger.info("Serving worker: wait for Ctrl+C to exit ------------------------------------") | |||
| print("Serving worker: wait for Ctrl+C to exit ------------------------------------") | |||
| logger.info("Serving worker Agent: wait for Ctrl+C to exit ------------------------------------") | |||
| print("Serving worker Agent: wait for Ctrl+C to exit ------------------------------------") | |||
| WorkerAgent_.wait_and_clear() | |||
| logger.info("Serving worker: exited ------------------------------------") | |||
| print("Serving worker: exited ------------------------------------") | |||
| logger.info("Serving worker Agent: exited ------------------------------------") | |||
| print("Serving worker Agent: exited ------------------------------------") | |||
| global _wait_and_clear_thread | |||
| if not _wait_and_clear_thread: | |||