You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

grad_comm_ops.py 20 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Generate bprop for comm ops"""
  16. from mindspore import Tensor
  17. import mindspore.common.dtype as mstype
  18. from mindspore.ops import functional as F
  19. from mindspore.communication import get_rank, get_group_size
  20. from mindspore.parallel._utils import _get_enable_parallel_optimizer, _get_grad_accumulation_shard
  21. from .. import operations as P
  22. from ...common.tensor import RowTensor
  23. from ..composite.multitype_ops.zeros_like_impl import zeros_like
  24. from ..operations.comm_ops import (AllGather, _MiniStepAllGather, _HostAllGather, AllReduce, NeighborExchange, AlltoAll, NeighborExchangeV2,
  25. Broadcast, _GetTensorSlice, _MirrorOperator, _MirrorMiniStepOperator, ReduceOp,
  26. ReduceScatter, _HostReduceScatter, _VirtualDiv, _VirtualAdd, AllSwap,
  27. _VirtualAssignAdd, _VirtualAccuGrad, _MirrorMicroStepOperator, _MicroStepAllGather)
  28. from .grad_base import bprop_getters
  29. from ..operations._inner_ops import Send, Receive
  30. from ..operations import _grad_ops as G
  31. @bprop_getters.register(AllReduce)
  32. def get_bprop_all_reduce(self):
  33. """Generate bprop for AllReduce, do allreduce or allgather, allgather for sparse feature."""
  34. all_reduce_grad = AllReduce(ReduceOp.SUM, self.group)
  35. all_gather = AllGather(group=self.group)
  36. if self.instance_name:
  37. instance_name = "grad" + self.instance_name
  38. all_reduce_grad.set_prim_instance_name(instance_name)
  39. equal = P.Equal()
  40. cast = P.Cast()
  41. mul = P.Mul()
  42. div = P.RealDiv()
  43. dtype = P.DType()
  44. if self.op == ReduceOp.PROD:
  45. def bprop(x, out, dout):
  46. dy1 = mul(dout, out)
  47. dy2 = all_reduce_grad(dy1)
  48. dx = div(dy2, x)
  49. return (dx,)
  50. elif self.op == ReduceOp.SUM:
  51. def bprop(x, out, dout):
  52. if F.issubclass_(F.typeof(dout), mstype.tensor):
  53. dx = all_reduce_grad(dout)
  54. else:
  55. indices = all_gather(dout.indices)
  56. grad = all_gather(dout.values)
  57. dx = RowTensor(indices, grad, dout.dense_shape)
  58. return (dx,)
  59. else:
  60. def bprop(x, out, dout):
  61. if F.issubclass_(F.typeof(dout), mstype.tensor):
  62. dx = all_reduce_grad(dout)
  63. z = equal(x, out)
  64. z = cast(z, dtype(dx))
  65. dx = mul(dx, z)
  66. else:
  67. indices = all_gather(dout.indices)
  68. grad = all_gather(dout.values)
  69. z = equal(x, out)
  70. z = cast(z, dtype(grad))
  71. grad = mul(grad, z)
  72. dx = RowTensor(indices, grad, dout.dense_shape)
  73. return (dx,)
  74. return bprop
  75. @bprop_getters.register(Send)
  76. def get_bprop_send(self):
  77. """Generate bprop for Send."""
  78. shape = self.get_attr_dict()["shape"]
  79. dtype = self.get_attr_dict()["dtype"]
  80. send_grad = Receive(self.sr_tag, self.rank, shape, dtype, self.group_back)
  81. virtual_input = Tensor(0.0, dtype)
  82. def bprop(x, out, dout):
  83. dx = send_grad(virtual_input)
  84. return (dx,)
  85. return bprop
  86. @bprop_getters.register(Receive)
  87. def get_bprop_receive(self):
  88. """Generate bprop for Receive."""
  89. receive_grad = Send(self.tag, self.rank, self.group_back)
  90. depend = P.Depend()
  91. cast = P.Cast()
  92. out_tensor = Tensor(0.0, mstype.float16)
  93. is_opt_shard = _get_enable_parallel_optimizer()
  94. def bprop(x, out, dout):
  95. send_out = receive_grad(dout)
  96. if is_opt_shard:
  97. dx = depend(F.zeros_like(x), send_out)
  98. else:
  99. dx = depend(cast(out_tensor, F.dtype(x)), send_out)
  100. return (dx,)
  101. return bprop
  102. @bprop_getters.register(_VirtualAdd)
  103. def get_bprop_virtual_add(self):
  104. """Generate bprop for _VirtualAdd"""
  105. def bprop(x, grad_accu, out, dout):
  106. return (dout + grad_accu, zeros_like(grad_accu))
  107. return bprop
  108. @bprop_getters.register(_VirtualAssignAdd)
  109. def get_bprop_virtual_assign_add(self):
  110. """Generate bprop for VirtualAssignAdd."""
  111. assign_add = P.AssignAdd()
  112. cast = P.Cast()
  113. dtype = P.DType()
  114. out_tensor = Tensor(0.0, mstype.float16)
  115. reduce_scatter = None
  116. group = self.get_attr_dict().get("group", None)
  117. fusion = self.get_attr_dict().get("fusion", 0)
  118. if group:
  119. reduce_scatter = ReduceScatter(ReduceOp.SUM, group).add_prim_attr("fusion", fusion)
  120. if self.instance_name:
  121. instance_name = "_grad_accumulation_shard_grad" + self.instance_name
  122. reduce_scatter.set_prim_instance_name(instance_name)
  123. # For pipeline training, as the fused communication will be visited later
  124. # this may make memory increase, so we need to add a tag to let the
  125. # fused communication not be effective.
  126. reduce_scatter.add_prim_attr("not_delay_fusion", True)
  127. def bprop(x, y, out, dout):
  128. if reduce_scatter:
  129. dout = reduce_scatter(dout)
  130. temp = assign_add(y, dout)
  131. return F.depend((cast(out_tensor, dtype(x)), cast(out_tensor, dtype(y))), temp)
  132. return bprop
  133. @bprop_getters.register(_VirtualAccuGrad)
  134. def get_bprop_virtual_accu_grad(self):
  135. """Generate bprop for VirtualAccuGrad."""
  136. cast = P.Cast()
  137. dtype = P.DType()
  138. out_tensor = Tensor(0.0, mstype.float16)
  139. def bprop(x, y, out, dout):
  140. return (F.depend(y, dout), cast(out_tensor, dtype(y)))
  141. return bprop
  142. @bprop_getters.register(_MirrorMicroStepOperator)
  143. def get_bprop_mirror_micro_step_operator(self):
  144. """
  145. Backpropagator for _MirrorMicroStepOperator, do allreduce or allgather for the devices in the group,
  146. allgather for sparse feature.
  147. """
  148. group = self.group
  149. dev_num = self.dev_num
  150. mean_flag = self.mean_flag
  151. scale = 1 / dev_num
  152. all_reduce = AllReduce(group=group)
  153. fusion = self.get_attr_dict()["fusion"]
  154. all_reduce.add_prim_attr("fusion", fusion)
  155. if hasattr(self, 'parameter'):
  156. parameter = self.parameter
  157. all_reduce.add_prim_attr("parameter", parameter)
  158. if self.instance_name:
  159. instance_name = "grad_mirror" + self.instance_name
  160. all_reduce.set_prim_instance_name(instance_name)
  161. cast = P.Cast()
  162. dtype = P.DType()
  163. assign = P.Assign()
  164. if "parameter_micro" in self.get_attr_dict():
  165. assign.add_prim_attr("parameter_micro", 0)
  166. out_tensor = Tensor(1.0, mstype.float16)
  167. opt_shard = _get_enable_parallel_optimizer()
  168. def bprop(x, z, out, dout):
  169. real_grad = z
  170. assign_out = dout
  171. if mean_flag:
  172. if F.issubclass_(F.typeof(dout), mstype.tensor):
  173. z = F.depend(z, dout)
  174. real_grad = all_reduce(z)
  175. real_grad = F.tensor_mul(real_grad, scale)
  176. assign_out = assign(z, real_grad)
  177. else:
  178. if F.issubclass_(F.typeof(dout), mstype.tensor):
  179. z = F.depend(z, dout)
  180. real_grad = all_reduce(z)
  181. assign_out = assign(z, real_grad)
  182. if opt_shard:
  183. return (real_grad, cast(out_tensor, dtype(z)))
  184. return F.depend((cast(out_tensor, dtype(x)), cast(out_tensor, dtype(z))), assign_out)
  185. return bprop
  186. @bprop_getters.register(Broadcast)
  187. def get_bprop_broad_cast(self):
  188. """Generate bprop for Broadcast."""
  189. def bprop(x, out, dout):
  190. return (dout,)
  191. return bprop
  192. @bprop_getters.register(AllGather)
  193. def get_bprop_all_gather(self):
  194. """Generate bprop for AllGather"""
  195. fusion = self.get_attr_dict()["fusion"]
  196. reduce_scatter = ReduceScatter(ReduceOp.SUM, self.group).add_prim_attr("fusion", fusion)
  197. if self.instance_name:
  198. instance_name = "grad_" + self.instance_name
  199. reduce_scatter.set_prim_instance_name(instance_name)
  200. mean_flag = self.get_attr_dict()["mean_flag"]
  201. scale = 1 / self.rank_size
  202. def bprop(x, out, dout):
  203. dx = reduce_scatter(dout)
  204. if mean_flag:
  205. dx = F.tensor_mul(dx, scale)
  206. return (dx,)
  207. return bprop
  208. @bprop_getters.register(_MiniStepAllGather)
  209. def get_bprop_mini_step_all_gather(self):
  210. """Generate bprop for _MiniStepAllGather"""
  211. fusion = self.get_attr_dict()["fusion"]
  212. mean_flag = self.get_attr_dict()["mean_flag"]
  213. do_mirror = self.get_attr_dict()["do_mirror"]
  214. add_accu = self.get_attr_dict().get("add_accu", False)
  215. gradient_shard = _get_grad_accumulation_shard()
  216. scale = 1 / self.rank_size
  217. all_reduce = AllReduce(ReduceOp.SUM, self.group).add_prim_attr("fusion", fusion)
  218. assign_add = P.AssignAdd()
  219. if self.instance_name:
  220. instance_name = "grad_" + self.instance_name
  221. all_reduce.set_prim_instance_name(instance_name)
  222. rank = get_rank(self.group)
  223. dev_num = get_group_size(self.group)
  224. split = P.Split(output_num=dev_num)
  225. def bprop(x, z, out, dout):
  226. if do_mirror:
  227. if not gradient_shard:
  228. z = F.depend(z, F.assign_add(z, dout))
  229. grad = all_reduce(z)
  230. dx = split(grad)[rank]
  231. if mean_flag:
  232. dx = F.tensor_mul(dx, scale)
  233. else:
  234. dout = F.depend(dout, z)
  235. grad = all_reduce(dout)
  236. dx = split(grad)[rank]
  237. if mean_flag:
  238. dx = F.tensor_mul(dx, scale)
  239. if add_accu:
  240. z = assign_add(z, dx)
  241. dx = F.depend(dx, z)
  242. else:
  243. dx = dout
  244. return (dx, zeros_like(z))
  245. return bprop
  246. @bprop_getters.register(_MicroStepAllGather)
  247. def get_bprop_micro_step_all_gather(self):
  248. """Generate bprop for _MicroStepAllGather"""
  249. fusion = self.get_attr_dict()["fusion"]
  250. mean_flag = self.get_attr_dict()["mean_flag"]
  251. do_mirror = self.get_attr_dict()["do_mirror"]
  252. scale = 1 / self.rank_size
  253. all_reduce = AllReduce(ReduceOp.SUM, self.group).add_prim_attr("fusion", fusion)
  254. rank = get_rank(self.group)
  255. dev_num = get_group_size(self.group)
  256. split = P.Split(output_num=dev_num)
  257. if self.instance_name:
  258. instance_name = "grad_" + self.instance_name
  259. all_reduce.set_prim_instance_name(instance_name)
  260. cast = P.Cast()
  261. dtype = P.DType()
  262. out_tensor = Tensor(1.0, mstype.float16)
  263. # z: accu_grad
  264. def bprop(x, z, out, dout):
  265. z = F.depend(z, dout)
  266. if not do_mirror:
  267. return (z, cast(out_tensor, dtype(z)))
  268. real_grad = all_reduce(z)
  269. real_grad = split(real_grad)[rank]
  270. if mean_flag:
  271. real_grad = F.tensor_mul(real_grad, scale)
  272. return (real_grad, cast(out_tensor, dtype(z)))
  273. return bprop
  274. @bprop_getters.register(_HostAllGather)
  275. def get_bprop_host_all_gather(self):
  276. """Generate bprop for _HostAllGather"""
  277. host_all_gather_grad = _HostReduceScatter(ReduceOp.SUM, self.group)
  278. if self.instance_name:
  279. instance_name = "grad" + self.instance_name
  280. host_all_gather_grad.set_prim_instance_name(instance_name)
  281. def bprop(x, out, dout):
  282. dx = host_all_gather_grad(dout)
  283. return (dx,)
  284. return bprop
  285. @bprop_getters.register(ReduceScatter)
  286. def get_bprop_reduce_scatter(self):
  287. """Generate bprop for ReduceScatter"""
  288. reduce_scatter_grad = AllGather(self.group)
  289. if self.instance_name:
  290. instance_name = "grad" + self.instance_name
  291. reduce_scatter_grad.set_prim_instance_name(instance_name)
  292. if self.op != ReduceOp.SUM:
  293. raise RuntimeError("The reducescatter bprop only support ReduceOp.SUM until now.")
  294. def bprop(x, out, dout):
  295. dx = reduce_scatter_grad(dout)
  296. return (dx,)
  297. return bprop
  298. @bprop_getters.register(AllSwap)
  299. def get_bprop_allswap(self):
  300. """Generate bprop for AllSwap."""
  301. all_swap_grad = AllSwap(self.group)
  302. if self.instance_name:
  303. instance_name = "grad" + self.instance_name
  304. all_swap_grad.set_prim_instance_name(instance_name)
  305. def bprop(x, send_size, recv_size, out, dout):
  306. dx = all_swap_grad(dout, recv_size, send_size)
  307. return (dx, zeros_like(send_size), zeros_like(recv_size))
  308. return bprop
  309. @bprop_getters.register(_HostReduceScatter)
  310. def get_bprop_host_reduce_scatter(self):
  311. """Generate bprop for _HostReduceScatter"""
  312. host_reduce_scatter_grad = _HostAllGather(self.group)
  313. if self.instance_name:
  314. instance_name = "grad" + self.instance_name
  315. host_reduce_scatter_grad.set_prim_instance_name(instance_name)
  316. if self.op != ReduceOp.SUM:
  317. raise RuntimeError("The hostreducescatter bprop only support ReduceOp.SUM until now.")
  318. def bprop(x, out, dout):
  319. dx = host_reduce_scatter_grad(dout)
  320. return (dx,)
  321. return bprop
  322. @bprop_getters.register(NeighborExchange)
  323. def get_bprop_neighborexchange(self):
  324. """Generate bprop for NeighborExchange."""
  325. group = self.group
  326. send_rank_ids = self.recv_rank_ids
  327. recv_rank_ids = self.send_rank_ids
  328. recv_shapes = self.send_shapes
  329. send_shapes = self.recv_shapes
  330. recv_type = self.recv_type
  331. neighborexchange_grad = NeighborExchange(send_rank_ids, recv_rank_ids, recv_shapes, send_shapes, recv_type, group)
  332. def bprop(x, out, dout):
  333. return (neighborexchange_grad(dout),)
  334. return bprop
  335. @bprop_getters.register(AlltoAll)
  336. def get_bprop_all_to_all(self):
  337. """Generate bprop for AlltoAll."""
  338. all_to_all_grad = AlltoAll(self.split_count, self.concat_dim, self.split_dim, self.group)
  339. if self.instance_name:
  340. instance_name = "grad" + self.instance_name
  341. all_to_all_grad.set_prim_instance_name(instance_name)
  342. def bprop(x, out, dout):
  343. dx = all_to_all_grad(dout)
  344. return (dx,)
  345. return bprop
  346. @bprop_getters.register(NeighborExchangeV2)
  347. def get_bprop_neighborexchangev2(self):
  348. """Generate bprop for NeighborExchangeV2."""
  349. group = self.group
  350. send_rank_ids = self.recv_rank_ids
  351. recv_rank_ids = self.send_rank_ids
  352. send_lens = self.recv_lens
  353. recv_lens = self.send_lens
  354. data_format = self.data_format
  355. neighborexchangev2_grad = G.NeighborExchangeV2Grad(send_rank_ids, send_lens, recv_rank_ids,
  356. recv_lens, data_format, group)
  357. def bprop(x, out, dout):
  358. return (neighborexchangev2_grad(dout),)
  359. return bprop
  360. @bprop_getters.register(_MirrorOperator)
  361. def get_bprop_mirror_operator(self):
  362. """
  363. Backpropagator for _MirrorOperator, do allreduce or allgather for the devices in group(only for one group),
  364. allgather for sparse feature.
  365. """
  366. group = self.group
  367. dev_num = self.dev_num
  368. mean_flag = self.mean_flag
  369. all_reduce = AllReduce(group=group)
  370. all_gather = AllGather(group=group)
  371. mul = P.Mul()
  372. cast = P.Cast()
  373. fusion = self.get_attr_dict()["fusion"]
  374. all_reduce.add_prim_attr("fusion", fusion)
  375. if hasattr(self, 'parameter'):
  376. parameter = self.parameter
  377. all_reduce.add_prim_attr("parameter", parameter)
  378. if self.instance_name:
  379. instance_name = "grad_mirror" + self.instance_name
  380. all_reduce.set_prim_instance_name(instance_name)
  381. def bprop(x, out, dout):
  382. if mean_flag:
  383. if F.issubclass_(F.typeof(dout), mstype.tensor):
  384. dx = all_reduce(dout)
  385. float_one = F.scalar_cast(1.0, F.dtype(dx))
  386. num = F.scalar_cast(dev_num, F.dtype(dx))
  387. dx = mul(dx, cast(F.scalar_to_array(float_one/num), F.dtype(dx)))
  388. else:
  389. indices = all_gather(dout.indices)
  390. grad = all_gather(dout.values)
  391. float_one = F.scalar_cast(1.0, F.dtype(grad))
  392. num = F.scalar_cast(dev_num, F.dtype(grad))
  393. grad = mul(grad, cast(F.scalar_to_array(float_one/num), F.dtype(grad)))
  394. dx = RowTensor(indices, grad, dout.dense_shape)
  395. else:
  396. if F.issubclass_(F.typeof(dout), mstype.tensor):
  397. dx = all_reduce(dout)
  398. else:
  399. indices = all_gather(dout.indices)
  400. grad = all_gather(dout.values)
  401. dx = RowTensor(indices, grad, dout.dense_shape)
  402. return (dx,)
  403. return bprop
  404. @bprop_getters.register(_MirrorMiniStepOperator)
  405. def get_bprop_mirror_mini_step_operator(self):
  406. """
  407. Backpropagator for _MirrorMiniStepOperator, do allreduce or allgather for the devices in the group,
  408. allgather for sparse feature.
  409. """
  410. group = self.group
  411. dev_num = self.dev_num
  412. mean_flag = self.mean_flag
  413. all_reduce = AllReduce(group=group)
  414. mul = P.Mul()
  415. cast = P.Cast()
  416. fusion = self.get_attr_dict()["fusion"]
  417. all_reduce.add_prim_attr("fusion", fusion)
  418. if hasattr(self, 'parameter'):
  419. parameter = self.parameter
  420. all_reduce.add_prim_attr("parameter", parameter)
  421. if self.instance_name:
  422. instance_name = "grad_mirror" + self.instance_name
  423. all_reduce.set_prim_instance_name(instance_name)
  424. do_mirror = self.get_attr_dict()["do_mirror"]
  425. def bprop(x, z, out, dout):
  426. if mean_flag:
  427. if F.issubclass_(F.typeof(dout), mstype.tensor):
  428. if do_mirror:
  429. z = F.depend(z, F.assign_add(z, dout))
  430. real_grad = all_reduce(z)
  431. dx = real_grad
  432. else:
  433. dx = dout
  434. float_one = F.scalar_cast(1.0, F.dtype(dx))
  435. num = F.scalar_cast(dev_num, F.dtype(dx))
  436. dx = mul(dx, cast(F.scalar_to_array(float_one/num), F.dtype(dx)))
  437. else:
  438. dx = zeros_like(x) # The grad accumulation do not support row tensor now
  439. else:
  440. if F.issubclass_(F.typeof(dout), mstype.tensor):
  441. if do_mirror:
  442. z = F.depend(z, F.assign_add(z, dout))
  443. real_grad = all_reduce(z)
  444. dx = real_grad
  445. else:
  446. dx = dout
  447. else:
  448. dx = zeros_like(x) # The grad accumulation do not support row tensor now
  449. return (dx, zeros_like(z))
  450. return bprop
  451. @bprop_getters.register(_VirtualDiv)
  452. def get_bprop_virtual_div_operator(self):
  453. """Backpropagator for _VirtualDiv, do Div for the divisor."""
  454. divisor = self.divisor
  455. op = P.RealDiv()
  456. cast = P.Cast()
  457. dtype = P.DType()
  458. def bprop(x, out, dout):
  459. if F.issubclass_(F.typeof(dout), mstype.tensor):
  460. if F.issubclass_(F.dtype(dout), mstype.bool_) or F.issubclass_(F.dtype(dout), mstype.int32) \
  461. or F.issubclass_(F.dtype(dout), mstype.int16):
  462. return (dout,)
  463. dx = op(dout, cast(F.scalar_to_array(divisor), dtype(dout)))
  464. return (dx,)
  465. if F.issubclass_(F.typeof(dout), mstype.tuple_):
  466. dx = ()
  467. input_nums = F.tuple_len(dout)
  468. for i in range(input_nums):
  469. ele_grad = op(dout[i], cast(F.scalar_to_array(divisor), dtype(dout[i])))
  470. dx = dx + (ele_grad,)
  471. return (dx,)
  472. dx = []
  473. input_nums = F.list_len(dout)
  474. for i in range(input_nums):
  475. ele_grad = op(dout[i], cast(F.scalar_to_array(divisor), dtype(dout[i])))
  476. dx.append(ele_grad)
  477. return (dx,)
  478. return bprop
  479. @bprop_getters.register(_GetTensorSlice)
  480. def get_bprop_get_tensor_slice_operator(self):
  481. """Backpropagator for _GetTensorSlice"""
  482. def bprop(x, dev_mat, tensor_map, out, dout):
  483. return (zeros_like(x),)
  484. return bprop