You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

node_util.cc 23 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464
  1. /**
  2. * Copyright 2020-2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "tools/common/node_util.h"
  17. #include <memory>
  18. #include <set>
  19. #include <vector>
  20. #include "src/ops/populate/populate_register.h"
  21. #include "src/common/common.h"
  22. #include "src/common/log_adapter.h"
  23. #include "tools/common/graph_util.h"
  24. #include "tools/common/tensor_util.h"
  25. #include "src/runtime/infer_manager.h"
  26. namespace mindspore {
  27. namespace lite {
  28. constexpr size_t kInitialSize = 1024;
  29. static const std::vector<schema::PrimitiveType> nhwcOpList = {schema::PrimitiveType_Conv2DBackpropFilterFusion,
  30. schema::PrimitiveType_Conv2DBackpropInputFusion,
  31. schema::PrimitiveType_AvgPoolGrad,
  32. schema::PrimitiveType_MaxPoolGrad,
  33. schema::PrimitiveType_BiasAddGrad,
  34. schema::PrimitiveType_BatchNormGrad,
  35. schema::PrimitiveType_ApplyMomentum,
  36. schema::PrimitiveType_SGD,
  37. schema::PrimitiveType_Adam,
  38. schema::PrimitiveType_ResizeGrad,
  39. schema::PrimitiveType_AvgPoolFusion,
  40. schema::PrimitiveType_MaxPoolFusion,
  41. schema::PrimitiveType_Conv2DFusion,
  42. schema::PrimitiveType_Conv2dTransposeFusion,
  43. schema::PrimitiveType_LRN,
  44. schema::PrimitiveType_Resize,
  45. schema::PrimitiveType_BatchNorm,
  46. schema::PrimitiveType_FusedBatchNorm,
  47. schema::PrimitiveType_PReLUFusion,
  48. schema::PrimitiveType_BiasAdd,
  49. schema::PrimitiveType_SpaceToDepth,
  50. schema::PrimitiveType_DepthToSpace,
  51. schema::PrimitiveType_TopKFusion,
  52. schema::PrimitiveType_BatchToSpace,
  53. schema::PrimitiveType_SpaceToBatch,
  54. schema::PrimitiveType_SpaceToBatchND};
  55. static const std::vector<schema::PrimitiveType> nchwOpList = {schema::PrimitiveType_InstanceNorm};
  56. static const std::vector<schema::PrimitiveType> nhwcOpAllInputList = {
  57. schema::PrimitiveType_AvgPoolGrad, schema::PrimitiveType_MaxPoolGrad,
  58. schema::PrimitiveType_ActivationGrad, schema::PrimitiveType_Conv2DBackpropFilterFusion,
  59. schema::PrimitiveType_BatchNormGrad, schema::PrimitiveType_ResizeGrad};
  60. // index {} mean all inputs need insert
  61. static std::unordered_map<schema::PrimitiveType, std::vector<int>> extNhwcInsertIndex = {
  62. {schema::PrimitiveType_BatchNormGrad, {0, 1}},
  63. {schema::PrimitiveType_Conv2DBackpropFilterFusion, {0, 1}},
  64. {schema::PrimitiveType_ApplyMomentum, {3}},
  65. {schema::PrimitiveType_SGD, {1}},
  66. {schema::PrimitiveType_Adam, {9}}};
  67. static const std::vector<schema::PrimitiveType> fp32FullOpList = {
  68. schema::PrimitiveType_Concat, schema::PrimitiveType_AddFusion,
  69. schema::PrimitiveType_Floor}; // fp32 ops support C4 and nhwc in fp32
  70. static const std::vector<schema::PrimitiveType> int8NeedNhwcOpList = {};
  71. static const std::vector<schema::PrimitiveType> int8OpList = {schema::PrimitiveType_Conv2DFusion,
  72. schema::PrimitiveType_Conv2dTransposeFusion,
  73. schema::PrimitiveType_AddFusion,
  74. schema::PrimitiveType_Transpose,
  75. schema::PrimitiveType_AvgPoolFusion,
  76. schema::PrimitiveType_MaxPoolFusion,
  77. schema::PrimitiveType_Concat,
  78. schema::PrimitiveType_Softmax,
  79. schema::PrimitiveType_Reshape,
  80. schema::PrimitiveType_Activation,
  81. schema::PrimitiveType_Resize,
  82. schema::PrimitiveType_FullConnection,
  83. schema::PrimitiveType_ArgMaxFusion,
  84. schema::PrimitiveType_ArgMinFusion,
  85. schema::PrimitiveType_BatchNorm,
  86. schema::PrimitiveType_FusedBatchNorm,
  87. schema::PrimitiveType_BiasAdd,
  88. schema::PrimitiveType_DivFusion,
  89. schema::PrimitiveType_MulFusion,
  90. schema::PrimitiveType_SliceFusion,
  91. schema::PrimitiveType_Split,
  92. schema::PrimitiveType_Squeeze,
  93. schema::PrimitiveType_SubFusion,
  94. schema::PrimitiveType_StridedSlice,
  95. schema::PrimitiveType_TopKFusion,
  96. schema::PrimitiveType_Unsqueeze,
  97. schema::PrimitiveType_MatMul,
  98. schema::PrimitiveType_PadFusion,
  99. schema::PrimitiveType_ScaleFusion,
  100. schema::PrimitiveType_Cast,
  101. schema::PrimitiveType_Shape,
  102. schema::PrimitiveType_ExpandDims,
  103. schema::PrimitiveType_BatchToSpace,
  104. schema::PrimitiveType_BatchToSpaceND,
  105. schema::PrimitiveType_ReduceFusion,
  106. schema::PrimitiveType_Round,
  107. schema::PrimitiveType_Floor,
  108. schema::PrimitiveType_Ceil,
  109. schema::PrimitiveType_Abs,
  110. schema::PrimitiveType_Sin,
  111. schema::PrimitiveType_Cos,
  112. schema::PrimitiveType_Log,
  113. schema::PrimitiveType_Sqrt,
  114. schema::PrimitiveType_Rsqrt,
  115. schema::PrimitiveType_Square,
  116. schema::PrimitiveType_LogicalNot,
  117. schema::PrimitiveType_SpaceToBatch,
  118. schema::PrimitiveType_SpaceToBatchND,
  119. schema::PrimitiveType_DepthToSpace,
  120. schema::PrimitiveType_PowFusion,
  121. schema::PrimitiveType_GatherNd,
  122. schema::PrimitiveType_LeakyRelu,
  123. schema::PrimitiveType_Gather,
  124. schema::PrimitiveType_Equal,
  125. schema::PrimitiveType_NotEqual,
  126. schema::PrimitiveType_LessEqual,
  127. schema::PrimitiveType_Greater,
  128. schema::PrimitiveType_GreaterEqual,
  129. schema::PrimitiveType_Eltwise,
  130. schema::PrimitiveType_DetectionPostProcess,
  131. schema::PrimitiveType_Crop,
  132. schema::PrimitiveType_PriorBox,
  133. schema::PrimitiveType_QuantDTypeCast,
  134. schema::PrimitiveType_LayerNormFusion,
  135. schema::PrimitiveType_L2NormalizeFusion};
  136. static const std::vector<schema::PrimitiveType> needInsertOpList = {
  137. schema::PrimitiveType_Eltwise, schema::PrimitiveType_Activation, schema::PrimitiveType_Concat,
  138. schema::PrimitiveType_PowFusion, schema::PrimitiveType_StridedSlice, schema::PrimitiveType_AddFusion,
  139. schema::PrimitiveType_AddN, schema::PrimitiveType_Split, schema::PrimitiveType_SliceFusion,
  140. schema::PrimitiveType_Crop, schema::PrimitiveType_MulFusion, schema::PrimitiveType_Maximum,
  141. schema::PrimitiveType_ActivationGrad};
  142. static const std::unordered_map<int, int> nc2NhAxisMap = {{0, 0}, {1, -1}, {2, 1}, {3, 2}};
  143. std::unordered_map<int, int> GetNc2NhAxisMap() { return nc2NhAxisMap; }
  144. std::vector<schema::PrimitiveType> GetInsertOpList() { return needInsertOpList; }
  145. std::vector<schema::PrimitiveType> Getfp32FullOpList() { return fp32FullOpList; }
  146. std::vector<schema::PrimitiveType> GetNhwcOpList() { return nhwcOpList; }
  147. std::vector<schema::PrimitiveType> GetNchwOpList() { return nchwOpList; }
  148. std::unordered_map<schema::PrimitiveType, std::vector<int>> GetExtNhwcIndexes() { return extNhwcInsertIndex; }
  149. std::vector<schema::PrimitiveType> GetNhwcAllInputOpList() { return nhwcOpAllInputList; }
  150. std::vector<schema::PrimitiveType> GetUint8NhwcOpList() { return int8NeedNhwcOpList; }
  151. std::vector<schema::PrimitiveType> GetInt8OpList() { return int8OpList; }
  152. const schema::Primitive *ConvertToPrimitive(schema::PrimitiveT *primitive_t, flatbuffers::FlatBufferBuilder *fbb) {
  153. auto prim_offset = schema::CreatePrimitive(*fbb, primitive_t);
  154. fbb->Finish(prim_offset);
  155. auto prim_buf = fbb->GetBufferPointer();
  156. return flatbuffers::GetRoot<schema::Primitive>(prim_buf);
  157. }
  158. STATUS NodeUtils::ConvertDims(mindspore::schema::Format src_format, const std::vector<int32_t> &src_dims,
  159. mindspore::schema::Format dst_format, std::vector<int32_t> *dst_dims) {
  160. MS_ASSERT(nullptr != dst_dims);
  161. if ((src_dims.size() != DIM_DEFAULT_SIZE && src_dims.size() != 3) || src_format == dst_format) {
  162. MS_LOG(ERROR) << "Convert format , src size " << src_dims.size()
  163. << " <3 or src format is equal to dst format,not need convert";
  164. *dst_dims = src_dims;
  165. return RET_PARAM_INVALID;
  166. }
  167. std::vector<int32_t> nchw_dim;
  168. switch (src_format) {
  169. case schema::Format::Format_NCHW:
  170. nchw_dim = src_dims;
  171. break;
  172. case schema::Format::Format_NHWC:
  173. if (src_dims.size() == DIM_DEFAULT_SIZE) {
  174. nchw_dim.push_back(src_dims[NHWC_N]);
  175. nchw_dim.push_back(src_dims[NHWC_C]);
  176. nchw_dim.push_back(src_dims[NHWC_H]);
  177. nchw_dim.push_back(src_dims[NHWC_W]);
  178. } else {
  179. nchw_dim.push_back(src_dims[HWC_C]);
  180. nchw_dim.push_back(src_dims[HWC_H]);
  181. nchw_dim.push_back(src_dims[HWC_W]);
  182. }
  183. break;
  184. default:
  185. MS_LOG(ERROR) << "Not support src format: " << EnumNameFormat(src_format);
  186. return RET_ERROR;
  187. }
  188. if (nchw_dim.empty()) {
  189. MS_LOG(ERROR) << "Param nchw_dim is empty!";
  190. return RET_ERROR;
  191. }
  192. switch (dst_format) {
  193. case schema::Format::Format_NCHW:
  194. *dst_dims = nchw_dim;
  195. break;
  196. case schema::Format::Format_NHWC:
  197. if (src_dims.size() == DIM_DEFAULT_SIZE) {
  198. dst_dims->push_back(nchw_dim[NCHW_N]);
  199. dst_dims->push_back(nchw_dim[NCHW_H]);
  200. dst_dims->push_back(nchw_dim[NCHW_W]);
  201. dst_dims->push_back(nchw_dim[NCHW_C]);
  202. }
  203. break;
  204. default:
  205. MS_LOG(ERROR) << "Not support dst format: " << dst_format;
  206. return RET_ERROR;
  207. }
  208. return RET_OK;
  209. }
  210. static bool IsKCHWSource(kTransFilterType type) {
  211. return (type == kKCHW2HWCK || type == kKCHW2HWKC || type == kKCHW2KHWC || type == kKCHW2CKHW);
  212. }
  213. static bool IsCKHWSource(kTransFilterType type) {
  214. return (type == kCKHW2HWCK || type == kCKHW2HWKC || type == kCKHW2KHWC);
  215. }
  216. static bool IsHWCKSource(kTransFilterType type) { return (type == kHWCK2KCHW || type == kHWCK2CKHW); }
  217. static bool IsHWKCSource(kTransFilterType type) { return (type == kHWKC2KCHW || type == kHWKC2CKHW); }
  218. static bool IsNHWCSource(kTransFilterType type) {
  219. return (type == kNHWC2KCHW || type == kNHWC2HWCK || type == kNHWC2CKHW);
  220. }
  221. static bool IsCHWKSource(kTransFilterType type) { return (type == kCHWK2HWCK || type == kCHWK2KHWC); }
  222. static bool IsKHWCSource(kTransFilterType type) { return (type == kKHWC2HWCK || type == kKHWC2CHWK); }
  223. STATUS GetFilterDim(const std::vector<int32_t> &oriDims, kTransFilterType type, int32_t *filterK, int32_t *filterC,
  224. int32_t *filterH, int32_t *filterW) {
  225. if (filterK == nullptr || filterC == nullptr || filterH == nullptr || filterW == nullptr) {
  226. MS_LOG(ERROR) << "null input";
  227. return RET_NULL_PTR;
  228. }
  229. MS_ASSERT(oriDims.size() == 4);
  230. if (IsKCHWSource(type)) {
  231. *filterK = oriDims.at(KCHW_K);
  232. *filterC = oriDims.at(KCHW_C);
  233. *filterH = oriDims.at(KCHW_H);
  234. *filterW = oriDims.at(KCHW_W);
  235. } else if (IsCKHWSource(type)) {
  236. *filterC = oriDims.at(CKHW_C);
  237. *filterK = oriDims.at(CKHW_K);
  238. *filterH = oriDims.at(CKHW_H);
  239. *filterW = oriDims.at(CKHW_W);
  240. } else if (IsHWCKSource(type)) {
  241. *filterH = oriDims.at(HWCK_H);
  242. *filterW = oriDims.at(HWCK_W);
  243. *filterC = oriDims.at(HWCK_C);
  244. *filterK = oriDims.at(HWCK_K);
  245. } else if (IsHWKCSource(type)) {
  246. *filterH = oriDims.at(HWKC_H);
  247. *filterW = oriDims.at(HWKC_W);
  248. *filterK = oriDims.at(HWKC_K);
  249. *filterC = oriDims.at(HWKC_C);
  250. } else if (IsNHWCSource(type)) {
  251. *filterK = oriDims.at(NHWC_N);
  252. *filterH = oriDims.at(NHWC_H);
  253. *filterW = oriDims.at(NHWC_W);
  254. *filterC = oriDims.at(NHWC_C);
  255. } else if (IsCHWKSource(type)) {
  256. *filterC = oriDims.at(CHWK_C);
  257. *filterH = oriDims.at(CHWK_H);
  258. *filterW = oriDims.at(CHWK_W);
  259. *filterK = oriDims.at(CHWK_K);
  260. } else if (IsKHWCSource(type)) {
  261. *filterK = oriDims.at(KHWC_K);
  262. *filterH = oriDims.at(KHWC_H);
  263. *filterW = oriDims.at(KHWC_W);
  264. *filterC = oriDims.at(KHWC_C);
  265. } else {
  266. MS_LOG(ERROR) << "Unsupported transFilterType: " << type;
  267. return RET_ERROR;
  268. }
  269. return RET_OK;
  270. }
  271. STATUS SetFilterDim(schema::TensorT *tensor, kTransFilterType type, int32_t filterK, int32_t filterC, int32_t filterH,
  272. int32_t filterW) {
  273. MS_ASSERT(tensor != nullptr);
  274. if (type == kKCHW2HWCK || type == kCKHW2HWCK || type == kNHWC2HWCK || type == kKHWC2HWCK || type == kCHWK2HWCK) {
  275. tensor->dims = {filterH, filterW, filterC, filterK};
  276. } else if (type == kKCHW2HWKC || type == kCKHW2HWKC) {
  277. tensor->dims = {filterH, filterW, filterK, filterC};
  278. } else if (type == kHWCK2KCHW || type == kHWKC2KCHW || type == kNHWC2KCHW) {
  279. tensor->dims = {filterK, filterC, filterH, filterW};
  280. } else if (type == kHWCK2CKHW || type == kHWKC2CKHW || type == kNHWC2CKHW || type == kKCHW2CKHW) {
  281. tensor->dims = {filterC, filterK, filterH, filterW};
  282. } else if (type == kKHWC2CHWK) {
  283. tensor->dims = {filterC, filterH, filterW, filterK};
  284. } else if (type == kKCHW2KHWC || type == kCKHW2KHWC || type == kCHWK2KHWC) {
  285. tensor->dims = {filterK, filterH, filterW, filterC};
  286. } else {
  287. MS_LOG(ERROR) << "Unsupported transFilterType: " << type;
  288. return RET_ERROR;
  289. }
  290. return RET_OK;
  291. }
  292. static int Convert2KHWC(int srcFormat) {
  293. if (srcFormat == schema::Format::Format_KCHW) return kKCHW2KHWC;
  294. if (srcFormat == schema::Format::Format_CKHW) return kCKHW2KHWC;
  295. if (srcFormat == schema::Format::Format_CHWK) return kCHWK2KHWC;
  296. return -1;
  297. }
  298. static int Convert2HWCK(int srcFormat) {
  299. if (srcFormat == schema::Format::Format_KCHW) return kKCHW2HWCK;
  300. if (srcFormat == schema::Format::Format_KHWC) return kKHWC2HWCK;
  301. if (srcFormat == schema::Format::Format_CKHW) return kCKHW2HWCK;
  302. if (srcFormat == schema::Format::Format_CHWK) return kCHWK2HWCK;
  303. return -1;
  304. }
  305. static int Convert2KCHW(int srcFormat) {
  306. if (srcFormat == schema::Format::Format_HWCK) return kHWCK2KCHW;
  307. if (srcFormat == schema::Format::Format_HWKC) return kHWKC2KCHW;
  308. if (srcFormat == schema::Format::Format_KHWC) return kKHWC2KCHW;
  309. if (srcFormat == schema::Format::Format_CKHW) return kCKHW2KCHW;
  310. if (srcFormat == schema::Format::Format_CHWK) return kCHWK2KCHW;
  311. return -1;
  312. }
  313. static int Convert2CKHW(int srcFormat) {
  314. if (srcFormat == schema::Format::Format_HWCK) return kHWCK2CKHW;
  315. if (srcFormat == schema::Format::Format_HWKC) return kHWKC2CKHW;
  316. if (srcFormat == schema::Format::Format_KCHW) return kKCHW2CKHW;
  317. return -1;
  318. }
  319. STATUS NodeInferShpae(const schema::CNodeT &node, const std::vector<Tensor *> &inputs, std::vector<Tensor *> *outputs) {
  320. flatbuffers::FlatBufferBuilder fbb(kInitialSize);
  321. auto prim = ConvertToPrimitive(node.primitive.get(), &fbb);
  322. if (prim == nullptr) {
  323. MS_LOG(ERROR) << "get primitive failed.";
  324. fbb.Clear();
  325. return RET_ERROR;
  326. }
  327. auto parameter_gen = lite::PopulateRegistry::GetInstance()->GetParameterCreator(prim->value_type(), SCHEMA_CUR);
  328. if (parameter_gen == nullptr) {
  329. fbb.Clear();
  330. MS_LOG(ERROR) << "PopulateParameter return nullptr, type: " << schema::EnumNamePrimitiveType(prim->value_type());
  331. return RET_ERROR;
  332. }
  333. auto parameter = parameter_gen(prim);
  334. if (parameter == nullptr) {
  335. fbb.Clear();
  336. MS_LOG(ERROR) << "parameter is nullptr.";
  337. return RET_ERROR;
  338. }
  339. auto ret = KernelInferShape(inputs, *outputs, parameter);
  340. fbb.Clear();
  341. free(parameter);
  342. return ret;
  343. }
  344. size_t GetTensorInputIndexInCNode(const uint32_t &tensor_index, const schema::CNodeT &cnode) {
  345. size_t ret = -1;
  346. for (size_t i = 0; i < cnode.inputIndex.size(); i++) {
  347. if (cnode.inputIndex.at(i) == tensor_index) {
  348. ret = i;
  349. }
  350. }
  351. return ret;
  352. }
  353. STATUS TransFilterFormat(schema::TensorT *tensor, schema::Format dstFormat) {
  354. if (tensor == nullptr) {
  355. MS_LOG(ERROR) << "tensor is null";
  356. return RET_NULL_PTR;
  357. }
  358. std::vector<int32_t> oriDims = tensor->dims;
  359. if (oriDims.size() != (size_t)DIM_DEFAULT_SIZE) {
  360. MS_LOG(ERROR) << "Filter dim-num is not supported, dim-num: " << oriDims.size();
  361. return RET_ERROR;
  362. }
  363. auto srcFormat = tensor->format;
  364. auto dataType = tensor->dataType;
  365. STATUS status;
  366. int convert = -1;
  367. if (dstFormat == srcFormat) return RET_OK;
  368. switch (dstFormat) {
  369. case schema::Format::Format_KHWC:
  370. convert = Convert2KHWC(srcFormat);
  371. break;
  372. case schema::Format::Format_HWCK:
  373. convert = Convert2HWCK(srcFormat);
  374. break;
  375. case schema::Format::Format_KCHW:
  376. convert = Convert2KCHW(srcFormat);
  377. break;
  378. case schema::Format::Format_CKHW:
  379. convert = Convert2CKHW(srcFormat);
  380. break;
  381. default:
  382. convert = -1;
  383. }
  384. if (convert == -1) {
  385. MS_LOG(ERROR) << "Unsupported transform from " << EnumNameFormat(srcFormat) << " to " << EnumNameFormat(dstFormat);
  386. return RET_ERROR;
  387. }
  388. if (dataType == kNumberTypeFloat32) {
  389. status = TransFilterFormat<float>(tensor, static_cast<kTransFilterType>(convert));
  390. } else if (dataType == kNumberTypeUInt8) {
  391. status = TransFilterFormat<uint8_t>(tensor, static_cast<kTransFilterType>(convert));
  392. } else if (dataType == kNumberTypeInt8) {
  393. status = TransFilterFormat<int8_t>(tensor, static_cast<kTransFilterType>(convert));
  394. } else {
  395. MS_LOG(ERROR) << "Unsupported dataType: " << dataType;
  396. return RET_ERROR;
  397. }
  398. if (status != RET_OK) {
  399. MS_LOG(ERROR) << "TransFilterData failed: " << status;
  400. return status;
  401. }
  402. return RET_OK;
  403. }
  404. size_t GetCNodeOutputsSize(const std::shared_ptr<AnfNode> &anf_node, bool train_flag) {
  405. auto cnode = anf_node->cast<CNodePtr>();
  406. if (train_flag &&
  407. (opt::CheckPrimitiveType(cnode, prim::kPrimConv2DFusion) || opt::CheckPrimitiveType(cnode, prim::kPrimAdam))) {
  408. return 1;
  409. }
  410. if (utils::isa<abstract::AbstractTuple>(cnode->abstract())) {
  411. auto tuple = std::reinterpret_pointer_cast<abstract::AbstractTuple>(cnode->abstract());
  412. return tuple->elements().size();
  413. } else {
  414. return 1;
  415. }
  416. }
  417. } // namespace lite
  418. } // namespace mindspore