You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tensor_manip.cpp 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318
  1. /**
  2. * \file imperative/src/impl/ops/tensor_manip.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "megbrain/opr/tensor_manip.h"
  12. #include "megbrain/imperative/ops/autogen.h"
  13. #include "megbrain/imperative/ops/opr_attr.h"
  14. #include "../async_releaser.h"
  15. #include "../dnn_op_helper.h"
  16. #include "../op_trait.h"
  17. namespace mgb::imperative {
  18. namespace get_var_shape {
  19. cg::OperatorNodeBase* apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  20. auto&& op_def = def.cast_final_safe<GetVarShape>();
  21. OperatorNodeConfig config{op_def.make_name()};
  22. return opr::GetVarShape::make(inputs, op_def.param(), config).node()->owner_opr();
  23. }
  24. DispatchMode decide_dispatch_mode(
  25. const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) {
  26. bool host_computable = true;
  27. for (auto&& inp : inputs) {
  28. // FIXME(czh): remove value check after proxy graph's
  29. // apply_on_device_tensornd is supported and output Tensor
  30. // is made before add_task.
  31. // then if layout is valid, ptr->layout must be ready
  32. if (inp.value.empty() || inp.value.layout().ndim == 0) {
  33. host_computable = false;
  34. break;
  35. }
  36. }
  37. return host_computable ? DEFAULT_CPU : KERNEL;
  38. }
  39. void apply_on_device_tensornd(
  40. const OpDef& def, const SmallVector<DeviceTensorND>& inputs,
  41. SmallVector<DeviceTensorND>* outputs) {
  42. auto&& op_def = def.cast_final_safe<GetVarShape>();
  43. TensorShape shp;
  44. if (inputs.size() == 1) {
  45. shp = inputs[0].layout();
  46. } else {
  47. TensorShapeArray src(inputs.size());
  48. for (size_t i = 0; i < inputs.size(); ++i) {
  49. src[i] = inputs[i].layout();
  50. }
  51. megdnn::Elemwise::deduce_shape(src, shp);
  52. }
  53. mgb_assert(shp.ndim != 0, "input shape invalid");
  54. mgb_assert(
  55. (*outputs)[0].comp_node() == CompNode::default_cpu(),
  56. "GetVarShape's apply_on_device_tensornd should receive default_cpu "
  57. "outputs.");
  58. HostTensorND hv;
  59. if (op_def.axis == opr::GetVarShape::Param::INVALID_AXIS) {
  60. hv = HostTensorND(CompNode::default_cpu(), {shp.ndim}, dtype::Int32());
  61. auto* ptr = hv.ptr<dt_int32>();
  62. for (size_t i = 0; i < shp.ndim; ++i) {
  63. ptr[i] = shp.shape[i];
  64. }
  65. } else {
  66. int32_t axis = op_def.axis;
  67. if (axis < 0) {
  68. axis += shp.ndim;
  69. }
  70. mgb_assert(axis >= 0 && axis < (int32_t)shp.ndim);
  71. hv = HostTensorND(CompNode::default_cpu(), {1}, dtype::Int32());
  72. auto* ptr = hv.ptr<dt_int32>();
  73. ptr[0] = shp.shape[axis];
  74. }
  75. (*outputs)[0] = DeviceTensorND::make_proxy(hv);
  76. }
  77. HostTensorND get_var_shape_host_tensor(
  78. const OpDef& def, const SmallVector<TensorPtr>& inputs) {
  79. SmallVector<DeviceTensorND> input_tensornds;
  80. for (auto&& inp : inputs) {
  81. input_tensornds.push_back(inp->dev_tensor());
  82. }
  83. SmallVector<DeviceTensorND> output_tensornds = {
  84. {CompNode::default_cpu(), dtype::Int32()}};
  85. apply_on_device_tensornd(def, input_tensornds, &output_tensornds);
  86. // restore to input comp_node
  87. return HostTensorND::make_proxy(output_tensornds[0])
  88. .proxy_to_comp_node(inputs[0]->comp_node());
  89. }
  90. SmallVector<TensorPtr> apply_on_physical_tensor(
  91. const OpDef& def, const SmallVector<TensorPtr>& inputs,
  92. SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) {
  93. return {Tensor::make(std::move(get_var_shape_host_tensor(def, inputs)))};
  94. }
  95. std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
  96. const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) {
  97. auto&& op_def = def.cast_final_safe<GetVarShape>();
  98. auto&& desc = inputs[0];
  99. TensorShape shp;
  100. if (inputs.size() == 1) {
  101. shp = desc.layout;
  102. } else {
  103. TensorShapeArray src(inputs.size());
  104. for (size_t i = 0; i < inputs.size(); ++i) {
  105. src[i] = inputs[i].layout;
  106. }
  107. megdnn::Elemwise::deduce_shape(src, shp);
  108. }
  109. if (!shp.ndim) {
  110. return {{{TensorLayout(dtype::Int32()), desc.comp_node}}, false};
  111. }
  112. DeviceTensorND value;
  113. if (op_def.axis == opr::GetVarShape::Param::INVALID_AXIS) {
  114. value = DeviceTensorND(CompNode::default_cpu(), {shp.ndim}, dtype::Int32());
  115. auto* ptr = value.ptr<dt_int32>();
  116. for (size_t i = 0; i < shp.ndim; ++i) {
  117. ptr[i] = shp[i];
  118. }
  119. } else {
  120. int32_t axis = op_def.axis;
  121. if (axis < 0) {
  122. axis += shp.ndim;
  123. }
  124. mgb_assert(axis >= 0 && axis < (int32_t)shp.ndim);
  125. value = DeviceTensorND(CompNode::default_cpu(), {1}, dtype::Int32());
  126. auto* ptr = value.ptr<dt_int32>();
  127. ptr[0] = shp[axis];
  128. }
  129. return {{{value.layout(), desc.comp_node, std::move(value)}}, true};
  130. }
  131. std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) {
  132. auto* node = &node_->cast_final_safe<opr::GetVarShape>();
  133. return GetVarShape::make(node->param());
  134. }
  135. OP_TRAIT_REG(GetVarShape, GetVarShape, opr::GetVarShape)
  136. .make_from_op_node(make_from_op_node)
  137. .decide_dispatch_mode(decide_dispatch_mode)
  138. .infer_output_attrs_fallible(infer_output_attrs_fallible)
  139. .apply_on_var_node(apply_on_var_node)
  140. .apply_on_device_tensornd(apply_on_device_tensornd)
  141. .apply_on_physical_tensor(apply_on_physical_tensor)
  142. .fallback();
  143. } // namespace get_var_shape
  144. namespace param_pack {
  145. TensorShapeArray get_shapes(const std::vector<std::vector<size_t>>& shapes) {
  146. TensorShapeArray ret;
  147. for (auto&& i : shapes) {
  148. SmallVector<size_t> shape(i.begin(), i.end());
  149. TensorShape shp(shape);
  150. ret.push_back(shp);
  151. }
  152. return ret;
  153. }
  154. cg::OperatorNodeBase* param_pack_split_apply_on_var_node(
  155. const OpDef& def, const VarNodeArray& inputs) {
  156. auto&& param = def.cast_final_safe<ParamPackSplit>();
  157. auto&& graph = inputs[0]->owner_graph();
  158. auto&& shapes = get_shapes(param.shapes);
  159. OperatorNodeConfig config(param.make_name());
  160. cg::OperatorNodeBase* opr =
  161. graph->insert_opr(std::make_unique<mgb::opr::ParamPackSplit>(
  162. inputs[0], param.offsets, shapes, config));
  163. return opr;
  164. }
  165. SmallVector<TensorPtr> param_pack_split_apply_on_physical_tensor(
  166. const OpDef& def, const SmallVector<TensorPtr>& inputs,
  167. SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) {
  168. auto&& param = def.cast_final_safe<ParamPackSplit>();
  169. mgb_assert(
  170. inputs.size() == 1, "ParamPackSplit take 1 input, got %lu", inputs.size());
  171. auto&& inp = inputs[0];
  172. auto&& shp = inp->layout();
  173. mgb_assert(shp.ndim == 1, "ParamPackSplit input shape invalid, ndim should be 1");
  174. mgb_assert(param.shapes.size() * 2 == param.offsets.size());
  175. SmallVector<TensorPtr> ret;
  176. auto&& shapes = get_shapes(param.shapes);
  177. size_t dtype_size = inputs[0]->layout().dtype.size();
  178. for (size_t i = 0; i < shapes.size(); ++i) {
  179. // memory forward
  180. ret.push_back(inputs[0]->sub(param.offsets[i * 2] * dtype_size, shapes[i]));
  181. }
  182. return ret;
  183. }
  184. OP_TRAIT_REG(ParamPackSplit, ParamPackSplit, mgb::opr::ParamPackSplit)
  185. .apply_on_var_node(param_pack_split_apply_on_var_node)
  186. .apply_on_physical_tensor(param_pack_split_apply_on_physical_tensor)
  187. .fallback();
  188. cg::OperatorNodeBase* param_pack_concat_apply_on_var_node(
  189. const OpDef& def, const VarNodeArray& inputs) {
  190. auto&& param = def.cast_final_safe<ParamPackConcat>();
  191. auto&& graph = inputs[0]->owner_graph();
  192. VarNodeArray inps(inputs.begin(), inputs.end() - 1);
  193. OperatorNodeConfig config{param.make_name()};
  194. cg::OperatorNodeBase* opr =
  195. graph->insert_opr(std::make_unique<mgb::opr::ParamPackConcat>(
  196. inps, inputs.back(), param.offsets, config));
  197. return opr;
  198. }
  199. SmallVector<TensorPtr> param_pack_concat_apply_on_physical_tensor(
  200. const OpDef& def, const SmallVector<TensorPtr>& inputs,
  201. SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) {
  202. def.cast_final_safe<ParamPackConcat>();
  203. mgb_assert(inputs.size() > 1, "param_pack should have at least one input");
  204. auto comp_node = inputs.front()->comp_node();
  205. auto dtype = inputs.front()->dtype();
  206. size_t nr_inputs = inputs.size() - 1;
  207. size_t nr_elems = 0;
  208. for (size_t i = 0; i < nr_inputs; ++i) {
  209. auto& input = inputs[i];
  210. mgb_assert(
  211. comp_node == input->comp_node(),
  212. "inputs for param_pack_concat must in same comp_node");
  213. mgb_assert(
  214. dtype == input->dtype(),
  215. "inputs for param_pack_concat must have same dtype");
  216. nr_elems += input->layout().total_nr_elems();
  217. }
  218. auto dest_layout = TensorLayout({nr_elems}, dtype);
  219. auto output = Tensor::make(dest_layout, comp_node);
  220. auto caller = DnnOprCaller<megdnn::ParamPackConcat>(comp_node);
  221. size_t srcs_size = sizeof(void*) * nr_inputs;
  222. void** srcs_raw_ptr = (void**)comp_node.alloc_host(srcs_size);
  223. std::shared_ptr<dt_byte> srcs_ptr = {
  224. (dt_byte*)srcs_raw_ptr,
  225. [comp_node](dt_byte* ptr) { comp_node.free_host(ptr); }};
  226. TensorLayout srcs_layout = TensorLayout{{nr_inputs}, dtype::Int32()};
  227. size_t ws_size;
  228. {
  229. TensorShapeArray src_shapes;
  230. for (size_t i = 0; i < nr_inputs; ++i) {
  231. src_shapes.push_back(inputs[i]->shape());
  232. }
  233. ws_size = caller.op->get_workspace_in_bytes(
  234. src_shapes, inputs.back()->shape(), TensorShape{});
  235. }
  236. for (size_t i = 0; i < nr_inputs; ++i) {
  237. srcs_raw_ptr[i] = inputs[i]->dev_tensor().as_megdnn().raw_ptr();
  238. }
  239. HostTensorStorage srcs_storage;
  240. srcs_storage.reset(comp_node, srcs_size, srcs_ptr);
  241. caller.op->exec(
  242. {srcs_raw_ptr, srcs_layout}, inputs.back()->dev_tensor().as_megdnn(),
  243. output->dev_tensor().as_megdnn(),
  244. caller.create_workspace({{ws_size}, dtype::Byte()}));
  245. AsyncReleaser::inst()->add(
  246. HostTensorND{comp_node, srcs_layout}.storage(srcs_storage));
  247. return {output};
  248. }
  249. OP_TRAIT_REG(ParamPackConcat, ParamPackConcat, mgb::opr::ParamPackConcat)
  250. .apply_on_var_node(param_pack_concat_apply_on_var_node)
  251. .apply_on_physical_tensor(param_pack_concat_apply_on_physical_tensor)
  252. .fallback();
  253. } // namespace param_pack
  254. namespace split {
  255. std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) {
  256. using Options = opr::Split::Options;
  257. auto* node = &node_->cast_final_safe<opr::Split>();
  258. auto&& opt = node->options();
  259. int axis = opt.axis;
  260. mgb_assert(
  261. opt.method == Options::Method::SPECIFY,
  262. "only Split with SPECIFY output shapes is supported");
  263. mgb_assert(opt.partition.size() == opt.nr_part);
  264. return Split::make(axis, 0);
  265. }
  266. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  267. using Options = opr::Split::Options;
  268. auto&& sp = static_cast<const Split&>(def);
  269. OperatorNodeConfig config{sp.make_name()};
  270. opr::Split::Options opt;
  271. if (sp.nsections) {
  272. opt = Options::make_average(sp.axis, sp.nsections);
  273. opt.method = Options::Method::CALL_BACK;
  274. } else {
  275. opt.axis = sp.axis;
  276. opt.method = Options::Method::SPECIFY;
  277. mgb_assert(inputs.size() > 1);
  278. opt.nr_part = inputs.size() - 1;
  279. opt.partition.resize(opt.nr_part);
  280. for (size_t i = 1; i < inputs.size(); ++i)
  281. opt.partition[i - 1] = inputs[i];
  282. }
  283. return opr::Split::make(inputs[0], opt, config);
  284. }
  285. OP_TRAIT_REG(Split, Split, opr::Split)
  286. .make_from_op_node(make_from_op_node)
  287. .apply_on_var_node(apply_on_var_node)
  288. .fallback();
  289. } // namespace split
  290. } // namespace mgb::imperative