You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895
  1. # -*- coding: utf-8 -*-
  2. # pylint: disable=too-many-lines
  3. from functools import lru_cache
  4. from typing import NamedTuple, Optional, Sequence, Tuple, Union
  5. from ..core import _config
  6. from ..core._imperative_rt.core2 import (
  7. Const,
  8. adaptive_pool2d_cpp,
  9. apply,
  10. dtype_promotion,
  11. pixel_shuffle_cpp,
  12. )
  13. from ..core._imperative_rt.ops import get_global_rng_seed as _get_global_rng_seed
  14. from ..core.ops import builtin
  15. from ..core.ops.builtin import (
  16. BatchNorm,
  17. Dimshuffle,
  18. Dropout,
  19. Elemwise,
  20. GetVarShape,
  21. Identity,
  22. Reduce,
  23. Reshape,
  24. TypeCvt,
  25. )
  26. from ..core.tensor import amp, megbrain_graph
  27. from ..core.tensor.array_method import _matmul
  28. from ..core.tensor.utils import (
  29. astensor1d,
  30. cast_tensors,
  31. convert_single_value,
  32. make_shape_tuple,
  33. subgraph,
  34. subgraph_fn,
  35. )
  36. from ..device import get_default_device
  37. from ..distributed import WORLD, is_distributed
  38. from ..jit import exclude_from_trace
  39. from ..tensor import Tensor
  40. from ..utils.deprecation import deprecated_func
  41. from .debug_param import get_execution_strategy
  42. from .distributed import all_reduce_sum
  43. from .elemwise import _elwise, exp, log, log1p, maximum, minimum
  44. from .math import max, normalize, sum
  45. from .tensor import broadcast_to, concat, expand_dims, ones, squeeze, zeros
  46. __all__ = [
  47. "adaptive_avg_pool2d",
  48. "adaptive_max_pool2d",
  49. "avg_pool2d",
  50. "batch_norm",
  51. "conv1d",
  52. "conv2d",
  53. "conv3d",
  54. "conv_transpose2d",
  55. "conv_transpose3d",
  56. "deformable_conv2d",
  57. "deformable_psroi_pooling",
  58. "dropout",
  59. "embedding",
  60. "gelu",
  61. "group_norm",
  62. "hsigmoid",
  63. "hswish",
  64. "indexing_one_hot",
  65. "layer_norm",
  66. "leaky_relu",
  67. "linear",
  68. "local_conv2d",
  69. "local_response_norm",
  70. "logsigmoid",
  71. "logsumexp",
  72. "logsoftmax",
  73. "max_pool2d",
  74. "normalize",
  75. "one_hot",
  76. "prelu",
  77. "pad",
  78. "relu",
  79. "relu6",
  80. "remap",
  81. "sigmoid",
  82. "sliding_window",
  83. "sliding_window_transpose",
  84. "silu",
  85. "softmax",
  86. "softplus",
  87. "sync_batch_norm",
  88. "warp_affine",
  89. "warp_perspective",
  90. "pixel_shuffle",
  91. "region_restricted_conv",
  92. ]
  93. def expand_hw(x):
  94. # judge int is 5 times faster than judge Sequence
  95. if isinstance(x, int):
  96. return x, x
  97. if isinstance(x, Sequence):
  98. return int(x[0]), int(x[1])
  99. return int(x), int(x)
  100. def expand_dhw(x):
  101. if isinstance(x, int):
  102. return x, x, x
  103. if isinstance(x, Sequence):
  104. return int(x[0]), int(x[1]), int(x[2])
  105. return int(x), int(x), int(x)
  106. def linear(
  107. inp: Tensor, weight: Tensor, bias: Optional[Tensor] = None, compute_mode="default",
  108. ) -> Tensor:
  109. r"""Applies a linear transformation to the input tensor.
  110. Refer to :class:`~.module.linear.Linear` for more information.
  111. Args:
  112. inp: input tensor with shape `(N, in_features)`.
  113. weight: weight with shape `(out_features, in_features)`.
  114. bias: bias with shape `(out_features,)`. Default: None
  115. """
  116. compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode)
  117. ret = _matmul(inp, weight, transpose_b=True, compute_mode=compute_mode)
  118. if bias is not None:
  119. if amp._enabled:
  120. bias = bias.astype("float16")
  121. ret += bias
  122. return ret
  123. def conv1d(
  124. inp: Tensor,
  125. weight: Tensor,
  126. bias: Optional[Tensor] = None,
  127. stride: int = 1,
  128. padding: int = 0,
  129. dilation: int = 1,
  130. groups: int = 1,
  131. conv_mode="cross_correlation",
  132. compute_mode="default",
  133. ) -> Tensor:
  134. r"""1D convolution operation.
  135. Refer to :class:`~.Conv1d` for more information.
  136. Args:
  137. inp: The feature map of the convolution operation
  138. weight: The convolution kernel.
  139. bias: The bias added to the result of convolution (if given)
  140. stride: Stride of the 1D convolution operation. Default: 1
  141. padding: Size of the paddings added to the input on both sides of its
  142. spatial dimensions. Only zero-padding is supported. Default: 0
  143. dilation: Dilation of the 1D convolution operation. Default: 1
  144. groups: number of groups to divide input and output channels into,
  145. so as to perform a "grouped convolution". When ``groups`` is not 1,
  146. ``in_channels`` and ``out_channels`` must be divisible by ``groups``,
  147. and the shape of weight should be ``(groups, out_channel // groups,
  148. in_channels // groups, kernel_size)``. Default: 1
  149. conv_mode: Supports 'cross_correlation'. Default:
  150. 'cross_correlation'.
  151. compute_mode: When set to 'default', no special requirements will be
  152. placed on the precision of intermediate results. When set to 'float32',
  153. float32 would be used for accumulator and intermediate result, but only
  154. effective when input and output are of float16 dtype.
  155. """
  156. assert (
  157. conv_mode.lower() == "cross_correlation"
  158. or conv_mode.name == "CROSS_CORRELATION"
  159. )
  160. assert compute_mode.lower() == "default" or compute_mode.name == "DEFAULT"
  161. assert inp.ndim == 3, "the input dimension of conv1d should be 3"
  162. assert weight.ndim == 3, "the weight dimension of conv1d should be 3"
  163. if bias is not None:
  164. assert bias.ndim == 3, "the bias dimension of conv1d should be 3"
  165. stride_h = stride
  166. pad_h = padding
  167. dilate_h = dilation
  168. compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode)
  169. sparse_type = "dense" if groups == 1 else "group"
  170. op = builtin.Convolution(
  171. stride_h=stride_h,
  172. stride_w=1,
  173. pad_h=pad_h,
  174. pad_w=0,
  175. dilate_h=dilate_h,
  176. dilate_w=1,
  177. strategy=get_execution_strategy(),
  178. mode=conv_mode,
  179. compute_mode=compute_mode,
  180. sparse=sparse_type,
  181. )
  182. (output,) = apply(op, inp, weight)
  183. if bias is not None:
  184. if amp._enabled:
  185. (bias,) = cast_tensors(bias)
  186. output += bias
  187. return output
  188. def conv2d(
  189. inp: Tensor,
  190. weight: Tensor,
  191. bias: Optional[Tensor] = None,
  192. stride: Union[int, Tuple[int, int]] = 1,
  193. padding: Union[int, Tuple[int, int]] = 0,
  194. dilation: Union[int, Tuple[int, int]] = 1,
  195. groups: int = 1,
  196. conv_mode="cross_correlation",
  197. compute_mode="default",
  198. ) -> Tensor:
  199. r"""2D convolution operation.
  200. Refer to :class:`~.module.Conv2d` for more information.
  201. Args:
  202. inp: feature map of the convolution operation.
  203. weight: convolution kernel.
  204. bias: bias added to the result of convolution (if given).
  205. stride: stride of the 2D convolution operation. Default: 1
  206. padding: size of the paddings added to the input on both sides of its
  207. spatial dimensions. Only zero-padding is supported. Default: 0
  208. dilation: dilation of the 2D convolution operation. Default: 1
  209. groups: number of groups into which the input and output channels are divided,
  210. so as to perform a ``grouped convolution``. When ``groups`` is not 1,
  211. ``in_channels`` and ``out_channels`` must be divisible by ``groups``,
  212. and the shape of weight should be ``(groups, out_channel // groups,
  213. in_channels // groups, height, width)``. Default: 1
  214. conv_mode: supports "cross_correlation". Default: "cross_correlation"
  215. compute_mode: when set to "default", no special requirements will be
  216. placed on the precision of intermediate results. When set to "float32",
  217. "float32" would be used for accumulator and intermediate result, but only
  218. effective when input and output are of float16 dtype.
  219. Returns:
  220. output tensor.
  221. """
  222. assert (
  223. conv_mode.lower() == "cross_correlation"
  224. or conv_mode.name == "CROSS_CORRELATION"
  225. )
  226. stride_h, stride_w = expand_hw(stride)
  227. pad_h, pad_w = expand_hw(padding)
  228. dilate_h, dilate_w = expand_hw(dilation)
  229. sparse_type = "dense" if groups == 1 else "group"
  230. compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode)
  231. op = builtin.Convolution(
  232. stride_h=stride_h,
  233. stride_w=stride_w,
  234. pad_h=pad_h,
  235. pad_w=pad_w,
  236. dilate_h=dilate_h,
  237. dilate_w=dilate_w,
  238. strategy=get_execution_strategy(),
  239. mode=conv_mode,
  240. compute_mode=compute_mode,
  241. sparse=sparse_type,
  242. )
  243. (output,) = apply(op, inp, weight)
  244. if bias is not None:
  245. if amp._enabled:
  246. (bias,) = cast_tensors(bias)
  247. output += bias
  248. return output
  249. def conv3d(
  250. inp: Tensor,
  251. weight: Tensor,
  252. bias: Optional[Tensor] = None,
  253. stride: Union[int, Tuple[int, int, int]] = 1,
  254. padding: Union[int, Tuple[int, int, int]] = 0,
  255. dilation: Union[int, Tuple[int, int, int]] = 1,
  256. groups: int = 1,
  257. conv_mode: str = "cross_correlation",
  258. ) -> Tensor:
  259. r"""3D convolution operation.
  260. Refer to :class:`~.Conv3d` for more information.
  261. Args:
  262. inp: feature map of the convolution operation.
  263. weight: convolution kernel.
  264. bias: bias added to the result of convolution (if given).
  265. stride: stride of the 3D convolution operation. Default: 1
  266. padding: size of the paddings added to the input on both sides of its
  267. spatial dimensions. Only zero-padding is supported. Default: 0
  268. dilation: dilation of the 3D convolution operation. Default: 1
  269. groups: number of groups into which the input and output channels are divided,
  270. so as to perform a ``grouped convolution``. When ``groups`` is not 1,
  271. ``in_channels`` and ``out_channels`` must be divisible by ``groups``,
  272. and the shape of weight should be ``(groups, out_channel // groups,
  273. in_channels // groups, depth, height, width)``. Default: 1
  274. conv_mode: supports "cross_correlation". Default: "cross_correlation"
  275. Returns:
  276. output tensor.
  277. """
  278. assert conv_mode.lower() == "cross_correlation"
  279. D, H, W = 0, 1, 2
  280. pad = expand_dhw(padding)
  281. stride = expand_dhw(stride)
  282. dilate = expand_dhw(dilation)
  283. sparse_type = "dense" if groups == 1 else "group"
  284. op = builtin.Convolution3D(
  285. pad_d=pad[D],
  286. pad_h=pad[H],
  287. pad_w=pad[W],
  288. stride_d=stride[D],
  289. stride_h=stride[H],
  290. stride_w=stride[W],
  291. dilate_d=dilate[D],
  292. dilate_h=dilate[H],
  293. dilate_w=dilate[W],
  294. strategy=get_execution_strategy(),
  295. mode=conv_mode,
  296. sparse=sparse_type,
  297. )
  298. (output,) = apply(op, inp, weight)
  299. if bias is not None:
  300. output += bias
  301. return output
  302. def conv_transpose2d(
  303. inp: Tensor,
  304. weight: Tensor,
  305. bias: Optional[Tensor] = None,
  306. stride: Union[int, Tuple[int, int]] = 1,
  307. padding: Union[int, Tuple[int, int]] = 0,
  308. output_padding: Union[int, Tuple[int, int]] = 0,
  309. dilation: Union[int, Tuple[int, int]] = 1,
  310. groups: int = 1,
  311. conv_mode="cross_correlation",
  312. compute_mode="default",
  313. ) -> Tensor:
  314. r"""2D transposed convolution operation.
  315. Refer to :class:`~.module.conv.ConvTranspose2d` for more information.
  316. Args:
  317. inp: feature map of the convolution operation.
  318. weight: convolution kernel.
  319. weight usually has shape ``(in_channels, out_channels, height, width)``.
  320. bias: bias added to the result of convolution (if given).
  321. stride: stride of the 2D convolution operation. Default: 1
  322. padding: size of the paddings added to the input on both sides of its
  323. spatial dimensions. Only zero-padding is supported. Default: 0
  324. output_padding: size of paddings appended to output. Default: 0
  325. dilation: dilation of the 2D convolution operation. Default: 1
  326. groups: number of groups into which the input and output channels are divided,
  327. so as to perform a ``grouped convolution``. When ``groups`` is not 1,
  328. ``in_channels`` and ``out_channels`` must be divisible by groups,
  329. and the shape of weight should be ``(groups, in_channels // groups,
  330. out_channels // groups, height, width)``. Default: 1
  331. conv_mode: supports "cross_correlation". Default: "cross_correlation"
  332. compute_mode: when set to "default", no special requirements will be
  333. placed on the precision of intermediate results. When set to "float32",
  334. "float32" would be used for accumulator and intermediate result, but only
  335. effective when input and output are of float16 dtype.
  336. Returns:
  337. output tensor.
  338. """
  339. assert (
  340. conv_mode.lower() == "cross_correlation"
  341. or conv_mode.name == "CROSS_CORRELATION"
  342. )
  343. stride_h, stride_w = expand_hw(stride)
  344. pad_h, pad_w = expand_hw(padding)
  345. output_pad_h, output_pad_w = expand_hw(output_padding)
  346. dilate_h, dilate_w = expand_hw(dilation)
  347. compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode)
  348. sparse_type = "dense" if groups == 1 else "group"
  349. op = builtin.ConvolutionBackwardData(
  350. stride_h=stride_h,
  351. stride_w=stride_w,
  352. pad_h=pad_h,
  353. pad_w=pad_w,
  354. dilate_h=dilate_h,
  355. dilate_w=dilate_w,
  356. strategy=get_execution_strategy(),
  357. compute_mode=compute_mode,
  358. sparse=sparse_type,
  359. )
  360. if output_pad_h != 0 or output_pad_h != 0:
  361. assert (
  362. output_pad_h < stride[0]
  363. ), "output_padding[0] shoule be less than stride[0]"
  364. assert (
  365. output_pad_w < stride[1]
  366. ), "output_padding[1] shoule be less than stride[1]"
  367. Hout = (
  368. (inp.shape[2] - 1) * stride[0]
  369. - 2 * padding[0]
  370. + dilation[0] * (weight.shape[2] - 1)
  371. + output_pad_h
  372. + 1
  373. )
  374. Wout = (
  375. (inp.shape[3] - 1) * stride[1]
  376. - 2 * padding[1]
  377. + dilation[1] * (weight.shape[3] - 1)
  378. + output_pad_w
  379. + 1
  380. )
  381. output_shape = [inp.shape[0], weight.shape[1], Hout, Wout]
  382. output_shape = astensor1d(output_shape)
  383. (output,) = apply(op, weight, inp, output_shape)
  384. else:
  385. (output,) = apply(op, weight, inp)
  386. if bias is not None:
  387. if amp._enabled:
  388. bias = cast_tensors(bias)
  389. output += bias
  390. return output
  391. def deformable_conv2d(
  392. inp: Tensor,
  393. weight: Tensor,
  394. offset: Tensor,
  395. mask: Tensor,
  396. bias: Optional[Tensor] = None,
  397. stride: Union[int, Tuple[int, int]] = 1,
  398. padding: Union[int, Tuple[int, int]] = 0,
  399. dilation: Union[int, Tuple[int, int]] = 1,
  400. groups: int = 1,
  401. conv_mode="cross_correlation",
  402. compute_mode="default",
  403. ) -> Tensor:
  404. r"""Deformable Convolution.
  405. Args:
  406. inp: input feature map.
  407. weight: convolution kernel.
  408. weight usually has shape ``(out_channels, in_channels, height, width)``.
  409. offset: input offset to kernel, channel of this tensor should match the deformable settings.
  410. mask: input mask to kernel, channel of this tensor should match the deformable settings.
  411. bias: bias added to the result of convolution (if given).
  412. stride: stride of the 2D convolution operation. Default: 1
  413. padding: size of the paddings added to the input on both sides of its
  414. spatial dimensions. Only zero-padding is supported. Default: 0
  415. dilation: dilation of the 2D convolution operation. Default: 1
  416. groups: number of groups into which the input and output channels are divided,
  417. so as to perform a ``grouped convolution``. When ``groups`` is not 1,
  418. ``in_channels`` and ``out_channels`` must be divisible by groups,
  419. and the shape of weight should be ``(groups, out_channel // groups,
  420. in_channels // groups, height, width)``. Default: 1
  421. conv_mode: supports "cross_correlation". Default: "cross_correlation"
  422. compute_mode: when set to "default", no special requirements will be
  423. placed on the precision of intermediate results. When set to "float32",
  424. "float32" would be used for accumulator and intermediate result, but only
  425. effective when input and output are of float16 dtype.
  426. Returns:
  427. output tensor.
  428. """
  429. assert (
  430. conv_mode.lower() == "cross_correlation"
  431. or conv_mode.name == "CROSS_CORRELATION"
  432. )
  433. if amp._enabled:
  434. inp, weight, offset, mask, bias = cast_tensors(inp, weight, offset, mask, bias)
  435. else:
  436. offset = offset.astype("float32")
  437. mask = mask.astype("float32")
  438. stride_h, stride_w = expand_hw(stride)
  439. pad_h, pad_w = expand_hw(padding)
  440. dilate_h, dilate_w = expand_hw(dilation)
  441. compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode)
  442. sparse_type = "dense" if groups == 1 else "group"
  443. op = builtin.DeformableConv(
  444. stride_h=stride_h,
  445. stride_w=stride_w,
  446. pad_h=pad_h,
  447. pad_w=pad_w,
  448. dilate_h=dilate_h,
  449. dilate_w=dilate_w,
  450. strategy=get_execution_strategy(),
  451. mode=conv_mode,
  452. compute_mode=compute_mode,
  453. sparse=sparse_type,
  454. )
  455. (output,) = apply(op, inp, weight, offset, mask)
  456. if bias is not None:
  457. output += bias
  458. return output
  459. def local_conv2d(
  460. inp: Tensor,
  461. weight: Tensor,
  462. bias: Optional[Tensor] = None,
  463. stride: Union[int, Tuple[int, int]] = 1,
  464. padding: Union[int, Tuple[int, int]] = 0,
  465. dilation: Union[int, Tuple[int, int]] = 1,
  466. conv_mode="cross_correlation",
  467. ):
  468. r"""Applies a spatial convolution with untied kernels over an groupped channeled input 4D tensor.
  469. It is also known as the locally connected layer.
  470. Args:
  471. inp: input feature map.
  472. weight: convolution kernel.
  473. weight usually has shape ``(out_channels, in_channels, height, width)``.
  474. bias: bias added to the result of convolution (if given).
  475. stride: stride of the 2D convolution operation. Default: 1
  476. padding: size of the paddings added to the input on both sides of its
  477. spatial dimensions. Only zero-padding is supported. Default: 0
  478. dilation: dilation of the 2D convolution operation. Default: 1
  479. Returns:
  480. output tensor.
  481. """
  482. assert (
  483. conv_mode.lower() == "cross_correlation"
  484. or conv_mode.name == "CROSS_CORRELATION"
  485. )
  486. stride_h, stride_w = expand_hw(stride)
  487. pad_h, pad_w = expand_hw(padding)
  488. dilate_h, dilate_w = expand_hw(dilation)
  489. # local conv only support "dense" mode, but weight could contain group dimension.
  490. op = builtin.GroupLocal(
  491. stride_h=stride_h,
  492. stride_w=stride_w,
  493. pad_h=pad_h,
  494. pad_w=pad_w,
  495. dilate_h=dilate_h,
  496. dilate_w=dilate_w,
  497. mode=conv_mode,
  498. sparse="dense",
  499. )
  500. (output,) = apply(op, inp, weight)
  501. if bias is not None:
  502. output += bias
  503. return output
  504. def conv_transpose3d(
  505. inp: Tensor,
  506. weight: Tensor,
  507. bias: Optional[Tensor] = None,
  508. stride: Union[int, Tuple[int, int, int]] = 1,
  509. padding: Union[int, Tuple[int, int, int]] = 0,
  510. output_padding: Union[int, Tuple[int, int, int]] = 0,
  511. dilation: Union[int, Tuple[int, int, int]] = 1,
  512. groups: int = 1,
  513. ) -> Tensor:
  514. r"""3D transposed convolution operation. Only support the case that groups = 1
  515. and conv_mode = "cross_correlation".
  516. Refer to :class:`~.ConvTranspose3d` for more information.
  517. Args:
  518. inp: feature map of the convolution operation.
  519. weight: convolution kernel.
  520. weight usually has shape ``(in_channels, out_channels, depth, height, width)``.
  521. bias: bias added to the result of convolution (if given).
  522. stride: stride of the 3D convolution operation. Default: 1
  523. padding: size of the paddings added to the input on all sides of its
  524. spatial dimensions. Only zero-padding is supported. Default: 0
  525. output_padding: size of paddings appended to output. Default: 0
  526. dilation: dilation of the 3D convolution operation. Default: 1
  527. groups: number of groups into which the input and output channels are divided,
  528. so as to perform a ``grouped convolution``. When ``groups`` is not 1,
  529. ``in_channels`` and ``out_channels`` must be divisible by groups,
  530. and the shape of weight should be ``(groups, in_channels // groups,
  531. out_channels // groups, depth, height, width)``. Default: 1
  532. Returns:
  533. output tensor.
  534. """
  535. D, H, W = 0, 1, 2
  536. pad = expand_dhw(padding)
  537. stride = expand_dhw(stride)
  538. dilate = expand_dhw(dilation)
  539. output_padding = expand_dhw(output_padding)
  540. sparse_type = "dense" if groups == 1 else "group"
  541. op = builtin.Convolution3DBackwardData(
  542. pad_d=pad[D],
  543. pad_h=pad[H],
  544. pad_w=pad[W],
  545. stride_d=stride[D],
  546. stride_h=stride[H],
  547. stride_w=stride[W],
  548. dilate_d=dilate[D],
  549. dilate_h=dilate[H],
  550. dilate_w=dilate[W],
  551. strategy=get_execution_strategy(),
  552. sparse=sparse_type,
  553. )
  554. if output_padding[0] != 0 or output_padding[1] != 0 or output_padding[2] != 0:
  555. assert (
  556. output_padding[0] < stride[0]
  557. ), "output_padding[0] shoule be less than stride[0]"
  558. assert (
  559. output_padding[1] < stride[1]
  560. ), "output_padding[1] shoule be less than stride[1]"
  561. assert (
  562. output_padding[2] < stride[2]
  563. ), "output_padding[2] shoule be less than stride[2]"
  564. Dout = (
  565. (inp.shape[2] - 1) * stride[0]
  566. - 2 * padding[0]
  567. + dilation[0] * (weight.shape[2] - 1)
  568. + output_padding[0]
  569. + 1
  570. )
  571. Hout = (
  572. (inp.shape[3] - 1) * stride[1]
  573. - 2 * padding[1]
  574. + dilation[1] * (weight.shape[3] - 1)
  575. + output_padding[1]
  576. + 1
  577. )
  578. Wout = (
  579. (inp.shape[4] - 1) * stride[2]
  580. - 2 * padding[2]
  581. + dilation[2] * (weight.shape[4] - 1)
  582. + output_padding[2]
  583. + 1
  584. )
  585. output_shape = [inp.shape[0], weight.shape[1], Dout, Hout, Wout]
  586. output_shape = astensor1d(output_shape)
  587. (output,) = apply(op, weight, inp, output_shape)
  588. else:
  589. (output,) = apply(op, weight, inp)
  590. if bias is not None:
  591. output += bias
  592. return output
  593. def max_pool2d(
  594. inp: Tensor,
  595. kernel_size: Union[int, Tuple[int, int]],
  596. stride: Optional[Union[int, Tuple[int, int]]] = None,
  597. padding: Union[int, Tuple[int, int]] = 0,
  598. ) -> Tensor:
  599. r"""Applies a 2D max pooling over an input tensor.
  600. Refer to :class:`~.MaxPool2d` for more information.
  601. Args:
  602. inp: input tensor of shape :math:`(N, C, H_{\text{in}}, W_{\text{in}})`.
  603. kernel_size: size of the window used to calculate the max value.
  604. stride: stride of the window. Default value is ``kernel_size``.
  605. padding: implicit zero padding added on both sides. Default: 0.
  606. Returns:
  607. output tensor of shape `(N, C, H_{\text{out}}, W_{\text{out}})`.
  608. Examples:
  609. >>> import numpy as np
  610. >>> input = Tensor(np.arange(1 * 1 * 3 * 4).astype(np.float32).reshape(1, 1, 3, 4))
  611. >>> F.nn.max_pool2d(input, 2, 1, 0)
  612. Tensor([[[[ 5. 6. 7.]
  613. [ 9. 10. 11.]]]], device=xpux:0)
  614. """
  615. if stride is None:
  616. stride = kernel_size
  617. window_h, window_w = expand_hw(kernel_size)
  618. stride_h, stride_w = expand_hw(stride)
  619. padding_h, padding_w = expand_hw(padding)
  620. op = builtin.Pooling(
  621. window_h=window_h,
  622. window_w=window_w,
  623. stride_h=stride_h,
  624. stride_w=stride_w,
  625. pad_h=padding_h,
  626. pad_w=padding_w,
  627. mode="max",
  628. strategy=get_execution_strategy(),
  629. )
  630. (output,) = apply(op, inp)
  631. return output
  632. def avg_pool2d(
  633. inp: Tensor,
  634. kernel_size: Union[int, Tuple[int, int]],
  635. stride: Optional[Union[int, Tuple[int, int]]] = None,
  636. padding: Union[int, Tuple[int, int]] = 0,
  637. mode: str = "average_count_exclude_padding",
  638. ) -> Tensor:
  639. r"""Applies 2D average pooling over an input tensor.
  640. Refer to :class:`~.AvgPool2d` for more information.
  641. Args:
  642. inp: input tensor of shape :math:`(N, C, H_{\text{in}}, W_{\text{in}})` .
  643. kernel_size: size of the window used to calculate the average value.
  644. stride: stride of the window. Default value is ``kernel_size``.
  645. padding: implicit zero padding added on both sides. Default: 0.
  646. mode: whether to include the padding values while calculating the average, set
  647. to "average" will do counting.
  648. Default: "average_count_exclude_padding"
  649. Returns:
  650. output tensor of shape :math:`(N, C, H_{\text{out}}, W_{\text{out}})`.
  651. Examples:
  652. >>> import numpy as np
  653. >>> inp = Tensor(np.arange(1 * 1 * 3 * 4).astype(np.float32).reshape(1, 1, 3, 4))
  654. >>> F.avg_pool2d(inp, kernel_size=2, stride=2, padding=[1,0], mode="average")
  655. Tensor([[[[0.25 1.25]
  656. [6.5 8.5 ]]]], device=xpux:0)
  657. """
  658. if stride is None:
  659. stride = kernel_size
  660. window_h, window_w = expand_hw(kernel_size)
  661. stride_h, stride_w = expand_hw(stride)
  662. padding_h, padding_w = expand_hw(padding)
  663. op = builtin.Pooling(
  664. window_h=window_h,
  665. window_w=window_w,
  666. stride_h=stride_h,
  667. stride_w=stride_w,
  668. pad_h=padding_h,
  669. pad_w=padding_w,
  670. mode=mode,
  671. strategy=get_execution_strategy(),
  672. )
  673. (output,) = apply(op, inp)
  674. return output
  675. def adaptive_max_pool2d(
  676. inp: Tensor, oshp: Union[Tuple[int, int], int, Tensor],
  677. ) -> Tensor:
  678. r"""Applies a 2D max adaptive pooling over an input.
  679. Refer to :class:`~.MaxAdaptivePool2d` for more information.
  680. Args:
  681. inp: input tensor.
  682. oshp: `(OH, OW)` size of the output shape.
  683. Returns:
  684. output tensor.
  685. """
  686. return adaptive_pool2d_cpp(inp, oshp, "MAX")
  687. def adaptive_avg_pool2d(
  688. inp: Tensor, oshp: Union[Tuple[int, int], int, Tensor],
  689. ) -> Tensor:
  690. r"""Applies a 2D average adaptive pooling over an input.
  691. Refer to :class:`~.AvgAdaptivePool2d` for more information.
  692. Args:
  693. inp: input tensor.
  694. oshp: `(OH, OW)` size of the output shape.
  695. Returns:
  696. output tensor.
  697. """
  698. return adaptive_pool2d_cpp(inp, oshp, "AVERAGE")
  699. def deformable_psroi_pooling(
  700. inp: Tensor,
  701. rois: Tensor,
  702. trans: Tensor,
  703. no_trans: bool,
  704. part_size: int,
  705. pooled_h: int,
  706. pooled_w: int,
  707. sample_per_part: int,
  708. spatial_scale: float,
  709. trans_std: float = 0.1,
  710. ):
  711. r"""Deformable PSROI(Position Sensitive Region of Interest) Pooling.
  712. Args:
  713. inp: input feature map.
  714. rois: the rois for feature pooling.
  715. trans: input offset to psroi_pooling.
  716. no_trans: check the phase of DeformablePSROIPooling. False to the
  717. 1st phase, True to the 2nd phase.
  718. part_size: part size.
  719. sample_per_part: sample points of each part.
  720. pooled_shape: kernel shape of convolution.
  721. spatial_scale: the spatial_scale w.r.t input image.
  722. trans_std: multiplier used in 2nd phase.
  723. """
  724. op = builtin.DeformablePSROIPooling(
  725. no_trans=no_trans,
  726. part_size=part_size,
  727. pooled_h=pooled_h,
  728. pooled_w=pooled_w,
  729. sample_per_part=sample_per_part,
  730. spatial_scale=spatial_scale,
  731. trans_std=trans_std,
  732. )
  733. output, _ = apply(op, inp, rois, trans)
  734. return output
  735. def hswish(x):
  736. r"""Element-wise `x * relu6(x + 3) / 6`.
  737. Example:
  738. >>> import numpy as np
  739. >>> x = Tensor(np.arange(5).astype(np.float32))
  740. >>> out = F.hswish(x)
  741. >>> out.numpy().round(decimals=4)
  742. array([0. , 0.6667, 1.6667, 3. , 4. ], dtype=float32)
  743. """
  744. return _elwise(x, mode=Elemwise.Mode.H_SWISH)
  745. def sigmoid(x):
  746. r"""Element-wise `1 / ( 1 + exp( -x ) )`."""
  747. return _elwise(x, mode=Elemwise.Mode.SIGMOID)
  748. def hsigmoid(x):
  749. r"""Element-wise `relu6(x + 3) / 6`."""
  750. return _elwise(x, mode=Elemwise.Mode.HSIGMOID)
  751. def relu(x):
  752. r"""Element-wise `max(x, 0)`."""
  753. return _elwise(x, mode=Elemwise.Mode.RELU)
  754. def relu6(x):
  755. r"""Element-wise `min(max(x, 0), 6)`."""
  756. return _elwise(x, mode=Elemwise.Mode.RELU6)
  757. def prelu(x, y):
  758. r"""Element-wise `max(x, 0) + y * min(x, 0)`."""
  759. return _elwise(x, y, mode=Elemwise.Mode.PRELU)
  760. def leaky_relu(inp: Tensor, negative_slope: float = 0.01) -> Tensor:
  761. r"""Element-wise LeakyReLU function
  762. Refer to :class:`~.LeakyReLU` for more information.
  763. """
  764. return _elwise(inp, negative_slope, mode=Elemwise.Mode.PRELU)
  765. def silu(x):
  766. r"""Applies the element-wise Sigmoid Linear Unit function, i.e. `x * sigmoid(x)`."""
  767. return _elwise(x, mode=Elemwise.Mode.SILU)
  768. def gelu(x):
  769. r"""Applies the element-wise function:
  770. .. math::
  771. \text{gelu}(x) = x\Phi(x)
  772. where :math:`\Phi(x)` is the Cumulative Distribution Function for Gaussian Distribution.
  773. """
  774. return _elwise(x, mode=Elemwise.Mode.GELU)
  775. def softplus(inp: Tensor) -> Tensor:
  776. r"""Applies the element-wise function:
  777. .. math::
  778. \text{softplus}(x) = \log(1 + \exp(x))
  779. softplus is a smooth approximation to the ReLU function and can be used
  780. to constrain the output to be always positive.
  781. For numerical stability the implementation follows this transformation:
  782. .. math::
  783. \text{softplus}(x) = \log(1 + \exp(x))
  784. = \log(1 + \exp(-\text{abs}(x))) + \max(x, 0)
  785. = \log1p(\exp(-\text{abs}(x))) + \text{relu}(x)
  786. Examples:
  787. >>> import numpy as np
  788. >>> x = Tensor(np.arange(-3, 3, dtype=np.float32))
  789. >>> y = F.softplus(x)
  790. >>> y.numpy().round(decimals=4)
  791. array([0.0486, 0.1269, 0.3133, 0.6931, 1.3133, 2.1269], dtype=float32)
  792. """
  793. return _elwise(inp, mode=Elemwise.Mode.SOFTPLUS)
  794. def logsoftmax(inp: Tensor, axis: Union[int, Sequence[int]]) -> Tensor:
  795. r"""Applies the :math:`\log(\text{softmax}(x))` function to an n-dimensional
  796. input tensor. The :math:`\text{logsoftmax}(x)` formulation can be simplified as:
  797. .. math::
  798. \text{logsoftmax}(x_{i}) = \log(\frac{\exp(x_i) }{ \sum_j \exp(x_j)} )
  799. For numerical stability the implementation follows this transformation:
  800. .. math::
  801. \text{logsoftmax}(x)
  802. = \log (\frac{\exp (x)}{\sum_{i}(\exp (x_{i}))})
  803. = x - \log (\sum_{i}(\exp (x_{i})))
  804. = x - \text{logsumexp}(x)
  805. Examples:
  806. >>> import numpy as np
  807. >>> x = Tensor(np.arange(-5, 5, dtype=np.float32)).reshape(2,5)
  808. >>> y = F.logsoftmax(x, axis=1)
  809. >>> y.numpy().round(decimals=4)
  810. array([[-4.4519, -3.4519, -2.4519, -1.4519, -0.4519],
  811. [-4.4519, -3.4519, -2.4519, -1.4519, -0.4519]], dtype=float32)
  812. """
  813. return inp - logsumexp(inp, axis, keepdims=True)
  814. def logsigmoid(inp: Tensor) -> Tensor:
  815. r"""Applies the element-wise function:
  816. .. math::
  817. \text{logsigmoid}(x) = \log(\frac{ 1 }{ 1 + \exp(-x)})
  818. = \log(1/(1 + \exp(-x)))
  819. = - \log(1 + \exp(-x))
  820. = - \text{softplus}(-x)
  821. Examples:
  822. >>> import numpy as np
  823. >>> x = Tensor(np.arange(-5, 5, dtype=np.float32))
  824. >>> y = F.logsigmoid(x)
  825. >>> y.numpy().round(decimals=4)
  826. array([-5.0067, -4.0182, -3.0486, -2.1269, -1.3133, -0.6931, -0.3133,
  827. -0.1269, -0.0486, -0.0181], dtype=float32)
  828. """
  829. return _elwise(inp, mode=Elemwise.Mode.LOGSIGMOID)
  830. def logsumexp(
  831. inp: Tensor, axis: Union[int, Sequence[int]], keepdims: bool = False
  832. ) -> Tensor:
  833. r"""Calculates the logarithm of the inputs' exponential sum along the given :attr:`axis`.
  834. .. math::
  835. \text{logsumexp}(x)= \log \sum_{j=1}^{n} \exp \left(x_{j}\right)
  836. For numerical stability, the implementation follows this transformation:
  837. .. math::
  838. \text{logsumexp}(x)= \log \sum_{j=1}^{n} \exp \left(x_{j}\right)
  839. = \text{logsumexp}(x)=b+\log \sum_{j=1}^{n} \exp \left(x_{j}-b\right)
  840. where
  841. .. math::
  842. b = \max(x_j)
  843. Examples:
  844. >>> import numpy as np
  845. >>> x = Tensor(np.arange(-5, 5, dtype=np.float32)).reshape(2,5)
  846. >>> y = F.logsumexp(x, axis=1, keepdims=False)
  847. >>> y.numpy().round(decimals=4)
  848. array([-0.5481, 4.4519], dtype=float32)
  849. """
  850. max_value = max(inp.detach(), axis, keepdims=True)
  851. if keepdims:
  852. return max_value + log(sum(exp(inp - max_value), axis, keepdims))
  853. else:
  854. return squeeze(max_value, axis=None) + log(
  855. sum(exp(inp - max_value), axis, keepdims)
  856. )
  857. def _get_softmax_axis(ndim: int) -> int:
  858. if ndim in (0, 1, 3):
  859. return 0
  860. return 1
  861. def softmax(inp: Tensor, axis: Optional[int] = None) -> Tensor:
  862. r"""Applies a :math:`\text{softmax}(x)` function. :math:`\text{softmax}(x)` is defined as:
  863. .. math::
  864. \text{softmax}(x_{i}) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}
  865. It is applied to all elements along axis, and rescales elements so that
  866. they stay in the range `[0, 1]` and sum to 1.
  867. See :class:`~.module.Softmax` for more details.
  868. Examples:
  869. >>> import numpy as np
  870. >>> x = Tensor(np.arange(-5, 5, dtype=np.float32)).reshape(2,5)
  871. >>> out = F.softmax(x)
  872. >>> out.numpy().round(decimals=4)
  873. array([[0.0117, 0.0317, 0.0861, 0.2341, 0.6364],
  874. [0.0117, 0.0317, 0.0861, 0.2341, 0.6364]], dtype=float32)
  875. """
  876. if axis is None:
  877. axis = _get_softmax_axis(len(inp.shape))
  878. if isinstance(axis, list):
  879. offset = inp.max(axis=axis, keepdims=True).detach()
  880. cached = exp(inp - offset)
  881. down = sum(cached, axis=axis, keepdims=True)
  882. return cached / down
  883. else:
  884. op = builtin.Softmax(axis=axis,)
  885. (output,) = apply(op, inp)
  886. return output
  887. def group_norm(
  888. inp: Tensor,
  889. num_groups: int,
  890. affine: bool,
  891. weight: Optional[Tensor] = None,
  892. bias: Optional[Tensor] = None,
  893. eps: float = 1e-5,
  894. ):
  895. r"""Applies Group Normalization over a mini-batch of inputs as described in
  896. the paper `Group Normalization <https://arxiv.org/abs/1803.08494>`__
  897. Args:
  898. inp: input tensor.
  899. num_groups: number of groups to separate the channels into
  900. affine: whether to use weight and bias
  901. weight: must not be None when the affine is true
  902. bias: must not be None when the affine is true
  903. eps: a value added to the denominator for numerical stability. Default: 1e-5
  904. """
  905. op = builtin.GroupNorm(affine=affine, eps=eps, group=num_groups,)
  906. if affine:
  907. assert weight is not None and bias is not None
  908. return apply(op, inp, weight, bias)[0]
  909. else:
  910. return apply(op, inp)[0]
  911. def layer_norm(
  912. inp: Tensor,
  913. normalized_shape: tuple,
  914. affine: bool,
  915. weight: Optional[Tensor] = None,
  916. bias: Optional[Tensor] = None,
  917. eps: float = 1e-5,
  918. ):
  919. r"""Applies layer normalization to the input. Support tensor of any shape as input.
  920. Reference: https://arxiv.org/pdf/1803.08494.pdf.
  921. Args:
  922. inp: input tensor.
  923. normalized_shape: the shape that you want to be normalizated
  924. affine: whether to use weight and bias
  925. weight: must not be None when the affine is true
  926. bias: must not be None when the affine is true
  927. eps: a value added to the denominator for numerical stability. Default: 1e-5
  928. """
  929. if isinstance(normalized_shape, int):
  930. normalized_shape = [normalized_shape]
  931. normalized_dim = len(normalized_shape)
  932. assert normalized_dim > 0
  933. normalized_size = 1
  934. for i in range(normalized_dim):
  935. normalized_size = normalized_size * normalized_shape[i]
  936. op = builtin.LayerNorm(
  937. affine=affine,
  938. eps=eps,
  939. normalized_dim=normalized_dim,
  940. normalized_size=normalized_size,
  941. )
  942. if affine:
  943. assert weight is not None and bias is not None
  944. return apply(op, inp, weight, bias)[0]
  945. else:
  946. # assert weight is None and bias is None
  947. return apply(op, inp)[0]
  948. def batch_norm(
  949. inp: Tensor,
  950. running_mean: Tensor = None,
  951. running_var: Tensor = None,
  952. weight: Optional[Tensor] = None,
  953. bias: Optional[Tensor] = None,
  954. *,
  955. training: bool = False,
  956. momentum: float = 0.9,
  957. eps: float = 1e-5,
  958. inplace: bool = True,
  959. ):
  960. r"""Applies batch normalization to the input.
  961. Refer to :class:`~.BatchNorm2d` and :class:`~.BatchNorm1d` for more information.
  962. Args:
  963. inp: input tensor.
  964. running_mean: tensor to store running mean.
  965. running_var: tensor to store running variance.
  966. weight: scaling tensor in the learnable affine parameters.
  967. See :math:`\gamma` in :class:`~.BatchNorm2d`.
  968. bias: bias tensor in the learnable affine parameters.
  969. See :math:`\beta` in :class:`~.BatchNorm2d`.
  970. training: a boolean value to indicate whether batch norm is performed
  971. in training mode. Default: False
  972. momentum: value used for the ``running_mean`` and ``running_var``
  973. computation. Default: 0.9
  974. eps: a value added to the denominator for numerical stability. Default: 1e-5
  975. inplace: whether to update ``running_mean`` and ``running_var``
  976. inplace or return new tensors. Default: True
  977. compute_mode: When set to 'default', no special requirements will be
  978. placed on the precision of intermediate results. When set to 'float32',
  979. float32 would be used for accumulator and intermediate result, but only
  980. effective when input and output are of float16 dtype.
  981. param_dim: a value indicating in which format the parameters are.
  982. Default: 'dim_1c11', which means NCHW format.
  983. And 'dim_111c' means NHWC format.
  984. """
  985. def make_full_if_none(x, value):
  986. x_ndim = None if x is None else x.ndim
  987. # in general case, x will be returned here directly
  988. if x_ndim is not None and x_ndim != 1:
  989. return x
  990. C = inp.shape[1]
  991. pshape = (1, C, 1, 1)
  992. if x is None:
  993. x = Const(value, inp.dtype, inp.device)
  994. shape = astensor1d(pshape, inp, dtype="int32", device=inp.device)
  995. (result,) = apply(builtin.Broadcast(), x, shape)
  996. result.format = inp.format
  997. return result
  998. else:
  999. assert x_ndim == 1
  1000. shape = astensor1d(pshape, inp, dtype="int32", device=inp.device)
  1001. (result,) = apply(builtin.Reshape(), x, shape)
  1002. return result
  1003. has_mean = running_mean is not None
  1004. has_var = running_var is not None
  1005. if not training:
  1006. assert has_mean, "running_mean must be provided in inference mode"
  1007. assert has_var, "running_var must be provided in inference mode"
  1008. weight = make_full_if_none(weight, 1)
  1009. bias = make_full_if_none(bias, 0)
  1010. if not training:
  1011. op = builtin.BatchNorm(
  1012. fwd_mode=BatchNorm.FwdMode.INFERENCE, epsilon=eps, param_dim="dim_1c11"
  1013. )
  1014. ret = apply(op, inp, weight, bias, running_mean, running_var)[-1]
  1015. return ret
  1016. else:
  1017. op = builtin.BatchNorm(
  1018. avg_factor=1 - momentum, epsilon=eps, param_dim="dim_1c11"
  1019. )
  1020. if has_mean or has_var:
  1021. running_mean = make_full_if_none(running_mean, 0)
  1022. running_var = make_full_if_none(running_var, 1)
  1023. new_mean, new_var, *_, inp = apply(
  1024. op, inp, weight, bias, running_mean, running_var
  1025. )
  1026. if not has_mean:
  1027. new_mean = None
  1028. if not has_var:
  1029. new_var = None
  1030. if inplace:
  1031. if has_mean:
  1032. running_mean[...] = new_mean
  1033. if has_var:
  1034. running_var[...] = new_var
  1035. return inp
  1036. else:
  1037. return inp, new_mean, new_var
  1038. else:
  1039. inp = apply(op, inp, weight, bias)[-1]
  1040. return inp
  1041. @lru_cache(maxsize=None)
  1042. def _get_sync_bn_ops(device, dtype, eps_mode, ndim, channels):
  1043. # fmt: off
  1044. @subgraph("SyncBnStage0", dtype, device, 1)
  1045. def syncbn_stage0(inputs, f, c):
  1046. input = inputs[0]
  1047. reduce_shape = c((1, channels) + (1,) * (ndim - 2), dtype="int32", device=device)
  1048. input_shape = f(GetVarShape(), input)
  1049. input_elems = f(Reduce(mode="product", axis=0), input_shape)
  1050. reduce_elems = f(Reduce(mode="product", axis=0), reduce_shape)
  1051. reduce_size = f("//", input_elems, reduce_elems)
  1052. channel_x1s = f(Reduce(mode="sum"), input, reduce_shape)
  1053. channel_x2s = f(Reduce(mode="sum_sqr"), input, reduce_shape)
  1054. reduce_size_f = f(TypeCvt(dtype=dtype), reduce_size)
  1055. return (reduce_shape, reduce_size_f, channel_x1s, channel_x2s), (False, False, True, True)
  1056. @subgraph("SyncBnStage1", dtype, device, 7)
  1057. def syncbn_stage1(inputs, f, c):
  1058. input, reduce_size, channel_x1s, channel_x2s, eps = inputs[0:5]
  1059. weight, bias = inputs[5:7]
  1060. channel_mean = f("/", channel_x1s, reduce_size)
  1061. channel_var =\
  1062. f("+", f("/", f("**", channel_x1s, c(2)),
  1063. f("-", f("*", reduce_size, reduce_size))),
  1064. f("/", channel_x2s, reduce_size))
  1065. invsqrt_channel_var = f("**", f(eps_mode, channel_var, eps), c(-0.5))
  1066. inv_var_wt = f("*", invsqrt_channel_var, weight)
  1067. neg_channel_mean = f("-", channel_mean)
  1068. outvar =\
  1069. f("fma3", input, inv_var_wt,
  1070. f("+", f("*", neg_channel_mean, inv_var_wt),
  1071. bias))
  1072. return (outvar, channel_mean, channel_var), (True, True, True)
  1073. @subgraph("SyncBnStage1Inference", dtype, device, 6)
  1074. def syncbn_stage1_inference(inputs, f, c):
  1075. input, channel_mean, channel_var, eps = inputs[0:4]
  1076. weight, bias = inputs[4:6]
  1077. invsqrt_channel_var = f("**", f(eps_mode, channel_var, eps), c(-0.5))
  1078. inv_var_wt = f("*", invsqrt_channel_var, weight)
  1079. neg_channel_mean = f("-", channel_mean)
  1080. outvar =\
  1081. f("+", f("*", input, inv_var_wt),
  1082. f("+", f("*", neg_channel_mean, inv_var_wt),
  1083. bias))
  1084. return (outvar,), (True,)
  1085. @subgraph("SyncBnStage2", dtype, device, 7)
  1086. def syncbn_stage2(inputs, f, c):
  1087. running_mean, running_var, momentum = inputs[0:3]
  1088. reduce_size, channel_x1s, channel_x2s, channel_mean = inputs[3:7]
  1089. c1_minus_momentum = f("-", c(1), momentum)
  1090. reduce_size_minus_c1 = f("-", reduce_size, c(1))
  1091. running_mean = f("fma4",
  1092. running_mean, momentum,
  1093. c1_minus_momentum, channel_mean,
  1094. )
  1095. channel_variance_unbiased =\
  1096. f("+", f("/", f("**", channel_x1s, c(2)),
  1097. f("*", f("-", reduce_size),
  1098. reduce_size_minus_c1)),
  1099. f("/", channel_x2s,
  1100. reduce_size_minus_c1))
  1101. running_var = f("fma4",
  1102. running_var, momentum,
  1103. c1_minus_momentum, channel_variance_unbiased
  1104. )
  1105. return (running_mean, running_var), (True, True)
  1106. @subgraph("SyncBnConcatStats", dtype, device, 3)
  1107. def syncbn_concat_stats(inputs, f, c):
  1108. reduce_size, channel_x1s, channel_x2s = inputs[0:3]
  1109. reduce_size = f(builtin.Broadcast(), reduce_size, c([1]*ndim, dtype="int32"))
  1110. stats = f(builtin.Concat(axis=1, comp_node=device), reduce_size, channel_x1s, channel_x2s)
  1111. return (stats,), (True,)
  1112. @subgraph("SyncBnSplitStats", dtype, device, 1)
  1113. def syncbn_split_stats(inputs, f, c):
  1114. stats = inputs[0]
  1115. c_1 = c(1, dtype="int32")
  1116. channel_x1s_end = c(channels+1, dtype="int32")
  1117. def _subtensor(src, axis, begin, end):
  1118. items = (axis, (begin is not None), (end is not None), False, False),
  1119. args = ()
  1120. if begin is not None:
  1121. args += begin,
  1122. if end is not None:
  1123. args += end,
  1124. return f(builtin.Subtensor(items=items), src, *args)
  1125. reduce_size = _subtensor(stats, 1, None, c_1)
  1126. channel_x1s = _subtensor(stats, 1, c_1, channel_x1s_end)
  1127. channel_x2s = _subtensor(stats, 1, channel_x1s_end, None)
  1128. reduce_size = f(builtin.Reshape(), reduce_size, c_1)
  1129. return (reduce_size, channel_x1s, channel_x2s), (False, True, True)
  1130. # fmt: on
  1131. return (
  1132. syncbn_stage0,
  1133. syncbn_stage1,
  1134. syncbn_stage1_inference,
  1135. syncbn_stage2,
  1136. syncbn_concat_stats,
  1137. syncbn_split_stats,
  1138. )
  1139. def sync_batch_norm(
  1140. inp: Tensor,
  1141. running_mean: Tensor,
  1142. running_var: Tensor,
  1143. weight: Optional[Tensor] = None,
  1144. bias: Optional[Tensor] = None,
  1145. training: bool = False,
  1146. momentum: Union[float, Tensor] = 0.9,
  1147. eps: float = 1e-5,
  1148. eps_mode="additive",
  1149. group=WORLD,
  1150. ) -> Tensor:
  1151. r"""Applies synchronized batch normalization to the input.
  1152. Refer to :class:`~.BatchNorm2d` and :class:`~.BatchNorm1d` for more information.
  1153. Args:
  1154. inp: input tensor.
  1155. running_mean: tensor to store running mean.
  1156. running_var: tensor to store running variance.
  1157. weight: scaling tensor in the learnable affine parameters.
  1158. See :math:`\gamma` in :class:`~.BatchNorm2d`.
  1159. bias: bias tensor in the learnable affine parameters.
  1160. See :math:`\beta` in :class:`~.BatchNorm2d`.
  1161. training: a boolean value to indicate whether batch norm is performed
  1162. in traning mode. Default: False
  1163. momentum: value used for the ``running_mean`` and ``running_var``
  1164. computation. Default: 0.9
  1165. eps: a value added to the denominator for numerical stability.
  1166. Default: 1e-5
  1167. eps_mode: mode of calculation for eps, "max" or "additive".
  1168. Default: "additive"
  1169. group: communication group, caculate mean and variance between this group.
  1170. Default: :obj:`~megengine.distributed.WORLD`
  1171. """
  1172. _eps_mode = eps_mode.lower()
  1173. assert _eps_mode in {"max", "additive"}, "unknown eps_mode: {}".format(eps_mode)
  1174. if _eps_mode == "additive" and not (is_distributed() and training):
  1175. return batch_norm(
  1176. inp,
  1177. running_mean,
  1178. running_var,
  1179. weight,
  1180. bias,
  1181. training=training,
  1182. momentum=momentum,
  1183. eps=eps,
  1184. )
  1185. if amp._enabled:
  1186. inp, weight, bias, running_mean, running_var = cast_tensors(
  1187. inp, weight, bias, running_mean, running_var, promote=True
  1188. )
  1189. _channels = make_shape_tuple(inp.shape)[1]
  1190. _ndim = inp.ndim
  1191. _device = inp.device
  1192. _dtype = inp.dtype
  1193. if _ndim != 4:
  1194. raise NotImplementedError("sync_batch_norm for ndim != 4")
  1195. def _make_full_if_none(x, value):
  1196. if x is None:
  1197. x = Const(value, inp.dtype, _device)
  1198. (result,) = apply(builtin.Broadcast(), x, reduce_shape)
  1199. return result
  1200. elif x.ndim == 1:
  1201. (result,) = apply(builtin.Reshape(), x, reduce_shape)
  1202. return result
  1203. return x
  1204. (
  1205. syncbn_stage0,
  1206. syncbn_stage1,
  1207. syncbn_stage1_inference,
  1208. syncbn_stage2,
  1209. syncbn_concat_stats,
  1210. syncbn_split_stats,
  1211. ) = _get_sync_bn_ops(_device, _dtype, eps_mode, _ndim, _channels)
  1212. reduce_shape, reduce_size, channel_x1s, channel_x2s = apply(syncbn_stage0(), inp)
  1213. eps = convert_single_value(eps, dtype=inp.dtype, device=inp.device)
  1214. weight = _make_full_if_none(weight, 1)
  1215. bias = _make_full_if_none(bias, 0)
  1216. if training:
  1217. if is_distributed():
  1218. # reduce all nodes' data to calculate mean and variance
  1219. (stat,) = apply(
  1220. syncbn_concat_stats(), reduce_size, channel_x1s, channel_x2s
  1221. )
  1222. stat = all_reduce_sum(stat, group)
  1223. reduce_size, channel_x1s, channel_x2s = apply(syncbn_split_stats(), stat)
  1224. outvar, channel_mean, *_ = apply(
  1225. syncbn_stage1(),
  1226. inp,
  1227. reduce_size,
  1228. channel_x1s,
  1229. channel_x2s,
  1230. eps,
  1231. weight,
  1232. bias,
  1233. )
  1234. else:
  1235. assert running_var is not None and running_mean is not None
  1236. channel_mean = running_mean
  1237. channel_var = running_var
  1238. outvar, *_ = apply(
  1239. syncbn_stage1_inference(), inp, channel_mean, channel_var, eps, weight, bias
  1240. )
  1241. # outvar = output * weight + bias
  1242. # where output = inp * invsqrt_channel_variance + (
  1243. # -channel_mean * invsqrt_channel_variance
  1244. # )
  1245. # Manually expand output for gopt
  1246. if training and running_var is not None and running_mean is not None:
  1247. momentum = convert_single_value(momentum, dtype=inp.dtype, device=inp.device)
  1248. running_mean[...], running_var[...] = apply(
  1249. syncbn_stage2(),
  1250. running_mean,
  1251. running_var,
  1252. momentum,
  1253. reduce_size,
  1254. channel_x1s,
  1255. channel_x2s,
  1256. channel_mean,
  1257. )
  1258. if amp._enabled:
  1259. outvar = outvar.astype("float16")
  1260. return outvar
  1261. def dropout(inp: Tensor, drop_prob: float, training: bool = True) -> Tensor:
  1262. r"""Returns a new tensor where each of the elements are randomly set to zero
  1263. with probability P = ``drop_prob``. Optionally rescale the output tensor if ``training`` is True.
  1264. Args:
  1265. inp: input tensor.
  1266. drop_prob: probability to drop (set to zero) a single element.
  1267. training: the default behavior of ``dropout`` during training is to rescale the output,
  1268. then it can be replaced by an :class:`~.module.identify.Identity` during inference. Default: True
  1269. Returns:
  1270. the ouput tensor
  1271. Examples:
  1272. >>> import numpy as np
  1273. >>> data = Tensor(np.ones(10000000, dtype=np.float32))
  1274. >>> out = F.nn.dropout(data, 1.0 / 3.0, training=True)
  1275. >>> assert not out.numpy().all()
  1276. >>> out = F.nn.dropout(data, 1.0 / 3.0, training=False)
  1277. >>> assert out.numpy().all()
  1278. >>> out.numpy()
  1279. array([1., 1., 1., ..., 1., 1., 1.], dtype=float32)
  1280. """
  1281. assert 0 <= drop_prob < 1
  1282. if not training or drop_prob == 0:
  1283. return inp
  1284. # model in training mode, e.g. model.train()
  1285. op = Dropout(drop_prob=drop_prob, seed=_get_global_rng_seed(), handle=0)
  1286. outputs = apply(op, inp)
  1287. return outputs[0]
  1288. def one_hot(inp: Tensor, num_classes: int) -> Tensor:
  1289. r"""Performs one-hot encoding for the input tensor.
  1290. Args:
  1291. inp: input tensor.
  1292. num_classes: number of classes denotes the last dimension of the output tensor.
  1293. Examples:
  1294. >>> import numpy as np
  1295. >>> x = Tensor(np.arange(1, 4, dtype=np.int32))
  1296. >>> F.one_hot(x, num_classes=4)
  1297. Tensor([[0 1 0 0]
  1298. [0 0 1 0]
  1299. [0 0 0 1]], dtype=int32, device=xpux:0)
  1300. """
  1301. zeros_tensor = zeros(
  1302. list(inp.shape) + [num_classes], dtype=inp.dtype, device=inp.device
  1303. )
  1304. ones_tensor = ones(list(inp.shape) + [1], dtype=inp.dtype, device=inp.device)
  1305. op = builtin.IndexingSetOneHot(axis=inp.ndim, ndim=inp.ndim)
  1306. (result,) = apply(op, zeros_tensor, inp, ones_tensor)
  1307. return result
  1308. def embedding(
  1309. inp: Tensor,
  1310. weight: Tensor,
  1311. padding_idx: Optional[int] = None,
  1312. max_norm: Optional[float] = None,
  1313. norm_type: Optional[float] = None,
  1314. ):
  1315. r"""Applies lookup table for embedding.
  1316. Args:
  1317. inp: tensor with indices.
  1318. weight: learnable weights which embeds from.
  1319. padding_idx: should be set to None, not supported now.
  1320. max_norm: should be set to None, not supported now.
  1321. norm_type: should be set to None, not supported now.
  1322. Refer to :class:`~.module.Embedding` for more information.
  1323. """
  1324. if padding_idx is not None:
  1325. raise ValueError("Not support padding_idx Now!")
  1326. if max_norm is not None or norm_type is not None:
  1327. raise ValueError("Not support weight normlization Now!")
  1328. dest_shp = list(inp.shape) + [weight.shape[-1]]
  1329. return weight[inp.reshape(-1)].reshape(dest_shp)
  1330. def indexing_one_hot(
  1331. src: Tensor, index: Tensor, axis: int = 1, keepdims=False
  1332. ) -> Tensor:
  1333. r"""One-hot indexing for some axes.
  1334. Args:
  1335. src: input tensor.
  1336. index: index tensor.
  1337. axis: axis on src for which values in index index. Default: 1
  1338. keepdims: whether not to remove the axis in result. Default: False
  1339. Examples:
  1340. >>> src = Tensor([[1.0, 2.0]])
  1341. >>> index = Tensor([0])
  1342. >>> val = F.indexing_one_hot(src, index)
  1343. >>> val.numpy()
  1344. array([1.], dtype=float32)
  1345. """
  1346. assert isinstance(src, Tensor), "src must be of Tensor type"
  1347. op = builtin.IndexingOneHot(axis=axis, ndim=src.ndim)
  1348. index = convert_single_value(index, dtype="int32", device=src.device)
  1349. (result,) = apply(op, src, index)
  1350. if not keepdims:
  1351. result = squeeze(result, axis)
  1352. return result
  1353. def sliding_window(
  1354. inp: Tensor,
  1355. kernel_size: Union[int, Tuple[int, int]],
  1356. padding: Union[int, Tuple[int, int]] = 0,
  1357. stride: Union[int, Tuple[int, int]] = 1,
  1358. dilation: Union[int, Tuple[int, int]] = 1,
  1359. ) -> Tensor:
  1360. r"""Extracts sliding local blocks from a batched input tensor.
  1361. Refer to :class:`~.module.sliding_window.SlidingWindow` for more information.
  1362. Args:
  1363. inp: input tensor.
  1364. kernel_size: size of the window.
  1365. padding: implicit zero padding added on both sides of input. Default: 0
  1366. stride: stride of the window. Default: 1
  1367. dilation: dilation of the window. Default: 1
  1368. """
  1369. padding_h, padding_w = expand_hw(padding)
  1370. stride_h, stride_w = expand_hw(stride)
  1371. dilation_h, dilation_w = expand_hw(dilation)
  1372. window_h, window_w = expand_hw(kernel_size)
  1373. op = builtin.Images2Neibs(
  1374. pad_h=padding_h,
  1375. pad_w=padding_w,
  1376. stride_h=stride_h,
  1377. stride_w=stride_w,
  1378. dilate_h=dilation_h,
  1379. dilate_w=dilation_w,
  1380. window_h=window_h,
  1381. window_w=window_w,
  1382. )
  1383. (output,) = apply(op, inp)
  1384. return output
  1385. def sliding_window_transpose(
  1386. inp: Tensor,
  1387. output_size: Union[int, Tuple[int, int]],
  1388. kernel_size: Union[int, Tuple[int, int]],
  1389. padding: Union[int, Tuple[int, int]] = 0,
  1390. stride: Union[int, Tuple[int, int]] = 1,
  1391. dilation: Union[int, Tuple[int, int]] = 1,
  1392. ) -> Tensor:
  1393. r"""Sum over the sliding windows on the corresponding input location.
  1394. Refer to :class:`~.module.sliding_window.SlidingWindowTranspose` for more information.
  1395. Args:
  1396. inp: input tensor.
  1397. output_size: shape of output tensor.
  1398. kernel_size: size of the window.
  1399. padding: implicit zero padding added on both sides of input. Default: 0
  1400. stride: stride of the window. Default: 1
  1401. dilation: dilation of the window. Default: 1
  1402. """
  1403. output_h, output_w = expand_hw(output_size)
  1404. padding_h, padding_w = expand_hw(padding)
  1405. stride_h, stride_w = expand_hw(stride)
  1406. dilation_h, dilation_w = expand_hw(dilation)
  1407. window_h, window_w = expand_hw(kernel_size)
  1408. expected_h = (
  1409. output_h + 2 * padding_h - dilation_h * (window_h - 1) - 1
  1410. ) // stride_h + 1
  1411. expected_w = (
  1412. output_w + 2 * padding_w - dilation_w * (window_w - 1) - 1
  1413. ) // stride_w + 1
  1414. assert inp.ndim == 6, "the input dimension of sliding_window_transpose should be 6"
  1415. assert (
  1416. inp.shape[2] == expected_h and inp.shape[3] == expected_w
  1417. ), "the input shape and output size do not match"
  1418. op = builtin.SlidingWindowTranspose(
  1419. out_h=output_h,
  1420. out_w=output_w,
  1421. pad_h=padding_h,
  1422. pad_w=padding_w,
  1423. stride_h=stride_h,
  1424. stride_w=stride_w,
  1425. dilate_h=dilation_h,
  1426. dilate_w=dilation_w,
  1427. window_h=window_h,
  1428. window_w=window_w,
  1429. )
  1430. (output,) = apply(op, inp)
  1431. return output
  1432. def pad(
  1433. src: Tensor,
  1434. pad_width: Tuple[Tuple[int, int], ...],
  1435. mode: str = "constant",
  1436. constant_value: float = 0.0,
  1437. ) -> Tensor:
  1438. r"""Pads the input tensor.
  1439. Args:
  1440. pad_width: A tuple. Each element in the tuple is the tuple of 2-elements,
  1441. the 2 elements represent the padding size on both sides of the current dimension, ``(front_offset, back_offset)``
  1442. mode: One of the following string values. Default: ``'constant'``
  1443. * ``'constant'``: Pads with a constant value.
  1444. * ``'reflect'``: Pads with the reflection of the tensor mirrored on the first and last values of the tensor along each axis.
  1445. * ``'replicate'``: Pads with the edge values of tensor.
  1446. constant_val: Fill value for ``'constant'`` padding. Default: 0
  1447. Examples:
  1448. >>> import numpy as np
  1449. >>> inp = Tensor([[1., 2., 3.],[4., 5., 6.]])
  1450. >>> inp
  1451. Tensor([[1. 2. 3.]
  1452. [4. 5. 6.]], device=xpux:0)
  1453. >>> F.nn.pad(inp, pad_width=((1, 1),), mode="constant")
  1454. Tensor([[0. 0. 0.]
  1455. [1. 2. 3.]
  1456. [4. 5. 6.]
  1457. [0. 0. 0.]], device=xpux:0)
  1458. >>> F.nn.pad(inp, pad_width=((1, 1),), mode="constant", constant_value=9)
  1459. Tensor([[9. 9. 9.]
  1460. [1. 2. 3.]
  1461. [4. 5. 6.]
  1462. [9. 9. 9.]], device=xpux:0)
  1463. >>> F.nn.pad(inp, pad_width=((1, 1), (1, 2)), mode="reflect")
  1464. Tensor([[5. 4. 5. 6. 5. 4.]
  1465. [2. 1. 2. 3. 2. 1.]
  1466. [5. 4. 5. 6. 5. 4.]
  1467. [2. 1. 2. 3. 2. 1.]], device=xpux:0)
  1468. >>> F.nn.pad(inp, pad_width=((1, 1), (1, 2)), mode="replicate")
  1469. Tensor([[1. 1. 2. 3. 3. 3.]
  1470. [1. 1. 2. 3. 3. 3.]
  1471. [4. 4. 5. 6. 6. 6.]
  1472. [4. 4. 5. 6. 6. 6.]], device=xpux:0)
  1473. """
  1474. p_offsets = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
  1475. assert mode.lower() in ["constant", "edge", "replicate", "reflect"]
  1476. if mode.lower() == "edge":
  1477. mode = "replicate"
  1478. for i in range(0, len(pad_width)):
  1479. p_offsets[i * 2] = pad_width[i][0]
  1480. p_offsets[i * 2 + 1] = pad_width[i][1]
  1481. op = builtin.Padding(
  1482. front_offset_dim0=p_offsets[0],
  1483. front_offset_dim1=p_offsets[2],
  1484. front_offset_dim2=p_offsets[4],
  1485. front_offset_dim3=p_offsets[6],
  1486. front_offset_dim4=p_offsets[8],
  1487. front_offset_dim5=p_offsets[10],
  1488. front_offset_dim6=p_offsets[12],
  1489. back_offset_dim0=p_offsets[1],
  1490. back_offset_dim1=p_offsets[3],
  1491. back_offset_dim2=p_offsets[5],
  1492. back_offset_dim3=p_offsets[7],
  1493. back_offset_dim4=p_offsets[9],
  1494. back_offset_dim5=p_offsets[11],
  1495. back_offset_dim6=p_offsets[13],
  1496. padding_val=constant_value,
  1497. padding_mode=mode.upper(),
  1498. )
  1499. (output,) = apply(op, src)
  1500. return output
  1501. def local_response_norm(
  1502. inp: Tensor,
  1503. kernel_size: int = 5,
  1504. k: float = 2.0,
  1505. alpha: float = 1e-4,
  1506. beta: float = 0.75,
  1507. ) -> Tensor:
  1508. r"""
  1509. Apply local response normalization to the input tensor.
  1510. Args:
  1511. kernel_size: the size of the kernel to apply LRN on.
  1512. k: hyperparameter k. The default vaule is 2.0.
  1513. alpha: hyperparameter alpha. The default value is 1e-4.
  1514. beta: hyperparameter beta. The default value is 0.75.
  1515. Example:
  1516. >>> import numpy as np
  1517. >>> inp = Tensor(np.arange(25, dtype=np.float32).reshape(1,1,5,5))
  1518. >>> GT = np.array([[[[ 0., 0.999925, 1.9994003, 2.9979765, 3.9952066],
  1519. ... [ 4.9906454, 5.983851, 6.974385, 7.961814, 8.945709 ],
  1520. ... [ 9.925651, 10.90122, 11.872011, 12.837625, 13.7976675],
  1521. ... [14.751757, 15.699524, 16.640602, 17.574642, 18.501305 ],
  1522. ... [19.420258, 20.331186, 21.233786, 22.127764, 23.012836 ]]]])
  1523. >>> out = F.local_response_norm(inp, kernel_size=3, k=1.0, alpha=1e-4, beta=0.75)
  1524. >>> np.testing.assert_allclose(GT, out.numpy(), rtol=1e-6, atol=1e-6)
  1525. """
  1526. op = builtin.LRN(n=kernel_size, k=k, alpha=alpha, beta=beta,)
  1527. (output,) = apply(op, inp)
  1528. return output
  1529. @lru_cache(maxsize=None)
  1530. def _get_layerPixelShuffle(device, dtype, dim_order):
  1531. @subgraph("LayerPixelShuffle", dtype, device, 3)
  1532. def layerPixelShuffle(inputs, f, c):
  1533. inp, shape_0, shape_1 = inputs
  1534. inp = f(Reshape(), inp, shape_0)
  1535. inp = f(Dimshuffle(dim_order), inp)
  1536. oup = f(Reshape(), inp, shape_1)
  1537. return (oup,), (True,)
  1538. return layerPixelShuffle
  1539. def layerPixelShuffle_traceable(inp, upscale_factor):
  1540. assert upscale_factor > 0, "upscale_factor should larger than 0"
  1541. assert inp.ndim >= 3, "the input dimension of pixel_shuffle should be larger than 3"
  1542. assert (
  1543. inp.shape[-3] % (upscale_factor ** 2) == 0
  1544. ), "the -3 dimension should be divided by (upscale_factor ** 2)"
  1545. _device = inp.device
  1546. _dtype = inp.dtype
  1547. shape_ori = inp.shape
  1548. high_dim = shape_ori[:-3]
  1549. square = upscale_factor ** 2
  1550. n = 1
  1551. for item in high_dim:
  1552. n *= item
  1553. shape_0 = (
  1554. n,
  1555. int(shape_ori[-3] / square),
  1556. upscale_factor,
  1557. upscale_factor,
  1558. shape_ori[-2],
  1559. shape_ori[-1],
  1560. )
  1561. shape_1 = (
  1562. *high_dim,
  1563. int(shape_ori[-3] / square),
  1564. shape_ori[-2] * upscale_factor,
  1565. shape_ori[-1] * upscale_factor,
  1566. )
  1567. dim_order = (0, 1, 4, 2, 5, 3)
  1568. layerPixelShuffle = _get_layerPixelShuffle(_device, _dtype, dim_order)
  1569. shape_0 = convert_single_value(shape_0, device=inp.device)
  1570. shape_1 = convert_single_value(shape_1, device=inp.device)
  1571. outvar, *_ = apply(layerPixelShuffle(), inp, shape_0, shape_1)
  1572. return outvar
  1573. def pixel_shuffle(inp: Tensor, upscale_factor: int) -> Tensor:
  1574. """
  1575. Rearranges elements in a tensor of shape `(..., C * r^2, H, W)` to a tensor of
  1576. shape `(..., C, H * r, W * r)`, where `r` is an upscale factor, where `...` is
  1577. zero or more batch dimensions.
  1578. :param inp: input tensor.
  1579. :param upscale_factor: upscale factor of pixel_shuffle.
  1580. :return: output tensor.
  1581. """
  1582. return pixel_shuffle_cpp(inp, upscale_factor, layerPixelShuffle_traceable)
  1583. def region_restricted_conv(
  1584. inp: Tensor,
  1585. weight: Tensor,
  1586. rin: Tensor,
  1587. rout: Tensor,
  1588. bias: Optional[Tensor] = None,
  1589. stride: Union[int, Tuple[int, int, int]] = 1,
  1590. padding: Union[int, Tuple[int, int, int]] = 0,
  1591. dilation: Union[int, Tuple[int, int, int]] = 1,
  1592. groups: int = 1,
  1593. conv_mode: str = "cross_correlation",
  1594. compute_mode="default",
  1595. ) -> Tensor:
  1596. r"""Region Restricted convolution operation.
  1597. Refer to :class:`~.RegionRestrictedConv` for more information.
  1598. Args:
  1599. inp: feature map of the convolution operation.
  1600. weight: convolution kernel.
  1601. rin: input mask
  1602. rout: output mask
  1603. bias: bias added to the result of convolution (if given).
  1604. stride: stride of the 2D region restricted convolution operation. Default: 1
  1605. padding: size of the paddings added to the input on both sides of its
  1606. spatial dimensions. Only zero-padding is supported. Default: 0
  1607. dilation: dilation of the 2D convolution operation. Default: 1
  1608. groups: number of groups into which the input and output channels are divided,
  1609. so as to perform a ``grouped convolution``. When ``groups`` is not 1,
  1610. ``in_channels`` and ``out_channels`` must be divisible by ``groups``,
  1611. and the shape of weight should be ``(groups, out_channel // groups,
  1612. in_channels // groups, depth, height, width)``. Default: 1
  1613. conv_mode: supports "cross_correlation". Default: "cross_correlation"
  1614. Returns:
  1615. output tensor.
  1616. """
  1617. assert conv_mode.lower() == "cross_correlation"
  1618. pad_h, pad_w = expand_hw(padding)
  1619. stride_h, stride_w = expand_hw(stride)
  1620. dilate_h, dilate_w = expand_hw(dilation)
  1621. sparse_type = "group"
  1622. assert groups > 0, (
  1623. "RegionRestrictedConv expected grouped conv mode, \
  1624. which requires groups > 0, but got groups=%d"
  1625. % (groups)
  1626. )
  1627. op = builtin.RegionRestrictedConvolution(
  1628. stride_h=stride_h,
  1629. stride_w=stride_w,
  1630. pad_h=pad_h,
  1631. pad_w=pad_w,
  1632. dilate_h=dilate_h,
  1633. dilate_w=dilate_w,
  1634. mode=conv_mode,
  1635. compute_mode=compute_mode,
  1636. sparse=sparse_type,
  1637. )
  1638. (output,) = apply(op, inp, weight, rin, rout)
  1639. if bias is not None:
  1640. output += bias
  1641. return output
  1642. from .quantized import conv_bias_activation # isort:skip
  1643. from .loss import * # isort:skip
  1644. from .metric import * # isort:skip
  1645. from .vision import * # isort:skip