| @@ -41,6 +41,7 @@ cmake-build-debug | |||
| *.pb.h | |||
| *.pb.cc | |||
| *.pb | |||
| *_grpc.py | |||
| # Object files | |||
| *.o | |||
| @@ -24,20 +24,20 @@ | |||
| namespace mindspore { | |||
| namespace inference { | |||
| enum Status { SUCCESS = 0, FAILED, INVALID_INPUTS }; | |||
| class MS_API InferSession { | |||
| public: | |||
| InferSession() = default; | |||
| virtual ~InferSession() = default; | |||
| virtual bool InitEnv(const std::string &device_type, uint32_t device_id) = 0; | |||
| virtual bool FinalizeEnv() = 0; | |||
| virtual bool LoadModelFromFile(const std::string &file_name, uint32_t &model_id) = 0; | |||
| virtual bool UnloadModel(uint32_t model_id) = 0; | |||
| virtual Status InitEnv(const std::string &device_type, uint32_t device_id) = 0; | |||
| virtual Status FinalizeEnv() = 0; | |||
| virtual Status LoadModelFromFile(const std::string &file_name, uint32_t &model_id) = 0; | |||
| virtual Status UnloadModel(uint32_t model_id) = 0; | |||
| // override this method to avoid request/reply data copy | |||
| virtual bool ExecuteModel(uint32_t model_id, const RequestBase &request, ReplyBase &reply) = 0; | |||
| virtual Status ExecuteModel(uint32_t model_id, const RequestBase &request, ReplyBase &reply) = 0; | |||
| virtual bool ExecuteModel(uint32_t model_id, const std::vector<InferTensor> &inputs, | |||
| std::vector<InferTensor> &outputs) { | |||
| virtual Status ExecuteModel(uint32_t model_id, const std::vector<InferTensor> &inputs, | |||
| std::vector<InferTensor> &outputs) { | |||
| VectorInferTensorWrapRequest request(inputs); | |||
| VectorInferTensorWrapReply reply(outputs); | |||
| return ExecuteModel(model_id, request, reply); | |||
| @@ -37,8 +37,8 @@ namespace mindspore::inference { | |||
| std::shared_ptr<InferSession> InferSession::CreateSession(const std::string &device, uint32_t device_id) { | |||
| try { | |||
| auto session = std::make_shared<MSInferSession>(); | |||
| bool ret = session->InitEnv(device, device_id); | |||
| if (!ret) { | |||
| Status ret = session->InitEnv(device, device_id); | |||
| if (ret != SUCCESS) { | |||
| return nullptr; | |||
| } | |||
| return session; | |||
| @@ -84,21 +84,21 @@ std::shared_ptr<std::vector<char>> MSInferSession::ReadFile(const std::string &f | |||
| return buf; | |||
| } | |||
| bool MSInferSession::LoadModelFromFile(const std::string &file_name, uint32_t &model_id) { | |||
| Status MSInferSession::LoadModelFromFile(const std::string &file_name, uint32_t &model_id) { | |||
| auto graphBuf = ReadFile(file_name); | |||
| if (graphBuf == nullptr) { | |||
| MS_LOG(ERROR) << "Read model file failed, file name is " << file_name.c_str(); | |||
| return false; | |||
| return FAILED; | |||
| } | |||
| auto graph = LoadModel(graphBuf->data(), graphBuf->size(), device_type_); | |||
| if (graph == nullptr) { | |||
| MS_LOG(ERROR) << "Load graph model failed, file name is " << file_name.c_str(); | |||
| return false; | |||
| return FAILED; | |||
| } | |||
| bool ret = CompileGraph(graph, model_id); | |||
| if (!ret) { | |||
| Status ret = CompileGraph(graph, model_id); | |||
| if (ret != SUCCESS) { | |||
| MS_LOG(ERROR) << "Compile graph model failed, file name is " << file_name.c_str(); | |||
| return false; | |||
| return FAILED; | |||
| } | |||
| MS_LOG(INFO) << "Load model from file " << file_name << " success"; | |||
| @@ -107,14 +107,14 @@ bool MSInferSession::LoadModelFromFile(const std::string &file_name, uint32_t &m | |||
| rtError_t rt_ret = rtCtxGetCurrent(&context_); | |||
| if (rt_ret != RT_ERROR_NONE || context_ == nullptr) { | |||
| MS_LOG(ERROR) << "the ascend device context is null"; | |||
| return false; | |||
| return FAILED; | |||
| } | |||
| #endif | |||
| return true; | |||
| return SUCCESS; | |||
| } | |||
| bool MSInferSession::UnloadModel(uint32_t model_id) { return true; } | |||
| Status MSInferSession::UnloadModel(uint32_t model_id) { return SUCCESS; } | |||
| tensor::TensorPtr ServingTensor2MSTensor(const InferTensorBase &out_tensor) { | |||
| std::vector<int> shape; | |||
| @@ -170,16 +170,16 @@ void MSTensor2ServingTensor(tensor::TensorPtr ms_tensor, InferTensorBase &out_te | |||
| out_tensor.set_data(ms_tensor->data_c(), ms_tensor->Size()); | |||
| } | |||
| bool MSInferSession::ExecuteModel(uint32_t model_id, const RequestBase &request, ReplyBase &reply) { | |||
| Status MSInferSession::ExecuteModel(uint32_t model_id, const RequestBase &request, ReplyBase &reply) { | |||
| #ifdef ENABLE_D | |||
| if (context_ == nullptr) { | |||
| MS_LOG(ERROR) << "rtCtx is nullptr"; | |||
| return false; | |||
| return FAILED; | |||
| } | |||
| rtError_t rt_ret = rtCtxSetCurrent(context_); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| MS_LOG(ERROR) << "set Ascend rtCtx failed"; | |||
| return false; | |||
| return FAILED; | |||
| } | |||
| #endif | |||
| @@ -187,47 +187,47 @@ bool MSInferSession::ExecuteModel(uint32_t model_id, const RequestBase &request, | |||
| for (size_t i = 0; i < request.size(); i++) { | |||
| if (request[i] == nullptr) { | |||
| MS_LOG(ERROR) << "Execute Model " << model_id << " Failed, input tensor is null, index " << i; | |||
| return false; | |||
| return FAILED; | |||
| } | |||
| auto input = ServingTensor2MSTensor(*request[i]); | |||
| if (input == nullptr) { | |||
| MS_LOG(ERROR) << "Tensor convert failed"; | |||
| return false; | |||
| return FAILED; | |||
| } | |||
| inputs.push_back(input); | |||
| } | |||
| if (!CheckModelInputs(model_id, inputs)) { | |||
| MS_LOG(ERROR) << "Check Model " << model_id << " Inputs Failed"; | |||
| return false; | |||
| return INVALID_INPUTS; | |||
| } | |||
| vector<tensor::TensorPtr> outputs = RunGraph(model_id, inputs); | |||
| if (outputs.empty()) { | |||
| MS_LOG(ERROR) << "Execute Model " << model_id << " Failed"; | |||
| return false; | |||
| return FAILED; | |||
| } | |||
| reply.clear(); | |||
| for (const auto &tensor : outputs) { | |||
| auto out_tensor = reply.add(); | |||
| if (out_tensor == nullptr) { | |||
| MS_LOG(ERROR) << "Execute Model " << model_id << " Failed, add output tensor failed"; | |||
| return false; | |||
| return FAILED; | |||
| } | |||
| MSTensor2ServingTensor(tensor, *out_tensor); | |||
| } | |||
| return true; | |||
| return SUCCESS; | |||
| } | |||
| bool MSInferSession::FinalizeEnv() { | |||
| Status MSInferSession::FinalizeEnv() { | |||
| auto ms_context = MsContext::GetInstance(); | |||
| if (ms_context == nullptr) { | |||
| MS_LOG(ERROR) << "Get Context failed!"; | |||
| return false; | |||
| return FAILED; | |||
| } | |||
| if (!ms_context->CloseTsd()) { | |||
| MS_LOG(ERROR) << "Inference CloseTsd failed!"; | |||
| return false; | |||
| return FAILED; | |||
| } | |||
| return true; | |||
| return SUCCESS; | |||
| } | |||
| std::shared_ptr<FuncGraph> MSInferSession::LoadModel(const char *model_buf, size_t size, const std::string &device) { | |||
| @@ -292,16 +292,16 @@ void MSInferSession::RegAllOp() { | |||
| return; | |||
| } | |||
| bool MSInferSession::CompileGraph(std::shared_ptr<FuncGraph> funcGraphPtr, uint32_t &model_id) { | |||
| Status MSInferSession::CompileGraph(std::shared_ptr<FuncGraph> funcGraphPtr, uint32_t &model_id) { | |||
| MS_ASSERT(session_impl_ != nullptr); | |||
| try { | |||
| auto graph_id = session_impl_->CompileGraph(NOT_NULL(funcGraphPtr)); | |||
| py::gil_scoped_release gil_release; | |||
| model_id = graph_id; | |||
| return true; | |||
| return SUCCESS; | |||
| } catch (std::exception &e) { | |||
| MS_LOG(ERROR) << "Inference CompileGraph failed"; | |||
| return false; | |||
| return FAILED; | |||
| } | |||
| } | |||
| @@ -327,31 +327,31 @@ string MSInferSession::AjustTargetName(const std::string &device) { | |||
| } | |||
| } | |||
| bool MSInferSession::InitEnv(const std::string &device, uint32_t device_id) { | |||
| Status MSInferSession::InitEnv(const std::string &device, uint32_t device_id) { | |||
| RegAllOp(); | |||
| auto ms_context = MsContext::GetInstance(); | |||
| ms_context->set_execution_mode(kGraphMode); | |||
| ms_context->set_device_id(device_id); | |||
| auto ajust_device = AjustTargetName(device); | |||
| if (ajust_device == "") { | |||
| return false; | |||
| return FAILED; | |||
| } | |||
| ms_context->set_device_target(device); | |||
| session_impl_ = session::SessionFactory::Get().Create(ajust_device); | |||
| if (session_impl_ == nullptr) { | |||
| MS_LOG(ERROR) << "Session create failed!, please make sure target device:" << device << " is available."; | |||
| return false; | |||
| return FAILED; | |||
| } | |||
| session_impl_->Init(device_id); | |||
| if (ms_context == nullptr) { | |||
| MS_LOG(ERROR) << "Get Context failed!"; | |||
| return false; | |||
| return FAILED; | |||
| } | |||
| if (!ms_context->OpenTsd()) { | |||
| MS_LOG(ERROR) << "Session init OpenTsd failed!"; | |||
| return false; | |||
| return FAILED; | |||
| } | |||
| return true; | |||
| return SUCCESS; | |||
| } | |||
| bool MSInferSession::CheckModelInputs(uint32_t graph_id, const std::vector<tensor::TensorPtr> &inputs) const { | |||
| @@ -38,11 +38,11 @@ class MSInferSession : public InferSession { | |||
| MSInferSession(); | |||
| ~MSInferSession(); | |||
| bool InitEnv(const std::string &device_type, uint32_t device_id) override; | |||
| bool FinalizeEnv() override; | |||
| bool LoadModelFromFile(const std::string &file_name, uint32_t &model_id) override; | |||
| bool UnloadModel(uint32_t model_id) override; | |||
| bool ExecuteModel(uint32_t model_id, const RequestBase &inputs, ReplyBase &outputs) override; | |||
| Status InitEnv(const std::string &device_type, uint32_t device_id) override; | |||
| Status FinalizeEnv() override; | |||
| Status LoadModelFromFile(const std::string &file_name, uint32_t &model_id) override; | |||
| Status UnloadModel(uint32_t model_id) override; | |||
| Status ExecuteModel(uint32_t model_id, const RequestBase &inputs, ReplyBase &outputs) override; | |||
| private: | |||
| std::shared_ptr<session::SessionBasic> session_impl_ = nullptr; | |||
| @@ -57,7 +57,7 @@ class MSInferSession : public InferSession { | |||
| std::shared_ptr<std::vector<char>> ReadFile(const std::string &file); | |||
| static void RegAllOp(); | |||
| string AjustTargetName(const std::string &device); | |||
| bool CompileGraph(std::shared_ptr<FuncGraph> funcGraphPtr, uint32_t &model_id); | |||
| Status CompileGraph(std::shared_ptr<FuncGraph> funcGraphPtr, uint32_t &model_id); | |||
| bool CheckModelInputs(uint32_t graph_id, const std::vector<tensor::TensorPtr> &inputs) const; | |||
| std::vector<tensor::TensorPtr> RunGraph(uint32_t graph_id, const std::vector<tensor::TensorPtr> &inputs); | |||
| }; | |||
| @@ -35,53 +35,53 @@ std::shared_ptr<InferSession> InferSession::CreateSession(const std::string &dev | |||
| } | |||
| } | |||
| bool AclSession::LoadModelFromFile(const std::string &file_name, uint32_t &model_id) { | |||
| return model_process_.LoadModelFromFile(file_name, model_id); | |||
| Status AclSession::LoadModelFromFile(const std::string &file_name, uint32_t &model_id) { | |||
| return model_process_.LoadModelFromFile(file_name, model_id) ? SUCCESS : FAILED; | |||
| } | |||
| bool AclSession::UnloadModel(uint32_t model_id) { | |||
| Status AclSession::UnloadModel(uint32_t model_id) { | |||
| model_process_.UnLoad(); | |||
| return true; | |||
| return SUCCESS; | |||
| } | |||
| bool AclSession::ExecuteModel(uint32_t model_id, const RequestBase &request, | |||
| ReplyBase &reply) { // set d context | |||
| Status AclSession::ExecuteModel(uint32_t model_id, const RequestBase &request, | |||
| ReplyBase &reply) { // set d context | |||
| aclError rt_ret = aclrtSetCurrentContext(context_); | |||
| if (rt_ret != ACL_ERROR_NONE) { | |||
| MSI_LOG_ERROR << "set the ascend device context failed"; | |||
| return false; | |||
| return FAILED; | |||
| } | |||
| return model_process_.Execute(request, reply); | |||
| return model_process_.Execute(request, reply) ? SUCCESS : FAILED; | |||
| } | |||
| bool AclSession::InitEnv(const std::string &device_type, uint32_t device_id) { | |||
| Status AclSession::InitEnv(const std::string &device_type, uint32_t device_id) { | |||
| device_type_ = device_type; | |||
| device_id_ = device_id; | |||
| auto ret = aclInit(nullptr); | |||
| if (ret != ACL_ERROR_NONE) { | |||
| MSI_LOG_ERROR << "Execute aclInit Failed"; | |||
| return false; | |||
| return FAILED; | |||
| } | |||
| MSI_LOG_INFO << "acl init success"; | |||
| ret = aclrtSetDevice(device_id_); | |||
| if (ret != ACL_ERROR_NONE) { | |||
| MSI_LOG_ERROR << "acl open device " << device_id_ << " failed"; | |||
| return false; | |||
| return FAILED; | |||
| } | |||
| MSI_LOG_INFO << "open device " << device_id_ << " success"; | |||
| ret = aclrtCreateContext(&context_, device_id_); | |||
| if (ret != ACL_ERROR_NONE) { | |||
| MSI_LOG_ERROR << "acl create context failed"; | |||
| return false; | |||
| return FAILED; | |||
| } | |||
| MSI_LOG_INFO << "create context success"; | |||
| ret = aclrtCreateStream(&stream_); | |||
| if (ret != ACL_ERROR_NONE) { | |||
| MSI_LOG_ERROR << "acl create stream failed"; | |||
| return false; | |||
| return FAILED; | |||
| } | |||
| MSI_LOG_INFO << "create stream success"; | |||
| @@ -89,17 +89,17 @@ bool AclSession::InitEnv(const std::string &device_type, uint32_t device_id) { | |||
| ret = aclrtGetRunMode(&run_mode); | |||
| if (ret != ACL_ERROR_NONE) { | |||
| MSI_LOG_ERROR << "acl get run mode failed"; | |||
| return false; | |||
| return FAILED; | |||
| } | |||
| bool is_device = (run_mode == ACL_DEVICE); | |||
| model_process_.SetIsDevice(is_device); | |||
| MSI_LOG_INFO << "get run mode success is device input/output " << is_device; | |||
| MSI_LOG_INFO << "Init acl success, device id " << device_id_; | |||
| return true; | |||
| return SUCCESS; | |||
| } | |||
| bool AclSession::FinalizeEnv() { | |||
| Status AclSession::FinalizeEnv() { | |||
| aclError ret; | |||
| if (stream_ != nullptr) { | |||
| ret = aclrtDestroyStream(stream_); | |||
| @@ -129,7 +129,7 @@ bool AclSession::FinalizeEnv() { | |||
| MSI_LOG_ERROR << "finalize acl failed"; | |||
| } | |||
| MSI_LOG_INFO << "end to finalize acl"; | |||
| return true; | |||
| return SUCCESS; | |||
| } | |||
| AclSession::AclSession() = default; | |||
| @@ -32,11 +32,11 @@ class AclSession : public InferSession { | |||
| public: | |||
| AclSession(); | |||
| bool InitEnv(const std::string &device_type, uint32_t device_id) override; | |||
| bool FinalizeEnv() override; | |||
| bool LoadModelFromFile(const std::string &file_name, uint32_t &model_id) override; | |||
| bool UnloadModel(uint32_t model_id) override; | |||
| bool ExecuteModel(uint32_t model_id, const RequestBase &request, ReplyBase &reply) override; | |||
| Status InitEnv(const std::string &device_type, uint32_t device_id) override; | |||
| Status FinalizeEnv() override; | |||
| Status LoadModelFromFile(const std::string &file_name, uint32_t &model_id) override; | |||
| Status UnloadModel(uint32_t model_id) override; | |||
| Status ExecuteModel(uint32_t model_id, const RequestBase &request, ReplyBase &reply) override; | |||
| private: | |||
| std::string device_type_; | |||
| @@ -31,6 +31,7 @@ | |||
| #include "core/version_control/version_controller.h" | |||
| #include "core/util/file_system_operation.h" | |||
| #include "core/serving_tensor.h" | |||
| #include "util/status.h" | |||
| using ms_serving::MSService; | |||
| using ms_serving::PredictReply; | |||
| @@ -79,9 +80,9 @@ Status Session::Predict(const PredictRequest &request, PredictReply &reply) { | |||
| auto ret = session_->ExecuteModel(graph_id_, serving_request, serving_reply); | |||
| MSI_LOG(INFO) << "run Predict finished"; | |||
| if (!ret) { | |||
| if (Status(ret) != SUCCESS) { | |||
| MSI_LOG(ERROR) << "execute model return failed"; | |||
| return FAILED; | |||
| return Status(ret); | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| @@ -97,9 +98,9 @@ Status Session::Warmup(const MindSporeModelPtr model) { | |||
| MSI_TIME_STAMP_START(LoadModelFromFile) | |||
| auto ret = session_->LoadModelFromFile(file_name, graph_id_); | |||
| MSI_TIME_STAMP_END(LoadModelFromFile) | |||
| if (!ret) { | |||
| if (Status(ret) != SUCCESS) { | |||
| MSI_LOG(ERROR) << "Load graph model failed, file name is " << file_name.c_str(); | |||
| return FAILED; | |||
| return Status(ret); | |||
| } | |||
| model_loaded_ = true; | |||
| MSI_LOG(INFO) << "Session Warmup finished"; | |||
| @@ -119,12 +120,22 @@ namespace { | |||
| static const uint32_t uint32max = 0x7FFFFFFF; | |||
| std::promise<void> exit_requested; | |||
| void ClearEnv() { | |||
| Session::Instance().Clear(); | |||
| // inference::ExitInference(); | |||
| } | |||
| void ClearEnv() { Session::Instance().Clear(); } | |||
| void HandleSignal(int sig) { exit_requested.set_value(); } | |||
| grpc::Status CreatGRPCStatus(Status status) { | |||
| switch (status) { | |||
| case SUCCESS: | |||
| return grpc::Status::OK; | |||
| case FAILED: | |||
| return grpc::Status::CANCELLED; | |||
| case INVALID_INPUTS: | |||
| return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, "The Predict Inputs do not match the Model Request!"); | |||
| default: | |||
| return grpc::Status::CANCELLED; | |||
| } | |||
| } | |||
| } // namespace | |||
| // Service Implement | |||
| @@ -134,8 +145,8 @@ class MSServiceImpl final : public MSService::Service { | |||
| MSI_TIME_STAMP_START(Predict) | |||
| auto res = Session::Instance().Predict(*request, *reply); | |||
| MSI_TIME_STAMP_END(Predict) | |||
| if (res != SUCCESS) { | |||
| return grpc::Status::CANCELLED; | |||
| if (res != inference::SUCCESS) { | |||
| return CreatGRPCStatus(res); | |||
| } | |||
| MSI_LOG(INFO) << "Finish call service Eval"; | |||
| return grpc::Status::OK; | |||
| @@ -18,7 +18,7 @@ | |||
| namespace mindspore { | |||
| namespace serving { | |||
| using Status = uint32_t; | |||
| enum ServingStatus { SUCCESS = 0, FAILED }; | |||
| enum ServingStatus { SUCCESS = 0, FAILED, INVALID_INPUTS }; | |||
| } // namespace serving | |||
| } // namespace mindspore | |||
| @@ -31,51 +31,51 @@ using ms_serving::TensorShape; | |||
| class MSClient { | |||
| public: | |||
| explicit MSClient(std::shared_ptr<Channel> channel) : stub_(MSService::NewStub(channel)) {} | |||
| explicit MSClient(std::shared_ptr<Channel> channel) : stub_(MSService::NewStub(channel)) {} | |||
| ~MSClient() = default; | |||
| ~MSClient() = default; | |||
| std::string Predict() { | |||
| // Data we are sending to the server. | |||
| PredictRequest request; | |||
| std::string Predict() { | |||
| // Data we are sending to the server. | |||
| PredictRequest request; | |||
| Tensor data; | |||
| TensorShape shape; | |||
| shape.add_dims(4); | |||
| *data.mutable_tensor_shape() = shape; | |||
| data.set_tensor_type(ms_serving::MS_FLOAT32); | |||
| std::vector<float> input_data{1, 2, 3, 4}; | |||
| data.set_data(input_data.data(), input_data.size() * sizeof(float)); | |||
| *request.add_data() = data; | |||
| *request.add_data() = data; | |||
| std::cout << "intput tensor size is " << request.data_size() << std::endl; | |||
| // Container for the data we expect from the server. | |||
| PredictReply reply; | |||
| Tensor data; | |||
| TensorShape shape; | |||
| shape.add_dims(4); | |||
| *data.mutable_tensor_shape() = shape; | |||
| data.set_tensor_type(ms_serving::MS_FLOAT32); | |||
| std::vector<float> input_data{1, 2, 3, 4}; | |||
| data.set_data(input_data.data(), input_data.size() * sizeof(float)); | |||
| *request.add_data() = data; | |||
| *request.add_data() = data; | |||
| std::cout << "intput tensor size is " << request.data_size() << std::endl; | |||
| // Container for the data we expect from the server. | |||
| PredictReply reply; | |||
| // Context for the client. It could be used to convey extra information to | |||
| // the server and/or tweak certain RPC behaviors. | |||
| ClientContext context; | |||
| // Context for the client. It could be used to convey extra information to | |||
| // the server and/or tweak certain RPC behaviors. | |||
| ClientContext context; | |||
| // The actual RPC. | |||
| Status status = stub_->Predict(&context, request, &reply); | |||
| std::cout << "Compute [1, 2, 3, 4] + [1, 2, 3, 4]" << std::endl; | |||
| // The actual RPC. | |||
| Status status = stub_->Predict(&context, request, &reply); | |||
| std::cout << "Compute [1, 2, 3, 4] + [1, 2, 3, 4]" << std::endl; | |||
| // Act upon its status. | |||
| if (status.ok()) { | |||
| std::cout << "Add result is"; | |||
| for (size_t i = 0; i < reply.result(0).data().size() / sizeof(float); i++) { | |||
| std::cout << " " << (reinterpret_cast<const float *>(reply.mutable_result(0)->mutable_data()->data()))[i]; | |||
| } | |||
| std::cout << std::endl; | |||
| // Act upon its status. | |||
| if (status.ok()) { | |||
| return "RPC OK"; | |||
| } else { | |||
| std::cout << status.error_code() << ": " << status.error_message() << std::endl; | |||
| return "RPC failed"; | |||
| } | |||
| return "RPC OK"; | |||
| } else { | |||
| std::cout << status.error_code() << ": " << status.error_message() << std::endl; | |||
| return "RPC failed"; | |||
| } | |||
| } | |||
| private: | |||
| std::unique_ptr<MSService::Stub> stub_; | |||
| std::unique_ptr<MSService::Stub> stub_; | |||
| }; | |||
| int main(int argc, char **argv) { | |||
| @@ -12,6 +12,7 @@ | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import sys | |||
| import grpc | |||
| import numpy as np | |||
| import ms_service_pb2 | |||
| @@ -19,7 +20,19 @@ import ms_service_pb2_grpc | |||
| def run(): | |||
| channel = grpc.insecure_channel('localhost:5500') | |||
| if len(sys.argv) > 2: | |||
| sys.exit("input error") | |||
| channel_str = "" | |||
| if len(sys.argv) == 2: | |||
| split_args = sys.argv[1].split('=') | |||
| if len(split_args) > 1: | |||
| channel_str = split_args[1] | |||
| else: | |||
| channel_str = 'localhost:5500' | |||
| else: | |||
| channel_str = 'localhost:5500' | |||
| channel = grpc.insecure_channel(channel_str) | |||
| stub = ms_service_pb2_grpc.MSServiceStub(channel) | |||
| request = ms_service_pb2.PredictRequest() | |||
| @@ -33,11 +46,17 @@ def run(): | |||
| y.tensor_type = ms_service_pb2.MS_FLOAT32 | |||
| y.data = (np.ones([4]).astype(np.float32)).tobytes() | |||
| result = stub.Predict(request) | |||
| print(result) | |||
| result_np = np.frombuffer(result.result[0].data, dtype=np.float32).reshape(result.result[0].tensor_shape.dims) | |||
| print("ms client received: ") | |||
| print(result_np) | |||
| try: | |||
| result = stub.Predict(request) | |||
| print(result) | |||
| result_np = np.frombuffer(result.result[0].data, dtype=np.float32).reshape(result.result[0].tensor_shape.dims) | |||
| print("ms client received: ") | |||
| print(result_np) | |||
| except grpc.RpcError as e: | |||
| print(e.details()) | |||
| status_code = e.code() | |||
| print(status_code.name) | |||
| print(status_code.value) | |||
| if __name__ == '__main__': | |||
| run() | |||