You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

dispacther.cc 14 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "master/dispacther.h"
  17. #include "worker/worker.h"
  18. #include "common/proto_tensor.h"
  19. namespace mindspore::serving {
  20. Dispatcher::Dispatcher() {}
  21. Dispatcher::~Dispatcher() { Clear(); }
  22. DispatcherWorkerContext Dispatcher::GetWorkSession(const RequestSpec &request_spec) const {
  23. Status status;
  24. DispatcherWorkerContext context;
  25. auto it = servable_map_.find(request_spec.servable_name);
  26. if (it == servable_map_.end()) {
  27. return context;
  28. }
  29. if (request_spec.version_number > 0) {
  30. auto item = find_if(it->second.begin(), it->second.end(), [&](const DispatcherWorkerContext &v) {
  31. return v.worker_spec.version_number == request_spec.version_number;
  32. });
  33. if (item != it->second.end()) {
  34. context.worker_spec = item->worker_spec;
  35. context.stub_ = item->stub_;
  36. context.worker_running_in_master = item->worker_running_in_master;
  37. }
  38. return context;
  39. }
  40. uint64_t max_version_number = 0;
  41. for (const auto &item : it->second) {
  42. if (max_version_number < item.worker_spec.version_number) {
  43. context.worker_spec = item.worker_spec;
  44. context.stub_ = item.stub_;
  45. context.worker_running_in_master = item.worker_running_in_master;
  46. max_version_number = item.worker_spec.version_number;
  47. }
  48. }
  49. return context;
  50. }
  51. Status Dispatcher::Dispatch(const proto::PredictRequest &request, proto::PredictReply *reply) {
  52. MSI_EXCEPTION_IF_NULL(reply);
  53. std::shared_lock<std::shared_mutex> lock(servable_shared_lock_);
  54. RequestSpec request_spec;
  55. GrpcTensorHelper::GetRequestSpec(request, &request_spec);
  56. auto worker = GetWorkSession(request_spec);
  57. if (!worker.stub_ && !worker.worker_running_in_master) {
  58. return INFER_STATUS_LOG_ERROR(INVALID_INPUTS) << "Request " << request_spec.Repr() << ", servable is not available";
  59. }
  60. bool find_method =
  61. std::any_of(worker.worker_spec.methods.begin(), worker.worker_spec.methods.end(),
  62. [&](const WorkerMethodInfo &method) { return method.name == request_spec.method_name; });
  63. if (!find_method) {
  64. return INFER_STATUS_LOG_ERROR(INVALID_INPUTS) << "Request " << request_spec.Repr() << ", method is not available";
  65. }
  66. /// TODO spec request version_number
  67. if (worker.stub_ != nullptr) {
  68. grpc::ClientContext context;
  69. auto status = worker.stub_->Predict(&context, request, reply);
  70. if (!status.ok()) {
  71. return INFER_STATUS_LOG_ERROR(FAILED)
  72. << "Predict failed, worker gRPC error: " << status.error_code() << ", " << status.error_message();
  73. }
  74. } else {
  75. return Worker::GetInstance().Run(request, reply);
  76. }
  77. return SUCCESS;
  78. }
  79. Status Dispatcher::RegisterServable(const proto::RegisterRequest &request, proto::RegisterReply * /*reply*/) {
  80. std::unique_lock<std::shared_mutex> lock(servable_shared_lock_);
  81. std::vector<WorkerSpec> worker_specs;
  82. GrpcTensorHelper::GetWorkerSpec(request, &worker_specs);
  83. if (worker_specs.empty()) {
  84. return INFER_STATUS_LOG_ERROR(FAILED) << "Register failed, servable cannot be empty";
  85. }
  86. for (auto &worker_spec : worker_specs) {
  87. if (worker_spec.servable_name.empty()) {
  88. return INFER_STATUS_LOG_ERROR(FAILED) << "Register failed, servable name cannot be empty";
  89. }
  90. if (worker_spec.version_number <= 0) {
  91. return INFER_STATUS_LOG_ERROR(FAILED) << "Register failed, servable name " << worker_spec.servable_name
  92. << " version number " << worker_spec.version_number << " cannot be 0";
  93. }
  94. auto target_str = request.address();
  95. auto it = servable_map_.find(worker_spec.servable_name);
  96. std::shared_ptr<grpc::Channel> channel = grpc::CreateChannel(target_str, grpc::InsecureChannelCredentials());
  97. bool find_registered = false;
  98. if (it != servable_map_.end()) {
  99. std::shared_ptr<Worker> worker;
  100. auto item = find_if(it->second.begin(), it->second.end(), [&](const DispatcherWorkerContext &v) {
  101. return v.worker_spec.version_number == worker_spec.version_number &&
  102. v.worker_spec.worker_address == worker_spec.worker_address;
  103. });
  104. if (item != it->second.end()) {
  105. MSI_LOG_WARNING << "Servable " << worker_spec.servable_name << " version " << worker_spec.version_number
  106. << " has been registered, old registered info will be replaced";
  107. item->worker_spec = worker_spec;
  108. item->stub_ = proto::MSWorker::NewStub(channel);
  109. find_registered = true;
  110. }
  111. }
  112. if (!find_registered) {
  113. DispatcherWorkerContext context;
  114. context.worker_spec = worker_spec;
  115. context.stub_ = proto::MSWorker::NewStub(channel);
  116. servable_map_[worker_spec.servable_name].push_back(context);
  117. }
  118. }
  119. return SUCCESS;
  120. }
  121. Status Dispatcher::UnregisterServable(const proto::ExitRequest &request, proto::ExitReply * /*reply*/) {
  122. if (clearing_flag) {
  123. return SUCCESS;
  124. }
  125. std::unique_lock<std::shared_mutex> lock(servable_shared_lock_);
  126. auto target_str = request.address();
  127. Status status;
  128. for (auto iter = servable_map_.begin(); iter != servable_map_.end();) {
  129. for (auto it = iter->second.begin(); it != iter->second.end();) {
  130. if (target_str == it->worker_spec.worker_address) {
  131. it = iter->second.erase(it);
  132. } else {
  133. ++it;
  134. }
  135. }
  136. if (iter->second.size() == 0) {
  137. iter = servable_map_.erase(iter);
  138. } else {
  139. ++iter;
  140. }
  141. }
  142. return SUCCESS;
  143. }
  144. Status Dispatcher::AddServable(const proto::AddWorkerRequest &request, proto::AddWorkerReply * /*reply*/) {
  145. std::unique_lock<std::shared_mutex> lock(servable_shared_lock_);
  146. WorkerSpec worker_spec;
  147. GrpcTensorHelper::GetWorkerSpec(request, &worker_spec);
  148. auto target_str = request.address();
  149. if (worker_spec.servable_name.empty()) {
  150. return INFER_STATUS_LOG_ERROR(FAILED) << "AddServable failed, servable name cannot be empty";
  151. }
  152. if (worker_spec.version_number <= 0) {
  153. return INFER_STATUS_LOG_ERROR(FAILED) << "AddServable failed, servable name " << worker_spec.servable_name
  154. << " version number " << worker_spec.version_number << " cannot be 0";
  155. }
  156. Status status;
  157. auto it = servable_map_.find(worker_spec.servable_name);
  158. if (it != servable_map_.end()) {
  159. bool find = std::any_of(it->second.begin(), it->second.end(), [&](const DispatcherWorkerContext &item) {
  160. return item.worker_spec.version_number == worker_spec.version_number &&
  161. item.worker_spec.worker_address == worker_spec.worker_address;
  162. });
  163. if (find) {
  164. MSI_LOG_WARNING << "Servable " << worker_spec.servable_name << " version " << worker_spec.version_number
  165. << " has been registered";
  166. return SUCCESS;
  167. }
  168. }
  169. DispatcherWorkerContext context;
  170. context.worker_spec = worker_spec;
  171. std::shared_ptr<grpc::Channel> channel = grpc::CreateChannel(target_str, grpc::InsecureChannelCredentials());
  172. context.stub_ = proto::MSWorker::NewStub(channel);
  173. servable_map_[worker_spec.servable_name].push_back(context);
  174. return SUCCESS;
  175. }
  176. Status Dispatcher::RemoveServable(const proto::RemoveWorkerRequest &request, proto::RemoveWorkerReply * /*reply*/) {
  177. std::unique_lock<std::shared_mutex> lock(servable_shared_lock_);
  178. WorkerSpec worker_spec;
  179. GrpcTensorHelper::GetWorkerSpec(request, &worker_spec);
  180. auto target_str = request.address();
  181. Status status;
  182. for (auto iter = servable_map_.begin(); iter != servable_map_.end();) {
  183. for (auto it = iter->second.begin(); it != iter->second.end();) {
  184. if (target_str == it->worker_spec.worker_address && it->worker_spec.servable_name == worker_spec.servable_name &&
  185. it->worker_spec.version_number == worker_spec.version_number) {
  186. it = iter->second.erase(it);
  187. } else {
  188. ++it;
  189. }
  190. }
  191. if (iter->second.size() == 0) {
  192. iter = servable_map_.erase(iter);
  193. } else {
  194. ++iter;
  195. }
  196. }
  197. return SUCCESS;
  198. }
  199. void Dispatcher::Clear() {
  200. std::unique_lock<std::shared_mutex> lock(servable_shared_lock_);
  201. clearing_flag = true;
  202. for (auto iter = servable_map_.begin(); iter != servable_map_.end(); ++iter) {
  203. for (auto it = iter->second.begin(); it != iter->second.end(); ++it) {
  204. proto::ExitRequest request;
  205. request.set_address(it->worker_spec.worker_address);
  206. proto::ExitReply reply;
  207. grpc::ClientContext context;
  208. const int32_t TIME_OUT = 1;
  209. std::chrono::system_clock::time_point deadline =
  210. std::chrono::system_clock::now() + std::chrono::seconds(TIME_OUT);
  211. context.set_deadline(deadline);
  212. if (it->stub_) {
  213. (void)it->stub_->Exit(&context, request, &reply);
  214. } else {
  215. Worker::GetInstance().Clear();
  216. }
  217. }
  218. }
  219. servable_map_.clear();
  220. }
  221. Status Dispatcher::RegisterLocalServable(const std::vector<WorkerSpec> &worker_specs) {
  222. std::unique_lock<std::shared_mutex> lock(servable_shared_lock_);
  223. if (worker_specs.empty()) {
  224. return INFER_STATUS_LOG_ERROR(FAILED) << "Register failed, servable cannot be empty";
  225. }
  226. for (auto &worker_spec : worker_specs) {
  227. if (worker_spec.servable_name.empty()) {
  228. return INFER_STATUS_LOG_ERROR(FAILED) << "Register failed, servable name cannot be empty";
  229. }
  230. if (worker_spec.version_number <= 0) {
  231. return INFER_STATUS_LOG_ERROR(FAILED) << "Register failed, servable name " << worker_spec.servable_name
  232. << " version number " << worker_spec.version_number << " cannot be 0";
  233. }
  234. auto it = servable_map_.find(worker_spec.servable_name);
  235. bool find_registered = false;
  236. if (it != servable_map_.end()) {
  237. std::shared_ptr<Worker> worker;
  238. auto item = find_if(it->second.begin(), it->second.end(), [&](const DispatcherWorkerContext &v) {
  239. return v.worker_spec.version_number == worker_spec.version_number &&
  240. v.worker_spec.worker_address == worker_spec.worker_address;
  241. });
  242. if (item != it->second.end()) {
  243. MSI_LOG_WARNING << "Servable " << worker_spec.servable_name << " version " << worker_spec.version_number
  244. << " has been registered, old registered info will be replaced";
  245. item->worker_spec = worker_spec;
  246. find_registered = true;
  247. }
  248. }
  249. if (!find_registered) {
  250. DispatcherWorkerContext context;
  251. context.worker_spec = worker_spec;
  252. context.worker_running_in_master = true;
  253. servable_map_[worker_spec.servable_name].push_back(context);
  254. }
  255. }
  256. return SUCCESS;
  257. }
  258. Status Dispatcher::UnregisterLocalServable() {
  259. if (clearing_flag) {
  260. return SUCCESS;
  261. }
  262. std::unique_lock<std::shared_mutex> lock(servable_shared_lock_);
  263. Status status;
  264. for (auto iter = servable_map_.begin(); iter != servable_map_.end();) {
  265. for (auto it = iter->second.begin(); it != iter->second.end();) {
  266. if (it->worker_running_in_master) {
  267. it = iter->second.erase(it);
  268. } else {
  269. ++it;
  270. }
  271. }
  272. if (iter->second.size() == 0) {
  273. iter = servable_map_.erase(iter);
  274. } else {
  275. ++iter;
  276. }
  277. }
  278. return SUCCESS;
  279. }
  280. Status Dispatcher::AddLocalServable(const WorkerSpec &worker_spec) {
  281. std::unique_lock<std::shared_mutex> lock(servable_shared_lock_);
  282. if (worker_spec.servable_name.empty()) {
  283. return INFER_STATUS_LOG_ERROR(FAILED) << "AddServable failed, servable name cannot be empty";
  284. }
  285. if (worker_spec.version_number <= 0) {
  286. return INFER_STATUS_LOG_ERROR(FAILED) << "AddServable failed, servable name " << worker_spec.servable_name
  287. << " version number " << worker_spec.version_number << " cannot be 0";
  288. }
  289. Status status;
  290. auto it = servable_map_.find(worker_spec.servable_name);
  291. if (it != servable_map_.end()) {
  292. bool find = std::any_of(it->second.begin(), it->second.end(), [&](const DispatcherWorkerContext &item) {
  293. return item.worker_spec.version_number == worker_spec.version_number &&
  294. item.worker_spec.worker_address == worker_spec.worker_address;
  295. });
  296. if (find) {
  297. MSI_LOG_WARNING << "Servable " << worker_spec.servable_name << " version " << worker_spec.version_number
  298. << " has been registered";
  299. return SUCCESS;
  300. }
  301. }
  302. DispatcherWorkerContext context;
  303. context.worker_spec = worker_spec;
  304. context.worker_running_in_master = true;
  305. servable_map_[worker_spec.servable_name].push_back(context);
  306. return SUCCESS;
  307. }
  308. Status Dispatcher::RemoveLocalServable(const WorkerSpec &worker_spec) {
  309. std::unique_lock<std::shared_mutex> lock(servable_shared_lock_);
  310. Status status;
  311. for (auto iter = servable_map_.begin(); iter != servable_map_.end();) {
  312. for (auto it = iter->second.begin(); it != iter->second.end();) {
  313. if (it->worker_running_in_master && it->worker_spec.servable_name == worker_spec.servable_name &&
  314. it->worker_spec.version_number == worker_spec.version_number) {
  315. it = iter->second.erase(it);
  316. } else {
  317. ++it;
  318. }
  319. }
  320. if (iter->second.size() == 0) {
  321. iter = servable_map_.erase(iter);
  322. } else {
  323. ++iter;
  324. }
  325. }
  326. return SUCCESS;
  327. }
  328. } // namespace mindspore::serving

A lightweight and high-performance service module that helps MindSpore developers efficiently deploy online inference services in the production environment.