You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

pynative_execute.cc 16 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "pynative/pynative_execute.h"
  17. #include <typeinfo>
  18. #include <map>
  19. #include <set>
  20. #include <unordered_set>
  21. #include <algorithm>
  22. #include "utils/any.h"
  23. #include "utils/utils.h"
  24. #include "utils/context/ms_context.h"
  25. #include "operator/ops.h"
  26. #include "operator/composite/do_signature.h"
  27. #include "pipeline/parse/data_converter.h"
  28. #include "pipeline/static_analysis/prim.h"
  29. #include "session/session_factory.h"
  30. #include "pre_activate/pass/const_input_to_attr_registry.h"
  31. #include "pre_activate/common/helper.h"
  32. #include "pynative/base.h"
  33. #ifdef ENABLE_GE
  34. #include "pynative/pynative_execute_ge.h"
  35. #endif
  36. const char SINGLE_OP_GRAPH[] = "single_op_graph";
  37. // primitive unable to infer value for constant input in PyNative mode
  38. const std::set<std::string> vm_operators = {"partial", "depend", "make_ref", "zeros_like_tensor"};
  39. namespace mindspore {
  40. namespace pynative {
  41. static std::shared_ptr<session::SessionBasic> session = nullptr;
  42. inline ValuePtr PyAttrValue(const py::object &obj) {
  43. ValuePtr converted_ret = nullptr;
  44. bool converted = parse::ConvertData(obj, &converted_ret);
  45. if (!converted) {
  46. MS_LOG(EXCEPTION) << "Attribute convert error with type:" << std::string(py::str(obj));
  47. }
  48. return converted_ret;
  49. }
  50. py::tuple ConvertInputs(const PrimitivePyPtr &prim, const py::tuple &py_args) {
  51. auto signature = prim->signatures();
  52. std::vector<SignatureEnumDType> dtypes;
  53. (void)std::transform(signature.begin(), signature.end(), std::back_inserter(dtypes),
  54. [](const Signature &sig) { return sig.dtype; });
  55. int empty_dtype_count = std::count(dtypes.begin(), dtypes.end(), SignatureEnumDType::kDTypeEmptyDefaultValue);
  56. if (dtypes.size() == 0 || static_cast<int>(dtypes.size()) == empty_dtype_count) {
  57. return py_args;
  58. }
  59. std::map<SignatureEnumDType, std::vector<size_t>> type_indexs;
  60. for (size_t i = 0; i < dtypes.size(); ++i) {
  61. auto it = type_indexs.find(dtypes[i]);
  62. if (it == type_indexs.end()) {
  63. (void)type_indexs.insert(std::make_pair(dtypes[i], std::vector<size_t>{i}));
  64. } else {
  65. it->second.push_back(i);
  66. }
  67. }
  68. std::map<SignatureEnumDType, size_t> dst_type;
  69. for (auto it = type_indexs.begin(); it != type_indexs.end(); (void)++it) {
  70. auto type = it->first;
  71. auto indexs = it->second;
  72. if (indexs.size() < 2) {
  73. continue;
  74. }
  75. size_t m_index = indexs[0];
  76. for (size_t i = 1; i < indexs.size(); ++i) {
  77. if (py::isinstance<tensor::Tensor>(py_args[indexs[i]])) {
  78. m_index = indexs[i];
  79. }
  80. }
  81. (void)dst_type.insert(std::make_pair(type, m_index));
  82. }
  83. py::tuple py_inputs(py_args.size());
  84. for (size_t i = 0; i < py_args.size(); ++i) {
  85. auto it = dst_type.find(dtypes[i]);
  86. if (it != dst_type.end() && it->second != i &&
  87. (py::isinstance<py::int_>(py_args[i]) || py::isinstance<py::float_>(py_args[i]))) {
  88. auto tensor_ptr = py::cast<tensor::TensorPtr>(py_args[it->second]);
  89. if (py::isinstance<py::int_>(py_args[i])) {
  90. py_inputs[i] = std::make_shared<tensor::Tensor>(py::cast<py::int_>(py_args[i]), tensor_ptr->Dtype());
  91. } else {
  92. py_inputs[i] = std::make_shared<tensor::Tensor>(py::cast<py::float_>(py_args[i]), tensor_ptr->Dtype());
  93. }
  94. continue;
  95. }
  96. py_inputs[i] = py_args[i];
  97. }
  98. return py_inputs;
  99. }
  100. void PynativeInfer(const PrimitivePyPtr &prim, const py::tuple &py_args, OpExecInfo *const op_exec_info) {
  101. size_t size = py_args.size();
  102. AbstractBasePtrList args_spec_list;
  103. for (size_t i = 0; i < size; i++) {
  104. ValuePtr input_value = PyAttrValue(py_args[i]);
  105. if (py::isinstance<tensor::Tensor>(py_args[i])) {
  106. args_spec_list.emplace_back(abstract::FromValueInside(input_value, true));
  107. } else {
  108. args_spec_list.emplace_back(abstract::FromValueInside(input_value, false));
  109. }
  110. }
  111. AbstractBasePtr infer_res = InferOnePrim(prim, args_spec_list);
  112. op_exec_info->abstract = infer_res;
  113. }
  114. OpExecInfoPtr GenerateOpExecInfo(const py::args &args) {
  115. if (args.size() != PY_ARGS_NUM) {
  116. MS_LOG(ERROR) << "Four args are needed by RunOp";
  117. return nullptr;
  118. }
  119. auto op_exec_info = std::make_shared<OpExecInfo>();
  120. MS_EXCEPTION_IF_NULL(op_exec_info);
  121. op_exec_info->op_name = py::cast<std::string>(args[PY_NAME]);
  122. auto prim = py::cast<PrimitivePyPtr>(args[PY_PRIM]);
  123. auto pyobj = prim->GetPyObj();
  124. if (pyobj == nullptr) {
  125. MS_LOG(EXCEPTION) << "pyobj is empty";
  126. }
  127. py::tuple py_args = ConvertInputs(prim, args[PY_INPUTS]);
  128. // use python infer method
  129. if (ignore_infer_prim.find(op_exec_info->op_name) == ignore_infer_prim.end()) {
  130. PynativeInfer(prim, py_args, op_exec_info.get());
  131. }
  132. op_exec_info->py_primitive = prim;
  133. op_exec_info->op_attrs = py::getattr(args[PY_PRIM], "attrs");
  134. op_exec_info->op_inputs = py_args;
  135. op_exec_info->inputs_mask = args[PY_INPUT_MASK];
  136. if (op_exec_info->op_inputs.size() != op_exec_info->inputs_mask.size()) {
  137. MS_LOG(ERROR) << "Op:" << op_exec_info->op_name << " inputs size not equal op_mask";
  138. return nullptr;
  139. }
  140. return op_exec_info;
  141. }
  142. std::string GetSingleOpGraphInfo(const OpExecInfoPtr &op_exec_info) {
  143. MS_EXCEPTION_IF_NULL(op_exec_info);
  144. std::string graph_info;
  145. MS_EXCEPTION_IF_NULL(op_exec_info->abstract);
  146. // get input tensor info
  147. size_t input_num = op_exec_info->op_inputs.size();
  148. for (size_t index = 0; index < input_num; ++index) {
  149. if (py::isinstance<tensor::Tensor>(op_exec_info->op_inputs[index])) {
  150. auto tensor_ptr = py::cast<tensor::TensorPtr>(op_exec_info->op_inputs[index]);
  151. MS_EXCEPTION_IF_NULL(tensor_ptr);
  152. (void)graph_info.append(tensor_ptr->GetShapeAndDataTypeInfo() + "_");
  153. }
  154. }
  155. // get prim and abstract info
  156. (void)graph_info.append(std::to_string((uintptr_t)(op_exec_info->py_primitive.get())) + "_" +
  157. op_exec_info->abstract->ToString());
  158. MS_LOG(INFO) << "Graph info [" << graph_info << "]";
  159. return graph_info;
  160. }
  161. py::object RunOpInVM(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *status) {
  162. MS_LOG(INFO) << "RunOpInVM start";
  163. MS_EXCEPTION_IF_NULL(status);
  164. MS_EXCEPTION_IF_NULL(op_exec_info);
  165. MS_EXCEPTION_IF_NULL(op_exec_info->py_primitive);
  166. auto func = op_exec_info->py_primitive->GetComputeFunction();
  167. if (py::isinstance<py::none>(func)) {
  168. MS_LOG(ERROR) << "VM failed to get func";
  169. *status = PYNATIVE_OP_NOT_IMPLEMENTED_ERR;
  170. py::tuple err_ret(0);
  171. return std::move(err_ret);
  172. }
  173. // execute op
  174. py::tuple result = py::make_tuple(func(*op_exec_info->op_inputs));
  175. *status = PYNATIVE_SUCCESS;
  176. MS_LOG(INFO) << "RunOpInVM end";
  177. return std::move(result);
  178. }
  179. bool RunOpConvertConstInputToAttr(const py::object &input_object, size_t input_index, const PrimitivePtr &op_prim,
  180. const std::unordered_set<size_t> &input_attrs) {
  181. MS_EXCEPTION_IF_NULL(op_prim);
  182. auto input_names_value = op_prim->GetAttr(kAttrInputNames);
  183. if (input_names_value == nullptr) {
  184. return false;
  185. }
  186. auto input_names_vec = GetValue<std::vector<std::string>>(input_names_value);
  187. if (input_index >= input_names_vec.size()) {
  188. MS_LOG(EXCEPTION) << "The input index: " << input_index << " is large than the input names vector size!";
  189. }
  190. if (input_attrs.find(input_index) != input_attrs.end()) {
  191. ValuePtr value = parse::data_converter::PyDataToValue(input_object);
  192. MS_EXCEPTION_IF_NULL(value);
  193. auto input_name = input_names_vec[input_index];
  194. op_prim->set_attr(input_name, value);
  195. return true;
  196. }
  197. return false;
  198. }
  199. void PlantTensorTupleToVector(const py::tuple &tuple_inputs, const PrimitivePtr &op_prim,
  200. std::vector<tensor::TensorPtr> *input_tensor) {
  201. MS_EXCEPTION_IF_NULL(op_prim);
  202. MS_EXCEPTION_IF_NULL(input_tensor);
  203. for (const auto &input_object : tuple_inputs) {
  204. if (!py::isinstance<tensor::Tensor>(input_object)) {
  205. MS_LOG(EXCEPTION) << "The input object is not a tensor!";
  206. }
  207. auto tensor = py::cast<tensor::TensorPtr>(input_object);
  208. MS_EXCEPTION_IF_NULL(tensor);
  209. input_tensor->push_back(tensor);
  210. }
  211. op_prim->set_attr(kAttrDynInputSizes, MakeValue(std::vector<int>{SizeToInt(tuple_inputs.size())}));
  212. }
  213. void ConvertValueTupleToTensor(const py::object &input_object, std::vector<tensor::TensorPtr> *input_tensor) {
  214. MS_EXCEPTION_IF_NULL(input_tensor);
  215. ValuePtr input_value = parse::data_converter::PyDataToValue(input_object);
  216. MS_EXCEPTION_IF_NULL(input_value);
  217. if (!input_value->isa<ValueTuple>()) {
  218. MS_LOG(EXCEPTION) << "The input object is not a value tuple!";
  219. }
  220. auto value_tuple = input_value->cast<ValueTuplePtr>();
  221. MS_EXCEPTION_IF_NULL(value_tuple);
  222. tensor::TensorPtr tensor_ptr = opt::CreateTupleTensor(value_tuple);
  223. MS_EXCEPTION_IF_NULL(tensor_ptr);
  224. input_tensor->push_back(tensor_ptr);
  225. }
  226. void ConvertPyObjectToTensor(const py::object &input_object, const PrimitivePtr &op_prim,
  227. std::vector<tensor::TensorPtr> *input_tensor) {
  228. MS_EXCEPTION_IF_NULL(op_prim);
  229. MS_EXCEPTION_IF_NULL(input_tensor);
  230. tensor::TensorPtr tensor_ptr = nullptr;
  231. if (py::isinstance<tensor::Tensor>(input_object)) {
  232. tensor_ptr = py::cast<tensor::TensorPtr>(input_object);
  233. } else if (py::isinstance<py::float_>(input_object)) {
  234. tensor_ptr = std::make_shared<tensor::Tensor>(py::cast<py::float_>(input_object), kFloat32);
  235. } else if (py::isinstance<py::int_>(input_object)) {
  236. tensor_ptr = std::make_shared<tensor::Tensor>(py::cast<py::int_>(input_object), nullptr);
  237. } else if (py::isinstance<py::list>(input_object)) {
  238. tensor_ptr = std::make_shared<tensor::Tensor>(py::cast<py::list>(input_object), nullptr);
  239. } else if (py::isinstance<py::array>(input_object)) {
  240. tensor_ptr = std::make_shared<tensor::Tensor>(py::cast<py::array>(input_object), nullptr);
  241. } else if (py::isinstance<py::tuple>(input_object)) {
  242. auto tuple_inputs = py::cast<py::tuple>(input_object);
  243. if (py::isinstance<tensor::Tensor>(tuple_inputs[0])) {
  244. PlantTensorTupleToVector(tuple_inputs, op_prim, input_tensor);
  245. } else {
  246. ConvertValueTupleToTensor(input_object, input_tensor);
  247. }
  248. return;
  249. } else {
  250. MS_LOG(EXCEPTION) << "Run op inputs type is invalid!";
  251. }
  252. MS_EXCEPTION_IF_NULL(tensor_ptr);
  253. input_tensor->push_back(tensor_ptr);
  254. }
  255. void ConstructInputTensor(const OpExecInfoPtr &op_run_info, std::vector<bool> *tensors_mask,
  256. std::vector<tensor::TensorPtr> *input_tensors) {
  257. MS_EXCEPTION_IF_NULL(tensors_mask);
  258. MS_EXCEPTION_IF_NULL(input_tensors);
  259. PrimitivePtr op_prim = op_run_info->py_primitive;
  260. MS_EXCEPTION_IF_NULL(op_prim);
  261. if (op_run_info->op_inputs.size() != op_run_info->inputs_mask.size()) {
  262. MS_LOG(EXCEPTION) << "Op input size " << op_run_info->op_inputs.size() << " should be equal to op input mask size "
  263. << op_run_info->inputs_mask.size();
  264. }
  265. opt::ConstInputToAttrInfoRegister reg;
  266. bool reg_exist = opt::ConstInputToAttrInfoRegistry::Instance().GetRegisterByOpName(op_run_info->op_name, &reg);
  267. size_t input_num = op_run_info->op_inputs.size();
  268. MS_LOG(INFO) << "py input size: " << input_num;
  269. for (size_t index = 0; index < input_num; ++index) {
  270. // convert const input to attr
  271. if (reg_exist &&
  272. RunOpConvertConstInputToAttr(op_run_info->op_inputs[index], index, op_prim, reg.GetConstInputAttrInfo())) {
  273. continue;
  274. }
  275. // convert const and tuple input to tensor
  276. ConvertPyObjectToTensor(op_run_info->op_inputs[index], op_prim, input_tensors);
  277. // make tensors, weight : 1, data : 0
  278. std::vector<bool> new_mask(input_tensors->size() - tensors_mask->size(),
  279. py::cast<bool>(op_run_info->inputs_mask[index]));
  280. tensors_mask->insert(tensors_mask->end(), new_mask.begin(), new_mask.end());
  281. }
  282. }
  283. py::object RunOpInMs(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *status) {
  284. MS_EXCEPTION_IF_NULL(op_exec_info);
  285. MS_LOG(INFO) << "Start run op[" << op_exec_info->op_name << "] with backend policy ms";
  286. auto ms_context = MsContext::GetInstance();
  287. MS_EXCEPTION_IF_NULL(ms_context);
  288. ms_context->set_enable_pynative_infer(true);
  289. std::string device_target = ms_context->device_target();
  290. if (device_target != kAscendDevice && device_target != kGPUDevice) {
  291. MS_EXCEPTION(ArgumentError) << "Device target [" << device_target << "] is not supported in Pynative mode";
  292. }
  293. if (session == nullptr) {
  294. session = session::SessionFactory::Get().Create(device_target);
  295. }
  296. MS_EXCEPTION_IF_NULL(session);
  297. session->Init(ms_context->device_id());
  298. std::string graph_info = GetSingleOpGraphInfo(op_exec_info);
  299. std::vector<tensor::TensorPtr> input_tensors;
  300. std::vector<bool> tensors_mask;
  301. ConstructInputTensor(op_exec_info, &tensors_mask, &input_tensors);
  302. session->BuildOp(*op_exec_info, graph_info, input_tensors, tensors_mask);
  303. py::tuple result = session->RunOp(*op_exec_info, graph_info, input_tensors);
  304. ms_context->set_enable_pynative_infer(false);
  305. *status = PYNATIVE_SUCCESS;
  306. return result;
  307. }
  308. py::object RunOpWithBackendPolicy(MsBackendPolicy backend_policy, const OpExecInfoPtr op_exec_info,
  309. PynativeStatusCode *const status) {
  310. MS_EXCEPTION_IF_NULL(status);
  311. py::object result;
  312. switch (backend_policy) {
  313. case kMsBackendVmOnly: {
  314. // use vm only
  315. MS_LOG(INFO) << "RunOp use VM only backend";
  316. result = RunOpInVM(op_exec_info, status);
  317. break;
  318. }
  319. case kMsBackendGePrior: {
  320. #ifdef ENABLE_GE
  321. // use GE first, use vm when GE fails
  322. MS_LOG(INFO) << "RunOp use GE first backend";
  323. result = RunOpInGE(op_exec_info, status);
  324. if (*status != PYNATIVE_SUCCESS) {
  325. result = RunOpInVM(op_exec_info, status);
  326. }
  327. #endif
  328. break;
  329. }
  330. case kMsBackendMsPrior: {
  331. // use Ms fisrt,use others when ms failed
  332. MS_LOG(INFO) << "RunOp use Ms first backend";
  333. result = RunOpInMs(op_exec_info, status);
  334. if (*status != PYNATIVE_SUCCESS) {
  335. MS_LOG(ERROR) << "RunOp use Ms backend failed!!!";
  336. }
  337. break;
  338. }
  339. default:
  340. MS_LOG(ERROR) << "No backend configured for run op";
  341. }
  342. return result;
  343. }
  344. py::tuple RunOp(const py::args &args) {
  345. py::object result;
  346. // returns a null py::tuple on error
  347. py::tuple err_ret(0);
  348. PynativeStatusCode status = PYNATIVE_UNKNOWN_STATE;
  349. OpExecInfoPtr op_exec_info = GenerateOpExecInfo(args);
  350. MS_EXCEPTION_IF_NULL(op_exec_info);
  351. if (op_exec_info->abstract != nullptr) {
  352. py::dict output = abstract::ConvertAbstractToPython(op_exec_info->abstract);
  353. if (!output["value"].is_none()) {
  354. py::tuple value_ret(1);
  355. value_ret[0] = output["value"];
  356. return value_ret;
  357. }
  358. }
  359. MS_LOG(INFO) << "RunOp start, op name is: " << op_exec_info->op_name;
  360. mindspore::parse::python_adapter::set_python_env_flag(true);
  361. MsBackendPolicy backend_policy;
  362. #if (!defined ENABLE_GE)
  363. auto ms_context = MsContext::GetInstance();
  364. MS_EXCEPTION_IF_NULL(ms_context);
  365. if (ms_context->backend_policy() == "ms") {
  366. backend_policy = kMsBackendMsPrior;
  367. } else {
  368. backend_policy = kMsBackendVmOnly;
  369. }
  370. #else
  371. auto ms_context = MsContext::GetInstance();
  372. MS_EXCEPTION_IF_NULL(ms_context);
  373. ms_context->PynativeInitGe();
  374. backend_policy = kMsBackendGeOnly;
  375. #endif
  376. if (vm_operators.find(op_exec_info->op_name) != vm_operators.end()) {
  377. backend_policy = kMsBackendVmOnly;
  378. }
  379. result = RunOpWithBackendPolicy(backend_policy, op_exec_info, &status);
  380. if (status != PYNATIVE_SUCCESS) {
  381. MS_LOG(ERROR) << "Failed to run " << op_exec_info->op_name;
  382. return err_ret;
  383. }
  384. MS_LOG(INFO) << "RunOp end";
  385. return result;
  386. }
  387. void ClearPyNativeSession() { session = nullptr; }
  388. } // namespace pynative
  389. } // namespace mindspore