You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_servable_common.h 16 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_SERVING_TEST_SERVABLE_COMMON_H
  17. #define MINDSPORE_SERVING_TEST_SERVABLE_COMMON_H
  18. #include <fstream>
  19. #include <memory>
  20. #include <set>
  21. #include <string>
  22. #include <vector>
  23. #include "common/common_test.h"
  24. #include "master/server.h"
  25. #include "worker/worker.h"
  26. #include "worker/notfiy_master/local_notify.h"
  27. #include "worker/context.h"
  28. #include "master/grpc/grpc_process.h"
  29. #include "mindspore_serving/proto/ms_service.pb.h"
  30. namespace mindspore {
  31. namespace serving {
  32. #define ExpectContainMsg(error_msg, expected_msg) \
  33. { \
  34. auto error_msg_str = error_msg; \
  35. EXPECT_TRUE(error_msg_str.find(expected_msg) != std::string::npos); \
  36. if (error_msg_str.find(expected_msg) == std::string::npos) { \
  37. std::cout << "error_msg: " << error_msg_str << ", expected_msg: " << expected_msg << std::endl; \
  38. } \
  39. }
  40. class TestMasterWorker : public UT::Common {
  41. public:
  42. TestMasterWorker() = default;
  43. void Init(std::string servable_dir, std::string servable_name, int version_number, std::string model_file) {
  44. servable_dir_ = servable_dir;
  45. servable_name_ = servable_name;
  46. version_number_ = version_number;
  47. model_file_ = model_file;
  48. servable_name_path_ = servable_dir_ + "/" + servable_name_;
  49. version_number_path_ = servable_name_path_ + "/" + std::to_string(version_number_);
  50. model_name_path_ = version_number_path_ + "/" + model_file_;
  51. __mode_t access_mode = S_IRWXU | S_IRWXG | S_IROTH | S_IXOTH;
  52. mkdir(servable_dir_.c_str(), access_mode);
  53. mkdir(servable_name_path_.c_str(), access_mode);
  54. mkdir(version_number_path_.c_str(), access_mode);
  55. std::ofstream fp(model_name_path_);
  56. fp << "model content";
  57. fp.close();
  58. model_name_path_list_.emplace(model_name_path_);
  59. version_number_path_list_.emplace(version_number_path_);
  60. servable_name_path_list_.emplace(servable_name_path_);
  61. servable_dir_list_.emplace(servable_dir_);
  62. }
  63. virtual void SetUp() {}
  64. virtual void TearDown() {
  65. for (auto &item : model_name_path_list_) {
  66. remove(item.c_str());
  67. }
  68. for (auto &item : version_number_path_list_) {
  69. rmdir(item.c_str());
  70. }
  71. for (auto &item : servable_name_path_list_) {
  72. rmdir(item.c_str());
  73. }
  74. for (auto &item : servable_dir_list_) {
  75. rmdir(item.c_str());
  76. }
  77. Worker::GetInstance().Clear();
  78. Server::Instance().Clear();
  79. }
  80. void StartAddServable() {
  81. auto status = StartServable(servable_dir_, servable_name_, 0);
  82. ASSERT_TRUE(status.IsSuccess());
  83. }
  84. void RegisterAddServable(bool with_batch_dim = false) {
  85. DeclareServable(servable_name_, model_file_, "mindir", with_batch_dim);
  86. // register_method
  87. RegisterMethod(servable_name_, "add_common", {"x1", "x2"}, {"y"}, 2, 1);
  88. }
  89. static Status StartServable(const std::string &servable_dir, const std::string &servable_name, int version_number) {
  90. auto notify_master = std::make_shared<LocalNotifyMaster>();
  91. ServableContext::Instance()->SetDeviceId(0);
  92. ServableContext::Instance()->SetDeviceTypeStr("Ascend");
  93. Status status = Worker::GetInstance().StartServable(servable_dir, servable_name, version_number, notify_master);
  94. return status;
  95. }
  96. static void DeclareServable(const std::string &servable_name, const std::string &servable_file,
  97. const std::string &model_type, bool with_batch_dim = false) {
  98. ServableMeta servable_meta;
  99. servable_meta.servable_name = servable_name;
  100. servable_meta.servable_file = servable_file;
  101. servable_meta.SetModelFormat(model_type);
  102. servable_meta.with_batch_dim = with_batch_dim;
  103. // declare_servable
  104. ServableStorage::Instance().DeclareServable(servable_meta);
  105. }
  106. static Status RegisterMethod(const std::string &servable_name, const std::string &method_name,
  107. const std::vector<std::string> &input_names,
  108. const std::vector<std::string> &output_names, size_t servable_input_count,
  109. size_t servable_output_count) {
  110. auto status =
  111. ServableStorage::Instance().RegisterInputOutputInfo(servable_name, servable_input_count, servable_output_count);
  112. if (status != SUCCESS) {
  113. return status;
  114. }
  115. MethodSignature method_signature;
  116. method_signature.servable_name = servable_name;
  117. method_signature.method_name = method_name;
  118. method_signature.inputs = input_names;
  119. method_signature.outputs = output_names;
  120. // method input 0 and input 1 as servable input
  121. method_signature.servable_inputs = {{kPredictPhaseTag_Input, 0}, {kPredictPhaseTag_Input, 1}};
  122. // servable output as method output
  123. method_signature.returns = {{kPredictPhaseTag_Predict, 0}};
  124. ServableStorage::Instance().RegisterMethod(method_signature);
  125. return SUCCESS;
  126. }
  127. std::string servable_dir_;
  128. std::string servable_name_;
  129. int version_number_ = 0;
  130. std::string model_file_;
  131. std::string model_name_path_;
  132. std::string version_number_path_;
  133. std::string servable_name_path_;
  134. std::set<std::string> servable_dir_list_;
  135. std::set<std::string> model_name_path_list_;
  136. std::set<std::string> version_number_path_list_;
  137. std::set<std::string> servable_name_path_list_;
  138. };
  139. class TestMasterWorkerClient : public TestMasterWorker {
  140. public:
  141. TestMasterWorkerClient() = default;
  142. static void InitTensor(proto::Tensor *tensor, const std::vector<int64_t> &shape, proto::DataType data_type,
  143. const void *data, size_t data_size) {
  144. MSI_EXCEPTION_IF_NULL(tensor);
  145. tensor->set_dtype(data_type);
  146. auto proto_shape = tensor->mutable_shape();
  147. for (auto item : shape) {
  148. proto_shape->add_dims(item);
  149. }
  150. tensor->set_data(data, data_size);
  151. }
  152. static std::vector<float> InitOneInstanceRequest(proto::PredictRequest *request, const std::string &servable_name,
  153. const std::string &method_name, int version_number) {
  154. MSI_EXCEPTION_IF_NULL(request);
  155. auto request_servable_spec = request->mutable_servable_spec();
  156. request_servable_spec->set_name(servable_name);
  157. request_servable_spec->set_method_name(method_name);
  158. request_servable_spec->set_version_number(version_number);
  159. std::vector<float> x1_data = {1.1, 2.2, 3.3, 4.4};
  160. std::vector<float> x2_data = {1.2, 2.3, 3.4, 4.5};
  161. std::vector<float> y_data;
  162. for (size_t i = 0; i < x1_data.size(); i++) {
  163. y_data.push_back(x1_data[i] + x2_data[i]);
  164. }
  165. auto instance = request->add_instances();
  166. auto &input_map = (*instance->mutable_items());
  167. // input x1
  168. InitTensor(&input_map["x1"], {2, 2}, proto::MS_FLOAT32, x1_data.data(), x1_data.size() * sizeof(float));
  169. // input x2
  170. InitTensor(&input_map["x2"], {2, 2}, proto::MS_FLOAT32, x2_data.data(), x2_data.size() * sizeof(float));
  171. return y_data;
  172. }
  173. template <class IN_DT = float, class OUT_DT = float>
  174. static std::vector<std::vector<OUT_DT>> InitMultiInstancesRequest(proto::PredictRequest *request,
  175. const std::string &servable_name,
  176. const std::string &method_name, int version_number,
  177. size_t instances_count) {
  178. MSI_EXCEPTION_IF_NULL(request);
  179. auto request_servable_spec = request->mutable_servable_spec();
  180. request_servable_spec->set_name(servable_name);
  181. request_servable_spec->set_method_name(method_name);
  182. request_servable_spec->set_version_number(version_number);
  183. auto data_type = proto::MS_FLOAT32;
  184. if (std::string(typeid(IN_DT).name()) == std::string(typeid(int32_t).name())) {
  185. data_type = proto::MS_INT32;
  186. }
  187. std::vector<std::vector<OUT_DT>> y_data_list;
  188. for (size_t k = 0; k < instances_count; k++) {
  189. std::vector<float> x1_data_org = {1.1, 2.2, 3.3, 4.4};
  190. std::vector<float> x2_data_org = {6.6, 7.7, 8.8, 9.9};
  191. std::vector<IN_DT> x1_data;
  192. std::vector<IN_DT> x2_data;
  193. std::vector<OUT_DT> y_data;
  194. for (size_t i = 0; i < x1_data_org.size(); i++) {
  195. x1_data.push_back(static_cast<IN_DT>(x1_data_org[i] * (k + 1)));
  196. x2_data.push_back(static_cast<IN_DT>(x2_data_org[i] * (k + 1)));
  197. y_data.push_back(static_cast<OUT_DT>(x1_data[i] + x2_data[i]));
  198. }
  199. y_data_list.push_back(y_data);
  200. auto instance = request->add_instances();
  201. auto &input_map = (*instance->mutable_items());
  202. // input x1
  203. InitTensor(&input_map["x1"], {2, 2}, data_type, x1_data.data(), x1_data.size() * sizeof(IN_DT));
  204. // input x2
  205. InitTensor(&input_map["x2"], {2, 2}, data_type, x2_data.data(), x2_data.size() * sizeof(IN_DT));
  206. }
  207. return y_data_list;
  208. }
  209. template <class IN_DT = float, class OUT_DT = float>
  210. static std::vector<std::vector<OUT_DT>> InitMultiInstancesShape2Request(proto::PredictRequest *request,
  211. const std::string &servable_name,
  212. const std::string &method_name,
  213. int version_number, size_t instances_count) {
  214. MSI_EXCEPTION_IF_NULL(request);
  215. auto request_servable_spec = request->mutable_servable_spec();
  216. request_servable_spec->set_name(servable_name);
  217. request_servable_spec->set_method_name(method_name);
  218. request_servable_spec->set_version_number(version_number);
  219. auto data_type = proto::MS_FLOAT32;
  220. if (std::string(typeid(IN_DT).name()) == std::string(typeid(int32_t).name())) {
  221. data_type = proto::MS_INT32;
  222. }
  223. std::vector<std::vector<OUT_DT>> y_data_list;
  224. for (size_t k = 0; k < instances_count; k++) {
  225. std::vector<float> x1_data_org = {1.1, 2.2};
  226. std::vector<float> x2_data_org = {8.8, 9.9};
  227. std::vector<IN_DT> x1_data;
  228. std::vector<IN_DT> x2_data;
  229. std::vector<OUT_DT> y_data;
  230. for (size_t i = 0; i < x1_data_org.size(); i++) {
  231. x1_data.push_back(static_cast<IN_DT>(x1_data_org[i] * (k + 1)));
  232. x2_data.push_back(static_cast<IN_DT>(x2_data_org[i] * (k + 1)));
  233. y_data.push_back(x1_data[i] + x2_data[i]);
  234. }
  235. y_data_list.push_back(y_data);
  236. auto instance = request->add_instances();
  237. auto &input_map = (*instance->mutable_items());
  238. // input x1
  239. InitTensor(&input_map["x1"], {2}, data_type, x1_data.data(), x1_data.size() * sizeof(IN_DT));
  240. // input x2
  241. InitTensor(&input_map["x2"], {2}, data_type, x2_data.data(), x2_data.size() * sizeof(IN_DT));
  242. }
  243. return y_data_list;
  244. }
  245. template <class OUT_DT>
  246. static void CheckMultiInstanceResult(const proto::PredictReply &reply,
  247. const std::vector<std::vector<OUT_DT>> &y_data_list,
  248. size_t instances_count) { // checkout output
  249. ASSERT_EQ(reply.instances_size(), instances_count);
  250. ASSERT_EQ(reply.error_msg_size(), 0);
  251. auto data_type = proto::MS_FLOAT32;
  252. if (std::string(typeid(OUT_DT).name()) == std::string(typeid(int32_t).name())) {
  253. data_type = proto::MS_INT32;
  254. }
  255. std::vector<int64_t> shape;
  256. if (y_data_list[0].size() == 4) {
  257. shape = {2, 2};
  258. } else {
  259. shape = {2};
  260. }
  261. for (size_t k = 0; k < instances_count; k++) {
  262. auto &output_instance = reply.instances(k);
  263. ASSERT_EQ(output_instance.items_size(), 1);
  264. auto &output_items = output_instance.items();
  265. ASSERT_EQ(output_items.begin()->first, "y");
  266. auto &output_tensor = output_items.begin()->second;
  267. CheckTensor(output_tensor, shape, data_type, y_data_list[k].data(), y_data_list[k].size() * sizeof(OUT_DT));
  268. }
  269. }
  270. template <class OUT_DT>
  271. static void CheckInstanceResult(const proto::PredictReply &reply, const std::vector<OUT_DT> &y_data) {
  272. // checkout output
  273. ASSERT_EQ(reply.instances_size(), 1);
  274. ASSERT_EQ(reply.error_msg_size(), 0);
  275. auto data_type = proto::MS_FLOAT32;
  276. if (std::string(typeid(OUT_DT).name()) == std::string(typeid(int32_t).name())) {
  277. data_type = proto::MS_INT32;
  278. }
  279. std::vector<int64_t> shape;
  280. if (y_data.size() == 4) {
  281. shape = {2, 2};
  282. } else {
  283. shape = {2};
  284. }
  285. auto &output_instance = reply.instances(0);
  286. ASSERT_EQ(output_instance.items_size(), 1);
  287. auto &output_items = output_instance.items();
  288. ASSERT_EQ(output_items.begin()->first, "y");
  289. auto &output_tensor = output_items.begin()->second;
  290. CheckTensor(output_tensor, shape, data_type, y_data.data(), y_data.size() * sizeof(OUT_DT));
  291. }
  292. static void CheckTensor(const proto::Tensor &output_tensor, const std::vector<int64_t> &shape,
  293. proto::DataType data_type, const void *data, size_t data_size) {
  294. EXPECT_EQ(output_tensor.dtype(), data_type);
  295. // check shape [2,2]
  296. auto &output_tensor_shape = output_tensor.shape();
  297. ASSERT_EQ(output_tensor_shape.dims_size(), shape.size());
  298. std::vector<int64_t> proto_shape;
  299. for (size_t i = 0; i < output_tensor_shape.dims_size(); i++) {
  300. proto_shape.push_back(output_tensor_shape.dims(i));
  301. }
  302. EXPECT_EQ(proto_shape, shape);
  303. // check data
  304. ASSERT_EQ(output_tensor.data().size(), data_size);
  305. switch (data_type) {
  306. case proto::MS_FLOAT32: {
  307. auto data_len = data_size / sizeof(float);
  308. auto real_data = reinterpret_cast<const float *>(output_tensor.data().data());
  309. auto expect_data = reinterpret_cast<const float *>(data);
  310. for (size_t i = 0; i < data_len; i++) {
  311. EXPECT_EQ(real_data[i], expect_data[i]);
  312. if (real_data[i] != expect_data[i]) {
  313. break;
  314. }
  315. }
  316. break;
  317. }
  318. case proto::MS_INT32: {
  319. auto data_len = data_size / sizeof(int32_t);
  320. auto real_data = reinterpret_cast<const int32_t *>(output_tensor.data().data());
  321. auto expect_data = reinterpret_cast<const int32_t *>(data);
  322. for (size_t i = 0; i < data_len; i++) {
  323. EXPECT_EQ(real_data[i], expect_data[i]);
  324. if (real_data[i] != expect_data[i]) {
  325. break;
  326. }
  327. }
  328. break;
  329. }
  330. default:
  331. FAIL();
  332. }
  333. }
  334. static grpc::Status Dispatch(const proto::PredictRequest &request, proto::PredictReply *reply) {
  335. MSServiceImpl impl(Server::Instance().GetDispatcher());
  336. grpc::ServerContext context;
  337. auto promise = std::make_shared<std::promise<void>>();
  338. auto future = promise->get_future();
  339. DispatchCallback callback = [promise](Status status) { promise->set_value(); };
  340. auto status = impl.PredictAsync(&request, reply, callback);
  341. if (!status.IsSuccess()) {
  342. return grpc::Status::OK;
  343. }
  344. future.get();
  345. return grpc::Status::OK;
  346. }
  347. };
  348. } // namespace serving
  349. } // namespace mindspore
  350. #endif // MINDSPORE_SERVING_TEST_SERVABLE_COMMON_H

A lightweight and high-performance service module that helps MindSpore developers efficiently deploy online inference services in the production environment.