You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

client_example.py 3.9 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. import random
  16. import json
  17. import grpc
  18. import numpy as np
  19. import requests
  20. import ms_service_pb2
  21. import ms_service_pb2_grpc
  22. import mindspore.dataset as de
  23. from mindspore import Tensor, context
  24. from mindspore import log as logger
  25. from tests.st.networks.models.bert.src.bert_model import BertModel
  26. from .generate_model import bert_net_cfg
  27. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  28. random.seed(1)
  29. np.random.seed(1)
  30. de.config.set_seed(1)
  31. def test_bert():
  32. MAX_MESSAGE_LENGTH = 0x7fffffff
  33. input_ids = np.random.randint(0, 1000, size=(2, 32), dtype=np.int32)
  34. segment_ids = np.zeros((2, 32), dtype=np.int32)
  35. input_mask = np.zeros((2, 32), dtype=np.int32)
  36. # grpc visit
  37. channel = grpc.insecure_channel('localhost:5500', options=[('grpc.max_send_message_length', MAX_MESSAGE_LENGTH),
  38. ('grpc.max_receive_message_length', MAX_MESSAGE_LENGTH)])
  39. stub = ms_service_pb2_grpc.MSServiceStub(channel)
  40. request = ms_service_pb2.PredictRequest()
  41. x = request.data.add()
  42. x.tensor_shape.dims.extend([2, 32])
  43. x.tensor_type = ms_service_pb2.MS_INT32
  44. x.data = input_ids.tobytes()
  45. y = request.data.add()
  46. y.tensor_shape.dims.extend([2, 32])
  47. y.tensor_type = ms_service_pb2.MS_INT32
  48. y.data = segment_ids.tobytes()
  49. z = request.data.add()
  50. z.tensor_shape.dims.extend([2, 32])
  51. z.tensor_type = ms_service_pb2.MS_INT32
  52. z.data = input_mask.tobytes()
  53. result = stub.Predict(request)
  54. grpc_result = np.frombuffer(result.result[0].data, dtype=np.float32).reshape(result.result[0].tensor_shape.dims)
  55. print("ms grpc client received: ")
  56. print(grpc_result)
  57. # ms result
  58. net = BertModel(bert_net_cfg, False)
  59. bert_out = net(Tensor(input_ids), Tensor(segment_ids), Tensor(input_mask))
  60. print("bert out: ")
  61. print(bert_out[0])
  62. bert_out_size = len(bert_out)
  63. # compare grpc result
  64. for i in range(bert_out_size):
  65. grpc_result = np.frombuffer(result.result[i].data, dtype=np.float32).reshape(result.result[i].tensor_shape.dims)
  66. logger.info("i:{}, grpc_result:{}, bert_out:{}".
  67. format(i, result.result[i].tensor_shape.dims, bert_out[i].asnumpy().shape))
  68. assert np.allclose(bert_out[i].asnumpy(), grpc_result, 0.001, 0.001, equal_nan=True)
  69. # http visit
  70. data = {"tensor": [input_ids.tolist(), segment_ids.tolist(), input_mask.tolist()]}
  71. url = "http://127.0.0.1:5501"
  72. input_json = json.dumps(data)
  73. headers = {'Content-type': 'application/json'}
  74. response = requests.post(url, data=input_json, headers=headers)
  75. result = response.text
  76. result = result.replace('\r', '\\r').replace('\n', '\\n')
  77. result_json = json.loads(result, strict=False)
  78. http_result = np.array(result_json['tensor'])
  79. print("ms http client received: ")
  80. print(http_result[0][:200])
  81. # compare http result
  82. for i in range(bert_out_size):
  83. logger.info("i:{}, http_result:{}, bert_out:{}".
  84. format(i, np.shape(http_result[i]), bert_out[i].asnumpy().shape))
  85. assert np.allclose(bert_out[i].asnumpy(), http_result[i], 0.001, 0.001, equal_nan=True)