| @@ -96,4 +96,8 @@ if (ENABLE_TESTCASES) | |||||
| add_subdirectory(tests) | add_subdirectory(tests) | ||||
| endif() | endif() | ||||
| if (ENABLE_SERVING) | |||||
| add_subdirectory(serving) | |||||
| endif() | |||||
| include(cmake/package.cmake) | include(cmake/package.cmake) | ||||
| @@ -53,6 +53,7 @@ usage() | |||||
| echo " -V Specify the minimum required cuda version, default CUDA 9.2" | echo " -V Specify the minimum required cuda version, default CUDA 9.2" | ||||
| echo " -I Compile predict, default off" | echo " -I Compile predict, default off" | ||||
| echo " -K Compile with AKG, default off" | echo " -K Compile with AKG, default off" | ||||
| echo " -s Enable serving module, default off" | |||||
| } | } | ||||
| # check value of input is 'on' or 'off' | # check value of input is 'on' or 'off' | ||||
| @@ -92,9 +93,9 @@ checkopts() | |||||
| USE_GLOG="on" | USE_GLOG="on" | ||||
| PREDICT_PLATFORM="" | PREDICT_PLATFORM="" | ||||
| ENABLE_AKG="off" | ENABLE_AKG="off" | ||||
| ENABLE_SERVING="off" | |||||
| # Process the options | # Process the options | ||||
| while getopts 'drvj:c:t:hsb:a:g:p:ie:m:I:LRP:Q:D:zM:V:K' opt | |||||
| while getopts 'drvj:c:t:hsb:a:g:p:ie:m:I:LRP:Q:D:zM:V:K:s' opt | |||||
| do | do | ||||
| OPTARG=$(echo ${OPTARG} | tr '[A-Z]' '[a-z]') | OPTARG=$(echo ${OPTARG} | tr '[A-Z]' '[a-z]') | ||||
| case "${opt}" in | case "${opt}" in | ||||
| @@ -235,6 +236,10 @@ checkopts() | |||||
| ENABLE_AKG="on" | ENABLE_AKG="on" | ||||
| echo "enable compile with akg" | echo "enable compile with akg" | ||||
| ;; | ;; | ||||
| s) | |||||
| ENABLE_SERVING="on" | |||||
| echo "enable serving" | |||||
| ;; | |||||
| *) | *) | ||||
| echo "Unknown option ${opt}!" | echo "Unknown option ${opt}!" | ||||
| usage | usage | ||||
| @@ -314,6 +319,10 @@ build_mindspore() | |||||
| if [[ "X$ENABLE_AKG" = "Xon" ]] && [[ "X$ENABLE_D" = "Xon" ]]; then | if [[ "X$ENABLE_AKG" = "Xon" ]] && [[ "X$ENABLE_D" = "Xon" ]]; then | ||||
| CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_AKG=ON" | CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_AKG=ON" | ||||
| fi | fi | ||||
| if [[ "X$ENABLE_SERVING" = "Xon" ]]; then | |||||
| CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_SERVING=ON" | |||||
| fi | |||||
| echo "${CMAKE_ARGS}" | echo "${CMAKE_ARGS}" | ||||
| if [[ "X$INC_BUILD" = "Xoff" ]]; then | if [[ "X$INC_BUILD" = "Xoff" ]]; then | ||||
| cmake ${CMAKE_ARGS} ../.. | cmake ${CMAKE_ARGS} ../.. | ||||
| @@ -37,6 +37,8 @@ class MS_API MSSession { | |||||
| }; | }; | ||||
| std::shared_ptr<FuncGraph> MS_API LoadModel(const char *model_buf, size_t size, const std::string &device); | std::shared_ptr<FuncGraph> MS_API LoadModel(const char *model_buf, size_t size, const std::string &device); | ||||
| void MS_API ExitInference(); | |||||
| } // namespace inference | } // namespace inference | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_INCLUDE_MS_SESSION_H | #endif // MINDSPORE_INCLUDE_MS_SESSION_H | ||||
| @@ -247,7 +247,7 @@ add_library(inference SHARED | |||||
| ${CMAKE_CURRENT_SOURCE_DIR}/session/session.cc | ${CMAKE_CURRENT_SOURCE_DIR}/session/session.cc | ||||
| ${LOAD_ONNX_SRC} | ${LOAD_ONNX_SRC} | ||||
| ) | ) | ||||
| target_link_libraries(inference PRIVATE ${PYTHON_LIB} ${SECUREC_LIBRARY} | |||||
| target_link_libraries(inference PRIVATE ${PYTHON_LIBRARY} ${SECUREC_LIBRARY} | |||||
| -Wl,--whole-archive mindspore -Wl,--no-whole-archive mindspore_gvar mindspore::protobuf) | -Wl,--whole-archive mindspore -Wl,--no-whole-archive mindspore_gvar mindspore::protobuf) | ||||
| if (ENABLE_CPU) | if (ENABLE_CPU) | ||||
| @@ -38,6 +38,18 @@ std::shared_ptr<FuncGraph> LoadModel(const char *model_buf, size_t size, const s | |||||
| return anf_graph; | return anf_graph; | ||||
| } | } | ||||
| void ExitInference() { | |||||
| auto ms_context = MsContext::GetInstance(); | |||||
| if (ms_context == nullptr) { | |||||
| MS_LOG(ERROR) << "Get Context failed!"; | |||||
| return; | |||||
| } | |||||
| if (!ms_context->CloseTsd()) { | |||||
| MS_LOG(ERROR) << "Inference CloseTsd failed!"; | |||||
| return; | |||||
| } | |||||
| } | |||||
| std::shared_ptr<MSSession> MSSession::CreateSession(const std::string &device, uint32_t device_id) { | std::shared_ptr<MSSession> MSSession::CreateSession(const std::string &device, uint32_t device_id) { | ||||
| auto session = std::make_shared<inference::Session>(); | auto session = std::make_shared<inference::Session>(); | ||||
| auto ret = session->Init(device, device_id); | auto ret = session->Init(device, device_id); | ||||
| @@ -101,11 +113,14 @@ void Session::RegAllOp() { | |||||
| uint32_t Session::CompileGraph(std::shared_ptr<FuncGraph> funcGraphPtr) { | uint32_t Session::CompileGraph(std::shared_ptr<FuncGraph> funcGraphPtr) { | ||||
| MS_ASSERT(session_impl_ != nullptr); | MS_ASSERT(session_impl_ != nullptr); | ||||
| return session_impl_->CompileGraph(NOT_NULL(funcGraphPtr)); | |||||
| auto graph_id = session_impl_->CompileGraph(NOT_NULL(funcGraphPtr)); | |||||
| py::gil_scoped_release gil_release; | |||||
| return graph_id; | |||||
| } | } | ||||
| MultiTensor Session::RunGraph(uint32_t graph_id, const std::vector<std::shared_ptr<inference::MSTensor>> &inputs) { | MultiTensor Session::RunGraph(uint32_t graph_id, const std::vector<std::shared_ptr<inference::MSTensor>> &inputs) { | ||||
| std::vector<tensor::TensorPtr> inTensors; | std::vector<tensor::TensorPtr> inTensors; | ||||
| inTensors.resize(inputs.size()); | |||||
| bool has_error = false; | bool has_error = false; | ||||
| std::transform(inputs.begin(), inputs.end(), inTensors.begin(), | std::transform(inputs.begin(), inputs.end(), inTensors.begin(), | ||||
| [&has_error](const std::shared_ptr<inference::MSTensor> &tensor_ptr) -> tensor::TensorPtr { | [&has_error](const std::shared_ptr<inference::MSTensor> &tensor_ptr) -> tensor::TensorPtr { | ||||
| @@ -144,6 +159,14 @@ int Session::Init(const std::string &device, uint32_t device_id) { | |||||
| return -1; | return -1; | ||||
| } | } | ||||
| session_impl_->Init(device_id); | session_impl_->Init(device_id); | ||||
| if (ms_context == nullptr) { | |||||
| MS_LOG(ERROR) << "Get Context failed!"; | |||||
| return -1; | |||||
| } | |||||
| if (!ms_context->OpenTsd()) { | |||||
| MS_LOG(ERROR) << "Session init OpenTsd failed!"; | |||||
| return -1; | |||||
| } | |||||
| return 0; | return 0; | ||||
| } | } | ||||
| @@ -0,0 +1,69 @@ | |||||
| find_package(Threads REQUIRED) | |||||
| # This branch assumes that gRPC and all its dependencies are already installed | |||||
| # on this system, so they can be located by find_package(). | |||||
| # Find Protobuf installation | |||||
| # Looks for protobuf-config.cmake file installed by Protobuf's cmake installation. | |||||
| #set(protobuf_MODULE_COMPATIBLE TRUE) | |||||
| #find_package(Protobuf CONFIG REQUIRED) | |||||
| #message(STATUS "Using protobuf ${protobuf_VERSION}") | |||||
| add_library(protobuf::libprotobuf ALIAS protobuf::protobuf) | |||||
| add_executable(protobuf::libprotoc ALIAS protobuf::protoc) | |||||
| set(_PROTOBUF_LIBPROTOBUF protobuf::libprotobuf) | |||||
| set(_REFLECTION gRPC::grpc++_reflection) | |||||
| if(CMAKE_CROSSCOMPILING) | |||||
| find_program(_PROTOBUF_PROTOC protoc) | |||||
| else() | |||||
| set(_PROTOBUF_PROTOC $<TARGET_FILE:protobuf::protoc>) | |||||
| endif() | |||||
| # Find gRPC installation | |||||
| # Looks for gRPCConfig.cmake file installed by gRPC's cmake installation. | |||||
| find_package(gRPC CONFIG REQUIRED) | |||||
| message(STATUS "Using gRPC ${gRPC_VERSION}") | |||||
| set(_GRPC_GRPCPP gRPC::grpc++) | |||||
| if(CMAKE_CROSSCOMPILING) | |||||
| find_program(_GRPC_CPP_PLUGIN_EXECUTABLE grpc_cpp_plugin) | |||||
| else() | |||||
| set(_GRPC_CPP_PLUGIN_EXECUTABLE $<TARGET_FILE:gRPC::grpc_cpp_plugin>) | |||||
| endif() | |||||
| # Proto file | |||||
| get_filename_component(hw_proto "ms_service.proto" ABSOLUTE) | |||||
| get_filename_component(hw_proto_path "${hw_proto}" PATH) | |||||
| # Generated sources | |||||
| set(hw_proto_srcs "${CMAKE_CURRENT_BINARY_DIR}/ms_service.pb.cc") | |||||
| set(hw_proto_hdrs "${CMAKE_CURRENT_BINARY_DIR}/ms_service.pb.h") | |||||
| set(hw_grpc_srcs "${CMAKE_CURRENT_BINARY_DIR}/ms_service.grpc.pb.cc") | |||||
| set(hw_grpc_hdrs "${CMAKE_CURRENT_BINARY_DIR}/ms_service.grpc.pb.h") | |||||
| add_custom_command( | |||||
| OUTPUT "${hw_proto_srcs}" "${hw_proto_hdrs}" "${hw_grpc_srcs}" "${hw_grpc_hdrs}" | |||||
| COMMAND ${_PROTOBUF_PROTOC} | |||||
| ARGS --grpc_out "${CMAKE_CURRENT_BINARY_DIR}" | |||||
| --cpp_out "${CMAKE_CURRENT_BINARY_DIR}" | |||||
| -I "${hw_proto_path}" | |||||
| --plugin=protoc-gen-grpc="${_GRPC_CPP_PLUGIN_EXECUTABLE}" | |||||
| "${hw_proto}" | |||||
| DEPENDS "${hw_proto}") | |||||
| # Include generated *.pb.h files | |||||
| include_directories("${CMAKE_CURRENT_BINARY_DIR}" "${CMAKE_CURRENT_SOURCE_DIR}" "${CMAKE_CURRENT_SOURCE_DIR}/core" | |||||
| "${PROJECT_SOURCE_DIR}/mindspore/ccsrc") | |||||
| file(GLOB_RECURSE CORE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} | |||||
| "core/*.cc" "core/util/*.cc" "core/version_control/*.cc") | |||||
| list(APPEND SERVING_SRC "main.cc" ${hw_proto_srcs} ${hw_grpc_srcs} ${CORE_SRC_LIST}) | |||||
| include_directories(${CMAKE_BINARY_DIR}) | |||||
| add_executable(ms_serving ${SERVING_SRC}) | |||||
| target_link_libraries(ms_serving inference mindspore_gvar) | |||||
| target_link_libraries(ms_serving ${_REFLECTION} ${_GRPC_GRPCPP} ${_PROTOBUF_LIBPROTOBUF} pthread) | |||||
| if (ENABLE_D) | |||||
| add_compile_definitions(ENABLE_D) | |||||
| target_link_libraries(ms_serving ${RUNTIME_LIB}) | |||||
| endif() | |||||
| @@ -0,0 +1,36 @@ | |||||
| # serving | |||||
| #### Description | |||||
| A flexible, high-performance serving system for deep learning models | |||||
| #### Software Architecture | |||||
| Software architecture description | |||||
| #### Installation | |||||
| 1. xxxx | |||||
| 2. xxxx | |||||
| 3. xxxx | |||||
| #### Instructions | |||||
| 1. xxxx | |||||
| 2. xxxx | |||||
| 3. xxxx | |||||
| #### Contribution | |||||
| 1. Fork the repository | |||||
| 2. Create Feat_xxx branch | |||||
| 3. Commit your code | |||||
| 4. Create Pull Request | |||||
| #### Gitee Feature | |||||
| 1. You can use Readme\_XXX.md to support different languages, such as Readme\_en.md, Readme\_zh.md | |||||
| 2. Gitee blog [blog.gitee.com](https://blog.gitee.com) | |||||
| 3. Explore open source project [https://gitee.com/explore](https://gitee.com/explore) | |||||
| 4. The most valuable open source project [GVP](https://gitee.com/gvp) | |||||
| 5. The manual of Gitee [https://gitee.com/help](https://gitee.com/help) | |||||
| 6. The most popular members [https://gitee.com/gitee-stars/](https://gitee.com/gitee-stars/) | |||||
| @@ -0,0 +1,37 @@ | |||||
| # serving | |||||
| #### 介绍 | |||||
| A flexible, high-performance serving system for deep learning models | |||||
| #### 软件架构 | |||||
| 软件架构说明 | |||||
| #### 安装教程 | |||||
| 1. xxxx | |||||
| 2. xxxx | |||||
| 3. xxxx | |||||
| #### 使用说明 | |||||
| 1. xxxx | |||||
| 2. xxxx | |||||
| 3. xxxx | |||||
| #### 参与贡献 | |||||
| 1. Fork 本仓库 | |||||
| 2. 新建 Feat_xxx 分支 | |||||
| 3. 提交代码 | |||||
| 4. 新建 Pull Request | |||||
| #### 码云特技 | |||||
| 1. 使用 Readme\_XXX.md 来支持不同的语言,例如 Readme\_en.md, Readme\_zh.md | |||||
| 2. 码云官方博客 [blog.gitee.com](https://blog.gitee.com) | |||||
| 3. 你可以 [https://gitee.com/explore](https://gitee.com/explore) 这个地址来了解码云上的优秀开源项目 | |||||
| 4. [GVP](https://gitee.com/gvp) 全称是码云最有价值开源项目,是码云综合评定出的优秀开源项目 | |||||
| 5. 码云官方提供的使用手册 [https://gitee.com/help](https://gitee.com/help) | |||||
| 6. 码云封面人物是一档用来展示码云会员风采的栏目 [https://gitee.com/gitee-stars/](https://gitee.com/gitee-stars/) | |||||
| @@ -0,0 +1,277 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "core/server.h" | |||||
| #include <grpcpp/grpcpp.h> | |||||
| #include <grpcpp/health_check_service_interface.h> | |||||
| #include <grpcpp/ext/proto_server_reflection_plugin.h> | |||||
| #include <string> | |||||
| #include <map> | |||||
| #include <vector> | |||||
| #include <utility> | |||||
| #include <memory> | |||||
| #include "mindspore/ccsrc/utils/log_adapter.h" | |||||
| #include "serving/ms_service.grpc.pb.h" | |||||
| #include "core/util/option_parser.h" | |||||
| #include "core/version_control/version_controller.h" | |||||
| #include "mindspore/ccsrc/utils/context/ms_context.h" | |||||
| #include "core/util/file_system_operation.h" | |||||
| #include "graphengine/third_party/fwkacllib/inc/runtime/context.h" | |||||
| using ms_serving::MSService; | |||||
| using ms_serving::PredictReply; | |||||
| using ms_serving::PredictRequest; | |||||
| namespace mindspore { | |||||
| namespace serving { | |||||
| using MSTensorPtr = std::shared_ptr<inference::MSTensor>; | |||||
| Status Session::CreatDeviceSession(const std::string &device, uint32_t device_id) { | |||||
| session_ = inference::MSSession::CreateSession(device + "Inference", device_id); | |||||
| if (session_ == nullptr) { | |||||
| MS_LOG(ERROR) << "Creat Session Failed"; | |||||
| return FAILED; | |||||
| } | |||||
| device_type_ = device; | |||||
| return SUCCESS; | |||||
| } | |||||
| Session &Session::Instance() { | |||||
| static Session instance; | |||||
| return instance; | |||||
| } | |||||
| Status Session::Predict(const std::vector<MSTensorPtr> &inputs, inference::MultiTensor *outputs) { | |||||
| if (last_graph_ == nullptr) { | |||||
| MS_LOG(ERROR) << "the model has not loaded"; | |||||
| return FAILED; | |||||
| } | |||||
| if (session_ == nullptr) { | |||||
| MS_LOG(ERROR) << "the inference session has not be initialized"; | |||||
| return FAILED; | |||||
| } | |||||
| std::lock_guard<std::mutex> lock(mutex_); | |||||
| MS_LOG(INFO) << "run Predict"; | |||||
| *outputs = session_->RunGraph(graph_id_, inputs); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status Session::Warmup(const MindSporeModelPtr model) { | |||||
| if (session_ == nullptr) { | |||||
| MS_LOG(ERROR) << "The CreatDeviceSession should be called, before warmup"; | |||||
| return FAILED; | |||||
| } | |||||
| std::lock_guard<std::mutex> lock(mutex_); | |||||
| size_t size = 0; | |||||
| std::string file_name = model->GetModelPath() + '/' + model->GetModelName(); | |||||
| char *graphBuf = ReadFile(file_name.c_str(), &size); | |||||
| if (graphBuf == nullptr) { | |||||
| MS_LOG(ERROR) << "Load graph model failed, file name is " << file_name.c_str(); | |||||
| return FAILED; | |||||
| } | |||||
| last_graph_ = inference::LoadModel(graphBuf, size, device_type_); | |||||
| graph_id_ = session_->CompileGraph(last_graph_); | |||||
| MS_LOG(INFO) << "Session Warmup"; | |||||
| return SUCCESS; | |||||
| } | |||||
| Status Session::Clear() { | |||||
| session_ = nullptr; | |||||
| return SUCCESS; | |||||
| } | |||||
| namespace { | |||||
| const std::map<ms_serving::DataType, TypeId> type2id_map{ | |||||
| {ms_serving::MS_UNKNOWN, TypeId::kNumberTypeBegin}, {ms_serving::MS_BOOL, TypeId::kNumberTypeBool}, | |||||
| {ms_serving::MS_INT8, TypeId::kNumberTypeInt8}, {ms_serving::MS_UINT8, TypeId::kNumberTypeUInt8}, | |||||
| {ms_serving::MS_INT16, TypeId::kNumberTypeInt16}, {ms_serving::MS_UINT16, TypeId::kNumberTypeUInt16}, | |||||
| {ms_serving::MS_INT32, TypeId::kNumberTypeInt32}, {ms_serving::MS_UINT32, TypeId::kNumberTypeUInt32}, | |||||
| {ms_serving::MS_INT64, TypeId::kNumberTypeInt64}, {ms_serving::MS_UINT64, TypeId::kNumberTypeUInt64}, | |||||
| {ms_serving::MS_FLOAT16, TypeId::kNumberTypeFloat16}, {ms_serving::MS_FLOAT32, TypeId::kNumberTypeFloat32}, | |||||
| {ms_serving::MS_FLOAT64, TypeId::kNumberTypeFloat64}, | |||||
| }; | |||||
| const std::map<TypeId, ms_serving::DataType> id2type_map{ | |||||
| {TypeId::kNumberTypeBegin, ms_serving::MS_UNKNOWN}, {TypeId::kNumberTypeBool, ms_serving::MS_BOOL}, | |||||
| {TypeId::kNumberTypeInt8, ms_serving::MS_INT8}, {TypeId::kNumberTypeUInt8, ms_serving::MS_UINT8}, | |||||
| {TypeId::kNumberTypeInt16, ms_serving::MS_INT16}, {TypeId::kNumberTypeUInt16, ms_serving::MS_UINT16}, | |||||
| {TypeId::kNumberTypeInt32, ms_serving::MS_INT32}, {TypeId::kNumberTypeUInt32, ms_serving::MS_UINT32}, | |||||
| {TypeId::kNumberTypeInt64, ms_serving::MS_INT64}, {TypeId::kNumberTypeUInt64, ms_serving::MS_UINT64}, | |||||
| {TypeId::kNumberTypeFloat16, ms_serving::MS_FLOAT16}, {TypeId::kNumberTypeFloat32, ms_serving::MS_FLOAT32}, | |||||
| {TypeId::kNumberTypeFloat64, ms_serving::MS_FLOAT64}, | |||||
| }; | |||||
| const std::map<ms_serving::DataType, size_t> length_map{ | |||||
| {ms_serving::MS_UNKNOWN, 0}, | |||||
| {ms_serving::MS_BOOL, sizeof(bool)}, | |||||
| {ms_serving::MS_INT8, sizeof(int8_t)}, | |||||
| {ms_serving::MS_UINT8, sizeof(uint8_t)}, | |||||
| {ms_serving::MS_INT16, sizeof(int16_t)}, | |||||
| {ms_serving::MS_UINT16, sizeof(uint16_t)}, | |||||
| {ms_serving::MS_INT32, sizeof(int32_t)}, | |||||
| {ms_serving::MS_UINT32, sizeof(uint32_t)}, | |||||
| {ms_serving::MS_INT64, sizeof(int64_t)}, | |||||
| {ms_serving::MS_UINT64, sizeof(uint64_t)}, | |||||
| {ms_serving::MS_FLOAT16, 2}, | |||||
| {ms_serving::MS_FLOAT32, 4}, | |||||
| {ms_serving::MS_FLOAT64, 8}, | |||||
| }; | |||||
| MSTensorPtr ServingTensor2MSTensor(const ms_serving::Tensor &tensor) { | |||||
| std::vector<int> shape; | |||||
| for (auto dim : tensor.tensor_shape().dims()) { | |||||
| shape.push_back(static_cast<int>(dim)); | |||||
| } | |||||
| auto iter = type2id_map.find(tensor.tensor_type()); | |||||
| if (iter == type2id_map.end()) { | |||||
| MS_LOG(ERROR) << "input tensor type is wrong, type is " << tensor.tensor_type(); | |||||
| return nullptr; | |||||
| } | |||||
| TypeId type = iter->second; | |||||
| auto ms_tensor = std::shared_ptr<inference::MSTensor>(inference::MSTensor::CreateTensor(type, shape)); | |||||
| memcpy_s(ms_tensor->MutableData(), tensor.data().size(), tensor.data().data(), tensor.data().size()); | |||||
| return ms_tensor; | |||||
| } | |||||
| ms_serving::Tensor MSTensor2ServingTensor(MSTensorPtr ms_tensor) { | |||||
| ms_serving::Tensor tensor; | |||||
| ms_serving::TensorShape shape; | |||||
| for (auto dim : ms_tensor->shape()) { | |||||
| shape.add_dims(dim); | |||||
| } | |||||
| *tensor.mutable_tensor_shape() = shape; | |||||
| auto iter = id2type_map.find(ms_tensor->data_type()); | |||||
| if (iter == id2type_map.end()) { | |||||
| MS_LOG(ERROR) << "input tensor type is wrong, type is " << tensor.tensor_type(); | |||||
| return tensor; | |||||
| } | |||||
| tensor.set_tensor_type(iter->second); | |||||
| tensor.set_data(ms_tensor->MutableData(), ms_tensor->Size()); | |||||
| return tensor; | |||||
| } | |||||
| void ClearEnv() { | |||||
| Session::Instance().Clear(); | |||||
| inference::ExitInference(); | |||||
| } | |||||
| void HandleSignal(int sig) { | |||||
| ClearEnv(); | |||||
| exit(0); | |||||
| } | |||||
| #ifdef ENABLE_D | |||||
| static rtContext_t g_ctx = nullptr; | |||||
| #endif | |||||
| } // namespace | |||||
| // Service Implement | |||||
| class MSServiceImpl final : public MSService::Service { | |||||
| grpc::Status Predict(grpc::ServerContext *context, const PredictRequest *request, PredictReply *reply) override { | |||||
| std::lock_guard<std::mutex> lock(mutex_); | |||||
| #ifdef ENABLE_D | |||||
| if (g_ctx == nullptr) { | |||||
| MS_LOG(ERROR) << "rtCtx is nullptr"; | |||||
| return grpc::Status::CANCELLED; | |||||
| } | |||||
| rtError_t rt_ret = rtCtxSetCurrent(g_ctx); | |||||
| if (rt_ret != RT_ERROR_NONE) { | |||||
| MS_LOG(ERROR) << "set Ascend rtCtx failed"; | |||||
| } | |||||
| #endif | |||||
| std::vector<MSTensorPtr> inputs; | |||||
| inference::MultiTensor outputs; | |||||
| for (int i = 0; i < request->data_size(); i++) { | |||||
| auto input = ServingTensor2MSTensor(request->data(i)); | |||||
| if (input == nullptr) { | |||||
| MS_LOG(ERROR) << "Tensor convert failed"; | |||||
| return grpc::Status::CANCELLED; | |||||
| } | |||||
| inputs.push_back(input); | |||||
| } | |||||
| auto res = Session::Instance().Predict(inputs, &outputs); | |||||
| if (res != SUCCESS) { | |||||
| return grpc::Status::CANCELLED; | |||||
| } | |||||
| for (const auto &tensor : outputs) { | |||||
| *reply->add_result() = MSTensor2ServingTensor(tensor); | |||||
| } | |||||
| MS_LOG(INFO) << "Finish call service Eval"; | |||||
| return grpc::Status::OK; | |||||
| } | |||||
| grpc::Status Test(grpc::ServerContext *context, const PredictRequest *request, PredictReply *reply) override { | |||||
| MS_LOG(INFO) << "TestService call"; | |||||
| return grpc::Status::OK; | |||||
| } | |||||
| std::mutex mutex_; | |||||
| }; | |||||
| Status Server::BuildAndStart() { | |||||
| // handle exit signal | |||||
| signal(SIGINT, HandleSignal); | |||||
| Status res; | |||||
| auto option_args = Options::Instance().GetArgs(); | |||||
| std::string server_address = "0.0.0.0:" + std::to_string(option_args->grpc_port); | |||||
| std::string model_path = option_args->model_path; | |||||
| std::string model_name = option_args->model_name; | |||||
| std::string device_type = option_args->device_type; | |||||
| auto device_id = option_args->device_id; | |||||
| res = Session::Instance().CreatDeviceSession(device_type, device_id); | |||||
| if (res != SUCCESS) { | |||||
| MS_LOG(ERROR) << "creat session failed"; | |||||
| ClearEnv(); | |||||
| return res; | |||||
| } | |||||
| VersionController version_controller(option_args->poll_model_wait_seconds, model_path, model_name); | |||||
| res = version_controller.Run(); | |||||
| if (res != SUCCESS) { | |||||
| MS_LOG(ERROR) << "load model failed"; | |||||
| ClearEnv(); | |||||
| return res; | |||||
| } | |||||
| #ifdef ENABLE_D | |||||
| // set d context | |||||
| rtContext_t ctx = nullptr; | |||||
| rtError_t rt_ret = rtCtxGetCurrent(&ctx); | |||||
| if (rt_ret != RT_ERROR_NONE || ctx == nullptr) { | |||||
| MS_LOG(ERROR) << "the ascend device context is null"; | |||||
| return FAILED; | |||||
| } | |||||
| g_ctx = ctx; | |||||
| #endif | |||||
| MSServiceImpl service; | |||||
| grpc::EnableDefaultHealthCheckService(true); | |||||
| grpc::reflection::InitProtoReflectionServerBuilderPlugin(); | |||||
| // Set the port is not reuseable | |||||
| auto option = grpc::MakeChannelArgumentOption(GRPC_ARG_ALLOW_REUSEPORT, 0); | |||||
| grpc::ServerBuilder builder; | |||||
| builder.SetOption(std::move(option)); | |||||
| // Listen on the given address without any authentication mechanism. | |||||
| builder.AddListeningPort(server_address, grpc::InsecureServerCredentials()); | |||||
| // Register "service" as the instance through which we'll communicate with | |||||
| // clients. In this case it corresponds to an *synchronous* service. | |||||
| builder.RegisterService(&service); | |||||
| // Finally assemble the server. | |||||
| std::unique_ptr<grpc::Server> server(builder.BuildAndStart()); | |||||
| MS_LOG(INFO) << "Server listening on " << server_address << std::endl; | |||||
| // Wait for the server to shutdown. Note that some other thread must be | |||||
| // responsible for shutting down the server for this call to ever return. | |||||
| server->Wait(); | |||||
| return SUCCESS; | |||||
| } | |||||
| } // namespace serving | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,56 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_SERVER_H | |||||
| #define MINDSPORE_SERVER_H | |||||
| #include <string> | |||||
| #include <mutex> | |||||
| #include <vector> | |||||
| #include <memory> | |||||
| #include "util/status.h" | |||||
| #include "version_control/model.h" | |||||
| #include "include/inference.h" | |||||
| #include "mindspore/ccsrc/debug/info.h" | |||||
| namespace mindspore { | |||||
| namespace serving { | |||||
| class Session { | |||||
| public: | |||||
| static Session &Instance(); | |||||
| Status CreatDeviceSession(const std::string &device, uint32_t device_id); | |||||
| Status Predict(const std::vector<std::shared_ptr<inference::MSTensor>> &inputs, inference::MultiTensor *output); | |||||
| Status Warmup(const MindSporeModelPtr model); | |||||
| Status Clear(); | |||||
| private: | |||||
| Session() = default; | |||||
| ~Session() = default; | |||||
| int sesseion_id_{0}; | |||||
| std::shared_ptr<inference::MSSession> session_{nullptr}; | |||||
| FuncGraphPtr last_graph_{nullptr}; | |||||
| uint32_t graph_id_{0}; | |||||
| std::mutex mutex_; | |||||
| std::string device_type_; | |||||
| }; | |||||
| class Server { | |||||
| public: | |||||
| Server() = default; | |||||
| ~Server() = default; | |||||
| Status BuildAndStart(); | |||||
| }; | |||||
| } // namespace serving | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_SERVER_H | |||||
| @@ -0,0 +1,102 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "core/util/file_system_operation.h" | |||||
| #include <unistd.h> | |||||
| #include <dirent.h> | |||||
| #include <sys/types.h> | |||||
| #include <sys/stat.h> | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include <iostream> | |||||
| #include <algorithm> | |||||
| #include <ctime> | |||||
| #include <fstream> | |||||
| #include <memory> | |||||
| #include "mindspore/ccsrc/utils/log_adapter.h" | |||||
| namespace mindspore { | |||||
| namespace serving { | |||||
| char *ReadFile(const char *file, size_t *size) { | |||||
| if (file == nullptr) { | |||||
| MS_LOG(ERROR) << "file is nullptr"; | |||||
| return nullptr; | |||||
| } | |||||
| MS_ASSERT(size != nullptr); | |||||
| std::string realPath = file; | |||||
| std::ifstream ifs(realPath); | |||||
| if (!ifs.good()) { | |||||
| MS_LOG(ERROR) << "file: " << realPath << " is not exist"; | |||||
| return nullptr; | |||||
| } | |||||
| if (!ifs.is_open()) { | |||||
| MS_LOG(ERROR) << "file: " << realPath << "open failed"; | |||||
| return nullptr; | |||||
| } | |||||
| ifs.seekg(0, std::ios::end); | |||||
| *size = ifs.tellg(); | |||||
| std::unique_ptr<char> buf(new (std::nothrow) char[*size]); | |||||
| if (buf == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc buf failed, file: " << realPath; | |||||
| ifs.close(); | |||||
| return nullptr; | |||||
| } | |||||
| ifs.seekg(0, std::ios::beg); | |||||
| ifs.read(buf.get(), *size); | |||||
| ifs.close(); | |||||
| return buf.release(); | |||||
| } | |||||
| bool DirOrFileExist(const std::string &file_path) { | |||||
| int ret = access(file_path.c_str(), 0); | |||||
| return (ret == -1) ? false : true; | |||||
| } | |||||
| std::vector<std::string> GetAllSubDirs(const std::string &dir_path) { | |||||
| DIR *dir; | |||||
| struct dirent *ptr; | |||||
| std::vector<std::string> SubDirs; | |||||
| if ((dir = opendir(dir_path.c_str())) == NULL) { | |||||
| MS_LOG(ERROR) << "Open " << dir_path << " error!"; | |||||
| return std::vector<std::string>(); | |||||
| } | |||||
| while ((ptr = readdir(dir)) != NULL) { | |||||
| std::string name = ptr->d_name; | |||||
| if (name == "." || name == "..") { | |||||
| continue; | |||||
| } | |||||
| if (ptr->d_type == DT_DIR) { | |||||
| SubDirs.push_back(dir_path + "/" + name); | |||||
| } | |||||
| } | |||||
| closedir(dir); | |||||
| std::sort(SubDirs.begin(), SubDirs.end()); | |||||
| return SubDirs; | |||||
| } | |||||
| time_t GetModifyTime(const std::string &file_path) { | |||||
| struct stat info; | |||||
| (void)stat(file_path.c_str(), &info); | |||||
| return info.st_mtime; | |||||
| } | |||||
| } // namespace serving | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,32 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_SERVING_FILE_SYSTEM_OPERATION_H_ | |||||
| #define MINDSPORE_SERVING_FILE_SYSTEM_OPERATION_H_ | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include <ctime> | |||||
| namespace mindspore { | |||||
| namespace serving { | |||||
| char *ReadFile(const char *file, size_t *size); | |||||
| bool DirOrFileExist(const std::string &file_path); | |||||
| std::vector<std::string> GetAllSubDirs(const std::string &dir_path); | |||||
| time_t GetModifyTime(const std::string &file_path); | |||||
| } // namespace serving | |||||
| } // namespace mindspore | |||||
| #endif // !MINDSPORE_SERVING_FILE_SYSTEM_OPERATION_H_ | |||||
| @@ -0,0 +1,243 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "core/util/option_parser.h" | |||||
| #include <vector> | |||||
| #include <string> | |||||
| #include <cstring> | |||||
| #include <iostream> | |||||
| #include <iomanip> | |||||
| #include "mindspore/ccsrc/utils/log_adapter.h" | |||||
| namespace mindspore { | |||||
| namespace serving { | |||||
| bool StartWith(const std::string &str, const std::string &expected) { | |||||
| return expected.empty() || | |||||
| (str.size() >= expected.size() && memcmp(str.data(), expected.data(), expected.size()) == 0); | |||||
| } | |||||
| bool RemovePrefix(std::string *str, const std::string &prefix) { | |||||
| if (!StartWith(*str, prefix)) return false; | |||||
| str->replace(str->begin(), str->begin() + prefix.size(), ""); | |||||
| return true; | |||||
| } | |||||
| bool Option::ParseInt32(std::string *arg) { | |||||
| if (RemovePrefix(arg, "--") && RemovePrefix(arg, name_) && RemovePrefix(arg, "=")) { | |||||
| char extra; | |||||
| int32_t parsed_value; | |||||
| if (sscanf(arg->data(), "%d%c", &parsed_value, &extra) != 1) { | |||||
| std::cout << "Parse " << name_ << "Error for option " << *arg << std::endl; | |||||
| return false; | |||||
| } else { | |||||
| *int32_default_ = parsed_value; | |||||
| } | |||||
| return true; | |||||
| } | |||||
| return false; | |||||
| } | |||||
| bool Option::ParseBool(std::string *arg) { | |||||
| if (RemovePrefix(arg, "--") && RemovePrefix(arg, name_) && RemovePrefix(arg, "=")) { | |||||
| if (*arg == "true") { | |||||
| *bool_default_ = true; | |||||
| } else if (*arg == "false") { | |||||
| *bool_default_ = false; | |||||
| } else { | |||||
| std::cout << "Parse " << name_ << " Error for option " << *arg << std::endl; | |||||
| return false; | |||||
| } | |||||
| return true; | |||||
| } | |||||
| return false; | |||||
| } | |||||
| bool Option::ParseString(std::string *arg) { | |||||
| if (RemovePrefix(arg, "--") && RemovePrefix(arg, name_) && RemovePrefix(arg, "=")) { | |||||
| *string_default_ = *arg; | |||||
| return true; | |||||
| } | |||||
| return false; | |||||
| } | |||||
| bool Option::ParseFloat(std::string *arg) { | |||||
| if (RemovePrefix(arg, "--") && RemovePrefix(arg, name_) && RemovePrefix(arg, "=")) { | |||||
| char extra; | |||||
| float parsed_value; | |||||
| if (sscanf(arg->data(), "%f%c", &parsed_value, &extra) != 1) { | |||||
| std::cout << "Parse " << name_ << "Error for option " << *arg << std::endl; | |||||
| return false; | |||||
| } else { | |||||
| *float_default_ = parsed_value; | |||||
| } | |||||
| return true; | |||||
| } | |||||
| return false; | |||||
| } | |||||
| Option::Option(const std::string &name, int32_t *default_point, const std::string &usage) | |||||
| : name_(name), | |||||
| type_(MS_TYPE_INT32), | |||||
| int32_default_(default_point), | |||||
| bool_default_(nullptr), | |||||
| string_default_(nullptr), | |||||
| float_default_(nullptr), | |||||
| usage_(usage) {} | |||||
| Option::Option(const std::string &name, bool *default_point, const std::string &usage) | |||||
| : name_(name), | |||||
| type_(MS_TYPE_BOOL), | |||||
| int32_default_(nullptr), | |||||
| bool_default_(default_point), | |||||
| string_default_(nullptr), | |||||
| float_default_(nullptr), | |||||
| usage_(usage) {} | |||||
| Option::Option(const std::string &name, std::string *default_point, const std::string &usage) | |||||
| : name_(name), | |||||
| type_(MS_TYPE_STRING), | |||||
| int32_default_(nullptr), | |||||
| bool_default_(nullptr), | |||||
| string_default_(default_point), | |||||
| float_default_(nullptr), | |||||
| usage_(usage) {} | |||||
| Option::Option(const std::string &name, float *default_point, const std::string &usage) | |||||
| : name_(name), | |||||
| type_(MS_TYPE_FLOAT), | |||||
| int32_default_(nullptr), | |||||
| bool_default_(nullptr), | |||||
| string_default_(nullptr), | |||||
| float_default_(default_point), | |||||
| usage_(usage) {} | |||||
| bool Option::Parse(std::string *arg) { | |||||
| bool result = false; | |||||
| switch (type_) { | |||||
| case MS_TYPE_BOOL: | |||||
| result = ParseBool(arg); | |||||
| break; | |||||
| case MS_TYPE_FLOAT: | |||||
| result = ParseFloat(arg); | |||||
| break; | |||||
| case MS_TYPE_INT32: | |||||
| result = ParseInt32(arg); | |||||
| break; | |||||
| case MS_TYPE_STRING: | |||||
| result = ParseString(arg); | |||||
| break; | |||||
| default: | |||||
| break; | |||||
| } | |||||
| return result; | |||||
| } | |||||
| std::shared_ptr<Options> Options::inst_ = nullptr; | |||||
| Options &Options::Instance() { | |||||
| static Options instance; | |||||
| return instance; | |||||
| } | |||||
| Options::Options() : args_(nullptr) { CreateOptions(); } | |||||
| void Options::CreateOptions() { | |||||
| args_ = std::make_shared<Arguments>(); | |||||
| std::vector<Option> options = { | |||||
| Option("port", &args_->grpc_port, "Port to listen on for gRPC API, default is 5500"), | |||||
| Option("model_name", &args_->model_name, "model name "), | |||||
| Option("model_path", &args_->model_path, "the path of the model files"), | |||||
| Option("device_id", &args_->device_id, "the device id, default is 0"), | |||||
| }; | |||||
| options_ = options; | |||||
| } | |||||
| bool Options::CheckOptions() { | |||||
| if (args_->model_name == "" || args_->model_path == "") { | |||||
| std::cout << "model_path and model_name should not be null" << std::endl; | |||||
| return false; | |||||
| } | |||||
| if (args_->device_type != "Ascend") { | |||||
| std::cout << "device_type only support Ascend right now" << std::endl; | |||||
| return false; | |||||
| } | |||||
| return true; | |||||
| } | |||||
| bool Options::ParseCommandLine(int argc, char **argv) { | |||||
| if (argc < 2 || (strcmp(argv[1], "--help") == 0)) { | |||||
| Usage(); | |||||
| return false; | |||||
| } | |||||
| std::vector<std::string> unkown_options; | |||||
| for (int i = 1; i < argc; ++i) { | |||||
| bool found = false; | |||||
| for (auto &option : options_) { | |||||
| std::string arg = argv[i]; | |||||
| if (option.Parse(&arg)) { | |||||
| found = true; | |||||
| break; | |||||
| } | |||||
| } | |||||
| if (found == false) { | |||||
| unkown_options.push_back(argv[i]); | |||||
| } | |||||
| } | |||||
| if (!unkown_options.empty()) { | |||||
| std::cout << "unkown options:" << std::endl; | |||||
| for (const auto &option : unkown_options) { | |||||
| std::cout << option << std::endl; | |||||
| } | |||||
| } | |||||
| bool valid = (unkown_options.empty() && CheckOptions()); | |||||
| if (!valid) { | |||||
| Usage(); | |||||
| } | |||||
| return valid; | |||||
| } | |||||
| void Options::Usage() { | |||||
| std::cout << "USAGE: mindspore-serving [options]" << std::endl; | |||||
| for (const auto &option : options_) { | |||||
| std::string type; | |||||
| switch (option.type_) { | |||||
| case Option::MS_TYPE_BOOL: | |||||
| type = "bool"; | |||||
| break; | |||||
| case Option::MS_TYPE_FLOAT: | |||||
| type = "float"; | |||||
| break; | |||||
| case Option::MS_TYPE_INT32: | |||||
| type = "int32"; | |||||
| break; | |||||
| case Option::MS_TYPE_STRING: | |||||
| type = "string"; | |||||
| break; | |||||
| default: | |||||
| break; | |||||
| } | |||||
| std::cout << "--" << std::setw(30) << std::left << option.name_ << std::setw(10) << std::left << type | |||||
| << option.usage_ << std::endl; | |||||
| } | |||||
| } | |||||
| } // namespace serving | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,84 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_SERVING_OPTION_PARSER_H_ | |||||
| #define MINDSPORE_SERVING_OPTION_PARSER_H_ | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include <memory> | |||||
| namespace mindspore { | |||||
| namespace serving { | |||||
| struct Arguments { | |||||
| int32_t grpc_port = 5500; | |||||
| std::string grpc_socket_path; | |||||
| std::string ssl_config_file; | |||||
| int32_t poll_model_wait_seconds = 1; | |||||
| std::string model_name; | |||||
| std::string model_path; | |||||
| std::string device_type = "Ascend"; | |||||
| int32_t device_id = 0; | |||||
| }; | |||||
| class Option { | |||||
| public: | |||||
| Option(const std::string &name, int32_t *default_point, const std::string &usage); | |||||
| Option(const std::string &name, bool *default_point, const std::string &usage); | |||||
| Option(const std::string &name, std::string *default_point, const std::string &usage); | |||||
| Option(const std::string &name, float *default_point, const std::string &usage); | |||||
| private: | |||||
| friend class Options; | |||||
| bool ParseInt32(std::string *arg); | |||||
| bool ParseBool(std::string *arg); | |||||
| bool ParseString(std::string *arg); | |||||
| bool ParseFloat(std::string *arg); | |||||
| bool Parse(std::string *arg); | |||||
| std::string name_; | |||||
| enum { MS_TYPE_INT32, MS_TYPE_BOOL, MS_TYPE_STRING, MS_TYPE_FLOAT } type_; | |||||
| int32_t *int32_default_; | |||||
| bool *bool_default_; | |||||
| std::string *string_default_; | |||||
| float *float_default_; | |||||
| std::string usage_; | |||||
| }; | |||||
| class Options { | |||||
| public: | |||||
| ~Options() = default; | |||||
| Options(const Options &) = delete; | |||||
| Options &operator=(const Options &) = delete; | |||||
| static Options &Instance(); | |||||
| bool ParseCommandLine(int argc, char **argv); | |||||
| void Usage(); | |||||
| std::shared_ptr<Arguments> GetArgs() { return args_; } | |||||
| private: | |||||
| Options(); | |||||
| void CreateOptions(); | |||||
| bool CheckOptions(); | |||||
| static std::shared_ptr<Options> inst_; | |||||
| std::string usage_; | |||||
| std::vector<Option> options_; | |||||
| std::shared_ptr<Arguments> args_; | |||||
| }; | |||||
| } // namespace serving | |||||
| } // namespace mindspore | |||||
| #endif | |||||
| @@ -0,0 +1,25 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_STATUS_H | |||||
| #define MINDSPORE_STATUS_H | |||||
| namespace mindspore { | |||||
| namespace serving { | |||||
| using Status = uint32_t; | |||||
| enum ServingStatus { SUCCESS = 0, FAILED }; | |||||
| } // namespace serving | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_STATUS_H | |||||
| @@ -0,0 +1,33 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "core/version_control/model.h" | |||||
| #include <string> | |||||
| #include "mindspore/ccsrc/utils/log_adapter.h" | |||||
| namespace mindspore { | |||||
| namespace serving { | |||||
| MindSporeModel::MindSporeModel(const std::string &model_name, const std::string &model_path, | |||||
| const std::string &model_version, const time_t &last_update_time) | |||||
| : model_name_(model_name), | |||||
| model_path_(model_path), | |||||
| model_version_(model_version), | |||||
| last_update_time_(last_update_time) { | |||||
| MS_LOG(INFO) << "init mindspore model, model_name = " << model_name_ << ", model_path = " << model_path_ | |||||
| << ", model_version = " << model_version_ << ", last_update_time = " << last_update_time_; | |||||
| } | |||||
| } // namespace serving | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,47 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_SERVING_MODEL_H_ | |||||
| #define MINDSPORE_SERVING_MODEL_H_ | |||||
| #include <string> | |||||
| #include <ctime> | |||||
| #include <memory> | |||||
| namespace mindspore { | |||||
| namespace serving { | |||||
| class MindSporeModel { | |||||
| public: | |||||
| MindSporeModel(const std::string &model_name, const std::string &model_path, const std::string &model_version, | |||||
| const time_t &last_update_time); | |||||
| ~MindSporeModel() = default; | |||||
| std::string GetModelName() { return model_name_; } | |||||
| std::string GetModelPath() { return model_path_; } | |||||
| std::string GetModelVersion() { return model_version_; } | |||||
| time_t GetLastUpdateTime() { return last_update_time_; } | |||||
| void SetLastUpdateTime(const time_t &last_update_time) { last_update_time_ = last_update_time; } | |||||
| private: | |||||
| std::string model_name_; | |||||
| std::string model_path_; | |||||
| std::string model_version_; | |||||
| time_t last_update_time_; | |||||
| }; | |||||
| using MindSporeModelPtr = std::shared_ptr<MindSporeModel>; | |||||
| } // namespace serving | |||||
| } // namespace mindspore | |||||
| #endif // !MINDSPORE_SERVING_MODEL_H_ | |||||
| @@ -0,0 +1,134 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "core/version_control/version_controller.h" | |||||
| #include <string> | |||||
| #include <iostream> | |||||
| #include <ctime> | |||||
| #include <memory> | |||||
| #include "util/file_system_operation.h" | |||||
| #include "mindspore/ccsrc/utils/log_adapter.h" | |||||
| #include "core/server.h" | |||||
| namespace mindspore { | |||||
| namespace serving { | |||||
| volatile bool stop_poll = false; | |||||
| std::string GetVersionFromPath(const std::string &path) { | |||||
| std::string new_path = path; | |||||
| if (path.back() == '/') { | |||||
| new_path = path.substr(0, path.size() - 1); | |||||
| } | |||||
| std::string::size_type index = new_path.find_last_of("/"); | |||||
| std::string version = new_path.substr(index + 1); | |||||
| return version; | |||||
| } | |||||
| void PeriodicFunction::operator()() { | |||||
| while (true) { | |||||
| std::this_thread::sleep_for(std::chrono::milliseconds(poll_model_wait_seconds_ * 1000)); | |||||
| std::vector<std::string> SubDirs = GetAllSubDirs(models_path_); | |||||
| if (version_control_strategy_ == VersionController::VersionControllerStrategy::kLastest) { | |||||
| auto path = SubDirs.empty() ? models_path_ : SubDirs.back(); | |||||
| std::string model_version = GetVersionFromPath(path); | |||||
| time_t last_update_time = GetModifyTime(path); | |||||
| if (model_version != valid_models_.back()->GetModelVersion()) { | |||||
| MindSporeModelPtr model_ptr = std::make_shared<MindSporeModel>(valid_models_.front()->GetModelName(), path, | |||||
| model_version, last_update_time); | |||||
| valid_models_.back() = model_ptr; | |||||
| Session::Instance().Warmup(valid_models_.back()); | |||||
| } else { | |||||
| if (difftime(valid_models_.back()->GetLastUpdateTime(), last_update_time) < 0) { | |||||
| valid_models_.back()->SetLastUpdateTime(last_update_time); | |||||
| } | |||||
| } | |||||
| } else { | |||||
| // not support | |||||
| } | |||||
| if (stop_poll == true) { | |||||
| break; | |||||
| } | |||||
| } | |||||
| } | |||||
| VersionController::VersionController(int32_t poll_model_wait_seconds, const std::string &models_path, | |||||
| const std::string &model_name) | |||||
| : version_control_strategy_(kLastest), | |||||
| poll_model_wait_seconds_(poll_model_wait_seconds), | |||||
| models_path_(models_path), | |||||
| model_name_(model_name) {} | |||||
| void StopPollModelPeriodic() { stop_poll = true; } | |||||
| VersionController::~VersionController() { | |||||
| StopPollModelPeriodic(); | |||||
| if (poll_model_thread_.joinable()) { | |||||
| poll_model_thread_.join(); | |||||
| } | |||||
| } | |||||
| Status VersionController::Run() { | |||||
| Status ret; | |||||
| ret = CreateInitModels(); | |||||
| if (ret != SUCCESS) { | |||||
| return ret; | |||||
| } | |||||
| // disable periodic check | |||||
| // StartPollModelPeriodic(); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status VersionController::CreateInitModels() { | |||||
| if (!DirOrFileExist(models_path_)) { | |||||
| MS_LOG(ERROR) << "Model Path Not Exist!" << std::endl; | |||||
| return FAILED; | |||||
| } | |||||
| std::vector<std::string> SubDirs = GetAllSubDirs(models_path_); | |||||
| if (version_control_strategy_ == kLastest) { | |||||
| auto path = SubDirs.empty() ? models_path_ : SubDirs.back(); | |||||
| std::string model_version = GetVersionFromPath(path); | |||||
| time_t last_update_time = GetModifyTime(path); | |||||
| MindSporeModelPtr model_ptr = std::make_shared<MindSporeModel>(model_name_, path, model_version, last_update_time); | |||||
| valid_models_.emplace_back(model_ptr); | |||||
| } else { | |||||
| for (auto &dir : SubDirs) { | |||||
| std::string model_version = GetVersionFromPath(dir); | |||||
| time_t last_update_time = GetModifyTime(dir); | |||||
| MindSporeModelPtr model_ptr = std::make_shared<MindSporeModel>(model_name_, dir, model_version, last_update_time); | |||||
| valid_models_.emplace_back(model_ptr); | |||||
| } | |||||
| } | |||||
| if (valid_models_.empty()) { | |||||
| MS_LOG(ERROR) << "There is no valid model for serving"; | |||||
| return FAILED; | |||||
| } | |||||
| Session::Instance().Warmup(valid_models_.back()); | |||||
| return SUCCESS; | |||||
| } | |||||
| void VersionController::StartPollModelPeriodic() { | |||||
| poll_model_thread_ = std::thread( | |||||
| PeriodicFunction(poll_model_wait_seconds_, models_path_, version_control_strategy_, std::ref(valid_models_))); | |||||
| } | |||||
| void VersionController::StopPollModelPeriodic() {} | |||||
| } // namespace serving | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,71 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_SERVING_VERSOIN_CONTROLLER_H_ | |||||
| #define MINDSPORE_SERVING_VERSOIN_CONTROLLER_H_ | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include <thread> | |||||
| #include "./model.h" | |||||
| #include "util/status.h" | |||||
| namespace mindspore { | |||||
| namespace serving { | |||||
| class VersionController { | |||||
| public: | |||||
| enum VersionControllerStrategy { kLastest = 0, kMulti = 1 }; | |||||
| VersionController(int32_t poll_model_wait_seconds, const std::string &models_path, const std::string &model_name); | |||||
| ~VersionController(); | |||||
| Status Run(); | |||||
| void StartPollModelPeriodic(); | |||||
| void StopPollModelPeriodic(); | |||||
| private: | |||||
| Status CreateInitModels(); | |||||
| private: | |||||
| VersionControllerStrategy version_control_strategy_; | |||||
| std::vector<MindSporeModelPtr> valid_models_; | |||||
| int32_t poll_model_wait_seconds_; | |||||
| std::thread poll_model_thread_; | |||||
| std::string models_path_; | |||||
| std::string model_name_; | |||||
| }; | |||||
| class PeriodicFunction { | |||||
| public: | |||||
| PeriodicFunction(int32_t poll_model_wait_seconds, const std::string &models_path, | |||||
| VersionController::VersionControllerStrategy version_control_strategy, | |||||
| const std::vector<MindSporeModelPtr> &valid_models) | |||||
| : poll_model_wait_seconds_(poll_model_wait_seconds), | |||||
| models_path_(models_path), | |||||
| version_control_strategy_(version_control_strategy), | |||||
| valid_models_(valid_models) {} | |||||
| ~PeriodicFunction() = default; | |||||
| void operator()(); | |||||
| private: | |||||
| int32_t poll_model_wait_seconds_; | |||||
| std::string models_path_; | |||||
| VersionController::VersionControllerStrategy version_control_strategy_; | |||||
| std::vector<MindSporeModelPtr> valid_models_; | |||||
| }; | |||||
| } // namespace serving | |||||
| } // namespace mindspore | |||||
| #endif // !MINDSPORE_SERVING_VERSOIN_CONTROLLER_H_ | |||||
| @@ -0,0 +1,72 @@ | |||||
| cmake_minimum_required(VERSION 3.5.1) | |||||
| project(HelloWorld C CXX) | |||||
| set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11") | |||||
| find_package(Threads REQUIRED) | |||||
| # This branch assumes that gRPC and all its dependencies are already installed | |||||
| # on this system, so they can be located by find_package(). | |||||
| # Find Protobuf installation | |||||
| # Looks for protobuf-config.cmake file installed by Protobuf's cmake installation. | |||||
| set(protobuf_MODULE_COMPATIBLE TRUE) | |||||
| find_package(Protobuf CONFIG REQUIRED) | |||||
| message(STATUS "Using protobuf ${protobuf_VERSION}") | |||||
| set(_PROTOBUF_LIBPROTOBUF protobuf::libprotobuf) | |||||
| set(_REFLECTION gRPC::grpc++_reflection) | |||||
| if(CMAKE_CROSSCOMPILING) | |||||
| find_program(_PROTOBUF_PROTOC protoc) | |||||
| else() | |||||
| set(_PROTOBUF_PROTOC $<TARGET_FILE:protobuf::protoc>) | |||||
| endif() | |||||
| # Find gRPC installation | |||||
| # Looks for gRPCConfig.cmake file installed by gRPC's cmake installation. | |||||
| find_package(gRPC CONFIG REQUIRED) | |||||
| message(STATUS "Using gRPC ${gRPC_VERSION}") | |||||
| set(_GRPC_GRPCPP gRPC::grpc++) | |||||
| if(CMAKE_CROSSCOMPILING) | |||||
| find_program(_GRPC_CPP_PLUGIN_EXECUTABLE grpc_cpp_plugin) | |||||
| else() | |||||
| set(_GRPC_CPP_PLUGIN_EXECUTABLE $<TARGET_FILE:gRPC::grpc_cpp_plugin>) | |||||
| endif() | |||||
| # Proto file | |||||
| get_filename_component(hw_proto "../ms_service.proto" ABSOLUTE) | |||||
| get_filename_component(hw_proto_path "${hw_proto}" PATH) | |||||
| # Generated sources | |||||
| set(hw_proto_srcs "${CMAKE_CURRENT_BINARY_DIR}/ms_service.pb.cc") | |||||
| set(hw_proto_hdrs "${CMAKE_CURRENT_BINARY_DIR}/ms_service.pb.h") | |||||
| set(hw_grpc_srcs "${CMAKE_CURRENT_BINARY_DIR}/ms_service.grpc.pb.cc") | |||||
| set(hw_grpc_hdrs "${CMAKE_CURRENT_BINARY_DIR}/ms_service.grpc.pb.h") | |||||
| add_custom_command( | |||||
| OUTPUT "${hw_proto_srcs}" "${hw_proto_hdrs}" "${hw_grpc_srcs}" "${hw_grpc_hdrs}" | |||||
| COMMAND ${_PROTOBUF_PROTOC} | |||||
| ARGS --grpc_out "${CMAKE_CURRENT_BINARY_DIR}" | |||||
| --cpp_out "${CMAKE_CURRENT_BINARY_DIR}" | |||||
| -I "${hw_proto_path}" | |||||
| --plugin=protoc-gen-grpc="${_GRPC_CPP_PLUGIN_EXECUTABLE}" | |||||
| "${hw_proto}" | |||||
| DEPENDS "${hw_proto}") | |||||
| # Include generated *.pb.h files | |||||
| include_directories("${CMAKE_CURRENT_BINARY_DIR}") | |||||
| # Targets greeter_[async_](client|server) | |||||
| foreach(_target | |||||
| ms_client ms_server) | |||||
| add_executable(${_target} "${_target}.cc" | |||||
| ${hw_proto_srcs} | |||||
| ${hw_grpc_srcs}) | |||||
| target_link_libraries(${_target} | |||||
| ${_REFLECTION} | |||||
| ${_GRPC_GRPCPP} | |||||
| ${_PROTOBUF_LIBPROTOBUF}) | |||||
| endforeach() | |||||
| @@ -0,0 +1,105 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <grpcpp/grpcpp.h> | |||||
| #include <iostream> | |||||
| #include "serving/ms_service.grpc.pb.h" | |||||
| using grpc::Channel; | |||||
| using grpc::ClientContext; | |||||
| using grpc::Status; | |||||
| using ms_serving::MSService; | |||||
| using ms_serving::PredictReply; | |||||
| using ms_serving::PredictRequest; | |||||
| using ms_serving::Tensor; | |||||
| using ms_serving::TensorShape; | |||||
| class MSClient { | |||||
| public: | |||||
| explicit MSClient(std::shared_ptr<Channel> channel) : stub_(MSService::NewStub(channel)) {} | |||||
| std::string Predict(const std::string &user) { | |||||
| // Data we are sending to the server. | |||||
| PredictRequest request; | |||||
| Tensor data; | |||||
| TensorShape shape; | |||||
| shape.add_dims(1); | |||||
| shape.add_dims(1); | |||||
| shape.add_dims(2); | |||||
| shape.add_dims(2); | |||||
| *data.mutable_tensor_shape() = shape; | |||||
| data.set_tensor_type(ms_serving::MS_FLOAT32); | |||||
| vector<float> input_data{1.1, 2.1, 3.1, 4.1}; | |||||
| data.set_data(input_data.data(), input_data.size()); | |||||
| *request.add_data() = data; | |||||
| *request.add_data() = data; | |||||
| // Container for the data we expect from the server. | |||||
| PredictReply reply; | |||||
| // Context for the client. It could be used to convey extra information to | |||||
| // the server and/or tweak certain RPC behaviors. | |||||
| ClientContext context; | |||||
| // The actual RPC. | |||||
| Status status = stub_->Predict(&context, request, &reply); | |||||
| // Act upon its status. | |||||
| if (status.ok()) { | |||||
| return "RPC OK"; | |||||
| } else { | |||||
| std::cout << status.error_code() << ": " << status.error_message() << std::endl; | |||||
| return "RPC failed"; | |||||
| } | |||||
| } | |||||
| private: | |||||
| std::unique_ptr<MSService::Stub> stub_; | |||||
| }; | |||||
| int main(int argc, char **argv) { | |||||
| // Instantiate the client. It requires a channel, out of which the actual RPCs | |||||
| // are created. This channel models a connection to an endpoint specified by | |||||
| // the argument "--target=" which is the only expected argument. | |||||
| // We indicate that the channel isn't authenticated (use of | |||||
| // InsecureChannelCredentials()). | |||||
| std::string target_str; | |||||
| std::string arg_str("--target"); | |||||
| if (argc > 1) { | |||||
| std::string arg_val = argv[1]; | |||||
| size_t start_pos = arg_val.find(arg_str); | |||||
| if (start_pos != std::string::npos) { | |||||
| start_pos += arg_str.size(); | |||||
| if (arg_val[start_pos] == '=') { | |||||
| target_str = arg_val.substr(start_pos + 1); | |||||
| } else { | |||||
| std::cout << "The only correct argument syntax is --target=" << std::endl; | |||||
| return 0; | |||||
| } | |||||
| } else { | |||||
| std::cout << "The only acceptable argument is --target=" << std::endl; | |||||
| return 0; | |||||
| } | |||||
| } else { | |||||
| target_str = "localhost:85010"; | |||||
| } | |||||
| MSClient client(grpc::CreateChannel(target_str, grpc::InsecureChannelCredentials())); | |||||
| string request; | |||||
| string reply = client.Predict(request); | |||||
| std::cout << "client received: " << reply << std::endl; | |||||
| return 0; | |||||
| } | |||||
| @@ -0,0 +1,67 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <grpcpp/grpcpp.h> | |||||
| #include <grpcpp/health_check_service_interface.h> | |||||
| #include <grpcpp/ext/proto_server_reflection_plugin.h> | |||||
| #include <iostream> | |||||
| #include "serving/ms_service.grpc.pb.h" | |||||
| using grpc::Server; | |||||
| using grpc::ServerBuilder; | |||||
| using grpc::ServerContext; | |||||
| using grpc::Status; | |||||
| using ms_serving::MSService; | |||||
| using ms_serving::PredictReply; | |||||
| using ms_serving::PredictRequest; | |||||
| // Logic and data behind the server's behavior. | |||||
| class MSServiceImpl final : public MSService::Service { | |||||
| Status Predict(ServerContext *context, const PredictRequest *request, PredictReply *reply) override { | |||||
| cout << "server eval" << endl; | |||||
| return Status::OK; | |||||
| } | |||||
| }; | |||||
| void RunServer() { | |||||
| std::string server_address("0.0.0.0:50051"); | |||||
| MSServiceImpl service; | |||||
| grpc::EnableDefaultHealthCheckService(true); | |||||
| grpc::reflection::InitProtoReflectionServerBuilderPlugin(); | |||||
| auto option = grpc::MakeChannelArgumentOption(GRPC_ARG_ALLOW_REUSEPORT, 0); | |||||
| ServerBuilder builder; | |||||
| builder.SetOption(std::move(option)); | |||||
| // Listen on the given address without any authentication mechanism. | |||||
| builder.AddListeningPort(server_address, grpc::InsecureServerCredentials()); | |||||
| // Register "service" as the instance through which we'll communicate with | |||||
| // clients. In this case it corresponds to an *synchronous* service. | |||||
| builder.RegisterService(&service); | |||||
| // Finally assemble the server. | |||||
| std::unique_ptr<Server> server(builder.BuildAndStart()); | |||||
| std::cout << "Server listening on " << server_address << std::endl; | |||||
| // Wait for the server to shutdown. Note that some other thread must be | |||||
| // responsible for shutting down the server for this call to ever return. | |||||
| server->Wait(); | |||||
| } | |||||
| int main(int argc, char **argv) { | |||||
| RunServer(); | |||||
| return 0; | |||||
| } | |||||
| @@ -0,0 +1,29 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "core/server.h" | |||||
| #include "core/util/option_parser.h" | |||||
| using mindspore::serving::Options; | |||||
| int main(int argc, char **argv) { | |||||
| auto flag = Options::Instance().ParseCommandLine(argc, argv); | |||||
| if (!flag) { | |||||
| return 0; | |||||
| } | |||||
| mindspore::serving::Server server; | |||||
| server.BuildAndStart(); | |||||
| return 0; | |||||
| } | |||||
| @@ -0,0 +1,48 @@ | |||||
| // ms_service.proto | |||||
| syntax = "proto3"; | |||||
| package ms_serving; | |||||
| service MSService { | |||||
| rpc Predict(PredictRequest) returns (PredictReply) {} | |||||
| rpc Test(PredictRequest) returns (PredictReply) {} | |||||
| } | |||||
| message PredictRequest { | |||||
| repeated Tensor data = 1; | |||||
| } | |||||
| message PredictReply { | |||||
| repeated Tensor result = 1; | |||||
| } | |||||
| enum DataType { | |||||
| MS_UNKNOWN = 0; | |||||
| MS_BOOL = 1; | |||||
| MS_INT8 = 2; | |||||
| MS_UINT8 = 3; | |||||
| MS_INT16 = 4; | |||||
| MS_UINT16 = 5; | |||||
| MS_INT32 = 6; | |||||
| MS_UINT32 = 7; | |||||
| MS_INT64 = 8; | |||||
| MS_UINT64 = 9; | |||||
| MS_FLOAT16 = 10; | |||||
| MS_FLOAT32 = 11; | |||||
| MS_FLOAT64 = 12; | |||||
| } | |||||
| message TensorShape { | |||||
| repeated int64 dims = 1; | |||||
| }; | |||||
| message Tensor { | |||||
| // tensor shape info | |||||
| TensorShape tensor_shape = 1; | |||||
| // tensor content data type | |||||
| DataType tensor_type = 2; | |||||
| // tensor data | |||||
| bytes data = 3; | |||||
| } | |||||
| @@ -0,0 +1,57 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| import grpc | |||||
| import numpy as np | |||||
| import ms_service_pb2 | |||||
| import ms_service_pb2_grpc | |||||
| def run(): | |||||
| channel = grpc.insecure_channel('localhost:50051') | |||||
| stub = ms_service_pb2_grpc.MSServiceStub(channel) | |||||
| # request = ms_service_pb2.PredictRequest() | |||||
| # request.name = 'haha' | |||||
| # response = stub.Eval(request) | |||||
| # print("ms client received: " + response.message) | |||||
| request = ms_service_pb2.PredictRequest() | |||||
| request.data.tensor_shape.dims.extend([32, 1, 32, 32]) | |||||
| request.data.tensor_type = ms_service_pb2.MS_FLOAT32 | |||||
| request.data.data = (np.ones([32, 1, 32, 32]).astype(np.float32) * 0.01).tobytes() | |||||
| request.label.tensor_shape.dims.extend([32]) | |||||
| request.label.tensor_type = ms_service_pb2.MS_INT32 | |||||
| request.label.data = np.ones([32]).astype(np.int32).tobytes() | |||||
| result = stub.Predict(request) | |||||
| #result_np = np.frombuffer(result.result.data, dtype=np.float32).reshape(result.result.tensor_shape.dims) | |||||
| print("ms client received: ") | |||||
| #print(result_np) | |||||
| # future_list = [] | |||||
| # times = 1000 | |||||
| # for i in range(times): | |||||
| # async_future = stub.Eval.future(request) | |||||
| # future_list.append(async_future) | |||||
| # print("async call, future list add item " + str(i)); | |||||
| # | |||||
| # for i in range(len(future_list)): | |||||
| # async_result = future_list[i].result() | |||||
| # print("ms client async get result of item " + str(i)) | |||||
| if __name__ == '__main__': | |||||
| run() | |||||
| @@ -0,0 +1,46 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| import grpc | |||||
| import numpy as np | |||||
| import ms_service_pb2 | |||||
| import ms_service_pb2_grpc | |||||
| def run(): | |||||
| channel = grpc.insecure_channel('localhost:50051') | |||||
| stub = ms_service_pb2_grpc.MSServiceStub(channel) | |||||
| # request = ms_service_pb2.EvalRequest() | |||||
| # request.name = 'haha' | |||||
| # response = stub.Eval(request) | |||||
| # print("ms client received: " + response.message) | |||||
| request = ms_service_pb2.PredictRequest() | |||||
| request.data.tensor_shape.dims.extend([32, 1, 32, 32]) | |||||
| request.data.tensor_type = ms_service_pb2.MS_FLOAT32 | |||||
| request.data.data = (np.ones([32, 1, 32, 32]).astype(np.float32) * 0.01).tobytes() | |||||
| request.label.tensor_shape.dims.extend([32]) | |||||
| request.label.tensor_type = ms_service_pb2.MS_INT32 | |||||
| request.label.data = np.ones([32]).astype(np.int32).tobytes() | |||||
| result = stub.Test(request) | |||||
| #result_np = np.frombuffer(result.result.data, dtype=np.float32).reshape(result.result.tensor_shape.dims) | |||||
| print("ms client test call received: ") | |||||
| #print(result_np) | |||||
| if __name__ == '__main__': | |||||
| run() | |||||
| @@ -0,0 +1,55 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| from concurrent import futures | |||||
| import time | |||||
| import grpc | |||||
| import numpy as np | |||||
| import ms_service_pb2 | |||||
| import ms_service_pb2_grpc | |||||
| import test_cpu_lenet | |||||
| from mindspore import Tensor | |||||
| class MSService(ms_service_pb2_grpc.MSServiceServicer): | |||||
| def Predict(self, request, context): | |||||
| request_data = request.data | |||||
| request_label = request.label | |||||
| data_from_buffer = np.frombuffer(request_data.data, dtype=np.float32) | |||||
| data_from_buffer = data_from_buffer.reshape(request_data.tensor_shape.dims) | |||||
| data = Tensor(data_from_buffer) | |||||
| label_from_buffer = np.frombuffer(request_label.data, dtype=np.int32) | |||||
| label_from_buffer = label_from_buffer.reshape(request_label.tensor_shape.dims) | |||||
| label = Tensor(label_from_buffer) | |||||
| result = test_cpu_lenet.test_lenet(data, label) | |||||
| result_reply = ms_service_pb2.PredictReply() | |||||
| result_reply.result.tensor_shape.dims.extend(result.shape()) | |||||
| result_reply.result.data = result.asnumpy().tobytes() | |||||
| return result_reply | |||||
| def serve(): | |||||
| server = grpc.server(futures.ThreadPoolExecutor(max_workers=1)) | |||||
| ms_service_pb2_grpc.add_MSServiceServicer_to_server(MSService(), server) | |||||
| server.add_insecure_port('[::]:50051') | |||||
| server.start() | |||||
| try: | |||||
| while True: | |||||
| time.sleep(60*60*24) # one day in seconds | |||||
| except KeyboardInterrupt: | |||||
| server.stop(0) | |||||
| if __name__ == '__main__': | |||||
| serve() | |||||
| @@ -0,0 +1,96 @@ | |||||
| # Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! | |||||
| import grpc | |||||
| import ms_service_pb2 as ms__service__pb2 | |||||
| class MSServiceStub(object): | |||||
| """Missing associated documentation comment in .proto file""" | |||||
| def __init__(self, channel): | |||||
| """Constructor. | |||||
| Args: | |||||
| channel: A grpc.Channel. | |||||
| """ | |||||
| self.Predict = channel.unary_unary( | |||||
| '/ms_serving.MSService/Predict', | |||||
| request_serializer=ms__service__pb2.PredictRequest.SerializeToString, | |||||
| response_deserializer=ms__service__pb2.PredictReply.FromString, | |||||
| ) | |||||
| self.Test = channel.unary_unary( | |||||
| '/ms_serving.MSService/Test', | |||||
| request_serializer=ms__service__pb2.PredictRequest.SerializeToString, | |||||
| response_deserializer=ms__service__pb2.PredictReply.FromString, | |||||
| ) | |||||
| class MSServiceServicer(object): | |||||
| """Missing associated documentation comment in .proto file""" | |||||
| def Predict(self, request, context): | |||||
| """Missing associated documentation comment in .proto file""" | |||||
| context.set_code(grpc.StatusCode.UNIMPLEMENTED) | |||||
| context.set_details('Method not implemented!') | |||||
| raise NotImplementedError('Method not implemented!') | |||||
| def Test(self, request, context): | |||||
| """Missing associated documentation comment in .proto file""" | |||||
| context.set_code(grpc.StatusCode.UNIMPLEMENTED) | |||||
| context.set_details('Method not implemented!') | |||||
| raise NotImplementedError('Method not implemented!') | |||||
| def add_MSServiceServicer_to_server(servicer, server): | |||||
| rpc_method_handlers = { | |||||
| 'Predict': grpc.unary_unary_rpc_method_handler( | |||||
| servicer.Predict, | |||||
| request_deserializer=ms__service__pb2.PredictRequest.FromString, | |||||
| response_serializer=ms__service__pb2.PredictReply.SerializeToString, | |||||
| ), | |||||
| 'Test': grpc.unary_unary_rpc_method_handler( | |||||
| servicer.Test, | |||||
| request_deserializer=ms__service__pb2.PredictRequest.FromString, | |||||
| response_serializer=ms__service__pb2.PredictReply.SerializeToString, | |||||
| ), | |||||
| } | |||||
| generic_handler = grpc.method_handlers_generic_handler( | |||||
| 'ms_serving.MSService', rpc_method_handlers) | |||||
| server.add_generic_rpc_handlers((generic_handler,)) | |||||
| # This class is part of an EXPERIMENTAL API. | |||||
| class MSService(object): | |||||
| """Missing associated documentation comment in .proto file""" | |||||
| @staticmethod | |||||
| def Predict(request, | |||||
| target, | |||||
| options=(), | |||||
| channel_credentials=None, | |||||
| call_credentials=None, | |||||
| compression=None, | |||||
| wait_for_ready=None, | |||||
| timeout=None, | |||||
| metadata=None): | |||||
| return grpc.experimental.unary_unary(request, target, '/ms_serving.MSService/Predict', | |||||
| ms__service__pb2.PredictRequest.SerializeToString, | |||||
| ms__service__pb2.PredictReply.FromString, | |||||
| options, channel_credentials, | |||||
| call_credentials, compression, wait_for_ready, timeout, metadata) | |||||
| @staticmethod | |||||
| def Test(request, | |||||
| target, | |||||
| options=(), | |||||
| channel_credentials=None, | |||||
| call_credentials=None, | |||||
| compression=None, | |||||
| wait_for_ready=None, | |||||
| timeout=None, | |||||
| metadata=None): | |||||
| return grpc.experimental.unary_unary(request, target, '/ms_serving.MSService/Test', | |||||
| ms__service__pb2.PredictRequest.SerializeToString, | |||||
| ms__service__pb2.PredictReply.FromString, | |||||
| options, channel_credentials, | |||||
| call_credentials, compression, wait_for_ready, timeout, metadata) | |||||
| @@ -0,0 +1,91 @@ | |||||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| import numpy as np | |||||
| import mindspore.context as context | |||||
| import mindspore.nn as nn | |||||
| from mindspore import Tensor | |||||
| from mindspore.nn import TrainOneStepCell, WithLossCell | |||||
| from mindspore.nn.optim import Momentum | |||||
| from mindspore.ops import operations as P | |||||
| import ms_service_pb2 | |||||
| class LeNet(nn.Cell): | |||||
| def __init__(self): | |||||
| super(LeNet, self).__init__() | |||||
| self.relu = P.ReLU() | |||||
| self.batch_size = 32 | |||||
| self.conv1 = nn.Conv2d(1, 6, kernel_size=5, stride=1, padding=0, has_bias=False, pad_mode='valid') | |||||
| self.conv2 = nn.Conv2d(6, 16, kernel_size=5, stride=1, padding=0, has_bias=False, pad_mode='valid') | |||||
| self.pool = nn.MaxPool2d(kernel_size=2, stride=2) | |||||
| self.reshape = P.Reshape() | |||||
| self.fc1 = nn.Dense(400, 120) | |||||
| self.fc2 = nn.Dense(120, 84) | |||||
| self.fc3 = nn.Dense(84, 10) | |||||
| def construct(self, input_x): | |||||
| output = self.conv1(input_x) | |||||
| output = self.relu(output) | |||||
| output = self.pool(output) | |||||
| output = self.conv2(output) | |||||
| output = self.relu(output) | |||||
| output = self.pool(output) | |||||
| output = self.reshape(output, (self.batch_size, -1)) | |||||
| output = self.fc1(output) | |||||
| output = self.relu(output) | |||||
| output = self.fc2(output) | |||||
| output = self.relu(output) | |||||
| output = self.fc3(output) | |||||
| return output | |||||
| def train(net, data, label): | |||||
| learning_rate = 0.01 | |||||
| momentum = 0.9 | |||||
| optimizer = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), learning_rate, momentum) | |||||
| criterion = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) | |||||
| net_with_criterion = WithLossCell(net, criterion) | |||||
| train_network = TrainOneStepCell(net_with_criterion, optimizer) # optimizer | |||||
| train_network.set_train() | |||||
| res = train_network(data, label) | |||||
| print("+++++++++Loss+++++++++++++") | |||||
| print(res) | |||||
| print("+++++++++++++++++++++++++++") | |||||
| assert res | |||||
| return res | |||||
| def test_lenet(data, label): | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="CPU") | |||||
| net = LeNet() | |||||
| return train(net, data, label) | |||||
| if __name__ == '__main__': | |||||
| tensor = ms_service_pb2.Tensor() | |||||
| tensor.tensor_shape.dim.extend([32, 1, 32, 32]) | |||||
| # tensor.tensor_shape.dim.add() = 1 | |||||
| # tensor.tensor_shape.dim.add() = 32 | |||||
| # tensor.tensor_shape.dim.add() = 32 | |||||
| tensor.tensor_type = ms_service_pb2.MS_FLOAT32 | |||||
| tensor.data = np.ones([32, 1, 32, 32]).astype(np.float32).tobytes() | |||||
| data_from_buffer = np.frombuffer(tensor.data, dtype=np.float32) | |||||
| print(tensor.tensor_shape.dim) | |||||
| data_from_buffer = data_from_buffer.reshape(tensor.tensor_shape.dim) | |||||
| print(data_from_buffer.shape) | |||||
| input_data = Tensor(data_from_buffer * 0.01) | |||||
| input_label = Tensor(np.ones([32]).astype(np.int32)) | |||||
| test_lenet(input_data, input_label) | |||||
| @@ -0,0 +1,105 @@ | |||||
| #!/bin/bash | |||||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| set -e | |||||
| CLANG_FORMAT=$(which clang-format) || (echo "Please install 'clang-format' tool first"; exit 1) | |||||
| version=$("${CLANG_FORMAT}" --version | sed -n "s/.*\ \([0-9]*\)\.[0-9]*\.[0-9]*.*/\1/p") | |||||
| if [[ "${version}" -lt "8" ]]; then | |||||
| echo "clang-format's version must be at least 8.0.0" | |||||
| exit 1 | |||||
| fi | |||||
| CURRENT_PATH=$(pwd) | |||||
| SCRIPTS_PATH=$(dirname "$0") | |||||
| echo "CURRENT_PATH=${CURRENT_PATH}" | |||||
| echo "SCRIPTS_PATH=${SCRIPTS_PATH}" | |||||
| # print usage message | |||||
| function usage() | |||||
| { | |||||
| echo "Format the specified source files to conform the code style." | |||||
| echo "Usage:" | |||||
| echo "bash $0 [-a] [-c] [-l] [-h]" | |||||
| echo "e.g. $0 -c" | |||||
| echo "" | |||||
| echo "Options:" | |||||
| echo " -a format of all files" | |||||
| echo " -c format of the files changed compared to last commit, default case" | |||||
| echo " -l format of the files changed in last commit" | |||||
| echo " -h Print usage" | |||||
| } | |||||
| # check and set options | |||||
| function checkopts() | |||||
| { | |||||
| # init variable | |||||
| mode="changed" # default format changed files | |||||
| # Process the options | |||||
| while getopts 'aclh' opt | |||||
| do | |||||
| case "${opt}" in | |||||
| a) | |||||
| mode="all" | |||||
| ;; | |||||
| c) | |||||
| mode="changed" | |||||
| ;; | |||||
| l) | |||||
| mode="lastcommit" | |||||
| ;; | |||||
| h) | |||||
| usage | |||||
| exit 0 | |||||
| ;; | |||||
| *) | |||||
| echo "Unknown option ${opt}!" | |||||
| usage | |||||
| exit 1 | |||||
| esac | |||||
| done | |||||
| } | |||||
| # init variable | |||||
| # check options | |||||
| checkopts "$@" | |||||
| # switch to project root path, which contains clang-format config file '.clang-format' | |||||
| cd "${SCRIPTS_PATH}/.." || exit 1 | |||||
| FMT_FILE_LIST='__format_files_list__' | |||||
| if [[ "X${mode}" == "Xall" ]]; then | |||||
| find ./ -type f -name "*" | grep "\.h$\|\.cc$" > "${FMT_FILE_LIST}" || true | |||||
| elif [[ "X${mode}" == "Xchanged" ]]; then | |||||
| git diff --name-only | grep "\.h$\|\.cc$" > "${FMT_FILE_LIST}" || true | |||||
| else # "X${mode}" == "Xlastcommit" | |||||
| git diff --name-only HEAD~ HEAD | grep "\.h$\|\.cc$" > "${FMT_FILE_LIST}" || true | |||||
| fi | |||||
| while read line; do | |||||
| if [ -f "${line}" ]; then | |||||
| ${CLANG_FORMAT} -i "${line}" | |||||
| fi | |||||
| done < "${FMT_FILE_LIST}" | |||||
| rm "${FMT_FILE_LIST}" | |||||
| cd "${CURRENT_PATH}" || exit 1 | |||||
| echo "Specified cpp source files have been format successfully." | |||||