You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_servable_common.h 16 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_SERVING_TEST_SERVABLE_COMMON_H
  17. #define MINDSPORE_SERVING_TEST_SERVABLE_COMMON_H
  18. #include <fstream>
  19. #include <memory>
  20. #include <set>
  21. #include <string>
  22. #include <vector>
  23. #include "common/common_test.h"
  24. #include "master/server.h"
  25. #include "worker/worker.h"
  26. #include "worker/notfiy_master/local_notify.h"
  27. #include "worker/context.h"
  28. #include "worker/local_servable/local_sevable.h"
  29. #include "master/grpc/grpc_process.h"
  30. #include "mindspore_serving/proto/ms_service.pb.h"
  31. namespace mindspore {
  32. namespace serving {
  33. #define ExpectContainMsg(error_msg, expected_msg) \
  34. { \
  35. auto error_msg_str = error_msg; \
  36. EXPECT_TRUE(error_msg_str.find(expected_msg) != std::string::npos); \
  37. if (error_msg_str.find(expected_msg) == std::string::npos) { \
  38. std::cout << "error_msg: " << error_msg_str << ", expected_msg: " << expected_msg << std::endl; \
  39. } \
  40. }
  41. class TestMasterWorker : public UT::Common {
  42. public:
  43. TestMasterWorker() = default;
  44. void Init(std::string servable_dir, std::string servable_name, int version_number, std::string model_file) {
  45. servable_dir_ = servable_dir;
  46. servable_name_ = servable_name;
  47. version_number_ = version_number;
  48. model_file_ = model_file;
  49. servable_name_path_ = servable_dir_ + "/" + servable_name_;
  50. version_number_path_ = servable_name_path_ + "/" + std::to_string(version_number_);
  51. model_name_path_ = version_number_path_ + "/" + model_file_;
  52. __mode_t access_mode = S_IRWXU | S_IRWXG | S_IROTH | S_IXOTH;
  53. mkdir(servable_dir_.c_str(), access_mode);
  54. mkdir(servable_name_path_.c_str(), access_mode);
  55. mkdir(version_number_path_.c_str(), access_mode);
  56. std::ofstream fp(model_name_path_);
  57. fp << "model content";
  58. fp.close();
  59. model_name_path_list_.emplace(model_name_path_);
  60. version_number_path_list_.emplace(version_number_path_);
  61. servable_name_path_list_.emplace(servable_name_path_);
  62. servable_dir_list_.emplace(servable_dir_);
  63. }
  64. virtual void SetUp() {}
  65. virtual void TearDown() {
  66. for (auto &item : model_name_path_list_) {
  67. remove(item.c_str());
  68. }
  69. for (auto &item : version_number_path_list_) {
  70. rmdir(item.c_str());
  71. }
  72. for (auto &item : servable_name_path_list_) {
  73. rmdir(item.c_str());
  74. }
  75. for (auto &item : servable_dir_list_) {
  76. rmdir(item.c_str());
  77. }
  78. Worker::GetInstance().Clear();
  79. Server::Instance().Clear();
  80. }
  81. void StartAddServable() {
  82. auto status = StartServable(servable_dir_, servable_name_, 0);
  83. ASSERT_TRUE(status.IsSuccess());
  84. }
  85. void RegisterAddServable(bool with_batch_dim = false) {
  86. DeclareServable(servable_name_, model_file_, "mindir", with_batch_dim);
  87. // register_method
  88. RegisterMethod(servable_name_, "add_common", {"x1", "x2"}, {"y"}, 2, 1);
  89. }
  90. static Status StartServable(const std::string &servable_dir, const std::string &servable_name, int version_number) {
  91. auto notify_master = std::make_shared<LocalNotifyMaster>();
  92. ServableContext::Instance()->SetDeviceId(0);
  93. ServableContext::Instance()->SetDeviceTypeStr("Ascend");
  94. auto servable = std::make_shared<LocalModelServable>();
  95. auto status = servable->StartServable(servable_dir, servable_name, version_number);
  96. if (status != SUCCESS) {
  97. return status;
  98. }
  99. status = Worker::GetInstance().StartServable(servable, notify_master);
  100. return status;
  101. }
  102. static void DeclareServable(const std::string &servable_name, const std::string &servable_file,
  103. const std::string &model_type, bool with_batch_dim = false) {
  104. ServableMeta servable_meta;
  105. servable_meta.common_meta.servable_name = servable_name;
  106. servable_meta.common_meta.with_batch_dim = with_batch_dim;
  107. servable_meta.local_meta.servable_file = servable_file;
  108. servable_meta.local_meta.SetModelFormat(model_type);
  109. // declare_servable
  110. ServableStorage::Instance().DeclareServable(servable_meta);
  111. }
  112. static Status RegisterMethod(const std::string &servable_name, const std::string &method_name,
  113. const std::vector<std::string> &input_names,
  114. const std::vector<std::string> &output_names, size_t servable_input_count,
  115. size_t servable_output_count) {
  116. auto status =
  117. ServableStorage::Instance().RegisterInputOutputInfo(servable_name, servable_input_count, servable_output_count);
  118. if (status != SUCCESS) {
  119. return status;
  120. }
  121. MethodSignature method_signature;
  122. method_signature.servable_name = servable_name;
  123. method_signature.method_name = method_name;
  124. method_signature.inputs = input_names;
  125. method_signature.outputs = output_names;
  126. // method input 0 and input 1 as servable input
  127. method_signature.servable_inputs = {{kPredictPhaseTag_Input, 0}, {kPredictPhaseTag_Input, 1}};
  128. // servable output as method output
  129. method_signature.returns = {{kPredictPhaseTag_Predict, 0}};
  130. ServableStorage::Instance().RegisterMethod(method_signature);
  131. return SUCCESS;
  132. }
  133. std::string servable_dir_;
  134. std::string servable_name_;
  135. int version_number_ = 0;
  136. std::string model_file_;
  137. std::string model_name_path_;
  138. std::string version_number_path_;
  139. std::string servable_name_path_;
  140. std::set<std::string> servable_dir_list_;
  141. std::set<std::string> model_name_path_list_;
  142. std::set<std::string> version_number_path_list_;
  143. std::set<std::string> servable_name_path_list_;
  144. };
  145. class TestMasterWorkerClient : public TestMasterWorker {
  146. public:
  147. TestMasterWorkerClient() = default;
  148. static void InitTensor(proto::Tensor *tensor, const std::vector<int64_t> &shape, proto::DataType data_type,
  149. const void *data, size_t data_size) {
  150. MSI_EXCEPTION_IF_NULL(tensor);
  151. tensor->set_dtype(data_type);
  152. auto proto_shape = tensor->mutable_shape();
  153. for (auto item : shape) {
  154. proto_shape->add_dims(item);
  155. }
  156. tensor->set_data(data, data_size);
  157. }
  158. static std::vector<float> InitOneInstanceRequest(proto::PredictRequest *request, const std::string &servable_name,
  159. const std::string &method_name, int version_number) {
  160. MSI_EXCEPTION_IF_NULL(request);
  161. auto request_servable_spec = request->mutable_servable_spec();
  162. request_servable_spec->set_name(servable_name);
  163. request_servable_spec->set_method_name(method_name);
  164. request_servable_spec->set_version_number(version_number);
  165. std::vector<float> x1_data = {1.1, 2.2, 3.3, 4.4};
  166. std::vector<float> x2_data = {1.2, 2.3, 3.4, 4.5};
  167. std::vector<float> y_data;
  168. for (size_t i = 0; i < x1_data.size(); i++) {
  169. y_data.push_back(x1_data[i] + x2_data[i]);
  170. }
  171. auto instance = request->add_instances();
  172. auto &input_map = (*instance->mutable_items());
  173. // input x1
  174. InitTensor(&input_map["x1"], {2, 2}, proto::MS_FLOAT32, x1_data.data(), x1_data.size() * sizeof(float));
  175. // input x2
  176. InitTensor(&input_map["x2"], {2, 2}, proto::MS_FLOAT32, x2_data.data(), x2_data.size() * sizeof(float));
  177. return y_data;
  178. }
  179. template <class IN_DT = float, class OUT_DT = float>
  180. static std::vector<std::vector<OUT_DT>> InitMultiInstancesRequest(proto::PredictRequest *request,
  181. const std::string &servable_name,
  182. const std::string &method_name, int version_number,
  183. size_t instances_count) {
  184. MSI_EXCEPTION_IF_NULL(request);
  185. auto request_servable_spec = request->mutable_servable_spec();
  186. request_servable_spec->set_name(servable_name);
  187. request_servable_spec->set_method_name(method_name);
  188. request_servable_spec->set_version_number(version_number);
  189. auto data_type = proto::MS_FLOAT32;
  190. if (std::string(typeid(IN_DT).name()) == std::string(typeid(int32_t).name())) {
  191. data_type = proto::MS_INT32;
  192. }
  193. std::vector<std::vector<OUT_DT>> y_data_list;
  194. for (size_t k = 0; k < instances_count; k++) {
  195. std::vector<float> x1_data_org = {1.1, 2.2, 3.3, 4.4};
  196. std::vector<float> x2_data_org = {6.6, 7.7, 8.8, 9.9};
  197. std::vector<IN_DT> x1_data;
  198. std::vector<IN_DT> x2_data;
  199. std::vector<OUT_DT> y_data;
  200. for (size_t i = 0; i < x1_data_org.size(); i++) {
  201. x1_data.push_back(static_cast<IN_DT>(x1_data_org[i] * (k + 1)));
  202. x2_data.push_back(static_cast<IN_DT>(x2_data_org[i] * (k + 1)));
  203. y_data.push_back(static_cast<OUT_DT>(x1_data[i] + x2_data[i]));
  204. }
  205. y_data_list.push_back(y_data);
  206. auto instance = request->add_instances();
  207. auto &input_map = (*instance->mutable_items());
  208. // input x1
  209. InitTensor(&input_map["x1"], {2, 2}, data_type, x1_data.data(), x1_data.size() * sizeof(IN_DT));
  210. // input x2
  211. InitTensor(&input_map["x2"], {2, 2}, data_type, x2_data.data(), x2_data.size() * sizeof(IN_DT));
  212. }
  213. return y_data_list;
  214. }
  215. template <class IN_DT = float, class OUT_DT = float>
  216. static std::vector<std::vector<OUT_DT>> InitMultiInstancesShape2Request(proto::PredictRequest *request,
  217. const std::string &servable_name,
  218. const std::string &method_name,
  219. int version_number, size_t instances_count) {
  220. MSI_EXCEPTION_IF_NULL(request);
  221. auto request_servable_spec = request->mutable_servable_spec();
  222. request_servable_spec->set_name(servable_name);
  223. request_servable_spec->set_method_name(method_name);
  224. request_servable_spec->set_version_number(version_number);
  225. auto data_type = proto::MS_FLOAT32;
  226. if (std::string(typeid(IN_DT).name()) == std::string(typeid(int32_t).name())) {
  227. data_type = proto::MS_INT32;
  228. }
  229. std::vector<std::vector<OUT_DT>> y_data_list;
  230. for (size_t k = 0; k < instances_count; k++) {
  231. std::vector<float> x1_data_org = {1.1, 2.2};
  232. std::vector<float> x2_data_org = {8.8, 9.9};
  233. std::vector<IN_DT> x1_data;
  234. std::vector<IN_DT> x2_data;
  235. std::vector<OUT_DT> y_data;
  236. for (size_t i = 0; i < x1_data_org.size(); i++) {
  237. x1_data.push_back(static_cast<IN_DT>(x1_data_org[i] * (k + 1)));
  238. x2_data.push_back(static_cast<IN_DT>(x2_data_org[i] * (k + 1)));
  239. y_data.push_back(x1_data[i] + x2_data[i]);
  240. }
  241. y_data_list.push_back(y_data);
  242. auto instance = request->add_instances();
  243. auto &input_map = (*instance->mutable_items());
  244. // input x1
  245. InitTensor(&input_map["x1"], {2}, data_type, x1_data.data(), x1_data.size() * sizeof(IN_DT));
  246. // input x2
  247. InitTensor(&input_map["x2"], {2}, data_type, x2_data.data(), x2_data.size() * sizeof(IN_DT));
  248. }
  249. return y_data_list;
  250. }
  251. template <class OUT_DT>
  252. static void CheckMultiInstanceResult(const proto::PredictReply &reply,
  253. const std::vector<std::vector<OUT_DT>> &y_data_list,
  254. size_t instances_count) { // checkout output
  255. ASSERT_EQ(reply.instances_size(), instances_count);
  256. ASSERT_EQ(reply.error_msg_size(), 0);
  257. auto data_type = proto::MS_FLOAT32;
  258. if (std::string(typeid(OUT_DT).name()) == std::string(typeid(int32_t).name())) {
  259. data_type = proto::MS_INT32;
  260. }
  261. std::vector<int64_t> shape;
  262. if (y_data_list[0].size() == 4) {
  263. shape = {2, 2};
  264. } else {
  265. shape = {2};
  266. }
  267. for (size_t k = 0; k < instances_count; k++) {
  268. auto &output_instance = reply.instances(k);
  269. ASSERT_EQ(output_instance.items_size(), 1);
  270. auto &output_items = output_instance.items();
  271. ASSERT_EQ(output_items.begin()->first, "y");
  272. auto &output_tensor = output_items.begin()->second;
  273. CheckTensor(output_tensor, shape, data_type, y_data_list[k].data(), y_data_list[k].size() * sizeof(OUT_DT));
  274. }
  275. }
  276. template <class OUT_DT>
  277. static void CheckInstanceResult(const proto::PredictReply &reply, const std::vector<OUT_DT> &y_data) {
  278. // checkout output
  279. ASSERT_EQ(reply.instances_size(), 1);
  280. ASSERT_EQ(reply.error_msg_size(), 0);
  281. auto data_type = proto::MS_FLOAT32;
  282. if (std::string(typeid(OUT_DT).name()) == std::string(typeid(int32_t).name())) {
  283. data_type = proto::MS_INT32;
  284. }
  285. std::vector<int64_t> shape;
  286. if (y_data.size() == 4) {
  287. shape = {2, 2};
  288. } else {
  289. shape = {2};
  290. }
  291. auto &output_instance = reply.instances(0);
  292. ASSERT_EQ(output_instance.items_size(), 1);
  293. auto &output_items = output_instance.items();
  294. ASSERT_EQ(output_items.begin()->first, "y");
  295. auto &output_tensor = output_items.begin()->second;
  296. CheckTensor(output_tensor, shape, data_type, y_data.data(), y_data.size() * sizeof(OUT_DT));
  297. }
  298. static void CheckTensor(const proto::Tensor &output_tensor, const std::vector<int64_t> &shape,
  299. proto::DataType data_type, const void *data, size_t data_size) {
  300. EXPECT_EQ(output_tensor.dtype(), data_type);
  301. // check shape [2,2]
  302. auto &output_tensor_shape = output_tensor.shape();
  303. ASSERT_EQ(output_tensor_shape.dims_size(), shape.size());
  304. std::vector<int64_t> proto_shape;
  305. for (size_t i = 0; i < output_tensor_shape.dims_size(); i++) {
  306. proto_shape.push_back(output_tensor_shape.dims(i));
  307. }
  308. EXPECT_EQ(proto_shape, shape);
  309. // check data
  310. ASSERT_EQ(output_tensor.data().size(), data_size);
  311. switch (data_type) {
  312. case proto::MS_FLOAT32: {
  313. auto data_len = data_size / sizeof(float);
  314. auto real_data = reinterpret_cast<const float *>(output_tensor.data().data());
  315. auto expect_data = reinterpret_cast<const float *>(data);
  316. for (size_t i = 0; i < data_len; i++) {
  317. EXPECT_EQ(real_data[i], expect_data[i]);
  318. if (real_data[i] != expect_data[i]) {
  319. break;
  320. }
  321. }
  322. break;
  323. }
  324. case proto::MS_INT32: {
  325. auto data_len = data_size / sizeof(int32_t);
  326. auto real_data = reinterpret_cast<const int32_t *>(output_tensor.data().data());
  327. auto expect_data = reinterpret_cast<const int32_t *>(data);
  328. for (size_t i = 0; i < data_len; i++) {
  329. EXPECT_EQ(real_data[i], expect_data[i]);
  330. if (real_data[i] != expect_data[i]) {
  331. break;
  332. }
  333. }
  334. break;
  335. }
  336. default:
  337. FAIL();
  338. }
  339. }
  340. static grpc::Status Dispatch(const proto::PredictRequest &request, proto::PredictReply *reply) {
  341. MSServiceImpl impl(Server::Instance().GetDispatcher());
  342. grpc::ServerContext context;
  343. auto promise = std::make_shared<std::promise<void>>();
  344. auto future = promise->get_future();
  345. PredictOnFinish callback = [promise]() { promise->set_value(); };
  346. impl.PredictAsync(&request, reply, callback);
  347. future.get();
  348. return grpc::Status::OK;
  349. }
  350. };
  351. } // namespace serving
  352. } // namespace mindspore
  353. #endif // MINDSPORE_SERVING_TEST_SERVABLE_COMMON_H

A lightweight and high-performance service module that helps MindSpore developers efficiently deploy online inference services in the production environment.