You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

format.cpp 29 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688
  1. #include "megbrain/imperative/transformations/format.h"
  2. #include "megbrain/imperative/transformations/grad.h"
  3. #include "megbrain/imperative/ops/autogen.h"
  4. #include "megbrain/imperative/ops/utility.h"
  5. namespace mgb {
  6. namespace imperative {
  7. using FT = Format::Type;
  8. TypedValueRef<FormattedTensorValue> FormatTransformation::as(
  9. const FormattedTensorValue& tensor, const FT& target) const {
  10. return m_value_type.make(tensor.value(), target);
  11. }
  12. TypedValueRef<FormattedTensorValue> FormatTransformation::to(
  13. const FormattedTensorValue& tensor, const FT& target,
  14. const std::string& scope) const {
  15. std::vector<int32_t> pattern;
  16. Format format = tensor.format();
  17. if (format == FT::NHWC && (target == FT::NCHW || target == FT::DEFAULT)) {
  18. // FIXME(czh): temporary fast path for group conv 5D weight.
  19. if (tensor.value().shape().cast<ShapeValue>().ndim == 5) {
  20. pattern = {0, 1, 4, 2, 3};
  21. } else {
  22. pattern = {0, 3, 1, 2};
  23. }
  24. } else if ((format == FT::NCHW || format == FT::DEFAULT) && target == FT::NHWC) {
  25. if (tensor.value().shape().cast<ShapeValue>().ndim == 5) {
  26. pattern = {0, 1, 3, 4, 2};
  27. } else {
  28. pattern = {0, 2, 3, 1};
  29. }
  30. } else {
  31. mgb_throw(
  32. MegBrainError, "Unsupport format conversion from %s to %s",
  33. format.to_string().c_str(), Format(target).to_string().c_str());
  34. }
  35. auto output =
  36. imperative::apply(*Dimshuffle::make(pattern, scope), {tensor.value()})[0];
  37. return m_value_type.make(output, target);
  38. }
  39. inline ValueRef FormatTransformation::unwrap_input(const ValueRef& input) const {
  40. if (auto format_input = input.as_ref(m_value_type)) {
  41. return format_input->value();
  42. } else {
  43. return input;
  44. }
  45. }
  46. inline ValueRefList FormatTransformation::unwrap_inputs(
  47. const Span<ValueRef>& inputs) const {
  48. ValueRefList unwrapped_inputs(inputs.size());
  49. for (size_t i = 0; i < inputs.size(); ++i) {
  50. unwrapped_inputs[i] = unwrap_input(inputs[i]);
  51. }
  52. return unwrapped_inputs;
  53. }
  54. inline ValueRef FormatTransformation::wrap_output(
  55. const ValueRef& output, Format format) const {
  56. return m_value_type.make(output, format);
  57. }
  58. inline ValueRefList FormatTransformation::wrap_outputs(
  59. const ValueRefList& outputs, Format format) const {
  60. ValueRefList wrapped_outputs(outputs.size());
  61. for (size_t i = 0; i < outputs.size(); ++i) {
  62. wrapped_outputs[i] = wrap_output(outputs[i], format);
  63. }
  64. return wrapped_outputs;
  65. }
  66. namespace {
  67. ValueShape convert_nhwc2nchw_shape(const ValueShape& shape) {
  68. auto out = ValueShape(shape);
  69. if (shape.ndim == 4) {
  70. out[1] = shape[3];
  71. out[2] = shape[1];
  72. out[3] = shape[2];
  73. return out;
  74. } else if (shape.ndim == 5) {
  75. out[2] = shape[4];
  76. out[3] = shape[2];
  77. out[4] = shape[3];
  78. return out;
  79. } else {
  80. mgb_throw(
  81. MegBrainError, "Unsupported shape ndim %lu in GetAttr(Shape).",
  82. shape.ndim);
  83. }
  84. }
  85. std::vector<int32_t> convert_nchw2nhwc_vector(const std::vector<int32_t>& shape) {
  86. auto out = std::vector<int32_t>(shape);
  87. if (shape.size() == 4) {
  88. out[1] = shape[2];
  89. out[2] = shape[3];
  90. out[3] = shape[1];
  91. return out;
  92. } else if (shape.size() == 5) {
  93. // GIOHW -> GIHWO
  94. out[2] = shape[3];
  95. out[3] = shape[4];
  96. out[4] = shape[2];
  97. return out;
  98. } else {
  99. mgb_throw(
  100. MegBrainError,
  101. "Unsupported shape ndim %lu in convert NCHW shape to NHWC.",
  102. shape.size());
  103. }
  104. }
  105. using FormatRule = std::function<ValueRefList(
  106. const OpDef&, Span<ValueRef>&, const bool&, const FormatTransformation&)>;
  107. static std::unordered_map<Typeinfo*, FormatRule> format_rules;
  108. template <typename T>
  109. void register_format_rule(ValueRefList (*rule)(
  110. const T&, Span<ValueRef>&, const bool&, const FormatTransformation&)) {
  111. format_rules[T::typeinfo()] = [rule](const OpDef& def, Span<ValueRef>& inputs,
  112. const bool& auto_convert,
  113. const FormatTransformation& t) {
  114. return (*rule)(def.cast_final_safe<T>(), inputs, auto_convert, t);
  115. };
  116. }
  117. inline auto convert_nchw2nhwc_pattern(const std::vector<int32_t>& pattern) {
  118. mgb_assert(pattern.size() == 4);
  119. auto nhwc_pattern = pattern;
  120. for (size_t idx = 0; idx < 4; ++idx) {
  121. auto dim = pattern[idx];
  122. if (dim == 1) {
  123. nhwc_pattern[idx] = 3;
  124. } else if (dim == 2) {
  125. nhwc_pattern[idx] = 1;
  126. } else if (dim == 3) {
  127. nhwc_pattern[idx] = 2;
  128. }
  129. }
  130. return nhwc_pattern;
  131. }
  132. ValueRefList dimshuffle_rule(
  133. const Dimshuffle& op, Span<ValueRef>& inputs, const bool& auto_convert,
  134. const FormatTransformation& t) {
  135. mgb_assert(inputs.size() == 1);
  136. auto& src = inputs[0].cast(t.value_type());
  137. // Only support converting pattern from NCHW to NHWC currently.
  138. if (auto_convert && src.format() == FT::NHWC) {
  139. auto pattern = convert_nchw2nhwc_pattern(op.pattern);
  140. // dimshuffle will not maintain NHWC Format
  141. return t.wrap_outputs(imperative::apply(
  142. *Dimshuffle::make(std::move(pattern), op.scope()),
  143. t.unwrap_inputs(inputs)));
  144. }
  145. return t.wrap_outputs(imperative::apply(op, t.unwrap_inputs(inputs)));
  146. }
  147. ValueRef convert_nchw2nhwc_tensornd(const HostTensorND& shape) {
  148. mgb_assert(shape.layout().total_nr_elems() == 4);
  149. auto* old_ptr = shape.ptr<dt_int32>();
  150. auto cn = shape.comp_node();
  151. auto layout = shape.layout();
  152. auto nhwc_shape = HostTensorND(cn, layout);
  153. auto* new_ptr = nhwc_shape.ptr<dt_int32>();
  154. new_ptr[0] = old_ptr[0];
  155. new_ptr[1] = old_ptr[2];
  156. new_ptr[2] = old_ptr[3];
  157. new_ptr[3] = old_ptr[1];
  158. auto hv = HostStorage::make(nhwc_shape.storage());
  159. auto nhwc_shape_input =
  160. imperative::apply(CreateTensor(CreateTensor::Const, cn, layout), hv)[0];
  161. return nhwc_shape_input;
  162. }
  163. ValueRefList reshape_rule(
  164. const Reshape& op, Span<ValueRef>& inputs, const bool& auto_convert,
  165. const FormatTransformation& t) {
  166. mgb_assert(inputs.size() >= 1);
  167. auto& src = inputs[0].cast(t.value_type());
  168. if (auto_convert && src.format() == FT::NHWC) {
  169. if (inputs.size() == 1) {
  170. if (op.shape.size() == 4) {
  171. // output is still NHWC format
  172. auto nhwc_shape = convert_nchw2nhwc_vector(op.shape);
  173. auto outputs = imperative::apply(
  174. *Reshape::make(op.axis, nhwc_shape),
  175. {t.unwrap_input(inputs[0])});
  176. return t.wrap_outputs(outputs, FT::NHWC);
  177. } else {
  178. // will not maintain src's format
  179. auto nchw_src = t.to(src, FT::DEFAULT, op.scope())->value();
  180. auto outputs = imperative::apply(op, {nchw_src});
  181. return t.wrap_outputs(outputs);
  182. }
  183. } else if (inputs.size() == 2) {
  184. auto shape = t.unwrap_input(inputs[1]).numpy()->as_nd();
  185. if (shape.layout().total_nr_elems() == 4) {
  186. // output is still NHWC format
  187. auto nhwc_shape = convert_nchw2nhwc_tensornd(shape);
  188. auto outputs = imperative::apply(
  189. op,
  190. SmallVector<ValueRef>{t.unwrap_input(inputs[0]), nhwc_shape});
  191. return t.wrap_outputs(outputs, FT::NHWC);
  192. } else {
  193. // will not maintain src's format
  194. auto nchw_src = t.to(src, FT::DEFAULT, op.scope())->value();
  195. auto outputs = imperative::apply(
  196. op, SmallVector<ValueRef>{nchw_src, t.unwrap_input(inputs[1])});
  197. return t.wrap_outputs(outputs);
  198. }
  199. }
  200. }
  201. return t.wrap_outputs(imperative::apply(op, t.unwrap_inputs(inputs)));
  202. }
  203. ValueRefList broadcast_rule(
  204. const Broadcast& op, Span<ValueRef>& inputs, const bool& auto_convert,
  205. const FormatTransformation& t) {
  206. mgb_assert(inputs.size() >= 1);
  207. auto& src = inputs[0].cast(t.value_type());
  208. if (auto_convert && src.format() == FT::NHWC) {
  209. if (inputs.size() == 1) {
  210. if (op.shape.size() == 4) {
  211. // output is still NHWC format
  212. auto nhwc_shape = convert_nchw2nhwc_vector(op.shape);
  213. auto outputs = imperative::apply(
  214. *Broadcast::make(nhwc_shape), {t.unwrap_input(inputs[0])});
  215. return t.wrap_outputs(outputs, FT::NHWC);
  216. } else {
  217. // will not maintain src's format
  218. auto nchw_src = t.to(src, FT::DEFAULT, op.scope())->value();
  219. auto outputs = imperative::apply(op, {nchw_src});
  220. return t.wrap_outputs(outputs);
  221. }
  222. } else if (inputs.size() == 2) {
  223. auto shape = t.unwrap_input(inputs[1]).numpy()->as_nd();
  224. if (shape.layout().total_nr_elems() == 4) {
  225. // output is still NHWC format
  226. auto nhwc_shape = convert_nchw2nhwc_tensornd(shape);
  227. auto outputs = imperative::apply(
  228. op,
  229. SmallVector<ValueRef>{t.unwrap_input(inputs[0]), nhwc_shape});
  230. return t.wrap_outputs(outputs, FT::NHWC);
  231. } else {
  232. // will not maintain src's format
  233. auto nchw_src = t.to(src, FT::DEFAULT, op.scope())->value();
  234. auto outputs = imperative::apply(
  235. op, SmallVector<ValueRef>{nchw_src, t.unwrap_input(inputs[1])});
  236. return t.wrap_outputs(outputs);
  237. }
  238. }
  239. }
  240. return t.wrap_outputs(imperative::apply(op, t.unwrap_inputs(inputs)));
  241. }
  242. inline bool is_reduce_ndim_idx_items(
  243. const std::vector<std::tuple<int8_t, bool, bool, bool, bool>>& items,
  244. const Span<ValueRef>& inputs) {
  245. for (auto i = 0; i < items.size(); ++i) {
  246. auto&& [axis, begin, end, step, idx] = items[i];
  247. if (idx) {
  248. // if inputs[i] contains more than one value, ndim will not be reduced.
  249. return inputs[i].is_scalar();
  250. }
  251. }
  252. return false;
  253. }
  254. inline auto convert_nchw2nhwc_idx_items(
  255. const std::vector<std::tuple<int8_t, bool, bool, bool, bool>>& items) {
  256. auto nhwc_items = items;
  257. for (auto i = 0; i < nhwc_items.size(); ++i) {
  258. auto&& [axis, begin, end, step, idx] = nhwc_items[i];
  259. if (axis == 2 || axis == 3) {
  260. nhwc_items[i] = {axis - 1, begin, end, step, idx};
  261. } else if (axis == 1) {
  262. nhwc_items[i] = {3, begin, end, step, idx};
  263. }
  264. }
  265. return nhwc_items;
  266. }
  267. template <typename T>
  268. ValueRefList subtensor_rule(
  269. const T& op, Span<ValueRef>& inputs, const bool& auto_convert,
  270. const FormatTransformation& t) {
  271. mgb_assert(inputs.size() >= 1);
  272. auto& src = inputs[0].cast(t.value_type());
  273. bool is_reduce_ndim = is_reduce_ndim_idx_items(
  274. op.items, {&inputs[1], &inputs[inputs.size() - 1]});
  275. if (!is_reduce_ndim) {
  276. // only support NHWC2NCHW convert, otherwise maintain src's format
  277. if (!(auto_convert && src.format() == FT::NHWC)) {
  278. return {t.wrap_output(
  279. imperative::apply(op, t.unwrap_inputs(inputs))[0], src.format())};
  280. }
  281. auto nhwc_items = convert_nchw2nhwc_idx_items(op.items);
  282. auto outputs = imperative::apply(
  283. *T::make(std::move(nhwc_items), op.scope()), t.unwrap_inputs(inputs));
  284. return t.wrap_outputs(outputs, FT::NHWC);
  285. }
  286. return t.wrap_outputs(imperative::apply(op, t.unwrap_inputs(inputs)));
  287. }
  288. template <typename T>
  289. ValueRefList setsubtensor_rule(
  290. const T& op, Span<ValueRef>& inputs, const bool& auto_convert,
  291. const FormatTransformation& t) {
  292. mgb_assert(inputs.size() >= 2);
  293. auto& src = inputs[0].cast(t.value_type());
  294. bool is_reduce_ndim = is_reduce_ndim_idx_items(
  295. op.items, {&inputs[2], &inputs[inputs.size() - 1]});
  296. if (!is_reduce_ndim) {
  297. // only support NHWC2NCHW convert, otherwise maintain src's format
  298. if (!(auto_convert && src.format() == FT::NHWC)) {
  299. return {t.wrap_output(
  300. imperative::apply(op, t.unwrap_inputs(inputs))[0], src.format())};
  301. }
  302. // value has been broadcasted to src's fake NCHW shape.
  303. auto& value = inputs[1].cast(t.value_type());
  304. auto& format = value.format();
  305. auto nhwc_inputs = ValueRefList(inputs.size());
  306. if (format == FT::DEFAULT || format == FT::NCHW) {
  307. // value for setsubtensor should transpose to match shape.
  308. auto nhwc_value = t.to(value, FT::NHWC);
  309. // make new inputs for setsubtensor
  310. nhwc_inputs[0] = src.value();
  311. nhwc_inputs[1] = nhwc_value->value();
  312. for (auto i = 2; i < inputs.size(); ++i) {
  313. nhwc_inputs[i] = t.unwrap_input(inputs[i]);
  314. }
  315. } else if (format != FT::NHWC) {
  316. mgb_throw(
  317. MegBrainError, "Unsupported format(%s) of value for setsubtensor.",
  318. format.to_string().c_str());
  319. }
  320. auto nhwc_items = convert_nchw2nhwc_idx_items(op.items);
  321. auto outputs = imperative::apply(
  322. *T::make(std::move(nhwc_items), op.scope()), nhwc_inputs);
  323. return t.wrap_outputs(outputs, FT::NHWC);
  324. }
  325. return t.wrap_outputs(imperative::apply(op, t.unwrap_inputs(inputs)));
  326. }
  327. inline FT get_inputs_format(Span<ValueRef>& inputs, const FormatTransformation& t) {
  328. FT format(FT::DEFAULT);
  329. for (auto& inp : inputs) {
  330. auto&& inp_format = inp.cast(t.value_type()).format();
  331. if (inp_format != FT::DEFAULT) {
  332. mgb_assert(format == FT::DEFAULT || inp_format == format);
  333. format = inp_format.type();
  334. }
  335. }
  336. return format;
  337. }
  338. inline ValueRefList unify_inputs_format(
  339. const Span<ValueRef>& inputs, const FT& dst_fmt, const std::string& scope,
  340. const FormatTransformation& t) {
  341. ValueRefList unified_inputs(inputs.size());
  342. for (size_t i = 0; i < inputs.size(); ++i) {
  343. auto&& inp = inputs[i].cast(t.value_type());
  344. if (inp.format() != dst_fmt &&
  345. inp.value().shape().cast<ShapeValue>().ndim == 4) {
  346. unified_inputs[i] = t.to(inp, dst_fmt, scope);
  347. } else {
  348. unified_inputs[i] = inputs[i];
  349. }
  350. }
  351. return unified_inputs;
  352. }
  353. ValueRefList elemwise_rule(
  354. const Elemwise& op, Span<ValueRef>& inputs, const bool& auto_convert,
  355. const FormatTransformation& t) {
  356. FT format = get_inputs_format(inputs, t);
  357. if (format == FT::NHWC && auto_convert) {
  358. auto unified_inputs = unify_inputs_format(inputs, FT::NHWC, op.scope(), t);
  359. return t.wrap_outputs(
  360. imperative::apply(op, t.unwrap_inputs(unified_inputs)), format);
  361. }
  362. return t.wrap_outputs(imperative::apply(op, t.unwrap_inputs(inputs)), format);
  363. }
  364. ValueRefList concat_rule(
  365. const Concat& op, Span<ValueRef>& inputs, const bool& auto_convert,
  366. const FormatTransformation& t) {
  367. FT format = get_inputs_format(inputs, t);
  368. if (!(format == FT::NHWC && auto_convert)) {
  369. return t.wrap_outputs(imperative::apply(op, t.unwrap_inputs(inputs)), format);
  370. }
  371. auto unified_inputs = unify_inputs_format(inputs, FT::NHWC, op.scope(), t);
  372. // TODO: handle 5D NHWC Tensor from group conv
  373. auto axis = op.axis;
  374. if (axis == 2 || axis == 3) {
  375. axis = axis - 1;
  376. } else if (axis == 1) {
  377. axis = 3;
  378. }
  379. return t.wrap_outputs(
  380. imperative::apply(
  381. *Concat::make(axis, op.comp_node, op.scope()),
  382. t.unwrap_inputs(unified_inputs)),
  383. format);
  384. }
  385. ValueRefList identity_rule_helper(
  386. const OpDef& op, const Span<ValueRef>& inputs, const FormatTransformation& t) {
  387. // mgb_assert(inputs.size() == 1);
  388. if (auto& src = inputs[0].as_ref(t.value_type())) {
  389. return t.wrap_outputs(
  390. imperative::apply(op, t.unwrap_inputs(inputs)), src->format());
  391. } else {
  392. return t.wrap_outputs(imperative::apply(op, t.unwrap_inputs(inputs)));
  393. }
  394. }
  395. ValueRefList batchnorm_rule(
  396. const BatchNorm& op, Span<ValueRef>& inputs, const bool& auto_convert,
  397. const FormatTransformation& t) {
  398. auto&& inp_format = inputs[0].cast(t.value_type()).format();
  399. if (inp_format == FT::NHWC) {
  400. auto&& new_param = op.param();
  401. new_param.param_dim = BatchNorm::ParamDim::DIM_111C;
  402. auto new_op = BatchNorm::make(new_param);
  403. return identity_rule_helper(*new_op, inputs, t);
  404. }
  405. return identity_rule_helper(op, inputs, t);
  406. }
  407. ValueRefList adaptive_pooling_rule(
  408. const AdaptivePooling& op, Span<ValueRef>& inputs, const bool& auto_convert,
  409. const FormatTransformation& t) {
  410. auto&& inp_format = inputs[0].cast(t.value_type()).format();
  411. if (inp_format == FT::NHWC) {
  412. auto&& new_param = op.param();
  413. new_param.format = AdaptivePooling::Format::NHWC;
  414. auto new_op = AdaptivePooling::make(new_param, op.shape);
  415. return identity_rule_helper(*new_op, inputs, t);
  416. }
  417. return identity_rule_helper(op, inputs, t);
  418. }
  419. // clang-format off
  420. #define FOREACH_MULTI_INPS_NO_PARAM_OP(cb) \
  421. cb(CompiledOp) \
  422. cb(SubgraphOp)
  423. #define FOREACH_IDENTITY_OP(cb) \
  424. cb(Copy) \
  425. cb(FastpathCopy) \
  426. cb(TypeCvt) \
  427. cb(Dropout) \
  428. cb(Identity)
  429. #define FOREACH_FORMAT_OP(cb) \
  430. cb(WarpAffine) \
  431. cb(Resize)
  432. #define FOREACH_FORMAT_POLICY_OP(cb) \
  433. cb(Pooling) \
  434. cb(Convolution)
  435. #define FOREACH_BYPASS_OP(cb) \
  436. cb(ParamPackSplit) \
  437. cb(ParamPackConcat) \
  438. cb(CollectiveComm) \
  439. cb(CheckNonFinite)
  440. // clang-format on
  441. // multi inputs op without params
  442. #define CREATE_MULTI_INPS_NO_PARAM_OP_RULE(Op) \
  443. ValueRefList Op##_rule( \
  444. const Op& _op, Span<ValueRef>& inputs, const bool& auto_convert, \
  445. const FormatTransformation& t) { \
  446. FT format = get_inputs_format(inputs, t); \
  447. return t.wrap_outputs( \
  448. imperative::apply(_op, t.unwrap_inputs(inputs)), format); \
  449. }
  450. FOREACH_MULTI_INPS_NO_PARAM_OP(CREATE_MULTI_INPS_NO_PARAM_OP_RULE)
  451. #undef CREATE_MULTI_INPS_NO_PARAM_OP_RULE
  452. // identity op
  453. #define CREATE_IDENTITY_OP_RULE(Op) \
  454. ValueRefList Op##_rule( \
  455. const Op& _op, Span<ValueRef>& inputs, const bool& auto_convert, \
  456. const FormatTransformation& t) { \
  457. return identity_rule_helper(_op, inputs, t); \
  458. }
  459. FOREACH_IDENTITY_OP(CREATE_IDENTITY_OP_RULE)
  460. #undef CREATE_IDENTITY_OP_RULE
  461. // identity op with Format param
  462. #define CREATE_FORMAT_OP_RULE(Op) \
  463. ValueRefList Op##_rule( \
  464. const Op& _op, Span<ValueRef>& inputs, const bool& auto_convert, \
  465. const FormatTransformation& t) { \
  466. auto&& inp_format = inputs[0].cast(t.value_type()).format(); \
  467. if (inp_format == FT::NHWC) { \
  468. auto&& new_param = _op.param(); \
  469. new_param.format = Op::Format::NHWC; \
  470. auto new_op = Op::make(new_param); \
  471. return identity_rule_helper(*new_op, inputs, t); \
  472. } \
  473. return identity_rule_helper(_op, inputs, t); \
  474. }
  475. FOREACH_FORMAT_OP(CREATE_FORMAT_OP_RULE)
  476. #undef CREATE_FORMAT_OP_RULE
  477. // identity op with Format and policy param
  478. #define CREATE_FORMAT_POLICY_OP_RULE(Op) \
  479. ValueRefList Op##_rule( \
  480. const Op& _op, Span<ValueRef>& inputs, const bool& auto_convert, \
  481. const FormatTransformation& t) { \
  482. auto&& inp_format = inputs[0].cast(t.value_type()).format(); \
  483. if (inp_format == FT::NHWC) { \
  484. auto&& new_param = _op.param(); \
  485. new_param.format = Op::Format::NHWC; \
  486. auto new_op = Op::make(new_param, _op.policy()); \
  487. return identity_rule_helper(*new_op, inputs, t); \
  488. } \
  489. return identity_rule_helper(_op, inputs, t); \
  490. }
  491. FOREACH_FORMAT_POLICY_OP(CREATE_FORMAT_POLICY_OP_RULE)
  492. #define CREATE_BYPASS_OP_RULE(Op) \
  493. ValueRefList Op##_rule( \
  494. const Op& _op, Span<ValueRef>& inputs, const bool& auto_convert, \
  495. const FormatTransformation& t) { \
  496. return t.wrap_outputs(imperative::apply(_op, t.unwrap_inputs(inputs))); \
  497. }
  498. FOREACH_BYPASS_OP(CREATE_BYPASS_OP_RULE)
  499. #undef CREATE_BYPASS_OP_RULE
  500. #undef CREATE_FORMAT_OP_RULE
  501. #define REGISTER_OP_RULE(op) register_format_rule(op##_rule);
  502. struct FormatRuleRegistry {
  503. FormatRuleRegistry() {
  504. register_format_rule(dimshuffle_rule);
  505. register_format_rule(reshape_rule);
  506. register_format_rule(broadcast_rule);
  507. register_format_rule(subtensor_rule<Subtensor>);
  508. register_format_rule(subtensor_rule<IndexingMultiAxisVec>);
  509. register_format_rule(setsubtensor_rule<SetSubtensor>);
  510. register_format_rule(setsubtensor_rule<IndexingSetMultiAxisVec>);
  511. register_format_rule(elemwise_rule);
  512. register_format_rule(concat_rule);
  513. register_format_rule(batchnorm_rule);
  514. register_format_rule(adaptive_pooling_rule);
  515. FOREACH_MULTI_INPS_NO_PARAM_OP(REGISTER_OP_RULE)
  516. FOREACH_IDENTITY_OP(REGISTER_OP_RULE)
  517. FOREACH_FORMAT_OP(REGISTER_OP_RULE)
  518. FOREACH_FORMAT_POLICY_OP(REGISTER_OP_RULE)
  519. FOREACH_BYPASS_OP(REGISTER_OP_RULE)
  520. }
  521. } _;
  522. #undef REGISTER_OP_RULE
  523. } // namespace
  524. ValueRefList FormatTransformation::apply_transformation(
  525. const Operator& op, Span<ValueRef> inputs) {
  526. if (auto* apply_op = op.as<ApplyOp>()) {
  527. // all inputs should be FormattedTensorValue
  528. auto iter = format_rules.find(apply_op->op().dyn_typeinfo());
  529. if (iter != format_rules.end()) {
  530. return iter->second(apply_op->op(), inputs, m_auto_convert, *this);
  531. } else {
  532. auto unified_inputs = unify_inputs_format(
  533. inputs, FT::DEFAULT, apply_op->op().scope(), *this);
  534. return wrap_outputs(imperative::apply(op, unwrap_inputs(unified_inputs)));
  535. }
  536. } else if (auto* create_tensor = op.as<CreateTensor>()) {
  537. auto format = create_tensor->format();
  538. if (format == FT::NHWC) {
  539. auto output = wrap_output(imperative::apply(op, inputs)[0]);
  540. output = to(output.cast(m_value_type), FT::NHWC, "");
  541. return {output};
  542. } else {
  543. return {wrap_output(imperative::apply(op, inputs)[0], format)};
  544. }
  545. } else if (auto* get_attr = op.as<GetAttr>()) {
  546. auto&& input = inputs.item();
  547. if (!input.is(m_value_type)) {
  548. return imperative::apply(op, input);
  549. }
  550. auto& src = input.cast(m_value_type);
  551. if (!(m_auto_convert && src.format() == FT::NHWC)) {
  552. return imperative::apply(op, unwrap_inputs(inputs));
  553. }
  554. switch (get_attr->attr()) {
  555. case GetAttr::Shape: {
  556. auto output = imperative::apply(op, unwrap_inputs(inputs))[0];
  557. auto shape = convert_nhwc2nchw_shape(output.cast<ShapeValue>());
  558. return {ShapeValue::make(shape)};
  559. }
  560. case GetAttr::Value: {
  561. auto nchw_src = unwrap_input(to(src, FT::DEFAULT, ""));
  562. return imperative::apply(op, {nchw_src});
  563. }
  564. default:
  565. return imperative::apply(op, unwrap_inputs(inputs));
  566. }
  567. } else if (op.is<GetFormat>()) {
  568. auto&& inp_ref = inputs[0].as_ref(m_value_type);
  569. if (inp_ref) {
  570. return {FormatValue::make(inp_ref->format())};
  571. } else {
  572. mgb_log_warn(
  573. "Not FormattedTensorValue input for GetFormat op: %s, %s",
  574. op.to_string().c_str(), inputs[0].to_string().c_str());
  575. return {FormatValue::make(FT::DEFAULT)};
  576. }
  577. } else if (auto* _op = op.as<SetFormat>()) {
  578. auto&& inp_ref = inputs[0].as_ref(m_value_type);
  579. mgb_assert(inp_ref, "Cannot set format for non-format Tensor.");
  580. return {m_value_type.make(inp_ref->value(), _op->format())};
  581. } else if (op.is<Operator::IdentityLike>()) {
  582. auto&& inp_ref = inputs[0].as_ref(m_value_type);
  583. if (inp_ref) {
  584. auto&& format = inp_ref->format();
  585. return wrap_outputs(imperative::apply(op, unwrap_inputs(inputs)), format);
  586. } else {
  587. mgb_log_warn(
  588. "Not FormattedTensorValue input for IdentityLike op: %s, %s",
  589. op.to_string().c_str(), inputs[0].to_string().c_str());
  590. return imperative::apply(op, inputs);
  591. }
  592. } else if (op.is<AttachGrad>()) {
  593. auto&& inp_ref = inputs[0].as_ref(m_value_type);
  594. if (inp_ref) {
  595. auto format = inp_ref->format();
  596. GenericFunction callback =
  597. (GenericFunction&)inputs[1].cast<FunctionValue>();
  598. // make param grads as FormattedTensor
  599. GenericFunction new_callback =
  600. [&, callback, format](Span<ValueRef> inputs_) -> ValueRefList {
  601. auto wrapped_inputs = SmallVector<ValueRef>{
  602. m_value_type.make(inputs_.item(), format)};
  603. auto ret = callback(wrapped_inputs);
  604. return ret;
  605. };
  606. auto&& outputs = imperative::apply(
  607. op, inp_ref->value(), FunctionValue::make(new_callback));
  608. // make params(GradValue) as FormattedTensor
  609. return wrap_outputs(outputs, format);
  610. } else {
  611. mgb_log_warn(
  612. "Not FormattedTensorValue input for AttachGrad op: %s, %s",
  613. op.to_string().c_str(), inputs[0].to_string().c_str());
  614. return imperative::apply(op, inputs);
  615. }
  616. } else if (auto* set_grad = op.as<SetGrad>()) {
  617. // make grads in Function backward as FormattedTensor
  618. size_t nr_inputs = set_grad->nr_inputs();
  619. size_t nr_outputs = inputs.size() - nr_inputs;
  620. Span<ValueRef> inputs_ = {inputs.data(), nr_inputs};
  621. Span<ValueRef> outputs_ = {inputs.data() + nr_inputs, nr_outputs};
  622. // run original apply.
  623. // grads needn't to unwrap and wrap, which will be unwrapped in GradTrans
  624. auto&& outputs = imperative::apply(op, unwrap_inputs(inputs));
  625. // handle output's formats
  626. auto wrapped_outputs = ValueRefList(nr_outputs);
  627. for (size_t i = 0; i < nr_outputs; ++i) {
  628. if (auto output_ref = outputs_[i].as_ref(m_value_type)) {
  629. wrapped_outputs[i] =
  630. m_value_type.make(outputs[i], output_ref->format());
  631. } else {
  632. mgb_log_warn(
  633. "Not FormattedTensorValue outputs for SetGrad op: %s, %s",
  634. op.to_string().c_str(), inputs_[i].to_string().c_str());
  635. wrapped_outputs[i] = m_value_type.make(outputs[i], FT::DEFAULT);
  636. }
  637. }
  638. return wrapped_outputs;
  639. } else {
  640. return imperative::apply(op, unwrap_inputs(inputs));
  641. }
  642. };
  643. } // namespace imperative
  644. } // namespace mgb