You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

trans.cc 78 kB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
5 years ago
5 years ago
4 years ago
4 years ago
4 years ago
4 years ago
5 years ago
5 years ago
5 years ago
4 years ago
4 years ago
4 years ago
4 years ago
5 years ago
5 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
5 years ago
4 years ago
4 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
optimize the comment and log description 修改: ops/operations/_inner_ops.py 修改: ops/operations/_quant_ops.py 修改: ops/operations/array_ops.py 修改: ops/operations/comm_ops.py 修改: ops/operations/math_ops.py 修改: ops/operations/quantum_ops.py 修改: ops/operations/rl_ops.py 修改: ops/operations/sponge_ops.py 修改: ops/operations/sponge_update_ops.py 修改: train/__init__.py 修改: common/tensor.py 修改: train/serialization.py 修改: ccsrc/pipeline/jit/parse/parse.h 修改: explainer/benchmark/_attribution/metric.py 修改: ops/composite/multitype_ops/_constexpr_utils.py 修改: ops/operations/comm_ops.py 修改: RELEASE.md 修改: mindspore/_extends/parse/standard_method.py 修改: mindspore/ccsrc/backend/kernel_compiler/cpu/concat_offset_cpu_kernel.cc 修改: mindspore/ccsrc/backend/kernel_compiler/cpu/dynamic_shape_cpu_kernel.cc 修改: mindspore/ccsrc/frontend/parallel/ops_info/reshape_info.cc 修改: mindspore/ccsrc/frontend/parallel/ops_info/tile_info.cc 修改: mindspore/ccsrc/frontend/parallel/ops_info/transpose_info.cc 修改: mindspore/ccsrc/frontend/parallel/strategy.h 修改: mindspore/common/tensor.py 修改: mindspore/core/abstract/prim_arrays.cc 修改: mindspore/core/abstract/prim_nn.cc 修改: mindspore/core/ops/conv2d.cc 修改: mindspore/core/ops/logical_and.h 修改: mindspore/core/ops/logical_not.h 修改: mindspore/core/ops/logical_or.h 修改: mindspore/core/ops/reduce_all.h 修改: mindspore/core/ops/reduce_any.h 修改: mindspore/lite/src/runtime/kernel/arm/fp32_grad/sgd.cc 修改: mindspore/nn/layer/quant.py 修改: mindspore/nn/optim/sgd.py 修改: mindspore/nn/sparse/sparse.py 修改: mindspore/numpy/array_creations.py 修改: mindspore/numpy/array_ops.py 修改: mindspore/numpy/logic_ops.py 修改: mindspore/numpy/math_ops.py 修改: mindspore/ops/operations/_inner_ops.py 修改: mindspore/ops/operations/array_ops.py 修改: mindspore/ops/operations/rl_ops.py 修改: mindspore/train/_utils.py 修改: tests/ut/python/model/test_lenet_core_after_exception.py 修改: mindspore/_extends/parse/standard_method.py 修改: mindspore/ops/operations/rl_ops.py 修改: mindspore/core/abstract/prim_nn.cc 修改: mindspore/core/ops/conv2d.cc 修改: mindspore/ccsrc/backend/kernel_compiler/cpu/ctcloss_cpu_kernel.cc 修改: mindspore/ccsrc/backend/kernel_compiler/cpu/fl/fused_pull_weight_kernel.h 修改: mindspore/ccsrc/backend/kernel_compiler/cpu/fl/fused_push_weight_kernel.h 修改: mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/conv2d_grad_filter_cpu_kernel.cc 修改: mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/conv2d_grad_input_cpu_kernel.cc 修改: mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_ftrl_ps_kernel.cc 修改: mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_lazy_adam_ps_kernel.cc 修改: mindspore/ccsrc/backend/kernel_compiler/cpu/rolling_cpu_kernel.cc 修改: mindspore/ccsrc/backend/kernel_compiler/cpu/scatter_arithmetic_cpu_kernel.cc 修改: mindspore/ccsrc/backend/kernel_compiler/cpu/split_cpu_kernel.cc 修改: mindspore/ccsrc/backend/kernel_compiler/cpu/update_cache_cpu_kernel.cc 修改: mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/split_gpu_kernel.h 修改: mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.h 修改: mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_grad_input_gpu_kernel.h 修改: mindspore/ccsrc/fl/server/server.cc 修改: mindspore/ccsrc/frontend/optimizer/ad/kpynative.cc 修改: mindspore/ccsrc/frontend/optimizer/irpass/incorporate_getitem.h 修改: mindspore/ccsrc/frontend/optimizer/irpass/inline.h 修改: mindspore/ccsrc/minddata/dataset/core/device_tensor.cc 修改: mindspore/ccsrc/minddata/dataset/core/tensor.cc 修改: mindspore/ccsrc/minddata/dataset/engine/datasetops/source/emnist_op.cc 修改: mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mnist_op.cc 修改: mindspore/ccsrc/minddata/dataset/engine/datasetops/source/qmnist_op.cc 修改: mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.cc 修改: mindspore/ccsrc/minddata/dataset/engine/opt/pre/epoch_ctrl_pass.cc 修改: mindspore/ccsrc/minddata/dataset/kernels/image/lite_image_utils.cc 修改: mindspore/ccsrc/pipeline/jit/action.cc 修改: mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.cc 修改: mindspore/ccsrc/runtime/device/ascend/executor/tiling/op_tiling_adapter.cc 修改: mindspore/compression/quant/quant_utils.py 修改: mindspore/core/abstract/prim_nn.cc 修改: mindspore/dataset/engine/validators.py 修改: mindspore/lite/micro/coder/opcoders/nnacl/fp32/affine_fp32_coder.cc 修改: mindspore/lite/micro/coder/opcoders/nnacl/int8/affine_int8_coder.cc 修改: mindspore/lite/src/runtime/kernel/ascend310/src/custom_kernel.cc 修改: mindspore/lite/src/runtime/kernel/opencl/kernel/matmul.cc 修改: mindspore/lite/src/runtime/kernel/opencl/kernel/strassen.cc 修改: mindspore/lite/tools/common/graph_util.h 修改: mindspore/lite/tools/optimizer/fisson/fisson_util.cc 修改: mindspore/ops/composite/math_ops.py 修改: mindspore/ops/operations/_inner_ops.py 修改: mindspore/ops/operations/array_ops.py 修改: mindspore/ops/operations/math_ops.py 修改: mindspore/ops/operations/other_ops.py 修改: mindspore/boost/boost_cell_wrapper.py 修改: mindspore/ccsrc/backend/kernel_compiler/cpu/update_cache_cpu_kernel.cc 修改: mindspore/ccsrc/common/trans.cc 修改: mindspore/ccsrc/frontend/parallel/cache_embedding/cache_embedding.cc 修改: mindspore/ccsrc/frontend/parallel/ops_info/gather_info.cc 修改: mindspore/lite/src/common/log_util.h 修改: mindspore/nn/wrap/loss_scale.py 修改: mindspore/parallel/nn/moe.py 修改: tests/mindspore_test_framework/mindspore_test.py 修改: mindspore/ccsrc/backend/kernel_compiler/cpu/split_cpu_kernel.cc 修改: mindspore/lite/tools/common/graph_util.h 修改: mindspore/ccsrc/frontend/parallel/ops_info/gather_info.cc 修改: mindspore/core/ops/conv2d.cc 修改: tests/ut/python/model/test_lenet_core_after_exception.py
4 years ago
4 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060
  1. /**
  2. * Copyright 2020-2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "common/trans.h"
  17. #include <functional>
  18. #include <numeric>
  19. #include <utility>
  20. #include <algorithm>
  21. #include "utils/ms_utils.h"
  22. #include "abstract/utils.h"
  23. #include "backend/kernel_compiler/kernel.h"
  24. #include "backend/kernel_compiler/tbe/tbe_dynaminc_shape_util.h"
  25. #include "runtime/device/convert_tensor_utils.h"
  26. #include "utils/convert_utils.h"
  27. #include "utils/log_adapter.h"
  28. #include "utils/utils.h"
  29. using mindspore::abstract::Shape;
  30. namespace mindspore {
  31. namespace trans {
  32. const int b1 = 1;
  33. const int b2 = 2;
  34. const int b4 = 4;
  35. const int b8 = 8;
  36. inline void SetData(size_t size, bool pad_zero, size_t src_idx, size_t dst_idx, const FormatArgs &args, void *result) {
  37. switch (size) {
  38. case b1:
  39. static_cast<uint8_t *>(result)[dst_idx] = pad_zero ? 0 : static_cast<const uint8_t *>(args.data)[src_idx];
  40. break;
  41. case b2:
  42. static_cast<uint16_t *>(result)[dst_idx] = pad_zero ? 0 : static_cast<const uint16_t *>(args.data)[src_idx];
  43. break;
  44. case b4:
  45. static_cast<uint32_t *>(result)[dst_idx] = pad_zero ? 0 : static_cast<const uint32_t *>(args.data)[src_idx];
  46. break;
  47. case b8:
  48. static_cast<uint64_t *>(result)[dst_idx] = pad_zero ? 0 : static_cast<const uint64_t *>(args.data)[src_idx];
  49. break;
  50. default:
  51. MS_LOG(EXCEPTION) << "Trans data not support size " << size;
  52. }
  53. }
  54. // greatest common divsor
  55. size_t Gcd(size_t a, size_t b) {
  56. if (b == 0) {
  57. return 0;
  58. }
  59. size_t c = b;
  60. while (a % b != 0) {
  61. c = a % b;
  62. a = b;
  63. b = c;
  64. }
  65. return c;
  66. }
  67. // least common multiple
  68. size_t Lcm(size_t a, size_t b) {
  69. if (b == 0) {
  70. return 0;
  71. }
  72. size_t ret = (a * b) / (Gcd(a, b));
  73. return ret;
  74. }
  75. template <typename T>
  76. T DivCeil(T n1, T n2) {
  77. if (n2 != 0) {
  78. return (n1 + n2 - 1) / n2;
  79. }
  80. return 0;
  81. }
  82. size_t GetShapeSize(const std::vector<size_t> &shape) {
  83. return std::accumulate(shape.begin(), shape.end(), size_t(1), std::multiplies<size_t>());
  84. }
  85. enum class DataTypeTransMode {
  86. FROM_FLOAT_TO_FLOAT16,
  87. FROM_FLOAT_TO_INT32,
  88. FROM_FLOAT16_TO_FLOAT,
  89. FROM_FLOAT16_TO_INT32,
  90. FROM_FLOAT16_TO_UINT8,
  91. FROM_INT32_TO_FLOAT,
  92. FROM_INT32_TO_FLOAT16,
  93. FROM_INT32_TO_UINT8,
  94. FROM_INT32_TO_INT8,
  95. FROM_INT32_TO_INT64,
  96. FROM_INT32_TO_BOOL,
  97. FROM_UINT8_TO_FLOAT,
  98. FROM_UINT8_TO_INT32,
  99. FROM_UINT8_TO_FLOAT16,
  100. FROM_INT8_TO_FLOAT,
  101. FROM_INT8_TO_FLOAT16,
  102. FROM_INT8_TO_INT32,
  103. FROM_INT64_TO_INT32,
  104. FROM_UINT16_TO_INT32,
  105. FROM_BOOL_TO_FLOAT,
  106. FROM_BOOL_TO_INT32,
  107. FROM_BOOL_TO_UINT8,
  108. FROM_BOOL_TO_FLOAT16,
  109. FROM_FLOAT64_TO_FLOAT32,
  110. FROM_FLOAT32_TO_FLOAT64
  111. };
  112. const std::map<std::pair<TypeId, TypeId>, DataTypeTransMode> mode_map{
  113. {std::pair<TypeId, TypeId>(kNumberTypeFloat64, kNumberTypeFloat32), DataTypeTransMode::FROM_FLOAT64_TO_FLOAT32},
  114. {std::pair<TypeId, TypeId>(kNumberTypeFloat32, kNumberTypeFloat64), DataTypeTransMode::FROM_FLOAT32_TO_FLOAT64},
  115. {std::pair<TypeId, TypeId>(kNumberTypeFloat32, kNumberTypeFloat16), DataTypeTransMode::FROM_FLOAT_TO_FLOAT16},
  116. {std::pair<TypeId, TypeId>(kNumberTypeFloat32, kNumberTypeInt32), DataTypeTransMode::FROM_FLOAT_TO_INT32},
  117. {std::pair<TypeId, TypeId>(kNumberTypeFloat16, kNumberTypeFloat32), DataTypeTransMode::FROM_FLOAT16_TO_FLOAT},
  118. {std::pair<TypeId, TypeId>(kNumberTypeFloat16, kNumberTypeInt32), DataTypeTransMode::FROM_FLOAT16_TO_INT32},
  119. {std::pair<TypeId, TypeId>(kNumberTypeFloat16, kNumberTypeUInt8), DataTypeTransMode::FROM_FLOAT16_TO_UINT8},
  120. {std::pair<TypeId, TypeId>(kNumberTypeInt32, kNumberTypeFloat32), DataTypeTransMode::FROM_INT32_TO_FLOAT},
  121. {std::pair<TypeId, TypeId>(kNumberTypeInt32, kNumberTypeFloat16), DataTypeTransMode::FROM_INT32_TO_FLOAT16},
  122. {std::pair<TypeId, TypeId>(kNumberTypeInt32, kNumberTypeUInt8), DataTypeTransMode::FROM_INT32_TO_UINT8},
  123. {std::pair<TypeId, TypeId>(kNumberTypeInt32, kNumberTypeInt8), DataTypeTransMode::FROM_INT32_TO_INT8},
  124. {std::pair<TypeId, TypeId>(kNumberTypeInt32, kNumberTypeInt64), DataTypeTransMode::FROM_INT32_TO_INT64},
  125. {std::pair<TypeId, TypeId>(kNumberTypeInt32, kNumberTypeBool), DataTypeTransMode::FROM_INT32_TO_BOOL},
  126. {std::pair<TypeId, TypeId>(kNumberTypeUInt8, kNumberTypeFloat32), DataTypeTransMode::FROM_UINT8_TO_FLOAT},
  127. {std::pair<TypeId, TypeId>(kNumberTypeUInt8, kNumberTypeInt32), DataTypeTransMode::FROM_UINT8_TO_INT32},
  128. {std::pair<TypeId, TypeId>(kNumberTypeUInt8, kNumberTypeFloat16), DataTypeTransMode::FROM_UINT8_TO_FLOAT16},
  129. {std::pair<TypeId, TypeId>(kNumberTypeInt8, kNumberTypeFloat32), DataTypeTransMode::FROM_INT8_TO_FLOAT},
  130. {std::pair<TypeId, TypeId>(kNumberTypeInt8, kNumberTypeFloat16), DataTypeTransMode::FROM_INT8_TO_FLOAT16},
  131. {std::pair<TypeId, TypeId>(kNumberTypeInt8, kNumberTypeInt32), DataTypeTransMode::FROM_INT8_TO_INT32},
  132. {std::pair<TypeId, TypeId>(kNumberTypeInt64, kNumberTypeInt32), DataTypeTransMode::FROM_INT64_TO_INT32},
  133. {std::pair<TypeId, TypeId>(kNumberTypeUInt16, kNumberTypeInt32), DataTypeTransMode::FROM_UINT16_TO_INT32},
  134. {std::pair<TypeId, TypeId>(kNumberTypeBool, kNumberTypeInt32), DataTypeTransMode::FROM_BOOL_TO_INT32},
  135. {std::pair<TypeId, TypeId>(kNumberTypeBool, kNumberTypeFloat), DataTypeTransMode::FROM_BOOL_TO_FLOAT},
  136. {std::pair<TypeId, TypeId>(kNumberTypeBool, kNumberTypeUInt8), DataTypeTransMode::FROM_BOOL_TO_UINT8},
  137. {std::pair<TypeId, TypeId>(kNumberTypeBool, kNumberTypeFloat16), DataTypeTransMode::FROM_BOOL_TO_FLOAT16}};
  138. void CheckMemSize(const TypeIdArgs &args) {
  139. auto src_type_size = abstract::TypeIdSize(args.host_data_type);
  140. auto dst_type_size = abstract::TypeIdSize(args.device_data_type);
  141. if (src_type_size < 1 || dst_type_size < 1) {
  142. MS_LOG(EXCEPTION) << "Invalid src or dst data type.";
  143. }
  144. if (args.data_size / src_type_size != args.host_shape_size) {
  145. MS_LOG(EXCEPTION) << "Invalid src or dst data size.";
  146. }
  147. }
  148. template <typename SrcT, typename DstT>
  149. void TransDataSrc2Dst(const TypeIdArgs &args, void *dst, const size_t data_size) {
  150. CheckMemSize(args);
  151. for (size_t idx = 0; idx != data_size; idx++) {
  152. SrcT src_data = static_cast<const SrcT *>(args.data)[idx];
  153. static_cast<DstT *>(dst)[idx] = static_cast<DstT>(src_data);
  154. }
  155. }
  156. template <typename SrcT>
  157. void TransDataSrc2Fp16(const TypeIdArgs &args, void *dst, const size_t data_size) {
  158. CheckMemSize(args);
  159. auto src_data = static_cast<const SrcT *>(args.data);
  160. auto half_data = static_cast<float16 *>(dst);
  161. for (size_t i = 0; i < data_size; i++) {
  162. half_data[i] = float16(src_data[i]);
  163. }
  164. }
  165. bool CastKernel(const TypeIdArgs &args, void *dst, const size_t data_size, const DataTypeTransMode mode) {
  166. using DtypeKernel = std::function<void(const TypeIdArgs &, void *, const size_t)>;
  167. const std::map<DataTypeTransMode, DtypeKernel> cast_kernel_map{
  168. {DataTypeTransMode::FROM_FLOAT_TO_INT32, TransDataSrc2Dst<float, int32_t>},
  169. {DataTypeTransMode::FROM_FLOAT64_TO_FLOAT32, TransDataSrc2Dst<double, float>},
  170. {DataTypeTransMode::FROM_FLOAT32_TO_FLOAT64, TransDataSrc2Dst<float, double>},
  171. {DataTypeTransMode::FROM_FLOAT16_TO_INT32, TransDataSrc2Dst<float16, int32_t>},
  172. {DataTypeTransMode::FROM_FLOAT16_TO_UINT8, TransDataSrc2Dst<float16, uint8_t>},
  173. {DataTypeTransMode::FROM_INT32_TO_FLOAT, TransDataSrc2Dst<int32_t, float>},
  174. {DataTypeTransMode::FROM_INT32_TO_INT8, TransDataSrc2Dst<int32_t, int8_t>},
  175. {DataTypeTransMode::FROM_INT32_TO_INT64, TransDataSrc2Dst<int32_t, int64_t>},
  176. {DataTypeTransMode::FROM_INT32_TO_UINT8, TransDataSrc2Dst<int32_t, uint8_t>},
  177. {DataTypeTransMode::FROM_INT32_TO_BOOL, TransDataSrc2Dst<int32_t, int8_t>},
  178. {DataTypeTransMode::FROM_INT32_TO_FLOAT16, TransDataSrc2Fp16<int32_t>},
  179. {DataTypeTransMode::FROM_UINT8_TO_FLOAT, TransDataSrc2Dst<uint8_t, float>},
  180. {DataTypeTransMode::FROM_UINT8_TO_INT32, TransDataSrc2Dst<uint8_t, int32_t>},
  181. {DataTypeTransMode::FROM_UINT8_TO_FLOAT16, TransDataSrc2Fp16<uint8_t>},
  182. {DataTypeTransMode::FROM_INT8_TO_FLOAT, TransDataSrc2Dst<int8_t, float>},
  183. {DataTypeTransMode::FROM_INT8_TO_FLOAT16, TransDataSrc2Fp16<int8_t>},
  184. {DataTypeTransMode::FROM_INT8_TO_INT32, TransDataSrc2Dst<int8_t, int32_t>},
  185. {DataTypeTransMode::FROM_INT64_TO_INT32, TransDataSrc2Dst<int64_t, int32_t>},
  186. {DataTypeTransMode::FROM_UINT16_TO_INT32, TransDataSrc2Dst<uint16_t, int32_t>},
  187. {DataTypeTransMode::FROM_BOOL_TO_INT32, TransDataSrc2Dst<int8_t, int32_t>},
  188. {DataTypeTransMode::FROM_BOOL_TO_FLOAT, TransDataSrc2Dst<int8_t, float>},
  189. {DataTypeTransMode::FROM_BOOL_TO_UINT8, TransDataSrc2Dst<int8_t, uint8_t>},
  190. {DataTypeTransMode::FROM_BOOL_TO_FLOAT16, TransDataSrc2Fp16<int8_t>}};
  191. if (mode == DataTypeTransMode::FROM_FLOAT_TO_FLOAT16) {
  192. device::FloatToHalf(dst, args.data, data_size);
  193. return true;
  194. } else if (mode == DataTypeTransMode::FROM_FLOAT16_TO_FLOAT) {
  195. device::HalfToFloat(dst, args.data, data_size);
  196. return true;
  197. }
  198. auto iter = cast_kernel_map.find(mode);
  199. if (iter != cast_kernel_map.end()) {
  200. iter->second(args, dst, data_size);
  201. return true;
  202. } else {
  203. MS_LOG(ERROR) << "Unsupported datatype trans";
  204. return false;
  205. }
  206. }
  207. namespace {
  208. bool HasShapeDynamic(const std::vector<int64_t> &shape_list) {
  209. return std::any_of(shape_list.begin(), shape_list.end(), [](int64_t shape) { return shape == Shape::SHP_ANY; });
  210. }
  211. template <typename T>
  212. bool CheckDims(const std::vector<T> &shape) {
  213. if (shape.size() != kNchwDims) {
  214. MS_LOG(ERROR) << "Host shape dims should be 4";
  215. return false;
  216. }
  217. return true;
  218. }
  219. std::vector<size_t> NchwDeviceShape(const std::vector<size_t> &shape) {
  220. if (!CheckDims(shape)) {
  221. MS_LOG(EXCEPTION) << "Check dims failed.";
  222. }
  223. return shape;
  224. }
  225. std::vector<int64_t> NchwDeviceDynamicShape(const std::vector<int64_t> &shape) {
  226. if (!CheckDims(shape)) {
  227. MS_LOG(EXCEPTION) << "Check dims failed.";
  228. }
  229. return shape;
  230. }
  231. std::vector<size_t> NhwcDeviceShape(const std::vector<size_t> &shape) {
  232. if (!CheckDims(shape)) {
  233. MS_LOG(EXCEPTION) << "Ccheck dims failed.";
  234. }
  235. std::vector<size_t> device_shape;
  236. device_shape.push_back(shape[kN]);
  237. device_shape.push_back(shape[kH]);
  238. device_shape.push_back(shape[kW]);
  239. device_shape.push_back(shape[kC]);
  240. return device_shape;
  241. }
  242. std::vector<int64_t> NhwcDeviceDynamicShape(const std::vector<int64_t> &shape) {
  243. if (!CheckDims(shape)) {
  244. MS_LOG(EXCEPTION) << "Ccheck dims failed.";
  245. }
  246. std::vector<int64_t> device_shape;
  247. device_shape.push_back(shape[kN]);
  248. device_shape.push_back(shape[kH]);
  249. device_shape.push_back(shape[kW]);
  250. device_shape.push_back(shape[kC]);
  251. return device_shape;
  252. }
  253. std::vector<size_t> HwchDeviceShape(const std::vector<size_t> &shape) {
  254. if (!CheckDims(shape)) {
  255. MS_LOG(EXCEPTION) << "Check dims failed.";
  256. }
  257. std::vector<size_t> device_shape;
  258. device_shape.push_back(shape[kH]);
  259. device_shape.push_back(shape[kW]);
  260. device_shape.push_back(shape[kC]);
  261. device_shape.push_back(shape[kN]);
  262. return device_shape;
  263. }
  264. std::vector<int64_t> HwchDeviceDynamicShape(const std::vector<int64_t> &shape) {
  265. if (!CheckDims(shape)) {
  266. MS_LOG(EXCEPTION) << "Check dims failed.";
  267. }
  268. std::vector<int64_t> device_shape;
  269. device_shape.push_back(shape[kH]);
  270. device_shape.push_back(shape[kW]);
  271. device_shape.push_back(shape[kC]);
  272. device_shape.push_back(shape[kN]);
  273. return device_shape;
  274. }
  275. std::vector<size_t> FracZDeviceShape(const std::vector<size_t> &shape) {
  276. if (!CheckDims(shape)) {
  277. MS_LOG(EXCEPTION) << "Check dims failed.";
  278. }
  279. std::vector<size_t> device_shape;
  280. const size_t cout16 = ((shape[kN] + kCubeSize - 1) / kCubeSize) * kCubeSize;
  281. const size_t cin16 = ((shape[kC] + kCubeSize - 1) / kCubeSize) * kCubeSize;
  282. device_shape.push_back(shape[kH] * shape[kW] * cin16 / kCubeSize);
  283. device_shape.push_back(cout16 / kCubeSize);
  284. device_shape.push_back(kCubeSize);
  285. device_shape.push_back(kCubeSize);
  286. return device_shape;
  287. }
  288. std::vector<int64_t> FracZDeviceDynamicShape(const std::vector<int64_t> &shape) {
  289. if (!CheckDims(shape)) {
  290. MS_LOG(EXCEPTION) << "Check dims failed.";
  291. }
  292. auto tmp = SizeToLong(kCubeSize);
  293. std::vector<int64_t> device_shape;
  294. if (HasShapeDynamic({shape[kC], shape[kH], shape[kW]})) {
  295. device_shape.push_back(Shape::SHP_ANY);
  296. } else {
  297. const int64_t cin16 = ((shape[kC] + tmp - 1) / tmp) * tmp;
  298. device_shape.push_back(shape[kH] * shape[kW] * cin16 / tmp);
  299. }
  300. if (shape[kN] == Shape::SHP_ANY) {
  301. device_shape.push_back(Shape::SHP_ANY);
  302. } else {
  303. const int64_t cout16 = ((shape[kN] + tmp - 1) / tmp) * tmp;
  304. device_shape.push_back(cout16 / tmp);
  305. }
  306. device_shape.push_back(tmp);
  307. device_shape.push_back(tmp);
  308. return device_shape;
  309. }
  310. std::vector<size_t> Nc1hwc0DeviceShape(const std::vector<size_t> &shape) {
  311. if (!CheckDims(shape)) {
  312. MS_LOG(EXCEPTION) << "Check dims failed.";
  313. }
  314. std::vector<size_t> device_shape;
  315. const size_t C1 = (shape[kC] + kCubeSize - 1) / kCubeSize;
  316. const size_t C0 = kCubeSize;
  317. device_shape.push_back(shape[kN]);
  318. device_shape.push_back(C1);
  319. device_shape.push_back(shape[kH]);
  320. device_shape.push_back(shape[kW]);
  321. device_shape.push_back(C0);
  322. return device_shape;
  323. }
  324. std::vector<int64_t> Nc1hwc0DeviceDynamicShape(const std::vector<int64_t> &shape) {
  325. if (!CheckDims(shape)) {
  326. MS_LOG(EXCEPTION) << "Check dims failed.";
  327. }
  328. std::vector<int64_t> device_shape;
  329. auto tmp = SizeToLong(kCubeSize);
  330. const int64_t C1 = (shape[kC] == Shape::SHP_ANY) ? Shape::SHP_ANY : (shape[kC] + tmp - 1) / tmp;
  331. const int64_t C0 = tmp;
  332. device_shape.push_back(shape[kN]);
  333. device_shape.push_back(C1);
  334. device_shape.push_back(shape[kH]);
  335. device_shape.push_back(shape[kW]);
  336. device_shape.push_back(C0);
  337. return device_shape;
  338. }
  339. std::vector<size_t> Ndc1hwc0DeviceShape(const std::vector<size_t> &shape) {
  340. // NCDHW
  341. if (shape.size() != kNcdhw) {
  342. MS_LOG(EXCEPTION) << "Check dims failed, expect shape dim 5, but got shape dim : " << shape.size();
  343. }
  344. std::vector<size_t> device_shape;
  345. const size_t C1 = (shape[1] + kCubeSize - 1) / kCubeSize;
  346. const size_t C0 = kCubeSize;
  347. device_shape.push_back(shape[N_ncdhw]);
  348. device_shape.push_back(shape[D_ncdhw]);
  349. device_shape.push_back(C1);
  350. device_shape.push_back(shape[H_ncdhw]);
  351. device_shape.push_back(shape[W_ncdhw]);
  352. device_shape.push_back(C0);
  353. return device_shape;
  354. }
  355. std::vector<int64_t> Ndc1hwc0DeviceDynamicShape(const std::vector<int64_t> &shape) {
  356. // NCDHW
  357. if (shape.size() != kNcdhw) {
  358. MS_LOG(EXCEPTION) << "Check dims failed, expect shape dim 5, but got shape dim : " << shape.size();
  359. }
  360. auto tmp = SizeToLong(kCubeSize);
  361. std::vector<int64_t> device_shape;
  362. const int64_t C1 = (shape[1] == Shape::SHP_ANY) ? Shape::SHP_ANY : (shape[1] + tmp - 1) / tmp;
  363. const int64_t C0 = tmp;
  364. device_shape.push_back(shape[N_ncdhw]);
  365. device_shape.push_back(shape[D_ncdhw]);
  366. device_shape.push_back(C1);
  367. device_shape.push_back(shape[H_ncdhw]);
  368. device_shape.push_back(shape[W_ncdhw]);
  369. device_shape.push_back(C0);
  370. return device_shape;
  371. }
  372. std::vector<size_t> Fracz3DDeviceShape(const std::vector<size_t> &shape) {
  373. // NCDHW -> Frac_Z_3D
  374. if (shape.size() != kNcdhw) {
  375. MS_LOG(EXCEPTION) << "Check dims failed, expect shape dim 5, but got shape dim : " << shape.size();
  376. }
  377. std::vector<size_t> device_shape;
  378. const size_t C1 = (shape[1] + kCubeSize - 1) / kCubeSize;
  379. const size_t N1 = (shape[0] + kCubeSize - 1) / kCubeSize;
  380. device_shape.push_back(shape[D_ncdhw] * C1 * shape[H_ncdhw] * shape[W_ncdhw]);
  381. device_shape.push_back(N1);
  382. device_shape.push_back(kCubeSize);
  383. device_shape.push_back(kCubeSize);
  384. return device_shape;
  385. }
  386. std::vector<int64_t> Fracz3DDeviceDynamicShape(const std::vector<int64_t> &shape) {
  387. // NCDHW -> Frac_Z_3D
  388. if (shape.size() != kNcdhw) {
  389. MS_LOG(EXCEPTION) << "Check dims failed, expect shape dim 5, but got shape dim : " << shape.size();
  390. }
  391. std::vector<int64_t> device_shape;
  392. auto tmp = SizeToLong(kCubeSize);
  393. if (HasShapeDynamic({shape[C_ncdhw], shape[D_ncdhw], shape[H_ncdhw], shape[W_ncdhw]})) {
  394. device_shape.push_back(Shape::SHP_ANY);
  395. } else {
  396. const int64_t C1 = (shape[1] + tmp - 1) / tmp;
  397. device_shape.push_back(shape[D_ncdhw] * C1 * shape[H_ncdhw] * shape[W_ncdhw]);
  398. }
  399. const int64_t N1 = (shape[0] == Shape::SHP_ANY) ? Shape::SHP_ANY : (shape[0] + tmp - 1) / tmp;
  400. device_shape.push_back(N1);
  401. device_shape.push_back(tmp);
  402. device_shape.push_back(tmp);
  403. return device_shape;
  404. }
  405. std::vector<size_t> C1hwncoc0DeviceShape(const std::vector<size_t> &shape) {
  406. if (!CheckDims(shape)) {
  407. MS_LOG(EXCEPTION) << "Check dims failed.";
  408. }
  409. std::vector<size_t> device_shape;
  410. device_shape.push_back((shape[kC] - 1) / kCubeSize + 1);
  411. device_shape.push_back(shape[kH]);
  412. device_shape.push_back(shape[kW]);
  413. device_shape.push_back(shape[kN]);
  414. device_shape.push_back(kCubeSize);
  415. device_shape.push_back(kCubeSize);
  416. return device_shape;
  417. }
  418. std::vector<int64_t> C1hwncoc0DeviceDynamicShape(const std::vector<int64_t> &shape) {
  419. if (!CheckDims(shape)) {
  420. MS_LOG(EXCEPTION) << "Check dims failed.";
  421. }
  422. std::vector<int64_t> device_shape;
  423. auto tmp = SizeToLong(kCubeSize);
  424. shape[kC] == Shape::SHP_ANY ? device_shape.push_back(Shape::SHP_ANY)
  425. : device_shape.push_back((shape[kC] - 1) / tmp + 1);
  426. device_shape.push_back(shape[kH]);
  427. device_shape.push_back(shape[kW]);
  428. device_shape.push_back(shape[kN]);
  429. device_shape.push_back(tmp);
  430. device_shape.push_back(tmp);
  431. return device_shape;
  432. }
  433. std::vector<size_t> FracZc04DeviceShape(const std::vector<size_t> &shape) {
  434. if (!CheckDims(shape)) {
  435. MS_LOG(EXCEPTION) << "Check dims failed.";
  436. }
  437. std::vector<size_t> device_shape;
  438. const size_t c0 = 4;
  439. auto first_dim = DivCeil(c0 * shape[kH] * shape[kW], kCubeSize);
  440. auto no = DivCeil(shape.at(kN), kCubeSize);
  441. device_shape.push_back(first_dim);
  442. device_shape.push_back(no);
  443. device_shape.push_back(kCubeSize);
  444. device_shape.push_back(kCubeSize);
  445. return device_shape;
  446. }
  447. std::vector<int64_t> FracZc04DeviceDynamicShape(const std::vector<int64_t> &shape) {
  448. if (!CheckDims(shape)) {
  449. MS_LOG(EXCEPTION) << "Check dims failed.";
  450. }
  451. std::vector<int64_t> device_shape;
  452. const int64_t c0 = 4;
  453. int64_t first_dim;
  454. if (HasShapeDynamic({shape[kH], shape[kW]})) {
  455. first_dim = Shape::SHP_ANY;
  456. } else {
  457. first_dim = DivCeil(c0 * shape[kH] * shape[kW], SizeToLong(kCubeSize));
  458. }
  459. auto shape_kN = shape.at(kN);
  460. int64_t no = (shape_kN == Shape::SHP_ANY) ? Shape::SHP_ANY : DivCeil(shape.at(kN), SizeToLong(kCubeSize));
  461. device_shape.push_back(first_dim);
  462. device_shape.push_back(no);
  463. device_shape.push_back(SizeToLong(kCubeSize));
  464. device_shape.push_back(SizeToLong(kCubeSize));
  465. return device_shape;
  466. }
  467. std::vector<size_t> Nc1hwc04DeviceShape(const std::vector<size_t> &shape) {
  468. if (!CheckDims(shape)) {
  469. MS_LOG(EXCEPTION) << "Check dims failed.";
  470. }
  471. std::vector<size_t> device_shape;
  472. const size_t C1 = 1;
  473. const size_t C0 = 4;
  474. device_shape.push_back(shape[kN]);
  475. device_shape.push_back(C1);
  476. device_shape.push_back(shape[kH]);
  477. device_shape.push_back(shape[kW]);
  478. device_shape.push_back(C0);
  479. return device_shape;
  480. }
  481. std::vector<int64_t> Nc1hwc04DeviceDynamicShape(const std::vector<int64_t> &shape) {
  482. if (!CheckDims(shape)) {
  483. MS_LOG(EXCEPTION) << "Check dims failed.";
  484. }
  485. std::vector<int64_t> device_shape;
  486. const int64_t C1 = 1;
  487. const int64_t C0 = 4;
  488. device_shape.push_back(shape[kN]);
  489. device_shape.push_back(C1);
  490. device_shape.push_back(shape[kH]);
  491. device_shape.push_back(shape[kW]);
  492. device_shape.push_back(C0);
  493. return device_shape;
  494. }
  495. std::vector<size_t> NcdhwDeviceShape(const std::vector<size_t> &shape) {
  496. if (shape.size() < kNcdhw) {
  497. MS_LOG(EXCEPTION) << "Shape dims must be 5 when format is ndhwc.";
  498. }
  499. return shape;
  500. }
  501. std::vector<int64_t> NcdhwDeviceDynamicShape(const std::vector<int64_t> &shape) {
  502. if (shape.size() < kNcdhw) {
  503. MS_LOG(EXCEPTION) << "Shape dims must be 5 when format is ndhwc.";
  504. }
  505. return shape;
  506. }
  507. // change channel-first shape to channel-last shape.
  508. // eg. [2,3,4] => [2,4,3]; [2,3,4,5] => [2,4,5,3]
  509. std::vector<size_t> ChannelLastDeviceShape(const std::vector<size_t> &shape) {
  510. auto dim = shape.size();
  511. std::vector<size_t> axis;
  512. axis.resize(dim);
  513. int step_value = 2;
  514. std::iota(axis.begin() + 1, axis.end(), step_value);
  515. axis[dim - 1] = 1;
  516. std::vector<size_t> device_shape;
  517. (void)std::transform(axis.begin(), axis.end(), std::back_inserter(device_shape),
  518. [&shape](size_t n) { return shape[n]; });
  519. return device_shape;
  520. }
  521. // change channel-first shape to channel-last shape.
  522. // eg. [2,3,4] => [2,4,3]; [2,3,4,5] => [2,4,5,3]
  523. std::vector<int64_t> ChannelLastDeviceDynamicShape(const std::vector<int64_t> &shape) {
  524. auto dim = shape.size();
  525. std::vector<int64_t> axis;
  526. axis.resize(dim);
  527. int step_value = 2;
  528. std::iota(axis.begin() + 1, axis.end(), step_value);
  529. axis[dim - 1] = 1;
  530. std::vector<int64_t> device_shape;
  531. (void)std::transform(axis.begin(), axis.end(), std::back_inserter(device_shape),
  532. [&shape](size_t n) { return shape[n]; });
  533. return device_shape;
  534. }
  535. std::vector<size_t> FracZDeviceShapeWithGroups(const std::vector<size_t> &shape, const int64_t groups = 1) {
  536. if (!CheckDims(shape)) {
  537. MS_LOG(EXCEPTION) << "Check dims failed.";
  538. }
  539. if (groups <= 0) {
  540. MS_LOG(EXCEPTION) << "The value of groups should be greater than 0, but got " << groups;
  541. }
  542. size_t group_size = LongToSize(groups);
  543. size_t cin_ori = shape[kC];
  544. size_t cout_ori = shape[kN] / group_size;
  545. size_t e_mult = std::min(Lcm(Lcm(cin_ori, kCubeSize) / cin_ori, Lcm(cout_ori, kCubeSize) / cout_ori), group_size);
  546. size_t cin_opt = DivCeil(e_mult * cin_ori, kCubeSize) * kCubeSize;
  547. size_t c1_dim = cin_opt / kCubeSize;
  548. size_t g_dim = DivCeil(group_size, e_mult);
  549. size_t n1 = DivCeil(cout_ori * e_mult, kCubeSize);
  550. std::vector<size_t> device_shape;
  551. device_shape.push_back(g_dim * c1_dim * shape[kH] * shape[kW]);
  552. device_shape.push_back(n1);
  553. device_shape.push_back(kNiSize);
  554. device_shape.push_back(kCubeSize);
  555. return device_shape;
  556. }
  557. std::vector<int64_t> FracZDeviceShapeWithGroups(const std::vector<int64_t> &shape, const int64_t groups = 1) {
  558. if (!CheckDims(shape)) {
  559. MS_LOG(EXCEPTION) << "Check dims failed.";
  560. }
  561. int64_t c1_dim = Shape::SHP_ANY;
  562. int64_t g_dim = Shape::SHP_ANY;
  563. int64_t n1 = Shape::SHP_ANY;
  564. if (groups <= 0) {
  565. MS_LOG(EXCEPTION) << "The value of groups should be greater than 0, but got " << groups;
  566. }
  567. auto tmp = SizeToLong(kCubeSize);
  568. if (!HasShapeDynamic({shape[kC], shape[kN]})) {
  569. size_t group_size = LongToSize(groups);
  570. size_t cin_ori_tmp = LongToSize(shape[kC]);
  571. size_t cout_ori_tmp = LongToSize(shape[kN]) / group_size;
  572. size_t e_mult =
  573. std::min(Lcm(Lcm(cin_ori_tmp, kCubeSize) / cin_ori_tmp, Lcm(cout_ori_tmp, kCubeSize) / cout_ori_tmp), group_size);
  574. int64_t cin_opt = SizeToLong(DivCeil(e_mult * cin_ori_tmp, kCubeSize) * kCubeSize);
  575. c1_dim = cin_opt / tmp;
  576. g_dim = SizeToLong(DivCeil(group_size, e_mult));
  577. n1 = SizeToLong(DivCeil(cout_ori_tmp * e_mult, kCubeSize));
  578. }
  579. std::vector<int64_t> device_shape;
  580. if (!HasShapeDynamic({shape[kC], shape[kN], shape[kH], shape[kW]})) {
  581. device_shape.push_back(g_dim * c1_dim * shape[kH] * shape[kW]);
  582. } else {
  583. device_shape.push_back(Shape::SHP_ANY);
  584. }
  585. device_shape.push_back(n1);
  586. device_shape.push_back(SizeToLong(kNiSize));
  587. device_shape.push_back(tmp);
  588. return device_shape;
  589. }
  590. std::vector<size_t> FracNZDeviceShape(const std::vector<size_t> &shape) {
  591. if (shape.size() == 1 && (shape[0] == 1 || shape[0] % kCubeSize == 0)) {
  592. // For [1] and [1024] shape we can trait it as NZ shape
  593. return shape;
  594. }
  595. std::vector<size_t> device_shape;
  596. if (shape.size() < kShape2dDims) {
  597. MS_LOG(EXCEPTION) << "Format FRACTAL_NZ don't support shape with " << shape.size() << " dims";
  598. } else {
  599. const auto remove_dim = 2;
  600. (void)std::copy(shape.begin(), shape.end() - remove_dim, std::back_inserter(device_shape));
  601. }
  602. auto h1 = (shape[shape.size() - kDim2] - 1) / kCubeSize + 1;
  603. auto w1 = (shape[shape.size() - kDim1] - 1) / kCubeSize + 1;
  604. device_shape.push_back(w1);
  605. device_shape.push_back(h1);
  606. device_shape.push_back(kCubeSize);
  607. device_shape.push_back(kCubeSize);
  608. return device_shape;
  609. }
  610. std::vector<int64_t> FracNZDeviceDynamicShape(const std::vector<int64_t> &shape) {
  611. std::vector<int64_t> device_shape;
  612. if (shape.size() == 1 && (shape[0] == 1 || shape[0] % SizeToLong(kCubeSize) == 0)) {
  613. // For [1] and [1024] shape we can trait it as NZ shape
  614. return shape;
  615. }
  616. if (shape.size() < kShape2dDims) {
  617. MS_LOG(EXCEPTION) << "Format FRACTAL_NZ don't support shape with " << shape.size() << " dims";
  618. } else {
  619. (void)std::copy(shape.begin(), shape.end() - kDim2, std::back_inserter(device_shape));
  620. }
  621. int64_t h_shape = shape[shape.size() - kDim2];
  622. int64_t w_shape = shape[shape.size() - kDim1];
  623. int64_t h1 = (h_shape == Shape::SHP_ANY) ? Shape::SHP_ANY : (h_shape - 1) / SizeToLong(kCubeSize) + 1;
  624. int64_t w1 = (w_shape == Shape::SHP_ANY) ? Shape::SHP_ANY : (w_shape - 1) / SizeToLong(kCubeSize) + 1;
  625. device_shape.push_back(w1);
  626. device_shape.push_back(h1);
  627. device_shape.push_back(kCubeSize);
  628. device_shape.push_back(kCubeSize);
  629. return device_shape;
  630. }
  631. std::vector<size_t> FracNZLSTMDeviceShape(const std::vector<size_t> &shape) {
  632. const size_t c0 = 4;
  633. const size_t h = shape.at(kN) / c0;
  634. const size_t i = shape.at(kC) - h;
  635. const size_t first = DivCeil(i, kCubeSize) + DivCeil(h, kCubeSize);
  636. const size_t second = c0 * DivCeil(h, kCubeSize);
  637. std::vector<size_t> device_shape;
  638. device_shape.push_back(first);
  639. device_shape.push_back(second);
  640. device_shape.push_back(kCubeSize);
  641. device_shape.push_back(kCubeSize);
  642. return device_shape;
  643. }
  644. std::vector<int64_t> FracNZLSTMDeviceDynamicShape(const std::vector<int64_t> &shape) {
  645. std::vector<int64_t> device_shape;
  646. const int64_t c0 = 4;
  647. const int64_t h_shape = shape.at(kN);
  648. const int64_t i_shape = shape.at(kC);
  649. const int64_t h = (h_shape == Shape::SHP_ANY) ? Shape::SHP_ANY : h_shape / c0;
  650. int64_t first = Shape::SHP_ANY;
  651. if (h_shape != Shape::SHP_ANY && i_shape != Shape::SHP_ANY) {
  652. int64_t i = i_shape - h;
  653. first = DivCeil(i, SizeToLong(kCubeSize)) + DivCeil(h, SizeToLong(kCubeSize));
  654. }
  655. const int64_t second = (h == Shape::SHP_ANY) ? Shape::SHP_ANY : c0 * DivCeil(h, SizeToLong(kCubeSize));
  656. device_shape.push_back(first);
  657. device_shape.push_back(second);
  658. device_shape.push_back(kCubeSize);
  659. device_shape.push_back(kCubeSize);
  660. return device_shape;
  661. }
  662. std::vector<size_t> FracZNRNNDeviceShape(const std::vector<size_t> &shape,
  663. const std::vector<int64_t> &input_hidden_size = {kAlign16, kAlign16}) {
  664. if (shape.size() < kShape2dDims) {
  665. MS_LOG(EXCEPTION) << "Format FRACTAL_ZN_RNN don't support shape with " << shape.size() << " dims";
  666. }
  667. size_t input_size = LongToSize(input_hidden_size[0]);
  668. size_t hidden_size = LongToSize(input_hidden_size[1]);
  669. auto dim_last1 = shape[shape.size() - kDim1];
  670. auto dim_last2 = shape[shape.size() - kDim2];
  671. if (hidden_size == 0) {
  672. MS_LOG(EXCEPTION) << "Hidden_size should not be 0.";
  673. }
  674. // cppcheck-suppress *
  675. if (dim_last1 % hidden_size != 0) {
  676. MS_LOG(EXCEPTION) << "Last dim of shape " << shape << " should be multiple of hidden_size " << hidden_size;
  677. }
  678. size_t n_num = dim_last1 / hidden_size;
  679. const size_t NUM16 = 16;
  680. const size_t C0 = kCubeSize;
  681. std::vector<size_t> device_shape = shape;
  682. if (dim_last2 == input_size || dim_last2 == hidden_size) {
  683. device_shape[shape.size() - kDim2] = DivCeil(dim_last2, NUM16);
  684. } else if (dim_last2 == input_size + hidden_size) {
  685. device_shape[shape.size() - kDim2] = DivCeil(input_size, NUM16) + DivCeil(hidden_size, NUM16);
  686. } else {
  687. MS_LOG(EXCEPTION) << "The second-last dim value of shape is invalid. Should be equal to `input_size` or "
  688. "`hidden_size` or `input_size + hidden_size`, but got second-last dim value: "
  689. << dim_last2 << " input_size: " << input_size << " hidden_size: " << hidden_size;
  690. }
  691. device_shape[shape.size() - 1] = n_num * DivCeil(hidden_size, C0);
  692. device_shape.push_back(NUM16);
  693. device_shape.push_back(C0);
  694. return device_shape;
  695. }
  696. std::vector<int64_t> FracZNRNNDeviceDynamicShape(const std::vector<int64_t> &shape,
  697. const std::vector<int64_t> &input_hidden_size = {kAlign16, kAlign16}) {
  698. if (shape.size() < kShape2dDims) {
  699. MS_LOG(EXCEPTION) << "Format FRACTAL_NZ_RNN don't support shape with " << shape.size() << " dims";
  700. }
  701. int64_t input_size = input_hidden_size[0];
  702. int64_t hidden_size = input_hidden_size[1];
  703. auto dim_last1 = shape[shape.size() - kDim1];
  704. auto dim_last2 = shape[shape.size() - kDim2];
  705. const int64_t NUM16 = 16;
  706. const int64_t C0 = SizeToLong(kCubeSize);
  707. std::vector<int64_t> device_shape = shape;
  708. if (dim_last2 == Shape::SHP_ANY) {
  709. device_shape[shape.size() - kDim2] = Shape::SHP_ANY;
  710. } else if (dim_last2 == input_size || dim_last2 == hidden_size) {
  711. device_shape[shape.size() - kDim2] = DivCeil(dim_last2, NUM16);
  712. } else if (dim_last2 == input_size + hidden_size) {
  713. device_shape[shape.size() - kDim2] = DivCeil(input_size, NUM16) + DivCeil(hidden_size, NUM16);
  714. } else {
  715. MS_LOG(EXCEPTION) << "The second-last dim value of shape is invalid. Should be equal to `input_size` or "
  716. "`hidden_size` or `input_size + hidden_size` or `-1`, but got second-last dim value: "
  717. << dim_last2 << " input_size: " << input_size << " hidden_size: " << hidden_size;
  718. }
  719. if (dim_last1 == Shape::SHP_ANY) {
  720. device_shape[shape.size() - kDim1] = Shape::SHP_ANY;
  721. } else {
  722. if (dim_last1 % hidden_size != 0) {
  723. MS_LOG(EXCEPTION) << "Last dim of shape " << shape << " should be multiple of hidden_size " << hidden_size;
  724. }
  725. int64_t n_num = shape[shape.size() - 1] / hidden_size;
  726. device_shape[shape.size() - kDim1] = n_num * DivCeil(hidden_size, C0);
  727. }
  728. device_shape.push_back(NUM16);
  729. device_shape.push_back(C0);
  730. return device_shape;
  731. }
  732. std::vector<size_t> NDRNNBiasDeviceShape(const std::vector<size_t> &shape, const int64_t hidden_size = 16) {
  733. if (shape.empty()) {
  734. MS_LOG(EXCEPTION) << "Format ND_RNN_BIAS don't support empty shape.";
  735. }
  736. if (hidden_size <= 0) {
  737. MS_LOG(EXCEPTION) << "Hidden_size should be greater than 0, but got " << hidden_size;
  738. }
  739. size_t hid_size = LongToSize(hidden_size);
  740. // cppcheck-suppress *
  741. if (shape[shape.size() - 1] % hid_size != 0) {
  742. MS_LOG(EXCEPTION) << "Last dim of shape " << shape << " should be multiple of hidden_size " << hid_size;
  743. }
  744. size_t n_num = shape[shape.size() - 1] / hid_size;
  745. const size_t C0 = kCubeSize;
  746. std::vector<size_t> device_shape = shape;
  747. device_shape[shape.size() - 1] = n_num * DivCeil(hid_size, C0) * C0;
  748. return device_shape;
  749. }
  750. std::vector<int64_t> NDRNNBiasDeviceDynamicShape(const std::vector<int64_t> &shape, const int64_t hidden_size = 16) {
  751. if (shape.empty()) {
  752. MS_LOG(EXCEPTION) << "Format ND_RNN_BIAS don't support empty shape.";
  753. }
  754. const int64_t C0 = SizeToLong(kCubeSize);
  755. std::vector<int64_t> device_shape = shape;
  756. // cppcheck-suppress *
  757. auto dim_last1 = shape[shape.size() - 1];
  758. if (dim_last1 == Shape::SHP_ANY) {
  759. device_shape[shape.size() - 1] = Shape::SHP_ANY;
  760. } else {
  761. if (hidden_size <= 0) {
  762. MS_LOG(EXCEPTION) << "Hidden_size should be greater than 0, but got " << hidden_size;
  763. }
  764. // cppcheck-suppress *
  765. if (dim_last1 % hidden_size != 0) {
  766. MS_LOG(EXCEPTION) << "Last dim of shape " << shape << " should be multiple of hidden_size " << hidden_size;
  767. }
  768. int64_t n_num = shape[shape.size() - 1] / hidden_size;
  769. device_shape[shape.size() - 1] = n_num * DivCeil(hidden_size, C0) * C0;
  770. }
  771. return device_shape;
  772. }
  773. } // namespace
  774. int64_t GetAttrGroups(const AnfNodePtr &node, const size_t index) {
  775. if (node == nullptr) {
  776. return 1;
  777. }
  778. if (node->isa<CNode>()) {
  779. auto cnode = node->cast<CNodePtr>();
  780. if (AnfAlgo::HasNodeAttr(kAttrFracZGroup, cnode)) {
  781. auto node_name = AnfAlgo::GetCNodeName(cnode);
  782. if (node_name == kAllReduceOpName || node_name == kBroadcastOpName) {
  783. // if index not exists in fracz_group_idx, return default value 1
  784. auto fz_group_idx = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(cnode, kAttrFracZGroupIdx);
  785. int64_t out_index = SizeToLong(index);
  786. auto fz_iter = std::find(std::begin(fz_group_idx), std::end(fz_group_idx), out_index);
  787. if (fz_iter == std::end(fz_group_idx)) {
  788. return 1;
  789. }
  790. }
  791. return AnfAlgo::GetNodeAttr<int64_t>(cnode, kAttrFracZGroup);
  792. }
  793. } else if (node->isa<Parameter>()) {
  794. auto param = node->cast<ParameterPtr>();
  795. MS_EXCEPTION_IF_NULL(param);
  796. return param->fracz_group();
  797. }
  798. return 1;
  799. }
  800. std::vector<int64_t> GetAttrInputAndHiddenSize(const AnfNodePtr &node) {
  801. MS_EXCEPTION_IF_NULL(node);
  802. std::vector<int64_t> input_hidden_size = {kAlign16, kAlign16};
  803. if (!node->isa<CNode>() && !node->isa<Parameter>()) {
  804. return input_hidden_size;
  805. }
  806. if (node->isa<Parameter>()) {
  807. auto param = node->cast<ParameterPtr>();
  808. input_hidden_size[0] = param->input_size();
  809. input_hidden_size[1] = param->hidden_size();
  810. } else {
  811. CNodePtr cnode = node->cast<CNodePtr>();
  812. if (cnode == nullptr || !AnfAlgo::HasNodeAttr(kAttrHiddenSize, cnode) ||
  813. !AnfAlgo::HasNodeAttr(kAttrInputSize, cnode)) {
  814. MS_LOG(EXCEPTION)
  815. << "Node with format FRACTAL_ZN_RNN or ND_RNN_BIAS should have hidden_size or input_size attr. Node info:"
  816. << node->DebugString();
  817. }
  818. input_hidden_size[0] = AnfAlgo::GetNodeAttr<int64_t>(cnode, kAttrInputSize);
  819. input_hidden_size[1] = AnfAlgo::GetNodeAttr<int64_t>(cnode, kAttrHiddenSize);
  820. }
  821. return input_hidden_size;
  822. }
  823. bool IsNeedPadding(const std::string &format, const size_t shape_size) {
  824. if (shape_size == 0) {
  825. return false;
  826. }
  827. if (format == kOpFormat_DEFAULT || format == kOpFormat_NCHW ||
  828. kNoPaddingFormatSet.find(format) != kNoPaddingFormatSet.end()) {
  829. return false;
  830. } else if (shape_size < kNchwDims) {
  831. return true;
  832. }
  833. return false;
  834. }
  835. ShapeVector GetRuntimePaddingShape(const AnfNodePtr &node, size_t index) {
  836. MS_EXCEPTION_IF_NULL(node);
  837. ShapeVector shape;
  838. std::vector<size_t> host_shape;
  839. if (node->isa<ValueNode>()) {
  840. auto value_node = node->cast<ValueNodePtr>();
  841. MS_EXCEPTION_IF_NULL(value_node);
  842. auto node_value = value_node->value();
  843. MS_EXCEPTION_IF_NULL(node_value);
  844. auto tensor = node_value->cast<tensor::TensorPtr>();
  845. if (tensor == nullptr) {
  846. MS_LOG(EXCEPTION) << " The node[ " << node->DebugString() << "]'s cannot convert ";
  847. }
  848. auto shape_temp = tensor->shape();
  849. (void)std::transform(shape_temp.begin(), shape_temp.end(), std::back_inserter(host_shape), LongToSize);
  850. if (host_shape.empty()) {
  851. host_shape.push_back(1);
  852. }
  853. } else {
  854. host_shape = AnfAlgo::GetOutputInferShape(node, index);
  855. }
  856. auto format = AnfAlgo::GetOutputFormat(node, index);
  857. if (trans::IsNeedPadding(format, host_shape.size())) {
  858. host_shape = trans::PaddingShape(host_shape, format, AnfAlgo::GetOutputReshapeType(node, index));
  859. }
  860. std::transform(host_shape.begin(), host_shape.end(), std::back_inserter(shape), SizeToLong);
  861. return shape;
  862. }
  863. void StringToAxisVector4D(const std::string &reshape_type_str, std::vector<Axis> *reshape_type_vec) {
  864. MS_EXCEPTION_IF_NULL(reshape_type_vec);
  865. if (reshape_type_str.empty()) {
  866. MS_LOG(DEBUG) << "Reshape type str is empty, no need padding.";
  867. return;
  868. }
  869. for (const auto &c : reshape_type_str) {
  870. switch (c) {
  871. case 'N':
  872. reshape_type_vec->push_back(N);
  873. break;
  874. case 'C':
  875. reshape_type_vec->push_back(C);
  876. break;
  877. case 'H':
  878. reshape_type_vec->push_back(H);
  879. break;
  880. case 'W':
  881. reshape_type_vec->push_back(W);
  882. break;
  883. default:
  884. MS_LOG(EXCEPTION) << "Unknown axis " << c << "in reshape type.";
  885. }
  886. }
  887. }
  888. void StringToAxisVector5D(const std::string &reshape_type_str, std::vector<Axis5D> *reshape_type_vec) {
  889. MS_EXCEPTION_IF_NULL(reshape_type_vec);
  890. if (reshape_type_str.empty()) {
  891. MS_LOG(DEBUG) << "Reshape type str is empty, no need padding.";
  892. return;
  893. }
  894. for (const auto &c : reshape_type_str) {
  895. switch (c) {
  896. case 'N':
  897. reshape_type_vec->push_back(N_ncdhw);
  898. break;
  899. case 'C':
  900. reshape_type_vec->push_back(C_ncdhw);
  901. break;
  902. case 'D':
  903. reshape_type_vec->push_back(D_ncdhw);
  904. break;
  905. case 'H':
  906. reshape_type_vec->push_back(H_ncdhw);
  907. break;
  908. case 'W':
  909. reshape_type_vec->push_back(W_ncdhw);
  910. break;
  911. default:
  912. MS_LOG(EXCEPTION) << "Unknown axis " << c << "in reshape type.";
  913. }
  914. }
  915. }
  916. std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const std::string &format,
  917. const int64_t groups, const std::vector<int64_t> &input_hidden_size) {
  918. using DeviceShapeTransfer = std::function<std::vector<size_t>(const std::vector<size_t> &)>;
  919. const std::map<std::string, DeviceShapeTransfer> device_shape_map{{kOpFormat_NCHW, NchwDeviceShape},
  920. {kOpFormat_NHWC, NhwcDeviceShape},
  921. {kOpFormat_HWCN, HwchDeviceShape},
  922. {kOpFormat_FRAC_Z, FracZDeviceShape},
  923. {kOpFormat_NC1HWC0, Nc1hwc0DeviceShape},
  924. {kOpFormat_C1HWNCoC0, C1hwncoc0DeviceShape},
  925. {kOpFormat_FRACTAL_Z_C04, FracZc04DeviceShape},
  926. {kOpFormat_NC1HWC0_C04, Nc1hwc04DeviceShape},
  927. {kOpFormat_NCDHW, NcdhwDeviceShape},
  928. {kOpFormat_ChannelLast, ChannelLastDeviceShape},
  929. {kOpFormat_NDC1HWC0, Ndc1hwc0DeviceShape},
  930. {kOpFormat_FRACTAL_Z_3D, Fracz3DDeviceShape},
  931. {kOpFormat_FRAC_NZ, FracNZDeviceShape},
  932. {kOpFormat_FRACTAL_ZN_LSTM, FracNZLSTMDeviceShape}};
  933. if (format == kOpFormat_ND || format == kOpFormat_DEFAULT) {
  934. return shape;
  935. }
  936. if (groups > 1 && format == kOpFormat_FRAC_Z) {
  937. return FracZDeviceShapeWithGroups(shape, groups);
  938. }
  939. if (format == kOpFormat_FRACTAL_ZN_RNN) {
  940. return FracZNRNNDeviceShape(shape, input_hidden_size);
  941. }
  942. if (format == kOpFormat_ND_RNN_BIAS) {
  943. return NDRNNBiasDeviceShape(shape, input_hidden_size[1]);
  944. }
  945. auto temp_shape = shape;
  946. if (kNoPaddingFormatSet.find(format) == kNoPaddingFormatSet.end() && format != kOpFormat_FRACTAL_ZN_LSTM &&
  947. shape.size() != kNchwDims && k3DFormatSet.find(format) == k3DFormatSet.end()) {
  948. MS_LOG(WARNING) << "Get Device Shape using a shape size is less than 4 ,should be Padding shape by Default firstly";
  949. temp_shape = PaddingShapeTo4dDefault(shape);
  950. }
  951. if (shape.size() != kNcdhw && k3DFormatSet.find(format) != k3DFormatSet.end()) {
  952. temp_shape = PaddingShapeTo5dDefault(shape);
  953. }
  954. auto iter = device_shape_map.find(format);
  955. if (iter == device_shape_map.end()) {
  956. MS_LOG(EXCEPTION) << "Unexpected format[" << format << "]";
  957. }
  958. return iter->second(temp_shape);
  959. }
  960. std::vector<int64_t> TransShapeToDevice(const std::vector<int64_t> &shape, const std::string &format,
  961. const int64_t groups, const std::vector<int64_t> &input_hidden_size) {
  962. using DeviceShapeTransfer = std::function<std::vector<int64_t>(const std::vector<int64_t> &)>;
  963. const std::map<std::string, DeviceShapeTransfer> device_shape_map{
  964. {kOpFormat_NCHW, NchwDeviceDynamicShape},
  965. {kOpFormat_NHWC, NhwcDeviceDynamicShape},
  966. {kOpFormat_HWCN, HwchDeviceDynamicShape},
  967. {kOpFormat_FRAC_Z, FracZDeviceDynamicShape},
  968. {kOpFormat_NC1HWC0, Nc1hwc0DeviceDynamicShape},
  969. {kOpFormat_C1HWNCoC0, C1hwncoc0DeviceDynamicShape},
  970. {kOpFormat_FRACTAL_Z_C04, FracZc04DeviceDynamicShape},
  971. {kOpFormat_NC1HWC0_C04, Nc1hwc04DeviceDynamicShape},
  972. {kOpFormat_NCDHW, NcdhwDeviceDynamicShape},
  973. {kOpFormat_ChannelLast, ChannelLastDeviceDynamicShape},
  974. {kOpFormat_NDC1HWC0, Ndc1hwc0DeviceDynamicShape},
  975. {kOpFormat_FRACTAL_Z_3D, Fracz3DDeviceDynamicShape},
  976. {kOpFormat_FRAC_NZ, FracNZDeviceDynamicShape},
  977. {kOpFormat_FRACTAL_ZN_LSTM, FracNZLSTMDeviceDynamicShape}};
  978. if (format == kOpFormat_ND || format == kOpFormat_DEFAULT || format == kOpFormat_NCHW) {
  979. return shape;
  980. }
  981. if (groups > 1 && format == kOpFormat_FRAC_Z) {
  982. return FracZDeviceShapeWithGroups(shape, groups);
  983. }
  984. if (format == kOpFormat_FRACTAL_ZN_RNN) {
  985. return FracZNRNNDeviceDynamicShape(shape, input_hidden_size);
  986. }
  987. if (format == kOpFormat_ND_RNN_BIAS) {
  988. return NDRNNBiasDeviceDynamicShape(shape, input_hidden_size[1]);
  989. }
  990. auto temp_shape = shape;
  991. if (kNoPaddingFormatSet.find(format) == kNoPaddingFormatSet.end() && format != kOpFormat_FRACTAL_ZN_LSTM &&
  992. shape.size() != kNchwDims && k3DFormatSet.find(format) == k3DFormatSet.end()) {
  993. MS_LOG(WARNING) << "Get Device Shape using a shape size is less than 4 ,should be Padding shape by Default firstly";
  994. temp_shape = PaddingShapeTo4dDefault(shape);
  995. }
  996. if (shape.size() != kNcdhw && k3DFormatSet.find(format) != k3DFormatSet.end()) {
  997. temp_shape = PaddingShapeTo5dDefault(shape);
  998. }
  999. auto iter = device_shape_map.find(format);
  1000. if (iter == device_shape_map.end()) {
  1001. MS_LOG(EXCEPTION) << "Unexpected format[" << format << "]";
  1002. }
  1003. return iter->second(temp_shape);
  1004. }
  1005. bool CheckArgs(const FormatArgs &args, size_t *size, size_t *total_size) {
  1006. if (args.host_shape.size() != kNchwDims) {
  1007. MS_LOG(ERROR) << "Invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims;
  1008. return false;
  1009. }
  1010. MS_EXCEPTION_IF_NULL(size);
  1011. MS_EXCEPTION_IF_NULL(total_size);
  1012. *size = abstract::TypeIdSize(args.src_data_type);
  1013. if (*size < 1) {
  1014. MS_LOG(ERROR) << "Illegal dtype.";
  1015. return false;
  1016. }
  1017. *total_size = abstract::ShapeSize(args.device_shape) * (*size);
  1018. if (*total_size != args.device_size) {
  1019. MS_LOG(ERROR) << "Illegal total data size, total_size:" << *total_size << ", device_size:" << args.device_size;
  1020. return false;
  1021. }
  1022. return true;
  1023. }
  1024. bool TransDataType(const TypeIdArgs &args, void *result) {
  1025. MS_LOG(DEBUG) << "Begin trans datatype from " << TypeIdLabel(args.host_data_type) << " to "
  1026. << TypeIdLabel(args.device_data_type);
  1027. MS_EXCEPTION_IF_NULL(result);
  1028. std::pair<TypeId, TypeId> type_info(args.host_data_type, args.device_data_type);
  1029. auto iter = mode_map.find(type_info);
  1030. if (iter == mode_map.end()) {
  1031. MS_LOG(ERROR) << "Unsupported datatype trans. src_type :" << TypeIdLabel(args.host_data_type)
  1032. << ", dst_type:" << TypeIdLabel(args.device_data_type);
  1033. return false;
  1034. }
  1035. auto trans_mode = iter->second;
  1036. if (!CastKernel(args, result, args.host_shape_size, trans_mode)) {
  1037. MS_LOG(ERROR) << "Failed to trans datatype..";
  1038. return false;
  1039. }
  1040. return true;
  1041. }
  1042. bool TransFormat(const FormatArgs &args, void *result, int64_t groups) {
  1043. MS_LOG(DEBUG) << "Start trans format.";
  1044. if (abstract::TypeIdSize(args.src_data_type) < 1) {
  1045. MS_LOG(ERROR) << "Invalid datatype..";
  1046. return false;
  1047. }
  1048. if (args.device_format == kOpFormat_HWCN || args.device_format == kOpFormat_NHWC) {
  1049. return NchwTo4D(args, result);
  1050. }
  1051. if (groups > 1 && args.device_format == kOpFormat_FRAC_Z) {
  1052. return NchwToFracZWithGroups(args, result, groups);
  1053. }
  1054. auto iter = kTransFormatMapOfHostToDevice.find(args.device_format);
  1055. if (iter == kTransFormatMapOfHostToDevice.end()) {
  1056. MS_LOG(EXCEPTION) << "Unexpected format[" << args.device_format << "]";
  1057. }
  1058. return iter->second(args, result);
  1059. }
  1060. bool TransFormat(const FormatArgs &args, void *result, const AnfNodePtr &node, const size_t index) {
  1061. int64_t groups = 1;
  1062. if (args.device_format == kOpFormat_FRAC_Z) {
  1063. groups = GetAttrGroups(node, index);
  1064. }
  1065. return TransFormat(args, result, groups);
  1066. }
  1067. bool TransFormatFromDeviceToHost(const FormatArgs &args, void *result, int64_t groups) {
  1068. const std::map<std::string, FormatTransfer> format_trans_map{
  1069. {kOpFormat_FRAC_Z, FracZToNchw}, {kOpFormat_FRAC_NZ, FracNzToNchw},
  1070. {kOpFormat_NC1HWC0, Nc1hwc0ToNchw}, {kOpFormat_C1HWNCoC0, C1hwncoc0ToNchw},
  1071. {kOpFormat_NC1HWC0_C04, Nc1hwc04ToNchw}, {kOpFormat_NDC1HWC0, Ndc1hwc0ToNcdhw},
  1072. {kOpFormat_FRACTAL_Z_3D, FracZ3DToNcdhw}};
  1073. MS_LOG(DEBUG) << "Start trans format.";
  1074. if (abstract::TypeIdSize(args.src_data_type) < 1) {
  1075. MS_LOG(ERROR) << "Invalid datatype..";
  1076. return false;
  1077. }
  1078. if (args.device_format == kOpFormat_HWCN || args.device_format == kOpFormat_NHWC) {
  1079. return ToNchw(args, result);
  1080. }
  1081. if (groups > 1 && args.device_format == kOpFormat_FRAC_Z) {
  1082. return FracZToNchwWithGroups(args, result, groups);
  1083. }
  1084. auto iter = format_trans_map.find(args.device_format);
  1085. if (iter == format_trans_map.end()) {
  1086. MS_LOG(EXCEPTION) << "Unexpected format[" << args.device_format << "]";
  1087. }
  1088. return iter->second(args, result);
  1089. }
  1090. bool TransFormatFromDeviceToHost(const FormatArgs &args, void *result, const AnfNodePtr &node, const size_t index) {
  1091. int64_t groups = 1;
  1092. if (args.device_format == kOpFormat_FRAC_Z) {
  1093. groups = GetAttrGroups(node, index);
  1094. }
  1095. return TransFormatFromDeviceToHost(args, result, groups);
  1096. }
  1097. bool NchwTo4D(const FormatArgs &args, void *result) {
  1098. // trans nchw to 4d
  1099. MS_LOG(DEBUG) << "Trans format from nchw to 4d.";
  1100. MS_EXCEPTION_IF_NULL(result);
  1101. size_t size = 0;
  1102. size_t total_size = 0;
  1103. if (!CheckArgs(args, &size, &total_size)) {
  1104. MS_LOG(ERROR) << "Check args failed.";
  1105. return false;
  1106. }
  1107. auto n = args.host_shape[kN];
  1108. auto c = args.host_shape[kC];
  1109. auto h = args.host_shape[kH];
  1110. auto w = args.host_shape[kW];
  1111. for (size_t ni = 0; ni < n; ni++) {
  1112. for (size_t ci = 0; ci < c; ci++) {
  1113. for (size_t hi = 0; hi < h; hi++) {
  1114. for (size_t wi = 0; wi < w; wi++) {
  1115. auto src_idx = ni * c * h * w + ci * h * w + hi * w + wi;
  1116. size_t dst_idx = 0;
  1117. if (args.device_format == kOpFormat_NHWC) {
  1118. dst_idx = ni * h * w * c + hi * w * c + wi * c + ci;
  1119. } else if (args.device_format == kOpFormat_HWCN) {
  1120. dst_idx = hi * w * c * n + wi * c * n + ci * n + ni;
  1121. }
  1122. SetData(size, false, src_idx, dst_idx, args, result);
  1123. }
  1124. }
  1125. }
  1126. }
  1127. return true;
  1128. }
  1129. bool ToNchw(const FormatArgs &args, void *result) {
  1130. MS_LOG(DEBUG) << "Trans format to nchw from 4d.";
  1131. MS_EXCEPTION_IF_NULL(result);
  1132. size_t size = 0;
  1133. size_t total_size = 0;
  1134. if (!CheckArgs(args, &size, &total_size)) {
  1135. MS_LOG(ERROR) << "Check args failed.";
  1136. return false;
  1137. }
  1138. auto n = args.host_shape[kN];
  1139. auto c = args.host_shape[kC];
  1140. auto h = args.host_shape[kH];
  1141. auto w = args.host_shape[kW];
  1142. for (size_t ni = 0; ni < n; ni++) {
  1143. for (size_t ci = 0; ci < c; ci++) {
  1144. for (size_t hi = 0; hi < h; hi++) {
  1145. for (size_t wi = 0; wi < w; wi++) {
  1146. auto dst_idx = ni * c * h * w + ci * h * w + hi * w + wi;
  1147. size_t src_idx = 0;
  1148. if (args.device_format == kOpFormat_NHWC) {
  1149. src_idx = ni * h * w * c + hi * w * c + wi * c + ci;
  1150. } else if (args.device_format == kOpFormat_HWCN) {
  1151. src_idx = hi * w * c * n + wi * c * n + ci * n + ni;
  1152. }
  1153. SetData(size, false, src_idx, dst_idx, args, result);
  1154. }
  1155. }
  1156. }
  1157. }
  1158. return true;
  1159. }
  1160. bool NchwToFracZ(const FormatArgs &args, void *result) {
  1161. MS_LOG(DEBUG) << "Trans format from nchw to frac_z";
  1162. MS_EXCEPTION_IF_NULL(result);
  1163. if (args.host_shape.size() != kNchwDims) {
  1164. MS_LOG(ERROR) << "Invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims;
  1165. return false;
  1166. }
  1167. auto size = abstract::TypeIdSize(args.src_data_type);
  1168. if (size < 1) {
  1169. MS_LOG(ERROR) << "Illegal dtype.";
  1170. return false;
  1171. }
  1172. auto n = args.host_shape[kN];
  1173. auto c = args.host_shape[kC];
  1174. auto h = args.host_shape[kH];
  1175. auto w = args.host_shape[kW];
  1176. const size_t c0 = 16;
  1177. auto c1 = DivCeil(c, c0);
  1178. auto hw = h * w;
  1179. auto chw = c * hw;
  1180. auto hwc0 = hw * c0;
  1181. auto nchw = n * chw;
  1182. auto hf_cnt = DivCeil(n, kCubeSize);
  1183. auto vf_cnt = c1 * hw;
  1184. auto fractal_ele_cnt = c0 * kCubeSize;
  1185. auto total_ele_cnt = hf_cnt * vf_cnt * fractal_ele_cnt;
  1186. auto dst_size = total_ele_cnt * size;
  1187. if (dst_size != args.device_size) {
  1188. MS_LOG(ERROR) << "Illegal total data size."
  1189. << "dst size is :" << dst_size << "device size is :" << args.device_size;
  1190. return false;
  1191. }
  1192. for (size_t vfi = 0; vfi < vf_cnt; vfi++) {
  1193. auto vf_base_i = vfi * hf_cnt; // vertical fractal matrix base index
  1194. for (size_t hfi = 0; hfi < hf_cnt; hfi++) {
  1195. auto gfi = vf_base_i + hfi; // global fractal matrix index
  1196. auto src_n_offset = hfi * chw * kCubeSize;
  1197. auto src_f_offset = src_n_offset + vfi % hw + vfi / hw * hwc0;
  1198. for (size_t row = 0; row < c0; row++) {
  1199. auto src_ci = vfi / hw * c0 + row;
  1200. auto src_row_offset = src_f_offset + row * hw;
  1201. for (size_t col = 0; col < kCubeSize; col++) {
  1202. auto src_ni = hfi * kCubeSize + col;
  1203. auto src_idx = src_row_offset + chw * col;
  1204. auto dst_idx = gfi * fractal_ele_cnt + col * c0 + row;
  1205. auto pad_zero = src_ni >= n || src_idx >= nchw || src_ci >= c;
  1206. SetData(size, pad_zero, src_idx, dst_idx, args, result);
  1207. }
  1208. }
  1209. }
  1210. }
  1211. return true;
  1212. }
  1213. bool FracZToNchw(const FormatArgs &args, void *result) {
  1214. MS_LOG(DEBUG) << "Trans format from frac_z to nchw";
  1215. MS_EXCEPTION_IF_NULL(result);
  1216. if (args.host_shape.size() != kNchwDims) {
  1217. MS_LOG(ERROR) << "Invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims;
  1218. return false;
  1219. }
  1220. auto size = abstract::TypeIdSize(args.src_data_type);
  1221. if (size < 1) {
  1222. MS_LOG(ERROR) << "Illegal dtype.";
  1223. return false;
  1224. }
  1225. auto total_size = abstract::ShapeSize(args.device_shape) * size;
  1226. if (total_size != args.device_size) {
  1227. MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size;
  1228. return false;
  1229. }
  1230. auto n0 = args.device_shape.at(1);
  1231. auto ni = args.device_shape.at(2);
  1232. auto c0 = args.device_shape.at(3);
  1233. auto n = args.host_shape[kN];
  1234. auto c = args.host_shape[kC];
  1235. auto h = args.host_shape[kH];
  1236. auto w = args.host_shape[kW];
  1237. auto nc = ni * n0;
  1238. auto ncc0 = nc * c0;
  1239. auto wncc0 = w * ncc0;
  1240. auto hwncc0 = h * wncc0;
  1241. auto hw = h * w;
  1242. auto chw = c * hw;
  1243. for (size_t n_idx = 0; n_idx < n; n_idx++) {
  1244. size_t n_head_addr = n_idx * chw;
  1245. for (size_t c_idx = 0; c_idx < c; c_idx++) {
  1246. size_t c_head_addr = n_head_addr + c_idx * hw;
  1247. for (size_t h_idx = 0; h_idx < h; h_idx++) {
  1248. size_t h_head_addr = c_head_addr + h_idx * w;
  1249. for (size_t w_idx = 0; w_idx < w; w_idx++) {
  1250. size_t dst_idx = h_head_addr + w_idx;
  1251. size_t c1_idx = c_idx / c0;
  1252. size_t c0_idx = c_idx % c0;
  1253. size_t nc_idx = n_idx;
  1254. size_t src_idx = c1_idx * hwncc0 + h_idx * wncc0 + w_idx * ncc0 + nc_idx * c0 + c0_idx;
  1255. SetData(size, false, src_idx, dst_idx, args, result);
  1256. }
  1257. }
  1258. }
  1259. }
  1260. return true;
  1261. }
  1262. bool NchwToFracZc04(const FormatArgs &args, void *result) {
  1263. // trans nchw to FracZc04
  1264. MS_LOG(DEBUG) << "Trans format from nchw to FracZc04.";
  1265. MS_EXCEPTION_IF_NULL(result);
  1266. size_t size = 0;
  1267. size_t total_size = 0;
  1268. if (!CheckArgs(args, &size, &total_size)) {
  1269. MS_LOG(ERROR) << "Check args failed.";
  1270. return false;
  1271. }
  1272. auto cube = kCubeSize;
  1273. auto n = args.host_shape[kN];
  1274. auto c = args.host_shape[kC];
  1275. auto h = args.host_shape[kH];
  1276. auto w = args.host_shape[kW];
  1277. const size_t c0 = 4;
  1278. auto c1 = DivCeil(c, c0);
  1279. auto hwc0 = h * w * c0;
  1280. auto hwc = h * w * c;
  1281. auto nhwc = n * h * w * c;
  1282. auto n_cnt = DivCeil(n, cube);
  1283. auto v_cnt = DivCeil(h * w * c0 * c1, cube);
  1284. size_t dst_idx = 0;
  1285. for (size_t vi = 0; vi < v_cnt; vi++) {
  1286. for (size_t ni = 0; ni < n_cnt; ni++) {
  1287. for (size_t col = 0; col < cube; col++) {
  1288. for (size_t row = 0; row < cube; row++) {
  1289. size_t cur_cube_n = cube * ni + col;
  1290. size_t cur_cube_c1hwc0 = cube * vi + row;
  1291. auto desc_g = cur_cube_n / n;
  1292. auto desc_n = cur_cube_n % n;
  1293. auto desc_c1 = cur_cube_c1hwc0 / hwc0;
  1294. auto desc_c0 = cur_cube_c1hwc0 % c0;
  1295. auto desc_h = (cur_cube_c1hwc0 - hwc0 * desc_c1) / (w * c0);
  1296. auto desc_w = (cur_cube_c1hwc0 - hwc0 * desc_c1 - w * c0 * desc_h) / c0;
  1297. auto c_idx = desc_c1 * c0 + desc_c0;
  1298. auto src_idx = desc_g * nhwc + desc_n * hwc + c_idx * h * w + desc_h * w + desc_w;
  1299. auto pad_zero = desc_g >= 1 || desc_n >= n || c_idx >= c;
  1300. SetData(size, pad_zero, src_idx, dst_idx, args, result);
  1301. dst_idx++;
  1302. }
  1303. }
  1304. }
  1305. }
  1306. return true;
  1307. }
  1308. bool NchwToNc1hwc04(const FormatArgs &args, void *result) {
  1309. MS_LOG(DEBUG) << "Trans format from nchw to Nc1hwc04.";
  1310. return NchwToNc1hwc0(args, result);
  1311. }
  1312. bool Nc1hwc04ToNchw(const FormatArgs &args, void *result) {
  1313. MS_LOG(DEBUG) << "Trans format from Nc1hwc04 to nchw.";
  1314. return Nc1hwc0ToNchw(args, result);
  1315. }
  1316. bool TransShapeToNz(const std::vector<size_t> &host_shape, std::vector<size_t> *hw_shape) {
  1317. MS_EXCEPTION_IF_NULL(hw_shape);
  1318. if (host_shape.empty()) {
  1319. MS_LOG(ERROR) << "Size of vector is 0.";
  1320. return false;
  1321. }
  1322. switch (host_shape.size()) {
  1323. case 1:
  1324. hw_shape->push_back(1);
  1325. hw_shape->push_back(1);
  1326. hw_shape->push_back(host_shape[0]);
  1327. return true;
  1328. default:
  1329. auto size = host_shape.size();
  1330. if (size < kDim2) {
  1331. MS_LOG(ERROR) << "Illegal size.";
  1332. return false;
  1333. }
  1334. size_t times = 1;
  1335. for (size_t i = 0; i != size - kDim2; i++) {
  1336. times *= host_shape[i];
  1337. }
  1338. hw_shape->push_back(times);
  1339. hw_shape->push_back(host_shape[size - kDim2]);
  1340. hw_shape->push_back(host_shape[size - kDim1]);
  1341. return true;
  1342. }
  1343. }
  1344. bool NchwToFracNz(const FormatArgs &args, void *result) {
  1345. MS_LOG(DEBUG) << "Trans format from nchw to frac_nz.";
  1346. MS_EXCEPTION_IF_NULL(result);
  1347. std::vector<size_t> hw_shape;
  1348. if (!TransShapeToNz(args.host_shape, &hw_shape)) {
  1349. MS_LOG(ERROR) << "Trans shape failed..";
  1350. return false;
  1351. }
  1352. if (hw_shape.size() < kDim3 || args.device_shape.size() < kDim4) {
  1353. MS_LOG(ERROR) << "Invalid shape size.";
  1354. return false;
  1355. }
  1356. auto size = abstract::TypeIdSize(args.src_data_type);
  1357. if (size < 1) {
  1358. MS_LOG(ERROR) << "Illegal dtype";
  1359. return false;
  1360. }
  1361. auto dst_size = abstract::ShapeSize(args.device_shape) * size;
  1362. if (dst_size != args.device_size) {
  1363. MS_LOG(ERROR) << "Illegal total data size, total_size:" << dst_size << ", device_size:" << args.device_size;
  1364. return false;
  1365. }
  1366. auto times = hw_shape.at(0);
  1367. auto h = hw_shape.at(1);
  1368. auto w = hw_shape.at(2);
  1369. auto hw = h * w;
  1370. auto shape_size = args.device_shape.size();
  1371. auto w1 = args.device_shape[shape_size - 4];
  1372. auto h1 = args.device_shape[shape_size - 3];
  1373. auto h0 = args.device_shape[shape_size - 2];
  1374. auto w0 = args.device_shape[shape_size - 1];
  1375. auto h1h0w0 = h1 * h0 * w0;
  1376. auto w1h1h0w0 = w1 * h1h0w0;
  1377. auto num_w1 = w / w0;
  1378. for (size_t times_idx = 0; times_idx < times; times_idx++) {
  1379. auto times_head = times_idx * w1h1h0w0;
  1380. auto src_times_head = times_idx * hw;
  1381. for (size_t h1h0_idx = 0; h1h0_idx < h; h1h0_idx++) {
  1382. auto h1h0_head = times_head + h1h0_idx * w0;
  1383. auto src_h_head = src_times_head + h1h0_idx * w;
  1384. for (size_t w1_idx = 0; w1_idx < num_w1; w1_idx++) {
  1385. for (size_t i = 0; i < w0; ++i) {
  1386. size_t src_idx = src_h_head + w1_idx * w0 + i;
  1387. size_t dst_idx = h1h0_head + w1_idx * h1h0w0 + i;
  1388. SetData(size, false, src_idx, dst_idx, args, result);
  1389. }
  1390. }
  1391. auto w1_head = num_w1 * w0;
  1392. for (size_t w0_idx = 0; w1_head + w0_idx < w; w0_idx++) {
  1393. auto src_w_idx = w1_head + w0_idx;
  1394. size_t dst_idx = h1h0_head + num_w1 * h1h0w0 + w0_idx;
  1395. size_t src_idx = src_h_head + src_w_idx;
  1396. SetData(size, false, src_idx, dst_idx, args, result);
  1397. }
  1398. }
  1399. }
  1400. return true;
  1401. }
  1402. bool FracNzToNchw(const FormatArgs &args, void *result) {
  1403. MS_LOG(DEBUG) << "Trans format from frac_nz to nchw";
  1404. MS_EXCEPTION_IF_NULL(result);
  1405. std::vector<size_t> hw_shape;
  1406. if (!TransShapeToNz(args.host_shape, &hw_shape)) {
  1407. MS_LOG(ERROR) << "Trans shape failed..";
  1408. return false;
  1409. }
  1410. if (hw_shape.size() < kDim3 || args.device_shape.size() < kDim4) {
  1411. MS_LOG(ERROR) << "Invalid shape size.";
  1412. return false;
  1413. }
  1414. auto size = abstract::TypeIdSize(args.src_data_type);
  1415. if (size < 1) {
  1416. MS_LOG(ERROR) << "Illegal dtype";
  1417. return false;
  1418. }
  1419. auto dst_size = abstract::ShapeSize(args.device_shape) * size;
  1420. if (dst_size != args.device_size) {
  1421. MS_LOG(ERROR) << "Illegal total data size, total_size:" << dst_size << ", device_size:" << args.device_size;
  1422. return false;
  1423. }
  1424. auto times = hw_shape.at(0);
  1425. auto h = hw_shape.at(1);
  1426. auto w = hw_shape.at(2);
  1427. auto hw = h * w;
  1428. auto shape_size = args.device_shape.size();
  1429. auto w1 = args.device_shape[shape_size - 4];
  1430. auto h1 = args.device_shape[shape_size - 3];
  1431. auto h0 = args.device_shape[shape_size - 2];
  1432. auto w0 = args.device_shape[shape_size - 1];
  1433. auto h1h0w0 = h1 * h0 * w0;
  1434. auto w1h1h0w0 = w1 * h1h0w0;
  1435. auto num_w1 = w / w0;
  1436. for (size_t times_idx = 0; times_idx < times; times_idx++) {
  1437. auto times_head = times_idx * w1h1h0w0;
  1438. auto src_times_head = times_idx * hw;
  1439. for (size_t h1h0_idx = 0; h1h0_idx < h; h1h0_idx++) {
  1440. auto h1h0_head = times_head + h1h0_idx * w0;
  1441. auto src_h_head = src_times_head + h1h0_idx * w;
  1442. for (size_t w1_idx = 0; w1_idx < num_w1; w1_idx++) {
  1443. for (size_t i = 0; i < w0; ++i) {
  1444. size_t src_idx = h1h0_head + w1_idx * h1h0w0 + i;
  1445. size_t dst_idx = src_h_head + w1_idx * w0 + i;
  1446. SetData(size, false, src_idx, dst_idx, args, result);
  1447. }
  1448. }
  1449. auto w1_head = num_w1 * w0;
  1450. for (size_t w0_idx = 0; w1_head + w0_idx < w; w0_idx++) {
  1451. auto src_w_idx = w1_head + w0_idx;
  1452. size_t src_idx = h1h0_head + num_w1 * h1h0w0 + w0_idx;
  1453. size_t dst_idx = src_h_head + src_w_idx;
  1454. SetData(size, false, src_idx, dst_idx, args, result);
  1455. }
  1456. }
  1457. }
  1458. return true;
  1459. }
  1460. bool NchwToNc1hwc0(const FormatArgs &args, void *result) {
  1461. MS_LOG(DEBUG) << "Trans format from nchw to Nc1h1wc0";
  1462. MS_EXCEPTION_IF_NULL(result);
  1463. if (args.host_shape.size() != kNchwDims) {
  1464. MS_LOG(ERROR) << "Invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims;
  1465. return false;
  1466. }
  1467. auto size = abstract::TypeIdSize(args.src_data_type);
  1468. if (size < 1) {
  1469. MS_LOG(ERROR) << "Illegal dtype.";
  1470. return false;
  1471. }
  1472. auto total_size = abstract::ShapeSize(args.device_shape) * size;
  1473. if (total_size != args.device_size) {
  1474. MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size;
  1475. return false;
  1476. }
  1477. auto n = args.host_shape[kN];
  1478. auto c = args.host_shape[kC];
  1479. auto h = args.host_shape[kH];
  1480. auto w = args.host_shape[kW];
  1481. size_t c0 = kCubeSize;
  1482. if (args.device_format == kOpFormat_NC1HWC0_C04) {
  1483. c0 = kCubeSize_C04;
  1484. }
  1485. auto c1 = DivCeil(c, c0);
  1486. auto hw = h * w;
  1487. auto chw = c * hw;
  1488. auto c1hwc0 = c1 * hw * c0;
  1489. auto wc0 = w * c0;
  1490. for (size_t n_idx = 0; n_idx < n; n_idx++) {
  1491. size_t n_head_addr = n_idx * c1hwc0;
  1492. for (size_t c1_idx = 0; c1_idx < c1; c1_idx++) {
  1493. size_t c1_head_addr = n_head_addr + c1_idx * hw * c0;
  1494. for (size_t h_idx = 0; h_idx < h; h_idx++) {
  1495. size_t h_head_addr = c1_head_addr + h_idx * wc0;
  1496. for (size_t w_idx = 0; w_idx < w; w_idx++) {
  1497. size_t w_head_addr = h_head_addr + w_idx * c0;
  1498. for (size_t c0_idx = 0; c0_idx < c0; c0_idx++) {
  1499. size_t dst_idx = c0_idx + w_head_addr;
  1500. size_t c_idx = c0_idx + c1_idx * c0;
  1501. size_t src_idx = n_idx * chw + c_idx * hw + h_idx * w + w_idx;
  1502. auto pad_zero = c_idx >= c;
  1503. SetData(size, pad_zero, src_idx, dst_idx, args, result);
  1504. }
  1505. }
  1506. }
  1507. }
  1508. }
  1509. return true;
  1510. }
  1511. bool Nc1hwc0ToNchw(const FormatArgs &args, void *result) {
  1512. MS_LOG(DEBUG) << "Trans format from nc1h1wc0 to nchw";
  1513. MS_EXCEPTION_IF_NULL(result);
  1514. if (args.host_shape.size() != kNchwDims) {
  1515. MS_LOG(ERROR) << "Invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims;
  1516. return false;
  1517. }
  1518. auto size = abstract::TypeIdSize(args.src_data_type);
  1519. if (size < 1) {
  1520. MS_LOG(ERROR) << "Illegal dtype.";
  1521. return false;
  1522. }
  1523. auto total_size = abstract::ShapeSize(args.device_shape) * size;
  1524. if (total_size != args.device_size) {
  1525. MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size;
  1526. return false;
  1527. }
  1528. auto n = args.host_shape[kN];
  1529. auto c = args.host_shape[kC];
  1530. auto h = args.host_shape[kH];
  1531. auto w = args.host_shape[kW];
  1532. auto c1 = args.device_shape[1];
  1533. auto c0 = args.device_shape[4];
  1534. auto hw = h * w;
  1535. auto chw = c * hw;
  1536. auto wc0 = w * c0;
  1537. auto hwc0 = h * wc0;
  1538. auto c1hwc0 = c1 * hwc0;
  1539. for (size_t n_idx = 0; n_idx < n; n_idx++) {
  1540. size_t n_head_addr = n_idx * chw;
  1541. for (size_t c_idx = 0; c_idx < c; c_idx++) {
  1542. size_t c_head_addr = n_head_addr + c_idx * hw;
  1543. for (size_t h_idx = 0; h_idx < h; h_idx++) {
  1544. size_t h_head_addr = c_head_addr + h_idx * w;
  1545. for (size_t w_idx = 0; w_idx < w; w_idx++) {
  1546. size_t dst_idx = h_head_addr + w_idx;
  1547. size_t c1_idx = c_idx / c0;
  1548. size_t c0_idx = c_idx % c0;
  1549. size_t src_idx = n_idx * c1hwc0 + c1_idx * hwc0 + h_idx * wc0 + w_idx * c0 + c0_idx;
  1550. SetData(size, false, src_idx, dst_idx, args, result);
  1551. }
  1552. }
  1553. }
  1554. }
  1555. return true;
  1556. }
  1557. bool NchwToC1hwncoc0(const FormatArgs &args, void *result) {
  1558. // trans nchw to c1hwncoc0
  1559. MS_LOG(DEBUG) << "Trans format from nchw to c1hwncoc0.";
  1560. MS_EXCEPTION_IF_NULL(result);
  1561. size_t size = 0;
  1562. size_t total_size = 0;
  1563. if (!CheckArgs(args, &size, &total_size)) {
  1564. MS_LOG(ERROR) << "Check args failed.";
  1565. return false;
  1566. }
  1567. auto n = args.host_shape[kN];
  1568. auto c = args.host_shape[kC];
  1569. auto h = args.host_shape[kH];
  1570. auto w = args.host_shape[kW];
  1571. const int co_idx = 4;
  1572. const int c0_idx = 5;
  1573. auto c1 = args.device_shape[0];
  1574. auto co = args.device_shape[co_idx];
  1575. auto c0 = args.device_shape[c0_idx];
  1576. for (size_t c1_i = 0; c1_i < c1; c1_i++) {
  1577. for (size_t h_i = 0; h_i < h; h_i++) {
  1578. for (size_t w_i = 0; w_i < w; w_i++) {
  1579. for (size_t n_i = 0; n_i < n; n_i++) {
  1580. for (size_t co_i = 0; co_i < co; co_i++) {
  1581. for (size_t c0_i = 0; c0_i < c0; c0_i++) {
  1582. size_t dst_idx = c1_i * h * w * n * co * c0 + h_i * w * n * co * c0 + w_i * n * co * c0 + n_i * co * c0 +
  1583. co_i * c0 + c0_i;
  1584. size_t c_i = c0_i + c1_i * c0;
  1585. size_t src_idx = n_i * c * h * w + c_i * h * w + h_i * w + w_i;
  1586. auto pad_zero = !(c_i < c && c0_i == co_i);
  1587. SetData(size, pad_zero, src_idx, dst_idx, args, result);
  1588. }
  1589. }
  1590. }
  1591. }
  1592. }
  1593. }
  1594. return true;
  1595. }
  1596. bool C1hwncoc0ToNchw(const FormatArgs &args, void *result) {
  1597. // trans c1hwncoc0 to nchw
  1598. MS_LOG(DEBUG) << "Trans format from c1hwncoc0 to nchw";
  1599. MS_EXCEPTION_IF_NULL(result);
  1600. size_t size = 0;
  1601. size_t total_size = 0;
  1602. if (!CheckArgs(args, &size, &total_size)) {
  1603. MS_LOG(ERROR) << "Check args failed.";
  1604. return false;
  1605. }
  1606. auto n = args.host_shape[kN];
  1607. auto c = args.host_shape[kC];
  1608. auto h = args.host_shape[kH];
  1609. auto w = args.host_shape[kW];
  1610. const int co_idx = 4;
  1611. const int c0_idx = 5;
  1612. auto co = args.device_shape[co_idx];
  1613. auto c0 = args.device_shape[c0_idx];
  1614. for (size_t n_i = 0; n_i < n; n_i++) {
  1615. for (size_t c_i = 0; c_i < c; c_i++) {
  1616. for (size_t h_i = 0; h_i < h; h_i++) {
  1617. for (size_t w_i = 0; w_i < w; w_i++) {
  1618. size_t dst_idx = n_i * c * h * w + c_i * h * w + h_i * w + w_i;
  1619. size_t c1_i = c_i / kCubeSize;
  1620. size_t c0_i = c_i % kCubeSize;
  1621. size_t co_i = c0_i;
  1622. size_t src_idx =
  1623. c1_i * h * w * n * co * c0 + h_i * w * n * co * c0 + w_i * n * co * c0 + n_i * co * c0 + co_i * c0 + c0_i;
  1624. SetData(size, false, src_idx, dst_idx, args, result);
  1625. }
  1626. }
  1627. }
  1628. }
  1629. return true;
  1630. }
  1631. bool Ndc1hwc0ToNcdhw(const FormatArgs &args, void *result) {
  1632. MS_LOG(DEBUG) << "Trans from ndc1hwc0 to ncdhw";
  1633. MS_EXCEPTION_IF_NULL(result);
  1634. if (args.host_shape.size() != kNcdhw) {
  1635. MS_LOG(ERROR) << "Illegal host shape dim, expect dim: 5, but got " << args.host_shape.size();
  1636. return false;
  1637. }
  1638. auto size = abstract::TypeIdSize(args.src_data_type);
  1639. if (size < 1) {
  1640. MS_LOG(ERROR) << "Illegal dtype.";
  1641. return false;
  1642. }
  1643. auto total_size = abstract::ShapeSize(args.device_shape) * size;
  1644. if (total_size != args.device_size) {
  1645. MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size;
  1646. return false;
  1647. }
  1648. auto n = args.host_shape[N_ncdhw];
  1649. auto c = args.host_shape[C_ncdhw];
  1650. auto d = args.host_shape[D_ncdhw];
  1651. auto h = args.host_shape[H_ncdhw];
  1652. auto w = args.host_shape[W_ncdhw];
  1653. auto c1 = args.device_shape[C1_ndc1hwc0];
  1654. auto c0 = args.device_shape[C0_ndc1hwc0];
  1655. const size_t cdhw = c * d * h * w;
  1656. const size_t dhw = d * h * w;
  1657. const size_t hw = h * w;
  1658. const size_t dc1hwc0 = d * c1 * h * w * c0;
  1659. const size_t c1hwc0 = c1 * h * w * c0;
  1660. const size_t hwc0 = h * w * c0;
  1661. const size_t wc0 = w * c0;
  1662. for (size_t n_i = 0; n_i < n; n_i++) {
  1663. size_t n_head = n_i * cdhw;
  1664. for (size_t c_i = 0; c_i < c; c_i++) {
  1665. size_t c_head = n_head + c_i * dhw;
  1666. for (size_t d_i = 0; d_i < d; d_i++) {
  1667. size_t d_head = c_head + d_i * hw;
  1668. for (size_t h_i = 0; h_i < h; h_i++) {
  1669. size_t h_head = d_head + h_i * w;
  1670. for (size_t w_i = 0; w_i < w; w_i++) {
  1671. size_t dst_i = h_head + w_i;
  1672. size_t c1_i = c_i / c0;
  1673. size_t c0_i = c_i % c0;
  1674. auto src_idx = n_i * dc1hwc0 + d_i * c1hwc0 + c1_i * hwc0 + h_i * wc0 + w_i * c0 + c0_i;
  1675. SetData(size, false, src_idx, dst_i, args, result);
  1676. }
  1677. }
  1678. }
  1679. }
  1680. }
  1681. return true;
  1682. }
  1683. bool NcdhwToNdc1hwc0(const FormatArgs &args, void *result) {
  1684. MS_LOG(DEBUG) << "Trans from ncdhw to ndc1hwc0";
  1685. MS_EXCEPTION_IF_NULL(result);
  1686. if (args.host_shape.size() != kNcdhw) {
  1687. MS_LOG(ERROR) << "Illegal host shape dim, expect dim: 5, but got " << args.host_shape.size();
  1688. return false;
  1689. }
  1690. auto size = abstract::TypeIdSize(args.src_data_type);
  1691. if (size < 1) {
  1692. MS_LOG(ERROR) << "Illegal dtype.";
  1693. return false;
  1694. }
  1695. auto total_size = abstract::ShapeSize(args.device_shape) * size;
  1696. if (total_size != args.device_size) {
  1697. MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size;
  1698. return false;
  1699. }
  1700. auto n = args.host_shape[N_ncdhw];
  1701. auto c = args.host_shape[C_ncdhw];
  1702. auto d = args.host_shape[D_ncdhw];
  1703. auto h = args.host_shape[H_ncdhw];
  1704. auto w = args.host_shape[W_ncdhw];
  1705. auto c0 = kCubeSize;
  1706. auto c1 = DivCeil(c, c0);
  1707. const size_t cdhw = c * d * h * w;
  1708. const size_t dhw = d * h * w;
  1709. const size_t hw = h * w;
  1710. const size_t dc1hwc0 = d * c1 * h * w * c0;
  1711. const size_t c1hwc0 = c1 * h * w * c0;
  1712. const size_t hwc0 = h * w * c0;
  1713. const size_t wc0 = w * c0;
  1714. for (size_t n_i = 0; n_i < n; n_i++) {
  1715. size_t n_head = n_i * dc1hwc0;
  1716. for (size_t d_i = 0; d_i < d; d_i++) {
  1717. size_t d_head = n_head + d_i * c1hwc0;
  1718. for (size_t c1_i = 0; c1_i < c1; c1_i++) {
  1719. size_t c1_head = d_head + c1_i * hwc0;
  1720. for (size_t h_i = 0; h_i < h; h_i++) {
  1721. size_t h_head = c1_head + h_i * wc0;
  1722. for (size_t w_i = 0; w_i < w; w_i++) {
  1723. size_t w_head = h_head + w_i * c0;
  1724. for (size_t c0_i = 0; c0_i < c0; c0_i++) {
  1725. size_t dst_i = c0_i + w_head;
  1726. size_t c_i = c0_i + c1_i * c0;
  1727. size_t src_i = n_i * cdhw + c_i * dhw + d_i * hw + h_i * w + w_i;
  1728. auto pad_zero = c_i >= c;
  1729. SetData(size, pad_zero, src_i, dst_i, args, result);
  1730. }
  1731. }
  1732. }
  1733. }
  1734. }
  1735. }
  1736. return true;
  1737. }
  1738. bool NcdhwToFracZ3D(const FormatArgs &args, void *result) {
  1739. MS_LOG(DEBUG) << "Trans from ncdhw to frac_z_3d";
  1740. MS_EXCEPTION_IF_NULL(result);
  1741. if (args.host_shape.size() != kNcdhw) {
  1742. MS_LOG(ERROR) << "Illegal host shape dim, expect dim: 5, but got " << args.host_shape.size();
  1743. return false;
  1744. }
  1745. auto size = abstract::TypeIdSize(args.src_data_type);
  1746. if (size < 1) {
  1747. MS_LOG(ERROR) << "Illegal dtype.";
  1748. return false;
  1749. }
  1750. auto total_size = abstract::ShapeSize(args.device_shape) * size;
  1751. if (total_size != args.device_size) {
  1752. MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size;
  1753. return false;
  1754. }
  1755. auto n = args.host_shape[N_ncdhw];
  1756. auto c = args.host_shape[C_ncdhw];
  1757. auto d = args.host_shape[D_ncdhw];
  1758. auto h = args.host_shape[H_ncdhw];
  1759. auto w = args.host_shape[W_ncdhw];
  1760. auto n1n0 = DivCeil(n, kCubeSize) * kCubeSize;
  1761. const size_t c0 = 16;
  1762. auto c1 = DivCeil(c, c0);
  1763. auto hw = h * w;
  1764. auto dhw = d * hw;
  1765. auto cdhw = c * dhw;
  1766. auto n1n0c0 = n1n0 * c0;
  1767. auto wn1n0c0 = w * n1n0c0;
  1768. auto hwn1n0c0 = h * wn1n0c0;
  1769. auto c1hwn1n0c0 = c1 * hwn1n0c0;
  1770. for (size_t d_i = 0; d_i < d; d_i++) {
  1771. for (size_t c1_i = 0; c1_i < c1; c1_i++) {
  1772. for (size_t h_i = 0; h_i < h; h_i++) {
  1773. for (size_t w_i = 0; w_i < w; w_i++) {
  1774. for (size_t n1n0_i = 0; n1n0_i < n1n0; n1n0_i++) {
  1775. for (size_t c0_i = 0; c0_i < c0; c0_i++) {
  1776. auto dst_i = d_i * c1hwn1n0c0 + c1_i * hwn1n0c0 + h_i * wn1n0c0 + w_i * n1n0c0 + n1n0_i * c0 + c0_i;
  1777. // ncdhw
  1778. size_t src_i = n1n0_i * cdhw + (c1_i * c0 + c0_i) * dhw + d_i * hw + h_i * w + w_i;
  1779. auto pad_zero = ((c1_i * c0 + c0_i) >= c) || (n1n0_i >= n);
  1780. SetData(size, pad_zero, src_i, dst_i, args, result);
  1781. }
  1782. }
  1783. }
  1784. }
  1785. }
  1786. }
  1787. return true;
  1788. }
  1789. bool FracZ3DToNcdhw(const FormatArgs &args, void *result) {
  1790. MS_LOG(DEBUG) << "Trans from frac_z_3d to ncdhw";
  1791. MS_EXCEPTION_IF_NULL(result);
  1792. if (args.host_shape.size() != kNcdhw) {
  1793. MS_LOG(ERROR) << "Illegal host shape dim, expect dim: 5, but got " << args.host_shape.size();
  1794. return false;
  1795. }
  1796. auto size = abstract::TypeIdSize(args.src_data_type);
  1797. if (size < 1) {
  1798. MS_LOG(ERROR) << "Illegal dtype.";
  1799. return false;
  1800. }
  1801. auto total_size = abstract::ShapeSize(args.device_shape) * size;
  1802. if (total_size != args.device_size) {
  1803. MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size;
  1804. return false;
  1805. }
  1806. auto n = args.host_shape[N_ncdhw];
  1807. auto c = args.host_shape[C_ncdhw];
  1808. auto d = args.host_shape[D_ncdhw];
  1809. auto h = args.host_shape[H_ncdhw];
  1810. auto w = args.host_shape[W_ncdhw];
  1811. const int kFZ3D_C0 = 3;
  1812. auto c0 = args.device_shape[kFZ3D_C0];
  1813. auto c1 = DivCeil(c, kCubeSize);
  1814. auto n1n0 = DivCeil(n, kCubeSize) * kCubeSize;
  1815. auto n1n0c0 = n1n0 * c0;
  1816. auto wn1n0c0 = w * n1n0c0;
  1817. auto hwn1n0c0 = h * wn1n0c0;
  1818. auto c1hwn1n0c0 = c1 * hwn1n0c0;
  1819. auto hw = h * w;
  1820. auto dhw = d * hw;
  1821. auto cdhw = c * dhw;
  1822. for (size_t n_i = 0; n_i < n; n_i++) {
  1823. size_t n_head = n_i * cdhw;
  1824. for (size_t c_i = 0; c_i < c; c_i++) {
  1825. size_t c_head = n_head + c_i * dhw;
  1826. for (size_t d_i = 0; d_i < d; d_i++) {
  1827. size_t d_head = c_head + d_i * hw;
  1828. for (size_t h_i = 0; h_i < h; h_i++) {
  1829. size_t h_head = d_head + h_i * w;
  1830. for (size_t w_i = 0; w_i < w; w_i++) {
  1831. size_t dst_i = h_head + w_i;
  1832. size_t c1_i = c_i / c0;
  1833. size_t c0_i = c_i % c0;
  1834. size_t nc_i = n_i;
  1835. size_t src_i = d_i * c1hwn1n0c0 + c1_i * hwn1n0c0 + h_i * wn1n0c0 + w_i * n1n0c0 + nc_i * c0 + c0_i;
  1836. SetData(size, false, src_i, dst_i, args, result);
  1837. }
  1838. }
  1839. }
  1840. }
  1841. }
  1842. return true;
  1843. }
  1844. bool CheckDimsAndDtypes(const FormatArgs &args, size_t *size) {
  1845. if (args.host_shape.size() != kNchwDims) {
  1846. MS_LOG(ERROR) << "Invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims;
  1847. return false;
  1848. }
  1849. *size = abstract::TypeIdSize(args.src_data_type);
  1850. if (*size < 1) {
  1851. MS_LOG(ERROR) << "Illegal dtype";
  1852. return false;
  1853. }
  1854. return true;
  1855. }
  1856. bool NchwFracZTransWithGroups(const FormatArgs &args, void *result, bool to_device, int64_t groups) {
  1857. MS_EXCEPTION_IF_NULL(result);
  1858. size_t size = 0;
  1859. if (!(CheckDimsAndDtypes(args, &size))) {
  1860. MS_LOG(ERROR) << "Illegal input args";
  1861. return false;
  1862. }
  1863. if (groups <= 0) {
  1864. MS_LOG(EXCEPTION) << "The value of groups should be greater than 0, but got " << groups;
  1865. }
  1866. auto n_dim = args.host_shape[kN];
  1867. auto c_dim = args.host_shape[kC];
  1868. auto h_dim = args.host_shape[kH];
  1869. auto w_dim = args.host_shape[kW];
  1870. const size_t d_dim = 1;
  1871. size_t group_size = LongToSize(groups);
  1872. auto cin_ori = c_dim;
  1873. auto cout_ori = n_dim / group_size;
  1874. if (cin_ori == 0 || cout_ori == 0) {
  1875. MS_LOG(ERROR) << "cin_ori, cout_ori must not be equal to 0";
  1876. return false;
  1877. }
  1878. size_t e_mult = std::min(Lcm(Lcm(cin_ori, kCubeSize) / cin_ori, Lcm(cout_ori, kCubeSize) / cout_ori), group_size);
  1879. if (e_mult == 0) {
  1880. MS_LOG(EXCEPTION) << "The value of e_mult should be greater than 0, but got " << e_mult;
  1881. }
  1882. size_t cin_opt = DivCeil(e_mult * cin_ori, kCubeSize) * kCubeSize;
  1883. size_t cout_opt = DivCeil(e_mult * cout_ori, kCubeSize) * kCubeSize;
  1884. size_t c1_dim = cin_opt / kCubeSize;
  1885. size_t dst_size = to_device ? GetShapeSize(args.device_shape) * size : GetShapeSize(args.host_shape) * size;
  1886. if (dst_size == 0) {
  1887. return true;
  1888. }
  1889. auto ret = memset_s(result, dst_size, 0, dst_size);
  1890. if (ret != EOK) {
  1891. MS_LOG(ERROR) << "memset failed";
  1892. return false;
  1893. }
  1894. for (size_t g = 0; g < group_size; ++g) {
  1895. for (size_t d = 0; d < d_dim; ++d) {
  1896. for (size_t c = 0; c < c_dim; ++c) {
  1897. for (size_t h = 0; h < h_dim; ++h) {
  1898. for (size_t w = 0; w < w_dim; ++w) {
  1899. for (size_t n = 0; n < cout_ori; ++n) {
  1900. size_t e_val = g % e_mult;
  1901. size_t dst_ci = e_val * cin_ori + c;
  1902. size_t dst_co = e_val * cout_ori + n;
  1903. size_t src_co = g * cout_ori + n;
  1904. size_t temporary = dst_ci % kCubeSize;
  1905. size_t dev_idx = (g / e_mult) * d_dim * c1_dim * h_dim * w_dim * cout_opt * kCubeSize +
  1906. d * c1_dim * h_dim * w_dim * cout_opt * kCubeSize +
  1907. (dst_ci / kCubeSize) * h_dim * w_dim * cout_opt * kCubeSize +
  1908. h * w_dim * cout_opt * kCubeSize + w * cout_opt * kCubeSize + dst_co * kCubeSize +
  1909. temporary;
  1910. size_t hst_idx =
  1911. src_co * c_dim * d_dim * h_dim * w_dim + c * d_dim * h_dim * w_dim + d * h_dim * w_dim + h * w_dim + w;
  1912. if (to_device) {
  1913. SetData(size, false, hst_idx, dev_idx, args, result);
  1914. } else {
  1915. SetData(size, false, dev_idx, hst_idx, args, result);
  1916. }
  1917. }
  1918. }
  1919. }
  1920. }
  1921. }
  1922. }
  1923. return true;
  1924. }
  1925. bool NchwToFracZWithGroups(const FormatArgs &args, void *result, int64_t groups) {
  1926. MS_LOG(DEBUG) << "Trans format from nchw to frac_z with groups=" << groups;
  1927. return NchwFracZTransWithGroups(args, result, true, groups);
  1928. }
  1929. bool FracZToNchwWithGroups(const FormatArgs &args, void *result, int64_t groups) {
  1930. MS_LOG(DEBUG) << "Trans format from frac_z to nchw with groups=" << groups;
  1931. return NchwFracZTransWithGroups(args, result, false, groups);
  1932. }
  1933. } // namespace trans
  1934. } // namespace mindspore