You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

backward_graph.cpp 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325
  1. #include "./helper.h"
  2. #include "megbrain/imperative/backward_graph_opt.h"
  3. #include "megbrain/imperative/ops/autogen.h"
  4. #include "megbrain/imperative/ops/opr_attr.h"
  5. #include "megbrain/opr/basic_arith.h"
  6. #include "megbrain/opr/dnn/batch_norm.h"
  7. using namespace mgb;
  8. using namespace cg;
  9. using namespace imperative;
  10. template <typename T>
  11. T prepare_backward_graph_inputs(
  12. const EncodedSubgraph& bg, const T& inputs, const T& outputs, const T& grads) {
  13. T ret;
  14. size_t i = 0;
  15. for (auto&& t : inputs) {
  16. if (bg.input_mask[i++]) {
  17. ret.push_back(t);
  18. }
  19. }
  20. for (auto&& t : outputs) {
  21. if (bg.input_mask[i++]) {
  22. ret.push_back(t);
  23. }
  24. }
  25. for (auto&& t : grads) {
  26. if (bg.input_mask[i++]) {
  27. ret.push_back(t);
  28. }
  29. }
  30. return ret;
  31. }
  32. template <typename T, typename U>
  33. T expand_grads(const U& mask, const T& outputs) {
  34. T ret(mask.size());
  35. for (size_t i = 0, j = 0; i < mask.size(); ++i) {
  36. if (mask[i]) {
  37. ret[i] = outputs[j++];
  38. }
  39. }
  40. return ret;
  41. }
  42. template <typename T>
  43. T prepare_optimized_backward_inputs(
  44. const OptimizedBackwardGraphResult& bg, const T& precomp, const T& inputs,
  45. const T& outputs, const T& grads) {
  46. T ret = precomp;
  47. size_t i = 0;
  48. for (auto&& t : inputs) {
  49. if (bg.save_for_backward[i++]) {
  50. ret.push_back(t);
  51. }
  52. }
  53. for (auto&& t : outputs) {
  54. if (bg.save_for_backward[i++]) {
  55. ret.push_back(t);
  56. }
  57. }
  58. for (auto&& t : grads) {
  59. if (bg.save_for_backward[i++]) {
  60. ret.push_back(t);
  61. }
  62. }
  63. return ret;
  64. }
  65. SmallVector<TensorPtr> apply_shared_on_physical_tensor(
  66. std::shared_ptr<OpDef> def, SmallVector<TensorPtr> inputs, size_t nr_outputs) {
  67. SmallVector<LogicalTensorDesc> input_descs;
  68. for (auto&& i : inputs) {
  69. input_descs.push_back({i->layout(), i->comp_node()});
  70. }
  71. auto [output_descs, validated] =
  72. OpDef::infer_output_attrs_fallible(*def, input_descs);
  73. return OpDef::apply_on_physical_tensor(*def, inputs, output_descs, validated);
  74. }
  75. TEST(TestImperative, BackwardGraphBasic) {
  76. HostTensorGenerator<> gen;
  77. SmallVector<HostTensorND> hvs;
  78. SmallVector<TensorPtr> inputs;
  79. for (size_t i = 0; i < 2; ++i) {
  80. hvs.push_back(*gen({42}));
  81. inputs.push_back(Tensor::make(hvs.back()));
  82. }
  83. using Param = opr::Elemwise::Param;
  84. Param param{Param::Mode::MUL};
  85. auto attr = OprAttr::make("Elemwise");
  86. attr->cast_final_safe<OprAttr>().param.write_pod(param);
  87. SmallVector<LogicalTensorDesc> input_descs;
  88. for (auto&& i : inputs) {
  89. input_descs.push_back({i->layout(), i->comp_node()});
  90. }
  91. auto result = OpDef::make_backward_graph(*attr, input_descs, {true, true}, {true});
  92. auto&& save_for_backward = result.input_mask;
  93. auto&& input_has_grad = result.output_mask;
  94. for (size_t i = 0; i < inputs.size(); i++) {
  95. input_descs[i].value = inputs[i]->dev_tensor();
  96. }
  97. auto [output_descs, validated] =
  98. OpDef::infer_output_attrs_fallible(*attr, input_descs);
  99. auto outputs =
  100. OpDef::apply_on_physical_tensor(*attr, inputs, output_descs, validated);
  101. inputs.push_back(outputs[0]);
  102. hvs.push_back(*gen({42}));
  103. inputs.push_back(Tensor::make(hvs.back()));
  104. mgb_assert(save_for_backward.size() == inputs.size());
  105. for (size_t i = 0; i < inputs.size(); ++i) {
  106. if (!save_for_backward[i]) {
  107. inputs[i].reset(); // drop unused tensor
  108. }
  109. }
  110. SmallVector<TensorPtr> backward_graph_inputs;
  111. for (auto&& i : inputs) {
  112. if (i) {
  113. backward_graph_inputs.push_back(i);
  114. }
  115. }
  116. inputs.clear();
  117. auto input_grads = result.graph.apply<TensorPtr>(
  118. backward_graph_inputs, apply_shared_on_physical_tensor,
  119. [&](auto&& x) { return x; });
  120. mgb_assert(input_grads.size() == input_has_grad.size());
  121. for (size_t i = 0; i < input_has_grad.size(); ++i) {
  122. mgb_assert(input_has_grad[i] == static_cast<bool>(input_grads[i]));
  123. }
  124. SmallVector<HostTensorND> res;
  125. for (auto&& i : input_grads) {
  126. res.emplace_back();
  127. res.back().copy_from(i->dev_tensor()).sync();
  128. }
  129. for (size_t i = 0; i < 42; ++i) {
  130. for (size_t j = 0; j < 1; ++j) {
  131. ASSERT_EQ(
  132. hvs[2].ptr<float>()[i] * hvs[j].ptr<float>()[i],
  133. res[j ^ 1].ptr<float>()[i]);
  134. }
  135. }
  136. }
  137. TEST(TestImperative, ProfileBackward) {
  138. auto cn = CompNode::load("xpux");
  139. using Policy = megdnn::param::ExecutionPolicy;
  140. using S = Policy::Strategy;
  141. Policy policy;
  142. policy.strategy = S::PROFILE;
  143. {
  144. megdnn::param::Convolution param;
  145. auto op = std::shared_ptr<OpDef>(Convolution::make(param, policy));
  146. LogicalTensorDesc inp_desc = {
  147. TensorLayout({16, 3, 16, 16}, dtype::Float32()), cn};
  148. LogicalTensorDesc weight_desc = {
  149. TensorLayout({16, 3, 5, 5}, dtype::Float32()), cn};
  150. auto bg = OpDef::make_backward_graph(
  151. *op, {inp_desc, weight_desc}, {true, false}, {true});
  152. auto&& bop = (bg.graph.exprs.at(0)).op;
  153. auto&& attr = bop->cast_final_safe<OprAttr>();
  154. // attr.type = ConvolutionBackwardDataV2
  155. mgb_assert(attr.policy.strategy == S::PROFILE);
  156. }
  157. {
  158. megdnn::param::Pooling param;
  159. auto op = std::shared_ptr<OpDef>(Pooling::make(param, policy));
  160. LogicalTensorDesc inp_desc = {
  161. TensorLayout({16, 3, 16, 16}, dtype::Float32()), cn};
  162. auto bg = OpDef::make_backward_graph(*op, {inp_desc}, {true}, {true});
  163. auto&& bop = (bg.graph.exprs.at(0)).op;
  164. auto&& attr = bop->cast_final_safe<OprAttr>();
  165. // attr.type = PoolingBackwardV1
  166. mgb_assert(attr.policy.strategy == S::PROFILE);
  167. }
  168. {
  169. megdnn::param::MatrixMul param;
  170. auto op = std::shared_ptr<OpDef>(MatrixMul::make(param, policy, 2, 2));
  171. LogicalTensorDesc inp1_desc = {TensorLayout({12, 16}, dtype::Float32()), cn};
  172. LogicalTensorDesc inp2_desc = {TensorLayout({16, 20}, dtype::Float32()), cn};
  173. auto bg = OpDef::make_backward_graph(
  174. *op, {inp1_desc, inp2_desc}, {true, false}, {true});
  175. auto&& bop = (bg.graph.exprs.at(0)).op;
  176. auto&& attr = bop->cast_final_safe<OprAttr>();
  177. // attr.type = MatrixMulV2
  178. mgb_assert(attr.policy.strategy == S::PROFILE);
  179. }
  180. }
  181. TEST(TestImperative, BackwardGraphIdentity) {
  182. HostTensorGenerator<> gen;
  183. auto host_a = gen({42}), host_dc = gen({42});
  184. auto a = Tensor::make(*host_a), dc = Tensor::make(*host_dc);
  185. SmallVector<TensorPtr> inputs;
  186. inputs.push_back(a);
  187. auto attr = OprAttr::make("Identity");
  188. attr->cast_final_safe<OprAttr>().param.write_pod<megdnn::param::Empty>({});
  189. SmallVector<LogicalTensorDesc> input_descs;
  190. input_descs.push_back({a->layout(), a->comp_node()});
  191. auto result = OpDef::make_backward_graph(*attr, input_descs, {true}, {true});
  192. auto&& save_for_backward = result.input_mask;
  193. auto&& input_has_grad = result.output_mask;
  194. auto [output_descs, validated] =
  195. OpDef::infer_output_attrs_fallible(*attr, input_descs);
  196. auto outputs =
  197. OpDef::apply_on_physical_tensor(*attr, inputs, output_descs, validated);
  198. inputs.push_back(outputs[0]);
  199. inputs.push_back(dc);
  200. mgb_assert(save_for_backward.size() == inputs.size());
  201. for (size_t i = 0; i < inputs.size(); ++i) {
  202. if (!save_for_backward[i]) {
  203. inputs[i].reset(); // drop unused tensor
  204. }
  205. }
  206. SmallVector<TensorPtr> backward_graph_inputs;
  207. for (auto&& i : inputs) {
  208. if (i) {
  209. backward_graph_inputs.push_back(i);
  210. }
  211. }
  212. inputs.clear();
  213. auto input_grads = result.graph.apply<TensorPtr>(
  214. backward_graph_inputs, apply_shared_on_physical_tensor,
  215. [&](auto&& x) { return x; });
  216. mgb_assert(input_grads.size() == input_has_grad.size());
  217. for (size_t i = 0; i < input_has_grad.size(); ++i) {
  218. mgb_assert(input_has_grad[i] == static_cast<bool>(input_grads[i]));
  219. }
  220. HostTensorND hv;
  221. hv.copy_from(input_grads[0]->dev_tensor()).sync();
  222. for (size_t i = 0; i < 42; ++i) {
  223. ASSERT_EQ(host_dc->ptr<float>()[i], hv.ptr<float>()[i]);
  224. }
  225. }
  226. TEST(TestImperative, BatchNormGrad) {
  227. auto cn = CompNode::load("xpux");
  228. using Param = opr::BatchNorm::Param;
  229. size_t N = 2, C = 3, H = 5, W = 5;
  230. LogicalTensorDesc inp{TensorLayout{{N, C, H, W}, dtype::Float32()}, cn};
  231. LogicalTensorDesc stat{TensorLayout{{C}, dtype::Float32()}, cn};
  232. {
  233. auto op = OprAttr::make("BatchNormV1");
  234. auto&& attr = op->cast_final_safe<OprAttr>();
  235. Param param;
  236. param.fwd_mode = Param::FwdMode::TRAINING;
  237. attr.param.write_pod(param);
  238. OpDef::make_backward_graph(
  239. attr, {inp, stat, stat, stat, stat}, {true, true, true, false, false},
  240. {false, false, false, false, false, true});
  241. }
  242. {
  243. auto op = OprAttr::make("BatchNormV1");
  244. auto&& attr = op->cast_final_safe<OprAttr>();
  245. Param param;
  246. param.fwd_mode = Param::FwdMode::TRAINING;
  247. attr.param.write_pod(param);
  248. OpDef::make_backward_graph(
  249. attr, {inp, stat, stat}, {true, true, true},
  250. {false, false, false, true});
  251. }
  252. }
  253. TEST(TestImperative, OptimizedBackwardGraphBasic) {
  254. auto cn = CompNode::load("xpux");
  255. LogicalTensorDesc desc = {TensorLayout(dtype::Float32()), cn};
  256. HostTensorGenerator<> gen;
  257. auto op = std::shared_ptr<OpDef>(Elemwise::make(Elemwise::Mode::ADD));
  258. auto bg = OpDef::make_backward_graph(*op, {desc, desc}, {true, true}, {true});
  259. auto obg = OptimizedBackwardGraphResult(bg);
  260. ASSERT_EQ(obg.save_for_backward.size(), 4);
  261. ASSERT_FALSE(obg.save_for_backward[0]);
  262. ASSERT_FALSE(obg.save_for_backward[1]);
  263. ASSERT_FALSE(obg.save_for_backward[2]);
  264. auto a_hv = gen({42});
  265. auto b_hv = gen({5, 42});
  266. auto dc_hv = gen({5, 42});
  267. auto a_tn = Tensor::make(*a_hv);
  268. auto b_tn = Tensor::make(*b_hv);
  269. auto dc_tn = Tensor::make(*dc_hv);
  270. SmallVector<LogicalTensorDesc> input_descs;
  271. input_descs.push_back({a_tn->layout(), a_tn->comp_node(), a_tn->dev_tensor()});
  272. input_descs.push_back({b_tn->layout(), b_tn->comp_node(), b_tn->dev_tensor()});
  273. auto [output_descs, validated] =
  274. OpDef::infer_output_attrs_fallible(*op, input_descs);
  275. auto c_tn = OpDef::apply_on_physical_tensor(
  276. *op, {a_tn, b_tn}, output_descs, validated)[0];
  277. auto backward_graph_inputs = prepare_backward_graph_inputs<SmallVector<TensorPtr>>(
  278. bg, {a_tn, b_tn}, {c_tn}, {dc_tn});
  279. auto grads = expand_grads(
  280. bg.output_mask,
  281. bg.graph.apply<TensorPtr>(
  282. backward_graph_inputs, apply_shared_on_physical_tensor,
  283. [&](auto&& x) { return x; }));
  284. auto precomp = obg.precomp.apply<TensorPtr>(
  285. SmallVector<TensorPtr>{a_tn, b_tn, c_tn}, apply_shared_on_physical_tensor,
  286. [&](auto&& x) { return x; });
  287. ASSERT_EQ(precomp.size(), 2);
  288. ASSERT_EQ(precomp[0]->shape().ndim, 1);
  289. ASSERT_LE(precomp[0]->shape()[0], 2);
  290. ASSERT_EQ(precomp[1]->shape().ndim, 1);
  291. ASSERT_LE(precomp[1]->shape()[0], 2);
  292. auto backward_inputs = prepare_optimized_backward_inputs<SmallVector<TensorPtr>>(
  293. obg, precomp, {a_tn, b_tn}, {c_tn}, {dc_tn});
  294. auto grads2 = expand_grads(
  295. obg.input_has_grad,
  296. obg.backward.apply<TensorPtr>(
  297. backward_inputs, apply_shared_on_physical_tensor,
  298. [&](auto&& x) { return x; }));
  299. ASSERT_EQ(grads2.size(), 2);
  300. MGB_ASSERT_TENSOR_EQ(grads[0]->get_value(), grads2[0]->get_value());
  301. MGB_ASSERT_TENSOR_EQ(grads[1]->get_value(), grads2[1]->get_value());
  302. }