Merge pull request !2561 from dinghao/mastertags/v0.6.0-beta
| @@ -277,10 +277,11 @@ endif () | |||||
| if (USE_GLOG) | if (USE_GLOG) | ||||
| target_link_libraries(inference PRIVATE mindspore::glog) | target_link_libraries(inference PRIVATE mindspore::glog) | ||||
| else() | |||||
| if (CMAKE_SYSTEM_NAME MATCHES "Linux") | |||||
| target_link_options(inference PRIVATE -Wl,-init,mindspore_log_init) | |||||
| elseif (CMAKE_SYSTEM_NAME MATCHES "Darwin") | |||||
| set_target_properties(inference PROPERTIES MACOSX_RPATH ON) | |||||
| endif () | |||||
| endif() | endif() | ||||
| if (CMAKE_SYSTEM_NAME MATCHES "Linux") | |||||
| target_link_options(inference PRIVATE -Wl,-init,common_log_init) | |||||
| elseif (CMAKE_SYSTEM_NAME MATCHES "Darwin") | |||||
| set_target_properties(inference PROPERTIES MACOSX_RPATH ON) | |||||
| endif () | |||||
| @@ -33,9 +33,14 @@ | |||||
| namespace py = pybind11; | namespace py = pybind11; | ||||
| namespace mindspore::inference { | namespace mindspore::inference { | ||||
| std::shared_ptr<FuncGraph> LoadModel(const char *model_buf, size_t size, const std::string &device) { | std::shared_ptr<FuncGraph> LoadModel(const char *model_buf, size_t size, const std::string &device) { | ||||
| inference::Session::RegAllOp(); | |||||
| auto anf_graph = lite::AnfConverter::RunAnfConverter(model_buf, size); | |||||
| return anf_graph; | |||||
| try { | |||||
| inference::Session::RegAllOp(); | |||||
| auto anf_graph = lite::AnfConverter::RunAnfConverter(model_buf, size); | |||||
| return anf_graph; | |||||
| } catch (std::exception &e) { | |||||
| MS_LOG(ERROR) << "Inference LoadModel failed"; | |||||
| return nullptr; | |||||
| } | |||||
| } | } | ||||
| void ExitInference() { | void ExitInference() { | ||||
| @@ -51,12 +56,17 @@ void ExitInference() { | |||||
| } | } | ||||
| std::shared_ptr<MSSession> MSSession::CreateSession(const std::string &device, uint32_t device_id) { | std::shared_ptr<MSSession> MSSession::CreateSession(const std::string &device, uint32_t device_id) { | ||||
| auto session = std::make_shared<inference::Session>(); | |||||
| auto ret = session->Init(device, device_id); | |||||
| if (ret != 0) { | |||||
| try { | |||||
| auto session = std::make_shared<inference::Session>(); | |||||
| auto ret = session->Init(device, device_id); | |||||
| if (ret != 0) { | |||||
| return nullptr; | |||||
| } | |||||
| return session; | |||||
| } catch (std::exception &e) { | |||||
| MS_LOG(ERROR) << "Inference CreatSession failed"; | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| return session; | |||||
| } | } | ||||
| void Session::RegAllOp() { | void Session::RegAllOp() { | ||||
| @@ -113,47 +123,71 @@ void Session::RegAllOp() { | |||||
| uint32_t Session::CompileGraph(std::shared_ptr<FuncGraph> funcGraphPtr) { | uint32_t Session::CompileGraph(std::shared_ptr<FuncGraph> funcGraphPtr) { | ||||
| MS_ASSERT(session_impl_ != nullptr); | MS_ASSERT(session_impl_ != nullptr); | ||||
| auto graph_id = session_impl_->CompileGraph(NOT_NULL(funcGraphPtr)); | |||||
| py::gil_scoped_release gil_release; | |||||
| return graph_id; | |||||
| try { | |||||
| auto graph_id = session_impl_->CompileGraph(NOT_NULL(funcGraphPtr)); | |||||
| py::gil_scoped_release gil_release; | |||||
| return graph_id; | |||||
| } catch (std::exception &e) { | |||||
| MS_LOG(ERROR) << "Inference CompileGraph failed"; | |||||
| return static_cast<uint32_t>(-1); | |||||
| } | |||||
| } | } | ||||
| MultiTensor Session::RunGraph(uint32_t graph_id, const std::vector<std::shared_ptr<inference::MSTensor>> &inputs) { | MultiTensor Session::RunGraph(uint32_t graph_id, const std::vector<std::shared_ptr<inference::MSTensor>> &inputs) { | ||||
| std::vector<tensor::TensorPtr> inTensors; | |||||
| inTensors.resize(inputs.size()); | |||||
| bool has_error = false; | |||||
| std::transform(inputs.begin(), inputs.end(), inTensors.begin(), | |||||
| [&has_error](const std::shared_ptr<inference::MSTensor> &tensor_ptr) -> tensor::TensorPtr { | |||||
| if (tensor_ptr == nullptr) { | |||||
| MS_LOG(WARNING) << "input MSTensor is nullptr, return nullptr"; | |||||
| has_error = true; | |||||
| return nullptr; | |||||
| } | |||||
| auto tensor = static_cast<inference::Tensor *>(tensor_ptr.get()); | |||||
| if (tensor == nullptr) { | |||||
| MS_LOG(ERROR) << "Can not cast input MSTensor to tensor"; | |||||
| has_error = true; | |||||
| return nullptr; | |||||
| } | |||||
| return tensor->tensor(); | |||||
| }); | |||||
| if (has_error) { | |||||
| MS_LOG(ERROR) << "Init Tensor failed, returning empty result"; | |||||
| std::vector<std::shared_ptr<inference::MSTensor>> multiTensor; | |||||
| return multiTensor; | |||||
| } | |||||
| VectorRef outputs; | |||||
| session_impl_->RunGraph(graph_id, inTensors, &outputs); | |||||
| try { | |||||
| std::vector<tensor::TensorPtr> inTensors; | |||||
| inTensors.resize(inputs.size()); | |||||
| bool has_error = false; | |||||
| std::transform(inputs.begin(), inputs.end(), inTensors.begin(), | |||||
| [&has_error](const std::shared_ptr<inference::MSTensor> &tensor_ptr) -> tensor::TensorPtr { | |||||
| if (tensor_ptr == nullptr) { | |||||
| MS_LOG(WARNING) << "input MSTensor is nullptr, return nullptr"; | |||||
| has_error = true; | |||||
| return nullptr; | |||||
| } | |||||
| auto tensor = static_cast<inference::Tensor *>(tensor_ptr.get()); | |||||
| if (tensor == nullptr) { | |||||
| MS_LOG(ERROR) << "Can not cast input MSTensor to tensor"; | |||||
| has_error = true; | |||||
| return nullptr; | |||||
| } | |||||
| return tensor->tensor(); | |||||
| }); | |||||
| if (has_error) { | |||||
| MS_LOG(ERROR) << "Init Tensor failed, returning empty result"; | |||||
| std::vector<std::shared_ptr<inference::MSTensor>> multiTensor; | |||||
| return multiTensor; | |||||
| } | |||||
| VectorRef outputs; | |||||
| session_impl_->RunGraph(graph_id, inTensors, &outputs); | |||||
| return TransformVectorRefToMultiTensor(outputs); | |||||
| return TransformVectorRefToMultiTensor(outputs); | |||||
| } catch (std::exception &e) { | |||||
| MS_LOG(ERROR) << "Inference Rungraph failed"; | |||||
| return MultiTensor(); | |||||
| } | |||||
| } | } | ||||
| namespace { | |||||
| string AjustTargetName(const std::string &device) { | |||||
| if (device == kAscendDevice) { | |||||
| return std::string(kAscendDevice) + "Inference"; | |||||
| } else { | |||||
| MS_LOG(ERROR) << "Only support device Ascend right now"; | |||||
| return ""; | |||||
| } | |||||
| } | |||||
| } // namespace | |||||
| int Session::Init(const std::string &device, uint32_t device_id) { | int Session::Init(const std::string &device, uint32_t device_id) { | ||||
| RegAllOp(); | RegAllOp(); | ||||
| auto ms_context = MsContext::GetInstance(); | auto ms_context = MsContext::GetInstance(); | ||||
| ms_context->set_execution_mode(kGraphMode); | ms_context->set_execution_mode(kGraphMode); | ||||
| ms_context->set_device_target(kAscendDevice); | |||||
| session_impl_ = session::SessionFactory::Get().Create(device); | |||||
| ms_context->set_device_id(device_id); | |||||
| auto ajust_device = AjustTargetName(device); | |||||
| if (ajust_device == "") { | |||||
| return -1; | |||||
| } | |||||
| ms_context->set_device_target(device); | |||||
| session_impl_ = session::SessionFactory::Get().Create(ajust_device); | |||||
| if (session_impl_ == nullptr) { | if (session_impl_ == nullptr) { | ||||
| MS_LOG(ERROR) << "Session create failed!, please make sure target device:" << device << " is available."; | MS_LOG(ERROR) << "Session create failed!, please make sure target device:" << device << " is available."; | ||||
| return -1; | return -1; | ||||
| @@ -463,7 +463,7 @@ void InitSubModulesLogLevel() { | |||||
| // set submodule's log level | // set submodule's log level | ||||
| auto submodule = GetEnv("MS_SUBMODULE_LOG_v"); | auto submodule = GetEnv("MS_SUBMODULE_LOG_v"); | ||||
| MS_LOG(INFO) << "MS_SUBMODULE_LOG_v=`" << submodule << "`"; | |||||
| MS_LOG(DEBUG) << "MS_SUBMODULE_LOG_v=`" << submodule << "`"; | |||||
| LogConfigParser parser(submodule); | LogConfigParser parser(submodule); | ||||
| auto configs = parser.Parse(); | auto configs = parser.Parse(); | ||||
| for (const auto &cfg : configs) { | for (const auto &cfg : configs) { | ||||
| @@ -489,22 +489,14 @@ void InitSubModulesLogLevel() { | |||||
| } // namespace mindspore | } // namespace mindspore | ||||
| extern "C" { | extern "C" { | ||||
| // shared lib init hook | |||||
| #if defined(_WIN32) || defined(_WIN64) | #if defined(_WIN32) || defined(_WIN64) | ||||
| __attribute__((constructor)) void mindspore_log_init(void) { | |||||
| __attribute__((constructor)) void common_log_init(void) { | |||||
| #else | #else | ||||
| void mindspore_log_init(void) { | |||||
| void common_log_init(void) { | |||||
| #endif | #endif | ||||
| #ifdef USE_GLOG | #ifdef USE_GLOG | ||||
| // do not use glog predefined log prefix | // do not use glog predefined log prefix | ||||
| FLAGS_log_prefix = false; | FLAGS_log_prefix = false; | ||||
| static bool is_glog_initialzed = false; | |||||
| if (!is_glog_initialzed) { | |||||
| #if !defined(_WIN32) && !defined(_WIN64) | |||||
| google::InitGoogleLogging("mindspore"); | |||||
| #endif | |||||
| is_glog_initialzed = true; | |||||
| } | |||||
| // set default log level to WARNING | // set default log level to WARNING | ||||
| if (mindspore::GetEnv("GLOG_v").empty()) { | if (mindspore::GetEnv("GLOG_v").empty()) { | ||||
| FLAGS_v = mindspore::WARNING; | FLAGS_v = mindspore::WARNING; | ||||
| @@ -525,4 +517,22 @@ void mindspore_log_init(void) { | |||||
| #endif | #endif | ||||
| mindspore::InitSubModulesLogLevel(); | mindspore::InitSubModulesLogLevel(); | ||||
| } | } | ||||
| // shared lib init hook | |||||
| #if defined(_WIN32) || defined(_WIN64) | |||||
| __attribute__((constructor)) void mindspore_log_init(void) { | |||||
| #else | |||||
| void mindspore_log_init(void) { | |||||
| #endif | |||||
| #ifdef USE_GLOG | |||||
| static bool is_glog_initialzed = false; | |||||
| if (!is_glog_initialzed) { | |||||
| #if !defined(_WIN32) && !defined(_WIN64) | |||||
| google::InitGoogleLogging("mindspore"); | |||||
| #endif | |||||
| is_glog_initialzed = true; | |||||
| } | |||||
| #endif | |||||
| common_log_init(); | |||||
| } | |||||
| } | } | ||||
| @@ -22,6 +22,7 @@ | |||||
| #include <vector> | #include <vector> | ||||
| #include <utility> | #include <utility> | ||||
| #include <memory> | #include <memory> | ||||
| #include <future> | |||||
| #include "mindspore/ccsrc/utils/log_adapter.h" | #include "mindspore/ccsrc/utils/log_adapter.h" | ||||
| #include "serving/ms_service.grpc.pb.h" | #include "serving/ms_service.grpc.pb.h" | ||||
| @@ -40,7 +41,7 @@ namespace serving { | |||||
| using MSTensorPtr = std::shared_ptr<inference::MSTensor>; | using MSTensorPtr = std::shared_ptr<inference::MSTensor>; | ||||
| Status Session::CreatDeviceSession(const std::string &device, uint32_t device_id) { | Status Session::CreatDeviceSession(const std::string &device, uint32_t device_id) { | ||||
| session_ = inference::MSSession::CreateSession(device + "Inference", device_id); | |||||
| session_ = inference::MSSession::CreateSession(device, device_id); | |||||
| if (session_ == nullptr) { | if (session_ == nullptr) { | ||||
| MS_LOG(ERROR) << "Creat Session Failed"; | MS_LOG(ERROR) << "Creat Session Failed"; | ||||
| return FAILED; | return FAILED; | ||||
| @@ -67,6 +68,7 @@ Status Session::Predict(const std::vector<MSTensorPtr> &inputs, inference::Multi | |||||
| MS_LOG(INFO) << "run Predict"; | MS_LOG(INFO) << "run Predict"; | ||||
| *outputs = session_->RunGraph(graph_id_, inputs); | *outputs = session_->RunGraph(graph_id_, inputs); | ||||
| MS_LOG(INFO) << "run Predict finished"; | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -80,12 +82,16 @@ Status Session::Warmup(const MindSporeModelPtr model) { | |||||
| std::string file_name = model->GetModelPath() + '/' + model->GetModelName(); | std::string file_name = model->GetModelPath() + '/' + model->GetModelName(); | ||||
| char *graphBuf = ReadFile(file_name.c_str(), &size); | char *graphBuf = ReadFile(file_name.c_str(), &size); | ||||
| if (graphBuf == nullptr) { | if (graphBuf == nullptr) { | ||||
| MS_LOG(ERROR) << "Load graph model failed, file name is " << file_name.c_str(); | |||||
| MS_LOG(ERROR) << "Read model file failed, file name is " << file_name.c_str(); | |||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| last_graph_ = inference::LoadModel(graphBuf, size, device_type_); | last_graph_ = inference::LoadModel(graphBuf, size, device_type_); | ||||
| if (last_graph_ == nullptr) { | |||||
| MS_LOG(ERROR) << "Load graph model failed, file name is " << file_name.c_str(); | |||||
| return FAILED; | |||||
| } | |||||
| graph_id_ = session_->CompileGraph(last_graph_); | graph_id_ = session_->CompileGraph(last_graph_); | ||||
| MS_LOG(INFO) << "Session Warmup"; | |||||
| MS_LOG(INFO) << "Session Warmup finished"; | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -95,6 +101,9 @@ Status Session::Clear() { | |||||
| } | } | ||||
| namespace { | namespace { | ||||
| static const uint32_t uint32max = 0x7FFFFFFF; | |||||
| std::promise<void> exit_requested; | |||||
| const std::map<ms_serving::DataType, TypeId> type2id_map{ | const std::map<ms_serving::DataType, TypeId> type2id_map{ | ||||
| {ms_serving::MS_UNKNOWN, TypeId::kNumberTypeBegin}, {ms_serving::MS_BOOL, TypeId::kNumberTypeBool}, | {ms_serving::MS_UNKNOWN, TypeId::kNumberTypeBegin}, {ms_serving::MS_BOOL, TypeId::kNumberTypeBool}, | ||||
| {ms_serving::MS_INT8, TypeId::kNumberTypeInt8}, {ms_serving::MS_UINT8, TypeId::kNumberTypeUInt8}, | {ms_serving::MS_INT8, TypeId::kNumberTypeInt8}, {ms_serving::MS_UINT8, TypeId::kNumberTypeUInt8}, | ||||
| @@ -141,7 +150,7 @@ MSTensorPtr ServingTensor2MSTensor(const ms_serving::Tensor &tensor) { | |||||
| } | } | ||||
| TypeId type = iter->second; | TypeId type = iter->second; | ||||
| auto ms_tensor = std::shared_ptr<inference::MSTensor>(inference::MSTensor::CreateTensor(type, shape)); | auto ms_tensor = std::shared_ptr<inference::MSTensor>(inference::MSTensor::CreateTensor(type, shape)); | ||||
| memcpy_s(ms_tensor->MutableData(), tensor.data().size(), tensor.data().data(), tensor.data().size()); | |||||
| memcpy_s(ms_tensor->MutableData(), ms_tensor->Size(), tensor.data().data(), tensor.data().size()); | |||||
| return ms_tensor; | return ms_tensor; | ||||
| } | } | ||||
| @@ -166,10 +175,7 @@ void ClearEnv() { | |||||
| Session::Instance().Clear(); | Session::Instance().Clear(); | ||||
| inference::ExitInference(); | inference::ExitInference(); | ||||
| } | } | ||||
| void HandleSignal(int sig) { | |||||
| ClearEnv(); | |||||
| exit(0); | |||||
| } | |||||
| void HandleSignal(int sig) { exit_requested.set_value(); } | |||||
| #ifdef ENABLE_D | #ifdef ENABLE_D | ||||
| static rtContext_t g_ctx = nullptr; | static rtContext_t g_ctx = nullptr; | ||||
| @@ -247,6 +253,7 @@ Status Server::BuildAndStart() { | |||||
| rtError_t rt_ret = rtCtxGetCurrent(&ctx); | rtError_t rt_ret = rtCtxGetCurrent(&ctx); | ||||
| if (rt_ret != RT_ERROR_NONE || ctx == nullptr) { | if (rt_ret != RT_ERROR_NONE || ctx == nullptr) { | ||||
| MS_LOG(ERROR) << "the ascend device context is null"; | MS_LOG(ERROR) << "the ascend device context is null"; | ||||
| ClearEnv(); | |||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| g_ctx = ctx; | g_ctx = ctx; | ||||
| @@ -258,6 +265,7 @@ Status Server::BuildAndStart() { | |||||
| auto option = grpc::MakeChannelArgumentOption(GRPC_ARG_ALLOW_REUSEPORT, 0); | auto option = grpc::MakeChannelArgumentOption(GRPC_ARG_ALLOW_REUSEPORT, 0); | ||||
| grpc::ServerBuilder builder; | grpc::ServerBuilder builder; | ||||
| builder.SetOption(std::move(option)); | builder.SetOption(std::move(option)); | ||||
| builder.SetMaxMessageSize(uint32max); | |||||
| // Listen on the given address without any authentication mechanism. | // Listen on the given address without any authentication mechanism. | ||||
| builder.AddListeningPort(server_address, grpc::InsecureServerCredentials()); | builder.AddListeningPort(server_address, grpc::InsecureServerCredentials()); | ||||
| // Register "service" as the instance through which we'll communicate with | // Register "service" as the instance through which we'll communicate with | ||||
| @@ -265,13 +273,15 @@ Status Server::BuildAndStart() { | |||||
| builder.RegisterService(&service); | builder.RegisterService(&service); | ||||
| // Finally assemble the server. | // Finally assemble the server. | ||||
| std::unique_ptr<grpc::Server> server(builder.BuildAndStart()); | std::unique_ptr<grpc::Server> server(builder.BuildAndStart()); | ||||
| auto grpc_server_run = [&server]() { server->Wait(); }; | |||||
| std::thread serving_thread(grpc_server_run); | |||||
| MS_LOG(INFO) << "Server listening on " << server_address << std::endl; | MS_LOG(INFO) << "Server listening on " << server_address << std::endl; | ||||
| // Wait for the server to shutdown. Note that some other thread must be | |||||
| // responsible for shutting down the server for this call to ever return. | |||||
| server->Wait(); | |||||
| auto exit_future = exit_requested.get_future(); | |||||
| exit_future.wait(); | |||||
| ClearEnv(); | |||||
| server->Shutdown(); | |||||
| serving_thread.join(); | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| } // namespace serving | } // namespace serving | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -29,7 +29,6 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace serving { | namespace serving { | ||||
| char *ReadFile(const char *file, size_t *size) { | char *ReadFile(const char *file, size_t *size) { | ||||
| if (file == nullptr) { | if (file == nullptr) { | ||||
| MS_LOG(ERROR) << "file is nullptr"; | MS_LOG(ERROR) << "file is nullptr"; | ||||
| @@ -70,8 +69,8 @@ bool DirOrFileExist(const std::string &file_path) { | |||||
| } | } | ||||
| std::vector<std::string> GetAllSubDirs(const std::string &dir_path) { | std::vector<std::string> GetAllSubDirs(const std::string &dir_path) { | ||||
| DIR *dir; | |||||
| struct dirent *ptr; | |||||
| DIR *dir = nullptr; | |||||
| struct dirent *ptr = nullptr; | |||||
| std::vector<std::string> SubDirs; | std::vector<std::string> SubDirs; | ||||
| if ((dir = opendir(dir_path.c_str())) == NULL) { | if ((dir = opendir(dir_path.c_str())) == NULL) { | ||||
| @@ -36,17 +36,16 @@ bool RemovePrefix(std::string *str, const std::string &prefix) { | |||||
| bool Option::ParseInt32(std::string *arg) { | bool Option::ParseInt32(std::string *arg) { | ||||
| if (RemovePrefix(arg, "--") && RemovePrefix(arg, name_) && RemovePrefix(arg, "=")) { | if (RemovePrefix(arg, "--") && RemovePrefix(arg, name_) && RemovePrefix(arg, "=")) { | ||||
| char extra; | |||||
| int32_t parsed_value; | int32_t parsed_value; | ||||
| if (sscanf(arg->data(), "%d%c", &parsed_value, &extra) != 1) { | |||||
| std::cout << "Parse " << name_ << "Error for option " << *arg << std::endl; | |||||
| try { | |||||
| parsed_value = std::stoi(arg->data()); | |||||
| } catch (std::invalid_argument) { | |||||
| std::cout << "Parse " << name_ << " Error for option " << *arg << std::endl; | |||||
| return false; | return false; | ||||
| } else { | |||||
| *int32_default_ = parsed_value; | |||||
| } | } | ||||
| *int32_default_ = parsed_value; | |||||
| return true; | return true; | ||||
| } | } | ||||
| return false; | return false; | ||||
| } | } | ||||
| @@ -76,17 +75,16 @@ bool Option::ParseString(std::string *arg) { | |||||
| bool Option::ParseFloat(std::string *arg) { | bool Option::ParseFloat(std::string *arg) { | ||||
| if (RemovePrefix(arg, "--") && RemovePrefix(arg, name_) && RemovePrefix(arg, "=")) { | if (RemovePrefix(arg, "--") && RemovePrefix(arg, name_) && RemovePrefix(arg, "=")) { | ||||
| char extra; | |||||
| float parsed_value; | float parsed_value; | ||||
| if (sscanf(arg->data(), "%f%c", &parsed_value, &extra) != 1) { | |||||
| std::cout << "Parse " << name_ << "Error for option " << *arg << std::endl; | |||||
| try { | |||||
| parsed_value = std::stof(arg->data()); | |||||
| } catch (std::invalid_argument) { | |||||
| std::cout << "Parse " << name_ << " Error for option " << *arg << std::endl; | |||||
| return false; | return false; | ||||
| } else { | |||||
| *float_default_ = parsed_value; | |||||
| } | } | ||||
| *float_default_ = parsed_value; | |||||
| return true; | return true; | ||||
| } | } | ||||
| return false; | return false; | ||||
| } | } | ||||
| @@ -159,10 +157,11 @@ Options::Options() : args_(nullptr) { CreateOptions(); } | |||||
| void Options::CreateOptions() { | void Options::CreateOptions() { | ||||
| args_ = std::make_shared<Arguments>(); | args_ = std::make_shared<Arguments>(); | ||||
| std::vector<Option> options = { | std::vector<Option> options = { | ||||
| Option("port", &args_->grpc_port, "Port to listen on for gRPC API, default is 5500"), | |||||
| Option("model_name", &args_->model_name, "model name "), | |||||
| Option("model_path", &args_->model_path, "the path of the model files"), | |||||
| Option("device_id", &args_->device_id, "the device id, default is 0"), | |||||
| Option("port", &args_->grpc_port, | |||||
| "[Optional] Port to listen on for gRPC API, default is 5500, range from 1 to 65535"), | |||||
| Option("model_name", &args_->model_name, "[Required] model name "), | |||||
| Option("model_path", &args_->model_path, "[Required] the path of the model files"), | |||||
| Option("device_id", &args_->device_id, "[Optional] the device id, default is 0, range from 0 to 7"), | |||||
| }; | }; | ||||
| options_ = options; | options_ = options; | ||||
| } | } | ||||
| @@ -176,6 +175,14 @@ bool Options::CheckOptions() { | |||||
| std::cout << "device_type only support Ascend right now" << std::endl; | std::cout << "device_type only support Ascend right now" << std::endl; | ||||
| return false; | return false; | ||||
| } | } | ||||
| if (args_->device_id > 7) { | |||||
| std::cout << "the device_id should be in [0~7]" << std::endl; | |||||
| return false; | |||||
| } | |||||
| if (args_->grpc_port < 1 || args_->grpc_port > 65535) { | |||||
| std::cout << "the port should be in [1~65535]" << std::endl; | |||||
| return false; | |||||
| } | |||||
| return true; | return true; | ||||
| } | } | ||||
| @@ -238,6 +245,5 @@ void Options::Usage() { | |||||
| << option.usage_ << std::endl; | << option.usage_ << std::endl; | ||||
| } | } | ||||
| } | } | ||||
| } // namespace serving | } // namespace serving | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -22,7 +22,6 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace serving { | namespace serving { | ||||
| struct Arguments { | struct Arguments { | ||||
| int32_t grpc_port = 5500; | int32_t grpc_port = 5500; | ||||
| std::string grpc_socket_path; | std::string grpc_socket_path; | ||||
| @@ -40,6 +39,7 @@ class Option { | |||||
| Option(const std::string &name, bool *default_point, const std::string &usage); | Option(const std::string &name, bool *default_point, const std::string &usage); | ||||
| Option(const std::string &name, std::string *default_point, const std::string &usage); | Option(const std::string &name, std::string *default_point, const std::string &usage); | ||||
| Option(const std::string &name, float *default_point, const std::string &usage); | Option(const std::string &name, float *default_point, const std::string &usage); | ||||
| ~Option() = default; | |||||
| private: | private: | ||||
| friend class Options; | friend class Options; | ||||
| @@ -77,7 +77,6 @@ class Options { | |||||
| std::vector<Option> options_; | std::vector<Option> options_; | ||||
| std::shared_ptr<Arguments> args_; | std::shared_ptr<Arguments> args_; | ||||
| }; | }; | ||||
| } // namespace serving | } // namespace serving | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -19,7 +19,6 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace serving { | namespace serving { | ||||
| MindSporeModel::MindSporeModel(const std::string &model_name, const std::string &model_path, | MindSporeModel::MindSporeModel(const std::string &model_name, const std::string &model_path, | ||||
| const std::string &model_version, const time_t &last_update_time) | const std::string &model_version, const time_t &last_update_time) | ||||
| : model_name_(model_name), | : model_name_(model_name), | ||||
| @@ -25,7 +25,6 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace serving { | namespace serving { | ||||
| volatile bool stop_poll = false; | volatile bool stop_poll = false; | ||||
| std::string GetVersionFromPath(const std::string &path) { | std::string GetVersionFromPath(const std::string &path) { | ||||
| @@ -102,10 +101,10 @@ Status VersionController::CreateInitModels() { | |||||
| } | } | ||||
| std::vector<std::string> SubDirs = GetAllSubDirs(models_path_); | std::vector<std::string> SubDirs = GetAllSubDirs(models_path_); | ||||
| if (version_control_strategy_ == kLastest) { | if (version_control_strategy_ == kLastest) { | ||||
| auto path = SubDirs.empty() ? models_path_ : SubDirs.back(); | |||||
| std::string model_version = GetVersionFromPath(path); | |||||
| time_t last_update_time = GetModifyTime(path); | |||||
| MindSporeModelPtr model_ptr = std::make_shared<MindSporeModel>(model_name_, path, model_version, last_update_time); | |||||
| std::string model_version = GetVersionFromPath(models_path_); | |||||
| time_t last_update_time = GetModifyTime(models_path_); | |||||
| MindSporeModelPtr model_ptr = | |||||
| std::make_shared<MindSporeModel>(model_name_, models_path_, model_version, last_update_time); | |||||
| valid_models_.emplace_back(model_ptr); | valid_models_.emplace_back(model_ptr); | ||||
| } else { | } else { | ||||
| for (auto &dir : SubDirs) { | for (auto &dir : SubDirs) { | ||||
| @@ -119,8 +118,8 @@ Status VersionController::CreateInitModels() { | |||||
| MS_LOG(ERROR) << "There is no valid model for serving"; | MS_LOG(ERROR) << "There is no valid model for serving"; | ||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| Session::Instance().Warmup(valid_models_.back()); | |||||
| return SUCCESS; | |||||
| auto ret = Session::Instance().Warmup(valid_models_.back()); | |||||
| return ret; | |||||
| } | } | ||||
| void VersionController::StartPollModelPeriodic() { | void VersionController::StartPollModelPeriodic() { | ||||
| @@ -129,6 +128,5 @@ void VersionController::StartPollModelPeriodic() { | |||||
| } | } | ||||
| void VersionController::StopPollModelPeriodic() {} | void VersionController::StopPollModelPeriodic() {} | ||||
| } // namespace serving | } // namespace serving | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -64,7 +64,6 @@ class PeriodicFunction { | |||||
| VersionController::VersionControllerStrategy version_control_strategy_; | VersionController::VersionControllerStrategy version_control_strategy_; | ||||
| std::vector<MindSporeModelPtr> valid_models_; | std::vector<MindSporeModelPtr> valid_models_; | ||||
| }; | }; | ||||
| } // namespace serving | } // namespace serving | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -214,6 +214,7 @@ PredictRequest ReadBertInput() { | |||||
| class MSClient { | class MSClient { | ||||
| public: | public: | ||||
| explicit MSClient(std::shared_ptr<Channel> channel) : stub_(MSService::NewStub(channel)) {} | explicit MSClient(std::shared_ptr<Channel> channel) : stub_(MSService::NewStub(channel)) {} | ||||
| ~MSClient() = default; | |||||
| std::string Predict(const std::string &type) { | std::string Predict(const std::string &type) { | ||||
| // Data we are sending to the server. | // Data we are sending to the server. | ||||
| @@ -310,7 +311,6 @@ int main(int argc, char **argv) { | |||||
| type = "add"; | type = "add"; | ||||
| } | } | ||||
| } | } | ||||
| } else { | } else { | ||||
| target_str = "localhost:5500"; | target_str = "localhost:5500"; | ||||
| type = "add"; | type = "add"; | ||||