| @@ -0,0 +1,98 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import random | |||
| import grpc | |||
| import numpy as np | |||
| import ms_service_pb2 | |||
| import ms_service_pb2_grpc | |||
| import mindspore.dataset as de | |||
| from mindspore import Tensor, context | |||
| from mindspore import log as logger | |||
| from tests.st.networks.models.bert.src.bert_model import BertModel | |||
| from .generate_model import AddNet, bert_net_cfg | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||
| random.seed(1) | |||
| np.random.seed(1) | |||
| de.config.set_seed(1) | |||
| def test_add(): | |||
| channel = grpc.insecure_channel('localhost:5500') | |||
| stub = ms_service_pb2_grpc.MSServiceStub(channel) | |||
| request = ms_service_pb2.PredictRequest() | |||
| x = request.data.add() | |||
| x.tensor_shape.dims.extend([4]) | |||
| x.tensor_type = ms_service_pb2.MS_FLOAT32 | |||
| x.data = (np.ones([4]).astype(np.float32)).tobytes() | |||
| y = request.data.add() | |||
| y.tensor_shape.dims.extend([4]) | |||
| y.tensor_type = ms_service_pb2.MS_FLOAT32 | |||
| y.data = (np.ones([4]).astype(np.float32)).tobytes() | |||
| result = stub.Predict(request) | |||
| result_np = np.frombuffer(result.result[0].data, dtype=np.float32).reshape(result.result[0].tensor_shape.dims) | |||
| print("ms client received: ") | |||
| print(result_np) | |||
| net = AddNet() | |||
| net_out = net(Tensor(np.ones([4]).astype(np.float32)), Tensor(np.ones([4]).astype(np.float32))) | |||
| print("add net out: ") | |||
| print(net_out) | |||
| assert np.allclose(net_out.asnumpy(), result_np, 0.001, 0.001, equal_nan=True) | |||
| def test_bert(): | |||
| MAX_MESSAGE_LENGTH = 0x7fffffff | |||
| input_ids = np.random.randint(0, 1000, size=(2, 32), dtype=np.int32) | |||
| segment_ids = np.zeros((2, 32), dtype=np.int32) | |||
| input_mask = np.zeros((2, 32), dtype=np.int32) | |||
| channel = grpc.insecure_channel('localhost:5500', options=[('grpc.max_send_message_length', MAX_MESSAGE_LENGTH), | |||
| ('grpc.max_receive_message_length', MAX_MESSAGE_LENGTH)]) | |||
| stub = ms_service_pb2_grpc.MSServiceStub(channel) | |||
| request = ms_service_pb2.PredictRequest() | |||
| x = request.data.add() | |||
| x.tensor_shape.dims.extend([2, 32]) | |||
| x.tensor_type = ms_service_pb2.MS_INT32 | |||
| x.data = input_ids.tobytes() | |||
| y = request.data.add() | |||
| y.tensor_shape.dims.extend([2, 32]) | |||
| y.tensor_type = ms_service_pb2.MS_INT32 | |||
| y.data = segment_ids.tobytes() | |||
| z = request.data.add() | |||
| z.tensor_shape.dims.extend([2, 32]) | |||
| z.tensor_type = ms_service_pb2.MS_INT32 | |||
| z.data = input_mask.tobytes() | |||
| result = stub.Predict(request) | |||
| result_np = np.frombuffer(result.result[0].data, dtype=np.float32).reshape(result.result[0].tensor_shape.dims) | |||
| print("ms client received: ") | |||
| print(result_np) | |||
| net = BertModel(bert_net_cfg, False) | |||
| bert_out = net(Tensor(input_ids), Tensor(segment_ids), Tensor(input_mask)) | |||
| print("bert out: ") | |||
| print(bert_out) | |||
| bert_out_size = len(bert_out) | |||
| for i in range(bert_out_size): | |||
| result_np = np.frombuffer(result.result[i].data, dtype=np.float32).reshape(result.result[i].tensor_shape.dims) | |||
| logger.info("i:{}, result_np:{}, bert_out:{}". | |||
| format(i, result.result[i].tensor_shape.dims, bert_out[i].asnumpy().shape)) | |||
| assert np.allclose(bert_out[i].asnumpy(), result_np, 0.001, 0.001, equal_nan=True) | |||
| @@ -0,0 +1,76 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import random | |||
| import numpy as np | |||
| import mindspore.nn as nn | |||
| import mindspore.common.dtype as mstype | |||
| import mindspore.dataset as de | |||
| from mindspore import Tensor, context | |||
| from mindspore.ops import operations as P | |||
| from mindspore.train.serialization import export | |||
| from tests.st.networks.models.bert.src.bert_model import BertModel, BertConfig | |||
| bert_net_cfg = BertConfig( | |||
| batch_size=2, | |||
| seq_length=32, | |||
| vocab_size=21128, | |||
| hidden_size=768, | |||
| num_hidden_layers=12, | |||
| num_attention_heads=12, | |||
| intermediate_size=3072, | |||
| hidden_act="gelu", | |||
| hidden_dropout_prob=0.1, | |||
| attention_probs_dropout_prob=0.1, | |||
| max_position_embeddings=512, | |||
| type_vocab_size=2, | |||
| initializer_range=0.02, | |||
| use_relative_positions=False, | |||
| input_mask_from_dataset=True, | |||
| token_type_ids_from_dataset=True, | |||
| dtype=mstype.float32, | |||
| compute_type=mstype.float16 | |||
| ) | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||
| random.seed(1) | |||
| np.random.seed(1) | |||
| de.config.set_seed(1) | |||
| class AddNet(nn.Cell): | |||
| def __init__(self): | |||
| super(AddNet, self).__init__() | |||
| self.add = P.TensorAdd() | |||
| def construct(self, x_, y_): | |||
| return self.add(x_, y_) | |||
| def export_add_model(): | |||
| net = AddNet() | |||
| x = np.ones(4).astype(np.float32) | |||
| y = np.ones(4).astype(np.float32) | |||
| export(net, Tensor(x), Tensor(y), file_name='add.pb', file_format='BINARY') | |||
| def export_bert_model(): | |||
| net = BertModel(bert_net_cfg, False) | |||
| input_ids = np.random.randint(0, 1000, size=(2, 32), dtype=np.int32) | |||
| segment_ids = np.zeros((2, 32), dtype=np.int32) | |||
| input_mask = np.zeros((2, 32), dtype=np.int32) | |||
| export(net, Tensor(input_ids), Tensor(segment_ids), Tensor(input_mask), file_name='bert.pb', file_format='BINARY') | |||
| if __name__ == '__main__': | |||
| export_add_model() | |||
| export_bert_model() | |||
| @@ -0,0 +1,117 @@ | |||
| #!/bin/bash | |||
| export GLOG_v=1 | |||
| export DEVICE_ID=1 | |||
| MINDSPORE_INSTALL_PATH=$1 | |||
| CURRPATH=$(cd $(dirname $0); pwd) | |||
| CURRUSER=$(whoami) | |||
| PROJECT_PATH=${CURRPATH}/../../../ | |||
| ENV_DEVICE_ID=$DEVICE_ID | |||
| echo "MINDSPORE_INSTALL_PATH:" ${MINDSPORE_INSTALL_PATH} | |||
| echo "CURRPATH:" ${CURRPATH} | |||
| echo "CURRUSER:" ${CURRUSER} | |||
| echo "PROJECT_PATH:" ${PROJECT_PATH} | |||
| echo "ENV_DEVICE_ID:" ${ENV_DEVICE_ID} | |||
| MODEL_PATH=${CURRPATH}/model | |||
| export LD_LIBRARY_PATH=${MINDSPORE_INSTALL_PATH}/lib:/usr/local/python/python375/lib/:${LD_LIBRARY_PATH} | |||
| export PYTHONPATH=${MINDSPORE_INSTALL_PATH}/../:${PYTHONPATH} | |||
| echo "LD_LIBRARY_PATH: " ${LD_LIBRARY_PATH} | |||
| echo "PYTHONPATH: " ${PYTHONPATH} | |||
| echo "-------------show MINDSPORE_INSTALL_PATH----------------" | |||
| ls -l ${MINDSPORE_INSTALL_PATH} | |||
| echo "------------------show /usr/lib64/----------------------" | |||
| ls -l /usr/local/python/python375/lib/ | |||
| clean_pid() | |||
| { | |||
| ps aux | grep 'ms_serving' | grep ${CURRUSER} | grep -v grep | awk '{print $2}' | xargs kill -15 | |||
| if [ $? -ne 0 ] | |||
| then | |||
| echo "clean pip failed" | |||
| fi | |||
| sleep 6 | |||
| } | |||
| prepare_model() | |||
| { | |||
| echo "### begin to generate mode for serving test ###" | |||
| python3 generate_model.py &> generate_model_serving.log | |||
| echo "### end to generate mode for serving test ###" | |||
| result=`ls -l | grep -E '*pb' | grep -v ".log" | wc -l` | |||
| if [ ${result} -ne 2 ] | |||
| then | |||
| cat generate_model_serving.log | |||
| echo "### generate model for serving test failed ###" && exit 1 | |||
| clean_pid | |||
| fi | |||
| rm -rf model | |||
| mkdir model | |||
| mv *.pb ${CURRPATH}/model | |||
| cp ${MINDSPORE_INSTALL_PATH}/ms_serving ./ | |||
| } | |||
| start_service() | |||
| { | |||
| ${CURRPATH}/ms_serving --port=$1 --model_path=${MODEL_PATH} --model_name=$2 --device_id=$3 > $2_service.log 2>&1 & | |||
| if [ $? -ne 0 ] | |||
| then | |||
| echo "$2 faile to start." | |||
| fi | |||
| result=`grep -E 'MS Serving listening on 0.0.0.0:5500|MS Serving listening on 0.0.0.0:5501' $2_service.log | wc -l` | |||
| count=0 | |||
| while [[ ${result} -ne 1 && ${count} -lt 150 ]] | |||
| do | |||
| sleep 1 | |||
| count=$(($count+1)) | |||
| result=`grep -E 'MS Serving listening on 0.0.0.0:5500|MS Serving listening on 0.0.0.0:5501' $2_service.log | wc -l` | |||
| done | |||
| if [ ${count} -eq 150 ] | |||
| then | |||
| clean_pid | |||
| cat $2_service.log | |||
| echo "start serving service failed!" && exit 1 | |||
| fi | |||
| echo "### start serving service end ###" | |||
| } | |||
| pytest_serving() | |||
| { | |||
| unset http_proxy https_proxy | |||
| CLIENT_DEVICE_ID=$((${ENV_DEVICE_ID}+1)) | |||
| export DEVICE_ID=${CLIENT_DEVICE_ID} | |||
| local test_client_name=$1 | |||
| echo "### $1 client start ###" | |||
| python3 -m pytest -v -s client_example.py::${test_client_name} > ${test_client_name}_client.log 2>&1 | |||
| if [ $? -ne 0 ] | |||
| then | |||
| clean_pid | |||
| cat ${test_client_name}_client.log | |||
| echo "client $1 faile to start." | |||
| fi | |||
| echo "### $1 client end ###" | |||
| } | |||
| test_add_model() | |||
| { | |||
| start_service 5500 add.pb ${ENV_DEVICE_ID} | |||
| pytest_serving test_add | |||
| clean_pid | |||
| } | |||
| test_bert_model() | |||
| { | |||
| start_service 5500 bert.pb ${ENV_DEVICE_ID} | |||
| pytest_serving test_bert | |||
| clean_pid | |||
| } | |||
| echo "-----serving start-----" | |||
| rm -rf ms_serving *.log *.pb *.dat ${CURRPATH}/model ${CURRPATH}/kernel_meta | |||
| prepare_model | |||
| test_add_model | |||
| test_bert_model | |||
| @@ -0,0 +1,39 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import os | |||
| import sys | |||
| import pytest | |||
| import numpy as np | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.env_single | |||
| def test_serving(): | |||
| """test_serving""" | |||
| sh_path = os.path.split(os.path.realpath(__file__))[0] | |||
| python_path_folders = [] | |||
| for python_path in sys.path: | |||
| if os.path.isdir(python_path): | |||
| python_path_folders += [python_path] | |||
| folders = [] | |||
| for folder in python_path_folders: | |||
| folders += [os.path.join(folder, x) for x in os.listdir(folder) \ | |||
| if os.path.isdir(os.path.join(folder, x)) and '/site-packages/mindspore' in os.path.join(folder, x)] | |||
| ret = os.system(f"sh {sh_path}/serving.sh {folders[0].split('mindspore', 1)[0] + 'mindspore'}") | |||
| assert np.allclose(ret, 0, 0.0001, 0.0001) | |||
| if __name__ == '__main__': | |||
| test_serving() | |||