You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

serving_py.cc 7.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <string>
  17. #include "python/worker/preprocess_py.h"
  18. #include "python/worker/postprocess_py.h"
  19. #include "python/worker/worker_py.h"
  20. #include "python/tensor_py.h"
  21. #include "common/servable.h"
  22. #include "worker/context.h"
  23. #include "python/master/master_py.h"
  24. #include "master/dispacther.h"
  25. namespace mindspore::serving {
  26. PYBIND11_MODULE(_mindspore_serving, m) {
  27. py::class_<PyPreprocessStorage, std::shared_ptr<PyPreprocessStorage>>(m, "PreprocessStorage_")
  28. .def(py::init<>())
  29. .def_static("get_instance", &PyPreprocessStorage::Instance)
  30. .def("register", &PyPreprocessStorage::Register)
  31. .def("get_pycpp_preprocess_info", &PyPreprocessStorage::GetPyCppPreprocessInfo);
  32. py::class_<PyPostprocessStorage, std::shared_ptr<PyPostprocessStorage>>(m, "PostprocessStorage_")
  33. .def(py::init<>())
  34. .def_static("get_instance", &PyPostprocessStorage::Instance)
  35. .def("register", &PyPostprocessStorage::Register)
  36. .def("get_pycpp_postprocess_info", &PyPostprocessStorage::GetPyCppPostprocessInfo);
  37. py::enum_<PredictPhaseTag>(m, "PredictPhaseTag_")
  38. .value("kPredictPhaseTag_Input", PredictPhaseTag::kPredictPhaseTag_Input)
  39. .value("kPredictPhaseTag_Preproces", PredictPhaseTag::kPredictPhaseTag_Preproces)
  40. .value("kPredictPhaseTag_Predict", PredictPhaseTag::kPredictPhaseTag_Predict)
  41. .value("kPredictPhaseTag_Postprocess", PredictPhaseTag::kPredictPhaseTag_Postprocess)
  42. .export_values();
  43. py::class_<MethodSignature>(m, "MethodSignature_")
  44. .def(py::init<>())
  45. .def_readwrite("method_name", &MethodSignature::method_name)
  46. .def_readwrite("inputs", &MethodSignature::inputs)
  47. .def_readwrite("outputs", &MethodSignature::outputs)
  48. .def_readwrite("preprocess_name", &MethodSignature::preprocess_name)
  49. .def_readwrite("preprocess_inputs", &MethodSignature::preprocess_inputs)
  50. .def_readwrite("postprocess_name", &MethodSignature::postprocess_name)
  51. .def_readwrite("postprocess_inputs", &MethodSignature::postprocess_inputs)
  52. .def_readwrite("servable_name", &MethodSignature::servable_name)
  53. .def_readwrite("servable_inputs", &MethodSignature::servable_inputs)
  54. .def_readwrite("returns", &MethodSignature::returns);
  55. py::class_<RequestSpec>(m, "RequestSpec_")
  56. .def(py::init<>())
  57. .def_readwrite("servable_name", &RequestSpec::servable_name)
  58. .def_readwrite("version_number", &RequestSpec::version_number)
  59. .def_readwrite("method_name", &RequestSpec::method_name);
  60. py::class_<ServableMeta>(m, "ServableMeta_")
  61. .def(py::init<>())
  62. .def_readwrite("servable_name", &ServableMeta::servable_name)
  63. .def_readwrite("inputs_count", &ServableMeta::inputs_count)
  64. .def_readwrite("outputs_count", &ServableMeta::outputs_count)
  65. .def_readwrite("servable_file", &ServableMeta::servable_file)
  66. .def_readwrite("with_batch_dim", &ServableMeta::with_batch_dim)
  67. .def_readwrite("options", &ServableMeta::load_options)
  68. .def_readwrite("without_batch_dim_inputs", &ServableMeta::without_batch_dim_inputs)
  69. .def("set_model_format", &ServableMeta::SetModelFormat);
  70. py::class_<ServableSignature>(m, "ServableSignature_")
  71. .def(py::init<>())
  72. .def_readwrite("servable_meta", &ServableSignature::servable_meta)
  73. .def_readwrite("methods", &ServableSignature::methods);
  74. py::class_<ServableStorage, std::shared_ptr<ServableStorage>>(m, "ServableStorage_")
  75. .def(py::init<>())
  76. .def_static("get_instance", &ServableStorage::Instance)
  77. .def("register_servable", &ServableStorage::Register)
  78. .def("register_servable_input_output_info", &ServableStorage::RegisterInputOutputInfo)
  79. .def("get_servable_input_output_info", &ServableStorage::GetInputOutputInfo)
  80. .def("register_method", &ServableStorage::RegisterMethod)
  81. .def("declare_servable", &ServableStorage::DeclareServable);
  82. py::class_<TaskContext>(m, "TaskContext_").def(py::init<>());
  83. py::class_<TaskItem>(m, "TaskItem_")
  84. .def(py::init<>())
  85. .def_readwrite("task_type", &TaskItem::task_type)
  86. .def_readwrite("name", &TaskItem::name)
  87. .def_property_readonly("instance_list",
  88. [](const TaskItem &item) {
  89. py::tuple instances(item.instance_list.size());
  90. for (size_t i = 0; i < item.instance_list.size(); i++) {
  91. instances[i] = PyTensor::AsNumpyTuple(item.instance_list[i].data);
  92. }
  93. return instances;
  94. })
  95. .def_readwrite("context_list", &TaskItem::context_list);
  96. py::class_<PyWorker>(m, "Worker_")
  97. .def_static("start_servable", &PyWorker::StartServable)
  98. .def_static("start_servable_in_master", &PyWorker::StartServableInMaster)
  99. .def_static("get_batch_size", &PyWorker::GetBatchSize)
  100. .def_static("wait_and_clear", &PyWorker::WaitAndClear)
  101. .def_static("stop", PyWorker::Stop)
  102. .def_static("get_py_task", &PyWorker::GetPyTask, py::call_guard<py::gil_scoped_release>())
  103. .def_static("try_get_preprocess_py_task", &PyWorker::TryGetPreprocessPyTask)
  104. .def_static("try_get_postprocess_py_task", &PyWorker::TryGetPostprocessPyTask)
  105. .def_static("push_preprocess_result", &PyWorker::PushPreprocessPyResult)
  106. .def_static("push_preprocess_failed", &PyWorker::PushPreprocessPyFailed)
  107. .def_static("push_postprocess_result", &PyWorker::PushPostprocessPyResult)
  108. .def_static("push_postprocess_failed", &PyWorker::PushPostprocessPyFailed);
  109. py::class_<ServableContext, std::shared_ptr<ServableContext>>(m, "Context_")
  110. .def(py::init<>())
  111. .def_static("get_instance", &ServableContext::Instance)
  112. .def("set_device_type_str",
  113. [](ServableContext &context, const std::string &device_type) {
  114. auto status = context.SetDeviceTypeStr(device_type);
  115. if (status != SUCCESS) {
  116. MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage();
  117. }
  118. })
  119. .def("set_device_id", &ServableContext::SetDeviceId);
  120. py::class_<PyMaster, std::shared_ptr<PyMaster>>(m, "Master_")
  121. .def_static("start_grpc_server", &PyMaster::StartGrpcServer)
  122. .def_static("start_grpc_master_server", &PyMaster::StartGrpcMasterServer)
  123. .def_static("start_restful_server", &PyMaster::StartRestfulServer)
  124. .def_static("wait_and_clear", &PyMaster::WaitAndClear)
  125. .def_static("stop", &PyMaster::Stop);
  126. (void)py::module::import("atexit").attr("register")(py::cpp_function{[&]() -> void {
  127. Server::Instance().Clear();
  128. Worker::GetInstance().Clear();
  129. }});
  130. }
  131. } // namespace mindspore::serving

A lightweight and high-performance service module that helps MindSpore developers efficiently deploy online inference services in the production environment.