You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

resource.cc 23 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398
  1. /**
  2. * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
  3. *
  4. * Copyright 2019-2021 Huawei Technologies Co., Ltd
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the "License");
  7. * you may not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an "AS IS" BASIS,
  14. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. #include "pipeline/jit/resource.h"
  19. #include "pipeline/jit/static_analysis/static_analysis.h"
  20. #include "debug/trace.h"
  21. #include "ir/dtype.h"
  22. #include "pipeline/jit/parse/data_converter.h"
  23. #include "frontend/operator/ops.h"
  24. #include "frontend/optimizer/ad/dfunctor.h"
  25. namespace mindspore {
  26. // namespace to support opmap definition
  27. namespace pipeline {
  28. BuiltInTypeMap &GetMethodMap() {
  29. static BuiltInTypeMap method_map = {{kObjectTypeString,
  30. {
  31. {"__bool__", std::string("str_bool")} // C.str_bool
  32. }},
  33. {kMetaTypeNone,
  34. {
  35. {"__bool__", std::string("none_bool")} // C.none_bool
  36. }},
  37. {kObjectTypeFunction,
  38. {
  39. {"__bool__", std::string("func_bool")} // C.str_bool
  40. }},
  41. {kNumberTypeBool,
  42. {
  43. {"__and__", prim::kPrimBoolAnd}, // P.bool_and
  44. {"__or__", prim::kPrimBoolOr}, // P.bool_or
  45. {"__eq__", prim::kPrimBoolEq}, // P.bool_eq
  46. {"__ne__", std::string("bool_ne")}, // C.bool_ne
  47. {"__bool__", prim::kPrimIdentity} // P.identity
  48. }},
  49. {kNumberTypeInt,
  50. {
  51. {"__add__", prim::kPrimScalarAdd}, // P.scalar_add
  52. {"__sub__", prim::kPrimScalarSub}, // P.scalar_sub
  53. {"__mul__", prim::kPrimScalarMul}, // P.scalar_mul
  54. {"__floordiv__", std::string("int_floordiv")}, // C.int_floordiv
  55. {"__truediv__", std::string("int_truediv")}, // C.int_truediv
  56. {"__mod__", prim::kPrimScalarMod}, // P.scalar_mod
  57. {"__pow__", prim::kPrimScalarPow}, // P.scalar_pow
  58. {"__floor__", prim::kPrimIdentity}, // P.identity
  59. {"__trunc__", prim::kPrimIdentity}, // P.identity
  60. {"__pos__", prim::kPrimScalarUadd}, // P.scalar_uadd
  61. {"__neg__", prim::kPrimScalarUsub}, // P.scalar_usub
  62. {"__eq__", prim::kPrimScalarEq}, // P.scalar_eq
  63. {"__ne__", prim::kPrimScalarNe}, // P.scalar_ne
  64. {"__lt__", prim::kPrimScalarLt}, // P.scalar_lt
  65. {"__gt__", prim::kPrimScalarGt}, // P.scalar_gt
  66. {"__le__", prim::kPrimScalarLe}, // P.scalar_le
  67. {"__ge__", prim::kPrimScalarGe}, // P.scalar_ge
  68. {"__bool__", std::string("int_bool")}, // C.int_bool
  69. {"__ms_to_array__", prim::kPrimScalarToArray}, // P.scalar_to_array
  70. }},
  71. {kNumberTypeUInt,
  72. {
  73. {"__add__", prim::kPrimScalarAdd}, // P.scalar_add,
  74. {"__sub__", prim::kPrimScalarSub}, // P.scalar_sub,
  75. {"__mul__", prim::kPrimScalarMul}, // P.scalar_mul,
  76. {"__floordiv__", prim::kPrimScalarDiv}, // P.scalar_div,
  77. {"__truediv__", std::string("int_truediv")}, // C.int_truediv
  78. {"__mod__", prim::kPrimScalarMod}, // P.scalar_mod,
  79. {"__pow__", prim::kPrimScalarPow}, // P.scalar_pow,
  80. {"__floor__", prim::kPrimIdentity}, // P.identity,
  81. {"__trunc__", prim::kPrimIdentity}, // P.identity,
  82. {"__pos__", prim::kPrimScalarUadd}, // P.scalar_uadd,
  83. {"__neg__", prim::kPrimScalarUsub}, // P.scalar_usub,
  84. {"__eq__", prim::kPrimScalarEq}, // P.scalar_eq,
  85. {"__ne__", prim::kPrimScalarNe}, // P.scalar_ne,
  86. {"__lt__", prim::kPrimScalarLt}, // P.scalar_lt,
  87. {"__gt__", prim::kPrimScalarGt}, // P.scalar_gt,
  88. {"__le__", prim::kPrimScalarLe}, // P.scalar_le,
  89. {"__ge__", prim::kPrimScalarGe}, // P.scalar_ge,
  90. {"__bool__", std::string("int_bool")}, // C.int_bool
  91. {"__ms_to_array__", prim::kPrimScalarToArray}, // P.scalar_to_array,
  92. }},
  93. {kNumberTypeFloat,
  94. {
  95. {"__add__", prim::kPrimScalarAdd}, // P.scalar_add,
  96. {"__sub__", prim::kPrimScalarSub}, // P.scalar_sub,
  97. {"__mul__", prim::kPrimScalarMul}, // P.scalar_mul,
  98. {"__floordiv__", std::string("float_floordiv")}, // C.float_floordiv
  99. {"__truediv__", prim::kPrimScalarDiv}, // P.scalar_div,
  100. {"__mod__", prim::kPrimScalarMod}, // P.scalar_mod,
  101. {"__pow__", prim::kPrimScalarPow}, // P.scalar_pow,
  102. {"__floor__", prim::kPrimScalarFloor}, // P.scalar_floor,
  103. {"__trunc__", prim::kPrimScalarTrunc}, // P.scalar_trunc,
  104. {"__pos__", prim::kPrimScalarUadd}, // P.scalar_uadd,
  105. {"__neg__", prim::kPrimScalarUsub}, // P.scalar_usub,
  106. {"__eq__", prim::kPrimScalarEq}, // P.scalar_eq,
  107. {"__ne__", prim::kPrimScalarNe}, // P.scalar_ne,
  108. {"__lt__", prim::kPrimScalarLt}, // P.scalar_lt,
  109. {"__gt__", prim::kPrimScalarGt}, // P.scalar_gt,
  110. {"__le__", prim::kPrimScalarLe}, // P.scalar_le,
  111. {"__ge__", prim::kPrimScalarGe}, // P.scalar_ge,
  112. {"__bool__", std::string("float_bool")}, // C.float_bool
  113. {"__ms_to_array__", prim::kPrimScalarToArray}, // P.scalar_to_array,
  114. }},
  115. {kObjectTypeTuple,
  116. {
  117. {"__len__", prim::kPrimTupleLen}, // P.tuple_len,
  118. {"__getitem__", prim::kPrimTupleGetItem}, // P.tuple_getitem,
  119. {"__setitem__", prim::kPrimTupleSetItem}, // P.tuple_setitem,
  120. {"__ms_iter__", prim::kPrimIdentity}, // P.identity,
  121. {"__ms_next__", std::string("tuple_next")}, // C.tuple_next,
  122. {"__ms_hasnext__", std::string("tuple_hasnext")}, // C.tuple_hasnext
  123. {"__bool__", std::string("tuple_bool")} // C.tuple_bool
  124. }},
  125. {kObjectTypeList,
  126. {
  127. {"__len__", prim::kPrimListLen}, // P.list_len,
  128. {"__getitem__", prim::kPrimListGetItem}, // P.list_getitem,
  129. {"__setitem__", prim::kPrimListSetItem}, // P.list_setitem,
  130. {"__ms_iter__", prim::kPrimIdentity}, // P.identity
  131. {"__ms_next__", std::string("list_next")}, // C.list_next
  132. {"append", std::string("list_append")}, // C.list_next
  133. {"__bool__", std::string("list_bool")}, // C.list_bool
  134. {"__ms_hasnext__", std::string("list_hasnext")},
  135. }},
  136. {kObjectTypeDictionary,
  137. {
  138. {"__len__", prim::kPrimDictLen}, // P.dict_len
  139. {"__getitem__", prim::kPrimDictGetItem}, // P.dict_getitem
  140. {"__setitem__", prim::kPrimDictSetItem}, // P.dict_setitem,
  141. {"keys", prim::kPrimDictGetKeys}, // P.dict_getkeys,
  142. {"values", prim::kPrimDictGetValues}, // P.dict_getvalues,
  143. {"__bool__", std::string("dict_bool")} // C.dict_bool
  144. }},
  145. {kObjectTypeTensorType,
  146. {
  147. {"all", std::string("all_")}, // C.reduce_all
  148. {"any", std::string("any_")}, // C.reduce_any
  149. {"__add__", std::string("add")}, // C.add
  150. {"__sub__", std::string("sub")}, // C.sub
  151. {"__mul__", std::string("mul")}, // C.mul
  152. {"abs", std::string("abs_")}, // C.abs_
  153. {"mean", std::string("mean")}, // C.mean
  154. {"__truediv__", std::string("truediv")}, // C.truediv
  155. {"__floordiv__", std::string("floordiv")}, // C.floordiv
  156. {"__mod__", std::string("mod")}, // C.mod
  157. {"__pow__", std::string("pow_")}, // C.pow
  158. {"__floor__", std::string("array_floor")}, // C.array_floor
  159. {"__trunc__", std::string("array_trunc")}, // C.array_trunc
  160. {"__pos__", std::string("array_uadd")}, // C.array_uadd
  161. {"__neg__", std::string("array_usub")}, // C.array_usub
  162. {"__eq__", std::string("eq")}, // C.eq
  163. {"__ne__", std::string("ne")}, // C.ne
  164. {"__lt__", std::string("lt")}, // C.lt
  165. {"__gt__", std::string("gt")}, // C.gt
  166. {"__le__", std::string("le")}, // C.le
  167. {"__ge__", std::string("ge")}, // C.ge
  168. {"expand_as", std::string("expand_tensor_as")}, // C.expand_as
  169. {"view", std::string("view")}, // C.view
  170. {"__len__", prim::kPrimArrayLen}, // P.array_len,
  171. {"__getitem__", prim::kPrimArrayGetItem}, // P.array_getitem,
  172. {"__setitem__", prim::kPrimArraySetItem}, // P.array_setitem,
  173. {"__ms_iter__", std::string("array_iter")}, // C.array_iter
  174. {"__ms_to_array__", prim::kPrimIdentity}, // P.identity,
  175. {"item", prim::kPrimArrayToScalar}, // P.array_to_scalar,
  176. {"transpose", std::string("transpose")}, // P.transpose
  177. {"flatten", std::string("flatten")}, // P.reshape(,-1)
  178. {"reshape", std::string("reshape")}, // P.reshape()
  179. {"ravel", std::string("ravel")}, // P.reshape(,(-1,))
  180. {"swapaxes", std::string("swapaxes")}, // P.transpose()
  181. {"squeeze", std::string("squeeze")}, // P.squeeze()
  182. {"astype", std::string("astype")}, // P.cast()
  183. {"__bool__", std::string("tensor_bool")}, // C.tensor_bool
  184. }},
  185. {kObjectTypeRowTensorType,
  186. {
  187. {"__add__", prim::kPrimRowTensorAdd}, // P.row_tensor_add
  188. }},
  189. {kObjectTypeJTagged, {}},
  190. {kObjectTypeSymbolicKeyType, {}},
  191. {kObjectTypeEnvType, {}}};
  192. return method_map;
  193. }
  194. BuiltInTypeMap &GetAttrMap() {
  195. static BuiltInTypeMap attr_map = {
  196. {kObjectTypeTensorType,
  197. {
  198. {"shape", std::string("shape_")}, // C.shape_
  199. {"dtype", std::string("dtype_")}, // C.dtype_
  200. {"size", std::string("size_")}, // C.size_
  201. {"ndim", std::string("ndim_")}, // C.ndim_
  202. {"T", std::string("T_")}, // C.T_
  203. {"itemsize", std::string("itemsize_")}, // C.itemsize_
  204. {"nbytes", std::string("nbytes_")}, // C.nbytes_
  205. {"strides", std::string("strides_")}, // C.strides_
  206. }},
  207. {kObjectTypeRowTensorType,
  208. {
  209. {"values", prim::kPrimRowTensorGetValues}, // F.row_tensor_get_values
  210. {"indices", prim::kPrimRowTensorGetIndices}, // F.row_tensor_get_indices
  211. {"dense_shape", prim::kPrimRowTensorGetDenseShape}, // F.row_tensor_get_dense_shape
  212. }},
  213. {kObjectTypeSparseTensorType,
  214. {
  215. {"values", prim::kPrimSparseTensorGetValues}, // F.sparse_tensor_get_values
  216. {"indices", prim::kPrimSparseTensorGetIndices}, // F.sparse_tensor_get_indices
  217. {"dense_shape", prim::kPrimSparseTensorGetDenseShape}, // F.sparse_tensor_get_dense_shape
  218. }},
  219. };
  220. return attr_map;
  221. }
  222. Resource::Resource(const py::object &obj)
  223. : engine_(std::make_shared<abstract::AnalysisEngine>(abstract::GetPrimEvaluatorConstructors(), manager_)),
  224. input_(obj),
  225. is_cleaned_(false) {}
  226. Resource::~Resource() {
  227. MS_LOG(DEBUG) << "Resource clear";
  228. std::unordered_map<std::string, Any>().swap(results_);
  229. // If exit normally, these global variables will be cleaned
  230. // in Resource::Clean call by MsPipeline::Compile, but if exit with MS_LOGEXCEPTION,
  231. // these global variables may not being cleaned, it may
  232. // cause segmentfault when free python object inside these global variables
  233. // after python interpreter got freed, so these global variables
  234. // are cleaned here.
  235. // So if exit normally, these global variable will be cleaned twice,
  236. // care be taken to prevent double free in the following functions.
  237. if (!is_cleaned_) {
  238. try {
  239. Clean();
  240. } catch (const std::exception &e) {
  241. MS_LOG(ERROR) << "Exception when cleaning resource. Error info " << e.what();
  242. } catch (...) {
  243. MS_LOG(ERROR) << "Exception when cleaning resource.";
  244. }
  245. }
  246. }
  247. Any GetMethodOrAttr(const string &name, const TypeId &type_id, const BuiltInTypeMap &method_map) {
  248. auto type_method_map = method_map.find(static_cast<int64_t>(type_id));
  249. if (type_method_map == method_map.end()) {
  250. return Any();
  251. }
  252. auto method = type_method_map->second.find(name);
  253. if (method == type_method_map->second.end()) {
  254. return Any();
  255. }
  256. return method->second;
  257. }
  258. bool Resource::IsTypeInBuiltInMap(const TypeId &type) {
  259. TypeId type_id = NormalizeTypeId(type);
  260. const BuiltInTypeMap &method_map = GetMethodMap();
  261. auto iter = method_map.find(static_cast<int64_t>(type_id));
  262. if (iter == method_map.end()) {
  263. const BuiltInTypeMap &attr_map = GetAttrMap();
  264. iter = attr_map.find(static_cast<int64_t>(type_id));
  265. if (iter == attr_map.end()) {
  266. return false;
  267. }
  268. }
  269. return true;
  270. }
  271. Any Resource::GetMethodPtr(const TypeId &type, const std::string &name) {
  272. TypeId type_id = NormalizeTypeId(type);
  273. const BuiltInTypeMap &method_map = GetMethodMap();
  274. return GetMethodOrAttr(name, type_id, method_map);
  275. }
  276. Any Resource::GetAttrPtr(const TypeId &type, const std::string &name) {
  277. TypeId type_id = NormalizeTypeId(type);
  278. const BuiltInTypeMap &attr_map = GetAttrMap();
  279. return GetMethodOrAttr(name, type_id, attr_map);
  280. }
  281. void Resource::Clean() {
  282. // AbstractTensor->elements() will be saved in AbstractBasePtrList
  283. args_spec_.clear();
  284. input_ = py::none();
  285. // Context with AbstractBasePtrList may be saved in GraphEvaluator
  286. // some Evaluator like ResolveEvaluator may save Python object in cache,
  287. // it should be cleaned before Python Interpreter destructed.
  288. MS_EXCEPTION_IF_NULL(engine_);
  289. engine_->ClearEvaluatorCache();
  290. // clean static variable to prevent from crash. As static variable is released after
  291. // Python threads is released.
  292. parse::data_converter::ClearObjectCache();
  293. parse::Parser::CleanParserResource();
  294. parse::CleanDataClassToClassMap();
  295. trace::ClearTraceStack();
  296. is_cleaned_ = true;
  297. }
  298. void MemoryCleaner::Init() {
  299. pynative_in_construct_process_ = false;
  300. pynative_in_end_graph_process_ = false;
  301. pynative_released_history_.clear();
  302. pynative_new_primtives_squence_.clear();
  303. }
  304. MemoryCleaner Resource::mem_cleaner_ = MemoryCleaner();
  305. void MemoryCleaner::RecordPrimitivePy(PrimitivePy *prim) {
  306. if (prim == nullptr) {
  307. return;
  308. }
  309. all_primitives_[prim] = true;
  310. }
  311. void MemoryCleaner::ReleasePrimitivePyObj(PrimitivePy *prim) {
  312. if (prim == nullptr) {
  313. return;
  314. }
  315. auto it = all_primitives_.find(prim);
  316. if (it == all_primitives_.end()) {
  317. return;
  318. }
  319. // If flag is false,the pointer hased been released, so it can't be visited.
  320. if (!it->second) {
  321. return;
  322. }
  323. all_primitives_[prim] = false;
  324. prim->SetPyObj(py::none());
  325. }
  326. void MemoryCleaner::ClearPrimitivePyPythonObj() {
  327. for (auto &it : all_primitives_) {
  328. if (it.second) {
  329. it.first->SetPyObj(py::none());
  330. }
  331. }
  332. all_primitives_.clear();
  333. }
  334. void MemoryCleaner::RecordPynativeShortLifePrimitivePy(PrimitivePy *prim) {
  335. if (prim == nullptr) {
  336. return;
  337. }
  338. if (pynative_short_life_primitives_.find(prim) != pynative_short_life_primitives_.end()) {
  339. return;
  340. }
  341. MS_LOG(DEBUG) << "Record pynative tmp primitive:" << prim->ToString();
  342. pynative_short_life_primitives_.insert(prim);
  343. pynative_new_primtives_squence_.push_back(prim->ToString());
  344. }
  345. void MemoryCleaner::ErasePynativeShortLifePrimitivePy(PrimitivePy *prim) {
  346. if (prim == nullptr) {
  347. return;
  348. }
  349. if (pynative_short_life_primitives_.find(prim) == pynative_short_life_primitives_.end()) {
  350. return;
  351. }
  352. pynative_short_life_primitives_.erase(prim);
  353. MS_LOG(DEBUG) << "Erase pynative tmp primitive:" << prim->ToString();
  354. }
  355. void MemoryCleaner::ClearPynativeShortLifePrimitivePy() {
  356. // If the primitives name sequence never been released before, keep the primtives alive
  357. if (std::find(pynative_released_history_.begin(), pynative_released_history_.end(),
  358. pynative_new_primtives_squence_) == pynative_released_history_.end()) {
  359. pynative_released_history_.push_back(pynative_new_primtives_squence_);
  360. } else {
  361. for (auto &primitive : pynative_short_life_primitives_) {
  362. ReleasePrimitivePyObj(primitive);
  363. }
  364. }
  365. pynative_short_life_primitives_.clear();
  366. pynative_new_primtives_squence_.clear();
  367. }
  368. void MemoryCleaner::EnterPynativeConstructProcess() { pynative_in_construct_process_ = true; }
  369. void MemoryCleaner::LeavePynativeConstructProcess() {
  370. pynative_in_construct_process_ = false;
  371. ClearPynativeShortLifePrimitivePy();
  372. }
  373. bool MemoryCleaner::IsInPynativeConstructProcess() const { return pynative_in_construct_process_; }
  374. void MemoryCleaner::EnterPynativeEndGraphProcess() { pynative_in_end_graph_process_ = true; }
  375. void MemoryCleaner::LeavePynativeEndGraphProcess() { pynative_in_end_graph_process_ = false; }
  376. bool MemoryCleaner::IsInPynativeEndGraphProcess() const { return pynative_in_end_graph_process_; }
  377. } // namespace pipeline
  378. } // namespace mindspore