You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

server.cc 11 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "core/server.h"
  17. #include <grpcpp/grpcpp.h>
  18. #include <grpcpp/health_check_service_interface.h>
  19. #include <grpcpp/ext/proto_server_reflection_plugin.h>
  20. #include <string>
  21. #include <map>
  22. #include <vector>
  23. #include <utility>
  24. #include <memory>
  25. #include <future>
  26. #include "mindspore/ccsrc/utils/log_adapter.h"
  27. #include "serving/ms_service.grpc.pb.h"
  28. #include "core/util/option_parser.h"
  29. #include "core/version_control/version_controller.h"
  30. #include "mindspore/ccsrc/utils/context/ms_context.h"
  31. #include "core/util/file_system_operation.h"
  32. #include "graphengine/third_party/fwkacllib/inc/runtime/context.h"
  33. using ms_serving::MSService;
  34. using ms_serving::PredictReply;
  35. using ms_serving::PredictRequest;
  36. namespace mindspore {
  37. namespace serving {
  38. using MSTensorPtr = std::shared_ptr<inference::MSTensor>;
  39. Status Session::CreatDeviceSession(const std::string &device, uint32_t device_id) {
  40. session_ = inference::MSSession::CreateSession(device, device_id);
  41. if (session_ == nullptr) {
  42. MS_LOG(ERROR) << "Creat Session Failed";
  43. return FAILED;
  44. }
  45. device_type_ = device;
  46. return SUCCESS;
  47. }
  48. Session &Session::Instance() {
  49. static Session instance;
  50. return instance;
  51. }
  52. Status Session::Predict(const std::vector<MSTensorPtr> &inputs, inference::MultiTensor *outputs) {
  53. if (last_graph_ == nullptr) {
  54. MS_LOG(ERROR) << "the model has not loaded";
  55. return FAILED;
  56. }
  57. if (session_ == nullptr) {
  58. MS_LOG(ERROR) << "the inference session has not be initialized";
  59. return FAILED;
  60. }
  61. std::lock_guard<std::mutex> lock(mutex_);
  62. MS_LOG(INFO) << "run Predict";
  63. *outputs = session_->RunGraph(graph_id_, inputs);
  64. MS_LOG(INFO) << "run Predict finished";
  65. return SUCCESS;
  66. }
  67. Status Session::Warmup(const MindSporeModelPtr model) {
  68. if (session_ == nullptr) {
  69. MS_LOG(ERROR) << "The CreatDeviceSession should be called, before warmup";
  70. return FAILED;
  71. }
  72. std::lock_guard<std::mutex> lock(mutex_);
  73. size_t size = 0;
  74. std::string file_name = model->GetModelPath() + '/' + model->GetModelName();
  75. char *graphBuf = ReadFile(file_name.c_str(), &size);
  76. if (graphBuf == nullptr) {
  77. MS_LOG(ERROR) << "Read model file failed, file name is " << file_name.c_str();
  78. return FAILED;
  79. }
  80. last_graph_ = inference::LoadModel(graphBuf, size, device_type_);
  81. if (last_graph_ == nullptr) {
  82. MS_LOG(ERROR) << "Load graph model failed, file name is " << file_name.c_str();
  83. return FAILED;
  84. }
  85. graph_id_ = session_->CompileGraph(last_graph_);
  86. MS_LOG(INFO) << "Session Warmup finished";
  87. return SUCCESS;
  88. }
  89. Status Session::Clear() {
  90. session_ = nullptr;
  91. return SUCCESS;
  92. }
  93. namespace {
  94. static const uint32_t uint32max = 0x7FFFFFFF;
  95. std::promise<void> exit_requested;
  96. const std::map<ms_serving::DataType, TypeId> type2id_map{
  97. {ms_serving::MS_UNKNOWN, TypeId::kNumberTypeBegin}, {ms_serving::MS_BOOL, TypeId::kNumberTypeBool},
  98. {ms_serving::MS_INT8, TypeId::kNumberTypeInt8}, {ms_serving::MS_UINT8, TypeId::kNumberTypeUInt8},
  99. {ms_serving::MS_INT16, TypeId::kNumberTypeInt16}, {ms_serving::MS_UINT16, TypeId::kNumberTypeUInt16},
  100. {ms_serving::MS_INT32, TypeId::kNumberTypeInt32}, {ms_serving::MS_UINT32, TypeId::kNumberTypeUInt32},
  101. {ms_serving::MS_INT64, TypeId::kNumberTypeInt64}, {ms_serving::MS_UINT64, TypeId::kNumberTypeUInt64},
  102. {ms_serving::MS_FLOAT16, TypeId::kNumberTypeFloat16}, {ms_serving::MS_FLOAT32, TypeId::kNumberTypeFloat32},
  103. {ms_serving::MS_FLOAT64, TypeId::kNumberTypeFloat64},
  104. };
  105. const std::map<TypeId, ms_serving::DataType> id2type_map{
  106. {TypeId::kNumberTypeBegin, ms_serving::MS_UNKNOWN}, {TypeId::kNumberTypeBool, ms_serving::MS_BOOL},
  107. {TypeId::kNumberTypeInt8, ms_serving::MS_INT8}, {TypeId::kNumberTypeUInt8, ms_serving::MS_UINT8},
  108. {TypeId::kNumberTypeInt16, ms_serving::MS_INT16}, {TypeId::kNumberTypeUInt16, ms_serving::MS_UINT16},
  109. {TypeId::kNumberTypeInt32, ms_serving::MS_INT32}, {TypeId::kNumberTypeUInt32, ms_serving::MS_UINT32},
  110. {TypeId::kNumberTypeInt64, ms_serving::MS_INT64}, {TypeId::kNumberTypeUInt64, ms_serving::MS_UINT64},
  111. {TypeId::kNumberTypeFloat16, ms_serving::MS_FLOAT16}, {TypeId::kNumberTypeFloat32, ms_serving::MS_FLOAT32},
  112. {TypeId::kNumberTypeFloat64, ms_serving::MS_FLOAT64},
  113. };
  114. const std::map<ms_serving::DataType, size_t> length_map{
  115. {ms_serving::MS_UNKNOWN, 0},
  116. {ms_serving::MS_BOOL, sizeof(bool)},
  117. {ms_serving::MS_INT8, sizeof(int8_t)},
  118. {ms_serving::MS_UINT8, sizeof(uint8_t)},
  119. {ms_serving::MS_INT16, sizeof(int16_t)},
  120. {ms_serving::MS_UINT16, sizeof(uint16_t)},
  121. {ms_serving::MS_INT32, sizeof(int32_t)},
  122. {ms_serving::MS_UINT32, sizeof(uint32_t)},
  123. {ms_serving::MS_INT64, sizeof(int64_t)},
  124. {ms_serving::MS_UINT64, sizeof(uint64_t)},
  125. {ms_serving::MS_FLOAT16, 2},
  126. {ms_serving::MS_FLOAT32, 4},
  127. {ms_serving::MS_FLOAT64, 8},
  128. };
  129. MSTensorPtr ServingTensor2MSTensor(const ms_serving::Tensor &tensor) {
  130. std::vector<int> shape;
  131. for (auto dim : tensor.tensor_shape().dims()) {
  132. shape.push_back(static_cast<int>(dim));
  133. }
  134. auto iter = type2id_map.find(tensor.tensor_type());
  135. if (iter == type2id_map.end()) {
  136. MS_LOG(ERROR) << "input tensor type is wrong, type is " << tensor.tensor_type();
  137. return nullptr;
  138. }
  139. TypeId type = iter->second;
  140. auto ms_tensor = std::shared_ptr<inference::MSTensor>(inference::MSTensor::CreateTensor(type, shape));
  141. memcpy_s(ms_tensor->MutableData(), ms_tensor->Size(), tensor.data().data(), tensor.data().size());
  142. return ms_tensor;
  143. }
  144. ms_serving::Tensor MSTensor2ServingTensor(MSTensorPtr ms_tensor) {
  145. ms_serving::Tensor tensor;
  146. ms_serving::TensorShape shape;
  147. for (auto dim : ms_tensor->shape()) {
  148. shape.add_dims(dim);
  149. }
  150. *tensor.mutable_tensor_shape() = shape;
  151. auto iter = id2type_map.find(ms_tensor->data_type());
  152. if (iter == id2type_map.end()) {
  153. MS_LOG(ERROR) << "input tensor type is wrong, type is " << tensor.tensor_type();
  154. return tensor;
  155. }
  156. tensor.set_tensor_type(iter->second);
  157. tensor.set_data(ms_tensor->MutableData(), ms_tensor->Size());
  158. return tensor;
  159. }
  160. void ClearEnv() {
  161. Session::Instance().Clear();
  162. inference::ExitInference();
  163. }
  164. void HandleSignal(int sig) { exit_requested.set_value(); }
  165. #ifdef ENABLE_D
  166. static rtContext_t g_ctx = nullptr;
  167. #endif
  168. } // namespace
  169. // Service Implement
  170. class MSServiceImpl final : public MSService::Service {
  171. grpc::Status Predict(grpc::ServerContext *context, const PredictRequest *request, PredictReply *reply) override {
  172. std::lock_guard<std::mutex> lock(mutex_);
  173. #ifdef ENABLE_D
  174. if (g_ctx == nullptr) {
  175. MS_LOG(ERROR) << "rtCtx is nullptr";
  176. return grpc::Status::CANCELLED;
  177. }
  178. rtError_t rt_ret = rtCtxSetCurrent(g_ctx);
  179. if (rt_ret != RT_ERROR_NONE) {
  180. MS_LOG(ERROR) << "set Ascend rtCtx failed";
  181. }
  182. #endif
  183. std::vector<MSTensorPtr> inputs;
  184. inference::MultiTensor outputs;
  185. for (int i = 0; i < request->data_size(); i++) {
  186. auto input = ServingTensor2MSTensor(request->data(i));
  187. if (input == nullptr) {
  188. MS_LOG(ERROR) << "Tensor convert failed";
  189. return grpc::Status::CANCELLED;
  190. }
  191. inputs.push_back(input);
  192. }
  193. auto res = Session::Instance().Predict(inputs, &outputs);
  194. if (res != SUCCESS) {
  195. return grpc::Status::CANCELLED;
  196. }
  197. for (const auto &tensor : outputs) {
  198. *reply->add_result() = MSTensor2ServingTensor(tensor);
  199. }
  200. MS_LOG(INFO) << "Finish call service Eval";
  201. return grpc::Status::OK;
  202. }
  203. grpc::Status Test(grpc::ServerContext *context, const PredictRequest *request, PredictReply *reply) override {
  204. MS_LOG(INFO) << "TestService call";
  205. return grpc::Status::OK;
  206. }
  207. std::mutex mutex_;
  208. };
  209. Status Server::BuildAndStart() {
  210. // handle exit signal
  211. signal(SIGINT, HandleSignal);
  212. Status res;
  213. auto option_args = Options::Instance().GetArgs();
  214. std::string server_address = "0.0.0.0:" + std::to_string(option_args->grpc_port);
  215. std::string model_path = option_args->model_path;
  216. std::string model_name = option_args->model_name;
  217. std::string device_type = option_args->device_type;
  218. auto device_id = option_args->device_id;
  219. res = Session::Instance().CreatDeviceSession(device_type, device_id);
  220. if (res != SUCCESS) {
  221. MS_LOG(ERROR) << "creat session failed";
  222. ClearEnv();
  223. return res;
  224. }
  225. VersionController version_controller(option_args->poll_model_wait_seconds, model_path, model_name);
  226. res = version_controller.Run();
  227. if (res != SUCCESS) {
  228. MS_LOG(ERROR) << "load model failed";
  229. ClearEnv();
  230. return res;
  231. }
  232. #ifdef ENABLE_D
  233. // set d context
  234. rtContext_t ctx = nullptr;
  235. rtError_t rt_ret = rtCtxGetCurrent(&ctx);
  236. if (rt_ret != RT_ERROR_NONE || ctx == nullptr) {
  237. MS_LOG(ERROR) << "the ascend device context is null";
  238. ClearEnv();
  239. return FAILED;
  240. }
  241. g_ctx = ctx;
  242. #endif
  243. MSServiceImpl service;
  244. grpc::EnableDefaultHealthCheckService(true);
  245. grpc::reflection::InitProtoReflectionServerBuilderPlugin();
  246. // Set the port is not reuseable
  247. auto option = grpc::MakeChannelArgumentOption(GRPC_ARG_ALLOW_REUSEPORT, 0);
  248. grpc::ServerBuilder builder;
  249. builder.SetOption(std::move(option));
  250. builder.SetMaxMessageSize(uint32max);
  251. // Listen on the given address without any authentication mechanism.
  252. builder.AddListeningPort(server_address, grpc::InsecureServerCredentials());
  253. // Register "service" as the instance through which we'll communicate with
  254. // clients. In this case it corresponds to an *synchronous* service.
  255. builder.RegisterService(&service);
  256. // Finally assemble the server.
  257. std::unique_ptr<grpc::Server> server(builder.BuildAndStart());
  258. if (server == nullptr) {
  259. MS_LOG(ERROR) << "The serving server create failed";
  260. ClearEnv();
  261. return FAILED;
  262. }
  263. auto grpc_server_run = [&server]() { server->Wait(); };
  264. std::thread serving_thread(grpc_server_run);
  265. MS_LOG(INFO) << "Server listening on " << server_address << std::endl;
  266. auto exit_future = exit_requested.get_future();
  267. exit_future.wait();
  268. ClearEnv();
  269. server->Shutdown();
  270. serving_thread.join();
  271. return SUCCESS;
  272. }
  273. } // namespace serving
  274. } // namespace mindspore