You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

gllo_utils.cc 50 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267
  1. /**
  2. * Copyright 2020-2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "tools/optimizer/common/gllo_utils.h"
  17. #include <vector>
  18. #include <utility>
  19. #include <algorithm>
  20. #include <unordered_map>
  21. #include <functional>
  22. #include <string>
  23. #include "src/ops/primitive_c.h"
  24. #include "src/common/common.h"
  25. #include "frontend/operator/ops.h"
  26. #include "backend/optimizer/common/helper.h"
  27. namespace mindspore {
  28. namespace opt {
  29. namespace {
  30. constexpr auto kAnfPrimitiveIndex = 0;
  31. bool IsRealKernel(const AnfNodePtr &node) {
  32. if (node == nullptr) {
  33. lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
  34. return false;
  35. }
  36. // parameter and value node is not a real kernel too
  37. if (!node->isa<CNode>()) {
  38. return true;
  39. }
  40. auto cnode = node->cast<CNodePtr>();
  41. if (cnode == nullptr) {
  42. lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
  43. return false;
  44. }
  45. if (cnode->inputs().empty()) {
  46. MS_LOG(ERROR) << "Illegal null input of cnode(%s)" << node->DebugString();
  47. lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_INPUT_TENSOR_ERROR);
  48. return false;
  49. }
  50. auto input = cnode->inputs()[0];
  51. bool is_virtual_node = IsPrimitive(input, prim::kPrimImageSummary) || IsPrimitive(input, prim::kPrimScalarSummary) ||
  52. IsPrimitive(input, prim::kPrimTensorSummary) ||
  53. IsPrimitive(input, prim::kPrimHistogramSummary) || IsPrimitive(input, prim::kPrimMakeTuple) ||
  54. IsPrimitive(input, prim::kPrimStateSetItem) || IsPrimitive(input, prim::kPrimDepend) ||
  55. IsPrimitive(input, prim::kPrimTupleGetItem) || IsPrimitive(input, prim::kPrimControlDepend) ||
  56. IsPrimitive(input, prim::kPrimReturn) || IsPrimitive(input, prim::kPrimPartial);
  57. return !is_virtual_node;
  58. }
  59. ValueNodePtr CreateValueNodeWithSexp(const BaseRef &sexp) {
  60. if (utils::isa<int>(sexp)) {
  61. return NewValueNode(utils::cast<int>(sexp));
  62. }
  63. if (utils::isa<float>(sexp)) {
  64. return NewValueNode(utils::cast<float>(sexp));
  65. }
  66. if (utils::isa<bool>(sexp)) {
  67. return NewValueNode(utils::cast<bool>(sexp));
  68. }
  69. if (utils::isa<ValuePtr>(sexp)) {
  70. return NewValueNode(utils::cast<ValuePtr>(sexp));
  71. }
  72. return nullptr;
  73. }
  74. CNodePtr CreateCNodeWithGraph(const std::vector<AnfNodePtr> &input_nodes, const BaseRef &graph) {
  75. if (utils::isa<FuncGraphPtr>(graph)) {
  76. return std::make_shared<CNode>(input_nodes, utils::cast<FuncGraphPtr>(graph));
  77. }
  78. if (utils::isa<VarPtr>(graph)) {
  79. return std::make_shared<CNode>(input_nodes, utils::cast<VarPtr>(graph));
  80. }
  81. return nullptr;
  82. }
  83. VarNodePtr CreateVarNodeWithSexp(const BaseRef &sexp, const BaseRef &graph) {
  84. if (utils::isa<VarPtr>(graph)) {
  85. MS_LOG(DEBUG) << "make VarPtr " + graph.ToString();
  86. return std::make_shared<VarNode>(utils::cast<VarPtr>(sexp), nullptr);
  87. }
  88. if (utils::isa<FuncGraphPtr>(graph)) {
  89. MS_LOG(DEBUG) << "VarNode, should input a Var in graph. It's GraphPtr: " + graph.ToString();
  90. return std::make_shared<VarNode>(utils::cast<VarPtr>(sexp), utils::cast<FuncGraphPtr>(graph));
  91. }
  92. MS_LOG(ERROR) << "VarNode, should input a Var in graph. It's " + graph.ToString();
  93. return nullptr;
  94. }
  95. AnfNodePtr HandleSexpVector(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars,
  96. bool multigraph) {
  97. if (primitive_vars == nullptr) {
  98. lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
  99. return nullptr;
  100. }
  101. MS_LOG(DEBUG) << "HandleSexpVector sexp: " + sexp.ToString() + ", graph " + graph.ToString();
  102. std::vector<AnfNodePtr> input_nodes;
  103. const auto &tuple = utils::cast<VectorRef>(sexp);
  104. if (multigraph && utils::isa<VarPtr>(graph)) {
  105. for (auto &x : tuple) {
  106. AnfNodePtr node = SexpToNode(x, std::make_shared<Var>("G"), primitive_vars, true);
  107. input_nodes.push_back(node);
  108. }
  109. VarPtr var_ptr = utils::cast<VarPtr>(graph);
  110. return std::make_shared<CNode>(input_nodes, var_ptr);
  111. }
  112. for (auto &x : tuple) {
  113. AnfNodePtr node = SexpToNode(x, graph, primitive_vars, multigraph);
  114. input_nodes.push_back(node);
  115. }
  116. return CreateCNodeWithGraph(input_nodes, graph);
  117. }
  118. } // namespace
  119. bool CheckInputs(const CNodePtr &cnode) {
  120. if (cnode == nullptr) {
  121. MS_LOG(ERROR) << "cnode is nullptr.";
  122. return false;
  123. }
  124. if (std::any_of(cnode->inputs().begin(), cnode->inputs().end(),
  125. [](const AnfNodePtr &anf_node) { return anf_node == nullptr; })) {
  126. MS_LOG(ERROR) << "input is nullptr.";
  127. return false;
  128. }
  129. return true;
  130. }
  131. bool CheckPrimitiveType(const AnfNodePtr &node, const PrimitivePtr &primitive_type) {
  132. if (node == nullptr) {
  133. lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
  134. return false;
  135. }
  136. if (!node->isa<CNode>()) {
  137. return false;
  138. }
  139. auto cnode = node->cast<CNodePtr>();
  140. if (cnode == nullptr) {
  141. lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
  142. return false;
  143. }
  144. return IsPrimitive(cnode->input(kAnfPrimitiveIndex), primitive_type);
  145. }
  146. bool AnfEqualPrimitive(AnfNodePtr a_node, AnfNodePtr b_node) {
  147. auto a_value_node = a_node->cast<ValueNodePtr>();
  148. auto b_value_node = b_node->cast<ValueNodePtr>();
  149. if (a_value_node == nullptr || b_value_node == nullptr) {
  150. lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
  151. return false;
  152. }
  153. auto a_value = a_value_node->value();
  154. auto b_value = b_value_node->value();
  155. if (a_value == nullptr || b_value == nullptr) {
  156. lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
  157. return false;
  158. }
  159. auto a_prim = a_value->cast<PrimitivePtr>();
  160. auto b_prim = b_value->cast<PrimitivePtr>();
  161. if (a_prim == nullptr || b_prim == nullptr) {
  162. lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
  163. return false;
  164. }
  165. return a_prim->cast<PrimitiveCPtr>()->Type() == b_prim->cast<PrimitiveCPtr>()->Type();
  166. }
  167. bool AnfEqualValueNode(AnfNodePtr a_node, AnfNodePtr b_node) {
  168. auto a_value_node_ptr = a_node->cast<ValueNodePtr>();
  169. auto b_value_node_ptr = b_node->cast<ValueNodePtr>();
  170. if (a_value_node_ptr == nullptr || b_value_node_ptr == nullptr) {
  171. MS_LOG(ERROR) << "cast value node ptr fail";
  172. lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
  173. return false;
  174. }
  175. auto a_value_ptr = a_value_node_ptr->value();
  176. auto b_value_ptr = b_value_node_ptr->value();
  177. if (a_value_ptr == nullptr || b_value_ptr == nullptr) {
  178. MS_LOG(ERROR) << "value ptr is nullptr";
  179. lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
  180. return false;
  181. }
  182. if (utils::isa<lite::PrimitiveC>(a_value_ptr) && utils::isa<lite::PrimitiveC>(b_value_ptr)) {
  183. auto a_obj = (lite::PrimitiveC *)(a_value_ptr.get());
  184. auto b_obj = (lite::PrimitiveC *)(b_value_ptr.get());
  185. return (*a_obj) == (*b_obj);
  186. } else {
  187. return (*a_value_ptr) == (*b_value_ptr);
  188. }
  189. }
  190. bool AnfEqual(const BaseRef &a, const BaseRef &b) {
  191. if (utils::isa<AnfNodePtr>(a) && utils::isa<AnfNodePtr>(b)) {
  192. auto a_node = utils::cast<AnfNodePtr>(a);
  193. auto b_node = utils::cast<AnfNodePtr>(b);
  194. if (a_node == nullptr || b_node == nullptr) {
  195. lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
  196. return false;
  197. }
  198. if (IsValueNode<Primitive>(a_node) && IsValueNode<Primitive>(b_node)) {
  199. return AnfEqualPrimitive(a_node, b_node);
  200. }
  201. if (a_node->isa<ValueNode>() && b_node->isa<ValueNode>()) {
  202. return AnfEqualValueNode(a_node, b_node);
  203. }
  204. }
  205. if (a.m_ptr->isa<lite::PrimitiveC>() && b.m_ptr->isa<lite::PrimitiveC>()) {
  206. auto a_value_node_ptr = a.m_ptr->cast<PrimitiveCPtr>();
  207. auto b_value_node_ptr = b.m_ptr->cast<PrimitiveCPtr>();
  208. return a_value_node_ptr->Type() == b_value_node_ptr->Type();
  209. }
  210. return a == b;
  211. }
  212. bool CNodeTypeEqual(const BaseRef &a, const BaseRef &b) {
  213. // To matchCNode and Kernel's type
  214. if (utils::isa<CNode>(a) && utils::isa<CNode>(b)) {
  215. return true;
  216. }
  217. return a.type() == b.type();
  218. }
  219. AnfNodePtr SexpToNode(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars, bool multigraph) {
  220. MS_LOG(DEBUG) << "SexpToNode sexp: " + sexp.ToString() + ", graph " + graph.ToString();
  221. if (primitive_vars == nullptr) {
  222. lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
  223. return nullptr;
  224. }
  225. if (utils::isa<VectorRef>(sexp)) {
  226. return HandleSexpVector(sexp, graph, primitive_vars, multigraph);
  227. }
  228. if (utils::isa<VarPtr>(sexp)) {
  229. auto var_ptr = utils::cast<VarPtr>(sexp);
  230. if (var_ptr == nullptr) {
  231. lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
  232. return nullptr;
  233. }
  234. if (var_ptr->primitive()) {
  235. (*primitive_vars)[var_ptr->primitive()] = var_ptr;
  236. return NewValueNode(var_ptr->primitive());
  237. }
  238. return CreateVarNodeWithSexp(sexp, graph);
  239. }
  240. if (utils::isa<AnfNodePtr>(sexp)) {
  241. return utils::cast<AnfNodePtr>(sexp);
  242. }
  243. auto value_node = CreateValueNodeWithSexp(sexp);
  244. if (value_node == nullptr) {
  245. MS_LOG(ERROR) << "sexp cannot converted. sexp: " << sexp.ToString();
  246. lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
  247. return nullptr;
  248. }
  249. return value_node;
  250. }
  251. bool IsRealCNodeKernel(const AnfNodePtr &node) {
  252. if (node == nullptr) {
  253. lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
  254. return false;
  255. }
  256. // parameter and value node is not a real cnode kernel
  257. if (!node->isa<CNode>()) {
  258. return false;
  259. }
  260. // return considered as a real node
  261. if (CheckPrimitiveType(node, prim::kPrimReturn)) {
  262. return true;
  263. }
  264. return IsRealKernel(node);
  265. }
  266. bool IsGraphKernel(const AnfNodePtr &node) {
  267. if (node == nullptr) {
  268. lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
  269. return false;
  270. }
  271. // graph kernel should be a real cnode kernel.
  272. if (!IsRealCNodeKernel(node)) {
  273. return false;
  274. }
  275. auto cnode = node->cast<CNodePtr>();
  276. if (cnode == nullptr) {
  277. lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
  278. return false;
  279. }
  280. auto input = cnode->input(kAnfPrimitiveIndex);
  281. // graph kernel should has func_graph as first input.
  282. if (!IsValueNode<FuncGraph>(input)) {
  283. return false;
  284. }
  285. auto func_graph = GetValueNode<FuncGraphPtr>(input);
  286. if (func_graph == nullptr) {
  287. lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
  288. return false;
  289. }
  290. return func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL);
  291. }
  292. int CheckIfFuncGraphIsNull(const FuncGraphPtr &graph) {
  293. if (graph == nullptr) {
  294. MS_LOG(ERROR) << "The graph is null.";
  295. lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
  296. return lite::RET_NULL_PTR;
  297. }
  298. return lite::RET_OK;
  299. }
  300. int CheckIfAnfNodeIsNull(const AnfNodePtr &node) {
  301. if (node == nullptr) {
  302. MS_LOG(ERROR) << "The AnfNode is null.";
  303. lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
  304. return lite::RET_NULL_PTR;
  305. }
  306. return lite::RET_OK;
  307. }
  308. int CheckIfCNodeIsNull(const CNodePtr &node) {
  309. if (node == nullptr) {
  310. MS_LOG(ERROR) << "The CNode is null.";
  311. lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
  312. return lite::RET_NULL_PTR;
  313. }
  314. return lite::RET_OK;
  315. }
  316. int CheckIfVarIsNull(const VarPtr &var) {
  317. if (var == nullptr) {
  318. MS_LOG(ERROR) << "The Var is null.";
  319. lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
  320. return lite::RET_NULL_PTR;
  321. }
  322. return lite::RET_OK;
  323. }
  324. int CheckIfNodeIsParam(const AnfNodePtr &node) {
  325. if (node != nullptr && !utils::isa<ParameterPtr>(node)) {
  326. MS_LOG(DEBUG) << "The Node is not param.";
  327. lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_INVALID_OP_ATTR);
  328. return lite::RET_INVALID_OP_ATTR;
  329. }
  330. return lite::RET_OK;
  331. }
  332. int CheckInputSize(const CNodePtr &node, const int size) {
  333. if (static_cast<int>(node->inputs().size()) != size) {
  334. MS_LOG(ERROR) << "The input size of node must be " << size << ", but it is" << node->inputs().size();
  335. lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_INVALID_OP_ATTR);
  336. return lite::RET_INVALID_OP_ATTR;
  337. }
  338. return lite::RET_OK;
  339. }
  340. int CheckLeastInputSize(const CNodePtr &node, const int size) {
  341. if (static_cast<int>(node->inputs().size()) < size) {
  342. MS_LOG(ERROR) << "The input size of node must be " << size << ", but it is" << node->inputs().size();
  343. lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_INVALID_OP_ATTR);
  344. return lite::RET_INVALID_OP_ATTR;
  345. }
  346. return lite::RET_OK;
  347. }
  348. ParameterPtr AddNewBiasNode(float *bias_data, const FuncGraphPtr &func_graph, int kernel_num,
  349. const ParamValueLitePtr &weight_tensor) {
  350. auto bias_parameter = func_graph->add_parameter();
  351. MS_ASSERT(bias_parameter != nullptr);
  352. std::vector<int> shape = {kernel_num};
  353. std::vector<int64_t> shape_vector;
  354. (void)std::transform(shape.begin(), shape.end(), std::back_inserter(shape_vector),
  355. [](const int32_t &value) { return static_cast<int64_t>(value); });
  356. auto abstract_tensor =
  357. std::make_shared<abstract::AbstractTensor>(TypeIdToType(weight_tensor->tensor_type()), shape_vector);
  358. bias_parameter->set_abstract(abstract_tensor);
  359. ParamValueLitePtr param_value = std::make_shared<ParamValueLite>();
  360. MS_ASSERT(param_value != nullptr);
  361. param_value->SetTensorData(bias_data, kernel_num * sizeof(float) / sizeof(uint8_t));
  362. param_value->set_format(weight_tensor->format());
  363. param_value->set_tensor_type(weight_tensor->tensor_type());
  364. param_value->set_tensor_shape(shape);
  365. bias_parameter->set_default_param(param_value);
  366. return bias_parameter;
  367. }
  368. schema::PrimitiveType GetCNodeType(const BaseRef &n) {
  369. ValueNodePtr value_node;
  370. if (utils::isa<CNodePtr>(n)) {
  371. auto in = utils::cast<CNodePtr>(n);
  372. value_node = in->input(0)->cast<ValueNodePtr>();
  373. } else if (utils::isa<ValueNodePtr>(n)) {
  374. value_node = utils::cast<ValueNodePtr>(n);
  375. } else {
  376. MS_LOG(INFO) << "only value node or cnode has type";
  377. return schema::PrimitiveType_NONE;
  378. }
  379. if (value_node == nullptr) {
  380. lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
  381. return schema::PrimitiveType_NONE;
  382. }
  383. auto value = value_node->value();
  384. MS_ASSERT(value != nullptr);
  385. if (utils::isa<PrimitiveCPtr>(value)) {
  386. auto primitive = value->cast<PrimitiveCPtr>();
  387. MS_ASSERT(primitive != nullptr);
  388. return (schema::PrimitiveType)primitive->Type();
  389. } else if (utils::isa<Primitive>(value)) {
  390. auto primitive = value->cast<PrimitivePtr>();
  391. MS_ASSERT(primitive != nullptr);
  392. MS_LOG(INFO) << "anf primitive node type:" << primitive->name();
  393. return schema::PrimitiveType_NONE;
  394. }
  395. return schema::PrimitiveType_NONE;
  396. }
  397. ParamValueLitePtr GetLiteParamValue(const AnfNodePtr &node) {
  398. MS_ASSERT(node != nullptr);
  399. if (!utils::isa<ParameterPtr>(node)) {
  400. if (utils::isa<ValueNodePtr>(node)) {
  401. auto valueNode = node->cast<ValueNodePtr>();
  402. auto value = std::dynamic_pointer_cast<ParamValueLite>(valueNode->value());
  403. if (value != nullptr) {
  404. return value;
  405. }
  406. }
  407. MS_LOG(DEBUG) << "get lite param value node neither parameternode or valuenode";
  408. return nullptr;
  409. }
  410. auto param = node->cast<ParameterPtr>();
  411. MS_ASSERT(param != nullptr);
  412. auto param_value = std::dynamic_pointer_cast<ParamValueLite>(param->default_param());
  413. return param_value;
  414. }
  415. AbstractBasePtr GetCNodeInputAbstract(const CNodePtr &cnode, size_t index) {
  416. if (cnode == nullptr) {
  417. MS_LOG(ERROR) << "CNodePtr is nullptr";
  418. return nullptr;
  419. }
  420. auto inputs = cnode->inputs();
  421. if (!(0 < index && index < inputs.size())) {
  422. return nullptr;
  423. }
  424. auto input = inputs[index];
  425. if (input == nullptr) {
  426. MS_LOG(ERROR) << "CNode input is nullptr";
  427. return nullptr;
  428. }
  429. AbstractBasePtr abstract = nullptr;
  430. if (utils::isa<ParameterPtr>(input)) {
  431. auto parameter = input->cast<ParameterPtr>();
  432. abstract = parameter->abstract();
  433. } else if (utils::isa<CNodePtr>(input)) {
  434. auto input_cnode = input->cast<CNodePtr>();
  435. if (GetCNodeType(input_cnode) == schema::PrimitiveType_TupleGetItem) {
  436. auto tuple_inputs = input_cnode->inputs();
  437. MS_ASSERT(tuple_inputs.size() == kTupleGetItemInputSize);
  438. auto get_item_input_cnode = tuple_inputs.at(1);
  439. MS_ASSERT(get_item_input_cnode != nullptr);
  440. auto idx = GetTupleGetItemOutIndex(input_cnode);
  441. if (!utils::isa<abstract::AbstractTuplePtr>(get_item_input_cnode->abstract())) {
  442. MS_LOG(ERROR) << "TupleGetItem's abstract is not AbstractTuple";
  443. return nullptr;
  444. }
  445. auto abstract_tuple = utils::cast<abstract::AbstractTuplePtr>(get_item_input_cnode->abstract());
  446. auto abstract_list = abstract_tuple->elements();
  447. if (abstract_list.size() <= idx) {
  448. MS_LOG(ERROR) << "AbstractTuple's size is smaller than expect";
  449. return nullptr;
  450. }
  451. abstract = abstract_list[idx];
  452. } else {
  453. abstract = input_cnode->abstract();
  454. }
  455. } else {
  456. MS_LOG(ERROR) << "unsupported input node type";
  457. return nullptr;
  458. }
  459. return abstract;
  460. }
  461. bool IsParamNode(const BaseRef &n) {
  462. if (!utils::isa<ParameterPtr>(n)) {
  463. return false;
  464. }
  465. auto param = utils::cast<ParameterPtr>(n)->default_param();
  466. auto tensor = std::dynamic_pointer_cast<ParamValueLite>(param);
  467. if (tensor == nullptr) {
  468. return false;
  469. }
  470. return tensor->tensor_addr() != nullptr;
  471. }
  472. bool IsConvNode(const BaseRef &n) {
  473. if (utils::isa<CNodePtr>(n) || utils::isa<ValueNodePtr>(n)) {
  474. auto type = opt::GetCNodeType(n);
  475. return type == schema::PrimitiveType_Conv2D || type == schema::PrimitiveType_DepthwiseConv2D ||
  476. type == schema::PrimitiveType_DeConv2D;
  477. }
  478. return false;
  479. }
  480. bool IsPoolingNode(const BaseRef &n) {
  481. if (utils::isa<CNodePtr>(n) || utils::isa<ValueNodePtr>(n)) {
  482. auto type = opt::GetCNodeType(n);
  483. return type == schema::PrimitiveType_Pooling;
  484. }
  485. return false;
  486. }
  487. bool IsActivationNode(const BaseRef &n) {
  488. if (utils::isa<CNodePtr>(n) || utils::isa<ValueNodePtr>(n)) {
  489. auto type = opt::GetCNodeType(n);
  490. return type == schema::PrimitiveType_Activation;
  491. }
  492. return false;
  493. }
  494. bool IsQuantNode(const BaseRef &n) {
  495. if (utils::isa<CNodePtr>(n) || utils::isa<ValueNodePtr>(n)) {
  496. auto type = opt::GetCNodeType(n);
  497. return type == schema::PrimitiveType_QuantDTypeCast;
  498. }
  499. return false;
  500. }
  501. bool CheckIsAllInputsParam(const AnfNodePtr &node) {
  502. if (node == nullptr) {
  503. lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
  504. return 0;
  505. }
  506. if (utils::isa<CNode>(node)) {
  507. auto cnode = node->cast<CNodePtr>();
  508. for (size_t i = 1; i < cnode->inputs().size(); i++) {
  509. if (!utils::isa<Parameter>(cnode->input(i)) && !utils::isa<ValueNodePtr>(cnode->input(i))) {
  510. return false;
  511. }
  512. }
  513. return true;
  514. }
  515. return false;
  516. }
  517. size_t GetOutputTensorNum(const AnfNodePtr &node) {
  518. if (node == nullptr) {
  519. lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
  520. return 0;
  521. }
  522. auto type = node->Type();
  523. if (type == nullptr) {
  524. return 1;
  525. }
  526. if (type->isa<Tuple>()) {
  527. auto tuple_type = type->cast<TuplePtr>();
  528. if (tuple_type == nullptr) {
  529. lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
  530. return 0;
  531. }
  532. return tuple_type->size();
  533. } else if (type->isa<TensorType>() || type->isa<Number>()) {
  534. return 1;
  535. } else if (type->isa<TypeNone>()) {
  536. return 0;
  537. } else {
  538. return 1;
  539. }
  540. }
  541. bool IsMultiOutputTensors(const FuncGraphPtr &graph, const AnfNodePtr &node) {
  542. if (node == nullptr || graph == nullptr) {
  543. lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
  544. return 0;
  545. }
  546. auto output_node_list = GetRealNodeUsedList(graph, node);
  547. if (output_node_list == nullptr) {
  548. MS_LOG(ERROR) << "output node list is nullptr";
  549. return false;
  550. }
  551. if (output_node_list->size() != 1) {
  552. MS_LOG(DEBUG) << "fusion node has multi output nodes";
  553. return true;
  554. }
  555. return false;
  556. }
  557. std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> GetRealNodeUsedList(const FuncGraphPtr &graph,
  558. const AnfNodePtr &node) {
  559. auto output_node_list = std::make_shared<std::vector<std::pair<AnfNodePtr, int>>>();
  560. if (graph == nullptr || node == nullptr) {
  561. lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
  562. return nullptr;
  563. }
  564. auto manager = graph->manager();
  565. if (manager == nullptr) {
  566. lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
  567. return nullptr;
  568. }
  569. auto iter = manager->node_users().find(node);
  570. if (iter == manager->node_users().end()) {
  571. MS_LOG(ERROR) << "node has no output in manager";
  572. lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_ERROR);
  573. return nullptr;
  574. }
  575. auto output_info_list = iter->second;
  576. std::copy(output_info_list.begin(), output_info_list.end(), std::back_inserter(*output_node_list));
  577. return output_node_list;
  578. }
  579. size_t GetTupleGetItemOutIndex(const CNodePtr &tuple_get_item) {
  580. MS_ASSERT(tuple_get_item != nullptr);
  581. if (tuple_get_item->size() != kTupleGetItemInputSize) {
  582. MS_LOG(ERROR) << "The node tuple_get_item must have 2 inputs!";
  583. return -1;
  584. }
  585. auto output_index_value_node = tuple_get_item->input(kInputNodeOutputIndexInTupleGetItem);
  586. MS_ASSERT(output_index_value_node != nullptr);
  587. auto value_node = output_index_value_node->cast<ValueNodePtr>();
  588. MS_ASSERT(value_node != nullptr);
  589. return IntToSize(lite::CastToInt(value_node->value()).front());
  590. }
  591. std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> GetRealNodeUsedListByOutputIdx(const FuncGraphPtr &graph,
  592. const AnfNodePtr &node,
  593. size_t output_index) {
  594. MS_ASSERT(graph != nullptr);
  595. MS_ASSERT(node != nullptr);
  596. auto output_node_list = std::make_shared<std::vector<std::pair<AnfNodePtr, int>>>();
  597. auto manager = graph->manager();
  598. MS_ASSERT(manager != nullptr);
  599. auto iter = manager->node_users().find(node);
  600. if (iter == manager->node_users().end()) {
  601. MS_LOG(ERROR) << "node has no output in manager";
  602. return output_node_list;
  603. }
  604. auto output_info_list = iter->second;
  605. for (const auto &output_info : output_info_list) {
  606. size_t used_output_index;
  607. if (GetCNodeType(output_info.first) == schema::PrimitiveType_TupleGetItem) {
  608. used_output_index = GetTupleGetItemOutIndex(utils::cast<CNodePtr>(output_info.first));
  609. } else if (GetCNodeType(node) == schema::PrimitiveType_TupleGetItem) {
  610. used_output_index = output_index;
  611. } else {
  612. if (output_index != 0) {
  613. MS_LOG(ERROR) << "node has no output in manager";
  614. return output_node_list;
  615. }
  616. return output_node_list;
  617. }
  618. if (used_output_index == output_index) {
  619. output_node_list->push_back(output_info);
  620. }
  621. }
  622. return output_node_list;
  623. }
  624. STATUS GetFilterDim(const std::vector<int32_t> &oriDims, kTransFilterType type, int32_t *filterK, int32_t *filterC,
  625. int32_t *filterH, int32_t *filterW) {
  626. MS_ASSERT(oriDims.size() == 4);
  627. std::unordered_map<kTransFilterType, int> maps = {
  628. {kKCHW2HWCK, 1}, {kKCHW2HWKC, 1}, {kKCHW2KHWC, 1}, {kKCHW2CKHW, 1}, {kCKHW2HWCK, 2},
  629. {kCKHW2HWKC, 2}, {kCKHW2KHWC, 2}, {kHWCK2KCHW, 3}, {kHWCK2CKHW, 3}, {kHWCK2KHWC, 3},
  630. {kHWKC2KCHW, 4}, {kHWKC2CKHW, 4}, {kHWKC2KHWC, 4}, {kNHWC2KCHW, 5}, {kNHWC2HWCK, 5},
  631. {kNHWC2CKHW, 5}, {kCHWK2HWCK, 6}, {kCHWK2KHWC, 6}, {kKHWC2HWCK, 7}, {kKHWC2CHWK, 7},
  632. };
  633. if (maps.find(type) == maps.end()) {
  634. MS_LOG(ERROR) << "Unsupported transFilterType: " << type;
  635. return RET_ERROR;
  636. }
  637. switch (maps.find(type)->second) {
  638. case 1:
  639. *filterK = oriDims.at(lite::KCHW_K);
  640. *filterC = oriDims.at(lite::KCHW_C);
  641. *filterH = oriDims.at(lite::KCHW_H);
  642. *filterW = oriDims.at(lite::KCHW_W);
  643. break;
  644. case 2:
  645. *filterC = oriDims.at(lite::CKHW_C);
  646. *filterK = oriDims.at(lite::CKHW_K);
  647. *filterH = oriDims.at(lite::CKHW_H);
  648. *filterW = oriDims.at(lite::CKHW_W);
  649. break;
  650. case 3:
  651. *filterH = oriDims.at(lite::HWCK_H);
  652. *filterW = oriDims.at(lite::HWCK_W);
  653. *filterC = oriDims.at(lite::HWCK_C);
  654. *filterK = oriDims.at(lite::HWCK_K);
  655. break;
  656. case 4:
  657. *filterH = oriDims.at(lite::HWKC_H);
  658. *filterW = oriDims.at(lite::HWKC_W);
  659. *filterK = oriDims.at(lite::HWKC_K);
  660. *filterC = oriDims.at(lite::HWKC_C);
  661. break;
  662. case 5:
  663. *filterK = oriDims.at(lite::NHWC_N);
  664. *filterH = oriDims.at(lite::NHWC_H);
  665. *filterW = oriDims.at(lite::NHWC_W);
  666. *filterC = oriDims.at(lite::NHWC_C);
  667. break;
  668. case 6:
  669. *filterC = oriDims.at(lite::CHWK_C);
  670. *filterH = oriDims.at(lite::CHWK_H);
  671. *filterW = oriDims.at(lite::CHWK_W);
  672. *filterK = oriDims.at(lite::CHWK_K);
  673. break;
  674. case 7:
  675. *filterK = oriDims.at(lite::KHWC_K);
  676. *filterH = oriDims.at(lite::KHWC_H);
  677. *filterW = oriDims.at(lite::KHWC_W);
  678. *filterC = oriDims.at(lite::KHWC_C);
  679. break;
  680. default:
  681. MS_LOG(ERROR) << "Unsupported transFilterType: " << type;
  682. return RET_ERROR;
  683. }
  684. return RET_OK;
  685. }
  686. STATUS SetFilterDim(const ParamValueLitePtr &tensor, kTransFilterType type, int32_t filterK, int32_t filterC,
  687. int32_t filterH, int32_t filterW) {
  688. MS_ASSERT(tensor != nullptr);
  689. std::unordered_map<kTransFilterType, int> maps = {
  690. {kKCHW2HWCK, 1}, {kCKHW2HWCK, 1}, {kNHWC2HWCK, 1}, {kKHWC2HWCK, 1}, {kCHWK2HWCK, 1},
  691. {kKCHW2HWKC, 2}, {kCKHW2HWKC, 2}, {kHWCK2KCHW, 3}, {kHWKC2KCHW, 3}, {kNHWC2KCHW, 3},
  692. {kHWCK2CKHW, 4}, {kHWKC2CKHW, 4}, {kNHWC2CKHW, 4}, {kKCHW2CKHW, 4}, {kKHWC2CHWK, 5},
  693. {kKCHW2KHWC, 6}, {kCKHW2KHWC, 6}, {kCHWK2KHWC, 6}, {kHWCK2KHWC, 6}, {kHWKC2KHWC, 6},
  694. };
  695. if (maps.find(type) == maps.end()) {
  696. MS_LOG(ERROR) << "Unsupported transFilterType: " << type;
  697. return RET_ERROR;
  698. }
  699. switch (maps.find(type)->second) {
  700. case 1:
  701. tensor->set_tensor_shape({filterH, filterW, filterC, filterK});
  702. break;
  703. case 2:
  704. tensor->set_tensor_shape({filterH, filterW, filterK, filterC});
  705. break;
  706. case 3:
  707. tensor->set_tensor_shape({filterK, filterC, filterH, filterW});
  708. break;
  709. case 4:
  710. tensor->set_tensor_shape({filterC, filterK, filterH, filterW});
  711. break;
  712. case 5:
  713. tensor->set_tensor_shape({filterC, filterH, filterW, filterK});
  714. break;
  715. case 6:
  716. tensor->set_tensor_shape({filterK, filterH, filterW, filterC});
  717. break;
  718. default:
  719. MS_LOG(ERROR) << "Unsupported transFilterType: " << type;
  720. return RET_ERROR;
  721. }
  722. return RET_OK;
  723. }
  724. template <typename T>
  725. void TransFilterDataCHWK(kTransFilterType type, int32_t filterK, int32_t filterC, int32_t filterH, int32_t filterW,
  726. T *p1Buff, T *p2Buff, T *weightData, const std::unique_ptr<T[]> &buf) {
  727. for (int c = 0; c < filterC; ++c) {
  728. for (int h = 0; h < filterH; ++h) {
  729. for (int w = 0; w < filterW; ++w) {
  730. for (int k = 0; k < filterK; ++k) {
  731. p1Buff = weightData + ((c * filterH * filterW * filterK) + (h * filterW * filterK) + (w * filterK) + (k));
  732. if (type == kCHWK2HWCK) {
  733. p2Buff = buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k));
  734. } else if (type == kCHWK2KHWC) {
  735. p2Buff = buf.get() + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c));
  736. }
  737. *p2Buff = *p1Buff;
  738. }
  739. }
  740. }
  741. }
  742. }
  743. template <typename T>
  744. void TransFilterDataKHWC(kTransFilterType type, int32_t filterK, int32_t filterC, int32_t filterH, int32_t filterW,
  745. T *p1Buff, T *p2Buff, T *weightData, const std::unique_ptr<T[]> &buf) {
  746. for (int k = 0; k < filterK; ++k) {
  747. for (int h = 0; h < filterH; ++h) {
  748. for (int w = 0; w < filterW; ++w) {
  749. for (int c = 0; c < filterC; ++c) {
  750. p1Buff = weightData + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c));
  751. p2Buff = buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k));
  752. *p2Buff = *p1Buff;
  753. }
  754. }
  755. }
  756. }
  757. }
  758. template <typename T>
  759. void TransFilterDataKCHW(kTransFilterType type, int32_t filterK, int32_t filterC, int32_t filterH, int32_t filterW,
  760. T *p1Buff, T *p2Buff, T *weightData, const std::unique_ptr<T[]> &buf) {
  761. for (int k = 0; k < filterK; ++k) {
  762. for (int c = 0; c < filterC; ++c) {
  763. for (int h = 0; h < filterH; ++h) {
  764. for (int w = 0; w < filterW; ++w) {
  765. p1Buff = weightData + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w));
  766. if (type == kKCHW2HWCK) {
  767. p2Buff = buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k));
  768. } else if (type == kKCHW2KHWC) {
  769. p2Buff = buf.get() + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c));
  770. } else if (type == kKCHW2CKHW) {
  771. p2Buff = buf.get() + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w));
  772. } else {
  773. p2Buff = buf.get() + ((h * filterW * filterK * filterC) + (w * filterK * filterC) + (k * filterC) + (c));
  774. }
  775. *p2Buff = *p1Buff;
  776. }
  777. }
  778. }
  779. }
  780. }
  781. template <typename T>
  782. void TransFilterDataCKHW(kTransFilterType type, int32_t filterK, int32_t filterC, int32_t filterH, int32_t filterW,
  783. T *p1Buff, T *p2Buff, T *weightData, const std::unique_ptr<T[]> &buf) {
  784. for (int c = 0; c < filterC; ++c) {
  785. for (int k = 0; k < filterK; ++k) {
  786. for (int h = 0; h < filterH; ++h) {
  787. for (int w = 0; w < filterW; ++w) {
  788. p1Buff = weightData + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w));
  789. if (type == kCKHW2HWCK) {
  790. p2Buff = buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k));
  791. } else if (type == kCKHW2KHWC) {
  792. p2Buff = buf.get() + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c));
  793. } else {
  794. p2Buff = buf.get() + ((h * filterW * filterK * filterC) + (w * filterK * filterC) + (k * filterC) + (c));
  795. }
  796. *p2Buff = *p1Buff;
  797. }
  798. }
  799. }
  800. }
  801. }
  802. template <typename T>
  803. void TransFilterDataHWCK(kTransFilterType type, int32_t filterK, int32_t filterC, int32_t filterH, int32_t filterW,
  804. T *p1Buff, T *p2Buff, T *weightData, const std::unique_ptr<T[]> &buf) {
  805. for (int h = 0; h < filterH; ++h) {
  806. for (int w = 0; w < filterW; ++w) {
  807. for (int c = 0; c < filterC; ++c) {
  808. for (int k = 0; k < filterK; ++k) {
  809. p1Buff = weightData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k));
  810. if (type == kHWCK2KCHW) {
  811. p2Buff = buf.get() + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w));
  812. } else if (type == kHWCK2CKHW) {
  813. p2Buff = buf.get() + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w));
  814. } else {
  815. p2Buff = buf.get() + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c));
  816. }
  817. *p2Buff = *p1Buff;
  818. }
  819. }
  820. }
  821. }
  822. }
  823. template <typename T>
  824. void TransFilterDataHWKC(kTransFilterType type, int32_t filterK, int32_t filterC, int32_t filterH, int32_t filterW,
  825. T *p1Buff, T *p2Buff, T *weightData, const std::unique_ptr<T[]> &buf) {
  826. for (int h = 0; h < filterH; ++h) {
  827. for (int w = 0; w < filterW; ++w) {
  828. for (int c = 0; c < filterC; ++c) {
  829. for (int k = 0; k < filterK; ++k) {
  830. p1Buff = weightData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (k * filterC) + (c));
  831. if (type == kHWKC2KCHW) {
  832. p2Buff = buf.get() + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w));
  833. } else {
  834. p2Buff = buf.get() + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w));
  835. }
  836. *p2Buff = *p1Buff;
  837. }
  838. }
  839. }
  840. }
  841. }
  842. template <typename T>
  843. void TransFilterDataNHWC(kTransFilterType type, int32_t filterK, int32_t filterC, int32_t filterH, int32_t filterW,
  844. T *p1Buff, T *p2Buff, T *weightData, const std::unique_ptr<T[]> &buf) {
  845. for (int k = 0; k < filterK; ++k) {
  846. for (int h = 0; h < filterH; ++h) {
  847. for (int w = 0; w < filterW; ++w) {
  848. for (int c = 0; c < filterC; ++c) {
  849. p1Buff = weightData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (k * filterC) + (c));
  850. if (type == kNHWC2HWCK) {
  851. p2Buff = buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k));
  852. } else if (type == kNHWC2CKHW) {
  853. p2Buff = buf.get() + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w));
  854. } else {
  855. p2Buff = buf.get() + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w));
  856. }
  857. *p2Buff = *p1Buff;
  858. }
  859. }
  860. }
  861. }
  862. }
  863. template <typename T>
  864. void TransFilterDataKHWC2CHWK(kTransFilterType type, int32_t filterK, int32_t filterC, int32_t filterH, int32_t filterW,
  865. T *p1Buff, T *p2Buff, T *weightData, const std::unique_ptr<T[]> &buf) {
  866. for (int k = 0; k < filterK; ++k) {
  867. for (int h = 0; h < filterH; ++h) {
  868. for (int w = 0; w < filterW; ++w) {
  869. for (int c = 0; c < filterC; ++c) {
  870. p1Buff = weightData + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c));
  871. p2Buff = buf.get() + ((c * filterK * filterH * filterW) + (h * filterK * filterW) + (w * filterK) + (k));
  872. *p2Buff = *p1Buff;
  873. }
  874. }
  875. }
  876. }
  877. }
  878. template <typename T>
  879. static STATUS TransFilterData(const ParamValueLitePtr &tensor, kTransFilterType type, int32_t filterK, int32_t filterC,
  880. int32_t filterH, int32_t filterW) {
  881. MS_ASSERT(tensor != nullptr);
  882. int count = filterH * filterW * filterC * filterK;
  883. if (count <= 0) {
  884. MS_LOG(ERROR) << "Dim size invalid";
  885. return RET_ERROR;
  886. }
  887. std::unique_ptr<T[]> buf(new (std::nothrow) T[count]);
  888. if (buf == nullptr) {
  889. MS_LOG(ERROR) << "new buf failed";
  890. return RET_ERROR;
  891. }
  892. void *originWeightData = tensor->tensor_addr();
  893. T *weightData = static_cast<T *>(originWeightData);
  894. if (weightData == nullptr) {
  895. MS_LOG(ERROR) << "weightData is nullptr";
  896. return RET_ERROR;
  897. }
  898. T *p1Buff = nullptr;
  899. T *p2Buff = nullptr;
  900. std::unordered_map<kTransFilterType, int> maps = {
  901. {kCHWK2HWCK, 1}, {kCHWK2KHWC, 1}, {kKHWC2HWCK, 2}, {kKCHW2HWCK, 3}, {kKCHW2CKHW, 3},
  902. {kKCHW2KHWC, 3}, {kKCHW2HWKC, 3}, {kCKHW2HWCK, 4}, {kCKHW2KHWC, 4}, {kCKHW2HWKC, 4},
  903. {kHWCK2KCHW, 5}, {kHWCK2CKHW, 5}, {kHWCK2KHWC, 5}, {kHWKC2KCHW, 6}, {kHWKC2KHWC, 6},
  904. {kHWKC2CKHW, 6}, {kNHWC2HWCK, 7}, {kNHWC2KCHW, 7}, {kNHWC2CKHW, 7}, {kKHWC2CHWK, 8},
  905. };
  906. if (maps.find(type) == maps.end()) {
  907. MS_LOG(ERROR) << "Unsupported transFilterType: " << type;
  908. return RET_ERROR;
  909. }
  910. switch (maps.find(type)->second) {
  911. case 1: {
  912. TransFilterDataCHWK(type, filterK, filterC, filterH, filterW, p1Buff, p2Buff, weightData, buf);
  913. } break;
  914. case 2: {
  915. TransFilterDataKHWC(type, filterK, filterC, filterH, filterW, p1Buff, p2Buff, weightData, buf);
  916. } break;
  917. case 3: {
  918. TransFilterDataKCHW(type, filterK, filterC, filterH, filterW, p1Buff, p2Buff, weightData, buf);
  919. } break;
  920. case 4: {
  921. TransFilterDataCKHW(type, filterK, filterC, filterH, filterW, p1Buff, p2Buff, weightData, buf);
  922. } break;
  923. case 5: {
  924. TransFilterDataHWCK(type, filterK, filterC, filterH, filterW, p1Buff, p2Buff, weightData, buf);
  925. } break;
  926. case 6: {
  927. TransFilterDataHWKC(type, filterK, filterC, filterH, filterW, p1Buff, p2Buff, weightData, buf);
  928. } break;
  929. case 7: {
  930. TransFilterDataNHWC(type, filterK, filterC, filterH, filterW, p1Buff, p2Buff, weightData, buf);
  931. } break;
  932. case 8: {
  933. TransFilterDataKHWC2CHWK(type, filterK, filterC, filterH, filterW, p1Buff, p2Buff, weightData, buf);
  934. } break;
  935. default: {
  936. MS_LOG(ERROR) << "Unsupported transFilterType: " << type;
  937. return RET_ERROR;
  938. }
  939. }
  940. auto ret = ::memcpy_s(tensor->tensor_addr(), count * sizeof(T), buf.get(), count * sizeof(T));
  941. if (ret != EOK) {
  942. MS_LOG(ERROR) << "memcpy_s failed: " << ret;
  943. return RET_ERROR;
  944. }
  945. return RET_OK;
  946. }
  947. template <typename T>
  948. static STATUS TransFilterFormat(const ParamValueLitePtr &tensor, kTransFilterType type) {
  949. MS_ASSERT(tensor != nullptr);
  950. auto oriDims = tensor->tensor_shape();
  951. if (oriDims.size() != (size_t)lite::DIM_DEFAULT_SIZE) {
  952. MS_LOG(ERROR) << "Filter dim-num is not supported, dim-num: " << oriDims.size();
  953. return lite::RET_ERROR;
  954. }
  955. int32_t filterH;
  956. int32_t filterW;
  957. int32_t filterC;
  958. int32_t filterK;
  959. auto status = GetFilterDim(oriDims, type, &filterK, &filterC, &filterH, &filterW);
  960. if (status != lite::RET_OK) {
  961. MS_LOG(ERROR) << "GetFilterDim failed: " << status;
  962. return status;
  963. }
  964. status = SetFilterDim(tensor, type, filterK, filterC, filterH, filterW);
  965. if (status != lite::RET_OK) {
  966. MS_LOG(ERROR) << "SetFilterDim failed: " << status;
  967. return status;
  968. }
  969. status = TransFilterData<T>(tensor, type, filterK, filterC, filterH, filterW);
  970. if (status != lite::RET_OK) {
  971. MS_LOG(ERROR) << "TransFilterData failed: " << status;
  972. return status;
  973. }
  974. return lite::RET_OK;
  975. }
  976. STATUS TransFilterFormatWithType(const ParamValueLitePtr &tensor, TypeId data_type,
  977. kTransFilterType trans_filter_type) {
  978. if (data_type == kNumberTypeFloat32) {
  979. return TransFilterFormat<float>(tensor, trans_filter_type);
  980. } else if (data_type == kNumberTypeUInt8) {
  981. return TransFilterFormat<uint8_t>(tensor, trans_filter_type);
  982. } else if (data_type == kNumberTypeInt8) {
  983. return TransFilterFormat<int8_t>(tensor, trans_filter_type);
  984. } else if (data_type == kNumberTypeFloat16) {
  985. return TransFilterFormat<float16>(tensor, trans_filter_type);
  986. } else {
  987. MS_LOG(ERROR) << "Unsupported data_type: " << data_type;
  988. return RET_ERROR;
  989. }
  990. }
  991. STATUS TransFilterFormat(const ParamValueLitePtr &tensor, schema::Format dst_format) {
  992. if (tensor == nullptr) {
  993. return lite::RET_NULL_PTR;
  994. }
  995. auto ori_dims = tensor->tensor_shape();
  996. if (ori_dims.size() != (size_t)lite::DIM_DEFAULT_SIZE) {
  997. MS_LOG(ERROR) << "Filter dim-num is not supported, dim-num: " << ori_dims.size();
  998. return lite::RET_ERROR;
  999. }
  1000. auto src_format = tensor->format();
  1001. auto data_type = tensor->tensor_type();
  1002. lite::STATUS status;
  1003. std::unordered_map<schema::Format, kTransFilterType> khwc_trans_maps = {
  1004. {schema::Format::Format_KCHW, kKCHW2KHWC}, {schema::Format::Format_CKHW, kCKHW2KHWC},
  1005. {schema::Format::Format_CHWK, kCHWK2KHWC}, {schema::Format::Format_HWCK, kHWCK2KHWC},
  1006. {schema::Format::Format_HWKC, kHWKC2KHWC},
  1007. };
  1008. std::unordered_map<schema::Format, kTransFilterType> hwck_trans_maps = {
  1009. {schema::Format::Format_KCHW, kKCHW2HWCK},
  1010. {schema::Format::Format_KHWC, kKHWC2HWCK},
  1011. {schema::Format::Format_CKHW, kCKHW2HWCK},
  1012. {schema::Format::Format_CHWK, kCHWK2HWCK},
  1013. };
  1014. std::unordered_map<schema::Format, kTransFilterType> kchw_trans_maps = {
  1015. {schema::Format::Format_HWCK, kHWCK2KCHW}, {schema::Format::Format_HWKC, kHWKC2KCHW},
  1016. {schema::Format::Format_KHWC, kKHWC2KCHW}, {schema::Format::Format_CKHW, kCKHW2KCHW},
  1017. {schema::Format::Format_CHWK, kCHWK2KCHW},
  1018. };
  1019. std::unordered_map<schema::Format, kTransFilterType> ckhw_trans_maps = {{schema::Format::Format_HWCK, kHWCK2CKHW},
  1020. {schema::Format::Format_HWKC, kHWKC2CKHW},
  1021. {schema::Format::Format_KCHW, kKCHW2CKHW}};
  1022. std::unordered_map<schema::Format, kTransFilterType> chwk_trans_maps = {{schema::Format::Format_KHWC, kKHWC2CHWK}};
  1023. if (src_format == dst_format) {
  1024. return RET_OK;
  1025. }
  1026. switch (dst_format) {
  1027. case schema::Format::Format_KHWC: {
  1028. if (khwc_trans_maps.find(static_cast<const schema::Format>(src_format)) == khwc_trans_maps.end()) {
  1029. MS_LOG(ERROR) << "Unsupported transform from " << EnumNameFormat(static_cast<schema::Format>(src_format))
  1030. << " to " << EnumNameFormat(dst_format);
  1031. return RET_ERROR;
  1032. } else {
  1033. status = TransFilterFormatWithType(tensor, data_type,
  1034. khwc_trans_maps.find(static_cast<const schema::Format>(src_format))->second);
  1035. }
  1036. } break;
  1037. case schema::Format::Format_HWCK: {
  1038. if (hwck_trans_maps.find(static_cast<const schema::Format>(src_format)) == hwck_trans_maps.end()) {
  1039. MS_LOG(ERROR) << "Unsupported transform from " << EnumNameFormat(static_cast<schema::Format>(src_format))
  1040. << " to " << EnumNameFormat(dst_format);
  1041. return RET_ERROR;
  1042. } else {
  1043. status = TransFilterFormatWithType(tensor, data_type,
  1044. hwck_trans_maps.find(static_cast<const schema::Format>(src_format))->second);
  1045. }
  1046. } break;
  1047. case schema::Format::Format_KCHW: {
  1048. if (kchw_trans_maps.find(static_cast<const schema::Format>(src_format)) == kchw_trans_maps.end()) {
  1049. MS_LOG(ERROR) << "Unsupported transform from " << EnumNameFormat(static_cast<schema::Format>(src_format))
  1050. << " to " << EnumNameFormat(dst_format);
  1051. return RET_ERROR;
  1052. } else {
  1053. status = TransFilterFormatWithType(tensor, data_type,
  1054. kchw_trans_maps.find(static_cast<const schema::Format>(src_format))->second);
  1055. }
  1056. } break;
  1057. case schema::Format::Format_CKHW: {
  1058. if (ckhw_trans_maps.find(static_cast<const schema::Format>(src_format)) == ckhw_trans_maps.end()) {
  1059. MS_LOG(ERROR) << "Unsupported transform from " << EnumNameFormat(static_cast<schema::Format>(src_format))
  1060. << " to " << EnumNameFormat(dst_format);
  1061. return RET_ERROR;
  1062. } else {
  1063. status = TransFilterFormatWithType(tensor, data_type,
  1064. ckhw_trans_maps.find(static_cast<const schema::Format>(src_format))->second);
  1065. }
  1066. } break;
  1067. case schema::Format::Format_CHWK: {
  1068. if (chwk_trans_maps.find(static_cast<const schema::Format>(src_format)) == chwk_trans_maps.end()) {
  1069. MS_LOG(ERROR) << "Unsupported transform from " << EnumNameFormat(static_cast<schema::Format>(src_format))
  1070. << " to " << EnumNameFormat(dst_format);
  1071. return RET_ERROR;
  1072. } else {
  1073. status = TransFilterFormatWithType(tensor, data_type,
  1074. chwk_trans_maps.find(static_cast<const schema::Format>(src_format))->second);
  1075. }
  1076. } break;
  1077. default:
  1078. MS_LOG(ERROR) << "Unsupported transform from " << src_format << " to " << EnumNameFormat(dst_format);
  1079. return RET_ERROR;
  1080. }
  1081. if (status != RET_OK) {
  1082. MS_LOG(ERROR) << "TransFilterData failed: " << status;
  1083. return status;
  1084. }
  1085. return RET_OK;
  1086. }
  1087. ParameterPtr BuildIntValueParameterNode(const FuncGraphPtr &func_graph, const int32_t &data,
  1088. const std::string &node_name) {
  1089. MS_ASSERT(func_graph != nullptr);
  1090. MS_ASSERT(data.size() != 0);
  1091. auto param_node = func_graph->add_parameter();
  1092. auto type_ptr = TypeIdToType(kNumberTypeInt32);
  1093. auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr);
  1094. param_node->set_abstract(abstract_tensor);
  1095. param_node->set_name(node_name);
  1096. ParamValueLitePtr param_value = std::make_shared<ParamValueLite>();
  1097. MS_ASSERT(param_value != nullptr);
  1098. param_value->set_tensor_shape({1});
  1099. param_value->set_tensor_type(kNumberTypeInt32);
  1100. char *default_data = new (std::nothrow) char[sizeof(int32_t)];
  1101. *(reinterpret_cast<int32_t *>(default_data)) = data;
  1102. param_value->SetTensorData(default_data, sizeof(int32_t));
  1103. param_node->set_default_param(param_value);
  1104. return param_node;
  1105. }
  1106. ParameterPtr BuildIntVecParameterNode(const FuncGraphPtr &func_graph, const std::vector<int32_t> &data,
  1107. const std::string &node_name) {
  1108. MS_ASSERT(func_graph != nullptr);
  1109. MS_ASSERT(data.size() != 0);
  1110. auto param_node = func_graph->add_parameter();
  1111. auto type_ptr = TypeIdToType(kNumberTypeInt32);
  1112. std::vector<int64_t> shape_vector{static_cast<int64_t>(data.size())};
  1113. auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector);
  1114. param_node->set_abstract(abstract_tensor);
  1115. param_node->set_name(node_name);
  1116. ParamValueLitePtr param_value = std::make_shared<ParamValueLite>();
  1117. MS_ASSERT(param_value != nullptr);
  1118. std::vector<int32_t> shape{static_cast<int32_t>(data.size())};
  1119. param_value->set_tensor_shape(shape);
  1120. param_value->set_tensor_type(kNumberTypeInt32);
  1121. char *default_data = new (std::nothrow) char[data.size() * sizeof(int32_t)];
  1122. if (memcpy_s(default_data, data.size() * sizeof(int32_t), data.data(), data.size() * sizeof(int32_t)) != EOK) {
  1123. MS_LOG(ERROR) << "memcpy data failed.";
  1124. delete[] default_data;
  1125. return nullptr;
  1126. }
  1127. param_value->SetTensorData(default_data, data.size() * sizeof(int32_t));
  1128. param_node->set_default_param(param_value);
  1129. return param_node;
  1130. }
  1131. ParameterPtr BuildIntVec2DParameterNode(const FuncGraphPtr &func_graph, const std::vector<std::vector<int32_t>> &data,
  1132. const std::string &node_name) {
  1133. MS_ASSERT(func_graph != nullptr);
  1134. MS_ASSERT(data.size() != 0);
  1135. auto param_node = func_graph->add_parameter();
  1136. auto type_ptr = TypeIdToType(kNumberTypeInt32);
  1137. std::vector<int64_t> shape_vector;
  1138. shape_vector.push_back(data.size());
  1139. shape_vector.push_back(2);
  1140. auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector);
  1141. param_node->set_abstract(abstract_tensor);
  1142. param_node->set_name(node_name);
  1143. ParamValueLitePtr param_value = std::make_shared<ParamValueLite>();
  1144. MS_ASSERT(param_value != nullptr);
  1145. std::vector<int32_t> shape;
  1146. shape.push_back(data.size());
  1147. shape.push_back(2);
  1148. param_value->set_tensor_shape(shape);
  1149. param_value->set_tensor_type(kNumberTypeInt32);
  1150. std::vector<int32_t> data_1d;
  1151. for (auto pair : data) {
  1152. data_1d.insert(data_1d.end(), pair.begin(), pair.end());
  1153. }
  1154. auto size = data_1d.size() * sizeof(int32_t);
  1155. char *default_data = new (std::nothrow) char[size];
  1156. if (memcpy_s(default_data, size, data_1d.data(), size) != EOK) {
  1157. MS_LOG(ERROR) << "memcpy data failed.";
  1158. delete[] default_data;
  1159. return nullptr;
  1160. }
  1161. param_value->SetTensorData(default_data, size);
  1162. param_node->set_default_param(param_value);
  1163. return param_node;
  1164. }
  1165. ParameterPtr BuildFloatValueParameterNode(const FuncGraphPtr &func_graph, const float &data,
  1166. const std::string &node_name) {
  1167. MS_ASSERT(func_graph != nullptr);
  1168. MS_ASSERT(data.size() != 0);
  1169. auto param_node = func_graph->add_parameter();
  1170. auto type_ptr = TypeIdToType(kNumberTypeFloat32);
  1171. std::vector<int64_t> shape_vector = {1};
  1172. auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector);
  1173. param_node->set_abstract(abstract_tensor);
  1174. param_node->set_name(node_name);
  1175. ParamValueLitePtr param_value = std::make_shared<ParamValueLite>();
  1176. MS_ASSERT(param_value != nullptr);
  1177. param_value->set_tensor_shape({1});
  1178. param_value->set_tensor_type(kNumberTypeFloat32);
  1179. char *default_data = new (std::nothrow) char[sizeof(float)];
  1180. if (memcpy_s(default_data, sizeof(float), &data, sizeof(float)) != EOK) {
  1181. MS_LOG(ERROR) << "memcpy data failed.";
  1182. delete[] default_data;
  1183. return nullptr;
  1184. }
  1185. param_value->SetTensorData(default_data, sizeof(float));
  1186. param_node->set_default_param(param_value);
  1187. return param_node;
  1188. }
  1189. } // namespace opt
  1190. } // namespace mindspore