You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

gllo_utils.cc 45 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "tools/optimizer/common/gllo_utils.h"
  17. #include <vector>
  18. #include <algorithm>
  19. #include <utility>
  20. #include "src/ops/primitive_c.h"
  21. #include "src/common/common.h"
  22. #include "frontend/operator/ops.h"
  23. #include "backend/optimizer/common/helper.h"
  24. namespace mindspore {
  25. namespace opt {
  26. namespace {
  27. constexpr auto kAnfPrimitiveIndex = 0;
  28. bool CheckPrimitiveType(const AnfNodePtr &node, const PrimitivePtr &primitive_type) {
  29. if (node == nullptr) {
  30. lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
  31. return false;
  32. }
  33. if (!node->isa<CNode>()) {
  34. return false;
  35. }
  36. auto cnode = node->cast<CNodePtr>();
  37. if (cnode == nullptr) {
  38. lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
  39. return false;
  40. }
  41. return IsPrimitive(cnode->input(kAnfPrimitiveIndex), primitive_type);
  42. }
  43. bool IsRealKernel(const AnfNodePtr &node) {
  44. if (node == nullptr) {
  45. lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
  46. return false;
  47. }
  48. // parameter and value node is not a real kernel too
  49. if (!node->isa<CNode>()) {
  50. return true;
  51. }
  52. auto cnode = node->cast<CNodePtr>();
  53. if (cnode == nullptr) {
  54. lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
  55. return false;
  56. }
  57. if (cnode->inputs().empty()) {
  58. MS_LOG(ERROR) << "Illegal null input of cnode(%s)" << node->DebugString();
  59. lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_INPUT_TENSOR_ERROR);
  60. return false;
  61. }
  62. auto input = cnode->inputs()[0];
  63. bool is_virtual_node = IsPrimitive(input, prim::kPrimImageSummary) || IsPrimitive(input, prim::kPrimScalarSummary) ||
  64. IsPrimitive(input, prim::kPrimTensorSummary) ||
  65. IsPrimitive(input, prim::kPrimHistogramSummary) || IsPrimitive(input, prim::kPrimMakeTuple) ||
  66. IsPrimitive(input, prim::kPrimStateSetItem) || IsPrimitive(input, prim::kPrimDepend) ||
  67. IsPrimitive(input, prim::kPrimTupleGetItem) || IsPrimitive(input, prim::kPrimControlDepend) ||
  68. IsPrimitive(input, prim::kPrimReturn) || IsPrimitive(input, prim::kPrimPartial);
  69. return !is_virtual_node;
  70. }
  71. ValueNodePtr CreateValueNodeWithSexp(const BaseRef &sexp) {
  72. if (utils::isa<int>(sexp)) {
  73. return NewValueNode(utils::cast<int>(sexp));
  74. }
  75. if (utils::isa<float>(sexp)) {
  76. return NewValueNode(utils::cast<float>(sexp));
  77. }
  78. if (utils::isa<bool>(sexp)) {
  79. return NewValueNode(utils::cast<bool>(sexp));
  80. }
  81. if (utils::isa<ValuePtr>(sexp)) {
  82. return NewValueNode(utils::cast<ValuePtr>(sexp));
  83. }
  84. return nullptr;
  85. }
  86. CNodePtr CreateCNodeWithGraph(const std::vector<AnfNodePtr> &input_nodes, const BaseRef &graph) {
  87. if (utils::isa<FuncGraphPtr>(graph)) {
  88. return std::make_shared<CNode>(input_nodes, utils::cast<FuncGraphPtr>(graph));
  89. }
  90. if (utils::isa<VarPtr>(graph)) {
  91. return std::make_shared<CNode>(input_nodes, utils::cast<VarPtr>(graph));
  92. }
  93. return nullptr;
  94. }
  95. VarNodePtr CreateVarNodeWithSexp(const BaseRef &sexp, const BaseRef &graph) {
  96. if (utils::isa<VarPtr>(graph)) {
  97. MS_LOG(DEBUG) << "make VarPtr " + graph.ToString();
  98. return std::make_shared<VarNode>(utils::cast<VarPtr>(sexp), nullptr);
  99. }
  100. if (utils::isa<FuncGraphPtr>(graph)) {
  101. MS_LOG(DEBUG) << "VarNode, should input a Var in graph. It's GraphPtr: " + graph.ToString();
  102. return std::make_shared<VarNode>(utils::cast<VarPtr>(sexp), utils::cast<FuncGraphPtr>(graph));
  103. }
  104. MS_LOG(ERROR) << "VarNode, should input a Var in graph. It's " + graph.ToString();
  105. return nullptr;
  106. }
  107. AnfNodePtr HandleSexpVector(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars,
  108. bool multigraph) {
  109. MS_LOG(DEBUG) << "HandleSexpVector sexp: " + sexp.ToString() + ", graph " + graph.ToString();
  110. std::vector<AnfNodePtr> input_nodes;
  111. const auto &tuple = utils::cast<VectorRef>(sexp);
  112. if (multigraph && utils::isa<VarPtr>(graph)) {
  113. for (auto &x : tuple) {
  114. AnfNodePtr node = SexpToNode(x, std::make_shared<Var>("G"), primitive_vars, true);
  115. input_nodes.push_back(node);
  116. }
  117. VarPtr var_ptr = utils::cast<VarPtr>(graph);
  118. return std::make_shared<CNode>(input_nodes, var_ptr);
  119. }
  120. for (auto &x : tuple) {
  121. AnfNodePtr node = SexpToNode(x, graph, primitive_vars, multigraph);
  122. input_nodes.push_back(node);
  123. }
  124. return CreateCNodeWithGraph(input_nodes, graph);
  125. }
  126. } // namespace
  127. bool AnfEqual(const BaseRef &a, const BaseRef &b) {
  128. if (utils::isa<AnfNodePtr>(a) && utils::isa<AnfNodePtr>(b)) {
  129. auto a_node = utils::cast<AnfNodePtr>(a);
  130. auto b_node = utils::cast<AnfNodePtr>(b);
  131. if (a_node == nullptr || b_node == nullptr) {
  132. lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
  133. return false;
  134. }
  135. if (IsValueNode<Primitive>(a_node) && IsValueNode<Primitive>(b_node)) {
  136. auto a_value_node = a_node->cast<ValueNodePtr>();
  137. auto b_value_node = b_node->cast<ValueNodePtr>();
  138. if (a_value_node == nullptr || b_value_node == nullptr) {
  139. lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
  140. return false;
  141. }
  142. auto a_value = a_value_node->value();
  143. auto b_value = b_value_node->value();
  144. if (a_value == nullptr || b_value == nullptr) {
  145. lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
  146. return false;
  147. }
  148. auto a_prim = a_value->cast<PrimitivePtr>();
  149. auto b_prim = b_value->cast<PrimitivePtr>();
  150. if (a_prim == nullptr || b_prim == nullptr) {
  151. lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
  152. return false;
  153. }
  154. return a_prim->cast<PrimitiveCPtr>()->Type() == b_prim->cast<PrimitiveCPtr>()->Type();
  155. } else if (a_node->isa<ValueNode>() && b_node->isa<ValueNode>()) {
  156. auto a_value_node_ptr = a_node->cast<ValueNodePtr>();
  157. auto b_value_node_ptr = b_node->cast<ValueNodePtr>();
  158. if (a_value_node_ptr == nullptr || b_value_node_ptr == nullptr) {
  159. MS_LOG(ERROR) << "cast value node ptr fail";
  160. lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
  161. return false;
  162. }
  163. auto a_value_ptr = a_value_node_ptr->value();
  164. auto b_value_ptr = b_value_node_ptr->value();
  165. if (a_value_ptr == nullptr || b_value_ptr == nullptr) {
  166. MS_LOG(ERROR) << "value ptr is nullptr";
  167. lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
  168. return false;
  169. }
  170. if (utils::isa<lite::PrimitiveC>(a_value_ptr) && utils::isa<lite::PrimitiveC>(b_value_ptr)) {
  171. auto a_obj = (lite::PrimitiveC *)(a_value_ptr.get());
  172. auto b_obj = (lite::PrimitiveC *)(b_value_ptr.get());
  173. return (*a_obj) == (*b_obj);
  174. } else {
  175. return (*a_value_ptr) == (*b_value_ptr);
  176. }
  177. }
  178. }
  179. if (a.m_ptr->isa<lite::PrimitiveC>() && b.m_ptr->isa<lite::PrimitiveC>()) {
  180. auto a_value_node_ptr = a.m_ptr->cast<PrimitiveCPtr>();
  181. auto b_value_node_ptr = b.m_ptr->cast<PrimitiveCPtr>();
  182. return a_value_node_ptr->Type() == b_value_node_ptr->Type();
  183. }
  184. return a == b;
  185. }
  186. bool CNodeTypeEqual(const BaseRef &a, const BaseRef &b) {
  187. // To matchCNode and Kernel's type
  188. if (utils::isa<CNode>(a) && utils::isa<CNode>(b)) {
  189. return true;
  190. }
  191. return a.type() == b.type();
  192. }
  193. AnfNodePtr SexpToNode(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars, bool multigraph) {
  194. MS_LOG(DEBUG) << "SexpToNode sexp: " + sexp.ToString() + ", graph " + graph.ToString();
  195. if (primitive_vars == nullptr) {
  196. lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
  197. return nullptr;
  198. }
  199. if (utils::isa<VectorRef>(sexp)) {
  200. return HandleSexpVector(sexp, graph, primitive_vars, multigraph);
  201. }
  202. if (utils::isa<VarPtr>(sexp)) {
  203. auto var_ptr = utils::cast<VarPtr>(sexp);
  204. if (var_ptr == nullptr) {
  205. lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
  206. return nullptr;
  207. }
  208. if (var_ptr->primitive()) {
  209. (*primitive_vars)[var_ptr->primitive()] = var_ptr;
  210. return NewValueNode(var_ptr->primitive());
  211. }
  212. return CreateVarNodeWithSexp(sexp, graph);
  213. }
  214. if (utils::isa<AnfNodePtr>(sexp)) {
  215. return utils::cast<AnfNodePtr>(sexp);
  216. }
  217. auto value_node = CreateValueNodeWithSexp(sexp);
  218. if (value_node == nullptr) {
  219. MS_LOG(ERROR) << "sexp cannot converted. sexp: " << sexp.ToString();
  220. lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
  221. return nullptr;
  222. }
  223. return value_node;
  224. }
  225. bool IsRealCNodeKernel(const AnfNodePtr &node) {
  226. if (node == nullptr) {
  227. lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
  228. return false;
  229. }
  230. // parameter and value node is not a real cnode kernel
  231. if (!node->isa<CNode>()) {
  232. return false;
  233. }
  234. // return considered as a real node
  235. if (CheckPrimitiveType(node, prim::kPrimReturn)) {
  236. return true;
  237. }
  238. return IsRealKernel(node);
  239. }
  240. bool IsGraphKernel(const AnfNodePtr &node) {
  241. if (node == nullptr) {
  242. lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
  243. return false;
  244. }
  245. // graph kernel should be a real cnode kernel.
  246. if (!IsRealCNodeKernel(node)) {
  247. return false;
  248. }
  249. auto cnode = node->cast<CNodePtr>();
  250. if (cnode == nullptr) {
  251. lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
  252. return false;
  253. }
  254. auto input = cnode->input(kAnfPrimitiveIndex);
  255. // graph kernel should has func_graph as first input.
  256. if (!IsValueNode<FuncGraph>(input)) {
  257. return false;
  258. }
  259. auto func_graph = GetValueNode<FuncGraphPtr>(input);
  260. if (func_graph == nullptr) {
  261. lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
  262. return false;
  263. }
  264. return func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL);
  265. }
  266. int CheckIfFuncGraphIsNull(const FuncGraphPtr &graph) {
  267. if (graph == nullptr) {
  268. MS_LOG(ERROR) << "The graph is null.";
  269. lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
  270. return lite::RET_NULL_PTR;
  271. }
  272. return lite::RET_OK;
  273. }
  274. int CheckIfAnfNodeIsNull(const AnfNodePtr &node) {
  275. if (node == nullptr) {
  276. MS_LOG(ERROR) << "The AnfNode is null.";
  277. lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
  278. return lite::RET_NULL_PTR;
  279. }
  280. return lite::RET_OK;
  281. }
  282. int CheckIfCNodeIsNull(const CNodePtr &node) {
  283. if (node == nullptr) {
  284. MS_LOG(ERROR) << "The CNode is null.";
  285. lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
  286. return lite::RET_NULL_PTR;
  287. }
  288. return lite::RET_OK;
  289. }
  290. int CheckIfVarIsNull(const VarPtr &var) {
  291. if (var == nullptr) {
  292. MS_LOG(ERROR) << "The Var is null.";
  293. lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
  294. return lite::RET_NULL_PTR;
  295. }
  296. return lite::RET_OK;
  297. }
  298. int CheckIfNodeIsParam(const AnfNodePtr &node) {
  299. if (node != nullptr && !utils::isa<ParameterPtr>(node)) {
  300. MS_LOG(ERROR) << "The Node is not param.";
  301. lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_INVALID_OP_ATTR);
  302. return lite::RET_INVALID_OP_ATTR;
  303. }
  304. return lite::RET_OK;
  305. }
  306. int CheckInputSize(const CNodePtr &node, const int size) {
  307. if (static_cast<int>(node->inputs().size()) != size) {
  308. MS_LOG(ERROR) << "The input size of node must be " << size << ", but it is" << node->inputs().size();
  309. lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_INVALID_OP_ATTR);
  310. return lite::RET_INVALID_OP_ATTR;
  311. }
  312. return lite::RET_OK;
  313. }
  314. int CheckLeastInputSize(const CNodePtr &node, const int size) {
  315. if (static_cast<int>(node->inputs().size()) < size) {
  316. MS_LOG(ERROR) << "The input size of node must be " << size << ", but it is" << node->inputs().size();
  317. lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_INVALID_OP_ATTR);
  318. return lite::RET_INVALID_OP_ATTR;
  319. }
  320. return lite::RET_OK;
  321. }
  322. ParameterPtr AddNewBiasNode(float *bias_data, const FuncGraphPtr &func_graph, int kernel_num,
  323. const ParamValueLitePtr &weight_tensor) {
  324. auto bias_parameter = func_graph->add_parameter();
  325. MS_ASSERT(bias_parameter != nullptr);
  326. std::vector<int> shape = {kernel_num};
  327. auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(TypeIdToType(weight_tensor->tensor_type()), shape);
  328. bias_parameter->set_abstract(abstract_tensor);
  329. ParamValueLitePtr param_value = std::make_shared<ParamValueLite>();
  330. MS_ASSERT(param_value != nullptr);
  331. param_value->set_tensor_addr(bias_data);
  332. param_value->set_tensor_size(kernel_num * sizeof(float) / sizeof(uint8_t));
  333. param_value->set_format(weight_tensor->format());
  334. param_value->set_tensor_type(weight_tensor->tensor_type());
  335. param_value->set_tensor_shape(shape);
  336. bias_parameter->set_default_param(param_value);
  337. return bias_parameter;
  338. }
  339. schema::PrimitiveType GetCNodeType(const BaseRef &n) {
  340. ValueNodePtr value_node;
  341. if (utils::isa<CNodePtr>(n)) {
  342. auto in = utils::cast<CNodePtr>(n);
  343. value_node = in->input(0)->cast<ValueNodePtr>();
  344. } else if (utils::isa<ValueNodePtr>(n)) {
  345. value_node = utils::cast<ValueNodePtr>(n);
  346. } else {
  347. MS_LOG(ERROR) << "only value node or cnode has type";
  348. lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_INVALID_OP_ATTR);
  349. return schema::PrimitiveType_NONE;
  350. }
  351. if (value_node == nullptr) {
  352. lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
  353. return schema::PrimitiveType_NONE;
  354. }
  355. auto value = value_node->value();
  356. MS_ASSERT(value != nullptr);
  357. if (utils::isa<PrimitiveCPtr>(value)) {
  358. auto primitive = value->cast<PrimitiveCPtr>();
  359. MS_ASSERT(primitive != nullptr);
  360. return (schema::PrimitiveType)primitive->Type();
  361. } else if (utils::isa<Primitive>(value)) {
  362. auto primitive = value->cast<PrimitivePtr>();
  363. MS_ASSERT(primitive != nullptr);
  364. MS_LOG(INFO) << "anf primitive node type:" << primitive->name();
  365. return schema::PrimitiveType_NONE;
  366. }
  367. return schema::PrimitiveType_NONE;
  368. }
  369. ParamValueLitePtr GetLiteParamValue(const AnfNodePtr &node) {
  370. MS_ASSERT(node != nullptr);
  371. if (!utils::isa<ParameterPtr>(node)) {
  372. MS_LOG(ERROR) << "get lite param value node must paramter";
  373. return nullptr;
  374. }
  375. auto param = node->cast<ParameterPtr>();
  376. MS_ASSERT(param != nullptr);
  377. auto param_value = std::dynamic_pointer_cast<ParamValueLite>(param->default_param());
  378. return param_value;
  379. }
  380. bool IsParamNode(const BaseRef &n) {
  381. if (!utils::isa<ParameterPtr>(n)) {
  382. return false;
  383. }
  384. auto param = utils::cast<ParameterPtr>(n)->default_param();
  385. auto tensor = std::dynamic_pointer_cast<ParamValueLite>(param);
  386. if (tensor == nullptr) {
  387. return false;
  388. }
  389. return tensor->tensor_addr() != nullptr;
  390. }
  391. bool IsConvNode(const BaseRef &n) {
  392. if (utils::isa<CNodePtr>(n) || utils::isa<ValueNodePtr>(n)) {
  393. auto type = opt::GetCNodeType(n);
  394. return type == schema::PrimitiveType_Conv2D || type == schema::PrimitiveType_DepthwiseConv2D;
  395. }
  396. return false;
  397. }
  398. bool IsPoolingNode(const BaseRef &n) {
  399. if (utils::isa<CNodePtr>(n) || utils::isa<ValueNodePtr>(n)) {
  400. auto type = opt::GetCNodeType(n);
  401. return type == schema::PrimitiveType_Pooling;
  402. }
  403. return false;
  404. }
  405. bool IsQuantNode(const BaseRef &n) {
  406. if (utils::isa<CNodePtr>(n) || utils::isa<ValueNodePtr>(n)) {
  407. auto type = opt::GetCNodeType(n);
  408. return type == schema::PrimitiveType_QuantDTypeCast;
  409. }
  410. return false;
  411. }
  412. bool CheckIsAllInputsParam(const AnfNodePtr &node) {
  413. if (utils::isa<CNode>(node)) {
  414. auto cnode = node->cast<CNodePtr>();
  415. for (size_t i = 1; i < cnode->inputs().size(); i++) {
  416. if (!utils::isa<Parameter>(cnode->input(i))) {
  417. return false;
  418. }
  419. }
  420. return true;
  421. }
  422. return false;
  423. }
  424. size_t GetOutputTensorNum(const AnfNodePtr &node) {
  425. if (node == nullptr) {
  426. lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
  427. return 0;
  428. }
  429. auto type = node->Type();
  430. if (type == nullptr) {
  431. return 1;
  432. }
  433. if (type->isa<Tuple>()) {
  434. auto tuple_type = type->cast<TuplePtr>();
  435. if (tuple_type == nullptr) {
  436. lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
  437. return 0;
  438. }
  439. return tuple_type->size();
  440. } else if (type->isa<TensorType>() || type->isa<Number>()) {
  441. return 1;
  442. } else if (type->isa<TypeNone>()) {
  443. return 0;
  444. } else {
  445. return 1;
  446. }
  447. }
  448. bool IsMultiOutputTensors(const FuncGraphPtr &graph, const AnfNodePtr &node) {
  449. auto output_node_list = GetRealNodeUsedList(graph, node);
  450. if (output_node_list->size() != 1) {
  451. MS_LOG(DEBUG) << "fusion node has multi output nodes";
  452. return true;
  453. }
  454. return false;
  455. }
  456. std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> GetRealNodeUsedList(const FuncGraphPtr &graph,
  457. const AnfNodePtr &node) {
  458. auto output_node_list = std::make_shared<std::vector<std::pair<AnfNodePtr, int>>>();
  459. if (graph == nullptr) {
  460. lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
  461. return nullptr;
  462. }
  463. auto manager = graph->manager();
  464. if (manager == nullptr) {
  465. lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
  466. return nullptr;
  467. }
  468. auto iter = manager->node_users().find(node);
  469. if (iter == manager->node_users().end()) {
  470. MS_LOG(ERROR) << "node has no output in manager";
  471. lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NOT_FIND_OP);
  472. return nullptr;
  473. }
  474. auto output_info_list = iter->second;
  475. std::copy(output_info_list.begin(), output_info_list.end(), std::back_inserter(*output_node_list));
  476. return output_node_list;
  477. }
  478. size_t GetTupleGetItemOutIndex(const CNodePtr &tuple_get_item) {
  479. MS_ASSERT(tuple_get_item != nullptr);
  480. if (tuple_get_item->size() != kTupleGetItemInputSize) {
  481. MS_LOG(ERROR) << "The node tuple_get_item must have 2 inputs!";
  482. return -1;
  483. }
  484. auto output_index_value_node = tuple_get_item->input(kInputNodeOutputIndexInTupleGetItem);
  485. MS_ASSERT(output_index_value_node != nullptr);
  486. auto value_node = output_index_value_node->cast<ValueNodePtr>();
  487. MS_ASSERT(value_node != nullptr);
  488. return IntToSize(GetValue<int>(value_node->value()));
  489. }
  490. std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> GetRealNodeUsedListByOutputIdx(const FuncGraphPtr &graph,
  491. const AnfNodePtr &node,
  492. size_t output_index) {
  493. MS_ASSERT(graph != nullptr);
  494. MS_ASSERT(node != nullptr);
  495. auto output_node_list = std::make_shared<std::vector<std::pair<AnfNodePtr, int>>>();
  496. auto manager = graph->manager();
  497. MS_ASSERT(manager != nullptr);
  498. auto iter = manager->node_users().find(node);
  499. if (iter == manager->node_users().end()) {
  500. MS_LOG(ERROR) << "node has no output in manager";
  501. return output_node_list;
  502. }
  503. auto output_info_list = iter->second;
  504. for (const auto &output_info : output_info_list) {
  505. size_t used_output_index;
  506. if (GetCNodeType(output_info.first) == schema::PrimitiveType_TupleGetItem) {
  507. used_output_index = GetTupleGetItemOutIndex(utils::cast<CNodePtr>(output_info.first));
  508. } else if (GetCNodeType(node) == schema::PrimitiveType_TupleGetItem) {
  509. used_output_index = output_index;
  510. } else {
  511. if (output_index != 0) {
  512. MS_LOG(ERROR) << "node has no output in manager";
  513. return output_node_list;
  514. }
  515. return output_node_list;
  516. }
  517. if (used_output_index == output_index) {
  518. output_node_list->push_back(output_info);
  519. }
  520. }
  521. return output_node_list;
  522. }
  523. STATUS GetFilterDim(const std::vector<int32_t> &oriDims, kTransFilterType type, int32_t *filterK, int32_t *filterC,
  524. int32_t *filterH, int32_t *filterW) {
  525. MS_ASSERT(oriDims.size() == 4);
  526. if (type == kKCHW2HWCK || type == kKCHW2HWKC || type == kKCHW2KHWC || type == kKCHW2CKHW) {
  527. *filterK = oriDims.at(lite::KCHW_K);
  528. *filterC = oriDims.at(lite::KCHW_C);
  529. *filterH = oriDims.at(lite::KCHW_H);
  530. *filterW = oriDims.at(lite::KCHW_W);
  531. } else if (type == kCKHW2HWCK || type == kCKHW2HWKC || type == kCKHW2KHWC) {
  532. *filterC = oriDims.at(lite::CKHW_C);
  533. *filterK = oriDims.at(lite::CKHW_K);
  534. *filterH = oriDims.at(lite::CKHW_H);
  535. *filterW = oriDims.at(lite::CKHW_W);
  536. } else if (type == kHWCK2KCHW || type == kHWCK2CKHW) {
  537. *filterH = oriDims.at(lite::HWCK_H);
  538. *filterW = oriDims.at(lite::HWCK_W);
  539. *filterC = oriDims.at(lite::HWCK_C);
  540. *filterK = oriDims.at(lite::HWCK_K);
  541. } else if (type == kHWKC2KCHW || type == kHWKC2CKHW) {
  542. *filterH = oriDims.at(lite::HWKC_H);
  543. *filterW = oriDims.at(lite::HWKC_W);
  544. *filterK = oriDims.at(lite::HWKC_K);
  545. *filterC = oriDims.at(lite::HWKC_C);
  546. } else if (type == kNHWC2KCHW || type == kNHWC2HWCK || type == kNHWC2CKHW) {
  547. *filterK = oriDims.at(lite::NHWC_N);
  548. *filterH = oriDims.at(lite::NHWC_H);
  549. *filterW = oriDims.at(lite::NHWC_W);
  550. *filterC = oriDims.at(lite::NHWC_C);
  551. } else if (type == kCHWK2HWCK || type == kCHWK2KHWC) {
  552. *filterC = oriDims.at(lite::CHWK_C);
  553. *filterH = oriDims.at(lite::CHWK_H);
  554. *filterW = oriDims.at(lite::CHWK_W);
  555. *filterK = oriDims.at(lite::CHWK_K);
  556. } else if (type == kKHWC2HWCK || type == kKHWC2CHWK) {
  557. *filterK = oriDims.at(lite::KHWC_K);
  558. *filterH = oriDims.at(lite::KHWC_H);
  559. *filterW = oriDims.at(lite::KHWC_W);
  560. *filterC = oriDims.at(lite::KHWC_C);
  561. } else {
  562. MS_LOG(ERROR) << "Unsupported transFilterType: " << type;
  563. return RET_ERROR;
  564. }
  565. return RET_OK;
  566. }
  567. STATUS SetFilterDim(const ParamValueLitePtr &tensor, kTransFilterType type, int32_t filterK, int32_t filterC,
  568. int32_t filterH, int32_t filterW) {
  569. MS_ASSERT(tensor != nullptr);
  570. if (type == kKCHW2HWCK || type == kCKHW2HWCK || type == kNHWC2HWCK || type == kKHWC2HWCK || type == kCHWK2HWCK) {
  571. tensor->set_tensor_shape({filterH, filterW, filterC, filterK});
  572. } else if (type == kKCHW2HWKC || type == kCKHW2HWKC) {
  573. tensor->set_tensor_shape({filterH, filterW, filterK, filterC});
  574. } else if (type == kHWCK2KCHW || type == kHWKC2KCHW || type == kNHWC2KCHW) {
  575. tensor->set_tensor_shape({filterK, filterC, filterH, filterW});
  576. } else if (type == kHWCK2CKHW || type == kHWKC2CKHW || type == kNHWC2CKHW || type == kKCHW2CKHW) {
  577. tensor->set_tensor_shape({filterC, filterK, filterH, filterW});
  578. } else if (type == kKHWC2CHWK) {
  579. tensor->set_tensor_shape({filterC, filterH, filterW, filterK});
  580. } else if (type == kKCHW2KHWC || type == kCKHW2KHWC || type == kCHWK2KHWC) {
  581. tensor->set_tensor_shape({filterK, filterH, filterW, filterC});
  582. } else {
  583. MS_LOG(ERROR) << "Unsupported transFilterType: " << type;
  584. return RET_ERROR;
  585. }
  586. return RET_OK;
  587. }
  588. template <typename T>
  589. static STATUS TransFilterData(const ParamValueLitePtr &tensor, kTransFilterType type, int32_t filterK, int32_t filterC,
  590. int32_t filterH, int32_t filterW) {
  591. MS_ASSERT(tensor != nullptr);
  592. int count = filterH * filterW * filterC * filterK;
  593. if (count <= 0) {
  594. MS_LOG(ERROR) << "Dim size invalid";
  595. return RET_ERROR;
  596. }
  597. std::unique_ptr<T[]> buf(new (std::nothrow) T[count]);
  598. if (buf == nullptr) {
  599. MS_LOG(ERROR) << "new buf failed";
  600. return RET_ERROR;
  601. }
  602. void *originWeightData = tensor->tensor_addr();
  603. T *weightData = static_cast<T *>(originWeightData);
  604. if (weightData == nullptr) {
  605. MS_LOG(ERROR) << "weightData is nullptr";
  606. return RET_ERROR;
  607. }
  608. T *p1Buff = nullptr;
  609. T *p2Buff = nullptr;
  610. switch (type) {
  611. case kCHWK2HWCK:
  612. case kCHWK2KHWC: {
  613. for (int c = 0; c < filterC; ++c) {
  614. for (int h = 0; h < filterH; ++h) {
  615. for (int w = 0; w < filterW; ++w) {
  616. for (int k = 0; k < filterK; ++k) {
  617. p1Buff = weightData + ((c * filterH * filterW * filterK) + (h * filterW * filterK) + (w * filterK) + (k));
  618. if (type == kCHWK2HWCK) {
  619. p2Buff =
  620. buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k));
  621. } else if (type == kCHWK2KHWC) {
  622. p2Buff =
  623. buf.get() + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c));
  624. }
  625. *p2Buff = *p1Buff;
  626. }
  627. }
  628. }
  629. }
  630. } break;
  631. case kKHWC2HWCK: {
  632. for (int k = 0; k < filterK; ++k) {
  633. for (int h = 0; h < filterH; ++h) {
  634. for (int w = 0; w < filterW; ++w) {
  635. for (int c = 0; c < filterC; ++c) {
  636. p1Buff = weightData + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c));
  637. p2Buff = buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k));
  638. *p2Buff = *p1Buff;
  639. }
  640. }
  641. }
  642. }
  643. } break;
  644. case kKCHW2HWCK:
  645. case kKCHW2CKHW:
  646. case kKCHW2KHWC:
  647. case kKCHW2HWKC: {
  648. for (int k = 0; k < filterK; ++k) {
  649. for (int c = 0; c < filterC; ++c) {
  650. for (int h = 0; h < filterH; ++h) {
  651. for (int w = 0; w < filterW; ++w) {
  652. p1Buff = weightData + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w));
  653. if (type == kKCHW2HWCK) {
  654. p2Buff =
  655. buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k));
  656. } else if (type == kKCHW2KHWC) {
  657. p2Buff =
  658. buf.get() + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c));
  659. } else if (type == kKCHW2CKHW) {
  660. p2Buff =
  661. buf.get() + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w));
  662. } else {
  663. p2Buff =
  664. buf.get() + ((h * filterW * filterK * filterC) + (w * filterK * filterC) + (k * filterC) + (c));
  665. }
  666. *p2Buff = *p1Buff;
  667. }
  668. }
  669. }
  670. }
  671. } break;
  672. case kCKHW2HWCK:
  673. case kCKHW2KHWC:
  674. case kCKHW2HWKC: {
  675. for (int c = 0; c < filterC; ++c) {
  676. for (int k = 0; k < filterK; ++k) {
  677. for (int h = 0; h < filterH; ++h) {
  678. for (int w = 0; w < filterW; ++w) {
  679. p1Buff = weightData + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w));
  680. if (type == kCKHW2HWCK) {
  681. p2Buff =
  682. buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k));
  683. } else if (type == kCKHW2KHWC) {
  684. p2Buff =
  685. buf.get() + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c));
  686. } else {
  687. p2Buff =
  688. buf.get() + ((h * filterW * filterK * filterC) + (w * filterK * filterC) + (k * filterC) + (c));
  689. }
  690. *p2Buff = *p1Buff;
  691. }
  692. }
  693. }
  694. }
  695. } break;
  696. case kHWCK2KCHW:
  697. case kHWCK2CKHW: {
  698. for (int h = 0; h < filterH; ++h) {
  699. for (int w = 0; w < filterW; ++w) {
  700. for (int c = 0; c < filterC; ++c) {
  701. for (int k = 0; k < filterK; ++k) {
  702. p1Buff = weightData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k));
  703. if (type == kHWCK2KCHW) {
  704. p2Buff =
  705. buf.get() + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w));
  706. } else {
  707. p2Buff =
  708. buf.get() + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w));
  709. }
  710. *p2Buff = *p1Buff;
  711. }
  712. }
  713. }
  714. }
  715. } break;
  716. case kHWKC2KCHW:
  717. case kHWKC2CKHW: {
  718. for (int h = 0; h < filterH; ++h) {
  719. for (int w = 0; w < filterW; ++w) {
  720. for (int c = 0; c < filterC; ++c) {
  721. for (int k = 0; k < filterK; ++k) {
  722. p1Buff = weightData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (k * filterC) + (c));
  723. if (type == kHWKC2KCHW) {
  724. p2Buff =
  725. buf.get() + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w));
  726. } else {
  727. p2Buff =
  728. buf.get() + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w));
  729. }
  730. *p2Buff = *p1Buff;
  731. }
  732. }
  733. }
  734. }
  735. } break;
  736. case kNHWC2HWCK:
  737. case kNHWC2KCHW:
  738. case kNHWC2CKHW: {
  739. for (int k = 0; k < filterK; ++k) {
  740. for (int h = 0; h < filterH; ++h) {
  741. for (int w = 0; w < filterW; ++w) {
  742. for (int c = 0; c < filterC; ++c) {
  743. p1Buff = weightData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (k * filterC) + (c));
  744. if (type == kNHWC2HWCK) {
  745. p2Buff =
  746. buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k));
  747. } else if (type == kNHWC2CKHW) {
  748. p2Buff =
  749. buf.get() + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w));
  750. } else {
  751. p2Buff =
  752. buf.get() + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w));
  753. }
  754. *p2Buff = *p1Buff;
  755. }
  756. }
  757. }
  758. }
  759. } break;
  760. case kKHWC2CHWK: {
  761. for (int k = 0; k < filterK; ++k) {
  762. for (int h = 0; h < filterH; ++h) {
  763. for (int w = 0; w < filterW; ++w) {
  764. for (int c = 0; c < filterC; ++c) {
  765. p1Buff = weightData + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c));
  766. p2Buff = buf.get() + ((c * filterK * filterH * filterW) + (h * filterK * filterW) + (w * filterK) + (k));
  767. *p2Buff = *p1Buff;
  768. }
  769. }
  770. }
  771. }
  772. } break;
  773. default: {
  774. MS_LOG(ERROR) << "Unsupported transFilterType: " << type;
  775. return RET_ERROR;
  776. }
  777. }
  778. auto ret = ::memcpy_s(tensor->tensor_addr(), count * sizeof(T), buf.get(), count * sizeof(T));
  779. if (ret != EOK) {
  780. MS_LOG(ERROR) << "memcpy_s failed: " << ret;
  781. return RET_ERROR;
  782. }
  783. return RET_OK;
  784. }
  785. template <typename T>
  786. static STATUS TransFilterFormat(const ParamValueLitePtr &tensor, kTransFilterType type) {
  787. MS_ASSERT(tensor != nullptr);
  788. auto oriDims = tensor->tensor_shape();
  789. if (oriDims.size() != (size_t)lite::DIM_DEFAULT_SIZE) {
  790. MS_LOG(ERROR) << "Filter dim-num is not supported, dim-num: " << oriDims.size();
  791. return lite::RET_ERROR;
  792. }
  793. int32_t filterH;
  794. int32_t filterW;
  795. int32_t filterC;
  796. int32_t filterK;
  797. auto status = GetFilterDim(oriDims, type, &filterK, &filterC, &filterH, &filterW);
  798. if (status != lite::RET_OK) {
  799. MS_LOG(ERROR) << "GetFilterDim failed: " << status;
  800. return status;
  801. }
  802. status = SetFilterDim(tensor, type, filterK, filterC, filterH, filterW);
  803. if (status != lite::RET_OK) {
  804. MS_LOG(ERROR) << "SetFilterDim failed: " << status;
  805. return status;
  806. }
  807. status = TransFilterData<T>(tensor, type, filterK, filterC, filterH, filterW);
  808. if (status != lite::RET_OK) {
  809. MS_LOG(ERROR) << "TransFilterData failed: " << status;
  810. return status;
  811. }
  812. return lite::RET_OK;
  813. }
  814. STATUS TransFilterFormat(const ParamValueLitePtr &tensor, schema::Format dst_format) {
  815. if (tensor == nullptr) {
  816. return lite::RET_NULL_PTR;
  817. }
  818. auto ori_dims = tensor->tensor_shape();
  819. if (ori_dims.size() != (size_t)lite::DIM_DEFAULT_SIZE) {
  820. MS_LOG(ERROR) << "Filter dim-num is not supported, dim-num: " << ori_dims.size();
  821. return lite::RET_ERROR;
  822. }
  823. auto src_format = tensor->format();
  824. auto data_type = tensor->tensor_type();
  825. lite::STATUS status;
  826. switch (dst_format) {
  827. case schema::Format::Format_KHWC: {
  828. switch (src_format) {
  829. case schema::Format::Format_KCHW:
  830. if (data_type == kNumberTypeFloat32) {
  831. status = TransFilterFormat<float>(tensor, kKCHW2KHWC);
  832. } else if (data_type == kNumberTypeUInt8) {
  833. status = TransFilterFormat<uint8_t>(tensor, kKCHW2KHWC);
  834. } else if (data_type == kNumberTypeInt8) {
  835. status = TransFilterFormat<int8_t>(tensor, kKCHW2KHWC);
  836. } else if (data_type == kNumberTypeFloat16) {
  837. status = TransFilterFormat<float16>(tensor, kKCHW2KHWC);
  838. } else {
  839. MS_LOG(ERROR) << "Unsupported data_type: " << data_type;
  840. return RET_ERROR;
  841. }
  842. break;
  843. case schema::Format::Format_CKHW:
  844. if (data_type == kNumberTypeFloat32) {
  845. status = TransFilterFormat<float>(tensor, kCKHW2KHWC);
  846. } else if (data_type == kNumberTypeUInt8) {
  847. status = TransFilterFormat<uint8_t>(tensor, kCKHW2KHWC);
  848. } else if (data_type == kNumberTypeInt8) {
  849. status = TransFilterFormat<int8_t>(tensor, kCKHW2KHWC);
  850. } else if (data_type == kNumberTypeFloat16) {
  851. status = TransFilterFormat<float16>(tensor, kCKHW2KHWC);
  852. } else {
  853. MS_LOG(ERROR) << "Unsupported data_type: " << data_type;
  854. return RET_ERROR;
  855. }
  856. break;
  857. case schema::Format::Format_CHWK:
  858. if (data_type == kNumberTypeFloat32) {
  859. status = TransFilterFormat<float>(tensor, kCHWK2KHWC);
  860. } else if (data_type == kNumberTypeUInt8) {
  861. status = TransFilterFormat<uint8_t>(tensor, kCHWK2KHWC);
  862. } else if (data_type == kNumberTypeInt8) {
  863. status = TransFilterFormat<int8_t>(tensor, kCHWK2KHWC);
  864. } else if (data_type == kNumberTypeFloat16) {
  865. status = TransFilterFormat<float16>(tensor, kCHWK2KHWC);
  866. } else {
  867. MS_LOG(ERROR) << "Unsupported data_type: " << data_type;
  868. return RET_ERROR;
  869. }
  870. break;
  871. case schema::Format::Format_KHWC:
  872. return RET_OK;
  873. default:
  874. MS_LOG(ERROR) << "Unsupported transform from " << src_format << " to " << EnumNameFormat(dst_format);
  875. return RET_ERROR;
  876. }
  877. } break;
  878. case schema::Format::Format_HWCK: {
  879. switch (src_format) {
  880. case schema::Format::Format_KCHW:
  881. if (data_type == kNumberTypeFloat32) {
  882. status = TransFilterFormat<float>(tensor, kKCHW2HWCK);
  883. } else if (data_type == kNumberTypeUInt8) {
  884. status = TransFilterFormat<uint8_t>(tensor, kKCHW2HWCK);
  885. } else if (data_type == kNumberTypeInt8) {
  886. status = TransFilterFormat<int8_t>(tensor, kKCHW2HWCK);
  887. } else if (data_type == kNumberTypeFloat16) {
  888. status = TransFilterFormat<float16>(tensor, kKCHW2HWCK);
  889. } else {
  890. MS_LOG(ERROR) << "Unsupported data_type: " << data_type;
  891. return RET_ERROR;
  892. }
  893. break;
  894. case schema::Format::Format_KHWC:
  895. if (data_type == kNumberTypeFloat32) {
  896. status = TransFilterFormat<float>(tensor, kKHWC2HWCK);
  897. } else if (data_type == kNumberTypeUInt8) {
  898. status = TransFilterFormat<uint8_t>(tensor, kKHWC2HWCK);
  899. } else if (data_type == kNumberTypeInt8) {
  900. status = TransFilterFormat<int8_t>(tensor, kKHWC2HWCK);
  901. } else if (data_type == kNumberTypeFloat16) {
  902. status = TransFilterFormat<float16>(tensor, kKHWC2HWCK);
  903. } else {
  904. MS_LOG(ERROR) << "Unsupported data_type: " << data_type;
  905. return RET_ERROR;
  906. }
  907. break;
  908. case schema::Format::Format_CKHW:
  909. if (data_type == kNumberTypeFloat32) {
  910. status = TransFilterFormat<float>(tensor, kCKHW2HWCK);
  911. } else if (data_type == kNumberTypeUInt8) {
  912. status = TransFilterFormat<uint8_t>(tensor, kCKHW2HWCK);
  913. } else if (data_type == kNumberTypeInt8) {
  914. status = TransFilterFormat<int8_t>(tensor, kCKHW2HWCK);
  915. } else if (data_type == kNumberTypeFloat16) {
  916. status = TransFilterFormat<float16>(tensor, kCKHW2HWCK);
  917. } else {
  918. MS_LOG(ERROR) << "Unsupported data_type: " << data_type;
  919. return RET_ERROR;
  920. }
  921. break;
  922. case schema::Format::Format_CHWK:
  923. if (data_type == kNumberTypeFloat32) {
  924. status = TransFilterFormat<float>(tensor, kCHWK2HWCK);
  925. } else if (data_type == kNumberTypeUInt8) {
  926. status = TransFilterFormat<uint8_t>(tensor, kCHWK2HWCK);
  927. } else if (data_type == kNumberTypeInt8) {
  928. status = TransFilterFormat<int8_t>(tensor, kCHWK2HWCK);
  929. } else if (data_type == kNumberTypeFloat16) {
  930. status = TransFilterFormat<float16>(tensor, kCHWK2HWCK);
  931. } else {
  932. MS_LOG(ERROR) << "Unsupported data_type: " << data_type;
  933. return lite::RET_ERROR;
  934. }
  935. break;
  936. case schema::Format::Format_HWCK:
  937. return RET_OK;
  938. default:
  939. MS_LOG(ERROR) << "Unsupported transform from " << src_format << " to " << EnumNameFormat(dst_format);
  940. return RET_ERROR;
  941. }
  942. } break;
  943. case schema::Format::Format_KCHW: {
  944. switch (src_format) {
  945. case schema::Format::Format_KCHW:
  946. return RET_OK;
  947. case schema::Format::Format_HWCK:
  948. if (data_type == kNumberTypeFloat32) {
  949. status = TransFilterFormat<float>(tensor, kHWCK2KCHW);
  950. } else if (data_type == kNumberTypeUInt8) {
  951. status = TransFilterFormat<uint8_t>(tensor, kHWCK2KCHW);
  952. } else if (data_type == kNumberTypeInt8) {
  953. status = TransFilterFormat<int8_t>(tensor, kHWCK2KCHW);
  954. } else if (data_type == kNumberTypeFloat16) {
  955. status = TransFilterFormat<float16>(tensor, kHWCK2KCHW);
  956. } else {
  957. MS_LOG(ERROR) << "Unsupported data_type: " << data_type;
  958. return RET_ERROR;
  959. }
  960. break;
  961. case schema::Format::Format_HWKC:
  962. if (data_type == kNumberTypeFloat32) {
  963. status = TransFilterFormat<float>(tensor, kHWKC2KCHW);
  964. } else if (data_type == kNumberTypeUInt8) {
  965. status = TransFilterFormat<uint8_t>(tensor, kHWKC2KCHW);
  966. } else if (data_type == kNumberTypeInt8) {
  967. status = TransFilterFormat<int8_t>(tensor, kHWKC2KCHW);
  968. } else if (data_type == kNumberTypeFloat16) {
  969. status = TransFilterFormat<float16>(tensor, kHWCK2KCHW);
  970. } else {
  971. MS_LOG(ERROR) << "Unsupported data_type: " << data_type;
  972. return RET_ERROR;
  973. }
  974. break;
  975. case schema::Format::Format_KHWC:
  976. if (data_type == kNumberTypeFloat32) {
  977. status = TransFilterFormat<float>(tensor, kKHWC2KCHW);
  978. } else if (data_type == kNumberTypeUInt8) {
  979. status = TransFilterFormat<uint8_t>(tensor, kKHWC2KCHW);
  980. } else if (data_type == kNumberTypeInt8) {
  981. status = TransFilterFormat<int8_t>(tensor, kKHWC2KCHW);
  982. } else if (data_type == kNumberTypeFloat16) {
  983. status = TransFilterFormat<float16>(tensor, kKHWC2KCHW);
  984. } else {
  985. MS_LOG(ERROR) << "Unsupported data_type: " << data_type;
  986. return RET_ERROR;
  987. }
  988. break;
  989. case schema::Format::Format_CKHW:
  990. if (data_type == kNumberTypeFloat32) {
  991. status = TransFilterFormat<float>(tensor, kCKHW2KCHW);
  992. } else if (data_type == kNumberTypeUInt8) {
  993. status = TransFilterFormat<uint8_t>(tensor, kCKHW2KCHW);
  994. } else if (data_type == kNumberTypeInt8) {
  995. status = TransFilterFormat<int8_t>(tensor, kCKHW2KCHW);
  996. } else if (data_type == kNumberTypeFloat16) {
  997. status = TransFilterFormat<float16>(tensor, kCKHW2KCHW);
  998. } else {
  999. MS_LOG(ERROR) << "Unsupported data_type: " << data_type;
  1000. return RET_ERROR;
  1001. }
  1002. break;
  1003. case schema::Format::Format_CHWK:
  1004. if (data_type == kNumberTypeFloat32) {
  1005. status = TransFilterFormat<float>(tensor, kCHWK2KCHW);
  1006. } else if (data_type == kNumberTypeUInt8) {
  1007. status = TransFilterFormat<uint8_t>(tensor, kCHWK2KCHW);
  1008. } else if (data_type == kNumberTypeInt8) {
  1009. status = TransFilterFormat<int8_t>(tensor, kCHWK2KCHW);
  1010. } else if (data_type == kNumberTypeFloat16) {
  1011. status = TransFilterFormat<float16>(tensor, kCKHW2KCHW);
  1012. } else {
  1013. MS_LOG(ERROR) << "Unsupported data_type: " << data_type;
  1014. return RET_ERROR;
  1015. }
  1016. break;
  1017. default:
  1018. MS_LOG(ERROR) << "Unsupported transform from " << src_format << " to " << EnumNameFormat(dst_format);
  1019. return RET_ERROR;
  1020. }
  1021. } break;
  1022. case schema::Format::Format_CKHW: {
  1023. switch (src_format) {
  1024. case schema::Format::Format_HWCK:
  1025. if (data_type == kNumberTypeFloat32) {
  1026. status = TransFilterFormat<float>(tensor, kHWCK2CKHW);
  1027. } else if (data_type == kNumberTypeUInt8) {
  1028. status = TransFilterFormat<uint8_t>(tensor, kHWCK2CKHW);
  1029. } else if (data_type == kNumberTypeInt8) {
  1030. status = TransFilterFormat<int8_t>(tensor, kHWCK2CKHW);
  1031. } else if (data_type == kNumberTypeFloat16) {
  1032. status = TransFilterFormat<float16>(tensor, kHWCK2CKHW);
  1033. } else {
  1034. MS_LOG(ERROR) << "Unsupported data_type: " << data_type;
  1035. return RET_ERROR;
  1036. }
  1037. break;
  1038. case schema::Format::Format_HWKC:
  1039. if (data_type == kNumberTypeFloat32) {
  1040. status = TransFilterFormat<float>(tensor, kHWKC2CKHW);
  1041. } else if (data_type == kNumberTypeUInt8) {
  1042. status = TransFilterFormat<uint8_t>(tensor, kHWKC2CKHW);
  1043. } else if (data_type == kNumberTypeInt8) {
  1044. status = TransFilterFormat<int8_t>(tensor, kHWKC2CKHW);
  1045. } else if (data_type == kNumberTypeFloat16) {
  1046. status = TransFilterFormat<float16>(tensor, kHWKC2CKHW);
  1047. } else {
  1048. MS_LOG(ERROR) << "Unsupported data_type: " << data_type;
  1049. return RET_ERROR;
  1050. }
  1051. break;
  1052. case schema::Format::Format_KCHW:
  1053. if (data_type == kNumberTypeFloat32) {
  1054. status = TransFilterFormat<float>(tensor, kKCHW2CKHW);
  1055. } else if (data_type == kNumberTypeUInt8) {
  1056. status = TransFilterFormat<uint8_t>(tensor, kKCHW2CKHW);
  1057. } else if (data_type == kNumberTypeInt8) {
  1058. status = TransFilterFormat<int8_t>(tensor, kKCHW2CKHW);
  1059. } else if (data_type == kNumberTypeFloat16) {
  1060. status = TransFilterFormat<float16>(tensor, kKCHW2CKHW);
  1061. } else {
  1062. MS_LOG(ERROR) << "Unsupported data_type: " << data_type;
  1063. return RET_ERROR;
  1064. }
  1065. break;
  1066. case schema::Format::Format_CKHW:
  1067. return RET_OK;
  1068. default:
  1069. MS_LOG(ERROR) << "Unsupported transform from " << src_format << " to " << EnumNameFormat(dst_format);
  1070. return RET_ERROR;
  1071. }
  1072. } break;
  1073. default:
  1074. MS_LOG(ERROR) << "Unsupported transform from " << src_format << " to " << EnumNameFormat(dst_format);
  1075. return RET_ERROR;
  1076. }
  1077. if (status != RET_OK) {
  1078. MS_LOG(ERROR) << "TransFilterData failed: " << status;
  1079. return status;
  1080. }
  1081. return RET_OK;
  1082. }
  1083. } // namespace opt
  1084. } // namespace mindspore