| @@ -19,6 +19,7 @@ | |||
| #include <cstring> | |||
| #include <thread> | |||
| #include <chrono> | |||
| #include <map> | |||
| #include "worker/preprocess.h" | |||
| #include "worker/postprocess.h" | |||
| #include "mindspore_serving/ccsrc/common/tensor.h" | |||
| @@ -265,13 +266,16 @@ void WorkExecutor::OnRecievePreprocessInputs(const std::vector<Instance> &inputs | |||
| } | |||
| } | |||
| void WorkExecutor::OnRecievePostprocessInputs(const Instance &input) { | |||
| const MethodSignature &method_def = input.context.user_context->method_def; | |||
| auto real_input = CreateInputInstance(input, kPredictPhaseTag_Postprocess); | |||
| void WorkExecutor::OnRecievePostprocessInputs(const std::vector<Instance> &inputs) { | |||
| if (inputs.empty()) { | |||
| MSI_LOG_EXCEPTION << "Inputs cannot be empty"; | |||
| } | |||
| const MethodSignature &method_def = inputs[0].context.user_context->method_def; | |||
| auto real_input = CreateInputInstance(inputs, kPredictPhaseTag_Postprocess); | |||
| if (python_postprocess_names_.count(method_def.postprocess_name) > 0) { | |||
| py_postprocess_task_queue_->PushTask(method_def.postprocess_name, GetWorkerId(), {real_input}); | |||
| py_postprocess_task_queue_->PushTask(method_def.postprocess_name, GetWorkerId(), real_input); | |||
| } else { | |||
| cpp_postprocess_task_queue_->PushTask(method_def.postprocess_name, GetWorkerId(), {real_input}); | |||
| cpp_postprocess_task_queue_->PushTask(method_def.postprocess_name, GetWorkerId(), real_input); | |||
| } | |||
| } | |||
| @@ -332,14 +336,22 @@ void WorkExecutor::PredictHandle(const std::vector<Instance> &inputs) { | |||
| this->ReplyError(inputs, status); | |||
| return; | |||
| } | |||
| std::map<std::string, std::vector<Instance>> map_output; | |||
| std::vector<Instance> reply_result; | |||
| for (auto &output : outputs) { | |||
| MethodSignature &method_def = output.context.user_context->method_def; | |||
| if (!method_def.postprocess_name.empty()) { | |||
| OnRecievePostprocessInputs(output); | |||
| map_output[method_def.postprocess_name].push_back(output); | |||
| } else { | |||
| ReplyRequest(output); | |||
| reply_result.push_back(output); | |||
| } | |||
| } | |||
| if (!reply_result.empty()) { | |||
| ReplyRequest(reply_result); | |||
| } | |||
| for (auto &item : map_output) { | |||
| OnRecievePostprocessInputs(item.second); | |||
| } | |||
| return; | |||
| } catch (const std::bad_alloc &ex) { | |||
| status = INFER_STATUS_LOG_ERROR(SYSTEM_ERROR) << "Serving Error: malloc memory failed"; | |||
| @@ -76,9 +76,9 @@ class WorkExecutor { | |||
| bool ReplyRequest(const std::vector<Instance> &outputs); | |||
| bool ReplyRequest(const Instance &outputs); | |||
| void OnRecievePreprocessInputs(const std::vector<Instance> &inputs); // callback | |||
| void OnRecievePredictInputs(const std::vector<Instance> &inputs); // callback | |||
| void OnRecievePostprocessInputs(const Instance &inputs); // callback | |||
| void OnRecievePreprocessInputs(const std::vector<Instance> &inputs); // callback | |||
| void OnRecievePredictInputs(const std::vector<Instance> &inputs); // callback | |||
| void OnRecievePostprocessInputs(const std::vector<Instance> &inputs); // callback | |||
| void PredictHandle(const std::vector<Instance> &inputs); | |||
| Status PrePredict(const std::vector<Instance> &inputs); | |||
| @@ -42,9 +42,17 @@ export LD_LIBRARY_PATH=${BUILD_PATH}/mindspore_serving/tests/ut/python:${LD_LIBR | |||
| echo "PYTHONPATH=$PYTHONPATH" | |||
| echo "LD_LIBRARY_PATH=$LD_LIBRARY_PATH" | |||
| unset http_proxy | |||
| unset https_proxy | |||
| PROCESS=`netstat -nlp | grep :5500 | awk '{print $7}' | awk -F"/" '{print $1}'` | |||
| for i in $PROCESS | |||
| do | |||
| echo "Kill the process [ $i ]" | |||
| kill -9 $i | |||
| done | |||
| cd - | |||
| cd ${PROJECT_PATH}/tests/ut/python/tests/ | |||
| if [ $# -gt 0 ]; then | |||
| @@ -428,7 +428,7 @@ def add_common(x1, x2): | |||
| result = client.infer(instances) | |||
| print(result) | |||
| assert len(result) == instance_count | |||
| assert "Postprocess Failed" in str(result[1]["error"]) | |||
| assert "Postprocess Failed" in str(result[1]["error"]) or "Servable stopped" in str(result[1]["error"]) | |||
| @serving_test | |||
| @@ -510,8 +510,8 @@ def add_common(x1, x2): | |||
| print(result) | |||
| assert len(result) == instance_count | |||
| assert "Postprocess Failed" in str(result[0]["error"]) | |||
| assert "Postprocess Failed" in str(result[1]["error"]) | |||
| assert "Postprocess Failed" in str(result[0]["error"]) or "Servable stopped" in str(result[0]["error"]) | |||
| assert "Postprocess Failed" in str(result[1]["error"]) or "Servable stopped" in str(result[1]["error"]) | |||
| @serving_test | |||
| @@ -596,9 +596,7 @@ def add_common(x1, x2): | |||
| print(result) | |||
| assert len(result) == instance_count | |||
| assert result[0]["y"] == 0 | |||
| assert result[1]["y"] == 1 | |||
| assert "Preprocess Failed" in str(result[2]["error"]) | |||
| assert "Preprocess Failed" in str(result[2]["error"]) or "Servable stopped" in str(result[2]["error"]) | |||
| @serving_test | |||
| @@ -685,9 +683,9 @@ def add_common(x1, x2): | |||
| print(result) | |||
| assert len(result) == instance_count | |||
| assert "Preprocess Failed" in str(result[0]["error"]) | |||
| assert "Preprocess Failed" in str(result[1]["error"]) | |||
| assert "Preprocess Failed" in str(result[2]["error"]) | |||
| assert "Preprocess Failed" in str(result[0]["error"]) or "Servable stopped" in str(result[0]["error"]) | |||
| assert "Preprocess Failed" in str(result[1]["error"]) or "Servable stopped" in str(result[1]["error"]) | |||
| assert "Preprocess Failed" in str(result[2]["error"]) or "Servable stopped" in str(result[2]["error"]) | |||
| @serving_test | |||
| @@ -731,5 +729,5 @@ def add_common(x1, x2): | |||
| assert len(result) == instance_count | |||
| assert result[0]["y"] == 0 | |||
| assert "Preprocess Failed" in str(result[1]["error"]) | |||
| assert "Preprocess Failed" in str(result[1]["error"]) or "Servable stopped" in str(result[1]["error"]) | |||
| assert result[0]["y"] == 0 | |||
| @@ -1 +1 @@ | |||
| Subproject commit b5ad38fab8805ad3ac8486ab60d82ff71e8a2398 | |||
| Subproject commit 52fac12367131ec57e87ba757e42fc25479f433a | |||