You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

vm_me.py 26 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """VM implementations based on numpy."""
  16. import numpy as np
  17. from mindspore._checkparam import Validator as validator
  18. def avg_pooling(x, pool_h, pool_w, stride):
  19. """
  20. Applies average pooling over an input array.
  21. Args:
  22. x (numpy.ndarray): The input array to be average pooled.
  23. pool_h (int): Height of the pooling window.
  24. pool_w (int): Width of the pooling window.
  25. stride (int): The stride of the sliding window.
  26. Returns:
  27. numpy.ndarray, an output array after applying average pooling on input array.
  28. """
  29. validator.check_positive_int(stride, "stride")
  30. num, channel, height, width = x.shape
  31. out_h = (height - pool_h) // stride + 1
  32. out_w = (width - pool_w) // stride + 1
  33. col = im2col(x, pool_h, pool_w, stride)
  34. col = col.reshape(-1, pool_h * pool_w)
  35. out = np.mean(col, axis=1)
  36. out = out.reshape((num, out_h, out_w, channel)).transpose(0, 3, 1, 2)
  37. return out
  38. def avg_pool_grad(dout, origin_shape, pool_h, pool_w, stride):
  39. """
  40. Gets grad of average pooling.
  41. Args:
  42. x (numpy.ndarray): The input array to be average pooled.
  43. dout (numpy.ndarray): The grad of pre-layer.
  44. pool_h (int): Height of the pooling window.
  45. pool_w (int): Width of the pooling window.
  46. stride (int): The stride of the sliding window.
  47. Returns:
  48. numpy.ndarray, grad of avgerage pooling.
  49. """
  50. # pylint: disable=unused-argument
  51. _, _, height, width = dout.shape
  52. dx = np.zeros(origin_shape)
  53. for i in range(height):
  54. for j in range(width):
  55. dx[:, :, i:(i + pool_h), j:(j + pool_w)] += np.ones((pool_h, pool_w))
  56. return dx
  57. def _batch_norm(x, scale, shift, running_mean=None, running_var=None,
  58. eps=1e-05, momentum=0.1, is_training=True):
  59. """Batch normalization over an array."""
  60. _, c_h_w = x.shape
  61. # Handle running_mean and running_var are not None
  62. # if running_mean is None:
  63. # running_mean = np.zeros(c_h_w)
  64. # running_var = np.zeros(c_h_w)
  65. running_mean = np.zeros(c_h_w)
  66. running_var = np.zeros(c_h_w)
  67. if np.ndim(scale) > 0:
  68. scale = scale.mean()
  69. if np.ndim(shift) > 0:
  70. shift = shift.mean()
  71. if is_training:
  72. x_mean = np.mean(x, axis=0)
  73. x_var = np.var(x, axis=0)
  74. # Normalization followed by Affine transformation
  75. x_norm = (x - x_mean) / np.sqrt(x_var + eps)
  76. # Estimate running average of mean and variance to use at test time
  77. running_mean = momentum * running_mean + (1 - momentum) * x_mean
  78. running_var = momentum * running_var + (1 - momentum) * x_var
  79. else:
  80. # normalize using running average
  81. x_norm = (x - running_mean) / np.sqrt(running_var + eps)
  82. x_mean = running_mean
  83. x_var = running_var
  84. out = scale * x_norm + shift
  85. return out, x_mean, x_var, running_mean, running_var
  86. def batch_norm(x, scale=1, shift=0, mean=None, variance=None,
  87. eps=1e-05, momentum=0.1, is_training=True):
  88. """Batch normalization over an array."""
  89. input_shape = x.shape
  90. if x.ndim != 2:
  91. batch_num = x.shape[0]
  92. x = x.reshape(batch_num, -1)
  93. out, _, _, running_mean, running_var = _batch_norm(x, scale, shift, mean, variance, \
  94. eps, momentum, is_training)
  95. return out.reshape(*input_shape), np.array(scale), np.array(shift), running_mean, running_var
  96. def _batch_norm_grad(dout, x, scale, save_mean, save_inv_variance, \
  97. eps=1e-05, momentum=0.1, is_training=True):
  98. """Batch normalization over an array."""
  99. if x.ndim != 2:
  100. batch_num = x.shape[0]
  101. x = x.reshape(batch_num, -1)
  102. if np.ndim(scale) > 0:
  103. scale = scale.mean()
  104. x_norm, x_mean, x_var, _, _ = _batch_norm(x, scale, shift=0, running_mean=save_mean, \
  105. running_var=save_inv_variance, \
  106. eps=eps, momentum=momentum, is_training=is_training)
  107. batch_size = x.shape[0]
  108. dx_norm = scale * dout
  109. dvar = np.sum(dx_norm * (x - x_mean) * ((x_var + eps) ** (-3.0 / 2)) * (-1.0 / 2), axis=0)
  110. dmean = np.sum(dx_norm * (-1.0 / np.sqrt(x_var + eps)), axis=0) \
  111. + dvar * (np.sum(-2 * (x - x_mean), axis=0) * (1.0 / batch_size))
  112. dx = dx_norm * (1.0 / np.sqrt(x_var + eps)) + dvar * (2.0 * (x - x_mean) / batch_size) + dmean * (1.0 / batch_size)
  113. dgamma = np.sum(dout * x_norm, axis=0)
  114. dbeta = np.sum(dout, axis=0)
  115. return dx, dgamma, dbeta
  116. def batch_norm_grad(dy, x, scale, save_mean, save_inv_variance):
  117. """Batch normalization over an array."""
  118. if dy.ndim != 2:
  119. batch_size = dy.shape[0]
  120. dy = dy.reshape(batch_size, -1)
  121. dx, dgamma, dbeta = _batch_norm_grad(dy, x, scale, save_mean, save_inv_variance)
  122. input_shape = x.shape
  123. dx = dx.reshape(*input_shape)
  124. return dx, dgamma, dbeta
  125. def col2im(col, input_shape, filter_h, filter_w, stride=1, pad=0):
  126. """Rearranges a row vector to an image."""
  127. if isinstance(stride, int):
  128. stride_h = stride
  129. stride_w = stride
  130. elif isinstance(stride, tuple) and len(stride) == 2:
  131. stride_h = stride[0]
  132. stride_w = stride[1]
  133. elif isinstance(stride, tuple) and len(stride) == 4:
  134. stride_h = stride[2]
  135. stride_w = stride[3]
  136. else:
  137. raise ValueError(f"The \'stride\' should be an int number or "
  138. f"a tuple of two or four int numbers, but got {stride}")
  139. if isinstance(pad, int):
  140. pad_top = pad
  141. pad_bottom = pad
  142. pad_left = pad
  143. pad_right = pad
  144. elif isinstance(pad, tuple) and len(pad) == 2:
  145. pad_top = pad[0]
  146. pad_bottom = pad[0]
  147. pad_left = pad[1]
  148. pad_right = pad[1]
  149. elif isinstance(pad, tuple) and len(pad) == 4:
  150. pad_top, pad_bottom, pad_left, pad_right = pad
  151. else:
  152. raise ValueError(f"The \'pad\' should be an int number or "
  153. f"a tuple of two or four int numbers, but got {pad}")
  154. batch_num, channel, height, width = input_shape
  155. out_h = (height + pad_top + pad_bottom - filter_h) // stride_h + 1
  156. out_w = (width + pad_left + pad_right - filter_w) // stride_w + 1
  157. col = col.reshape(batch_num, out_h, out_w, channel, filter_h, filter_w) \
  158. .transpose(0, 3, 4, 5, 1, 2)
  159. img = np.zeros((batch_num,
  160. channel,
  161. height + pad_top + pad_bottom + stride_h - 1,
  162. width + pad_left + pad_right + stride_w - 1)) \
  163. .astype(col.dtype)
  164. for y in range(filter_h):
  165. y_max = y + stride_h * out_h
  166. for x in range(filter_w):
  167. x_max = x + stride_h * out_w
  168. img[:, :, y:y_max:stride_h, x:x_max:stride_h] += col[:, :, y, x, :, :]
  169. return img[:, :, pad_top:height + pad_bottom, pad_left:width + pad_right]
  170. def convolve(x, w, b=None, pad_mode="valid"):
  171. """
  172. Gets the discrete, linear convolution of two one-dimensional sequences.
  173. Args:
  174. x (numpy.ndarray): One-dimensional input array.
  175. w (numpy.ndarray): One-dimensional input array.
  176. b (numpy.ndarray): One-dimensional input array. Default: None.
  177. pad_mode (str): Padding mode which can be: "full" means returns the
  178. convolution at each point of overlap, with an output shape
  179. of (N+M-1,); "same" means returns output of length max(M, N);
  180. Amd "valid" means returns output of length max(M, N) - min(M, N)
  181. + 1. Default: "valid".
  182. Returns:
  183. numpy.ndarray, discrete, linear convolution of x and w, then plus b.
  184. """
  185. if pad_mode not in {"same", "valid"}:
  186. pad_mode = "full"
  187. y = np.convolve(x, w, pad_mode)
  188. if b:
  189. y += b
  190. return y
  191. def conv2d(x, weight, bias=None, stride=1, pad=0,
  192. dilation=1, groups=1, padding_mode='zeros'):
  193. """Convolution 2D."""
  194. # pylint: disable=unused-argument
  195. validator.check_value_type('stride', stride, (int, tuple))
  196. if isinstance(stride, int):
  197. stride = (stride, stride)
  198. elif len(stride) == 4:
  199. stride = (stride[2], stride[3])
  200. if len(stride) != 2 or (not isinstance(stride[0], int)) or \
  201. (not isinstance(stride[1], int)) or \
  202. stride[0] < 1 or stride[1] < 1:
  203. raise ValueError(f"The \'stride\' of \'conv2d\' should be an positive int number or "
  204. f"a tuple of two positive int numbers, but got {stride}")
  205. stride_h = stride[0]
  206. stride_w = stride[1]
  207. validator.check_value_type('dilation', dilation, (int, tuple))
  208. if isinstance(dilation, int):
  209. dilation = (dilation, dilation)
  210. elif len(dilation) == 4:
  211. dilation = (dilation[2], dilation[3])
  212. if len(dilation) != 2 or (not isinstance(dilation[0], int)) or \
  213. (not isinstance(dilation[1], int)) or \
  214. dilation[0] < 1 or dilation[1] < 1:
  215. raise ValueError(f"The \'dilation\' of \'conv2d\' should be an positive int number or "
  216. f"a tuple of two positive int numbers, but got {dilation}")
  217. dilation_h = dilation[0]
  218. dilation_w = dilation[1]
  219. if isinstance(pad, int):
  220. pad_top = pad
  221. pad_bottom = pad
  222. pad_left = pad
  223. pad_right = pad
  224. elif isinstance(pad, tuple) and len(pad) == 4:
  225. pad_top, pad_bottom, pad_left, pad_right = pad
  226. else:
  227. raise ValueError(f"The \'pad\' should be an int number or "
  228. f"a tuple of two or four int numbers, but got {pad}")
  229. batch_num, _, x_h, x_w = x.shape
  230. filter_num, _, filter_h, filter_w = weight.shape
  231. out_h = 1 + int((x_h + pad_top + pad_bottom - filter_h - (filter_h - 1) * (dilation_h - 1)) / stride_h)
  232. out_w = 1 + int((x_w + pad_left + pad_right - filter_w - (filter_w - 1) * (dilation_w - 1)) / stride_w)
  233. col = im2col(x, filter_h, filter_w, stride, pad, dilation)
  234. col_w = np.reshape(weight, (filter_num, -1)).T
  235. out = np.dot(col, col_w)
  236. out = out.reshape(batch_num, out_h, out_w, -1).transpose(0, 3, 1, 2)
  237. if bias is not None:
  238. out += bias
  239. return out
  240. def conv2d_backprop_filter(dout, x, w_size, stride=1, pad=0):
  241. """Backpropagation filter for conv2d."""
  242. filter_num, channel, filter_height, filter_width = w_size
  243. dout = dout.transpose(0, 2, 3, 1).reshape(-1, filter_num)
  244. col = im2col(x, filter_height, filter_width, stride, pad)
  245. dw = np.dot(col.T, dout)
  246. dw = dw.transpose(1, 0).reshape(filter_num, channel, filter_height, filter_width)
  247. return dw
  248. def conv2d_backprop_input(dout, x_size, weight, stride=1, pad=0):
  249. """Backpropagation input for conv2d."""
  250. filter_num, _, filter_h, filter_w = weight.shape
  251. dout = dout.transpose(0, 2, 3, 1).reshape(-1, filter_num)
  252. col_w = weight.reshape(filter_num, -1).T
  253. dcol = np.dot(dout, col_w.T)
  254. dx = col2im(dcol, x_size, filter_h, filter_w, stride, pad)
  255. return dx
  256. def flatten(x):
  257. """
  258. Flattens an array to one dimension.
  259. Args:
  260. x (numpy.ndarray): An array to be flattened.
  261. Returns:
  262. numpy.ndarray, a flattened array in one dimension.
  263. """
  264. return x.flatten()
  265. def flatten2(x):
  266. """
  267. Flattens an array to one dimension by reshape.
  268. Args:
  269. x (numpy.ndarray): An array to be flattened.
  270. Returns:
  271. numpy.ndarray, a flattened array in one dimension.
  272. """
  273. return x.reshape(1, -1)
  274. def flatten_batch(x):
  275. """
  276. Flattens a batch of arrays to one dimension.
  277. Args:
  278. x (numpy.ndarray): A batch of arrays to be flattened.
  279. Returns:
  280. numpy.ndarray, a flattened one dimension array.
  281. """
  282. return x.reshape(x.shape[0], -1)
  283. def flatten_grad(dout, x):
  284. """Grad of flatten."""
  285. dout = np.reshape(dout, x)
  286. return dout
  287. def im2col(img, filter_h, filter_w, stride=1, pad=0, dilation=1):
  288. """Rearranges an image to row vector."""
  289. if isinstance(stride, int):
  290. stride_h = stride
  291. stride_w = stride
  292. elif isinstance(stride, tuple) and len(stride) == 2:
  293. stride_h = stride[0]
  294. stride_w = stride[1]
  295. elif isinstance(stride, tuple) and len(stride) == 4:
  296. stride_h = stride[2]
  297. stride_w = stride[3]
  298. else:
  299. raise ValueError(f"The \'stride\' should be an int number or "
  300. f"a tuple of two or four int numbers, but got {stride}")
  301. if isinstance(dilation, int):
  302. dilation_h = dilation
  303. dilation_w = dilation
  304. elif isinstance(dilation, tuple) and len(dilation) == 2:
  305. dilation_h = dilation[0]
  306. dilation_w = dilation[1]
  307. elif isinstance(dilation, tuple) and len(dilation) == 4:
  308. dilation_h = dilation[2]
  309. dilation_w = dilation[3]
  310. else:
  311. raise ValueError(f"The \'dilation\' should be an int number or "
  312. f"a tuple of two or four int numbers, but got {dilation}")
  313. if isinstance(pad, int):
  314. pad_top = pad
  315. pad_bottom = pad
  316. pad_left = pad
  317. pad_right = pad
  318. elif isinstance(pad, tuple) and len(pad) == 4:
  319. pad_top, pad_bottom, pad_left, pad_right = pad
  320. else:
  321. raise ValueError(f"The \'pad\' should be an int number or "
  322. f"a tuple of two or four int numbers, but got {pad}")
  323. batch_num, channel, height, width = img.shape
  324. out_h = (height + pad_top + pad_bottom - filter_h - (filter_h - 1) * (dilation_h - 1)) // stride_h + 1
  325. out_w = (width + pad_left + pad_right - filter_w - (filter_w - 1) * (dilation_w - 1)) // stride_w + 1
  326. img = np.pad(img, [(0, 0), (0, 0), (pad_top, pad_bottom), (pad_left, pad_right)], 'constant')
  327. col = np.zeros((batch_num, channel, filter_h, filter_w, out_h, out_w)).astype(img.dtype)
  328. for y in range(filter_h):
  329. y_max = y + stride_h * out_h
  330. for x in range(filter_w):
  331. x_max = x + stride_h * out_w
  332. col[:, :, y, x, :, :] = img[:, :, y:y_max:stride_h, x:x_max:stride_h]
  333. col = col.transpose(0, 4, 5, 1, 2, 3).reshape(batch_num * out_h * out_w, -1)
  334. return col
  335. def matmul(x, w, b=None):
  336. """
  337. Dot product of array x and w, then plus array b if b is not None.
  338. Args:
  339. x (numpy.ndarray): Represents the input array.
  340. w (numpy.ndarray): Represents weights array.
  341. b (numpy.ndarray): Represents bias array which has the same shape as x. Default: None.
  342. Returns:
  343. numpy.ndarray, the result of (x*w + b).
  344. """
  345. y = np.dot(x, w)
  346. if b:
  347. y += b
  348. return y
  349. def max_pooling(x, pool_h, pool_w, stride):
  350. """Max pooling."""
  351. validator.check_positive_int(stride, "stride")
  352. num, channel, height, width = x.shape
  353. out_h = (height - pool_h) // stride + 1
  354. out_w = (width - pool_w) // stride + 1
  355. col = im2col(x, pool_h, pool_w, stride)
  356. col = col.reshape(-1, pool_h * pool_w)
  357. out = np.max(col, axis=1)
  358. out = out.reshape((num, out_h, out_w, channel)).transpose(0, 3, 1, 2)
  359. return out
  360. def max_pool_grad(x, dout, pool_h, pool_w, stride):
  361. """Grad of max pooling."""
  362. dout = dout.transpose(0, 2, 3, 1)
  363. pool_size = pool_h * pool_w
  364. dmax = np.zeros((dout.size, pool_size), dout.dtype)
  365. col = im2col(x, pool_h, pool_w, stride)
  366. col = col.reshape(-1, pool_h * pool_w)
  367. arg_max = np.argmax(col, axis=1)
  368. dmax[np.arange(arg_max.size), arg_max.flatten()] = dout.flatten()
  369. dmax = dmax.reshape(dout.shape + (pool_size,))
  370. dcol = dmax.reshape(dmax.shape[0] * dmax.shape[1] * dmax.shape[2], -1)
  371. dx = col2im(dcol, x.shape, pool_h, pool_w, stride)
  372. return dx
  373. def max_pool_grad_with_argmax(x, dout, arg_max, pool_h, pool_w, stride):
  374. """Grad of max pooling with argmax."""
  375. dout = dout.transpose(0, 2, 3, 1)
  376. pool_size = pool_h * pool_w
  377. dmax = np.zeros((dout.size, pool_size), dout.dtype)
  378. dmax[np.arange(arg_max.size), arg_max.flatten()] = dout.flatten()
  379. dmax = dmax.reshape(dout.shape + (pool_size,))
  380. dcol = dmax.reshape(dmax.shape[0] * dmax.shape[1] * dmax.shape[2], -1)
  381. dx = col2im(dcol, x.shape, pool_h, pool_w, stride)
  382. return dx
  383. def max_pool_with_argmax(x, pool_h, pool_w, stride):
  384. """Max pooling with argmax."""
  385. validator.check_positive_int(stride, "stride")
  386. num, channel, height, width = x.shape
  387. out_h = (height - pool_h) // stride + 1
  388. out_w = (width - pool_w) // stride + 1
  389. col = im2col(x, pool_h, pool_w, stride)
  390. col = col.reshape(-1, pool_h * pool_w)
  391. out = np.max(col, axis=1)
  392. out_argmax = np.argmax(col, axis=1)
  393. out = out.reshape((num, out_h, out_w, channel)).transpose(0, 3, 1, 2)
  394. out_argmax = out_argmax.reshape((num, out_h, out_w, channel)).transpose(0, 3, 1, 2)
  395. return out, out_argmax
  396. def relu(x):
  397. """
  398. Rectified linear unit.
  399. Args:
  400. x (numpy.ndarray): The input array.
  401. Returns:
  402. numpy.ndarray, the array applied relu.
  403. """
  404. return x * (x > 0)
  405. def relu_grad(y):
  406. """
  407. Grad of relu.
  408. Args:
  409. y (numpy.ndarray): The input array.
  410. Returns:
  411. numpy.ndarray, the array applied grad of relu.
  412. """
  413. y[y <= 0] = 0
  414. y[y > 0] = 1
  415. return y
  416. def sigmoid(x):
  417. """
  418. Sigmoid activation function.
  419. Args:
  420. x (numpy.ndarray): The input array.
  421. Returns:
  422. numpy.ndarray, the array applied sigmoid.
  423. """
  424. return 1 / (1 + np.exp(x * -1))
  425. def tanh(x):
  426. """
  427. Computes hyperbolic tangent element-wise.
  428. Args:
  429. x (numpy.ndarray): The input array.
  430. Returns:
  431. numpy.ndarray, the array applied tanh.
  432. """
  433. a = np.exp(x) - np.exp(x * -1)
  434. b = np.exp(x) + np.exp(x * -1)
  435. return a / b
  436. def softmax(x, axis=None):
  437. """
  438. Softmax function which is `softmax(x) = np.exp(x)/sum(np.exp(x))`.
  439. Args:
  440. x (numpy.ndarray): Input array.
  441. axis (Union[int, tuple[int]]): Axis to compute values along. Default: None.
  442. Returns:
  443. numpy.ndarray, has the same shape as x.
  444. """
  445. from scipy.special import softmax as scipy_softmax
  446. return scipy_softmax(x, axis)
  447. def softmax_cross_entropy_with_logits(logits, labels):
  448. sample_num = labels.shape[0]
  449. prob = softmax(logits)
  450. log_likelihood = -np.log(prob[range(sample_num)]) * labels
  451. loss = np.sum(log_likelihood)
  452. dx = prob.copy()
  453. dx[range(sample_num)] -= labels
  454. return loss, dx
  455. def shape(x):
  456. """
  457. Gets the array's dimensions.
  458. Args:
  459. x (numpy.ndarray): Input array.
  460. Returns:
  461. tuple, the shape/dimensions of the input array.
  462. """
  463. return np.array(np.shape(x))
  464. def expand_dims(x, axis):
  465. """
  466. Expands the shape of an array.
  467. Args:
  468. x (numpy.ndarray): Input array.
  469. axis (int): Position in the expanded axes where the new axis is placed.
  470. Returns:
  471. numpy.ndarray, view of input array with the number of dimensions increased by one.
  472. """
  473. return np.expand_dims(x, axis)
  474. def squeeze(x, axis):
  475. """
  476. Removes single-dimensional entries from the shape of an array.
  477. Args:
  478. x (numpy.ndarray): Input array.
  479. axis (Union[int, tuple[int]]): Selected subset of the single-dimensional entries in the shape.
  480. Returns:
  481. numpy.ndarray, the input numpy.ndarray, but with all or a subset of the dimensions of length
  482. 1 removed.
  483. """
  484. return np.squeeze(x, tuple(axis))
  485. def reshape(x, shp):
  486. """
  487. Applies a new shape to an array without changing its data.
  488. Args:
  489. x (numpy.ndarray): Input array.
  490. shp (tuple[int]): New shape to apply to x.
  491. Returns:
  492. numpy.ndarray, a new view object or a copy of input array.
  493. """
  494. return np.reshape(x, tuple(shp))
  495. def rank(x):
  496. """
  497. Gets number of array dimensions.
  498. Args:
  499. x (numpy.ndarray): Input array.
  500. Returns:
  501. int, number of input array dimensions.
  502. """
  503. return np.array(np.ndim(x))
  504. def logsoftmax(x):
  505. """
  506. Log softmax function.
  507. Args:
  508. x (numpy.ndarray): Input array.
  509. Returns:
  510. numpy.ndarray, the result of applying log softmax on the input array.
  511. """
  512. return np.array(np.log(softmax(x)))
  513. def transpose(x, axes=None):
  514. """
  515. Transposes an input array according to axes.
  516. Args:
  517. x (numpy.ndarray): Input array.
  518. axes (list): The axes to be transposed. Default: None.
  519. Returns:
  520. numpy.ndarray, transposed array.
  521. """
  522. return np.transpose(x, axes)
  523. def invert_permutation(x):
  524. """
  525. Gets the inverse permutation of an array.
  526. Args:
  527. x (numpy.ndarray): Input array.
  528. Returns:
  529. tuple, the inverse permutation of the input array.
  530. """
  531. x = np.array(x)
  532. y = np.argsort(x)
  533. return tuple(y)
  534. def select(cond, x, y):
  535. """
  536. Gets elements from x or y depending on cond.
  537. Args:
  538. cond (bool): Where True, yield x, otherwise yield y.
  539. x (numpy.ndarray): Values from which to choose.
  540. y (numpy.ndarray): Values from which to choose.
  541. Returns:
  542. numpy.ndarray, elements from x where condition is True, and elements from y elsewhere.
  543. """
  544. return np.where(cond, x, y)
  545. def sum_by_axis(x, axis):
  546. """
  547. Sum of array elements over a given axis.
  548. Args:
  549. x (numpy.ndarray): Input array.
  550. axis (Union[int, tuple[int]]): Axis or axes along which a sum is performed.
  551. Returns:
  552. numpy.ndarray, has the same shape as input array with the specified axis removed.
  553. """
  554. return np.sum(x, axis)
  555. def equal(x, y):
  556. """
  557. Gets (x == y) element-wise.
  558. Args:
  559. x (numpy.ndarray): Input array.
  560. y (numpy.ndarray): Input array.
  561. Returns:
  562. numpy.ndarray, element-wise comparison of x and y.
  563. """
  564. return np.equal(x, y)
  565. def not_equal(x, y):
  566. """
  567. Gets (x != y) element-wise.
  568. Args:
  569. x (numpy.ndarray): Input array.
  570. y (numpy.ndarray): Input array.
  571. Returns:
  572. numpy.ndarray, element-wise comparison of x and y.
  573. """
  574. return np.not_equal(x, y)
  575. def greater(x, y):
  576. """
  577. Get the truth value of (x > y) element-wise.
  578. Args:
  579. x (numpy.ndarray): Input array.
  580. y (numpy.ndarray): Input array.
  581. Returns:
  582. numpy.ndarray, element-wise comparison of x and y.
  583. """
  584. return np.greater(x, y)
  585. def less(x, y):
  586. """
  587. Get the truth value of (x < y) element-wise.
  588. Args:
  589. x (numpy.ndarray): Input array.
  590. y (numpy.ndarray): Input array.
  591. Returns:
  592. Array, element-wise comparison of x and y.
  593. """
  594. return np.less(x, y)
  595. def logical_not(x):
  596. """
  597. Gets the truth value of NOT x element-wise.
  598. Args:
  599. x (numpy.ndarray): Input array.
  600. Returns:
  601. bool, have the same shape as x of the NOT operation on elements of x.
  602. """
  603. return np.logical_not(x)
  604. def sqrt(x):
  605. """
  606. Gets the non-negative square-root of an numpy.ndarray, element-wise.
  607. Args:
  608. x (numpy.ndarray): Input array.
  609. Returns:
  610. numpy.ndarray, has the same shape as x, containing the positive square-root of each
  611. element in x.
  612. """
  613. return np.sqrt(x)
  614. def power(x, y):
  615. """
  616. First array elements raised to powers from second numpy.ndarray, element-wise.
  617. Args:
  618. x (numpy.ndarray): The bases array.
  619. y (numpy.ndarray): The exponents array.
  620. Returns:
  621. numpy.ndarray, the bases in x raised to the exponents in y.
  622. """
  623. return np.power(x, y)
  624. def exp(x):
  625. """
  626. Gets the exponential of all elements in the input array.
  627. Args:
  628. x (numpy.ndarray): Input array.
  629. Returns:
  630. numpy.ndarray, element-wise exponential of x.
  631. """
  632. return np.exp(x)
  633. def maximum(x, y):
  634. """
  635. Gets the max of x and y element-wise.
  636. If x > y, return x. Otherwise, return y.
  637. Args:
  638. x (numpy.ndarray): First input array.
  639. y (numpy.ndarray): Second input array ave the same type as x.
  640. Returns:
  641. numpy.ndarray, has the same type as x.
  642. """
  643. return np.maximum(x, y)
  644. def minimum(x, y):
  645. """
  646. Gets the min of x and y element-wise.
  647. If x < y, return x. Otherwise, return y.
  648. Args:
  649. x (numpy.ndarray): First input array.
  650. y (numpy.ndarray): Second input array have the same type as x.
  651. Returns:
  652. numpy.ndarray, has the same type as x.
  653. """
  654. return np.minimum(x, y)
  655. def all_(x, axis=(), keep_dims=False):
  656. """
  657. Check all array elements along a given axis evaluate to True.
  658. Args:
  659. x (numpy.ndarray): An array to be reduced.
  660. axis (Union[None, int, tuple(int)): Dimensions of reduction.
  661. keep_dims (bool): Whether to keep the reduced dimensions.
  662. Returns:
  663. numpy.ndarray, has the same type as x.
  664. """
  665. axis = None if axis == () else axis
  666. return np.all(x, axis, keepdims=keep_dims)
  667. def any_(x, axis=(), keep_dims=False):
  668. """
  669. Check any array element along a given axis evaluate to True.
  670. Args:
  671. x (numpy.ndarray): An array to be reduced.
  672. axis (Union[None, int, tuple(int)): Dimensions of reduction.
  673. keep_dims (bool): Whether to keep the reduced dimensions.
  674. Returns:
  675. numpy.ndarray, has the same type as x.
  676. """
  677. axis = None if axis == () else axis
  678. return np.any(x, axis, keepdims=keep_dims)