You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

scalar.cpp 14 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399
  1. /**
  2. * \file imperative/src/impl/transformations/trace.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "megbrain/imperative/transformations/scalar.h"
  12. #include "megbrain/imperative/ops/autogen.h"
  13. #include "megbrain/imperative/ops/utility.h"
  14. #include "megbrain/imperative/utils/stats.h"
  15. namespace mgb {
  16. namespace imperative {
  17. namespace {
  18. using ScalarRule = ValueRefList (*)(const OpDef&, Span<ValueRef>, Span<bool>);
  19. static std::unordered_map<Typeinfo*, ScalarRule> scalar_rules;
  20. ValueRef make_scalar_shape(CompNode device) {
  21. HostTensorND scalar_shape(device, {1}, dtype::Int32());
  22. scalar_shape.ptr<dt_int32>()[0] = 1;
  23. return imperative::apply(
  24. CreateTensor(CreateTensor::Const, device, scalar_shape.layout()),
  25. HostStorage::make(scalar_shape.storage()))[0];
  26. }
  27. bool is_scalar_shape(ValueRef shape) {
  28. // may have performance issue
  29. auto shape_of_shape = shape.shape();
  30. if (!shape_of_shape) {
  31. // assume not scalar
  32. return false;
  33. }
  34. return *shape_of_shape == ValueShape{0};
  35. }
  36. template <typename T, ValueRefList (*rule)(const T&, Span<ValueRef>, Span<bool>)>
  37. void register_scalar_rule() {
  38. scalar_rules[T::typeinfo()] = [](const OpDef& def, Span<ValueRef> inputs,
  39. Span<bool> inputs_mask) {
  40. return (*rule)(def.cast_final_safe<T>(), inputs, inputs_mask);
  41. };
  42. }
  43. template <typename TOpDef, size_t nr_inputs>
  44. ValueRefList elemwise_rule(
  45. const TOpDef& op_def, Span<ValueRef> inputs, Span<bool> inputs_mask) {
  46. if constexpr (nr_inputs != 0) {
  47. mgb_assert(inputs.size() == inputs.size(), "inputs size mismatch");
  48. }
  49. bool all_scalar = true;
  50. for (auto&& input_mask : inputs_mask) {
  51. if (!input_mask) {
  52. all_scalar = false;
  53. }
  54. }
  55. auto outputs = imperative::apply(op_def, inputs);
  56. if (all_scalar) {
  57. outputs[0] = ScalarValue::make(outputs[0]);
  58. }
  59. return outputs;
  60. }
  61. ValueRefList remove_axis_rule(
  62. const RemoveAxis& remove_axis, Span<ValueRef> inputs, Span<bool> inputs_mask) {
  63. mgb_assert(!inputs_mask.item());
  64. bool is_scalar = inputs.item().shape()->ndim == remove_axis.axis.size();
  65. if (is_scalar && remove_axis.axis.size() == 1) {
  66. return {ScalarValue::make(inputs.item())};
  67. }
  68. auto outputs = imperative::apply(remove_axis, inputs);
  69. if (is_scalar) {
  70. outputs[0] = ScalarValue::make(outputs[0]);
  71. }
  72. return outputs;
  73. }
  74. ValueRefList reduce_rule(
  75. const Reduce& reduce, Span<ValueRef> inputs, Span<bool> inputs_mask) {
  76. if (inputs.size() == 1) {
  77. return imperative::apply(reduce, inputs);
  78. }
  79. mgb_assert(inputs.size() == 2);
  80. bool is_scalar = is_scalar_shape(inputs[1]);
  81. if (is_scalar) {
  82. CompNode device = *inputs[0].device();
  83. return {ScalarValue::make(
  84. imperative::apply(reduce, inputs[0], make_scalar_shape(device))[0])};
  85. }
  86. return imperative::apply(reduce, inputs);
  87. }
  88. ValueRefList collective_comm_rule(
  89. const CollectiveComm& collective_comm, Span<ValueRef> inputs,
  90. Span<bool> inputs_mask) {
  91. mgb_assert(inputs.size() == 1);
  92. static std::unordered_set<CollectiveComm::Mode> modes = {
  93. CollectiveComm::Mode::ALL_REDUCE_MAX, CollectiveComm::Mode::ALL_REDUCE_MIN,
  94. CollectiveComm::Mode::ALL_REDUCE_SUM, CollectiveComm::Mode::BROADCAST,
  95. CollectiveComm::Mode::REDUCE_SUM,
  96. };
  97. if (modes.count(collective_comm.mode) == 0) {
  98. return imperative::apply(collective_comm, inputs);
  99. }
  100. if (inputs_mask.item()) {
  101. return {ScalarValue::make(imperative::apply(collective_comm, inputs[0])[0])};
  102. } else {
  103. return imperative::apply(collective_comm, inputs);
  104. }
  105. }
  106. ValueRefList param_pack_split_rule(
  107. const ParamPackSplit& param_pack_split, Span<ValueRef> inputs,
  108. Span<bool> inputs_mask) {
  109. auto outputs = imperative::apply(param_pack_split, inputs);
  110. size_t nr_outputs = outputs.size();
  111. mgb_assert(nr_outputs == param_pack_split.shapes.size());
  112. for (size_t i = 0; i < nr_outputs; ++i) {
  113. if (param_pack_split.shapes[i].empty()) {
  114. outputs[i] = ScalarValue::make(outputs[i]);
  115. }
  116. }
  117. return outputs;
  118. }
  119. ValueRefList dot_rule(const Dot& dot, Span<ValueRef> inputs, Span<bool> inputs_mask) {
  120. return {ScalarValue::make(imperative::apply(dot, inputs)[0])};
  121. }
  122. ValueRefList add_axis_rule(
  123. const AddAxis& add_axis, Span<ValueRef> inputs, Span<bool> inputs_mask) {
  124. mgb_assert(inputs.size() == 1);
  125. if (inputs_mask.item()) {
  126. mgb_assert(add_axis.axis[0] == 0);
  127. if (add_axis.axis.size() == 1) {
  128. return {inputs[0]};
  129. } else {
  130. std::vector<int32_t> axis(add_axis.axis.begin() + 1, add_axis.axis.end());
  131. return imperative::apply(*AddAxis::make(axis, add_axis.scope()), inputs[0]);
  132. }
  133. } else {
  134. return imperative::apply(add_axis, inputs);
  135. }
  136. }
  137. ValueRefList remote_recv_rule(
  138. const RemoteRecv& remote_recv, Span<ValueRef> inputs, Span<bool> inputs_mask) {
  139. if (remote_recv.shape.empty()) {
  140. std::vector<int32_t> shape = {1};
  141. auto remote_recv_no_scalar = RemoteRecv::make(
  142. remote_recv.key, remote_recv.addr, remote_recv.port,
  143. remote_recv.rank_from, remote_recv.cn, shape, remote_recv.dtype,
  144. remote_recv.backend);
  145. remote_recv_no_scalar->set_scope(remote_recv.scope());
  146. return imperative::apply(ApplyOp(*remote_recv_no_scalar), inputs);
  147. } else {
  148. return imperative::apply(remote_recv, inputs);
  149. }
  150. }
  151. ValueRefList check_no_finite_rule(
  152. const CheckNonFinite& check_no_finite, Span<ValueRef> inputs,
  153. Span<bool> inputs_mask) {
  154. auto outputs = imperative::apply(check_no_finite, inputs);
  155. mgb_assert(outputs.size() == inputs.size() + 1, "output size mismatch");
  156. outputs.back() = ScalarValue::make(outputs.back());
  157. for (size_t i = 0; i < inputs.size(); ++i) {
  158. if (inputs_mask[i]) {
  159. outputs[i] = ScalarValue::make(outputs[i]);
  160. }
  161. }
  162. return outputs;
  163. }
  164. ValueRefList subtensor_rule(
  165. const Subtensor& subtensor, Span<ValueRef> inputs, Span<bool> inputs_mask) {
  166. mgb_assert(inputs.size() >= 1);
  167. auto input = inputs[0];
  168. bool is_scalar;
  169. mgb_assert(!inputs_mask[0], "subtensor shouldn't have scalar input");
  170. if (auto shape = input.shape()) {
  171. size_t ndim = shape->ndim;
  172. for (auto&& [axis, begin, end, step, idx] : subtensor.items) {
  173. if (idx) {
  174. ndim--;
  175. }
  176. }
  177. is_scalar = ndim == 0;
  178. } else {
  179. // assume not scalar
  180. is_scalar = false;
  181. }
  182. auto outputs = imperative::apply(subtensor, inputs);
  183. if (is_scalar) {
  184. outputs[0] = ScalarValue::make(outputs[0]);
  185. }
  186. return outputs;
  187. }
  188. ValueRefList get_var_shape_rule(
  189. const GetVarShape& get_var_shape, Span<ValueRef> inputs,
  190. Span<bool> inputs_mask) {
  191. bool all_scalar = true;
  192. mgb_assert(inputs.size() >= 1);
  193. for (auto&& input_mask : inputs_mask) {
  194. if (!input_mask) {
  195. all_scalar = false;
  196. }
  197. }
  198. if (all_scalar) {
  199. auto device = inputs[0].device();
  200. auto storage = HostStorage::make(*device);
  201. // storage->ensure_size(1);
  202. return imperative::apply(
  203. CreateTensor(
  204. CreateTensor::Const, *device, dtype::Int32(), ValueShape{0}),
  205. storage);
  206. } else {
  207. return imperative::apply(get_var_shape, inputs);
  208. }
  209. }
  210. ValueRefList reshape_rule(
  211. const Reshape& reshape, Span<ValueRef> inputs, Span<bool> inputs_mask) {
  212. mgb_assert(inputs.size() == 2);
  213. bool is_scalar = is_scalar_shape(inputs[1]);
  214. if (is_scalar) {
  215. return {ScalarValue::make(imperative::apply(
  216. reshape, inputs[0], make_scalar_shape(*inputs[0].device()))[0])};
  217. } else {
  218. return imperative::apply(reshape, inputs);
  219. }
  220. }
  221. ValueRefList broadcast_rule(
  222. const Broadcast& broadcast, Span<ValueRef> inputs, Span<bool> inputs_mask) {
  223. mgb_assert(inputs.size() == 2);
  224. bool is_scalar = is_scalar_shape(inputs[1]);
  225. if (is_scalar) {
  226. return {ScalarValue::make(imperative::apply(
  227. broadcast, inputs[0], make_scalar_shape(*inputs[0].device()))[0])};
  228. } else {
  229. return imperative::apply(broadcast, inputs);
  230. }
  231. }
  232. template <typename T>
  233. ValueRefList subgraph_op_rule(
  234. const T& op, Span<ValueRef> inputs, Span<bool> inputs_mask,
  235. const Type<ScalarValue>& scalar_type) {
  236. // TODO: add flag instead of assume
  237. bool all_scalar = true;
  238. for (auto&& input_mask : inputs_mask) {
  239. if (!input_mask) {
  240. all_scalar = false;
  241. }
  242. }
  243. auto outputs = imperative::apply(op, inputs);
  244. if (all_scalar) {
  245. for (auto& output : outputs) {
  246. output = scalar_type.make(output);
  247. }
  248. }
  249. return outputs;
  250. }
  251. struct ScalarRuleRegistry {
  252. ScalarRuleRegistry() {
  253. register_scalar_rule<Elemwise, elemwise_rule<Elemwise, 0>>();
  254. register_scalar_rule<RemoveAxis, remove_axis_rule>();
  255. register_scalar_rule<Reduce, reduce_rule>();
  256. register_scalar_rule<TypeCvt, elemwise_rule<TypeCvt, 1>>();
  257. register_scalar_rule<CollectiveComm, collective_comm_rule>();
  258. register_scalar_rule<ParamPackSplit, param_pack_split_rule>();
  259. register_scalar_rule<Dot, dot_rule>();
  260. register_scalar_rule<AddAxis, add_axis_rule>();
  261. register_scalar_rule<RemoteRecv, remote_recv_rule>();
  262. register_scalar_rule<CheckNonFinite, check_no_finite_rule>();
  263. register_scalar_rule<Subtensor, subtensor_rule>();
  264. register_scalar_rule<GetVarShape, get_var_shape_rule>();
  265. register_scalar_rule<FastpathCopy, elemwise_rule<FastpathCopy, 1>>();
  266. register_scalar_rule<Reshape, reshape_rule>();
  267. register_scalar_rule<Broadcast, broadcast_rule>();
  268. register_scalar_rule<Copy, elemwise_rule<Copy, 1>>();
  269. register_scalar_rule<InplaceAdd, elemwise_rule<InplaceAdd, 4>>();
  270. register_scalar_rule<SubgraphOp, subgraph_op_rule<SubgraphOp>>();
  271. register_scalar_rule<CompiledOp, subgraph_op_rule<CompiledOp>>();
  272. }
  273. } _;
  274. } // namespace
  275. ValueRefList ScalarTransformation::apply_get_attr(
  276. const GetAttr& get_attr, Span<ValueRef> inputs) {
  277. auto&& input = inputs.item();
  278. bool is_scalar = input.is<ScalarValue>();
  279. if (!is_scalar) {
  280. return imperative::apply(get_attr, input);
  281. }
  282. auto unwrapped_input = input.cast<ScalarValue>().value();
  283. if (get_attr.attr() == GetAttr::Shape) {
  284. if (!m_empty_shape) {
  285. m_empty_shape = ShapeValue::make();
  286. }
  287. return {m_empty_shape};
  288. } else {
  289. auto outputs = imperative::apply(get_attr, unwrapped_input);
  290. auto& output = outputs[0];
  291. switch (get_attr.attr()) {
  292. case GetAttr::Value: {
  293. auto& hv = output.cast<HostValue>();
  294. mgb_assert(
  295. hv.shape() == ValueShape({1}),
  296. "underlying value should has shape {1}, got %s",
  297. hv.shape().to_string().c_str());
  298. output = HostValue::make(hv.dtype(), ValueShape(), hv.storage());
  299. break;
  300. }
  301. case GetAttr::Data: {
  302. auto& dv = output.cast<DeviceValue>();
  303. mgb_assert(
  304. dv.shape() == ValueShape({1}),
  305. "underlying value should has shape {1}, got %s",
  306. dv.shape().to_string().c_str());
  307. output = DeviceValue::make(dv.dtype(), ValueShape(), dv.storage());
  308. break;
  309. }
  310. default:
  311. break;
  312. }
  313. return outputs;
  314. }
  315. }
  316. ValueRefList ScalarTransformation::apply_transformation(
  317. const Operator& op, Span<ValueRef> inputs) {
  318. if (auto* get_attr = op.as<GetAttr>()) {
  319. // fastpath for GetAttr
  320. return apply_get_attr(*get_attr, inputs);
  321. } else if (auto* apply_op = op.as<ApplyOp>()) {
  322. if (apply_op->op().same_type<FastpathCopy>()) {
  323. return inputs[0];
  324. }
  325. }
  326. size_t nr_inputs = inputs.size();
  327. ValueRefList unwrapped_inputs(nr_inputs);
  328. SmallVector<bool> inputs_mask(nr_inputs);
  329. for (size_t i = 0; i < inputs.size(); ++i) {
  330. if (auto&& scalar_value = inputs[i].as_ref<ScalarValue>()) {
  331. unwrapped_inputs[i] = scalar_value->value();
  332. inputs_mask[i] = true;
  333. } else {
  334. unwrapped_inputs[i] = inputs[i];
  335. inputs_mask[i] = false;
  336. }
  337. }
  338. auto fallback = [&] { return imperative::apply(op, unwrapped_inputs); };
  339. if (auto apply_op = op.as<ApplyOp>()) {
  340. auto iter = scalar_rules.find(apply_op->op().dyn_typeinfo());
  341. if (iter != scalar_rules.end()) {
  342. return iter->second(apply_op->op(), unwrapped_inputs, inputs_mask);
  343. } else {
  344. // TODO: repeat op
  345. return fallback();
  346. }
  347. } else if (auto* create_tensor = op.as<CreateTensor>()) {
  348. if (create_tensor->shape().is_scalar()) {
  349. ValueShape scalar_shape = {1};
  350. CreateTensor scalar_op(
  351. create_tensor->kind(), create_tensor->device(),
  352. create_tensor->dtype(), scalar_shape);
  353. return {ScalarValue::make(imperative::apply(scalar_op, inputs)[0])};
  354. } else {
  355. return imperative::apply(op, inputs);
  356. }
  357. } else if (op.as<IsScalar>()) {
  358. mgb_assert(nr_inputs == 1);
  359. return {BoolValue::make(inputs_mask[0])};
  360. } else if (op.is<Operator::IdentityLike>()) {
  361. mgb_assert(nr_inputs == 1);
  362. bool is_scalar = inputs_mask[0];
  363. auto outputs = fallback();
  364. if (is_scalar) {
  365. outputs[0] = ScalarValue::make(outputs[0]);
  366. }
  367. return outputs;
  368. } else {
  369. return fallback();
  370. }
  371. };
  372. } // namespace imperative
  373. } // namespace mgb