From: @xu-yfei Reviewed-by: @zhoufeng54,@zhangyinxia Signed-off-by: @zhangyinxiatags/v1.2.0
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -39,7 +39,7 @@ namespace serving { | |||
| using PredictOnFinish = std::function<void()>; | |||
| using DispatchCallback = std::function<void(Status status)>; | |||
| using AsyncPredictCallback = std::function<void(Status status)>; | |||
| template <typename Request, typename Reply, typename MSStub> | |||
| class MSServiceClient { | |||
| @@ -80,7 +80,7 @@ class MSServiceClient { | |||
| } | |||
| } | |||
| void PredictAsync(const Request &request, Reply *reply, MSStub *stub, DispatchCallback callback) { | |||
| void PredictAsync(const Request &request, Reply *reply, MSStub *stub, AsyncPredictCallback callback) { | |||
| AsyncClientCall *call = new AsyncClientCall; | |||
| call->reply = reply; | |||
| call->callback = std::move(callback); | |||
| @@ -95,7 +95,7 @@ class MSServiceClient { | |||
| grpc::ClientContext context; | |||
| grpc::Status status; | |||
| Reply *reply; | |||
| DispatchCallback callback; | |||
| AsyncPredictCallback callback; | |||
| std::shared_ptr<grpc::ClientAsyncResponseReader<Reply>> response_reader; | |||
| }; | |||
| @@ -256,8 +256,17 @@ Status GrpcTensorHelper::CreateInstanceFromRequest(const proto::PredictRequest & | |||
| return SUCCESS; | |||
| } | |||
| Status GrpcTensorHelper::CreateReplyFromInstances(const proto::PredictRequest &request, | |||
| const vector<InstancePtr> &instances, proto::PredictReply *reply) { | |||
| void GrpcTensorHelper::CreateReplyFromInstances(const proto::PredictRequest &request, | |||
| const vector<InstancePtr> &instances, proto::PredictReply *reply) { | |||
| auto status = CreateReplyFromInstancesInner(request, instances, reply); | |||
| if (status != SUCCESS) { | |||
| CreateReplyFromErrorMsg(status, reply); | |||
| } | |||
| } | |||
| Status GrpcTensorHelper::CreateReplyFromInstancesInner(const proto::PredictRequest &request, | |||
| const std::vector<InstancePtr> &instances, | |||
| proto::PredictReply *reply) { | |||
| MSI_EXCEPTION_IF_NULL(reply); | |||
| *reply->mutable_servable_spec() = request.servable_spec(); | |||
| if (instances.empty()) { | |||
| @@ -422,6 +431,23 @@ Status GrpcTensorHelper::CheckRequestTensor(const proto::Tensor &tensor) { | |||
| return SUCCESS; | |||
| } | |||
| void GrpcTensorHelper::CreateReplyFromErrorMsg(const Status &error_msg, proto::PredictReply *reply) { | |||
| MSI_EXCEPTION_IF_NULL(reply); | |||
| if (error_msg == SUCCESS) { | |||
| return; | |||
| } | |||
| reply->clear_error_msg(); | |||
| reply->clear_instances(); | |||
| auto proto_error_msg = reply->add_error_msg(); | |||
| proto_error_msg->set_error_code(FAILED); | |||
| std::string error_msg_str = error_msg.StatusMessage(); | |||
| if (error_msg_str.empty()) { | |||
| proto_error_msg->set_error_msg("Predict failed"); | |||
| } else { | |||
| proto_error_msg->set_error_msg(error_msg_str); | |||
| } | |||
| } | |||
| serving::LogStream &operator<<(serving::LogStream &stream, proto::DataType data_type) { | |||
| const std::map<proto::DataType, std::string> type_name_map{ | |||
| {proto::MS_UNKNOWN, "proto::MS_UNKNOWN"}, {proto::MS_BOOL, "proto::kMSI_Bool"}, | |||
| @@ -67,8 +67,9 @@ class MS_API GrpcTensorHelper { | |||
| static void GetWorkerSpec(const proto::RemoveWorkerRequest &request, WorkerSpec *worker_spec); | |||
| static Status CreateInstanceFromRequest(const proto::PredictRequest &request, RequestSpec *request_spec, | |||
| std::vector<InstanceData> *results); | |||
| static Status CreateReplyFromInstances(const proto::PredictRequest &request, | |||
| const std::vector<InstancePtr> &instances, proto::PredictReply *reply); | |||
| static void CreateReplyFromInstances(const proto::PredictRequest &request, const std::vector<InstancePtr> &instances, | |||
| proto::PredictReply *reply); | |||
| static void CreateReplyFromErrorMsg(const Status &error_msg, proto::PredictReply *reply); | |||
| static void CopyFromAgentSpec(const proto::AgentSpec &request, WorkerAgentSpec *worker_specs); | |||
| static void CopyFromWorkerAgentSpec(const std::vector<WorkerAgentSpec> &worker_specs, | |||
| proto::AgentRegisterRequest *request); | |||
| @@ -78,6 +79,9 @@ class MS_API GrpcTensorHelper { | |||
| const std::vector<std::string> &input_names, | |||
| std::vector<InstanceData> *results); | |||
| static Status CheckRequestTensor(const proto::Tensor &tensor); | |||
| static Status CreateReplyFromInstancesInner(const proto::PredictRequest &request, | |||
| const std::vector<InstancePtr> &instances, proto::PredictReply *reply); | |||
| }; | |||
| extern MS_API LogStream &operator<<(serving::LogStream &stream, proto::DataType data_type); | |||
| @@ -55,25 +55,47 @@ DispatcherWorkerContext Dispatcher::GetWorkSession(const RequestSpec &request_sp | |||
| return context; | |||
| } | |||
| Status Dispatcher::Dispatch(const proto::PredictRequest &request, proto::PredictReply *reply) { | |||
| void Dispatcher::Dispatch(const proto::PredictRequest &request, proto::PredictReply *reply) { | |||
| MSI_EXCEPTION_IF_NULL(reply); | |||
| auto promise = std::make_shared<std::pair<std::promise<void>, Status>>(std::make_pair(std::promise<void>(), FAILED)); | |||
| auto future = promise->first.get_future(); | |||
| DispatchCallback callback = [promise](Status status) { | |||
| promise->second = status; | |||
| promise->first.set_value(); | |||
| }; | |||
| auto status = DispatchAsync(request, reply, callback); | |||
| auto promise = std::make_shared<std::promise<void>>(); | |||
| auto future = promise->get_future(); | |||
| PredictOnFinish on_finish = [promise]() { promise->set_value(); }; | |||
| DispatchAsync(request, reply, on_finish); | |||
| future.get(); // wait callback finish | |||
| } | |||
| void Dispatcher::DispatchAsync(const proto::PredictRequest &request, proto::PredictReply *reply, | |||
| PredictOnFinish on_finish) { | |||
| MSI_EXCEPTION_IF_NULL(reply); | |||
| Status status; | |||
| (*reply->mutable_servable_spec()) = request.servable_spec(); | |||
| try { | |||
| MSI_TIME_STAMP_START(Predict) | |||
| status = DispatchAsyncInner(request, reply, on_finish); | |||
| MSI_TIME_STAMP_END(Predict) | |||
| } catch (const std::bad_alloc &ex) { | |||
| MSI_LOG(ERROR) << "Serving Error: malloc memory failed"; | |||
| std::cout << "Serving Error: malloc memory failed" << std::endl; | |||
| } catch (const std::runtime_error &ex) { | |||
| MSI_LOG(ERROR) << "Serving Error: runtime error occurred: " << ex.what(); | |||
| std::cout << "Serving Error: runtime error occurred: " << ex.what() << std::endl; | |||
| } catch (const std::exception &ex) { | |||
| MSI_LOG(ERROR) << "Serving Error: exception occurred: " << ex.what(); | |||
| std::cout << "Serving Error: exception occurred: " << ex.what() << std::endl; | |||
| } catch (...) { | |||
| MSI_LOG(ERROR) << "Serving Error: exception occurred"; | |||
| std::cout << "Serving Error: exception occurred"; | |||
| } | |||
| MSI_LOG(INFO) << "Finish call service Eval"; | |||
| if (status != SUCCESS) { | |||
| MSI_LOG_ERROR << "DispatchAsync failed"; | |||
| return status; | |||
| GrpcTensorHelper::CreateReplyFromErrorMsg(status, reply); | |||
| on_finish(); | |||
| } | |||
| future.get(); // wait callback finish | |||
| return promise->second; | |||
| } | |||
| Status Dispatcher::DispatchAsync(const proto::PredictRequest &request, proto::PredictReply *reply, | |||
| DispatchCallback callback) { | |||
| Status Dispatcher::DispatchAsyncInner(const proto::PredictRequest &request, proto::PredictReply *reply, | |||
| PredictOnFinish on_finish) { | |||
| MSI_EXCEPTION_IF_NULL(reply); | |||
| std::shared_lock<std::shared_mutex> lock(servable_shared_lock_); | |||
| RequestSpec request_spec; | |||
| @@ -88,7 +110,7 @@ Status Dispatcher::DispatchAsync(const proto::PredictRequest &request, proto::Pr | |||
| if (!find_method) { | |||
| return INFER_STATUS_LOG_ERROR(INVALID_INPUTS) << "Request " << request_spec.Repr() << ", method is not available"; | |||
| } | |||
| return worker.notify_worker_->DispatchAsync(request, reply, std::move(callback)); | |||
| return worker.notify_worker_->DispatchAsync(request, reply, std::move(on_finish)); | |||
| } | |||
| Status Dispatcher::RegisterServableCommon(const std::vector<WorkerSpec> &worker_specs, CreateNotifyWorkerFunc func) { | |||
| @@ -216,7 +238,7 @@ Status Dispatcher::RegisterServable(const proto::RegisterRequest &request, proto | |||
| std::vector<WorkerSpec> worker_specs; | |||
| GrpcTensorHelper::GetWorkerSpec(request, &worker_specs); | |||
| auto create_notify_worker = [](const WorkerSpec &worker_spec) { | |||
| std::shared_ptr<BaseNotifyWorker> notify_worker = std::make_shared<GrpcNotfiyWorker>(worker_spec.worker_address); | |||
| std::shared_ptr<BaseNotifyWorker> notify_worker = std::make_shared<GrpcNotifyWorker>(worker_spec.worker_address); | |||
| return notify_worker; | |||
| }; | |||
| return RegisterServableCommon(worker_specs, create_notify_worker); | |||
| @@ -232,7 +254,7 @@ Status Dispatcher::AddServable(const proto::AddWorkerRequest &request, proto::Ad | |||
| GrpcTensorHelper::GetWorkerSpec(request, &worker_spec); | |||
| auto create_notify_worker = [](const WorkerSpec &worker_spec) { | |||
| std::shared_ptr<BaseNotifyWorker> notify_worker = std::make_shared<GrpcNotfiyWorker>(worker_spec.worker_address); | |||
| std::shared_ptr<BaseNotifyWorker> notify_worker = std::make_shared<GrpcNotifyWorker>(worker_spec.worker_address); | |||
| return notify_worker; | |||
| }; | |||
| return AddServableCommon(worker_spec, create_notify_worker); | |||
| @@ -40,8 +40,8 @@ class MS_API Dispatcher { | |||
| public: | |||
| Dispatcher(); | |||
| ~Dispatcher(); | |||
| Status Dispatch(const proto::PredictRequest &request, proto::PredictReply *reply); | |||
| Status DispatchAsync(const proto::PredictRequest &request, proto::PredictReply *reply, DispatchCallback callback); | |||
| void Dispatch(const proto::PredictRequest &request, proto::PredictReply *reply); | |||
| void DispatchAsync(const proto::PredictRequest &request, proto::PredictReply *reply, PredictOnFinish on_finish); | |||
| Status RegisterServable(const proto::RegisterRequest &request, proto::RegisterReply *reply); | |||
| Status UnregisterServable(const proto::ExitRequest &request, proto::ExitReply *reply); | |||
| @@ -71,6 +71,9 @@ class MS_API Dispatcher { | |||
| Status UnregisterServableCommon(const std::string &worker_address); | |||
| Status AddServableCommon(const WorkerSpec &worker_spec, CreateNotifyWorkerFunc func); | |||
| Status RemoveServableCommon(const WorkerSpec &worker_spec); | |||
| Status DispatchAsyncInner(const proto::PredictRequest &request, proto::PredictReply *reply, | |||
| PredictOnFinish on_finish); | |||
| }; | |||
| } // namespace mindspore::serving | |||
| @@ -36,53 +36,9 @@ std::string GetProtorWorkerSpecRepr(const proto::WorkerSpec &worker_spec) { | |||
| } | |||
| } // namespace | |||
| Status MSServiceImpl::PredictAsync(const proto::PredictRequest *request, proto::PredictReply *reply, | |||
| DispatchCallback callback) { | |||
| MSI_EXCEPTION_IF_NULL(request); | |||
| MSI_EXCEPTION_IF_NULL(reply); | |||
| Status status(FAILED); | |||
| auto on_status = [request, reply](Status status) { | |||
| if (status != SUCCESS) { | |||
| (*reply->mutable_servable_spec()) = request->servable_spec(); | |||
| reply->clear_error_msg(); | |||
| auto proto_error_msg = reply->add_error_msg(); | |||
| proto_error_msg->set_error_code(FAILED); | |||
| std::string error_msg = status.StatusMessage(); | |||
| if (error_msg.empty()) { | |||
| proto_error_msg->set_error_msg("Predict failed"); | |||
| } else { | |||
| proto_error_msg->set_error_msg(error_msg); | |||
| } | |||
| } | |||
| }; | |||
| DispatchCallback callback_with_status_handle = [callback, on_status](Status status) { | |||
| on_status(status); | |||
| callback(status); | |||
| }; | |||
| try { | |||
| MSI_TIME_STAMP_START(Predict) | |||
| status = dispatcher_->DispatchAsync(*request, reply, callback_with_status_handle); | |||
| MSI_TIME_STAMP_END(Predict) | |||
| } catch (const std::bad_alloc &ex) { | |||
| MSI_LOG(ERROR) << "Serving Error: malloc memory failed"; | |||
| std::cout << "Serving Error: malloc memory failed" << std::endl; | |||
| } catch (const std::runtime_error &ex) { | |||
| MSI_LOG(ERROR) << "Serving Error: runtime error occurred: " << ex.what(); | |||
| std::cout << "Serving Error: runtime error occurred: " << ex.what() << std::endl; | |||
| } catch (const std::exception &ex) { | |||
| MSI_LOG(ERROR) << "Serving Error: exception occurred: " << ex.what(); | |||
| std::cout << "Serving Error: exception occurred: " << ex.what() << std::endl; | |||
| } catch (...) { | |||
| MSI_LOG(ERROR) << "Serving Error: exception occurred"; | |||
| std::cout << "Serving Error: exception occurred"; | |||
| } | |||
| MSI_LOG(INFO) << "Finish call service Eval"; | |||
| if (status != SUCCESS) { | |||
| on_status(status); | |||
| return status; | |||
| } | |||
| return SUCCESS; | |||
| void MSServiceImpl::PredictAsync(const proto::PredictRequest *request, proto::PredictReply *reply, | |||
| PredictOnFinish on_finish) { | |||
| dispatcher_->DispatchAsync(*request, reply, on_finish); | |||
| } | |||
| grpc::Status MSMasterImpl::Register(grpc::ServerContext *context, const proto::RegisterRequest *request, | |||
| @@ -41,7 +41,7 @@ class MSServiceImpl { | |||
| explicit MSServiceImpl(std::shared_ptr<Dispatcher> dispatcher) : dispatcher_(dispatcher) {} | |||
| ~MSServiceImpl() = default; | |||
| Status PredictAsync(const proto::PredictRequest *request, proto::PredictReply *reply, DispatchCallback callback); | |||
| void PredictAsync(const proto::PredictRequest *request, proto::PredictReply *reply, PredictOnFinish on_finish); | |||
| private: | |||
| std::shared_ptr<Dispatcher> dispatcher_; | |||
| @@ -94,11 +94,8 @@ class MasterPredictContext : public MasterServiceContext { | |||
| void HandleRequest() override { | |||
| EnqueueRequest(service_impl_, async_service_, cq_); | |||
| state_ = STATE::FINISH; | |||
| DispatchCallback callback = [this](Status status) { responder_.Finish(response_, grpc::Status::OK, this); }; | |||
| Status status = service_impl_->PredictAsync(&request_, &response_, callback); | |||
| if (!status.IsSuccess()) { | |||
| responder_.Finish(response_, grpc::Status::OK, this); | |||
| } | |||
| PredictOnFinish on_finish = [this]() { responder_.Finish(response_, grpc::Status::OK, this); }; | |||
| service_impl_->PredictAsync(&request_, &response_, on_finish); | |||
| } | |||
| bool JudgeFinish() override { return state_ == STATE::FINISH; } | |||
| @@ -33,7 +33,7 @@ class MS_API BaseNotifyWorker { | |||
| virtual ~BaseNotifyWorker() = default; | |||
| virtual Status Exit() = 0; | |||
| virtual Status DispatchAsync(const proto::PredictRequest &request, proto::PredictReply *reply, | |||
| DispatchCallback callback) = 0; | |||
| PredictOnFinish on_finish) = 0; | |||
| }; | |||
| } // namespace serving | |||
| @@ -20,19 +20,20 @@ | |||
| #include <thread> | |||
| #include "common/exit_handle.h" | |||
| #include "common/grpc_server.h" | |||
| #include "common/proto_tensor.h" | |||
| namespace mindspore { | |||
| namespace serving { | |||
| GrpcNotfiyWorker::GrpcNotfiyWorker(const std::string &worker_address) { | |||
| GrpcNotifyWorker::GrpcNotifyWorker(const std::string &worker_address) { | |||
| worker_address_ = worker_address; | |||
| std::shared_ptr<grpc::Channel> channel = GrpcServer::CreateChannel(worker_address); | |||
| stub_ = proto::MSWorker::NewStub(channel); | |||
| } | |||
| GrpcNotfiyWorker::~GrpcNotfiyWorker() = default; | |||
| GrpcNotifyWorker::~GrpcNotifyWorker() = default; | |||
| Status GrpcNotfiyWorker::Exit() { | |||
| Status GrpcNotifyWorker::Exit() { | |||
| if (stub_) { | |||
| proto::ExitRequest request; | |||
| request.set_address(worker_address_); | |||
| @@ -47,8 +48,8 @@ Status GrpcNotfiyWorker::Exit() { | |||
| return SUCCESS; | |||
| } | |||
| Status GrpcNotfiyWorker::DispatchAsync(const proto::PredictRequest &request, proto::PredictReply *reply, | |||
| DispatchCallback callback) { | |||
| Status GrpcNotifyWorker::DispatchAsync(const proto::PredictRequest &request, proto::PredictReply *reply, | |||
| PredictOnFinish on_finish) { | |||
| if (!stub_) { | |||
| return INFER_STATUS_LOG_ERROR(FAILED) | |||
| << "Predict failed, worker gRPC has not been inited or has already exited, worker address " | |||
| @@ -58,6 +59,10 @@ Status GrpcNotfiyWorker::DispatchAsync(const proto::PredictRequest &request, pro | |||
| client_ = std::make_unique<MSPredictClient>(); | |||
| client_->Start(); | |||
| } | |||
| AsyncPredictCallback callback = [reply, on_finish](Status status) { | |||
| GrpcTensorHelper::CreateReplyFromErrorMsg(status, reply); | |||
| on_finish(); | |||
| }; | |||
| client_->PredictAsync(request, reply, stub_.get(), callback); | |||
| return SUCCESS; | |||
| } | |||
| @@ -27,15 +27,15 @@ | |||
| namespace mindspore { | |||
| namespace serving { | |||
| class MS_API GrpcNotfiyWorker : public BaseNotifyWorker { | |||
| class MS_API GrpcNotifyWorker : public BaseNotifyWorker { | |||
| public: | |||
| explicit GrpcNotfiyWorker(const std::string &worker_address); | |||
| ~GrpcNotfiyWorker() override; | |||
| explicit GrpcNotifyWorker(const std::string &worker_address); | |||
| ~GrpcNotifyWorker() override; | |||
| Status Exit() override; | |||
| Status DispatchAsync(const proto::PredictRequest &request, proto::PredictReply *reply, | |||
| DispatchCallback callback) override; | |||
| PredictOnFinish on_finish) override; | |||
| private: | |||
| std::string worker_address_; | |||
| @@ -26,8 +26,8 @@ Status LocalNotifyWorker::Exit() { | |||
| } | |||
| Status LocalNotifyWorker::DispatchAsync(const proto::PredictRequest &request, proto::PredictReply *reply, | |||
| DispatchCallback callback) { | |||
| return Worker::GetInstance().RunAsync(request, reply, callback); | |||
| PredictOnFinish on_finish) { | |||
| return Worker::GetInstance().RunAsync(request, reply, on_finish); | |||
| } | |||
| } // namespace serving | |||
| @@ -30,7 +30,7 @@ class MS_API LocalNotifyWorker : public BaseNotifyWorker { | |||
| Status Exit() override; | |||
| Status DispatchAsync(const proto::PredictRequest &request, proto::PredictReply *reply, | |||
| DispatchCallback callback) override; | |||
| PredictOnFinish on_finish) override; | |||
| }; | |||
| } // namespace serving | |||
| @@ -693,14 +693,8 @@ Status RestfulService::RunRestful(const std::shared_ptr<RestfulRequest> &restful | |||
| } | |||
| MSI_TIME_STAMP_START(Predict) | |||
| status = dispatcher_->Dispatch(request, &reply); | |||
| dispatcher_->Dispatch(request, &reply); | |||
| MSI_TIME_STAMP_END(Predict) | |||
| if (status != SUCCESS) { | |||
| std::string error_msg = status.StatusMessage(); | |||
| std::string msg = "Predict failed, " + error_msg; | |||
| status = msg; | |||
| return status; | |||
| } | |||
| MSI_TIME_STAMP_START(CreateReplyJson) | |||
| status = ParseReply(reply, out_json); | |||
| @@ -1037,11 +1031,6 @@ Status RestfulService::CheckReply(const ProtoTensor &pb_tensor) { | |||
| return status; | |||
| } | |||
| void RestfulService::ParseErrorMsg(const proto::ErrorMsg &error, json *const js) { | |||
| std::string str = error.error_msg(); | |||
| *js = str; | |||
| } | |||
| // 5.Parse reply | |||
| Status RestfulService::ParseReply(const PredictReply &reply, json *const out_json) { | |||
| Status status(SUCCESS); | |||
| @@ -1059,56 +1048,42 @@ Status RestfulService::ParseReply(const PredictReply &reply, json *const out_jso | |||
| Status RestfulService::ParseInstancesReply(const PredictReply &reply, json *const out_json) { | |||
| Status status(SUCCESS); | |||
| auto error_size = reply.error_msg_size(); | |||
| if (error_size != 0 && error_size != 1 && error_size != instances_nums_) { | |||
| auto reply_size = reply.instances().size(); | |||
| if (error_size == 1 && reply_size == 0) { | |||
| (*out_json)[kErrorMsg] = reply.error_msg()[0].error_msg(); | |||
| return SUCCESS; | |||
| } | |||
| if (error_size != 0 && error_size != instances_nums_) { | |||
| return INFER_STATUS_LOG_ERROR(FAILED) << "reply error size:" << error_size << " is not 0,1 or instances size"; | |||
| } | |||
| if (reply_size != instances_nums_) { | |||
| return INFER_STATUS_LOG_ERROR(FAILED) | |||
| << "reply size:" << reply_size << " is not matched request size:" << instances_nums_; | |||
| } | |||
| (*out_json)[kInstancesReply] = json(); | |||
| json &instances_json = (*out_json)[kInstancesReply]; | |||
| int32_t reply_num = reply.instances().size(); | |||
| if (reply_num == 0) { | |||
| reply_num = error_size; | |||
| } | |||
| if (error_size == 0 && reply_num != instances_nums_) { | |||
| return INFER_STATUS_LOG_ERROR(FAILED) | |||
| << "reply size:" << reply_num << " is not matched request size:" << instances_nums_; | |||
| } | |||
| for (int32_t i = 0; i < reply_num; i++) { | |||
| bool success_flag = true; | |||
| if (i < error_size) { | |||
| auto &cur_error = reply.error_msg().at(i); | |||
| success_flag = (cur_error.error_code() == 0); | |||
| for (int32_t i = 0; i < instances_nums_; i++) { | |||
| instances_json.push_back(json()); | |||
| auto &instance = instances_json.back(); | |||
| if (error_size != 0 && reply.error_msg()[i].error_code() != 0) { | |||
| instance[kErrorMsg] = reply.error_msg(i).error_msg(); | |||
| continue; | |||
| } | |||
| auto &cur_instance = reply.instances(i); | |||
| auto &items = cur_instance.items(); | |||
| if (items.empty()) { | |||
| return INFER_STATUS_LOG_ERROR(FAILED) << "reply instance items is empty"; | |||
| } | |||
| if (success_flag) { | |||
| if (i >= reply.instances_size()) { | |||
| return INFER_STATUS_LOG_ERROR(FAILED) | |||
| << "index:" << i << " is more than reply instances size:" << reply.instances_size(); | |||
| } | |||
| auto &cur_instance = reply.instances(i); | |||
| auto &items = cur_instance.items(); | |||
| if (items.empty()) { | |||
| return INFER_STATUS_LOG_ERROR(FAILED) << "reply instance items is empty"; | |||
| } | |||
| instances_json.push_back(json()); | |||
| auto &instance = instances_json.back(); | |||
| for (auto &item : items) { | |||
| instance[item.first] = json(); | |||
| auto &value_json = instance[item.first]; | |||
| status = ParseReplyDetail(item.second, &value_json); | |||
| if (status != SUCCESS) { | |||
| return status; | |||
| } | |||
| for (auto &item : items) { | |||
| instance[item.first] = json(); | |||
| auto &value_json = instance[item.first]; | |||
| status = ParseReplyDetail(item.second, &value_json); | |||
| if (status != SUCCESS) { | |||
| return status; | |||
| } | |||
| } else { | |||
| instances_json.push_back(json()); | |||
| auto &obj = instances_json.back(); | |||
| obj[kErrorMsg] = json(); | |||
| auto &js = obj[kErrorMsg]; | |||
| ParseErrorMsg(reply.error_msg(i), &js); | |||
| } | |||
| } | |||
| return status; | |||
| @@ -98,7 +98,6 @@ class RestfulService { | |||
| Status ParseScalarData(const ProtoTensor &pb_tensor, bool is_bytes, size_t index, json *const js); | |||
| template <typename T> | |||
| bool IsString(); | |||
| void ParseErrorMsg(const proto::ErrorMsg &error_msg, json *const js); | |||
| RequestType request_type_{kInvalidType}; | |||
| InstancesType instances_type_{kInvalidWay}; | |||
| @@ -41,13 +41,13 @@ class MSDistributedImpl final : public MSWorkerImpl { | |||
| : MSWorkerImpl(server_address), servable_(servable) {} | |||
| ~MSDistributedImpl() = default; | |||
| grpc::Status AgentRegister(grpc::ServerContext *context, const proto::AgentRegisterRequest *request, | |||
| proto::AgentRegisterReply *reply) override; | |||
| proto::AgentRegisterReply *reply); | |||
| grpc::Status AgentExit(grpc::ServerContext *context, const proto::AgentExitRequest *request, | |||
| proto::AgentExitReply *reply) override; | |||
| proto::AgentExitReply *reply); | |||
| grpc::Status AgentFailed(grpc::ServerContext *context, const proto::AgentFailedRequest *request, | |||
| proto::AgentFailedReply *reply) override; | |||
| proto::AgentFailedReply *reply); | |||
| grpc::Status AgentConfigAcquire(grpc::ServerContext *context, const proto::AgentConfigAcquireRequest *request, | |||
| proto::AgentConfigAcquireReply *reply) override; | |||
| proto::AgentConfigAcquireReply *reply); | |||
| private: | |||
| std::shared_ptr<DistributedServable> servable_; | |||
| @@ -76,7 +76,7 @@ Status DistributedServable::PredictInner(const std::vector<TensorBasePtr> &input | |||
| auto msg_list = std::make_shared<std::vector<DistributedPredictMsg>>(rank_size); | |||
| for (size_t i = 0; i < rank_size; ++i) { | |||
| DispatchCallback callback = [msg_list, i](const Status &status) { | |||
| AsyncPredictCallback callback = [msg_list, i](const Status &status) { | |||
| msg_list->at(i).status = status; | |||
| msg_list->at(i).promise.set_value(); | |||
| }; | |||
| @@ -33,7 +33,7 @@ class MS_API BaseNotifyAgent { | |||
| virtual ~BaseNotifyAgent() = default; | |||
| virtual Status Exit() = 0; | |||
| virtual Status DispatchAsync(const proto::DistributedPredictRequest &request, proto::DistributedPredictReply *reply, | |||
| DispatchCallback callback) = 0; | |||
| AsyncPredictCallback callback) = 0; | |||
| }; | |||
| } // namespace serving | |||
| @@ -55,7 +55,7 @@ Status GrpcNotifyAgent::Exit() { | |||
| } | |||
| Status GrpcNotifyAgent::DispatchAsync(const proto::DistributedPredictRequest &request, | |||
| proto::DistributedPredictReply *reply, DispatchCallback callback) { | |||
| proto::DistributedPredictReply *reply, AsyncPredictCallback callback) { | |||
| if (!stub_) { | |||
| return INFER_STATUS_LOG_ERROR(FAILED) | |||
| << "Predict failed, agent gRPC has not been inited or has already exited, agent address " << agent_address_; | |||
| @@ -35,7 +35,7 @@ class MS_API GrpcNotifyAgent : public BaseNotifyAgent { | |||
| Status Exit() override; | |||
| Status DispatchAsync(const proto::DistributedPredictRequest &request, proto::DistributedPredictReply *reply, | |||
| DispatchCallback callback) override; | |||
| AsyncPredictCallback callback) override; | |||
| private: | |||
| std::string agent_address_; | |||
| @@ -16,6 +16,7 @@ | |||
| #include "worker/grpc/worker_process.h" | |||
| #include "worker/worker.h" | |||
| #include "common/proto_tensor.h" | |||
| namespace mindspore { | |||
| namespace serving { | |||
| @@ -26,13 +27,13 @@ grpc::Status MSWorkerImpl::Exit(grpc::ServerContext *context, const proto::ExitR | |||
| return grpc::Status::OK; | |||
| } | |||
| grpc::Status MSWorkerImpl::PredictAsync(grpc::ServerContext *context, const proto::PredictRequest *request, | |||
| proto::PredictReply *reply, DispatchCallback callback) { | |||
| void MSWorkerImpl::PredictAsync(grpc::ServerContext *context, const proto::PredictRequest *request, | |||
| proto::PredictReply *reply, PredictOnFinish on_finish) { | |||
| Status status(FAILED); | |||
| MSI_LOG(INFO) << "Begin call service Eval"; | |||
| try { | |||
| MSI_TIME_STAMP_START(Predict) | |||
| status = Worker::GetInstance().RunAsync(*request, reply, callback); | |||
| status = Worker::GetInstance().RunAsync(*request, reply, on_finish); | |||
| MSI_TIME_STAMP_END(Predict) | |||
| } catch (const std::bad_alloc &ex) { | |||
| MSI_LOG(ERROR) << "Serving Error: malloc memory failed"; | |||
| @@ -49,19 +50,12 @@ grpc::Status MSWorkerImpl::PredictAsync(grpc::ServerContext *context, const prot | |||
| } | |||
| MSI_LOG(INFO) << "Finish call service Eval"; | |||
| if (status == INVALID_INPUTS) { | |||
| auto proto_error_msg = reply->add_error_msg(); | |||
| proto_error_msg->set_error_code(status.StatusCode()); | |||
| proto_error_msg->set_error_msg(status.StatusMessage()); | |||
| return grpc::Status::OK; | |||
| } else if (status != SUCCESS) { | |||
| auto proto_error_msg = reply->add_error_msg(); | |||
| proto_error_msg->set_error_code(FAILED); | |||
| proto_error_msg->set_error_msg("Predict failed"); | |||
| return grpc::Status::OK; | |||
| if (status != SUCCESS) { | |||
| GrpcTensorHelper::CreateReplyFromErrorMsg(status, reply); | |||
| on_finish(); | |||
| } | |||
| return grpc::Status::OK; | |||
| } | |||
| grpc::Status MSWorkerImpl::Ping(grpc::ServerContext *context, const proto::PingRequest *request, | |||
| proto::PingReply *reply) { | |||
| MSI_EXCEPTION_IF_NULL(request); | |||
| @@ -35,7 +35,7 @@ namespace mindspore { | |||
| namespace serving { | |||
| // Service Implement | |||
| class MSWorkerImpl : public proto::MSWorker::Service { | |||
| class MSWorkerImpl { | |||
| public: | |||
| explicit MSWorkerImpl(const std::string server_address) { | |||
| if (!watcher_) { | |||
| @@ -43,11 +43,11 @@ class MSWorkerImpl : public proto::MSWorker::Service { | |||
| } | |||
| } | |||
| grpc::Status PredictAsync(grpc::ServerContext *context, const proto::PredictRequest *request, | |||
| proto::PredictReply *reply, DispatchCallback callback); | |||
| grpc::Status Exit(grpc::ServerContext *context, const proto::ExitRequest *request, proto::ExitReply *reply) override; | |||
| grpc::Status Ping(grpc::ServerContext *context, const proto::PingRequest *request, proto::PingReply *reply) override; | |||
| grpc::Status Pong(grpc::ServerContext *context, const proto::PongRequest *request, proto::PongReply *reply) override; | |||
| void PredictAsync(grpc::ServerContext *context, const proto::PredictRequest *request, proto::PredictReply *reply, | |||
| PredictOnFinish on_finish); | |||
| grpc::Status Exit(grpc::ServerContext *context, const proto::ExitRequest *request, proto::ExitReply *reply); | |||
| grpc::Status Ping(grpc::ServerContext *context, const proto::PingRequest *request, proto::PingReply *reply); | |||
| grpc::Status Pong(grpc::ServerContext *context, const proto::PongRequest *request, proto::PongReply *reply); | |||
| std::shared_ptr<Watcher<proto::MSAgent, proto::MSMaster>> watcher_; | |||
| }; | |||
| @@ -100,11 +100,8 @@ class WorkerPredictContext : public WorkerServiceContext { | |||
| void HandleRequest() override { | |||
| EnqueueRequest(service_impl_, async_service_, cq_); | |||
| state_ = STATE::FINISH; | |||
| DispatchCallback callback = [this](Status status) { responder_.Finish(response_, grpc::Status::OK, this); }; | |||
| grpc::Status status = service_impl_->PredictAsync(&ctx_, &request_, &response_, callback); | |||
| if (!status.ok()) { | |||
| responder_.Finish(response_, status, this); | |||
| } | |||
| PredictOnFinish on_finish = [this]() { responder_.Finish(response_, grpc::Status::OK, this); }; | |||
| service_impl_->PredictAsync(&ctx_, &request_, &response_, on_finish); | |||
| } | |||
| private: | |||
| @@ -141,22 +141,22 @@ std::shared_ptr<Context> MindSporeModelWrap::TransformModelContext(const std::ma | |||
| MSI_LOG_ERROR << "Set model context output type failed, unknown data type " << val; | |||
| } | |||
| }; | |||
| std::map<std::string, ContextStrFun> option_map = { | |||
| {"acl_option.insert_op_config_file_path", mindspore::ModelContext::SetInsertOpConfigPath}, | |||
| {"acl_option.input_format", mindspore::ModelContext::SetInputFormat}, | |||
| {"acl_option.input_shape", mindspore::ModelContext::SetInputShape}, | |||
| {"acl_option.output_type", set_output_type}, | |||
| {"acl_option.precision_mode", mindspore::ModelContext::SetPrecisionMode}, | |||
| {"acl_option.op_select_impl_mode", mindspore::ModelContext::SetOpSelectImplMode}, | |||
| }; | |||
| auto context = std::make_shared<mindspore::ModelContext>(); | |||
| for (auto &item : options) { | |||
| const auto &key = item.first; | |||
| const auto &value = item.second; | |||
| auto it = option_map.find(key); | |||
| if (it != option_map.end()) { | |||
| MSI_LOG_INFO << "Set context options, key: " << key << ", value: " << value; | |||
| it->second(context, value); | |||
| if (key == "acl_option.insert_op_config_file_path") { | |||
| mindspore::ModelContext::SetInsertOpConfigPath(context, value); | |||
| } else if (key == "acl_option.input_format") { | |||
| mindspore::ModelContext::SetInputFormat(context, value); | |||
| } else if (key == "acl_option.input_shape") { | |||
| mindspore::ModelContext::SetInputShape(context, value); | |||
| } else if (key == "acl_option.output_type") { | |||
| set_output_type(context, value); | |||
| } else if (key == "acl_option.precision_mode") { | |||
| mindspore::ModelContext::SetPrecisionMode(context, value); | |||
| } else if (key == "acl_option.op_select_impl_mode") { | |||
| mindspore::ModelContext::SetOpSelectImplMode(context, value); | |||
| } | |||
| } | |||
| return context; | |||
| @@ -60,7 +60,7 @@ Status Worker::RemoveWorker(const ServableWorkerContext &work) { | |||
| return notify_master_->RemoveWorker(work.worker_spec); | |||
| } | |||
| Status Worker::RunAsync(const proto::PredictRequest &request, proto::PredictReply *reply, DispatchCallback callback) { | |||
| Status Worker::RunAsync(const proto::PredictRequest &request, proto::PredictReply *reply, PredictOnFinish on_finish) { | |||
| std::shared_lock<std::shared_mutex> lock(worker_shared_lock_); | |||
| if (!servable_started_) { | |||
| return INFER_STATUS_LOG_ERROR(FAILED) << "RunAsync worker for inference failed, worker has not been started"; | |||
| @@ -81,16 +81,9 @@ Status Worker::RunAsync(const proto::PredictRequest &request, proto::PredictRepl | |||
| if (worker.worker_service == nullptr) { | |||
| return INFER_STATUS_LOG_ERROR(FAILED) << "Cannot find servable match " << request_spec.Repr(); | |||
| } | |||
| WorkCallBack on_process_done = [request, reply, callback](const std::vector<InstancePtr> &instances) { | |||
| auto status = GrpcTensorHelper::CreateReplyFromInstances(request, instances, reply); | |||
| if (status != SUCCESS) { | |||
| MSI_LOG_ERROR << "transfer result to reply failed"; | |||
| reply->clear_error_msg(); | |||
| auto proto_error = reply->add_error_msg(); | |||
| proto_error->set_error_code(status.StatusCode()); | |||
| proto_error->set_error_msg(status.StatusMessage()); | |||
| } | |||
| callback(SUCCESS); | |||
| WorkCallBack on_process_done = [request, reply, on_finish](const std::vector<InstancePtr> &instances) { | |||
| GrpcTensorHelper::CreateReplyFromInstances(request, instances, reply); | |||
| on_finish(); | |||
| }; | |||
| return worker.worker_service->Work(request_spec, instances_data, on_process_done); | |||
| } | |||
| @@ -53,7 +53,7 @@ class MS_API Worker { | |||
| static Worker &GetInstance(); | |||
| void Clear(); | |||
| Status RunAsync(const proto::PredictRequest &request, proto::PredictReply *reply, DispatchCallback callback); | |||
| Status RunAsync(const proto::PredictRequest &request, proto::PredictReply *reply, PredictOnFinish on_finish); | |||
| Status StartServable(std::shared_ptr<ServableBase> servable, std::shared_ptr<BaseNotifyMaster> notify_master); | |||
| Status StartGrpcServer(const std::shared_ptr<MSWorkerServer> &grpc_server, const std::string &worker_ip, | |||
| @@ -376,11 +376,8 @@ class TestMasterWorkerClient : public TestMasterWorker { | |||
| grpc::ServerContext context; | |||
| auto promise = std::make_shared<std::promise<void>>(); | |||
| auto future = promise->get_future(); | |||
| DispatchCallback callback = [promise](Status status) { promise->set_value(); }; | |||
| auto status = impl.PredictAsync(&request, reply, callback); | |||
| if (!status.IsSuccess()) { | |||
| return grpc::Status::OK; | |||
| } | |||
| PredictOnFinish callback = [promise]() { promise->set_value(); }; | |||
| impl.PredictAsync(&request, reply, callback); | |||
| future.get(); | |||
| return grpc::Status::OK; | |||
| } | |||
| @@ -173,7 +173,7 @@ def test_grpc_start_restful_server_twice_failed(): | |||
| @serving_test | |||
| def test_grpc_alone_repeat_master_and_woker_port_failed(): | |||
| def test_grpc_alone_repeat_master_and_worker_port_failed(): | |||
| base = ServingTestBase() | |||
| base.init_servable(1, "add_servable_config.py") | |||
| master.start_master_server(master_port=7600) | |||
| @@ -731,3 +731,370 @@ def add_common(x1, x2): | |||
| assert result[0]["y"] == 0 | |||
| assert "Preprocess Failed" in str(result[1]["error"]) or "Servable stopped" in str(result[1]["error"]) | |||
| assert result[0]["y"] == 0 | |||
| @serving_test | |||
| def test_servable_worker_with_master_preprocess_runtime_error(): | |||
| # fail returned from Preprocess | |||
| base = ServingTestBase() | |||
| servable_content = servable_config_import | |||
| servable_content += servable_config_declare_servable | |||
| servable_content += r""" | |||
| index = 0 | |||
| def preprocess(instances): | |||
| count = len(instances) | |||
| global index | |||
| for i in range(count): | |||
| ret = index | |||
| index += 1 | |||
| if ret == 0: | |||
| raise RuntimeError("runtime error") | |||
| yield ret | |||
| @register.register_method(output_names=["y"]) | |||
| def add_common(x1, x2): | |||
| x3 = register.call_preprocess_pipeline(preprocess, x1) | |||
| y = register.call_servable(x1, x2) | |||
| return x3 | |||
| """ | |||
| base.init_servable_with_servable_config(1, servable_content) | |||
| worker.start_servable_in_master(base.servable_dir, base.servable_name) | |||
| master.start_grpc_server("0.0.0.0", 5500) | |||
| # Client | |||
| instance_count = 3 | |||
| instances = [] | |||
| y_data_list = [] | |||
| for i in range(instance_count): | |||
| x1 = np.asarray([[1.1, 2.2], [3.3, 4.4]]).astype(np.float32) * (i + 1) | |||
| x2 = np.asarray([[5.5, 6.6], [7.7, 8.8]]).astype(np.float32) * (i + 1) | |||
| y_data_list.append(x1 + x2) | |||
| instances.append({"x1": x1, "x2": x2}) | |||
| client = create_client("localhost", 5500, base.servable_name, "add_common") | |||
| result = client.infer(instances) | |||
| print(result) | |||
| assert "Preprocess Failed" in result[0]["error"] | |||
| assert result[1]["y"] == 1 | |||
| assert result[2]["y"] == 2 | |||
| @serving_test | |||
| def test_servable_worker_alone_preprocess_runtime_error(): | |||
| # fail returned from Preprocess | |||
| base = ServingTestBase() | |||
| servable_content = servable_config_import | |||
| servable_content += servable_config_declare_servable | |||
| servable_content += r""" | |||
| index = 0 | |||
| def preprocess(instances): | |||
| count = len(instances) | |||
| global index | |||
| for i in range(count): | |||
| ret = index | |||
| index += 1 | |||
| if ret == 0: | |||
| raise RuntimeError("runtime error") | |||
| yield ret | |||
| @register.register_method(output_names=["y"]) | |||
| def add_common(x1, x2): | |||
| x3 = register.call_preprocess_pipeline(preprocess, x1) | |||
| y = register.call_servable(x1, x2) | |||
| return x3 | |||
| """ | |||
| base.init_servable_with_servable_config(1, servable_content) | |||
| master.start_master_server("0.0.0.0", 6100) | |||
| worker.start_servable(base.servable_dir, base.servable_name, worker_port=6200, master_port=6100) | |||
| master.start_grpc_server("0.0.0.0", 5500) | |||
| # Client | |||
| instance_count = 3 | |||
| instances = [] | |||
| y_data_list = [] | |||
| for i in range(instance_count): | |||
| x1 = np.asarray([[1.1, 2.2], [3.3, 4.4]]).astype(np.float32) * (i + 1) | |||
| x2 = np.asarray([[5.5, 6.6], [7.7, 8.8]]).astype(np.float32) * (i + 1) | |||
| y_data_list.append(x1 + x2) | |||
| instances.append({"x1": x1, "x2": x2}) | |||
| client = create_client("localhost", 5500, base.servable_name, "add_common") | |||
| result = client.infer(instances) | |||
| print(result) | |||
| assert "Preprocess Failed" in result[0]["error"] | |||
| assert result[1]["y"] == 1 | |||
| assert result[2]["y"] == 2 | |||
| @serving_test | |||
| def test_servable_worker_with_master_predict_check_failed(): | |||
| # fail returned from Predict | |||
| base = ServingTestBase() | |||
| servable_content = servable_config_import | |||
| servable_content += servable_config_declare_servable | |||
| servable_content += r""" | |||
| @register.register_method(output_names=["y"]) | |||
| def add_common(x1, x2): | |||
| y = register.call_servable(x1, x2) | |||
| return y | |||
| """ | |||
| base.init_servable_with_servable_config(1, servable_content) | |||
| worker.start_servable_in_master(base.servable_dir, base.servable_name) | |||
| master.start_grpc_server("0.0.0.0", 5500) | |||
| # Client | |||
| instance_count = 3 | |||
| instances = [] | |||
| y_data_list = [] | |||
| for i in range(instance_count): | |||
| if i == 0: | |||
| x1 = np.asarray([[1.1], [3.3]]).astype(np.float32) * (i + 1) | |||
| else: | |||
| x1 = np.asarray([[1.1, 2.2], [3.3, 4.4]]).astype(np.float32) * (i + 1) | |||
| x2 = np.asarray([[5.5, 6.6], [7.7, 8.8]]).astype(np.float32) * (i + 1) | |||
| y_data_list.append(x1 + x2) | |||
| instances.append({"x1": x1, "x2": x2}) | |||
| client = create_client("localhost", 5500, base.servable_name, "add_common") | |||
| result = client.infer(instances) | |||
| print(result) | |||
| assert "Given model input 0 size 8 not match the size 16 defined in model" in result[0]["error"] | |||
| assert "y" in result[1] | |||
| assert "y" in result[2] | |||
| @serving_test | |||
| def test_servable_worker_alone_predict_check_failed(): | |||
| # fail returned from Predict | |||
| base = ServingTestBase() | |||
| servable_content = servable_config_import | |||
| servable_content += servable_config_declare_servable | |||
| servable_content += r""" | |||
| @register.register_method(output_names=["y"]) | |||
| def add_common(x1, x2): | |||
| y = register.call_servable(x1, x2) | |||
| return y | |||
| """ | |||
| base.init_servable_with_servable_config(1, servable_content) | |||
| master.start_master_server("0.0.0.0", 6100) | |||
| worker.start_servable(base.servable_dir, base.servable_name, worker_port=6200, master_port=6100) | |||
| master.start_grpc_server("0.0.0.0", 5500) | |||
| # Client | |||
| instance_count = 3 | |||
| instances = [] | |||
| y_data_list = [] | |||
| for i in range(instance_count): | |||
| if i == 0: | |||
| x1 = np.asarray([[1.1], [3.3]]).astype(np.float32) * (i + 1) | |||
| else: | |||
| x1 = np.asarray([[1.1, 2.2], [3.3, 4.4]]).astype(np.float32) * (i + 1) | |||
| x2 = np.asarray([[5.5, 6.6], [7.7, 8.8]]).astype(np.float32) * (i + 1) | |||
| y_data_list.append(x1 + x2) | |||
| instances.append({"x1": x1, "x2": x2}) | |||
| client = create_client("localhost", 5500, base.servable_name, "add_common") | |||
| result = client.infer(instances) | |||
| print(result) | |||
| assert "Given model input 0 size 8 not match the size 16 defined in model" in result[0]["error"] | |||
| assert "y" in result[1] | |||
| assert "y" in result[2] | |||
| @serving_test | |||
| def test_servable_worker_with_master_postprocess_runtime_error(): | |||
| # fail returned from Preprocess | |||
| base = ServingTestBase() | |||
| servable_content = servable_config_import | |||
| servable_content += servable_config_declare_servable | |||
| servable_content += r""" | |||
| index = 0 | |||
| def postprocess(y): | |||
| global index | |||
| ret = index | |||
| index += 1 | |||
| if ret == 0: | |||
| raise RuntimeError("runtime error") | |||
| return ret | |||
| @register.register_method(output_names=["y"]) | |||
| def add_common(x1, x2): | |||
| y = register.call_servable(x1, x2) | |||
| y = register.call_postprocess(postprocess, y) | |||
| return y | |||
| """ | |||
| base.init_servable_with_servable_config(1, servable_content) | |||
| worker.start_servable_in_master(base.servable_dir, base.servable_name) | |||
| master.start_grpc_server("0.0.0.0", 5500) | |||
| # Client | |||
| instance_count = 3 | |||
| instances = [] | |||
| y_data_list = [] | |||
| for i in range(instance_count): | |||
| x1 = np.asarray([[1.1, 2.2], [3.3, 4.4]]).astype(np.float32) * (i + 1) | |||
| x2 = np.asarray([[5.5, 6.6], [7.7, 8.8]]).astype(np.float32) * (i + 1) | |||
| y_data_list.append(x1 + x2) | |||
| instances.append({"x1": x1, "x2": x2}) | |||
| client = create_client("localhost", 5500, base.servable_name, "add_common") | |||
| result = client.infer(instances) | |||
| print(result) | |||
| assert "Postprocess Failed" in result[0]["error"] | |||
| assert result[1]["y"] == 1 | |||
| assert result[2]["y"] == 2 | |||
| @serving_test | |||
| def test_servable_worker_alone_postprocess_runtime_error(): | |||
| # fail returned from Preprocess | |||
| base = ServingTestBase() | |||
| servable_content = servable_config_import | |||
| servable_content += servable_config_declare_servable | |||
| servable_content += r""" | |||
| index = 0 | |||
| def postprocess(y): | |||
| global index | |||
| ret = index | |||
| index += 1 | |||
| if ret == 0: | |||
| raise RuntimeError("runtime error") | |||
| return ret | |||
| @register.register_method(output_names=["y"]) | |||
| def add_common(x1, x2): | |||
| y = register.call_servable(x1, x2) | |||
| y = register.call_postprocess(postprocess, y) | |||
| return y | |||
| """ | |||
| base.init_servable_with_servable_config(1, servable_content) | |||
| master.start_master_server("0.0.0.0", 6100) | |||
| worker.start_servable(base.servable_dir, base.servable_name, worker_port=6200, master_port=6100) | |||
| master.start_grpc_server("0.0.0.0", 5500) | |||
| # Client | |||
| instance_count = 3 | |||
| instances = [] | |||
| y_data_list = [] | |||
| for i in range(instance_count): | |||
| x1 = np.asarray([[1.1, 2.2], [3.3, 4.4]]).astype(np.float32) * (i + 1) | |||
| x2 = np.asarray([[5.5, 6.6], [7.7, 8.8]]).astype(np.float32) * (i + 1) | |||
| y_data_list.append(x1 + x2) | |||
| instances.append({"x1": x1, "x2": x2}) | |||
| client = create_client("localhost", 5500, base.servable_name, "add_common") | |||
| result = client.infer(instances) | |||
| print(result) | |||
| assert "Postprocess Failed" in result[0]["error"] | |||
| assert result[1]["y"] == 1 | |||
| assert result[2]["y"] == 2 | |||
| @serving_test | |||
| def test_servable_worker_with_master_input_param_less(): | |||
| # fail returned from Worker::RunAsync | |||
| base = ServingTestBase() | |||
| servable_content = servable_config_import | |||
| servable_content += servable_config_declare_servable | |||
| servable_content += servable_config_method_add_common | |||
| base.init_servable_with_servable_config(1, servable_content) | |||
| worker.start_servable_in_master(base.servable_dir, base.servable_name) | |||
| master.start_grpc_server("0.0.0.0", 5500) | |||
| # Client | |||
| instance_count = 3 | |||
| instances = [] | |||
| y_data_list = [] | |||
| for i in range(instance_count): | |||
| x1 = np.asarray([[1.1], [3.3]]).astype(np.float32) * (i + 1) | |||
| x2 = np.asarray([[5.5], [7.7]]).astype(np.float32) * (i + 1) | |||
| y_data_list.append(x1 + x2) | |||
| instances.append({"x3": x1, "x2": x2}) | |||
| client = create_client("localhost", 5500, base.servable_name, "add_common") | |||
| result = client.infer(instances) | |||
| print(result) | |||
| assert "Cannot find input x1 in instance input" in result["error"] | |||
| @serving_test | |||
| def test_servable_worker_alone_input_param_less(): | |||
| # fail returned from Worker::RunAsync | |||
| base = ServingTestBase() | |||
| servable_content = servable_config_import | |||
| servable_content += servable_config_declare_servable | |||
| servable_content += servable_config_method_add_common | |||
| base.init_servable_with_servable_config(1, servable_content) | |||
| master.start_master_server("0.0.0.0", 6100) | |||
| worker.start_servable(base.servable_dir, base.servable_name, worker_port=6200, master_port=6100) | |||
| master.start_grpc_server("0.0.0.0", 5500) | |||
| # Client | |||
| instance_count = 3 | |||
| instances = [] | |||
| y_data_list = [] | |||
| for i in range(instance_count): | |||
| x1 = np.asarray([[1.1], [3.3]]).astype(np.float32) * (i + 1) | |||
| x2 = np.asarray([[5.5], [7.7]]).astype(np.float32) * (i + 1) | |||
| y_data_list.append(x1 + x2) | |||
| instances.append({"x3": x1, "x2": x2}) | |||
| client = create_client("localhost", 5500, base.servable_name, "add_common") | |||
| result = client.infer(instances) | |||
| print(result) | |||
| assert "Cannot find input x1 in instance input" in result["error"] | |||
| @serving_test | |||
| def test_servable_worker_with_master_servable_not_available(): | |||
| # fail returned from Worker::RunAsync | |||
| base = ServingTestBase() | |||
| servable_content = servable_config_import | |||
| servable_content += servable_config_declare_servable | |||
| servable_content += servable_config_method_add_common | |||
| base.init_servable_with_servable_config(1, servable_content) | |||
| worker.start_servable_in_master(base.servable_dir, base.servable_name) | |||
| master.start_grpc_server("0.0.0.0", 5500) | |||
| # Client | |||
| instance_count = 3 | |||
| instances = [] | |||
| y_data_list = [] | |||
| for i in range(instance_count): | |||
| x1 = np.asarray([[1.1], [3.3]]).astype(np.float32) * (i + 1) | |||
| x2 = np.asarray([[5.5], [7.7]]).astype(np.float32) * (i + 1) | |||
| y_data_list.append(x1 + x2) | |||
| instances.append({"x3": x1, "x2": x2}) | |||
| client = create_client("localhost", 5500, base.servable_name + "error", "add_common") | |||
| result = client.infer(instances) | |||
| print(result) | |||
| assert "servable is not available" in result["error"] | |||
| @serving_test | |||
| def test_servable_worker_alone_servable_not_available(): | |||
| # fail returned from Worker::RunAsync | |||
| base = ServingTestBase() | |||
| servable_content = servable_config_import | |||
| servable_content += servable_config_declare_servable | |||
| servable_content += servable_config_method_add_common | |||
| base.init_servable_with_servable_config(1, servable_content) | |||
| master.start_master_server("0.0.0.0", 6100) | |||
| worker.start_servable(base.servable_dir, base.servable_name, worker_port=6200, master_port=6100) | |||
| master.start_grpc_server("0.0.0.0", 5500) | |||
| # Client | |||
| instance_count = 3 | |||
| instances = [] | |||
| y_data_list = [] | |||
| for i in range(instance_count): | |||
| x1 = np.asarray([[1.1], [3.3]]).astype(np.float32) * (i + 1) | |||
| x2 = np.asarray([[5.5], [7.7]]).astype(np.float32) * (i + 1) | |||
| y_data_list.append(x1 + x2) | |||
| instances.append({"x3": x1, "x2": x2}) | |||
| client = create_client("localhost", 5500, base.servable_name + "error", "add_common") | |||
| result = client.infer(instances) | |||
| print(result) | |||
| assert "servable is not available" in result["error"] | |||
| @@ -14,6 +14,9 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #include "include/api/context.h" | |||
| #include <any> | |||
| #include <map> | |||
| #include <type_traits> | |||
| #include "utils/log_adapter.h" | |||
| constexpr auto kGlobalContextDeviceTarget = "mindspore.ascend.globalcontext.device_target"; | |||
| @@ -28,18 +31,28 @@ constexpr auto kModelOptionPrecisionMode = "mindspore.option.precision_mode"; | |||
| constexpr auto kModelOptionOpSelectImplMode = "mindspore.option.op_select_impl_mode"; | |||
| namespace mindspore { | |||
| template <class T> | |||
| static T GetValue(const std::shared_ptr<Context> &context, const std::string &key) { | |||
| auto iter = context->params.find(key); | |||
| if (iter == context->params.end()) { | |||
| return T(); | |||
| struct Context::Data { | |||
| std::map<std::string, std::any> params; | |||
| }; | |||
| Context::Context() : data(std::make_shared<Data>()) {} | |||
| template <class T, typename U = std::remove_cv_t<std::remove_reference_t<T>>> | |||
| static const U &GetValue(const std::shared_ptr<Context> &context, const std::string &key) { | |||
| static U empty_result; | |||
| if (context == nullptr || context->data == nullptr) { | |||
| return empty_result; | |||
| } | |||
| auto iter = context->data->params.find(key); | |||
| if (iter == context->data->params.end()) { | |||
| return empty_result; | |||
| } | |||
| const std::any &value = iter->second; | |||
| if (value.type() != typeid(T)) { | |||
| return T(); | |||
| if (value.type() != typeid(U)) { | |||
| return empty_result; | |||
| } | |||
| return std::any_cast<T>(value); | |||
| return std::any_cast<const U &>(value); | |||
| } | |||
| std::shared_ptr<Context> GlobalContext::GetGlobalContext() { | |||
| @@ -47,22 +60,31 @@ std::shared_ptr<Context> GlobalContext::GetGlobalContext() { | |||
| return g_context; | |||
| } | |||
| void GlobalContext::SetGlobalDeviceTarget(const std::string &device_target) { | |||
| void GlobalContext::SetGlobalDeviceTarget(const std::vector<char> &device_target) { | |||
| auto global_context = GetGlobalContext(); | |||
| MS_EXCEPTION_IF_NULL(global_context); | |||
| global_context->params[kGlobalContextDeviceTarget] = device_target; | |||
| if (global_context->data == nullptr) { | |||
| global_context->data = std::make_shared<Data>(); | |||
| MS_EXCEPTION_IF_NULL(global_context->data); | |||
| } | |||
| global_context->data->params[kGlobalContextDeviceTarget] = CharToString(device_target); | |||
| } | |||
| std::string GlobalContext::GetGlobalDeviceTarget() { | |||
| std::vector<char> GlobalContext::GetGlobalDeviceTargetChar() { | |||
| auto global_context = GetGlobalContext(); | |||
| MS_EXCEPTION_IF_NULL(global_context); | |||
| return GetValue<std::string>(global_context, kGlobalContextDeviceTarget); | |||
| const std::string &ref = GetValue<std::string>(global_context, kGlobalContextDeviceTarget); | |||
| return StringToChar(ref); | |||
| } | |||
| void GlobalContext::SetGlobalDeviceID(const uint32_t &device_id) { | |||
| auto global_context = GetGlobalContext(); | |||
| MS_EXCEPTION_IF_NULL(global_context); | |||
| global_context->params[kGlobalContextDeviceID] = device_id; | |||
| if (global_context->data == nullptr) { | |||
| global_context->data = std::make_shared<Data>(); | |||
| MS_EXCEPTION_IF_NULL(global_context->data); | |||
| } | |||
| global_context->data->params[kGlobalContextDeviceID] = device_id; | |||
| } | |||
| uint32_t GlobalContext::GetGlobalDeviceID() { | |||
| @@ -71,39 +93,58 @@ uint32_t GlobalContext::GetGlobalDeviceID() { | |||
| return GetValue<uint32_t>(global_context, kGlobalContextDeviceID); | |||
| } | |||
| void ModelContext::SetInsertOpConfigPath(const std::shared_ptr<Context> &context, const std::string &cfg_path) { | |||
| void ModelContext::SetInsertOpConfigPath(const std::shared_ptr<Context> &context, const std::vector<char> &cfg_path) { | |||
| MS_EXCEPTION_IF_NULL(context); | |||
| context->params[kModelOptionInsertOpCfgPath] = cfg_path; | |||
| if (context->data == nullptr) { | |||
| context->data = std::make_shared<Data>(); | |||
| MS_EXCEPTION_IF_NULL(context->data); | |||
| } | |||
| context->data->params[kModelOptionInsertOpCfgPath] = CharToString(cfg_path); | |||
| } | |||
| std::string ModelContext::GetInsertOpConfigPath(const std::shared_ptr<Context> &context) { | |||
| std::vector<char> ModelContext::GetInsertOpConfigPathChar(const std::shared_ptr<Context> &context) { | |||
| MS_EXCEPTION_IF_NULL(context); | |||
| return GetValue<std::string>(context, kModelOptionInsertOpCfgPath); | |||
| const std::string &ref = GetValue<std::string>(context, kModelOptionInsertOpCfgPath); | |||
| return StringToChar(ref); | |||
| } | |||
| void ModelContext::SetInputFormat(const std::shared_ptr<Context> &context, const std::string &format) { | |||
| void ModelContext::SetInputFormat(const std::shared_ptr<Context> &context, const std::vector<char> &format) { | |||
| MS_EXCEPTION_IF_NULL(context); | |||
| context->params[kModelOptionInputFormat] = format; | |||
| if (context->data == nullptr) { | |||
| context->data = std::make_shared<Data>(); | |||
| MS_EXCEPTION_IF_NULL(context->data); | |||
| } | |||
| context->data->params[kModelOptionInputFormat] = CharToString(format); | |||
| } | |||
| std::string ModelContext::GetInputFormat(const std::shared_ptr<Context> &context) { | |||
| std::vector<char> ModelContext::GetInputFormatChar(const std::shared_ptr<Context> &context) { | |||
| MS_EXCEPTION_IF_NULL(context); | |||
| return GetValue<std::string>(context, kModelOptionInputFormat); | |||
| const std::string &ref = GetValue<std::string>(context, kModelOptionInputFormat); | |||
| return StringToChar(ref); | |||
| } | |||
| void ModelContext::SetInputShape(const std::shared_ptr<Context> &context, const std::string &shape) { | |||
| void ModelContext::SetInputShape(const std::shared_ptr<Context> &context, const std::vector<char> &shape) { | |||
| MS_EXCEPTION_IF_NULL(context); | |||
| context->params[kModelOptionInputShape] = shape; | |||
| if (context->data == nullptr) { | |||
| context->data = std::make_shared<Data>(); | |||
| MS_EXCEPTION_IF_NULL(context->data); | |||
| } | |||
| context->data->params[kModelOptionInputShape] = CharToString(shape); | |||
| } | |||
| std::string ModelContext::GetInputShape(const std::shared_ptr<Context> &context) { | |||
| std::vector<char> ModelContext::GetInputShapeChar(const std::shared_ptr<Context> &context) { | |||
| MS_EXCEPTION_IF_NULL(context); | |||
| return GetValue<std::string>(context, kModelOptionInputShape); | |||
| const std::string &ref = GetValue<std::string>(context, kModelOptionInputShape); | |||
| return StringToChar(ref); | |||
| } | |||
| void ModelContext::SetOutputType(const std::shared_ptr<Context> &context, enum DataType output_type) { | |||
| MS_EXCEPTION_IF_NULL(context); | |||
| context->params[kModelOptionOutputType] = output_type; | |||
| if (context->data == nullptr) { | |||
| context->data = std::make_shared<Data>(); | |||
| MS_EXCEPTION_IF_NULL(context->data); | |||
| } | |||
| context->data->params[kModelOptionOutputType] = output_type; | |||
| } | |||
| enum DataType ModelContext::GetOutputType(const std::shared_ptr<Context> &context) { | |||
| @@ -111,24 +152,34 @@ enum DataType ModelContext::GetOutputType(const std::shared_ptr<Context> &contex | |||
| return GetValue<enum DataType>(context, kModelOptionOutputType); | |||
| } | |||
| void ModelContext::SetPrecisionMode(const std::shared_ptr<Context> &context, const std::string &precision_mode) { | |||
| void ModelContext::SetPrecisionMode(const std::shared_ptr<Context> &context, const std::vector<char> &precision_mode) { | |||
| MS_EXCEPTION_IF_NULL(context); | |||
| context->params[kModelOptionPrecisionMode] = precision_mode; | |||
| if (context->data == nullptr) { | |||
| context->data = std::make_shared<Data>(); | |||
| MS_EXCEPTION_IF_NULL(context->data); | |||
| } | |||
| context->data->params[kModelOptionPrecisionMode] = CharToString(precision_mode); | |||
| } | |||
| std::string ModelContext::GetPrecisionMode(const std::shared_ptr<Context> &context) { | |||
| std::vector<char> ModelContext::GetPrecisionModeChar(const std::shared_ptr<Context> &context) { | |||
| MS_EXCEPTION_IF_NULL(context); | |||
| return GetValue<std::string>(context, kModelOptionPrecisionMode); | |||
| const std::string &ref = GetValue<std::string>(context, kModelOptionPrecisionMode); | |||
| return StringToChar(ref); | |||
| } | |||
| void ModelContext::SetOpSelectImplMode(const std::shared_ptr<Context> &context, | |||
| const std::string &op_select_impl_mode) { | |||
| const std::vector<char> &op_select_impl_mode) { | |||
| MS_EXCEPTION_IF_NULL(context); | |||
| context->params[kModelOptionOpSelectImplMode] = op_select_impl_mode; | |||
| if (context->data == nullptr) { | |||
| context->data = std::make_shared<Data>(); | |||
| MS_EXCEPTION_IF_NULL(context->data); | |||
| } | |||
| context->data->params[kModelOptionOpSelectImplMode] = CharToString(op_select_impl_mode); | |||
| } | |||
| std::string ModelContext::GetOpSelectImplMode(const std::shared_ptr<Context> &context) { | |||
| std::vector<char> ModelContext::GetOpSelectImplModeChar(const std::shared_ptr<Context> &context) { | |||
| MS_EXCEPTION_IF_NULL(context); | |||
| return GetValue<std::string>(context, kModelOptionOpSelectImplMode); | |||
| const std::string &ref = GetValue<std::string>(context, kModelOptionOpSelectImplMode); | |||
| return StringToChar(ref); | |||
| } | |||
| } // namespace mindspore | |||
| @@ -13,38 +13,38 @@ | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "cxx_api/graph/ms/ms_graph_impl.h" | |||
| #include "cxx_api/graph/ascend/ascend_graph_impl.h" | |||
| #include <algorithm> | |||
| #include "include/api/context.h" | |||
| #include "cxx_api/factory.h" | |||
| #include "stub/graph_impl_stub.h" | |||
| namespace mindspore { | |||
| API_FACTORY_REG(GraphCell::GraphImpl, Ascend910, MsGraphImpl); | |||
| API_FACTORY_REG(GraphCell::GraphImpl, Ascend910, AscendGraphImpl); | |||
| std::shared_ptr<GraphCell::GraphImpl> MsGraphImpl::graph_imp_stub_ = std::make_shared<GraphImplStubAdd>(); | |||
| std::shared_ptr<GraphCell::GraphImpl> AscendGraphImpl::graph_imp_stub_ = std::make_shared<GraphImplStubAdd>(); | |||
| MsGraphImpl::MsGraphImpl() {} | |||
| AscendGraphImpl::AscendGraphImpl() {} | |||
| MsGraphImpl::~MsGraphImpl() {} | |||
| AscendGraphImpl::~AscendGraphImpl() {} | |||
| std::vector<MSTensor> MsGraphImpl::GetInputs() { | |||
| std::vector<MSTensor> AscendGraphImpl::GetInputs() { | |||
| if (!graph_imp_stub_) { | |||
| return {}; | |||
| } | |||
| return graph_imp_stub_->GetInputs(); | |||
| } | |||
| std::vector<MSTensor> MsGraphImpl::GetOutputs() { | |||
| std::vector<MSTensor> AscendGraphImpl::GetOutputs() { | |||
| if (!graph_imp_stub_) { | |||
| return {}; | |||
| } | |||
| return graph_imp_stub_->GetOutputs(); | |||
| } | |||
| Status MsGraphImpl::Load() { return kSuccess; } | |||
| Status AscendGraphImpl::Load() { return kSuccess; } | |||
| Status MsGraphImpl::Run(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) { | |||
| Status AscendGraphImpl::Run(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) { | |||
| if (!graph_imp_stub_) { | |||
| return kMCFailed; | |||
| } | |||
| @@ -13,25 +13,25 @@ | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_CXX_API_GRAPH_MS_MS_GRAPH_IMPL_H | |||
| #define MINDSPORE_CCSRC_CXX_API_GRAPH_MS_MS_GRAPH_IMPL_H | |||
| #ifndef MINDSPORE_CCSRC_CXX_API_GRAPH_MS_ASCEND_GRAPH_IMPL_H | |||
| #define MINDSPORE_CCSRC_CXX_API_GRAPH_MS_ASCEND_GRAPH_IMPL_H | |||
| #include <functional> | |||
| #include <map> | |||
| #include <string> | |||
| #include <vector> | |||
| #include <memory> | |||
| #include <utility> | |||
| #include <mutex> | |||
| #include "include/api/status.h" | |||
| #include "include/api/graph.h" | |||
| #include "cxx_api/graph/graph_impl.h" | |||
| #include "cxx_api/model/model_impl.h" | |||
| namespace mindspore { | |||
| class MsGraphImpl : public GraphCell::GraphImpl { | |||
| class AscendGraphImpl : public GraphCell::GraphImpl { | |||
| public: | |||
| MsGraphImpl(); | |||
| ~MsGraphImpl() override; | |||
| AscendGraphImpl(); | |||
| ~AscendGraphImpl() override; | |||
| Status Run(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) override; | |||
| Status Load() override; | |||
| @@ -43,4 +43,4 @@ class MsGraphImpl : public GraphCell::GraphImpl { | |||
| }; | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_CXX_API_GRAPH_MS_MS_GRAPH_IMPL_H | |||
| #endif // MINDSPORE_CCSRC_CXX_API_GRAPH_MS_ASCEND_GRAPH_IMPL_H | |||
| @@ -17,8 +17,16 @@ | |||
| #include "include/api/context.h" | |||
| #include "cxx_api/model/model_impl.h" | |||
| #include "cxx_api/factory.h" | |||
| #include "utils/utils.h" | |||
| namespace mindspore { | |||
| namespace { | |||
| const std::map<std::string, std::set<ModelType>> kSupportedModelMap = { | |||
| {kDeviceTypeAscend310, {kOM, kMindIR}}, | |||
| {kDeviceTypeAscend910, {kMindIR}}, | |||
| {kDeviceTypeGPU, {kMindIR}}, | |||
| }; | |||
| } | |||
| Status Model::Build() { | |||
| MS_EXCEPTION_IF_NULL(impl_); | |||
| return impl_->Build(); | |||
| @@ -60,8 +68,22 @@ Model::Model(const std::vector<Output> &network, const std::shared_ptr<Context> | |||
| Model::~Model() {} | |||
| bool Model::CheckModelSupport(const std::string &device_type, ModelType) { | |||
| return Factory<ModelImpl>::Instance().CheckModelSupport(device_type); | |||
| } | |||
| bool Model::CheckModelSupport(const std::vector<char> &device_type, ModelType model_type) { | |||
| std::string device_type_str = CharToString(device_type); | |||
| if (!Factory<ModelImpl>::Instance().CheckModelSupport(device_type_str)) { | |||
| return false; | |||
| } | |||
| auto first_iter = kSupportedModelMap.find(device_type_str); | |||
| if (first_iter == kSupportedModelMap.end()) { | |||
| return false; | |||
| } | |||
| auto secend_iter = first_iter->second.find(model_type); | |||
| if (secend_iter == first_iter->second.end()) { | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| } // namespace mindspore | |||
| @@ -25,6 +25,7 @@ | |||
| #include "include/api/model.h" | |||
| #include "include/api/graph.h" | |||
| #include "cxx_api/graph/graph_data.h" | |||
| #include "utils/utils.h" | |||
| namespace mindspore { | |||
| class ModelImpl { | |||
| @@ -63,7 +64,9 @@ class ModelImpl { | |||
| friend class Model; | |||
| void SetGraph(const std::shared_ptr<Graph> &graph) { graph_ = graph; } | |||
| void SetContext(const std::shared_ptr<Context> &model_context) { | |||
| model_context_ = std::make_shared<Context>(*model_context); | |||
| if (model_context != nullptr) { | |||
| model_context_ = std::make_shared<Context>(*model_context); | |||
| } | |||
| } | |||
| }; | |||
| } // namespace mindspore | |||
| @@ -66,6 +66,11 @@ Status MsModel::Build() { | |||
| MS_LOG(INFO) << "Start build model."; | |||
| MS_EXCEPTION_IF_NULL(graph_); | |||
| if (graph_cell_ != nullptr) { | |||
| MS_LOG(INFO) << "This model has been built, skip."; | |||
| return kSuccess; | |||
| } | |||
| auto func_graph = ModelImpl::GetFuncGraph(); | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| @@ -16,7 +16,7 @@ | |||
| #include "include/api/serialization.h" | |||
| #include <fstream> | |||
| #include "cxx_api/graph/graph_data.h" | |||
| #include "utils/utils.h" | |||
| #include "utils/log_adapter.h" | |||
| namespace mindspore { | |||
| static Buffer ReadFile(const std::string &file) { | |||
| @@ -77,13 +77,17 @@ Graph Serialization::LoadModel(const void *model_data, size_t data_size, ModelTy | |||
| MS_LOG(EXCEPTION) << "Unsupported ModelType " << model_type; | |||
| } | |||
| Graph Serialization::LoadModel(const std::string &file, ModelType model_type) { | |||
| Buffer data = ReadFile(file); | |||
| Graph Serialization::LoadModel(const std::vector<char> &file, ModelType model_type) { | |||
| std::string file_path = CharToString(file); | |||
| Buffer data = ReadFile(file_path); | |||
| if (data.Data() == nullptr) { | |||
| MS_LOG(EXCEPTION) << "Read file " << file << " failed."; | |||
| MS_LOG(EXCEPTION) << "Read file " << file_path << " failed."; | |||
| } | |||
| if (model_type == kMindIR) { | |||
| auto anf_graph = std::make_shared<FuncGraph>(); | |||
| if (anf_graph == nullptr) { | |||
| MS_LOG(EXCEPTION) << "Load model failed."; | |||
| } | |||
| return Graph(std::make_shared<Graph::GraphData>(anf_graph, kMindIR)); | |||
| } else if (model_type == kOM) { | |||
| return Graph(std::make_shared<Graph::GraphData>(data, kOM)); | |||
| @@ -0,0 +1,207 @@ | |||
| /** | |||
| * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). | |||
| * | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "include/api/status.h" | |||
| #ifndef ENABLE_ANDROID | |||
| #include <thread> | |||
| #endif | |||
| #include <map> | |||
| #include <sstream> | |||
| namespace mindspore { | |||
| struct Status::Data { | |||
| enum StatusCode status_code = kSuccess; | |||
| std::string status_msg; | |||
| int line_of_code = -1; | |||
| std::string file_name; | |||
| std::string err_description; | |||
| }; | |||
| Status::Status() : data_(std::make_shared<Data>()) {} | |||
| Status::Status(enum StatusCode status_code, const std::vector<char> &status_msg) : data_(std::make_shared<Data>()) { | |||
| if (data_ == nullptr) { | |||
| return; | |||
| } | |||
| data_->status_msg = CharToString(status_msg); | |||
| data_->status_code = status_code; | |||
| } | |||
| Status::Status(enum StatusCode code, int line_of_code, const char *file_name, const std::vector<char> &extra) | |||
| : data_(std::make_shared<Data>()) { | |||
| if (data_ == nullptr) { | |||
| return; | |||
| } | |||
| data_->status_code = code; | |||
| data_->line_of_code = line_of_code; | |||
| if (file_name != nullptr) { | |||
| data_->file_name = file_name; | |||
| } | |||
| data_->err_description = CharToString(extra); | |||
| std::ostringstream ss; | |||
| #ifndef ENABLE_ANDROID | |||
| ss << "Thread ID " << std::this_thread::get_id() << " " << CodeAsString(code) << ". "; | |||
| if (!data_->err_description.empty()) { | |||
| ss << data_->err_description; | |||
| } | |||
| ss << "\n"; | |||
| #endif | |||
| ss << "Line of code : " << line_of_code << "\n"; | |||
| if (file_name != nullptr) { | |||
| ss << "File : " << file_name << "\n"; | |||
| } | |||
| data_->status_msg = ss.str(); | |||
| } | |||
| enum StatusCode Status::StatusCode() const { | |||
| if (data_ == nullptr) { | |||
| return kSuccess; | |||
| } | |||
| return data_->status_code; | |||
| } | |||
| std::vector<char> Status::ToCString() const { | |||
| if (data_ == nullptr) { | |||
| return std::vector<char>(); | |||
| } | |||
| return StringToChar(data_->status_msg); | |||
| } | |||
| int Status::GetLineOfCode() const { | |||
| if (data_ == nullptr) { | |||
| return -1; | |||
| } | |||
| return data_->line_of_code; | |||
| } | |||
| std::vector<char> Status::GetErrDescriptionChar() const { | |||
| if (data_ == nullptr) { | |||
| return std::vector<char>(); | |||
| } | |||
| return StringToChar(data_->status_msg); | |||
| } | |||
| std::vector<char> Status::CodeAsCString(enum StatusCode c) { | |||
| static std::map<enum StatusCode, std::string> info_map = {{kSuccess, "No error occurs."}, | |||
| // Core | |||
| {kCoreFailed, "Common error code."}, | |||
| // MD | |||
| {kMDOutOfMemory, "Out of memory"}, | |||
| {kMDShapeMisMatch, "Shape is incorrect"}, | |||
| {kMDInterrupted, "Interrupted system call"}, | |||
| {kMDNoSpace, "No space left on device"}, | |||
| {kMDPyFuncException, "Exception thrown from PyFunc"}, | |||
| {kMDDuplicateKey, "Duplicate key"}, | |||
| {kMDPythonInterpreterFailure, ""}, | |||
| {kMDTDTPushFailure, "Unexpected error"}, | |||
| {kMDFileNotExist, "Unexpected error"}, | |||
| {kMDProfilingError, "Error encountered while profiling"}, | |||
| {kMDBoundingBoxOutOfBounds, "Unexpected error"}, | |||
| {kMDBoundingBoxInvalidShape, "Unexpected error"}, | |||
| {kMDSyntaxError, "Syntax error"}, | |||
| {kMDTimeOut, "Unexpected error"}, | |||
| {kMDBuddySpaceFull, "BuddySpace full"}, | |||
| {kMDNetWorkError, "Network error"}, | |||
| {kMDNotImplementedYet, "Unexpected error"}, | |||
| {kMDUnexpectedError, "Unexpected error"}, | |||
| // ME | |||
| {kMEFailed, "Common error code."}, | |||
| {kMEInvalidInput, "Invalid input."}, | |||
| // MC | |||
| {kMCFailed, "Common error code."}, | |||
| {kMCDeviceError, "Device error."}, | |||
| {kMCInvalidInput, "Invalid input."}, | |||
| {kMCInvalidArgs, "Invalid arguments."}, | |||
| // Lite | |||
| {kLiteError, "Common error code."}, | |||
| {kLiteNullptr, "NULL pointer returned."}, | |||
| {kLiteParamInvalid, "Invalid parameter."}, | |||
| {kLiteNoChange, "No change."}, | |||
| {kLiteSuccessExit, "No error but exit."}, | |||
| {kLiteMemoryFailed, "Fail to create memory."}, | |||
| {kLiteNotSupport, "Fail to support."}, | |||
| {kLiteThreadPoolError, "Thread pool error."}, | |||
| {kLiteOutOfTensorRange, "Failed to check range."}, | |||
| {kLiteInputTensorError, "Failed to check input tensor."}, | |||
| {kLiteReentrantError, "Exist executor running."}, | |||
| {kLiteGraphFileError, "Failed to verify graph file."}, | |||
| {kLiteNotFindOp, "Failed to find operator."}, | |||
| {kLiteInvalidOpName, "Invalid operator name."}, | |||
| {kLiteInvalidOpAttr, "Invalid operator attr."}, | |||
| {kLiteOpExecuteFailure, "Failed to execution operator."}, | |||
| {kLiteFormatError, "Failed to checking tensor format."}, | |||
| {kLiteInferError, "Failed to infer shape."}, | |||
| {kLiteInferInvalid, "Invalid infer shape before runtime."}, | |||
| {kLiteInputParamInvalid, "Invalid input param by user."}}; | |||
| auto iter = info_map.find(c); | |||
| return StringToChar(iter == info_map.end() ? "Unknown error" : iter->second); | |||
| } | |||
| std::ostream &operator<<(std::ostream &os, const Status &s) { | |||
| os << s.ToString(); | |||
| return os; | |||
| } | |||
| std::vector<char> Status::SetErrDescription(const std::vector<char> &err_description) { | |||
| if (data_ == nullptr) { | |||
| return std::vector<char>(); | |||
| } | |||
| data_->err_description = CharToString(err_description); | |||
| std::ostringstream ss; | |||
| #ifndef ENABLE_ANDROID | |||
| ss << "Thread ID " << std::this_thread::get_id() << " " << CodeAsString(data_->status_code) << ". "; | |||
| if (!data_->err_description.empty()) { | |||
| ss << data_->err_description; | |||
| } | |||
| ss << "\n"; | |||
| #endif | |||
| if (data_->line_of_code > 0 && !data_->file_name.empty()) { | |||
| ss << "Line of code : " << data_->line_of_code << "\n"; | |||
| ss << "File : " << data_->file_name << "\n"; | |||
| } | |||
| data_->status_msg = ss.str(); | |||
| return StringToChar(data_->status_msg); | |||
| } | |||
| bool Status::operator==(const Status &other) const { | |||
| if (data_ == nullptr && other.data_ == nullptr) { | |||
| return true; | |||
| } | |||
| if (data_ == nullptr || other.data_ == nullptr) { | |||
| return false; | |||
| } | |||
| return data_->status_code == other.data_->status_code; | |||
| } | |||
| bool Status::operator==(enum StatusCode other_code) const { return StatusCode() == other_code; } | |||
| bool Status::operator!=(const Status &other) const { return !operator==(other); } | |||
| bool Status::operator!=(enum StatusCode other_code) const { return !operator==(other_code); } | |||
| Status::operator bool() const { return (StatusCode() == kSuccess); } | |||
| Status::operator int() const { return static_cast<int>(StatusCode()); } | |||
| Status Status::OK() { return StatusCode::kSuccess; } | |||
| bool Status::IsOk() const { return (StatusCode() == StatusCode::kSuccess); } | |||
| bool Status::IsError() const { return !IsOk(); } | |||
| } // namespace mindspore | |||
| @@ -133,10 +133,11 @@ class TensorReferenceImpl : public MSTensor::Impl { | |||
| std::vector<int64_t> shape_; | |||
| }; | |||
| MSTensor MSTensor::CreateTensor(const std::string &name, enum DataType type, const std::vector<int64_t> &shape, | |||
| MSTensor MSTensor::CreateTensor(const std::vector<char> &name, enum DataType type, const std::vector<int64_t> &shape, | |||
| const void *data, size_t data_len) noexcept { | |||
| std::string name_str = CharToString(name); | |||
| try { | |||
| std::shared_ptr<Impl> impl = std::make_shared<TensorDefaultImpl>(name, type, shape, data, data_len); | |||
| std::shared_ptr<Impl> impl = std::make_shared<TensorDefaultImpl>(name_str, type, shape, data, data_len); | |||
| return MSTensor(impl); | |||
| } catch (const std::bad_alloc &) { | |||
| MS_LOG(ERROR) << "Malloc memory failed."; | |||
| @@ -147,10 +148,11 @@ MSTensor MSTensor::CreateTensor(const std::string &name, enum DataType type, con | |||
| } | |||
| } | |||
| MSTensor MSTensor::CreateRefTensor(const std::string &name, enum DataType type, const std::vector<int64_t> &shape, | |||
| MSTensor MSTensor::CreateRefTensor(const std::vector<char> &name, enum DataType type, const std::vector<int64_t> &shape, | |||
| const void *data, size_t data_len) noexcept { | |||
| std::string name_str = CharToString(name); | |||
| try { | |||
| std::shared_ptr<Impl> impl = std::make_shared<TensorReferenceImpl>(name, type, shape, data, data_len); | |||
| std::shared_ptr<Impl> impl = std::make_shared<TensorReferenceImpl>(name_str, type, shape, data, data_len); | |||
| return MSTensor(impl); | |||
| } catch (const std::bad_alloc &) { | |||
| MS_LOG(ERROR) << "Malloc memory failed."; | |||
| @@ -164,9 +166,9 @@ MSTensor MSTensor::CreateRefTensor(const std::string &name, enum DataType type, | |||
| MSTensor::MSTensor() : impl_(std::make_shared<TensorDefaultImpl>()) {} | |||
| MSTensor::MSTensor(std::nullptr_t) : impl_(nullptr) {} | |||
| MSTensor::MSTensor(const std::shared_ptr<Impl> &impl) : impl_(impl) { MS_EXCEPTION_IF_NULL(impl); } | |||
| MSTensor::MSTensor(const std::string &name, enum DataType type, const std::vector<int64_t> &shape, const void *data, | |||
| size_t data_len) | |||
| : impl_(std::make_shared<TensorDefaultImpl>(name, type, shape, data, data_len)) {} | |||
| MSTensor::MSTensor(const std::vector<char> &name, enum DataType type, const std::vector<int64_t> &shape, | |||
| const void *data, size_t data_len) | |||
| : impl_(std::make_shared<TensorDefaultImpl>(CharToString(name), type, shape, data, data_len)) {} | |||
| MSTensor::~MSTensor() = default; | |||
| bool MSTensor::operator==(std::nullptr_t) const { return impl_ == nullptr; } | |||
| @@ -178,9 +180,9 @@ MSTensor MSTensor::Clone() const { | |||
| return ret; | |||
| } | |||
| const std::string &MSTensor::Name() const { | |||
| std::vector<char> MSTensor::CharName() const { | |||
| MS_EXCEPTION_IF_NULL(impl_); | |||
| return impl_->Name(); | |||
| return StringToChar(impl_->Name()); | |||
| } | |||
| enum DataType MSTensor::DataType() const { | |||
| @@ -16,49 +16,120 @@ | |||
| #ifndef MINDSPORE_INCLUDE_API_CONTEXT_H | |||
| #define MINDSPORE_INCLUDE_API_CONTEXT_H | |||
| #include <map> | |||
| #include <any> | |||
| #include <string> | |||
| #include <memory> | |||
| #include <vector> | |||
| #include "include/api/types.h" | |||
| #include "include/api/dual_abi_helper.h" | |||
| namespace mindspore { | |||
| constexpr auto kDeviceTypeAscend310 = "Ascend310"; | |||
| constexpr auto kDeviceTypeAscend910 = "Ascend910"; | |||
| constexpr auto kDeviceTypeGPU = "GPU"; | |||
| struct MS_API Context { | |||
| public: | |||
| Context(); | |||
| virtual ~Context() = default; | |||
| std::map<std::string, std::any> params; | |||
| struct Data; | |||
| std::shared_ptr<Data> data; | |||
| }; | |||
| struct MS_API GlobalContext : public Context { | |||
| public: | |||
| static std::shared_ptr<Context> GetGlobalContext(); | |||
| static void SetGlobalDeviceTarget(const std::string &device_target); | |||
| static std::string GetGlobalDeviceTarget(); | |||
| static inline void SetGlobalDeviceTarget(const std::string &device_target); | |||
| static inline std::string GetGlobalDeviceTarget(); | |||
| static void SetGlobalDeviceID(const uint32_t &device_id); | |||
| static uint32_t GetGlobalDeviceID(); | |||
| private: | |||
| // api without std::string | |||
| static void SetGlobalDeviceTarget(const std::vector<char> &device_target); | |||
| static std::vector<char> GetGlobalDeviceTargetChar(); | |||
| }; | |||
| struct MS_API ModelContext : public Context { | |||
| static void SetInsertOpConfigPath(const std::shared_ptr<Context> &context, const std::string &cfg_path); | |||
| static std::string GetInsertOpConfigPath(const std::shared_ptr<Context> &context); | |||
| public: | |||
| static inline void SetInsertOpConfigPath(const std::shared_ptr<Context> &context, const std::string &cfg_path); | |||
| static inline std::string GetInsertOpConfigPath(const std::shared_ptr<Context> &context); | |||
| static void SetInputFormat(const std::shared_ptr<Context> &context, const std::string &format); | |||
| static std::string GetInputFormat(const std::shared_ptr<Context> &context); | |||
| static inline void SetInputFormat(const std::shared_ptr<Context> &context, const std::string &format); | |||
| static inline std::string GetInputFormat(const std::shared_ptr<Context> &context); | |||
| static void SetInputShape(const std::shared_ptr<Context> &context, const std::string &shape); | |||
| static std::string GetInputShape(const std::shared_ptr<Context> &context); | |||
| static inline void SetInputShape(const std::shared_ptr<Context> &context, const std::string &shape); | |||
| static inline std::string GetInputShape(const std::shared_ptr<Context> &context); | |||
| static void SetOutputType(const std::shared_ptr<Context> &context, enum DataType output_type); | |||
| static enum DataType GetOutputType(const std::shared_ptr<Context> &context); | |||
| static void SetPrecisionMode(const std::shared_ptr<Context> &context, const std::string &precision_mode); | |||
| static std::string GetPrecisionMode(const std::shared_ptr<Context> &context); | |||
| static inline void SetPrecisionMode(const std::shared_ptr<Context> &context, const std::string &precision_mode); | |||
| static inline std::string GetPrecisionMode(const std::shared_ptr<Context> &context); | |||
| static inline void SetOpSelectImplMode(const std::shared_ptr<Context> &context, | |||
| const std::string &op_select_impl_mode); | |||
| static inline std::string GetOpSelectImplMode(const std::shared_ptr<Context> &context); | |||
| private: | |||
| // api without std::string | |||
| static void SetInsertOpConfigPath(const std::shared_ptr<Context> &context, const std::vector<char> &cfg_path); | |||
| static std::vector<char> GetInsertOpConfigPathChar(const std::shared_ptr<Context> &context); | |||
| static void SetInputFormat(const std::shared_ptr<Context> &context, const std::vector<char> &format); | |||
| static std::vector<char> GetInputFormatChar(const std::shared_ptr<Context> &context); | |||
| static void SetInputShape(const std::shared_ptr<Context> &context, const std::vector<char> &shape); | |||
| static std::vector<char> GetInputShapeChar(const std::shared_ptr<Context> &context); | |||
| static void SetPrecisionMode(const std::shared_ptr<Context> &context, const std::vector<char> &precision_mode); | |||
| static std::vector<char> GetPrecisionModeChar(const std::shared_ptr<Context> &context); | |||
| static void SetOpSelectImplMode(const std::shared_ptr<Context> &context, const std::string &op_select_impl_mode); | |||
| static std::string GetOpSelectImplMode(const std::shared_ptr<Context> &context); | |||
| static void SetOpSelectImplMode(const std::shared_ptr<Context> &context, | |||
| const std::vector<char> &op_select_impl_mode); | |||
| static std::vector<char> GetOpSelectImplModeChar(const std::shared_ptr<Context> &context); | |||
| }; | |||
| void GlobalContext::SetGlobalDeviceTarget(const std::string &device_target) { | |||
| SetGlobalDeviceTarget(StringToChar(device_target)); | |||
| } | |||
| std::string GlobalContext::GetGlobalDeviceTarget() { return CharToString(GetGlobalDeviceTargetChar()); } | |||
| void ModelContext::SetInsertOpConfigPath(const std::shared_ptr<Context> &context, const std::string &cfg_path) { | |||
| SetInsertOpConfigPath(context, StringToChar(cfg_path)); | |||
| } | |||
| std::string ModelContext::GetInsertOpConfigPath(const std::shared_ptr<Context> &context) { | |||
| return CharToString(GetInsertOpConfigPathChar(context)); | |||
| } | |||
| void ModelContext::SetInputFormat(const std::shared_ptr<Context> &context, const std::string &format) { | |||
| SetInputFormat(context, StringToChar(format)); | |||
| } | |||
| std::string ModelContext::GetInputFormat(const std::shared_ptr<Context> &context) { | |||
| return CharToString(GetInputFormatChar(context)); | |||
| } | |||
| void ModelContext::SetInputShape(const std::shared_ptr<Context> &context, const std::string &shape) { | |||
| SetInputShape(context, StringToChar(shape)); | |||
| } | |||
| std::string ModelContext::GetInputShape(const std::shared_ptr<Context> &context) { | |||
| return CharToString(GetInputShapeChar(context)); | |||
| } | |||
| void ModelContext::SetPrecisionMode(const std::shared_ptr<Context> &context, const std::string &precision_mode) { | |||
| SetPrecisionMode(context, StringToChar(precision_mode)); | |||
| } | |||
| std::string ModelContext::GetPrecisionMode(const std::shared_ptr<Context> &context) { | |||
| return CharToString(GetPrecisionModeChar(context)); | |||
| } | |||
| void ModelContext::SetOpSelectImplMode(const std::shared_ptr<Context> &context, | |||
| const std::string &op_select_impl_mode) { | |||
| SetOpSelectImplMode(context, StringToChar(op_select_impl_mode)); | |||
| } | |||
| std::string ModelContext::GetOpSelectImplMode(const std::shared_ptr<Context> &context) { | |||
| return CharToString(GetOpSelectImplModeChar(context)); | |||
| } | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_INCLUDE_API_CONTEXT_H | |||
| @@ -1,7 +1,5 @@ | |||
| /** | |||
| * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). | |||
| * | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -15,7 +13,6 @@ | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_INCLUDE_API_DATA_TYPE_H_ | |||
| #define MINDSPORE_INCLUDE_API_DATA_TYPE_H_ | |||
| @@ -0,0 +1,26 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_INCLUDE_API_DUAL_ABI_HELPER_H_ | |||
| #define MINDSPORE_INCLUDE_API_DUAL_ABI_HELPER_H_ | |||
| #include <string> | |||
| #include <vector> | |||
| namespace mindspore { | |||
| inline std::vector<char> StringToChar(const std::string &s) { return std::vector<char>(s.begin(), s.end()); } | |||
| inline std::string CharToString(const std::vector<char> &c) { return std::string(c.begin(), c.end()); } | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_INCLUDE_API_DUAL_ABI_HELPER_H_ | |||
| @@ -17,7 +17,6 @@ | |||
| #define MINDSPORE_INCLUDE_API_GRAPH_H | |||
| #include <cstddef> | |||
| #include <string> | |||
| #include <vector> | |||
| #include <map> | |||
| #include <memory> | |||
| @@ -25,6 +25,7 @@ | |||
| #include "include/api/types.h" | |||
| #include "include/api/graph.h" | |||
| #include "include/api/cell.h" | |||
| #include "include/api/dual_abi_helper.h" | |||
| namespace mindspore { | |||
| class ModelImpl; | |||
| @@ -46,10 +47,16 @@ class MS_API Model { | |||
| std::vector<MSTensor> GetInputs(); | |||
| std::vector<MSTensor> GetOutputs(); | |||
| static bool CheckModelSupport(const std::string &device_type, ModelType model_type); | |||
| static inline bool CheckModelSupport(const std::string &device_type, ModelType model_type); | |||
| private: | |||
| // api without std::string | |||
| static bool CheckModelSupport(const std::vector<char> &device_type, ModelType model_type); | |||
| std::shared_ptr<ModelImpl> impl_; | |||
| }; | |||
| bool Model::CheckModelSupport(const std::string &device_type, ModelType model_type) { | |||
| return CheckModelSupport(StringToChar(device_type), model_type); | |||
| } | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_INCLUDE_API_MODEL_H | |||
| @@ -24,16 +24,24 @@ | |||
| #include "include/api/types.h" | |||
| #include "include/api/model.h" | |||
| #include "include/api/graph.h" | |||
| #include "include/api/dual_abi_helper.h" | |||
| namespace mindspore { | |||
| class MS_API Serialization { | |||
| public: | |||
| static Graph LoadModel(const void *model_data, size_t data_size, ModelType model_type); | |||
| static Graph LoadModel(const std::string &file, ModelType model_type); | |||
| inline static Graph LoadModel(const std::string &file, ModelType model_type); | |||
| static Status LoadCheckPoint(const std::string &ckpt_file, std::map<std::string, Buffer> *parameters); | |||
| static Status SetParameters(const std::map<std::string, Buffer> ¶meters, Model *model); | |||
| static Status ExportModel(const Model &model, ModelType model_type, Buffer *model_data); | |||
| static Status ExportModel(const Model &model, ModelType model_type, const std::string &model_file); | |||
| private: | |||
| static Graph LoadModel(const std::vector<char> &file, ModelType model_type); | |||
| }; | |||
| Graph Serialization::LoadModel(const std::string &file, ModelType model_type) { | |||
| return LoadModel(StringToChar(file), model_type); | |||
| } | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_INCLUDE_API_SERIALIZATION_H | |||
| @@ -16,9 +16,13 @@ | |||
| #ifndef MINDSPORE_INCLUDE_API_STATUS_H | |||
| #define MINDSPORE_INCLUDE_API_STATUS_H | |||
| #include <memory> | |||
| #include <string> | |||
| #include <vector> | |||
| #include <ostream> | |||
| #include <climits> | |||
| #include "include/api/dual_abi_helper.h" | |||
| #include "include/api/types.h" | |||
| namespace mindspore { | |||
| enum CompCode : uint32_t { | |||
| @@ -100,39 +104,61 @@ enum StatusCode : uint32_t { | |||
| kLiteInputParamInvalid = kLite | (0x0FFFFFFF & -600), /**< Invalid input param by user. */ | |||
| }; | |||
| class Status { | |||
| class MS_API Status { | |||
| public: | |||
| Status() : status_code_(kSuccess) {} | |||
| Status(enum StatusCode status_code, const std::string &status_msg = "") // NOLINT(runtime/explicit) | |||
| : status_code_(status_code), status_msg_(status_msg) {} | |||
| Status(const StatusCode code, int line_of_code, const char *file_name, const std::string &extra = ""); | |||
| Status(); | |||
| inline Status(enum StatusCode status_code, const std::string &status_msg = ""); // NOLINT(runtime/explicit) | |||
| inline Status(const StatusCode code, int line_of_code, const char *file_name, const std::string &extra = ""); | |||
| ~Status() = default; | |||
| enum StatusCode StatusCode() const { return status_code_; } | |||
| const std::string &ToString() const { return status_msg_; } | |||
| enum StatusCode StatusCode() const; | |||
| inline std::string ToString() const; | |||
| int GetLineOfCode() const; | |||
| inline std::string GetErrDescription() const; | |||
| inline std::string SetErrDescription(const std::string &err_description); | |||
| friend std::ostream &operator<<(std::ostream &os, const Status &s); | |||
| bool operator==(const Status &other) const { return status_code_ == other.status_code_; } | |||
| bool operator==(enum StatusCode other_code) const { return status_code_ == other_code; } | |||
| bool operator!=(const Status &other) const { return status_code_ != other.status_code_; } | |||
| bool operator!=(enum StatusCode other_code) const { return status_code_ != other_code; } | |||
| bool operator==(const Status &other) const; | |||
| bool operator==(enum StatusCode other_code) const; | |||
| bool operator!=(const Status &other) const; | |||
| bool operator!=(enum StatusCode other_code) const; | |||
| explicit operator bool() const { return (status_code_ == kSuccess); } | |||
| explicit operator int() const { return static_cast<int>(status_code_); } | |||
| explicit operator bool() const; | |||
| explicit operator int() const; | |||
| static Status OK() { return Status(StatusCode::kSuccess); } | |||
| static Status OK(); | |||
| bool IsOk() const { return (StatusCode() == StatusCode::kSuccess); } | |||
| bool IsOk() const; | |||
| bool IsError() const { return !IsOk(); } | |||
| bool IsError() const; | |||
| static std::string CodeAsString(enum StatusCode c); | |||
| static inline std::string CodeAsString(enum StatusCode c); | |||
| private: | |||
| enum StatusCode status_code_; | |||
| std::string status_msg_; | |||
| // api without std::string | |||
| explicit Status(enum StatusCode status_code, const std::vector<char> &status_msg); | |||
| Status(const enum StatusCode code, int line_of_code, const char *file_name, const std::vector<char> &extra); | |||
| std::vector<char> ToCString() const; | |||
| std::vector<char> GetErrDescriptionChar() const; | |||
| std::vector<char> SetErrDescription(const std::vector<char> &err_description); | |||
| static std::vector<char> CodeAsCString(enum StatusCode c); | |||
| struct Data; | |||
| std::shared_ptr<Data> data_; | |||
| }; | |||
| Status::Status(enum StatusCode status_code, const std::string &status_msg) | |||
| : Status(status_code, StringToChar(status_msg)) {} | |||
| Status::Status(const enum StatusCode code, int line_of_code, const char *file_name, const std::string &extra) | |||
| : Status(code, line_of_code, file_name, StringToChar(extra)) {} | |||
| std::string Status::ToString() const { return CharToString(ToCString()); } | |||
| std::string Status::GetErrDescription() const { return CharToString(GetErrDescriptionChar()); } | |||
| std::string Status::SetErrDescription(const std::string &err_description) { | |||
| return CharToString(SetErrDescription(StringToChar(err_description))); | |||
| } | |||
| std::string Status::CodeAsString(enum StatusCode c) { return CharToString(CodeAsCString(c)); } | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_INCLUDE_API_STATUS_H | |||
| @@ -21,22 +21,10 @@ | |||
| #include <vector> | |||
| #include <memory> | |||
| #include "include/api/data_type.h" | |||
| #include "include/api/dual_abi_helper.h" | |||
| // refer to https://gcc.gnu.org/wiki/Visibility | |||
| #if defined _WIN32 || defined __CYGWIN__ | |||
| #ifdef BUILDING_DLL | |||
| #ifdef __GNUC__ | |||
| #define MS_API __attribute__((dllexport)) | |||
| #else | |||
| #define MS_API __declspec(dllexport) // Note: actually gcc seems to also supports this syntax. | |||
| #endif | |||
| #else | |||
| #ifdef __GNUC__ | |||
| #define MS_API __attribute__((dllimport)) | |||
| #else | |||
| #define MS_API __declspec(dllimport) // Note: actually gcc seems to also supports this syntax. | |||
| #endif | |||
| #endif | |||
| #ifdef _WIN32 | |||
| #define MS_API __declspec(dllexport) | |||
| #else | |||
| #define MS_API __attribute__((visibility("default"))) | |||
| #endif | |||
| @@ -55,18 +43,18 @@ class MS_API MSTensor { | |||
| public: | |||
| class Impl; | |||
| static MSTensor CreateTensor(const std::string &name, DataType type, const std::vector<int64_t> &shape, | |||
| const void *data, size_t data_len) noexcept; | |||
| static MSTensor CreateRefTensor(const std::string &name, DataType type, const std::vector<int64_t> &shape, | |||
| const void *data, size_t data_len) noexcept; | |||
| static inline MSTensor CreateTensor(const std::string &name, DataType type, const std::vector<int64_t> &shape, | |||
| const void *data, size_t data_len) noexcept; | |||
| static inline MSTensor CreateRefTensor(const std::string &name, DataType type, const std::vector<int64_t> &shape, | |||
| const void *data, size_t data_len) noexcept; | |||
| MSTensor(); | |||
| explicit MSTensor(const std::shared_ptr<Impl> &impl); | |||
| MSTensor(const std::string &name, DataType type, const std::vector<int64_t> &shape, const void *data, | |||
| size_t data_len); | |||
| inline MSTensor(const std::string &name, DataType type, const std::vector<int64_t> &shape, const void *data, | |||
| size_t data_len); | |||
| ~MSTensor(); | |||
| const std::string &Name() const; | |||
| inline std::string Name() const; | |||
| enum DataType DataType() const; | |||
| const std::vector<int64_t> &Shape() const; | |||
| int64_t ElementNum() const; | |||
| @@ -81,6 +69,15 @@ class MS_API MSTensor { | |||
| bool operator==(std::nullptr_t) const; | |||
| private: | |||
| // api without std::string | |||
| static MSTensor CreateTensor(const std::vector<char> &name, enum DataType type, const std::vector<int64_t> &shape, | |||
| const void *data, size_t data_len) noexcept; | |||
| static MSTensor CreateRefTensor(const std::vector<char> &name, enum DataType type, const std::vector<int64_t> &shape, | |||
| const void *data, size_t data_len) noexcept; | |||
| MSTensor(const std::vector<char> &name, enum DataType type, const std::vector<int64_t> &shape, const void *data, | |||
| size_t data_len); | |||
| std::vector<char> CharName() const; | |||
| friend class ModelImpl; | |||
| explicit MSTensor(std::nullptr_t); | |||
| std::shared_ptr<Impl> impl_; | |||
| @@ -123,5 +120,21 @@ class MS_API Buffer { | |||
| class Impl; | |||
| std::shared_ptr<Impl> impl_; | |||
| }; | |||
| MSTensor MSTensor::CreateTensor(const std::string &name, enum DataType type, const std::vector<int64_t> &shape, | |||
| const void *data, size_t data_len) noexcept { | |||
| return CreateTensor(StringToChar(name), type, shape, data, data_len); | |||
| } | |||
| MSTensor MSTensor::CreateRefTensor(const std::string &name, enum DataType type, const std::vector<int64_t> &shape, | |||
| const void *data, size_t data_len) noexcept { | |||
| return CreateRefTensor(StringToChar(name), type, shape, data, data_len); | |||
| } | |||
| MSTensor::MSTensor(const std::string &name, enum DataType type, const std::vector<int64_t> &shape, const void *data, | |||
| size_t data_len) | |||
| : MSTensor(StringToChar(name), type, shape, data, data_len) {} | |||
| std::string MSTensor::Name() const { return CharToString(CharName()); } | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_INCLUDE_API_TYPES_H | |||
| @@ -21,6 +21,7 @@ | |||
| #include <atomic> | |||
| #include <string> | |||
| #include <vector> | |||
| #include <set> | |||
| #include "utils/log_adapter.h" | |||
| namespace mindspore { | |||
| @@ -1 +1 @@ | |||
| Subproject commit 52fac12367131ec57e87ba757e42fc25479f433a | |||
| Subproject commit dd22b5ea7106baf494704be04e2dbaad6887f0ab | |||