You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

dispacther.cc 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "master/dispacther.h"
  17. #include <utility>
  18. #include "common/proto_tensor.h"
  19. #include "master/master_context.h"
  20. #include "master/notify_worker/grpc_notify.h"
  21. #include "master/notify_worker/local_notify.h"
  22. namespace mindspore::serving {
  23. Dispatcher::Dispatcher() {}
  24. Dispatcher::~Dispatcher() { Clear(); }
  25. DispatcherWorkerContext Dispatcher::GetWorkSession(const RequestSpec &request_spec) const {
  26. Status status;
  27. DispatcherWorkerContext context;
  28. auto it = servable_map_.find(request_spec.servable_name);
  29. if (it == servable_map_.end()) {
  30. return context;
  31. }
  32. if (request_spec.version_number > 0) {
  33. auto item = find_if(it->second.begin(), it->second.end(), [&](const DispatcherWorkerContext &v) {
  34. return v.worker_spec.version_number == request_spec.version_number;
  35. });
  36. if (item != it->second.end()) {
  37. context.worker_spec = item->worker_spec;
  38. context.notify_worker_ = item->notify_worker_;
  39. }
  40. return context;
  41. }
  42. uint64_t max_version_number = 0;
  43. for (const auto &item : it->second) {
  44. if (max_version_number < item.worker_spec.version_number) {
  45. context.worker_spec = item.worker_spec;
  46. context.notify_worker_ = item.notify_worker_;
  47. max_version_number = item.worker_spec.version_number;
  48. }
  49. }
  50. return context;
  51. }
  52. Status Dispatcher::JudgeInferNum() {
  53. auto max_infer_num = MasterContext::Instance()->GetMaxRequestBufferCount();
  54. if (infer_num_ >= max_infer_num) {
  55. return INFER_STATUS_LOG_ERROR(FAILED) << "Serving Error: request buffer number exceeds the limit " << max_infer_num;
  56. }
  57. return SUCCESS;
  58. }
  59. void Dispatcher::DispatchAsync(const proto::PredictRequest &request, proto::PredictReply *reply,
  60. PredictOnFinish on_finish) {
  61. MSI_EXCEPTION_IF_NULL(reply);
  62. (*reply->mutable_servable_spec()) = request.servable_spec();
  63. Status status = JudgeInferNum();
  64. if (status != SUCCESS) {
  65. GrpcTensorHelper::CreateReplyFromErrorMsg(status, reply);
  66. on_finish();
  67. return;
  68. }
  69. try {
  70. auto callback = [this, on_finish]() {
  71. on_finish();
  72. this->infer_num_--;
  73. };
  74. infer_num_++;
  75. status = DispatchAsyncInner(request, reply, callback);
  76. } catch (const std::bad_alloc &ex) {
  77. MSI_LOG(ERROR) << "Serving Error: malloc memory failed";
  78. std::cout << "Serving Error: malloc memory failed" << std::endl;
  79. } catch (const std::runtime_error &ex) {
  80. MSI_LOG(ERROR) << "Serving Error: runtime error occurred: " << ex.what();
  81. std::cout << "Serving Error: runtime error occurred: " << ex.what() << std::endl;
  82. } catch (const std::exception &ex) {
  83. MSI_LOG(ERROR) << "Serving Error: exception occurred: " << ex.what();
  84. std::cout << "Serving Error: exception occurred: " << ex.what() << std::endl;
  85. } catch (...) {
  86. MSI_LOG(ERROR) << "Serving Error: exception occurred";
  87. std::cout << "Serving Error: exception occurred";
  88. }
  89. MSI_LOG(INFO) << "Finish call service Eval";
  90. if (status != SUCCESS) {
  91. GrpcTensorHelper::CreateReplyFromErrorMsg(status, reply);
  92. on_finish();
  93. infer_num_--;
  94. }
  95. }
  96. Status Dispatcher::DispatchAsyncInner(const proto::PredictRequest &request, proto::PredictReply *reply,
  97. PredictOnFinish on_finish) {
  98. MSI_EXCEPTION_IF_NULL(reply);
  99. std::shared_lock<std::shared_mutex> lock(servable_shared_lock_);
  100. RequestSpec request_spec;
  101. GrpcTensorHelper::GetRequestSpec(request, &request_spec);
  102. auto worker = GetWorkSession(request_spec);
  103. if (!worker.notify_worker_) {
  104. return INFER_STATUS_LOG_ERROR(INVALID_INPUTS) << "Request " << request_spec.Repr() << ", servable is not available";
  105. }
  106. bool find_method =
  107. std::any_of(worker.worker_spec.methods.begin(), worker.worker_spec.methods.end(),
  108. [&](const WorkerMethodInfo &method) { return method.name == request_spec.method_name; });
  109. if (!find_method) {
  110. return INFER_STATUS_LOG_ERROR(INVALID_INPUTS) << "Request " << request_spec.Repr() << ", method is not available";
  111. }
  112. return worker.notify_worker_->DispatchAsync(request, reply, std::move(on_finish));
  113. }
  114. Status Dispatcher::RegisterServableCommon(const std::vector<WorkerSpec> &worker_specs, CreateNotifyWorkerFunc func) {
  115. std::unique_lock<std::shared_mutex> lock(servable_shared_lock_);
  116. if (worker_specs.empty()) {
  117. return INFER_STATUS_LOG_ERROR(FAILED) << "Register failed, servable cannot be empty";
  118. }
  119. MSI_EXCEPTION_IF_NULL(func);
  120. for (auto &worker_spec : worker_specs) {
  121. if (worker_spec.servable_name.empty()) {
  122. return INFER_STATUS_LOG_ERROR(FAILED) << "Register failed, servable name cannot be empty";
  123. }
  124. if (worker_spec.version_number <= 0) {
  125. return INFER_STATUS_LOG_ERROR(FAILED) << "Register failed, servable name " << worker_spec.servable_name
  126. << " version number " << worker_spec.version_number << " cannot be 0";
  127. }
  128. auto it = servable_map_.find(worker_spec.servable_name);
  129. std::shared_ptr<BaseNotifyWorker> notify_worker = func(worker_spec);
  130. bool find_registered = false;
  131. if (it != servable_map_.end()) {
  132. auto item = find_if(it->second.begin(), it->second.end(), [&](const DispatcherWorkerContext &v) {
  133. return v.worker_spec.version_number == worker_spec.version_number &&
  134. v.worker_spec.worker_address == worker_spec.worker_address;
  135. });
  136. if (item != it->second.end()) {
  137. MSI_LOG_WARNING << "Servable " << worker_spec.servable_name << " version " << worker_spec.version_number
  138. << " has been registered, old registered info will be replaced";
  139. item->worker_spec = worker_spec;
  140. item->notify_worker_ = notify_worker;
  141. find_registered = true;
  142. }
  143. }
  144. if (!find_registered) {
  145. DispatcherWorkerContext context;
  146. context.worker_spec = worker_spec;
  147. context.notify_worker_ = notify_worker;
  148. servable_map_[worker_spec.servable_name].push_back(context);
  149. }
  150. }
  151. return SUCCESS;
  152. }
  153. Status Dispatcher::UnregisterServableCommon(const std::string &worker_address) {
  154. if (clearing_flag) {
  155. return SUCCESS;
  156. }
  157. std::unique_lock<std::shared_mutex> lock(servable_shared_lock_);
  158. Status status;
  159. for (auto iter = servable_map_.begin(); iter != servable_map_.end();) {
  160. for (auto it = iter->second.begin(); it != iter->second.end();) {
  161. if (worker_address == it->worker_spec.worker_address) {
  162. it = iter->second.erase(it);
  163. } else {
  164. ++it;
  165. }
  166. }
  167. if (iter->second.size() == 0) {
  168. iter = servable_map_.erase(iter);
  169. } else {
  170. ++iter;
  171. }
  172. }
  173. return SUCCESS;
  174. }
  175. Status Dispatcher::AddServableCommon(const WorkerSpec &worker_spec, CreateNotifyWorkerFunc func) {
  176. std::unique_lock<std::shared_mutex> lock(servable_shared_lock_);
  177. MSI_EXCEPTION_IF_NULL(func);
  178. if (worker_spec.servable_name.empty()) {
  179. return INFER_STATUS_LOG_ERROR(FAILED) << "AddServable failed, servable name cannot be empty";
  180. }
  181. if (worker_spec.version_number <= 0) {
  182. return INFER_STATUS_LOG_ERROR(FAILED) << "AddServable failed, servable name " << worker_spec.servable_name
  183. << " version number " << worker_spec.version_number << " cannot be 0";
  184. }
  185. Status status;
  186. auto it = servable_map_.find(worker_spec.servable_name);
  187. if (it != servable_map_.end()) {
  188. bool find = std::any_of(it->second.begin(), it->second.end(), [&](const DispatcherWorkerContext &item) {
  189. return item.worker_spec.version_number == worker_spec.version_number &&
  190. item.worker_spec.worker_address == worker_spec.worker_address;
  191. });
  192. if (find) {
  193. MSI_LOG_WARNING << "Servable " << worker_spec.servable_name << " version " << worker_spec.version_number
  194. << " has been registered";
  195. return SUCCESS;
  196. }
  197. }
  198. DispatcherWorkerContext context;
  199. context.worker_spec = worker_spec;
  200. context.notify_worker_ = func(worker_spec);
  201. servable_map_[worker_spec.servable_name].push_back(context);
  202. return SUCCESS;
  203. }
  204. Status Dispatcher::RemoveServableCommon(const WorkerSpec &worker_spec) {
  205. std::unique_lock<std::shared_mutex> lock(servable_shared_lock_);
  206. Status status;
  207. for (auto iter = servable_map_.begin(); iter != servable_map_.end();) {
  208. for (auto it = iter->second.begin(); it != iter->second.end();) {
  209. if (worker_spec.worker_address == it->worker_spec.worker_address &&
  210. it->worker_spec.servable_name == worker_spec.servable_name &&
  211. it->worker_spec.version_number == worker_spec.version_number) {
  212. it = iter->second.erase(it);
  213. } else {
  214. ++it;
  215. }
  216. }
  217. if (iter->second.size() == 0) {
  218. iter = servable_map_.erase(iter);
  219. } else {
  220. ++iter;
  221. }
  222. }
  223. return SUCCESS;
  224. }
  225. Status Dispatcher::RegisterServable(const proto::RegisterRequest &request, proto::RegisterReply * /*reply*/) {
  226. std::vector<WorkerSpec> worker_specs;
  227. GrpcTensorHelper::GetWorkerSpec(request, &worker_specs);
  228. auto create_notify_worker = [](const WorkerSpec &worker_spec) {
  229. std::shared_ptr<BaseNotifyWorker> notify_worker = std::make_shared<GrpcNotifyWorker>(worker_spec.worker_address);
  230. return notify_worker;
  231. };
  232. return RegisterServableCommon(worker_specs, create_notify_worker);
  233. }
  234. Status Dispatcher::UnregisterServable(const proto::ExitRequest &request, proto::ExitReply * /*reply*/) {
  235. return UnregisterServableCommon(request.address());
  236. }
  237. Status Dispatcher::AddServable(const proto::AddWorkerRequest &request, proto::AddWorkerReply * /*reply*/) {
  238. std::unique_lock<std::shared_mutex> lock(servable_shared_lock_);
  239. WorkerSpec worker_spec;
  240. GrpcTensorHelper::GetWorkerSpec(request, &worker_spec);
  241. auto create_notify_worker = [](const WorkerSpec &worker_spec) {
  242. std::shared_ptr<BaseNotifyWorker> notify_worker = std::make_shared<GrpcNotifyWorker>(worker_spec.worker_address);
  243. return notify_worker;
  244. };
  245. return AddServableCommon(worker_spec, create_notify_worker);
  246. }
  247. Status Dispatcher::RemoveServable(const proto::RemoveWorkerRequest &request, proto::RemoveWorkerReply * /*reply*/) {
  248. WorkerSpec worker_spec;
  249. GrpcTensorHelper::GetWorkerSpec(request, &worker_spec);
  250. return RemoveServableCommon(worker_spec);
  251. }
  252. void Dispatcher::Clear() {
  253. std::unique_lock<std::shared_mutex> lock(servable_shared_lock_);
  254. clearing_flag = true;
  255. for (auto iter = servable_map_.begin(); iter != servable_map_.end(); ++iter) {
  256. for (auto it = iter->second.begin(); it != iter->second.end(); ++it) {
  257. if (it->notify_worker_) {
  258. it->notify_worker_->Exit();
  259. }
  260. }
  261. }
  262. servable_map_.clear();
  263. }
  264. Status Dispatcher::RegisterLocalServable(const std::vector<WorkerSpec> &worker_specs) {
  265. auto create_notify_worker = [](const WorkerSpec &worker_spec) {
  266. std::shared_ptr<BaseNotifyWorker> notify_worker = std::make_shared<LocalNotifyWorker>();
  267. return notify_worker;
  268. };
  269. return RegisterServableCommon(worker_specs, create_notify_worker);
  270. }
  271. Status Dispatcher::UnregisterLocalServable() { return UnregisterServableCommon(""); }
  272. Status Dispatcher::AddLocalServable(const WorkerSpec &worker_spec) {
  273. auto create_notify_worker = [](const WorkerSpec &worker_spec) {
  274. std::shared_ptr<BaseNotifyWorker> notify_worker = std::make_shared<LocalNotifyWorker>();
  275. return notify_worker;
  276. };
  277. return AddServableCommon(worker_spec, create_notify_worker);
  278. }
  279. Status Dispatcher::RemoveLocalServable(const WorkerSpec &worker_spec) { return RemoveServableCommon(worker_spec); }
  280. } // namespace mindspore::serving

A lightweight and high-performance service module that helps MindSpore developers efficiently deploy online inference services in the production environment.