diff --git a/example/matmul_distributed/export_model/distributed_inference.py b/example/matmul_distributed/export_model/distributed_inference.py index d9d94a4..1095538 100644 --- a/example/matmul_distributed/export_model/distributed_inference.py +++ b/example/matmul_distributed/export_model/distributed_inference.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/example/matmul_distributed/export_model/net.py b/example/matmul_distributed/export_model/net.py index feeb499..53d88b6 100644 --- a/example/matmul_distributed/export_model/net.py +++ b/example/matmul_distributed/export_model/net.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/mindspore_serving/ccsrc/common/grpc_client.h b/mindspore_serving/ccsrc/common/grpc_client.h index c784713..0991d0b 100644 --- a/mindspore_serving/ccsrc/common/grpc_client.h +++ b/mindspore_serving/ccsrc/common/grpc_client.h @@ -39,7 +39,7 @@ namespace serving { using PredictOnFinish = std::function; -using DispatchCallback = std::function; +using AsyncPredictCallback = std::function; template class MSServiceClient { @@ -80,7 +80,7 @@ class MSServiceClient { } } - void PredictAsync(const Request &request, Reply *reply, MSStub *stub, DispatchCallback callback) { + void PredictAsync(const Request &request, Reply *reply, MSStub *stub, AsyncPredictCallback callback) { AsyncClientCall *call = new AsyncClientCall; call->reply = reply; call->callback = std::move(callback); @@ -95,7 +95,7 @@ class MSServiceClient { grpc::ClientContext context; grpc::Status status; Reply *reply; - DispatchCallback callback; + AsyncPredictCallback callback; std::shared_ptr> response_reader; }; diff --git a/mindspore_serving/ccsrc/common/proto_tensor.cc b/mindspore_serving/ccsrc/common/proto_tensor.cc index 98b779c..dfecdb8 100644 --- a/mindspore_serving/ccsrc/common/proto_tensor.cc +++ b/mindspore_serving/ccsrc/common/proto_tensor.cc @@ -256,8 +256,17 @@ Status GrpcTensorHelper::CreateInstanceFromRequest(const proto::PredictRequest & return SUCCESS; } -Status GrpcTensorHelper::CreateReplyFromInstances(const proto::PredictRequest &request, - const vector &instances, proto::PredictReply *reply) { +void GrpcTensorHelper::CreateReplyFromInstances(const proto::PredictRequest &request, + const vector &instances, proto::PredictReply *reply) { + auto status = CreateReplyFromInstancesInner(request, instances, reply); + if (status != SUCCESS) { + CreateReplyFromErrorMsg(status, reply); + } +} + +Status GrpcTensorHelper::CreateReplyFromInstancesInner(const proto::PredictRequest &request, + const std::vector &instances, + proto::PredictReply *reply) { MSI_EXCEPTION_IF_NULL(reply); *reply->mutable_servable_spec() = request.servable_spec(); if (instances.empty()) { @@ -422,6 +431,23 @@ Status GrpcTensorHelper::CheckRequestTensor(const proto::Tensor &tensor) { return SUCCESS; } +void GrpcTensorHelper::CreateReplyFromErrorMsg(const Status &error_msg, proto::PredictReply *reply) { + MSI_EXCEPTION_IF_NULL(reply); + if (error_msg == SUCCESS) { + return; + } + reply->clear_error_msg(); + reply->clear_instances(); + auto proto_error_msg = reply->add_error_msg(); + proto_error_msg->set_error_code(FAILED); + std::string error_msg_str = error_msg.StatusMessage(); + if (error_msg_str.empty()) { + proto_error_msg->set_error_msg("Predict failed"); + } else { + proto_error_msg->set_error_msg(error_msg_str); + } +} + serving::LogStream &operator<<(serving::LogStream &stream, proto::DataType data_type) { const std::map type_name_map{ {proto::MS_UNKNOWN, "proto::MS_UNKNOWN"}, {proto::MS_BOOL, "proto::kMSI_Bool"}, diff --git a/mindspore_serving/ccsrc/common/proto_tensor.h b/mindspore_serving/ccsrc/common/proto_tensor.h index 518922b..a40c18c 100644 --- a/mindspore_serving/ccsrc/common/proto_tensor.h +++ b/mindspore_serving/ccsrc/common/proto_tensor.h @@ -67,8 +67,9 @@ class MS_API GrpcTensorHelper { static void GetWorkerSpec(const proto::RemoveWorkerRequest &request, WorkerSpec *worker_spec); static Status CreateInstanceFromRequest(const proto::PredictRequest &request, RequestSpec *request_spec, std::vector *results); - static Status CreateReplyFromInstances(const proto::PredictRequest &request, - const std::vector &instances, proto::PredictReply *reply); + static void CreateReplyFromInstances(const proto::PredictRequest &request, const std::vector &instances, + proto::PredictReply *reply); + static void CreateReplyFromErrorMsg(const Status &error_msg, proto::PredictReply *reply); static void CopyFromAgentSpec(const proto::AgentSpec &request, WorkerAgentSpec *worker_specs); static void CopyFromWorkerAgentSpec(const std::vector &worker_specs, proto::AgentRegisterRequest *request); @@ -78,6 +79,9 @@ class MS_API GrpcTensorHelper { const std::vector &input_names, std::vector *results); static Status CheckRequestTensor(const proto::Tensor &tensor); + + static Status CreateReplyFromInstancesInner(const proto::PredictRequest &request, + const std::vector &instances, proto::PredictReply *reply); }; extern MS_API LogStream &operator<<(serving::LogStream &stream, proto::DataType data_type); diff --git a/mindspore_serving/ccsrc/master/dispacther.cc b/mindspore_serving/ccsrc/master/dispacther.cc index bd53371..92bf6e6 100644 --- a/mindspore_serving/ccsrc/master/dispacther.cc +++ b/mindspore_serving/ccsrc/master/dispacther.cc @@ -55,25 +55,47 @@ DispatcherWorkerContext Dispatcher::GetWorkSession(const RequestSpec &request_sp return context; } -Status Dispatcher::Dispatch(const proto::PredictRequest &request, proto::PredictReply *reply) { +void Dispatcher::Dispatch(const proto::PredictRequest &request, proto::PredictReply *reply) { MSI_EXCEPTION_IF_NULL(reply); - auto promise = std::make_shared, Status>>(std::make_pair(std::promise(), FAILED)); - auto future = promise->first.get_future(); - DispatchCallback callback = [promise](Status status) { - promise->second = status; - promise->first.set_value(); - }; - auto status = DispatchAsync(request, reply, callback); + auto promise = std::make_shared>(); + auto future = promise->get_future(); + PredictOnFinish on_finish = [promise]() { promise->set_value(); }; + DispatchAsync(request, reply, on_finish); + future.get(); // wait callback finish +} + +void Dispatcher::DispatchAsync(const proto::PredictRequest &request, proto::PredictReply *reply, + PredictOnFinish on_finish) { + MSI_EXCEPTION_IF_NULL(reply); + Status status; + (*reply->mutable_servable_spec()) = request.servable_spec(); + try { + MSI_TIME_STAMP_START(Predict) + status = DispatchAsyncInner(request, reply, on_finish); + MSI_TIME_STAMP_END(Predict) + } catch (const std::bad_alloc &ex) { + MSI_LOG(ERROR) << "Serving Error: malloc memory failed"; + std::cout << "Serving Error: malloc memory failed" << std::endl; + } catch (const std::runtime_error &ex) { + MSI_LOG(ERROR) << "Serving Error: runtime error occurred: " << ex.what(); + std::cout << "Serving Error: runtime error occurred: " << ex.what() << std::endl; + } catch (const std::exception &ex) { + MSI_LOG(ERROR) << "Serving Error: exception occurred: " << ex.what(); + std::cout << "Serving Error: exception occurred: " << ex.what() << std::endl; + } catch (...) { + MSI_LOG(ERROR) << "Serving Error: exception occurred"; + std::cout << "Serving Error: exception occurred"; + } + MSI_LOG(INFO) << "Finish call service Eval"; + if (status != SUCCESS) { - MSI_LOG_ERROR << "DispatchAsync failed"; - return status; + GrpcTensorHelper::CreateReplyFromErrorMsg(status, reply); + on_finish(); } - future.get(); // wait callback finish - return promise->second; } -Status Dispatcher::DispatchAsync(const proto::PredictRequest &request, proto::PredictReply *reply, - DispatchCallback callback) { +Status Dispatcher::DispatchAsyncInner(const proto::PredictRequest &request, proto::PredictReply *reply, + PredictOnFinish on_finish) { MSI_EXCEPTION_IF_NULL(reply); std::shared_lock lock(servable_shared_lock_); RequestSpec request_spec; @@ -88,7 +110,7 @@ Status Dispatcher::DispatchAsync(const proto::PredictRequest &request, proto::Pr if (!find_method) { return INFER_STATUS_LOG_ERROR(INVALID_INPUTS) << "Request " << request_spec.Repr() << ", method is not available"; } - return worker.notify_worker_->DispatchAsync(request, reply, std::move(callback)); + return worker.notify_worker_->DispatchAsync(request, reply, std::move(on_finish)); } Status Dispatcher::RegisterServableCommon(const std::vector &worker_specs, CreateNotifyWorkerFunc func) { @@ -216,7 +238,7 @@ Status Dispatcher::RegisterServable(const proto::RegisterRequest &request, proto std::vector worker_specs; GrpcTensorHelper::GetWorkerSpec(request, &worker_specs); auto create_notify_worker = [](const WorkerSpec &worker_spec) { - std::shared_ptr notify_worker = std::make_shared(worker_spec.worker_address); + std::shared_ptr notify_worker = std::make_shared(worker_spec.worker_address); return notify_worker; }; return RegisterServableCommon(worker_specs, create_notify_worker); @@ -232,7 +254,7 @@ Status Dispatcher::AddServable(const proto::AddWorkerRequest &request, proto::Ad GrpcTensorHelper::GetWorkerSpec(request, &worker_spec); auto create_notify_worker = [](const WorkerSpec &worker_spec) { - std::shared_ptr notify_worker = std::make_shared(worker_spec.worker_address); + std::shared_ptr notify_worker = std::make_shared(worker_spec.worker_address); return notify_worker; }; return AddServableCommon(worker_spec, create_notify_worker); diff --git a/mindspore_serving/ccsrc/master/dispacther.h b/mindspore_serving/ccsrc/master/dispacther.h index 4e5e406..73f5235 100644 --- a/mindspore_serving/ccsrc/master/dispacther.h +++ b/mindspore_serving/ccsrc/master/dispacther.h @@ -40,8 +40,8 @@ class MS_API Dispatcher { public: Dispatcher(); ~Dispatcher(); - Status Dispatch(const proto::PredictRequest &request, proto::PredictReply *reply); - Status DispatchAsync(const proto::PredictRequest &request, proto::PredictReply *reply, DispatchCallback callback); + void Dispatch(const proto::PredictRequest &request, proto::PredictReply *reply); + void DispatchAsync(const proto::PredictRequest &request, proto::PredictReply *reply, PredictOnFinish on_finish); Status RegisterServable(const proto::RegisterRequest &request, proto::RegisterReply *reply); Status UnregisterServable(const proto::ExitRequest &request, proto::ExitReply *reply); @@ -71,6 +71,9 @@ class MS_API Dispatcher { Status UnregisterServableCommon(const std::string &worker_address); Status AddServableCommon(const WorkerSpec &worker_spec, CreateNotifyWorkerFunc func); Status RemoveServableCommon(const WorkerSpec &worker_spec); + + Status DispatchAsyncInner(const proto::PredictRequest &request, proto::PredictReply *reply, + PredictOnFinish on_finish); }; } // namespace mindspore::serving diff --git a/mindspore_serving/ccsrc/master/grpc/grpc_process.cc b/mindspore_serving/ccsrc/master/grpc/grpc_process.cc index 1022900..ab0f85c 100644 --- a/mindspore_serving/ccsrc/master/grpc/grpc_process.cc +++ b/mindspore_serving/ccsrc/master/grpc/grpc_process.cc @@ -36,53 +36,9 @@ std::string GetProtorWorkerSpecRepr(const proto::WorkerSpec &worker_spec) { } } // namespace -Status MSServiceImpl::PredictAsync(const proto::PredictRequest *request, proto::PredictReply *reply, - DispatchCallback callback) { - MSI_EXCEPTION_IF_NULL(request); - MSI_EXCEPTION_IF_NULL(reply); - Status status(FAILED); - auto on_status = [request, reply](Status status) { - if (status != SUCCESS) { - (*reply->mutable_servable_spec()) = request->servable_spec(); - reply->clear_error_msg(); - auto proto_error_msg = reply->add_error_msg(); - proto_error_msg->set_error_code(FAILED); - std::string error_msg = status.StatusMessage(); - if (error_msg.empty()) { - proto_error_msg->set_error_msg("Predict failed"); - } else { - proto_error_msg->set_error_msg(error_msg); - } - } - }; - DispatchCallback callback_with_status_handle = [callback, on_status](Status status) { - on_status(status); - callback(status); - }; - try { - MSI_TIME_STAMP_START(Predict) - status = dispatcher_->DispatchAsync(*request, reply, callback_with_status_handle); - MSI_TIME_STAMP_END(Predict) - } catch (const std::bad_alloc &ex) { - MSI_LOG(ERROR) << "Serving Error: malloc memory failed"; - std::cout << "Serving Error: malloc memory failed" << std::endl; - } catch (const std::runtime_error &ex) { - MSI_LOG(ERROR) << "Serving Error: runtime error occurred: " << ex.what(); - std::cout << "Serving Error: runtime error occurred: " << ex.what() << std::endl; - } catch (const std::exception &ex) { - MSI_LOG(ERROR) << "Serving Error: exception occurred: " << ex.what(); - std::cout << "Serving Error: exception occurred: " << ex.what() << std::endl; - } catch (...) { - MSI_LOG(ERROR) << "Serving Error: exception occurred"; - std::cout << "Serving Error: exception occurred"; - } - MSI_LOG(INFO) << "Finish call service Eval"; - - if (status != SUCCESS) { - on_status(status); - return status; - } - return SUCCESS; +void MSServiceImpl::PredictAsync(const proto::PredictRequest *request, proto::PredictReply *reply, + PredictOnFinish on_finish) { + dispatcher_->DispatchAsync(*request, reply, on_finish); } grpc::Status MSMasterImpl::Register(grpc::ServerContext *context, const proto::RegisterRequest *request, diff --git a/mindspore_serving/ccsrc/master/grpc/grpc_process.h b/mindspore_serving/ccsrc/master/grpc/grpc_process.h index a38c5be..f4c8d9c 100644 --- a/mindspore_serving/ccsrc/master/grpc/grpc_process.h +++ b/mindspore_serving/ccsrc/master/grpc/grpc_process.h @@ -41,7 +41,7 @@ class MSServiceImpl { explicit MSServiceImpl(std::shared_ptr dispatcher) : dispatcher_(dispatcher) {} ~MSServiceImpl() = default; - Status PredictAsync(const proto::PredictRequest *request, proto::PredictReply *reply, DispatchCallback callback); + void PredictAsync(const proto::PredictRequest *request, proto::PredictReply *reply, PredictOnFinish on_finish); private: std::shared_ptr dispatcher_; diff --git a/mindspore_serving/ccsrc/master/grpc/grpc_server.h b/mindspore_serving/ccsrc/master/grpc/grpc_server.h index 43d372d..d49cbb1 100644 --- a/mindspore_serving/ccsrc/master/grpc/grpc_server.h +++ b/mindspore_serving/ccsrc/master/grpc/grpc_server.h @@ -94,11 +94,8 @@ class MasterPredictContext : public MasterServiceContext { void HandleRequest() override { EnqueueRequest(service_impl_, async_service_, cq_); state_ = STATE::FINISH; - DispatchCallback callback = [this](Status status) { responder_.Finish(response_, grpc::Status::OK, this); }; - Status status = service_impl_->PredictAsync(&request_, &response_, callback); - if (!status.IsSuccess()) { - responder_.Finish(response_, grpc::Status::OK, this); - } + PredictOnFinish on_finish = [this]() { responder_.Finish(response_, grpc::Status::OK, this); }; + service_impl_->PredictAsync(&request_, &response_, on_finish); } bool JudgeFinish() override { return state_ == STATE::FINISH; } diff --git a/mindspore_serving/ccsrc/master/notify_worker/base_notify.h b/mindspore_serving/ccsrc/master/notify_worker/base_notify.h index 5ccb0c3..30ec379 100644 --- a/mindspore_serving/ccsrc/master/notify_worker/base_notify.h +++ b/mindspore_serving/ccsrc/master/notify_worker/base_notify.h @@ -33,7 +33,7 @@ class MS_API BaseNotifyWorker { virtual ~BaseNotifyWorker() = default; virtual Status Exit() = 0; virtual Status DispatchAsync(const proto::PredictRequest &request, proto::PredictReply *reply, - DispatchCallback callback) = 0; + PredictOnFinish on_finish) = 0; }; } // namespace serving diff --git a/mindspore_serving/ccsrc/master/notify_worker/grpc_notify.cc b/mindspore_serving/ccsrc/master/notify_worker/grpc_notify.cc index b60a86d..fe04956 100644 --- a/mindspore_serving/ccsrc/master/notify_worker/grpc_notify.cc +++ b/mindspore_serving/ccsrc/master/notify_worker/grpc_notify.cc @@ -20,19 +20,20 @@ #include #include "common/exit_handle.h" #include "common/grpc_server.h" +#include "common/proto_tensor.h" namespace mindspore { namespace serving { -GrpcNotfiyWorker::GrpcNotfiyWorker(const std::string &worker_address) { +GrpcNotifyWorker::GrpcNotifyWorker(const std::string &worker_address) { worker_address_ = worker_address; std::shared_ptr channel = GrpcServer::CreateChannel(worker_address); stub_ = proto::MSWorker::NewStub(channel); } -GrpcNotfiyWorker::~GrpcNotfiyWorker() = default; +GrpcNotifyWorker::~GrpcNotifyWorker() = default; -Status GrpcNotfiyWorker::Exit() { +Status GrpcNotifyWorker::Exit() { if (stub_) { proto::ExitRequest request; request.set_address(worker_address_); @@ -47,8 +48,8 @@ Status GrpcNotfiyWorker::Exit() { return SUCCESS; } -Status GrpcNotfiyWorker::DispatchAsync(const proto::PredictRequest &request, proto::PredictReply *reply, - DispatchCallback callback) { +Status GrpcNotifyWorker::DispatchAsync(const proto::PredictRequest &request, proto::PredictReply *reply, + PredictOnFinish on_finish) { if (!stub_) { return INFER_STATUS_LOG_ERROR(FAILED) << "Predict failed, worker gRPC has not been inited or has already exited, worker address " @@ -58,6 +59,10 @@ Status GrpcNotfiyWorker::DispatchAsync(const proto::PredictRequest &request, pro client_ = std::make_unique(); client_->Start(); } + AsyncPredictCallback callback = [reply, on_finish](Status status) { + GrpcTensorHelper::CreateReplyFromErrorMsg(status, reply); + on_finish(); + }; client_->PredictAsync(request, reply, stub_.get(), callback); return SUCCESS; } diff --git a/mindspore_serving/ccsrc/master/notify_worker/grpc_notify.h b/mindspore_serving/ccsrc/master/notify_worker/grpc_notify.h index c5200ad..083076b 100644 --- a/mindspore_serving/ccsrc/master/notify_worker/grpc_notify.h +++ b/mindspore_serving/ccsrc/master/notify_worker/grpc_notify.h @@ -27,15 +27,15 @@ namespace mindspore { namespace serving { -class MS_API GrpcNotfiyWorker : public BaseNotifyWorker { +class MS_API GrpcNotifyWorker : public BaseNotifyWorker { public: - explicit GrpcNotfiyWorker(const std::string &worker_address); - ~GrpcNotfiyWorker() override; + explicit GrpcNotifyWorker(const std::string &worker_address); + ~GrpcNotifyWorker() override; Status Exit() override; Status DispatchAsync(const proto::PredictRequest &request, proto::PredictReply *reply, - DispatchCallback callback) override; + PredictOnFinish on_finish) override; private: std::string worker_address_; diff --git a/mindspore_serving/ccsrc/master/notify_worker/local_notify.cc b/mindspore_serving/ccsrc/master/notify_worker/local_notify.cc index 4713441..6150a66 100644 --- a/mindspore_serving/ccsrc/master/notify_worker/local_notify.cc +++ b/mindspore_serving/ccsrc/master/notify_worker/local_notify.cc @@ -26,8 +26,8 @@ Status LocalNotifyWorker::Exit() { } Status LocalNotifyWorker::DispatchAsync(const proto::PredictRequest &request, proto::PredictReply *reply, - DispatchCallback callback) { - return Worker::GetInstance().RunAsync(request, reply, callback); + PredictOnFinish on_finish) { + return Worker::GetInstance().RunAsync(request, reply, on_finish); } } // namespace serving diff --git a/mindspore_serving/ccsrc/master/notify_worker/local_notify.h b/mindspore_serving/ccsrc/master/notify_worker/local_notify.h index 8fb7de2..e872dd3 100644 --- a/mindspore_serving/ccsrc/master/notify_worker/local_notify.h +++ b/mindspore_serving/ccsrc/master/notify_worker/local_notify.h @@ -30,7 +30,7 @@ class MS_API LocalNotifyWorker : public BaseNotifyWorker { Status Exit() override; Status DispatchAsync(const proto::PredictRequest &request, proto::PredictReply *reply, - DispatchCallback callback) override; + PredictOnFinish on_finish) override; }; } // namespace serving diff --git a/mindspore_serving/ccsrc/master/restful/http_process.cc b/mindspore_serving/ccsrc/master/restful/http_process.cc index ed8ed2b..c99b179 100644 --- a/mindspore_serving/ccsrc/master/restful/http_process.cc +++ b/mindspore_serving/ccsrc/master/restful/http_process.cc @@ -693,14 +693,8 @@ Status RestfulService::RunRestful(const std::shared_ptr &restful } MSI_TIME_STAMP_START(Predict) - status = dispatcher_->Dispatch(request, &reply); + dispatcher_->Dispatch(request, &reply); MSI_TIME_STAMP_END(Predict) - if (status != SUCCESS) { - std::string error_msg = status.StatusMessage(); - std::string msg = "Predict failed, " + error_msg; - status = msg; - return status; - } MSI_TIME_STAMP_START(CreateReplyJson) status = ParseReply(reply, out_json); @@ -1037,11 +1031,6 @@ Status RestfulService::CheckReply(const ProtoTensor &pb_tensor) { return status; } -void RestfulService::ParseErrorMsg(const proto::ErrorMsg &error, json *const js) { - std::string str = error.error_msg(); - *js = str; -} - // 5.Parse reply Status RestfulService::ParseReply(const PredictReply &reply, json *const out_json) { Status status(SUCCESS); @@ -1059,56 +1048,42 @@ Status RestfulService::ParseReply(const PredictReply &reply, json *const out_jso Status RestfulService::ParseInstancesReply(const PredictReply &reply, json *const out_json) { Status status(SUCCESS); auto error_size = reply.error_msg_size(); - if (error_size != 0 && error_size != 1 && error_size != instances_nums_) { + auto reply_size = reply.instances().size(); + if (error_size == 1 && reply_size == 0) { + (*out_json)[kErrorMsg] = reply.error_msg()[0].error_msg(); + return SUCCESS; + } + if (error_size != 0 && error_size != instances_nums_) { return INFER_STATUS_LOG_ERROR(FAILED) << "reply error size:" << error_size << " is not 0,1 or instances size"; } + if (reply_size != instances_nums_) { + return INFER_STATUS_LOG_ERROR(FAILED) + << "reply size:" << reply_size << " is not matched request size:" << instances_nums_; + } (*out_json)[kInstancesReply] = json(); json &instances_json = (*out_json)[kInstancesReply]; - int32_t reply_num = reply.instances().size(); - if (reply_num == 0) { - reply_num = error_size; - } - if (error_size == 0 && reply_num != instances_nums_) { - return INFER_STATUS_LOG_ERROR(FAILED) - << "reply size:" << reply_num << " is not matched request size:" << instances_nums_; - } - - for (int32_t i = 0; i < reply_num; i++) { - bool success_flag = true; - if (i < error_size) { - auto &cur_error = reply.error_msg().at(i); - success_flag = (cur_error.error_code() == 0); + for (int32_t i = 0; i < instances_nums_; i++) { + instances_json.push_back(json()); + auto &instance = instances_json.back(); + if (error_size != 0 && reply.error_msg()[i].error_code() != 0) { + instance[kErrorMsg] = reply.error_msg(i).error_msg(); + continue; + } + auto &cur_instance = reply.instances(i); + auto &items = cur_instance.items(); + if (items.empty()) { + return INFER_STATUS_LOG_ERROR(FAILED) << "reply instance items is empty"; } - if (success_flag) { - if (i >= reply.instances_size()) { - return INFER_STATUS_LOG_ERROR(FAILED) - << "index:" << i << " is more than reply instances size:" << reply.instances_size(); - } - auto &cur_instance = reply.instances(i); - auto &items = cur_instance.items(); - if (items.empty()) { - return INFER_STATUS_LOG_ERROR(FAILED) << "reply instance items is empty"; - } - instances_json.push_back(json()); - auto &instance = instances_json.back(); - - for (auto &item : items) { - instance[item.first] = json(); - auto &value_json = instance[item.first]; - status = ParseReplyDetail(item.second, &value_json); - if (status != SUCCESS) { - return status; - } + for (auto &item : items) { + instance[item.first] = json(); + auto &value_json = instance[item.first]; + status = ParseReplyDetail(item.second, &value_json); + if (status != SUCCESS) { + return status; } - } else { - instances_json.push_back(json()); - auto &obj = instances_json.back(); - obj[kErrorMsg] = json(); - auto &js = obj[kErrorMsg]; - ParseErrorMsg(reply.error_msg(i), &js); } } return status; diff --git a/mindspore_serving/ccsrc/master/restful/http_process.h b/mindspore_serving/ccsrc/master/restful/http_process.h index be7e668..5e0f83a 100644 --- a/mindspore_serving/ccsrc/master/restful/http_process.h +++ b/mindspore_serving/ccsrc/master/restful/http_process.h @@ -98,7 +98,6 @@ class RestfulService { Status ParseScalarData(const ProtoTensor &pb_tensor, bool is_bytes, size_t index, json *const js); template bool IsString(); - void ParseErrorMsg(const proto::ErrorMsg &error_msg, json *const js); RequestType request_type_{kInvalidType}; InstancesType instances_type_{kInvalidWay}; diff --git a/mindspore_serving/ccsrc/worker/distributed_worker/distributed_process/distributed_process.h b/mindspore_serving/ccsrc/worker/distributed_worker/distributed_process/distributed_process.h index 5661f9c..8e14a4a 100644 --- a/mindspore_serving/ccsrc/worker/distributed_worker/distributed_process/distributed_process.h +++ b/mindspore_serving/ccsrc/worker/distributed_worker/distributed_process/distributed_process.h @@ -41,13 +41,13 @@ class MSDistributedImpl final : public MSWorkerImpl { : MSWorkerImpl(server_address), servable_(servable) {} ~MSDistributedImpl() = default; grpc::Status AgentRegister(grpc::ServerContext *context, const proto::AgentRegisterRequest *request, - proto::AgentRegisterReply *reply) override; + proto::AgentRegisterReply *reply); grpc::Status AgentExit(grpc::ServerContext *context, const proto::AgentExitRequest *request, - proto::AgentExitReply *reply) override; + proto::AgentExitReply *reply); grpc::Status AgentFailed(grpc::ServerContext *context, const proto::AgentFailedRequest *request, - proto::AgentFailedReply *reply) override; + proto::AgentFailedReply *reply); grpc::Status AgentConfigAcquire(grpc::ServerContext *context, const proto::AgentConfigAcquireRequest *request, - proto::AgentConfigAcquireReply *reply) override; + proto::AgentConfigAcquireReply *reply); private: std::shared_ptr servable_; diff --git a/mindspore_serving/ccsrc/worker/distributed_worker/distributed_servable.cc b/mindspore_serving/ccsrc/worker/distributed_worker/distributed_servable.cc index f472a85..9cb043a 100644 --- a/mindspore_serving/ccsrc/worker/distributed_worker/distributed_servable.cc +++ b/mindspore_serving/ccsrc/worker/distributed_worker/distributed_servable.cc @@ -76,7 +76,7 @@ Status DistributedServable::PredictInner(const std::vector &input auto msg_list = std::make_shared>(rank_size); for (size_t i = 0; i < rank_size; ++i) { - DispatchCallback callback = [msg_list, i](const Status &status) { + AsyncPredictCallback callback = [msg_list, i](const Status &status) { msg_list->at(i).status = status; msg_list->at(i).promise.set_value(); }; diff --git a/mindspore_serving/ccsrc/worker/distributed_worker/notify_agent/base_notify_agent.h b/mindspore_serving/ccsrc/worker/distributed_worker/notify_agent/base_notify_agent.h index ac4d5c7..6e8d967 100644 --- a/mindspore_serving/ccsrc/worker/distributed_worker/notify_agent/base_notify_agent.h +++ b/mindspore_serving/ccsrc/worker/distributed_worker/notify_agent/base_notify_agent.h @@ -33,7 +33,7 @@ class MS_API BaseNotifyAgent { virtual ~BaseNotifyAgent() = default; virtual Status Exit() = 0; virtual Status DispatchAsync(const proto::DistributedPredictRequest &request, proto::DistributedPredictReply *reply, - DispatchCallback callback) = 0; + AsyncPredictCallback callback) = 0; }; } // namespace serving diff --git a/mindspore_serving/ccsrc/worker/distributed_worker/notify_agent/notify_agent.cc b/mindspore_serving/ccsrc/worker/distributed_worker/notify_agent/notify_agent.cc index 13c9a46..05b928d 100644 --- a/mindspore_serving/ccsrc/worker/distributed_worker/notify_agent/notify_agent.cc +++ b/mindspore_serving/ccsrc/worker/distributed_worker/notify_agent/notify_agent.cc @@ -55,7 +55,7 @@ Status GrpcNotifyAgent::Exit() { } Status GrpcNotifyAgent::DispatchAsync(const proto::DistributedPredictRequest &request, - proto::DistributedPredictReply *reply, DispatchCallback callback) { + proto::DistributedPredictReply *reply, AsyncPredictCallback callback) { if (!stub_) { return INFER_STATUS_LOG_ERROR(FAILED) << "Predict failed, agent gRPC has not been inited or has already exited, agent address " << agent_address_; diff --git a/mindspore_serving/ccsrc/worker/distributed_worker/notify_agent/notify_agent.h b/mindspore_serving/ccsrc/worker/distributed_worker/notify_agent/notify_agent.h index 2363b5a..bff419b 100644 --- a/mindspore_serving/ccsrc/worker/distributed_worker/notify_agent/notify_agent.h +++ b/mindspore_serving/ccsrc/worker/distributed_worker/notify_agent/notify_agent.h @@ -35,7 +35,7 @@ class MS_API GrpcNotifyAgent : public BaseNotifyAgent { Status Exit() override; Status DispatchAsync(const proto::DistributedPredictRequest &request, proto::DistributedPredictReply *reply, - DispatchCallback callback) override; + AsyncPredictCallback callback) override; private: std::string agent_address_; diff --git a/mindspore_serving/ccsrc/worker/grpc/worker_process.cc b/mindspore_serving/ccsrc/worker/grpc/worker_process.cc index 14dd967..11b6d81 100644 --- a/mindspore_serving/ccsrc/worker/grpc/worker_process.cc +++ b/mindspore_serving/ccsrc/worker/grpc/worker_process.cc @@ -16,6 +16,7 @@ #include "worker/grpc/worker_process.h" #include "worker/worker.h" +#include "common/proto_tensor.h" namespace mindspore { namespace serving { @@ -26,13 +27,13 @@ grpc::Status MSWorkerImpl::Exit(grpc::ServerContext *context, const proto::ExitR return grpc::Status::OK; } -grpc::Status MSWorkerImpl::PredictAsync(grpc::ServerContext *context, const proto::PredictRequest *request, - proto::PredictReply *reply, DispatchCallback callback) { +void MSWorkerImpl::PredictAsync(grpc::ServerContext *context, const proto::PredictRequest *request, + proto::PredictReply *reply, PredictOnFinish on_finish) { Status status(FAILED); MSI_LOG(INFO) << "Begin call service Eval"; try { MSI_TIME_STAMP_START(Predict) - status = Worker::GetInstance().RunAsync(*request, reply, callback); + status = Worker::GetInstance().RunAsync(*request, reply, on_finish); MSI_TIME_STAMP_END(Predict) } catch (const std::bad_alloc &ex) { MSI_LOG(ERROR) << "Serving Error: malloc memory failed"; @@ -49,19 +50,12 @@ grpc::Status MSWorkerImpl::PredictAsync(grpc::ServerContext *context, const prot } MSI_LOG(INFO) << "Finish call service Eval"; - if (status == INVALID_INPUTS) { - auto proto_error_msg = reply->add_error_msg(); - proto_error_msg->set_error_code(status.StatusCode()); - proto_error_msg->set_error_msg(status.StatusMessage()); - return grpc::Status::OK; - } else if (status != SUCCESS) { - auto proto_error_msg = reply->add_error_msg(); - proto_error_msg->set_error_code(FAILED); - proto_error_msg->set_error_msg("Predict failed"); - return grpc::Status::OK; + if (status != SUCCESS) { + GrpcTensorHelper::CreateReplyFromErrorMsg(status, reply); + on_finish(); } - return grpc::Status::OK; } + grpc::Status MSWorkerImpl::Ping(grpc::ServerContext *context, const proto::PingRequest *request, proto::PingReply *reply) { MSI_EXCEPTION_IF_NULL(request); diff --git a/mindspore_serving/ccsrc/worker/grpc/worker_process.h b/mindspore_serving/ccsrc/worker/grpc/worker_process.h index b53e851..10e6d27 100644 --- a/mindspore_serving/ccsrc/worker/grpc/worker_process.h +++ b/mindspore_serving/ccsrc/worker/grpc/worker_process.h @@ -35,7 +35,7 @@ namespace mindspore { namespace serving { // Service Implement -class MSWorkerImpl : public proto::MSWorker::Service { +class MSWorkerImpl { public: explicit MSWorkerImpl(const std::string server_address) { if (!watcher_) { @@ -43,11 +43,11 @@ class MSWorkerImpl : public proto::MSWorker::Service { } } - grpc::Status PredictAsync(grpc::ServerContext *context, const proto::PredictRequest *request, - proto::PredictReply *reply, DispatchCallback callback); - grpc::Status Exit(grpc::ServerContext *context, const proto::ExitRequest *request, proto::ExitReply *reply) override; - grpc::Status Ping(grpc::ServerContext *context, const proto::PingRequest *request, proto::PingReply *reply) override; - grpc::Status Pong(grpc::ServerContext *context, const proto::PongRequest *request, proto::PongReply *reply) override; + void PredictAsync(grpc::ServerContext *context, const proto::PredictRequest *request, proto::PredictReply *reply, + PredictOnFinish on_finish); + grpc::Status Exit(grpc::ServerContext *context, const proto::ExitRequest *request, proto::ExitReply *reply); + grpc::Status Ping(grpc::ServerContext *context, const proto::PingRequest *request, proto::PingReply *reply); + grpc::Status Pong(grpc::ServerContext *context, const proto::PongRequest *request, proto::PongReply *reply); std::shared_ptr> watcher_; }; diff --git a/mindspore_serving/ccsrc/worker/grpc/worker_server.h b/mindspore_serving/ccsrc/worker/grpc/worker_server.h index 23a9656..dbb3f13 100644 --- a/mindspore_serving/ccsrc/worker/grpc/worker_server.h +++ b/mindspore_serving/ccsrc/worker/grpc/worker_server.h @@ -100,11 +100,8 @@ class WorkerPredictContext : public WorkerServiceContext { void HandleRequest() override { EnqueueRequest(service_impl_, async_service_, cq_); state_ = STATE::FINISH; - DispatchCallback callback = [this](Status status) { responder_.Finish(response_, grpc::Status::OK, this); }; - grpc::Status status = service_impl_->PredictAsync(&ctx_, &request_, &response_, callback); - if (!status.ok()) { - responder_.Finish(response_, status, this); - } + PredictOnFinish on_finish = [this]() { responder_.Finish(response_, grpc::Status::OK, this); }; + service_impl_->PredictAsync(&ctx_, &request_, &response_, on_finish); } private: diff --git a/mindspore_serving/ccsrc/worker/inference/mindspore_model_wrap.cc b/mindspore_serving/ccsrc/worker/inference/mindspore_model_wrap.cc index c959d87..4ca3504 100644 --- a/mindspore_serving/ccsrc/worker/inference/mindspore_model_wrap.cc +++ b/mindspore_serving/ccsrc/worker/inference/mindspore_model_wrap.cc @@ -141,22 +141,22 @@ std::shared_ptr MindSporeModelWrap::TransformModelContext(const std::ma MSI_LOG_ERROR << "Set model context output type failed, unknown data type " << val; } }; - std::map option_map = { - {"acl_option.insert_op_config_file_path", mindspore::ModelContext::SetInsertOpConfigPath}, - {"acl_option.input_format", mindspore::ModelContext::SetInputFormat}, - {"acl_option.input_shape", mindspore::ModelContext::SetInputShape}, - {"acl_option.output_type", set_output_type}, - {"acl_option.precision_mode", mindspore::ModelContext::SetPrecisionMode}, - {"acl_option.op_select_impl_mode", mindspore::ModelContext::SetOpSelectImplMode}, - }; auto context = std::make_shared(); for (auto &item : options) { const auto &key = item.first; const auto &value = item.second; - auto it = option_map.find(key); - if (it != option_map.end()) { - MSI_LOG_INFO << "Set context options, key: " << key << ", value: " << value; - it->second(context, value); + if (key == "acl_option.insert_op_config_file_path") { + mindspore::ModelContext::SetInsertOpConfigPath(context, value); + } else if (key == "acl_option.input_format") { + mindspore::ModelContext::SetInputFormat(context, value); + } else if (key == "acl_option.input_shape") { + mindspore::ModelContext::SetInputShape(context, value); + } else if (key == "acl_option.output_type") { + set_output_type(context, value); + } else if (key == "acl_option.precision_mode") { + mindspore::ModelContext::SetPrecisionMode(context, value); + } else if (key == "acl_option.op_select_impl_mode") { + mindspore::ModelContext::SetOpSelectImplMode(context, value); } } return context; diff --git a/mindspore_serving/ccsrc/worker/worker.cc b/mindspore_serving/ccsrc/worker/worker.cc index faf877e..22cd7ba 100644 --- a/mindspore_serving/ccsrc/worker/worker.cc +++ b/mindspore_serving/ccsrc/worker/worker.cc @@ -60,7 +60,7 @@ Status Worker::RemoveWorker(const ServableWorkerContext &work) { return notify_master_->RemoveWorker(work.worker_spec); } -Status Worker::RunAsync(const proto::PredictRequest &request, proto::PredictReply *reply, DispatchCallback callback) { +Status Worker::RunAsync(const proto::PredictRequest &request, proto::PredictReply *reply, PredictOnFinish on_finish) { std::shared_lock lock(worker_shared_lock_); if (!servable_started_) { return INFER_STATUS_LOG_ERROR(FAILED) << "RunAsync worker for inference failed, worker has not been started"; @@ -81,16 +81,9 @@ Status Worker::RunAsync(const proto::PredictRequest &request, proto::PredictRepl if (worker.worker_service == nullptr) { return INFER_STATUS_LOG_ERROR(FAILED) << "Cannot find servable match " << request_spec.Repr(); } - WorkCallBack on_process_done = [request, reply, callback](const std::vector &instances) { - auto status = GrpcTensorHelper::CreateReplyFromInstances(request, instances, reply); - if (status != SUCCESS) { - MSI_LOG_ERROR << "transfer result to reply failed"; - reply->clear_error_msg(); - auto proto_error = reply->add_error_msg(); - proto_error->set_error_code(status.StatusCode()); - proto_error->set_error_msg(status.StatusMessage()); - } - callback(SUCCESS); + WorkCallBack on_process_done = [request, reply, on_finish](const std::vector &instances) { + GrpcTensorHelper::CreateReplyFromInstances(request, instances, reply); + on_finish(); }; return worker.worker_service->Work(request_spec, instances_data, on_process_done); } diff --git a/mindspore_serving/ccsrc/worker/worker.h b/mindspore_serving/ccsrc/worker/worker.h index d74c417..fa1a8f8 100644 --- a/mindspore_serving/ccsrc/worker/worker.h +++ b/mindspore_serving/ccsrc/worker/worker.h @@ -53,7 +53,7 @@ class MS_API Worker { static Worker &GetInstance(); void Clear(); - Status RunAsync(const proto::PredictRequest &request, proto::PredictReply *reply, DispatchCallback callback); + Status RunAsync(const proto::PredictRequest &request, proto::PredictReply *reply, PredictOnFinish on_finish); Status StartServable(std::shared_ptr servable, std::shared_ptr notify_master); Status StartGrpcServer(const std::shared_ptr &grpc_server, const std::string &worker_ip, diff --git a/tests/ut/cpp/common/test_servable_common.h b/tests/ut/cpp/common/test_servable_common.h index 5f565ac..3f3f876 100644 --- a/tests/ut/cpp/common/test_servable_common.h +++ b/tests/ut/cpp/common/test_servable_common.h @@ -376,11 +376,8 @@ class TestMasterWorkerClient : public TestMasterWorker { grpc::ServerContext context; auto promise = std::make_shared>(); auto future = promise->get_future(); - DispatchCallback callback = [promise](Status status) { promise->set_value(); }; - auto status = impl.PredictAsync(&request, reply, callback); - if (!status.IsSuccess()) { - return grpc::Status::OK; - } + PredictOnFinish callback = [promise]() { promise->set_value(); }; + impl.PredictAsync(&request, reply, callback); future.get(); return grpc::Status::OK; } diff --git a/tests/ut/python/tests/test_mater_worker_client.py b/tests/ut/python/tests/test_mater_worker_client.py index 7154d31..2d33868 100644 --- a/tests/ut/python/tests/test_mater_worker_client.py +++ b/tests/ut/python/tests/test_mater_worker_client.py @@ -173,7 +173,7 @@ def test_grpc_start_restful_server_twice_failed(): @serving_test -def test_grpc_alone_repeat_master_and_woker_port_failed(): +def test_grpc_alone_repeat_master_and_worker_port_failed(): base = ServingTestBase() base.init_servable(1, "add_servable_config.py") master.start_master_server(master_port=7600) @@ -731,3 +731,370 @@ def add_common(x1, x2): assert result[0]["y"] == 0 assert "Preprocess Failed" in str(result[1]["error"]) or "Servable stopped" in str(result[1]["error"]) assert result[0]["y"] == 0 + + +@serving_test +def test_servable_worker_with_master_preprocess_runtime_error(): + # fail returned from Preprocess + base = ServingTestBase() + servable_content = servable_config_import + servable_content += servable_config_declare_servable + servable_content += r""" +index = 0 +def preprocess(instances): + count = len(instances) + global index + for i in range(count): + ret = index + index += 1 + if ret == 0: + raise RuntimeError("runtime error") + yield ret + +@register.register_method(output_names=["y"]) +def add_common(x1, x2): + x3 = register.call_preprocess_pipeline(preprocess, x1) + y = register.call_servable(x1, x2) + return x3 +""" + base.init_servable_with_servable_config(1, servable_content) + worker.start_servable_in_master(base.servable_dir, base.servable_name) + master.start_grpc_server("0.0.0.0", 5500) + # Client + instance_count = 3 + + instances = [] + y_data_list = [] + for i in range(instance_count): + x1 = np.asarray([[1.1, 2.2], [3.3, 4.4]]).astype(np.float32) * (i + 1) + x2 = np.asarray([[5.5, 6.6], [7.7, 8.8]]).astype(np.float32) * (i + 1) + y_data_list.append(x1 + x2) + instances.append({"x1": x1, "x2": x2}) + + client = create_client("localhost", 5500, base.servable_name, "add_common") + result = client.infer(instances) + print(result) + assert "Preprocess Failed" in result[0]["error"] + assert result[1]["y"] == 1 + assert result[2]["y"] == 2 + + +@serving_test +def test_servable_worker_alone_preprocess_runtime_error(): + # fail returned from Preprocess + base = ServingTestBase() + servable_content = servable_config_import + servable_content += servable_config_declare_servable + servable_content += r""" +index = 0 +def preprocess(instances): + count = len(instances) + global index + for i in range(count): + ret = index + index += 1 + if ret == 0: + raise RuntimeError("runtime error") + yield ret + +@register.register_method(output_names=["y"]) +def add_common(x1, x2): + x3 = register.call_preprocess_pipeline(preprocess, x1) + y = register.call_servable(x1, x2) + return x3 +""" + base.init_servable_with_servable_config(1, servable_content) + master.start_master_server("0.0.0.0", 6100) + worker.start_servable(base.servable_dir, base.servable_name, worker_port=6200, master_port=6100) + master.start_grpc_server("0.0.0.0", 5500) + # Client + instance_count = 3 + + instances = [] + y_data_list = [] + for i in range(instance_count): + x1 = np.asarray([[1.1, 2.2], [3.3, 4.4]]).astype(np.float32) * (i + 1) + x2 = np.asarray([[5.5, 6.6], [7.7, 8.8]]).astype(np.float32) * (i + 1) + y_data_list.append(x1 + x2) + instances.append({"x1": x1, "x2": x2}) + + client = create_client("localhost", 5500, base.servable_name, "add_common") + result = client.infer(instances) + print(result) + assert "Preprocess Failed" in result[0]["error"] + assert result[1]["y"] == 1 + assert result[2]["y"] == 2 + + +@serving_test +def test_servable_worker_with_master_predict_check_failed(): + # fail returned from Predict + base = ServingTestBase() + servable_content = servable_config_import + servable_content += servable_config_declare_servable + servable_content += r""" +@register.register_method(output_names=["y"]) +def add_common(x1, x2): + y = register.call_servable(x1, x2) + return y +""" + base.init_servable_with_servable_config(1, servable_content) + worker.start_servable_in_master(base.servable_dir, base.servable_name) + master.start_grpc_server("0.0.0.0", 5500) + # Client + instance_count = 3 + + instances = [] + y_data_list = [] + for i in range(instance_count): + if i == 0: + x1 = np.asarray([[1.1], [3.3]]).astype(np.float32) * (i + 1) + else: + x1 = np.asarray([[1.1, 2.2], [3.3, 4.4]]).astype(np.float32) * (i + 1) + x2 = np.asarray([[5.5, 6.6], [7.7, 8.8]]).astype(np.float32) * (i + 1) + y_data_list.append(x1 + x2) + instances.append({"x1": x1, "x2": x2}) + + client = create_client("localhost", 5500, base.servable_name, "add_common") + result = client.infer(instances) + print(result) + assert "Given model input 0 size 8 not match the size 16 defined in model" in result[0]["error"] + assert "y" in result[1] + assert "y" in result[2] + + +@serving_test +def test_servable_worker_alone_predict_check_failed(): + # fail returned from Predict + base = ServingTestBase() + servable_content = servable_config_import + servable_content += servable_config_declare_servable + servable_content += r""" +@register.register_method(output_names=["y"]) +def add_common(x1, x2): + y = register.call_servable(x1, x2) + return y +""" + base.init_servable_with_servable_config(1, servable_content) + master.start_master_server("0.0.0.0", 6100) + worker.start_servable(base.servable_dir, base.servable_name, worker_port=6200, master_port=6100) + master.start_grpc_server("0.0.0.0", 5500) + # Client + instance_count = 3 + + instances = [] + y_data_list = [] + for i in range(instance_count): + if i == 0: + x1 = np.asarray([[1.1], [3.3]]).astype(np.float32) * (i + 1) + else: + x1 = np.asarray([[1.1, 2.2], [3.3, 4.4]]).astype(np.float32) * (i + 1) + x2 = np.asarray([[5.5, 6.6], [7.7, 8.8]]).astype(np.float32) * (i + 1) + y_data_list.append(x1 + x2) + instances.append({"x1": x1, "x2": x2}) + + client = create_client("localhost", 5500, base.servable_name, "add_common") + result = client.infer(instances) + print(result) + assert "Given model input 0 size 8 not match the size 16 defined in model" in result[0]["error"] + assert "y" in result[1] + assert "y" in result[2] + + +@serving_test +def test_servable_worker_with_master_postprocess_runtime_error(): + # fail returned from Preprocess + base = ServingTestBase() + servable_content = servable_config_import + servable_content += servable_config_declare_servable + servable_content += r""" +index = 0 +def postprocess(y): + global index + ret = index + index += 1 + if ret == 0: + raise RuntimeError("runtime error") + return ret + +@register.register_method(output_names=["y"]) +def add_common(x1, x2): + y = register.call_servable(x1, x2) + y = register.call_postprocess(postprocess, y) + return y +""" + base.init_servable_with_servable_config(1, servable_content) + worker.start_servable_in_master(base.servable_dir, base.servable_name) + master.start_grpc_server("0.0.0.0", 5500) + # Client + instance_count = 3 + + instances = [] + y_data_list = [] + for i in range(instance_count): + x1 = np.asarray([[1.1, 2.2], [3.3, 4.4]]).astype(np.float32) * (i + 1) + x2 = np.asarray([[5.5, 6.6], [7.7, 8.8]]).astype(np.float32) * (i + 1) + y_data_list.append(x1 + x2) + instances.append({"x1": x1, "x2": x2}) + + client = create_client("localhost", 5500, base.servable_name, "add_common") + result = client.infer(instances) + print(result) + assert "Postprocess Failed" in result[0]["error"] + assert result[1]["y"] == 1 + assert result[2]["y"] == 2 + + +@serving_test +def test_servable_worker_alone_postprocess_runtime_error(): + # fail returned from Preprocess + base = ServingTestBase() + servable_content = servable_config_import + servable_content += servable_config_declare_servable + servable_content += r""" +index = 0 +def postprocess(y): + global index + ret = index + index += 1 + if ret == 0: + raise RuntimeError("runtime error") + return ret + +@register.register_method(output_names=["y"]) +def add_common(x1, x2): + y = register.call_servable(x1, x2) + y = register.call_postprocess(postprocess, y) + return y +""" + base.init_servable_with_servable_config(1, servable_content) + master.start_master_server("0.0.0.0", 6100) + worker.start_servable(base.servable_dir, base.servable_name, worker_port=6200, master_port=6100) + master.start_grpc_server("0.0.0.0", 5500) + # Client + instance_count = 3 + + instances = [] + y_data_list = [] + for i in range(instance_count): + x1 = np.asarray([[1.1, 2.2], [3.3, 4.4]]).astype(np.float32) * (i + 1) + x2 = np.asarray([[5.5, 6.6], [7.7, 8.8]]).astype(np.float32) * (i + 1) + y_data_list.append(x1 + x2) + instances.append({"x1": x1, "x2": x2}) + + client = create_client("localhost", 5500, base.servable_name, "add_common") + result = client.infer(instances) + print(result) + assert "Postprocess Failed" in result[0]["error"] + assert result[1]["y"] == 1 + assert result[2]["y"] == 2 + + +@serving_test +def test_servable_worker_with_master_input_param_less(): + # fail returned from Worker::RunAsync + base = ServingTestBase() + servable_content = servable_config_import + servable_content += servable_config_declare_servable + servable_content += servable_config_method_add_common + base.init_servable_with_servable_config(1, servable_content) + worker.start_servable_in_master(base.servable_dir, base.servable_name) + master.start_grpc_server("0.0.0.0", 5500) + # Client + instance_count = 3 + + instances = [] + y_data_list = [] + for i in range(instance_count): + x1 = np.asarray([[1.1], [3.3]]).astype(np.float32) * (i + 1) + x2 = np.asarray([[5.5], [7.7]]).astype(np.float32) * (i + 1) + y_data_list.append(x1 + x2) + instances.append({"x3": x1, "x2": x2}) + + client = create_client("localhost", 5500, base.servable_name, "add_common") + result = client.infer(instances) + print(result) + assert "Cannot find input x1 in instance input" in result["error"] + + +@serving_test +def test_servable_worker_alone_input_param_less(): + # fail returned from Worker::RunAsync + base = ServingTestBase() + servable_content = servable_config_import + servable_content += servable_config_declare_servable + servable_content += servable_config_method_add_common + base.init_servable_with_servable_config(1, servable_content) + master.start_master_server("0.0.0.0", 6100) + worker.start_servable(base.servable_dir, base.servable_name, worker_port=6200, master_port=6100) + master.start_grpc_server("0.0.0.0", 5500) + # Client + instance_count = 3 + + instances = [] + y_data_list = [] + for i in range(instance_count): + x1 = np.asarray([[1.1], [3.3]]).astype(np.float32) * (i + 1) + x2 = np.asarray([[5.5], [7.7]]).astype(np.float32) * (i + 1) + y_data_list.append(x1 + x2) + instances.append({"x3": x1, "x2": x2}) + + client = create_client("localhost", 5500, base.servable_name, "add_common") + result = client.infer(instances) + print(result) + assert "Cannot find input x1 in instance input" in result["error"] + + +@serving_test +def test_servable_worker_with_master_servable_not_available(): + # fail returned from Worker::RunAsync + base = ServingTestBase() + servable_content = servable_config_import + servable_content += servable_config_declare_servable + servable_content += servable_config_method_add_common + base.init_servable_with_servable_config(1, servable_content) + worker.start_servable_in_master(base.servable_dir, base.servable_name) + master.start_grpc_server("0.0.0.0", 5500) + # Client + instance_count = 3 + + instances = [] + y_data_list = [] + for i in range(instance_count): + x1 = np.asarray([[1.1], [3.3]]).astype(np.float32) * (i + 1) + x2 = np.asarray([[5.5], [7.7]]).astype(np.float32) * (i + 1) + y_data_list.append(x1 + x2) + instances.append({"x3": x1, "x2": x2}) + + client = create_client("localhost", 5500, base.servable_name + "error", "add_common") + result = client.infer(instances) + print(result) + assert "servable is not available" in result["error"] + + +@serving_test +def test_servable_worker_alone_servable_not_available(): + # fail returned from Worker::RunAsync + base = ServingTestBase() + servable_content = servable_config_import + servable_content += servable_config_declare_servable + servable_content += servable_config_method_add_common + base.init_servable_with_servable_config(1, servable_content) + master.start_master_server("0.0.0.0", 6100) + worker.start_servable(base.servable_dir, base.servable_name, worker_port=6200, master_port=6100) + master.start_grpc_server("0.0.0.0", 5500) + # Client + instance_count = 3 + + instances = [] + y_data_list = [] + for i in range(instance_count): + x1 = np.asarray([[1.1], [3.3]]).astype(np.float32) * (i + 1) + x2 = np.asarray([[5.5], [7.7]]).astype(np.float32) * (i + 1) + y_data_list.append(x1 + x2) + instances.append({"x3": x1, "x2": x2}) + + client = create_client("localhost", 5500, base.servable_name + "error", "add_common") + result = client.infer(instances) + print(result) + assert "servable is not available" in result["error"] diff --git a/tests/ut/stub/cxx_api/context.cc b/tests/ut/stub/cxx_api/context.cc index a9ea405..d967963 100644 --- a/tests/ut/stub/cxx_api/context.cc +++ b/tests/ut/stub/cxx_api/context.cc @@ -14,6 +14,9 @@ * limitations under the License. */ #include "include/api/context.h" +#include +#include +#include #include "utils/log_adapter.h" constexpr auto kGlobalContextDeviceTarget = "mindspore.ascend.globalcontext.device_target"; @@ -28,18 +31,28 @@ constexpr auto kModelOptionPrecisionMode = "mindspore.option.precision_mode"; constexpr auto kModelOptionOpSelectImplMode = "mindspore.option.op_select_impl_mode"; namespace mindspore { -template -static T GetValue(const std::shared_ptr &context, const std::string &key) { - auto iter = context->params.find(key); - if (iter == context->params.end()) { - return T(); +struct Context::Data { + std::map params; +}; + +Context::Context() : data(std::make_shared()) {} + +template >> +static const U &GetValue(const std::shared_ptr &context, const std::string &key) { + static U empty_result; + if (context == nullptr || context->data == nullptr) { + return empty_result; + } + auto iter = context->data->params.find(key); + if (iter == context->data->params.end()) { + return empty_result; } const std::any &value = iter->second; - if (value.type() != typeid(T)) { - return T(); + if (value.type() != typeid(U)) { + return empty_result; } - return std::any_cast(value); + return std::any_cast(value); } std::shared_ptr GlobalContext::GetGlobalContext() { @@ -47,22 +60,31 @@ std::shared_ptr GlobalContext::GetGlobalContext() { return g_context; } -void GlobalContext::SetGlobalDeviceTarget(const std::string &device_target) { +void GlobalContext::SetGlobalDeviceTarget(const std::vector &device_target) { auto global_context = GetGlobalContext(); MS_EXCEPTION_IF_NULL(global_context); - global_context->params[kGlobalContextDeviceTarget] = device_target; + if (global_context->data == nullptr) { + global_context->data = std::make_shared(); + MS_EXCEPTION_IF_NULL(global_context->data); + } + global_context->data->params[kGlobalContextDeviceTarget] = CharToString(device_target); } -std::string GlobalContext::GetGlobalDeviceTarget() { +std::vector GlobalContext::GetGlobalDeviceTargetChar() { auto global_context = GetGlobalContext(); MS_EXCEPTION_IF_NULL(global_context); - return GetValue(global_context, kGlobalContextDeviceTarget); + const std::string &ref = GetValue(global_context, kGlobalContextDeviceTarget); + return StringToChar(ref); } void GlobalContext::SetGlobalDeviceID(const uint32_t &device_id) { auto global_context = GetGlobalContext(); MS_EXCEPTION_IF_NULL(global_context); - global_context->params[kGlobalContextDeviceID] = device_id; + if (global_context->data == nullptr) { + global_context->data = std::make_shared(); + MS_EXCEPTION_IF_NULL(global_context->data); + } + global_context->data->params[kGlobalContextDeviceID] = device_id; } uint32_t GlobalContext::GetGlobalDeviceID() { @@ -71,39 +93,58 @@ uint32_t GlobalContext::GetGlobalDeviceID() { return GetValue(global_context, kGlobalContextDeviceID); } -void ModelContext::SetInsertOpConfigPath(const std::shared_ptr &context, const std::string &cfg_path) { +void ModelContext::SetInsertOpConfigPath(const std::shared_ptr &context, const std::vector &cfg_path) { MS_EXCEPTION_IF_NULL(context); - context->params[kModelOptionInsertOpCfgPath] = cfg_path; + if (context->data == nullptr) { + context->data = std::make_shared(); + MS_EXCEPTION_IF_NULL(context->data); + } + context->data->params[kModelOptionInsertOpCfgPath] = CharToString(cfg_path); } -std::string ModelContext::GetInsertOpConfigPath(const std::shared_ptr &context) { +std::vector ModelContext::GetInsertOpConfigPathChar(const std::shared_ptr &context) { MS_EXCEPTION_IF_NULL(context); - return GetValue(context, kModelOptionInsertOpCfgPath); + const std::string &ref = GetValue(context, kModelOptionInsertOpCfgPath); + return StringToChar(ref); } -void ModelContext::SetInputFormat(const std::shared_ptr &context, const std::string &format) { +void ModelContext::SetInputFormat(const std::shared_ptr &context, const std::vector &format) { MS_EXCEPTION_IF_NULL(context); - context->params[kModelOptionInputFormat] = format; + if (context->data == nullptr) { + context->data = std::make_shared(); + MS_EXCEPTION_IF_NULL(context->data); + } + context->data->params[kModelOptionInputFormat] = CharToString(format); } -std::string ModelContext::GetInputFormat(const std::shared_ptr &context) { +std::vector ModelContext::GetInputFormatChar(const std::shared_ptr &context) { MS_EXCEPTION_IF_NULL(context); - return GetValue(context, kModelOptionInputFormat); + const std::string &ref = GetValue(context, kModelOptionInputFormat); + return StringToChar(ref); } -void ModelContext::SetInputShape(const std::shared_ptr &context, const std::string &shape) { +void ModelContext::SetInputShape(const std::shared_ptr &context, const std::vector &shape) { MS_EXCEPTION_IF_NULL(context); - context->params[kModelOptionInputShape] = shape; + if (context->data == nullptr) { + context->data = std::make_shared(); + MS_EXCEPTION_IF_NULL(context->data); + } + context->data->params[kModelOptionInputShape] = CharToString(shape); } -std::string ModelContext::GetInputShape(const std::shared_ptr &context) { +std::vector ModelContext::GetInputShapeChar(const std::shared_ptr &context) { MS_EXCEPTION_IF_NULL(context); - return GetValue(context, kModelOptionInputShape); + const std::string &ref = GetValue(context, kModelOptionInputShape); + return StringToChar(ref); } void ModelContext::SetOutputType(const std::shared_ptr &context, enum DataType output_type) { MS_EXCEPTION_IF_NULL(context); - context->params[kModelOptionOutputType] = output_type; + if (context->data == nullptr) { + context->data = std::make_shared(); + MS_EXCEPTION_IF_NULL(context->data); + } + context->data->params[kModelOptionOutputType] = output_type; } enum DataType ModelContext::GetOutputType(const std::shared_ptr &context) { @@ -111,24 +152,34 @@ enum DataType ModelContext::GetOutputType(const std::shared_ptr &contex return GetValue(context, kModelOptionOutputType); } -void ModelContext::SetPrecisionMode(const std::shared_ptr &context, const std::string &precision_mode) { +void ModelContext::SetPrecisionMode(const std::shared_ptr &context, const std::vector &precision_mode) { MS_EXCEPTION_IF_NULL(context); - context->params[kModelOptionPrecisionMode] = precision_mode; + if (context->data == nullptr) { + context->data = std::make_shared(); + MS_EXCEPTION_IF_NULL(context->data); + } + context->data->params[kModelOptionPrecisionMode] = CharToString(precision_mode); } -std::string ModelContext::GetPrecisionMode(const std::shared_ptr &context) { +std::vector ModelContext::GetPrecisionModeChar(const std::shared_ptr &context) { MS_EXCEPTION_IF_NULL(context); - return GetValue(context, kModelOptionPrecisionMode); + const std::string &ref = GetValue(context, kModelOptionPrecisionMode); + return StringToChar(ref); } void ModelContext::SetOpSelectImplMode(const std::shared_ptr &context, - const std::string &op_select_impl_mode) { + const std::vector &op_select_impl_mode) { MS_EXCEPTION_IF_NULL(context); - context->params[kModelOptionOpSelectImplMode] = op_select_impl_mode; + if (context->data == nullptr) { + context->data = std::make_shared(); + MS_EXCEPTION_IF_NULL(context->data); + } + context->data->params[kModelOptionOpSelectImplMode] = CharToString(op_select_impl_mode); } -std::string ModelContext::GetOpSelectImplMode(const std::shared_ptr &context) { +std::vector ModelContext::GetOpSelectImplModeChar(const std::shared_ptr &context) { MS_EXCEPTION_IF_NULL(context); - return GetValue(context, kModelOptionOpSelectImplMode); + const std::string &ref = GetValue(context, kModelOptionOpSelectImplMode); + return StringToChar(ref); } } // namespace mindspore diff --git a/tests/ut/stub/cxx_api/graph/ms/ms_graph_impl.cc b/tests/ut/stub/cxx_api/graph/ascend/ascend_graph_impl.cc similarity index 65% rename from tests/ut/stub/cxx_api/graph/ms/ms_graph_impl.cc rename to tests/ut/stub/cxx_api/graph/ascend/ascend_graph_impl.cc index dd84154..8942c1f 100644 --- a/tests/ut/stub/cxx_api/graph/ms/ms_graph_impl.cc +++ b/tests/ut/stub/cxx_api/graph/ascend/ascend_graph_impl.cc @@ -13,38 +13,38 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "cxx_api/graph/ms/ms_graph_impl.h" +#include "cxx_api/graph/ascend/ascend_graph_impl.h" #include #include "include/api/context.h" #include "cxx_api/factory.h" #include "stub/graph_impl_stub.h" namespace mindspore { -API_FACTORY_REG(GraphCell::GraphImpl, Ascend910, MsGraphImpl); +API_FACTORY_REG(GraphCell::GraphImpl, Ascend910, AscendGraphImpl); -std::shared_ptr MsGraphImpl::graph_imp_stub_ = std::make_shared(); +std::shared_ptr AscendGraphImpl::graph_imp_stub_ = std::make_shared(); -MsGraphImpl::MsGraphImpl() {} +AscendGraphImpl::AscendGraphImpl() {} -MsGraphImpl::~MsGraphImpl() {} +AscendGraphImpl::~AscendGraphImpl() {} -std::vector MsGraphImpl::GetInputs() { +std::vector AscendGraphImpl::GetInputs() { if (!graph_imp_stub_) { return {}; } return graph_imp_stub_->GetInputs(); } -std::vector MsGraphImpl::GetOutputs() { +std::vector AscendGraphImpl::GetOutputs() { if (!graph_imp_stub_) { return {}; } return graph_imp_stub_->GetOutputs(); } -Status MsGraphImpl::Load() { return kSuccess; } +Status AscendGraphImpl::Load() { return kSuccess; } -Status MsGraphImpl::Run(const std::vector &inputs, std::vector *outputs) { +Status AscendGraphImpl::Run(const std::vector &inputs, std::vector *outputs) { if (!graph_imp_stub_) { return kMCFailed; } diff --git a/tests/ut/stub/cxx_api/graph/ms/ms_graph_impl.h b/tests/ut/stub/cxx_api/graph/ascend/ascend_graph_impl.h similarity index 80% rename from tests/ut/stub/cxx_api/graph/ms/ms_graph_impl.h rename to tests/ut/stub/cxx_api/graph/ascend/ascend_graph_impl.h index 3a33f81..23aba62 100644 --- a/tests/ut/stub/cxx_api/graph/ms/ms_graph_impl.h +++ b/tests/ut/stub/cxx_api/graph/ascend/ascend_graph_impl.h @@ -13,25 +13,25 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_CXX_API_GRAPH_MS_MS_GRAPH_IMPL_H -#define MINDSPORE_CCSRC_CXX_API_GRAPH_MS_MS_GRAPH_IMPL_H +#ifndef MINDSPORE_CCSRC_CXX_API_GRAPH_MS_ASCEND_GRAPH_IMPL_H +#define MINDSPORE_CCSRC_CXX_API_GRAPH_MS_ASCEND_GRAPH_IMPL_H #include #include #include #include #include #include -#include #include "include/api/status.h" #include "include/api/graph.h" #include "cxx_api/graph/graph_impl.h" #include "cxx_api/model/model_impl.h" namespace mindspore { -class MsGraphImpl : public GraphCell::GraphImpl { + +class AscendGraphImpl : public GraphCell::GraphImpl { public: - MsGraphImpl(); - ~MsGraphImpl() override; + AscendGraphImpl(); + ~AscendGraphImpl() override; Status Run(const std::vector &inputs, std::vector *outputs) override; Status Load() override; @@ -43,4 +43,4 @@ class MsGraphImpl : public GraphCell::GraphImpl { }; } // namespace mindspore -#endif // MINDSPORE_CCSRC_CXX_API_GRAPH_MS_MS_GRAPH_IMPL_H +#endif // MINDSPORE_CCSRC_CXX_API_GRAPH_MS_ASCEND_GRAPH_IMPL_H diff --git a/tests/ut/stub/cxx_api/model/model.cc b/tests/ut/stub/cxx_api/model/model.cc index 900347f..f50fd0c 100644 --- a/tests/ut/stub/cxx_api/model/model.cc +++ b/tests/ut/stub/cxx_api/model/model.cc @@ -17,8 +17,16 @@ #include "include/api/context.h" #include "cxx_api/model/model_impl.h" #include "cxx_api/factory.h" +#include "utils/utils.h" namespace mindspore { +namespace { +const std::map> kSupportedModelMap = { + {kDeviceTypeAscend310, {kOM, kMindIR}}, + {kDeviceTypeAscend910, {kMindIR}}, + {kDeviceTypeGPU, {kMindIR}}, +}; +} Status Model::Build() { MS_EXCEPTION_IF_NULL(impl_); return impl_->Build(); @@ -60,8 +68,22 @@ Model::Model(const std::vector &network, const std::shared_ptr Model::~Model() {} -bool Model::CheckModelSupport(const std::string &device_type, ModelType) { - return Factory::Instance().CheckModelSupport(device_type); -} +bool Model::CheckModelSupport(const std::vector &device_type, ModelType model_type) { + std::string device_type_str = CharToString(device_type); + if (!Factory::Instance().CheckModelSupport(device_type_str)) { + return false; + } + + auto first_iter = kSupportedModelMap.find(device_type_str); + if (first_iter == kSupportedModelMap.end()) { + return false; + } + + auto secend_iter = first_iter->second.find(model_type); + if (secend_iter == first_iter->second.end()) { + return false; + } + return true; +} } // namespace mindspore diff --git a/tests/ut/stub/cxx_api/model/model_impl.h b/tests/ut/stub/cxx_api/model/model_impl.h index 4e66c0f..5efd287 100644 --- a/tests/ut/stub/cxx_api/model/model_impl.h +++ b/tests/ut/stub/cxx_api/model/model_impl.h @@ -25,6 +25,7 @@ #include "include/api/model.h" #include "include/api/graph.h" #include "cxx_api/graph/graph_data.h" +#include "utils/utils.h" namespace mindspore { class ModelImpl { @@ -63,7 +64,9 @@ class ModelImpl { friend class Model; void SetGraph(const std::shared_ptr &graph) { graph_ = graph; } void SetContext(const std::shared_ptr &model_context) { - model_context_ = std::make_shared(*model_context); + if (model_context != nullptr) { + model_context_ = std::make_shared(*model_context); + } } }; } // namespace mindspore diff --git a/tests/ut/stub/cxx_api/model/ms/ms_model.cc b/tests/ut/stub/cxx_api/model/ms/ms_model.cc index c13f923..722b579 100644 --- a/tests/ut/stub/cxx_api/model/ms/ms_model.cc +++ b/tests/ut/stub/cxx_api/model/ms/ms_model.cc @@ -66,6 +66,11 @@ Status MsModel::Build() { MS_LOG(INFO) << "Start build model."; MS_EXCEPTION_IF_NULL(graph_); + if (graph_cell_ != nullptr) { + MS_LOG(INFO) << "This model has been built, skip."; + return kSuccess; + } + auto func_graph = ModelImpl::GetFuncGraph(); MS_EXCEPTION_IF_NULL(func_graph); diff --git a/tests/ut/stub/cxx_api/serialization.cc b/tests/ut/stub/cxx_api/serialization.cc index 7353387..f964eca 100644 --- a/tests/ut/stub/cxx_api/serialization.cc +++ b/tests/ut/stub/cxx_api/serialization.cc @@ -16,7 +16,7 @@ #include "include/api/serialization.h" #include #include "cxx_api/graph/graph_data.h" -#include "utils/utils.h" +#include "utils/log_adapter.h" namespace mindspore { static Buffer ReadFile(const std::string &file) { @@ -77,13 +77,17 @@ Graph Serialization::LoadModel(const void *model_data, size_t data_size, ModelTy MS_LOG(EXCEPTION) << "Unsupported ModelType " << model_type; } -Graph Serialization::LoadModel(const std::string &file, ModelType model_type) { - Buffer data = ReadFile(file); +Graph Serialization::LoadModel(const std::vector &file, ModelType model_type) { + std::string file_path = CharToString(file); + Buffer data = ReadFile(file_path); if (data.Data() == nullptr) { - MS_LOG(EXCEPTION) << "Read file " << file << " failed."; + MS_LOG(EXCEPTION) << "Read file " << file_path << " failed."; } if (model_type == kMindIR) { auto anf_graph = std::make_shared(); + if (anf_graph == nullptr) { + MS_LOG(EXCEPTION) << "Load model failed."; + } return Graph(std::make_shared(anf_graph, kMindIR)); } else if (model_type == kOM) { return Graph(std::make_shared(data, kOM)); diff --git a/tests/ut/stub/cxx_api/status.cc b/tests/ut/stub/cxx_api/status.cc new file mode 100644 index 0000000..43a386e --- /dev/null +++ b/tests/ut/stub/cxx_api/status.cc @@ -0,0 +1,207 @@ +/** + * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). + * + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "include/api/status.h" +#ifndef ENABLE_ANDROID +#include +#endif +#include +#include + +namespace mindspore { +struct Status::Data { + enum StatusCode status_code = kSuccess; + std::string status_msg; + int line_of_code = -1; + std::string file_name; + std::string err_description; +}; + +Status::Status() : data_(std::make_shared()) {} + +Status::Status(enum StatusCode status_code, const std::vector &status_msg) : data_(std::make_shared()) { + if (data_ == nullptr) { + return; + } + + data_->status_msg = CharToString(status_msg); + data_->status_code = status_code; +} + +Status::Status(enum StatusCode code, int line_of_code, const char *file_name, const std::vector &extra) + : data_(std::make_shared()) { + if (data_ == nullptr) { + return; + } + data_->status_code = code; + data_->line_of_code = line_of_code; + if (file_name != nullptr) { + data_->file_name = file_name; + } + data_->err_description = CharToString(extra); + + std::ostringstream ss; +#ifndef ENABLE_ANDROID + ss << "Thread ID " << std::this_thread::get_id() << " " << CodeAsString(code) << ". "; + if (!data_->err_description.empty()) { + ss << data_->err_description; + } + ss << "\n"; +#endif + + ss << "Line of code : " << line_of_code << "\n"; + if (file_name != nullptr) { + ss << "File : " << file_name << "\n"; + } + data_->status_msg = ss.str(); +} + +enum StatusCode Status::StatusCode() const { + if (data_ == nullptr) { + return kSuccess; + } + return data_->status_code; +} + +std::vector Status::ToCString() const { + if (data_ == nullptr) { + return std::vector(); + } + return StringToChar(data_->status_msg); +} + +int Status::GetLineOfCode() const { + if (data_ == nullptr) { + return -1; + } + return data_->line_of_code; +} + +std::vector Status::GetErrDescriptionChar() const { + if (data_ == nullptr) { + return std::vector(); + } + return StringToChar(data_->status_msg); +} + +std::vector Status::CodeAsCString(enum StatusCode c) { + static std::map info_map = {{kSuccess, "No error occurs."}, + // Core + {kCoreFailed, "Common error code."}, + // MD + {kMDOutOfMemory, "Out of memory"}, + {kMDShapeMisMatch, "Shape is incorrect"}, + {kMDInterrupted, "Interrupted system call"}, + {kMDNoSpace, "No space left on device"}, + {kMDPyFuncException, "Exception thrown from PyFunc"}, + {kMDDuplicateKey, "Duplicate key"}, + {kMDPythonInterpreterFailure, ""}, + {kMDTDTPushFailure, "Unexpected error"}, + {kMDFileNotExist, "Unexpected error"}, + {kMDProfilingError, "Error encountered while profiling"}, + {kMDBoundingBoxOutOfBounds, "Unexpected error"}, + {kMDBoundingBoxInvalidShape, "Unexpected error"}, + {kMDSyntaxError, "Syntax error"}, + {kMDTimeOut, "Unexpected error"}, + {kMDBuddySpaceFull, "BuddySpace full"}, + {kMDNetWorkError, "Network error"}, + {kMDNotImplementedYet, "Unexpected error"}, + {kMDUnexpectedError, "Unexpected error"}, + // ME + {kMEFailed, "Common error code."}, + {kMEInvalidInput, "Invalid input."}, + // MC + {kMCFailed, "Common error code."}, + {kMCDeviceError, "Device error."}, + {kMCInvalidInput, "Invalid input."}, + {kMCInvalidArgs, "Invalid arguments."}, + // Lite + {kLiteError, "Common error code."}, + {kLiteNullptr, "NULL pointer returned."}, + {kLiteParamInvalid, "Invalid parameter."}, + {kLiteNoChange, "No change."}, + {kLiteSuccessExit, "No error but exit."}, + {kLiteMemoryFailed, "Fail to create memory."}, + {kLiteNotSupport, "Fail to support."}, + {kLiteThreadPoolError, "Thread pool error."}, + {kLiteOutOfTensorRange, "Failed to check range."}, + {kLiteInputTensorError, "Failed to check input tensor."}, + {kLiteReentrantError, "Exist executor running."}, + {kLiteGraphFileError, "Failed to verify graph file."}, + {kLiteNotFindOp, "Failed to find operator."}, + {kLiteInvalidOpName, "Invalid operator name."}, + {kLiteInvalidOpAttr, "Invalid operator attr."}, + {kLiteOpExecuteFailure, "Failed to execution operator."}, + {kLiteFormatError, "Failed to checking tensor format."}, + {kLiteInferError, "Failed to infer shape."}, + {kLiteInferInvalid, "Invalid infer shape before runtime."}, + {kLiteInputParamInvalid, "Invalid input param by user."}}; + auto iter = info_map.find(c); + return StringToChar(iter == info_map.end() ? "Unknown error" : iter->second); +} + +std::ostream &operator<<(std::ostream &os, const Status &s) { + os << s.ToString(); + return os; +} + +std::vector Status::SetErrDescription(const std::vector &err_description) { + if (data_ == nullptr) { + return std::vector(); + } + data_->err_description = CharToString(err_description); + std::ostringstream ss; +#ifndef ENABLE_ANDROID + ss << "Thread ID " << std::this_thread::get_id() << " " << CodeAsString(data_->status_code) << ". "; + if (!data_->err_description.empty()) { + ss << data_->err_description; + } + ss << "\n"; +#endif + + if (data_->line_of_code > 0 && !data_->file_name.empty()) { + ss << "Line of code : " << data_->line_of_code << "\n"; + ss << "File : " << data_->file_name << "\n"; + } + data_->status_msg = ss.str(); + return StringToChar(data_->status_msg); +} + +bool Status::operator==(const Status &other) const { + if (data_ == nullptr && other.data_ == nullptr) { + return true; + } + + if (data_ == nullptr || other.data_ == nullptr) { + return false; + } + + return data_->status_code == other.data_->status_code; +} + +bool Status::operator==(enum StatusCode other_code) const { return StatusCode() == other_code; } +bool Status::operator!=(const Status &other) const { return !operator==(other); } +bool Status::operator!=(enum StatusCode other_code) const { return !operator==(other_code); } + +Status::operator bool() const { return (StatusCode() == kSuccess); } +Status::operator int() const { return static_cast(StatusCode()); } + +Status Status::OK() { return StatusCode::kSuccess; } +bool Status::IsOk() const { return (StatusCode() == StatusCode::kSuccess); } +bool Status::IsError() const { return !IsOk(); } +} // namespace mindspore diff --git a/tests/ut/stub/cxx_api/types.cc b/tests/ut/stub/cxx_api/types.cc index 457b5ef..13a064b 100644 --- a/tests/ut/stub/cxx_api/types.cc +++ b/tests/ut/stub/cxx_api/types.cc @@ -133,10 +133,11 @@ class TensorReferenceImpl : public MSTensor::Impl { std::vector shape_; }; -MSTensor MSTensor::CreateTensor(const std::string &name, enum DataType type, const std::vector &shape, +MSTensor MSTensor::CreateTensor(const std::vector &name, enum DataType type, const std::vector &shape, const void *data, size_t data_len) noexcept { + std::string name_str = CharToString(name); try { - std::shared_ptr impl = std::make_shared(name, type, shape, data, data_len); + std::shared_ptr impl = std::make_shared(name_str, type, shape, data, data_len); return MSTensor(impl); } catch (const std::bad_alloc &) { MS_LOG(ERROR) << "Malloc memory failed."; @@ -147,10 +148,11 @@ MSTensor MSTensor::CreateTensor(const std::string &name, enum DataType type, con } } -MSTensor MSTensor::CreateRefTensor(const std::string &name, enum DataType type, const std::vector &shape, +MSTensor MSTensor::CreateRefTensor(const std::vector &name, enum DataType type, const std::vector &shape, const void *data, size_t data_len) noexcept { + std::string name_str = CharToString(name); try { - std::shared_ptr impl = std::make_shared(name, type, shape, data, data_len); + std::shared_ptr impl = std::make_shared(name_str, type, shape, data, data_len); return MSTensor(impl); } catch (const std::bad_alloc &) { MS_LOG(ERROR) << "Malloc memory failed."; @@ -164,9 +166,9 @@ MSTensor MSTensor::CreateRefTensor(const std::string &name, enum DataType type, MSTensor::MSTensor() : impl_(std::make_shared()) {} MSTensor::MSTensor(std::nullptr_t) : impl_(nullptr) {} MSTensor::MSTensor(const std::shared_ptr &impl) : impl_(impl) { MS_EXCEPTION_IF_NULL(impl); } -MSTensor::MSTensor(const std::string &name, enum DataType type, const std::vector &shape, const void *data, - size_t data_len) - : impl_(std::make_shared(name, type, shape, data, data_len)) {} +MSTensor::MSTensor(const std::vector &name, enum DataType type, const std::vector &shape, + const void *data, size_t data_len) + : impl_(std::make_shared(CharToString(name), type, shape, data, data_len)) {} MSTensor::~MSTensor() = default; bool MSTensor::operator==(std::nullptr_t) const { return impl_ == nullptr; } @@ -178,9 +180,9 @@ MSTensor MSTensor::Clone() const { return ret; } -const std::string &MSTensor::Name() const { +std::vector MSTensor::CharName() const { MS_EXCEPTION_IF_NULL(impl_); - return impl_->Name(); + return StringToChar(impl_->Name()); } enum DataType MSTensor::DataType() const { diff --git a/tests/ut/stub/include/api/context.h b/tests/ut/stub/include/api/context.h index 0aea49d..90dfa40 100644 --- a/tests/ut/stub/include/api/context.h +++ b/tests/ut/stub/include/api/context.h @@ -16,49 +16,120 @@ #ifndef MINDSPORE_INCLUDE_API_CONTEXT_H #define MINDSPORE_INCLUDE_API_CONTEXT_H -#include -#include #include #include +#include #include "include/api/types.h" +#include "include/api/dual_abi_helper.h" namespace mindspore { constexpr auto kDeviceTypeAscend310 = "Ascend310"; constexpr auto kDeviceTypeAscend910 = "Ascend910"; +constexpr auto kDeviceTypeGPU = "GPU"; struct MS_API Context { + public: + Context(); virtual ~Context() = default; - std::map params; + struct Data; + std::shared_ptr data; }; struct MS_API GlobalContext : public Context { + public: static std::shared_ptr GetGlobalContext(); - static void SetGlobalDeviceTarget(const std::string &device_target); - static std::string GetGlobalDeviceTarget(); + static inline void SetGlobalDeviceTarget(const std::string &device_target); + static inline std::string GetGlobalDeviceTarget(); static void SetGlobalDeviceID(const uint32_t &device_id); static uint32_t GetGlobalDeviceID(); + + private: + // api without std::string + static void SetGlobalDeviceTarget(const std::vector &device_target); + static std::vector GetGlobalDeviceTargetChar(); }; struct MS_API ModelContext : public Context { - static void SetInsertOpConfigPath(const std::shared_ptr &context, const std::string &cfg_path); - static std::string GetInsertOpConfigPath(const std::shared_ptr &context); + public: + static inline void SetInsertOpConfigPath(const std::shared_ptr &context, const std::string &cfg_path); + static inline std::string GetInsertOpConfigPath(const std::shared_ptr &context); - static void SetInputFormat(const std::shared_ptr &context, const std::string &format); - static std::string GetInputFormat(const std::shared_ptr &context); + static inline void SetInputFormat(const std::shared_ptr &context, const std::string &format); + static inline std::string GetInputFormat(const std::shared_ptr &context); - static void SetInputShape(const std::shared_ptr &context, const std::string &shape); - static std::string GetInputShape(const std::shared_ptr &context); + static inline void SetInputShape(const std::shared_ptr &context, const std::string &shape); + static inline std::string GetInputShape(const std::shared_ptr &context); static void SetOutputType(const std::shared_ptr &context, enum DataType output_type); static enum DataType GetOutputType(const std::shared_ptr &context); - static void SetPrecisionMode(const std::shared_ptr &context, const std::string &precision_mode); - static std::string GetPrecisionMode(const std::shared_ptr &context); + static inline void SetPrecisionMode(const std::shared_ptr &context, const std::string &precision_mode); + static inline std::string GetPrecisionMode(const std::shared_ptr &context); + + static inline void SetOpSelectImplMode(const std::shared_ptr &context, + const std::string &op_select_impl_mode); + static inline std::string GetOpSelectImplMode(const std::shared_ptr &context); + + private: + // api without std::string + static void SetInsertOpConfigPath(const std::shared_ptr &context, const std::vector &cfg_path); + static std::vector GetInsertOpConfigPathChar(const std::shared_ptr &context); + + static void SetInputFormat(const std::shared_ptr &context, const std::vector &format); + static std::vector GetInputFormatChar(const std::shared_ptr &context); + + static void SetInputShape(const std::shared_ptr &context, const std::vector &shape); + static std::vector GetInputShapeChar(const std::shared_ptr &context); + + static void SetPrecisionMode(const std::shared_ptr &context, const std::vector &precision_mode); + static std::vector GetPrecisionModeChar(const std::shared_ptr &context); - static void SetOpSelectImplMode(const std::shared_ptr &context, const std::string &op_select_impl_mode); - static std::string GetOpSelectImplMode(const std::shared_ptr &context); + static void SetOpSelectImplMode(const std::shared_ptr &context, + const std::vector &op_select_impl_mode); + static std::vector GetOpSelectImplModeChar(const std::shared_ptr &context); }; + +void GlobalContext::SetGlobalDeviceTarget(const std::string &device_target) { + SetGlobalDeviceTarget(StringToChar(device_target)); +} +std::string GlobalContext::GetGlobalDeviceTarget() { return CharToString(GetGlobalDeviceTargetChar()); } + +void ModelContext::SetInsertOpConfigPath(const std::shared_ptr &context, const std::string &cfg_path) { + SetInsertOpConfigPath(context, StringToChar(cfg_path)); +} +std::string ModelContext::GetInsertOpConfigPath(const std::shared_ptr &context) { + return CharToString(GetInsertOpConfigPathChar(context)); +} + +void ModelContext::SetInputFormat(const std::shared_ptr &context, const std::string &format) { + SetInputFormat(context, StringToChar(format)); +} +std::string ModelContext::GetInputFormat(const std::shared_ptr &context) { + return CharToString(GetInputFormatChar(context)); +} + +void ModelContext::SetInputShape(const std::shared_ptr &context, const std::string &shape) { + SetInputShape(context, StringToChar(shape)); +} +std::string ModelContext::GetInputShape(const std::shared_ptr &context) { + return CharToString(GetInputShapeChar(context)); +} + +void ModelContext::SetPrecisionMode(const std::shared_ptr &context, const std::string &precision_mode) { + SetPrecisionMode(context, StringToChar(precision_mode)); +} +std::string ModelContext::GetPrecisionMode(const std::shared_ptr &context) { + return CharToString(GetPrecisionModeChar(context)); +} + +void ModelContext::SetOpSelectImplMode(const std::shared_ptr &context, + const std::string &op_select_impl_mode) { + SetOpSelectImplMode(context, StringToChar(op_select_impl_mode)); +} +std::string ModelContext::GetOpSelectImplMode(const std::shared_ptr &context) { + return CharToString(GetOpSelectImplModeChar(context)); +} } // namespace mindspore #endif // MINDSPORE_INCLUDE_API_CONTEXT_H diff --git a/tests/ut/stub/include/api/data_type.h b/tests/ut/stub/include/api/data_type.h index c5e3b1d..a39488a 100644 --- a/tests/ut/stub/include/api/data_type.h +++ b/tests/ut/stub/include/api/data_type.h @@ -1,7 +1,5 @@ /** - * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). - * - * Copyright 2019-2020 Huawei Technologies Co., Ltd + * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,7 +13,6 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - #ifndef MINDSPORE_INCLUDE_API_DATA_TYPE_H_ #define MINDSPORE_INCLUDE_API_DATA_TYPE_H_ diff --git a/tests/ut/stub/include/api/dual_abi_helper.h b/tests/ut/stub/include/api/dual_abi_helper.h new file mode 100644 index 0000000..6bf9c6e --- /dev/null +++ b/tests/ut/stub/include/api/dual_abi_helper.h @@ -0,0 +1,26 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_INCLUDE_API_DUAL_ABI_HELPER_H_ +#define MINDSPORE_INCLUDE_API_DUAL_ABI_HELPER_H_ + +#include +#include + +namespace mindspore { +inline std::vector StringToChar(const std::string &s) { return std::vector(s.begin(), s.end()); } +inline std::string CharToString(const std::vector &c) { return std::string(c.begin(), c.end()); } +} // namespace mindspore +#endif // MINDSPORE_INCLUDE_API_DUAL_ABI_HELPER_H_ diff --git a/tests/ut/stub/include/api/graph.h b/tests/ut/stub/include/api/graph.h index a9288eb..892f604 100644 --- a/tests/ut/stub/include/api/graph.h +++ b/tests/ut/stub/include/api/graph.h @@ -17,7 +17,6 @@ #define MINDSPORE_INCLUDE_API_GRAPH_H #include -#include #include #include #include diff --git a/tests/ut/stub/include/api/model.h b/tests/ut/stub/include/api/model.h index 8d40108..78f202f 100644 --- a/tests/ut/stub/include/api/model.h +++ b/tests/ut/stub/include/api/model.h @@ -25,6 +25,7 @@ #include "include/api/types.h" #include "include/api/graph.h" #include "include/api/cell.h" +#include "include/api/dual_abi_helper.h" namespace mindspore { class ModelImpl; @@ -46,10 +47,16 @@ class MS_API Model { std::vector GetInputs(); std::vector GetOutputs(); - static bool CheckModelSupport(const std::string &device_type, ModelType model_type); + static inline bool CheckModelSupport(const std::string &device_type, ModelType model_type); private: + // api without std::string + static bool CheckModelSupport(const std::vector &device_type, ModelType model_type); std::shared_ptr impl_; }; + +bool Model::CheckModelSupport(const std::string &device_type, ModelType model_type) { + return CheckModelSupport(StringToChar(device_type), model_type); +} } // namespace mindspore #endif // MINDSPORE_INCLUDE_API_MODEL_H diff --git a/tests/ut/stub/include/api/serialization.h b/tests/ut/stub/include/api/serialization.h index 2c34b82..c5fb61e 100644 --- a/tests/ut/stub/include/api/serialization.h +++ b/tests/ut/stub/include/api/serialization.h @@ -24,16 +24,24 @@ #include "include/api/types.h" #include "include/api/model.h" #include "include/api/graph.h" +#include "include/api/dual_abi_helper.h" namespace mindspore { class MS_API Serialization { public: static Graph LoadModel(const void *model_data, size_t data_size, ModelType model_type); - static Graph LoadModel(const std::string &file, ModelType model_type); + inline static Graph LoadModel(const std::string &file, ModelType model_type); static Status LoadCheckPoint(const std::string &ckpt_file, std::map *parameters); static Status SetParameters(const std::map ¶meters, Model *model); static Status ExportModel(const Model &model, ModelType model_type, Buffer *model_data); static Status ExportModel(const Model &model, ModelType model_type, const std::string &model_file); + + private: + static Graph LoadModel(const std::vector &file, ModelType model_type); }; + +Graph Serialization::LoadModel(const std::string &file, ModelType model_type) { + return LoadModel(StringToChar(file), model_type); +} } // namespace mindspore #endif // MINDSPORE_INCLUDE_API_SERIALIZATION_H diff --git a/tests/ut/stub/include/api/status.h b/tests/ut/stub/include/api/status.h index 155959b..3a5c7c4 100644 --- a/tests/ut/stub/include/api/status.h +++ b/tests/ut/stub/include/api/status.h @@ -16,9 +16,13 @@ #ifndef MINDSPORE_INCLUDE_API_STATUS_H #define MINDSPORE_INCLUDE_API_STATUS_H +#include #include +#include #include #include +#include "include/api/dual_abi_helper.h" +#include "include/api/types.h" namespace mindspore { enum CompCode : uint32_t { @@ -100,39 +104,61 @@ enum StatusCode : uint32_t { kLiteInputParamInvalid = kLite | (0x0FFFFFFF & -600), /**< Invalid input param by user. */ }; -class Status { +class MS_API Status { public: - Status() : status_code_(kSuccess) {} - Status(enum StatusCode status_code, const std::string &status_msg = "") // NOLINT(runtime/explicit) - : status_code_(status_code), status_msg_(status_msg) {} - Status(const StatusCode code, int line_of_code, const char *file_name, const std::string &extra = ""); + Status(); + inline Status(enum StatusCode status_code, const std::string &status_msg = ""); // NOLINT(runtime/explicit) + inline Status(const StatusCode code, int line_of_code, const char *file_name, const std::string &extra = ""); ~Status() = default; - enum StatusCode StatusCode() const { return status_code_; } - const std::string &ToString() const { return status_msg_; } + enum StatusCode StatusCode() const; + inline std::string ToString() const; + + int GetLineOfCode() const; + inline std::string GetErrDescription() const; + inline std::string SetErrDescription(const std::string &err_description); friend std::ostream &operator<<(std::ostream &os, const Status &s); - bool operator==(const Status &other) const { return status_code_ == other.status_code_; } - bool operator==(enum StatusCode other_code) const { return status_code_ == other_code; } - bool operator!=(const Status &other) const { return status_code_ != other.status_code_; } - bool operator!=(enum StatusCode other_code) const { return status_code_ != other_code; } + bool operator==(const Status &other) const; + bool operator==(enum StatusCode other_code) const; + bool operator!=(const Status &other) const; + bool operator!=(enum StatusCode other_code) const; - explicit operator bool() const { return (status_code_ == kSuccess); } - explicit operator int() const { return static_cast(status_code_); } + explicit operator bool() const; + explicit operator int() const; - static Status OK() { return Status(StatusCode::kSuccess); } + static Status OK(); - bool IsOk() const { return (StatusCode() == StatusCode::kSuccess); } + bool IsOk() const; - bool IsError() const { return !IsOk(); } + bool IsError() const; - static std::string CodeAsString(enum StatusCode c); + static inline std::string CodeAsString(enum StatusCode c); private: - enum StatusCode status_code_; - std::string status_msg_; + // api without std::string + explicit Status(enum StatusCode status_code, const std::vector &status_msg); + Status(const enum StatusCode code, int line_of_code, const char *file_name, const std::vector &extra); + std::vector ToCString() const; + std::vector GetErrDescriptionChar() const; + std::vector SetErrDescription(const std::vector &err_description); + static std::vector CodeAsCString(enum StatusCode c); + + struct Data; + std::shared_ptr data_; }; + +Status::Status(enum StatusCode status_code, const std::string &status_msg) + : Status(status_code, StringToChar(status_msg)) {} +Status::Status(const enum StatusCode code, int line_of_code, const char *file_name, const std::string &extra) + : Status(code, line_of_code, file_name, StringToChar(extra)) {} +std::string Status::ToString() const { return CharToString(ToCString()); } +std::string Status::GetErrDescription() const { return CharToString(GetErrDescriptionChar()); } +std::string Status::SetErrDescription(const std::string &err_description) { + return CharToString(SetErrDescription(StringToChar(err_description))); +} +std::string Status::CodeAsString(enum StatusCode c) { return CharToString(CodeAsCString(c)); } } // namespace mindspore #endif // MINDSPORE_INCLUDE_API_STATUS_H diff --git a/tests/ut/stub/include/api/types.h b/tests/ut/stub/include/api/types.h index 8b9c79f..ea03ca9 100644 --- a/tests/ut/stub/include/api/types.h +++ b/tests/ut/stub/include/api/types.h @@ -21,22 +21,10 @@ #include #include #include "include/api/data_type.h" +#include "include/api/dual_abi_helper.h" -// refer to https://gcc.gnu.org/wiki/Visibility -#if defined _WIN32 || defined __CYGWIN__ -#ifdef BUILDING_DLL -#ifdef __GNUC__ -#define MS_API __attribute__((dllexport)) -#else -#define MS_API __declspec(dllexport) // Note: actually gcc seems to also supports this syntax. -#endif -#else -#ifdef __GNUC__ -#define MS_API __attribute__((dllimport)) -#else -#define MS_API __declspec(dllimport) // Note: actually gcc seems to also supports this syntax. -#endif -#endif +#ifdef _WIN32 +#define MS_API __declspec(dllexport) #else #define MS_API __attribute__((visibility("default"))) #endif @@ -55,18 +43,18 @@ class MS_API MSTensor { public: class Impl; - static MSTensor CreateTensor(const std::string &name, DataType type, const std::vector &shape, - const void *data, size_t data_len) noexcept; - static MSTensor CreateRefTensor(const std::string &name, DataType type, const std::vector &shape, - const void *data, size_t data_len) noexcept; + static inline MSTensor CreateTensor(const std::string &name, DataType type, const std::vector &shape, + const void *data, size_t data_len) noexcept; + static inline MSTensor CreateRefTensor(const std::string &name, DataType type, const std::vector &shape, + const void *data, size_t data_len) noexcept; MSTensor(); explicit MSTensor(const std::shared_ptr &impl); - MSTensor(const std::string &name, DataType type, const std::vector &shape, const void *data, - size_t data_len); + inline MSTensor(const std::string &name, DataType type, const std::vector &shape, const void *data, + size_t data_len); ~MSTensor(); - const std::string &Name() const; + inline std::string Name() const; enum DataType DataType() const; const std::vector &Shape() const; int64_t ElementNum() const; @@ -81,6 +69,15 @@ class MS_API MSTensor { bool operator==(std::nullptr_t) const; private: + // api without std::string + static MSTensor CreateTensor(const std::vector &name, enum DataType type, const std::vector &shape, + const void *data, size_t data_len) noexcept; + static MSTensor CreateRefTensor(const std::vector &name, enum DataType type, const std::vector &shape, + const void *data, size_t data_len) noexcept; + MSTensor(const std::vector &name, enum DataType type, const std::vector &shape, const void *data, + size_t data_len); + std::vector CharName() const; + friend class ModelImpl; explicit MSTensor(std::nullptr_t); std::shared_ptr impl_; @@ -123,5 +120,21 @@ class MS_API Buffer { class Impl; std::shared_ptr impl_; }; + +MSTensor MSTensor::CreateTensor(const std::string &name, enum DataType type, const std::vector &shape, + const void *data, size_t data_len) noexcept { + return CreateTensor(StringToChar(name), type, shape, data, data_len); +} + +MSTensor MSTensor::CreateRefTensor(const std::string &name, enum DataType type, const std::vector &shape, + const void *data, size_t data_len) noexcept { + return CreateRefTensor(StringToChar(name), type, shape, data, data_len); +} + +MSTensor::MSTensor(const std::string &name, enum DataType type, const std::vector &shape, const void *data, + size_t data_len) + : MSTensor(StringToChar(name), type, shape, data, data_len) {} + +std::string MSTensor::Name() const { return CharToString(CharName()); } } // namespace mindspore #endif // MINDSPORE_INCLUDE_API_TYPES_H diff --git a/tests/ut/stub/include/utils/utils.h b/tests/ut/stub/include/utils/utils.h index 2fe4feb..9012c41 100644 --- a/tests/ut/stub/include/utils/utils.h +++ b/tests/ut/stub/include/utils/utils.h @@ -21,6 +21,7 @@ #include #include #include +#include #include "utils/log_adapter.h" namespace mindspore { diff --git a/third_party/mindspore b/third_party/mindspore index 52fac12..dd22b5e 160000 --- a/third_party/mindspore +++ b/third_party/mindspore @@ -1 +1 @@ -Subproject commit 52fac12367131ec57e87ba757e42fc25479f433a +Subproject commit dd22b5ea7106baf494704be04e2dbaad6887f0ab