You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

acl_session_test_add.cc 6.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "acl_session_test_common.h"
  17. using namespace std;
  18. namespace mindspore {
  19. namespace serving {
  20. class AclSessionAddTest : public AclSessionTest {
  21. public:
  22. AclSessionAddTest() = default;
  23. void SetUp() override {
  24. AclSessionTest::SetUp();
  25. aclmdlDesc model_desc;
  26. model_desc.inputs.push_back(
  27. AclTensorDesc{.dims = {2, 24, 24, 3}, .data_type = ACL_FLOAT, .size = 2 * 24 * 24 * 3 * sizeof(float)});
  28. model_desc.inputs.push_back(
  29. AclTensorDesc{.dims = {2, 24, 24, 3}, .data_type = ACL_FLOAT, .size = 2 * 24 * 24 * 3 * sizeof(float)});
  30. model_desc.outputs.push_back(
  31. AclTensorDesc{.dims = {2, 24, 24, 3}, .data_type = ACL_FLOAT, .size = 2 * 24 * 24 * 3 * sizeof(float)});
  32. mock_model_desc_ = MockModelDesc(model_desc);
  33. g_acl_model_desc = &mock_model_desc_;
  34. g_acl_model = &add_mock_model_;
  35. }
  36. void CreateDefaultRequest(PredictRequest &request) {
  37. auto input0 = request.add_data();
  38. CreateTensor(*input0, {2, 24, 24, 3}, ::ms_serving::DataType::MS_FLOAT32);
  39. auto input1 = request.add_data();
  40. CreateTensor(*input1, {2, 24, 24, 3}, ::ms_serving::DataType::MS_FLOAT32);
  41. auto input0_data = reinterpret_cast<float *>(input0->mutable_data()->data());
  42. auto input1_data = reinterpret_cast<float *>(input1->mutable_data()->data());
  43. for (int i = 0; i < 2 * 24 * 24 * 3; i++) {
  44. input0_data[i] = i % 1024;
  45. input1_data[i] = i % 1024 + 1;
  46. }
  47. }
  48. void CheckDefaultReply(const PredictReply &reply) {
  49. EXPECT_TRUE(reply.result().size() == 1);
  50. if (reply.result().size() == 1) {
  51. CheckTensorItem(reply.result(0), {2, 24, 24, 3}, ::ms_serving::DataType::MS_FLOAT32);
  52. auto &output = reply.result(0).data();
  53. EXPECT_EQ(output.size(), 2 * 24 * 24 * 3 * sizeof(float));
  54. if (output.size() == 2 * 24 * 24 * 3 * sizeof(float)) {
  55. auto output_data = reinterpret_cast<const float *>(output.data());
  56. for (int i = 0; i < 2 * 24 * 24 * 3; i++) {
  57. EXPECT_EQ(output_data[i], (i % 1024) + (i % 1024 + 1));
  58. if (output_data[i] != (i % 1024) + (i % 1024 + 1)) {
  59. break;
  60. }
  61. }
  62. }
  63. }
  64. }
  65. MockModelDesc mock_model_desc_;
  66. AddMockAclModel add_mock_model_;
  67. };
  68. TEST_F(AclSessionAddTest, TestAclSession_OneTime_Success) {
  69. inference::AclSession acl_session;
  70. uint32_t device_id = 1;
  71. EXPECT_TRUE(acl_session.InitEnv("Ascend", device_id) == SUCCESS);
  72. uint32_t model_id = 0;
  73. EXPECT_TRUE(acl_session.LoadModelFromFile("fake_model_path", model_id) == SUCCESS);
  74. // create inputs
  75. PredictRequest request;
  76. CreateDefaultRequest(request);
  77. PredictReply reply;
  78. ServingRequest serving_request(request);
  79. ServingReply serving_reply(reply);
  80. EXPECT_TRUE(acl_session.ExecuteModel(model_id, serving_request, serving_reply) == SUCCESS);
  81. CheckDefaultReply(reply);
  82. EXPECT_TRUE(acl_session.UnloadModel(model_id) == SUCCESS);
  83. EXPECT_TRUE(acl_session.FinalizeEnv() == SUCCESS);
  84. };
  85. TEST_F(AclSessionAddTest, TestAclSession_MutilTimes_Success) {
  86. inference::AclSession acl_session;
  87. uint32_t device_id = 1;
  88. EXPECT_TRUE(acl_session.InitEnv("Ascend", device_id) == SUCCESS);
  89. uint32_t model_id = 0;
  90. EXPECT_TRUE(acl_session.LoadModelFromFile("fake_model_path", model_id) == SUCCESS);
  91. for (int i = 0; i < 10; i++) {
  92. // create inputs
  93. PredictRequest request;
  94. CreateDefaultRequest(request);
  95. PredictReply reply;
  96. ServingRequest serving_request(request);
  97. ServingReply serving_reply(reply);
  98. EXPECT_TRUE(acl_session.ExecuteModel(model_id, serving_request, serving_reply) == SUCCESS);
  99. CheckDefaultReply(reply);
  100. }
  101. EXPECT_TRUE(acl_session.UnloadModel(model_id) == SUCCESS);
  102. EXPECT_TRUE(acl_session.FinalizeEnv() == SUCCESS);
  103. };
  104. TEST_F(AclSessionAddTest, TestAclSession_DeviceRunMode_OneTime_Success) {
  105. SetDeviceRunMode();
  106. inference::AclSession acl_session;
  107. uint32_t device_id = 1;
  108. EXPECT_TRUE(acl_session.InitEnv("Ascend", device_id) == SUCCESS);
  109. uint32_t model_id = 0;
  110. EXPECT_TRUE(acl_session.LoadModelFromFile("fake_model_path", model_id) == SUCCESS);
  111. // create inputs
  112. PredictRequest request;
  113. CreateDefaultRequest(request);
  114. PredictReply reply;
  115. ServingRequest serving_request(request);
  116. ServingReply serving_reply(reply);
  117. EXPECT_TRUE(acl_session.ExecuteModel(model_id, serving_request, serving_reply) == SUCCESS);
  118. CheckDefaultReply(reply);
  119. EXPECT_TRUE(acl_session.UnloadModel(model_id) == SUCCESS);
  120. EXPECT_TRUE(acl_session.FinalizeEnv() == SUCCESS);
  121. };
  122. TEST_F(AclSessionAddTest, TestAclSession_DeviceRunMode_MutilTimes_Success) {
  123. SetDeviceRunMode();
  124. inference::AclSession acl_session;
  125. uint32_t device_id = 1;
  126. EXPECT_TRUE(acl_session.InitEnv("Ascend", device_id) == SUCCESS);
  127. uint32_t model_id = 0;
  128. EXPECT_TRUE(acl_session.LoadModelFromFile("fake_model_path", model_id) == SUCCESS);
  129. for (int i = 0; i < 10; i++) {
  130. // create inputs
  131. PredictRequest request;
  132. CreateDefaultRequest(request);
  133. PredictReply reply;
  134. ServingRequest serving_request(request);
  135. ServingReply serving_reply(reply);
  136. EXPECT_TRUE(acl_session.ExecuteModel(model_id, serving_request, serving_reply) == SUCCESS);
  137. CheckDefaultReply(reply);
  138. }
  139. EXPECT_TRUE(acl_session.UnloadModel(model_id) == SUCCESS);
  140. EXPECT_TRUE(acl_session.FinalizeEnv() == SUCCESS);
  141. };
  142. } // namespace serving
  143. } // namespace mindspore