You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

vm_me.py 24 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """VM implementations based on numpy."""
  16. import numpy as np
  17. from mindspore._checkparam import Rel
  18. from mindspore._checkparam import Validator as validator
  19. def avg_pooling(x, pool_h, pool_w, stride):
  20. """
  21. Applies average pooling over an input array.
  22. Args:
  23. x (numpy.ndarray): The input array to be average pooled.
  24. pool_h (int): Height of the pooling window.
  25. pool_w (int): Width of the pooling window.
  26. stride (int): The stride of the sliding window.
  27. Returns:
  28. numpy.ndarray, an output array after applying average pooling on input array.
  29. """
  30. validator.check_integer("stride", stride, 0, Rel.GT, None)
  31. num, channel, height, width = x.shape
  32. out_h = (height - pool_h)//stride + 1
  33. out_w = (width - pool_w)//stride + 1
  34. col = im2col(x, pool_h, pool_w, stride)
  35. col = col.reshape(-1, pool_h*pool_w)
  36. out = np.mean(col, axis=1)
  37. out = out.reshape((num, out_h, out_w, channel)).transpose(0, 3, 1, 2)
  38. return out
  39. def avg_pool_grad(dout, origin_shape, pool_h, pool_w, stride):
  40. """
  41. Gets grad of average pooling.
  42. Args:
  43. x (numpy.ndarray): The input array to be average pooled.
  44. dout (numpy.ndarray): The grad of pre-layer.
  45. pool_h (int): Height of the pooling window.
  46. pool_w (int): Width of the pooling window.
  47. stride (int): The stride of the sliding window.
  48. Returns:
  49. numpy.ndarray, grad of avgerage pooling.
  50. """
  51. # pylint: disable=unused-argument
  52. _, _, height, width = dout.shape
  53. dx = np.zeros(origin_shape)
  54. for i in range(height):
  55. for j in range(width):
  56. dx[:, :, i:(i+pool_h), j:(j+pool_w)] += np.ones((pool_h, pool_w))
  57. return dx
  58. def _batch_norm(x, scale, shift, running_mean=None, running_var=None,
  59. eps=1e-05, momentum=0.1, is_training=True):
  60. """Batch normalization over an array."""
  61. _, c_h_w = x.shape
  62. # Handle running_mean and running_var are not None
  63. # if running_mean is None:
  64. # running_mean = np.zeros(c_h_w)
  65. # running_var = np.zeros(c_h_w)
  66. running_mean = np.zeros(c_h_w)
  67. running_var = np.zeros(c_h_w)
  68. if np.ndim(scale) > 0:
  69. scale = scale.mean()
  70. if np.ndim(shift) > 0:
  71. shift = shift.mean()
  72. if is_training:
  73. x_mean = np.mean(x, axis=0)
  74. x_var = np.var(x, axis=0)
  75. # Normalization followed by Affine transformation
  76. x_norm = (x - x_mean)/np.sqrt(x_var + eps)
  77. # Estimate running average of mean and variance to use at test time
  78. running_mean = momentum * running_mean + (1 - momentum) * x_mean
  79. running_var = momentum * running_var + (1 - momentum) * x_var
  80. else:
  81. # normalize using running average
  82. x_norm = (x - running_mean)/np.sqrt(running_var + eps)
  83. x_mean = running_mean
  84. x_var = running_var
  85. out = scale * x_norm + shift
  86. return out, x_mean, x_var, running_mean, running_var
  87. def batch_norm(x, scale=1, shift=0, mean=None, variance=None,
  88. eps=1e-05, momentum=0.1, is_training=True):
  89. """Batch normalization over an array."""
  90. input_shape = x.shape
  91. if x.ndim != 2:
  92. batch_num = x.shape[0]
  93. x = x.reshape(batch_num, -1)
  94. out, _, _, running_mean, running_var = _batch_norm(x, scale, shift, mean, variance, \
  95. eps, momentum, is_training)
  96. return out.reshape(*input_shape), np.array(scale), np.array(shift), running_mean, running_var
  97. def _batch_norm_grad(dout, x, scale, save_mean, save_inv_variance, \
  98. eps=1e-05, momentum=0.1, is_training=True):
  99. """Batch normalization over an array."""
  100. if x.ndim != 2:
  101. batch_num = x.shape[0]
  102. x = x.reshape(batch_num, -1)
  103. if np.ndim(scale) > 0:
  104. scale = scale.mean()
  105. x_norm, x_mean, x_var, _, _ = _batch_norm(x, scale, shift=0, running_mean=save_mean, \
  106. running_var=save_inv_variance, \
  107. eps=eps, momentum=momentum, is_training=is_training)
  108. batch_size = x.shape[0]
  109. dx_norm = scale * dout
  110. dvar = np.sum(dx_norm*(x - x_mean)*((x_var + eps)**(-3.0/2))*(-1.0/2), axis=0)
  111. dmean = np.sum(dx_norm*(-1.0/np.sqrt(x_var + eps)), axis=0) \
  112. + dvar*(np.sum(-2*(x - x_mean), axis=0)*(1.0/batch_size))
  113. dx = dx_norm*(1.0/np.sqrt(x_var + eps)) + dvar*(2.0*(x - x_mean)/batch_size) + dmean*(1.0/batch_size)
  114. dgamma = np.sum(dout*x_norm, axis=0)
  115. dbeta = np.sum(dout, axis=0)
  116. return dx, dgamma, dbeta
  117. def batch_norm_grad(dy, x, scale, save_mean, save_inv_variance):
  118. """Batch normalization over an array."""
  119. if dy.ndim != 2:
  120. batch_size = dy.shape[0]
  121. dy = dy.reshape(batch_size, -1)
  122. dx, dgamma, dbeta = _batch_norm_grad(dy, x, scale, save_mean, save_inv_variance)
  123. input_shape = x.shape
  124. dx = dx.reshape(*input_shape)
  125. return dx, dgamma, dbeta
  126. def col2im(col, input_shape, filter_h, filter_w, stride=1, pad=0):
  127. """Rearranges a row vector to an image."""
  128. if isinstance(stride, int):
  129. stride_h = stride
  130. stride_w = stride
  131. elif isinstance(stride, tuple) and len(stride) == 2:
  132. stride_h = stride[0]
  133. stride_w = stride[1]
  134. elif isinstance(stride, tuple) and len(stride) == 4:
  135. stride_h = stride[2]
  136. stride_w = stride[3]
  137. else:
  138. raise ValueError(f"The \'stride\' should be an int number or "
  139. f"a tuple of two or four int numbers, but got {stride}")
  140. batch_num, channel, height, width = input_shape
  141. out_h = (height + 2*pad - filter_h)//stride_h + 1
  142. out_w = (width + 2*pad - filter_w)//stride_w + 1
  143. col = col.reshape(batch_num, out_h, out_w, channel, filter_h, filter_w) \
  144. .transpose(0, 3, 4, 5, 1, 2)
  145. img = np.zeros((batch_num,
  146. channel,
  147. height + 2*pad + stride_h - 1,
  148. width + 2*pad + stride_w - 1)) \
  149. .astype(col.dtype)
  150. for y in range(filter_h):
  151. y_max = y + stride_h*out_h
  152. for x in range(filter_w):
  153. x_max = x + stride_h*out_w
  154. img[:, :, y:y_max:stride_h, x:x_max:stride_h] += col[:, :, y, x, :, :]
  155. return img[:, :, pad:height + pad, pad:width + pad]
  156. def convolve(x, w, b=None, pad_mode="valid"):
  157. """
  158. Gets the discrete, linear convolution of two one-dimensional sequences.
  159. Args:
  160. x (numpy.ndarray): One-dimensional input array.
  161. w (numpy.ndarray): One-dimensional input array.
  162. b (numpy.ndarray): One-dimensional input array. Default: None.
  163. pad_mode (str): Padding mode which can be: "full" means returns the
  164. convolution at each point of overlap, with an output shape
  165. of (N+M-1,); "same" means returns output of length max(M, N);
  166. Amd "valid" means returns output of length max(M, N) - min(M, N)
  167. + 1. Default: "valid".
  168. Returns:
  169. numpy.ndarray, discrete, linear convolution of x and w, then plus b.
  170. """
  171. if pad_mode not in {"same", "valid"}:
  172. pad_mode = "full"
  173. y = np.convolve(x, w, pad_mode)
  174. if b:
  175. y += b
  176. return y
  177. def conv2d(x, weight, bias=None, stride=1, pad=0,
  178. dilation=1, groups=1, padding_mode='zeros'):
  179. """Convolution 2D."""
  180. # pylint: disable=unused-argument
  181. validator.check_value_type('stride', stride, (int, tuple), None)
  182. if isinstance(stride, int):
  183. stride = (stride, stride)
  184. elif len(stride) == 4:
  185. stride = (stride[2], stride[3])
  186. if len(stride) != 2 or (not isinstance(stride[0], int)) or \
  187. (not isinstance(stride[1], int)) or \
  188. stride[0] < 1 or stride[1] < 1:
  189. raise ValueError(f"The \'stride\' of \'conv2d\' should be an positive int number or "
  190. f"a tuple of two positive int numbers, but got {stride}")
  191. stride_h = stride[0]
  192. stride_w = stride[1]
  193. validator.check_value_type('dilation', dilation, (int, tuple), None)
  194. if isinstance(dilation, int):
  195. dilation = (dilation, dilation)
  196. elif len(dilation) == 4:
  197. dilation = (dilation[2], dilation[3])
  198. if len(dilation) != 2 or (not isinstance(dilation[0], int)) or \
  199. (not isinstance(dilation[1], int)) or \
  200. dilation[0] < 1 or dilation[1] < 1:
  201. raise ValueError(f"The \'dilation\' of \'conv2d\' should be an positive int number or "
  202. f"a tuple of two positive int numbers, but got {dilation}")
  203. dilation_h = dilation[0]
  204. dilation_w = dilation[1]
  205. batch_num, _, x_h, x_w = x.shape
  206. filter_num, _, filter_h, filter_w = weight.shape
  207. out_h = 1 + int((x_h + 2 * pad - filter_h - (filter_h - 1) * (dilation_h - 1)) / stride_h)
  208. out_w = 1 + int((x_w + 2 * pad - filter_w - (filter_w - 1) * (dilation_w - 1)) / stride_w)
  209. col = im2col(x, filter_h, filter_w, stride, pad, dilation)
  210. col_w = np.reshape(weight, (filter_num, -1)).T
  211. out = np.dot(col, col_w)
  212. out = out.reshape(batch_num, out_h, out_w, -1).transpose(0, 3, 1, 2)
  213. if bias is not None:
  214. out += bias
  215. return out
  216. def conv2d_backprop_filter(dout, x, w_size, stride=1, pad=0):
  217. """Backpropagation filter for conv2d."""
  218. filter_num, channel, filter_height, filter_width = w_size
  219. dout = dout.transpose(0, 2, 3, 1).reshape(-1, filter_num)
  220. col = im2col(x, filter_height, filter_width, stride, pad)
  221. dw = np.dot(col.T, dout)
  222. dw = dw.transpose(1, 0).reshape(filter_num, channel, filter_height, filter_width)
  223. return dw
  224. def conv2d_backprop_input(dout, x_size, weight, stride=1, pad=0):
  225. """Backpropagation input for conv2d."""
  226. filter_num, _, filter_h, filter_w = weight.shape
  227. dout = dout.transpose(0, 2, 3, 1).reshape(-1, filter_num)
  228. col_w = weight.reshape(filter_num, -1).T
  229. dcol = np.dot(dout, col_w.T)
  230. dx = col2im(dcol, x_size, filter_h, filter_w, stride, pad)
  231. return dx
  232. def flatten(x):
  233. """
  234. Flattens an array to one dimension.
  235. Args:
  236. x (numpy.ndarray): An array to be flattened.
  237. Returns:
  238. numpy.ndarray, a flattened array in one dimension.
  239. """
  240. return x.flatten()
  241. def flatten2(x):
  242. """
  243. Flattens an array to one dimension by reshape.
  244. Args:
  245. x (numpy.ndarray): An array to be flattened.
  246. Returns:
  247. numpy.ndarray, a flattened array in one dimension.
  248. """
  249. return x.reshape(1, -1)
  250. def flatten_batch(x):
  251. """
  252. Flattens a batch of arrays to one dimension.
  253. Args:
  254. x (numpy.ndarray): A batch of arrays to be flattened.
  255. Returns:
  256. numpy.ndarray, a flattened one dimension array.
  257. """
  258. return x.reshape(x.shape[0], -1)
  259. def flatten_grad(dout, x):
  260. """Grad of flatten."""
  261. dout = np.reshape(dout, x)
  262. return dout
  263. def im2col(img, filter_h, filter_w, stride=1, pad=0, dilation=1):
  264. """Rearranges an image to row vector."""
  265. if isinstance(stride, int):
  266. stride_h = stride
  267. stride_w = stride
  268. elif isinstance(stride, tuple) and len(stride) == 2:
  269. stride_h = stride[0]
  270. stride_w = stride[1]
  271. elif isinstance(stride, tuple) and len(stride) == 4:
  272. stride_h = stride[2]
  273. stride_w = stride[3]
  274. else:
  275. raise ValueError(f"The \'stride\' should be an int number or "
  276. f"a tuple of two or four int numbers, but got {stride}")
  277. if isinstance(dilation, int):
  278. dilation_h = dilation
  279. dilation_w = dilation
  280. elif isinstance(dilation, tuple) and len(dilation) == 2:
  281. dilation_h = dilation[0]
  282. dilation_w = dilation[1]
  283. elif isinstance(dilation, tuple) and len(dilation) == 4:
  284. dilation_h = dilation[2]
  285. dilation_w = dilation[3]
  286. else:
  287. raise ValueError(f"The \'dilation\' should be an int number or "
  288. f"a tuple of two or four int numbers, but got {dilation}")
  289. batch_num, channel, height, width = img.shape
  290. out_h = (height + 2*pad - filter_h- (filter_h - 1) * (dilation_h - 1))//stride_h + 1
  291. out_w = (width + 2*pad - filter_w- (filter_w - 1) * (dilation_w - 1))//stride_w + 1
  292. img = np.pad(img, [(0, 0), (0, 0), (pad, pad), (pad, pad)], 'constant')
  293. col = np.zeros((batch_num, channel, filter_h, filter_w, out_h, out_w)).astype(img.dtype)
  294. for y in range(filter_h):
  295. y_max = y + stride_h*out_h
  296. for x in range(filter_w):
  297. x_max = x + stride_h*out_w
  298. col[:, :, y, x, :, :] = img[:, :, y:y_max:stride_h, x:x_max:stride_h]
  299. col = col.transpose(0, 4, 5, 1, 2, 3).reshape(batch_num*out_h*out_w, -1)
  300. return col
  301. def matmul(x, w, b=None):
  302. """
  303. Dot product of array x and w, then plus array b if b is not None.
  304. Args:
  305. x (numpy.ndarray): Represents the input array.
  306. w (numpy.ndarray): Represents weights array.
  307. b (numpy.ndarray): Represents bias array which has the same shape as x. Default: None.
  308. Returns:
  309. numpy.ndarray, the result of (x*w + b).
  310. """
  311. y = np.dot(x, w)
  312. if b:
  313. y += b
  314. return y
  315. def max_pooling(x, pool_h, pool_w, stride):
  316. """Max pooling."""
  317. validator.check_integer("stride", stride, 0, Rel.GT, None)
  318. num, channel, height, width = x.shape
  319. out_h = (height - pool_h)//stride + 1
  320. out_w = (width - pool_w)//stride + 1
  321. col = im2col(x, pool_h, pool_w, stride)
  322. col = col.reshape(-1, pool_h*pool_w)
  323. out = np.max(col, axis=1)
  324. out = out.reshape((num, out_h, out_w, channel)).transpose(0, 3, 1, 2)
  325. return out
  326. def max_pool_grad(x, dout, pool_h, pool_w, stride):
  327. """Grad of max pooling."""
  328. dout = dout.transpose(0, 2, 3, 1)
  329. pool_size = pool_h * pool_w
  330. dmax = np.zeros((dout.size, pool_size))
  331. col = im2col(x, pool_h, pool_w, stride)
  332. col = col.reshape(-1, pool_h*pool_w)
  333. arg_max = np.argmax(col, axis=1)
  334. dmax[np.arange(arg_max.size), arg_max.flatten()] = dout.flatten()
  335. dmax = dmax.reshape(dout.shape + (pool_size,))
  336. dcol = dmax.reshape(dmax.shape[0]*dmax.shape[1]*dmax.shape[2], -1)
  337. dx = col2im(dcol, x.shape, pool_h, pool_w, stride)
  338. return dx
  339. def max_pool_grad_with_argmax(x, dout, arg_max, pool_h, pool_w, stride):
  340. """Grad of max pooling with argmax."""
  341. dout = dout.transpose(0, 2, 3, 1)
  342. pool_size = pool_h * pool_w
  343. dmax = np.zeros((dout.size, pool_size))
  344. dmax[np.arange(arg_max.size), arg_max.flatten()] = dout.flatten()
  345. dmax = dmax.reshape(dout.shape + (pool_size,))
  346. dcol = dmax.reshape(dmax.shape[0]*dmax.shape[1]*dmax.shape[2], -1)
  347. dx = col2im(dcol, x.shape, pool_h, pool_w, stride)
  348. return dx
  349. def max_pool_with_argmax(x, pool_h, pool_w, stride):
  350. """Max pooling with argmax."""
  351. validator.check_integer("stride", stride, 0, Rel.GT, None)
  352. num, channel, height, width = x.shape
  353. out_h = (height - pool_h)//stride + 1
  354. out_w = (width - pool_w)//stride + 1
  355. col = im2col(x, pool_h, pool_w, stride)
  356. col = col.reshape(-1, pool_h*pool_w)
  357. out = np.max(col, axis=1)
  358. out_argmax = np.argmax(col, axis=1)
  359. out = out.reshape((num, out_h, out_w, channel)).transpose(0, 3, 1, 2)
  360. out_argmax = out_argmax.reshape((num, out_h, out_w, channel)).transpose(0, 3, 1, 2)
  361. return out, out_argmax
  362. def relu(x):
  363. """
  364. Rectified linear unit.
  365. Args:
  366. x (numpy.ndarray): The input array.
  367. Returns:
  368. numpy.ndarray, the array applied relu.
  369. """
  370. return x * (x > 0)
  371. def relu_grad(y):
  372. """
  373. Grad of relu.
  374. Args:
  375. y (numpy.ndarray): The input array.
  376. Returns:
  377. numpy.ndarray, the array applied grad of relu.
  378. """
  379. y[y <= 0] = 0
  380. y[y > 0] = 1
  381. return y
  382. def sigmoid(x):
  383. """
  384. Sigmoid activation function.
  385. Args:
  386. x (numpy.ndarray): The input array.
  387. Returns:
  388. numpy.ndarray, the array applied sigmoid.
  389. """
  390. return 1 / (1 + np.exp(x * -1))
  391. def tanh(x):
  392. """
  393. Computes hyperbolic tangent element-wise.
  394. Args:
  395. x (numpy.ndarray): The input array.
  396. Returns:
  397. numpy.ndarray, the array applied tanh.
  398. """
  399. a = np.exp(x) - np.exp(x * -1)
  400. b = np.exp(x) + np.exp(x * -1)
  401. return a / b
  402. def softmax(x, axis=None):
  403. """
  404. Softmax function which is `softmax(x) = np.exp(x)/sum(np.exp(x))`.
  405. Args:
  406. x (numpy.ndarray): Input array.
  407. axis (Union[int, tuple[int]]): Axis to compute values along. Default: None.
  408. Returns:
  409. numpy.ndarray, has the same shape as x.
  410. """
  411. from scipy.special import softmax as scipy_softmax
  412. return scipy_softmax(x, axis)
  413. def softmax_cross_entropy_with_logits(logits, labels):
  414. sample_num = labels.shape[0]
  415. prob = softmax(logits)
  416. log_likelihood = -np.log(prob[range(sample_num)]) * labels
  417. #loss = np.sum(log_likelihood)
  418. loss = log_likelihood
  419. dx = prob.copy()
  420. dx[range(sample_num)] -= labels
  421. return loss, dx
  422. def shape(x):
  423. """
  424. Gets the array's dimensions.
  425. Args:
  426. x (numpy.ndarray): Input array.
  427. Returns:
  428. tuple, the shape/dimensions of the input array.
  429. """
  430. return np.array(np.shape(x))
  431. def expand_dims(x, axis):
  432. """
  433. Expands the shape of an array.
  434. Args:
  435. x (numpy.ndarray): Input array.
  436. axis (int): Position in the expanded axes where the new axis is placed.
  437. Returns:
  438. numpy.ndarray, view of input array with the number of dimensions increased by one.
  439. """
  440. return np.expand_dims(x, axis)
  441. def squeeze(x, axis):
  442. """
  443. Removes single-dimensional entries from the shape of an array.
  444. Args:
  445. x (numpy.ndarray): Input array.
  446. axis (Union[int, tuple[int]]): Selected subset of the single-dimensional entries in the shape.
  447. Returns:
  448. numpy.ndarray, the input numpy.ndarray, but with all or a subset of the dimensions of length
  449. 1 removed.
  450. """
  451. return np.squeeze(x, tuple(axis))
  452. def reshape(x, shp):
  453. """
  454. Applies a new shape to an array without changing its data.
  455. Args:
  456. x (numpy.ndarray): Input array.
  457. shp (tuple[int]): New shape to apply to x.
  458. Returns:
  459. numpy.ndarray, a new view object or a copy of input array.
  460. """
  461. return np.reshape(x, tuple(shp))
  462. def rank(x):
  463. """
  464. Gets number of array dimensions.
  465. Args:
  466. x (numpy.ndarray): Input array.
  467. Returns:
  468. int, number of input array dimensions.
  469. """
  470. return np.array(np.ndim(x))
  471. def logsoftmax(x):
  472. """
  473. Log softmax function.
  474. Args:
  475. x (numpy.ndarray): Input array.
  476. Returns:
  477. numpy.ndarray, the result of applying log softmax on the input array.
  478. """
  479. return np.array(np.log(softmax(x)))
  480. def transpose(x, axes=None):
  481. """
  482. Transposes an input array according to axes.
  483. Args:
  484. x (numpy.ndarray): Input array.
  485. axes (list): The axes to be transposed. Default: None.
  486. Returns:
  487. numpy.ndarray, transposed array.
  488. """
  489. return np.transpose(x, axes)
  490. def invert_permutation(x):
  491. """
  492. Gets the inverse permutation of an array.
  493. Args:
  494. x (numpy.ndarray): Input array.
  495. Returns:
  496. tuple, the inverse permutation of the input array.
  497. """
  498. x = np.array(x)
  499. y = np.argsort(x)
  500. return tuple(y)
  501. def select(cond, x, y):
  502. """
  503. Gets elements from x or y depending on cond.
  504. Args:
  505. cond (bool): Where True, yield x, otherwise yield y.
  506. x (numpy.ndarray): Values from which to choose.
  507. y (numpy.ndarray): Values from which to choose.
  508. Returns:
  509. numpy.ndarray, elements from x where condition is True, and elements from y elsewhere.
  510. """
  511. return np.where(cond, x, y)
  512. def sum_by_axis(x, axis):
  513. """
  514. Sum of array elements over a given axis.
  515. Args:
  516. x (numpy.ndarray): Input array.
  517. axis (Union[int, tuple[int]]): Axis or axes along which a sum is performed.
  518. Returns:
  519. numpy.ndarray, has the same shape as input array with the specified axis removed.
  520. """
  521. return np.sum(x, axis)
  522. def equal(x, y):
  523. """
  524. Gets (x == y) element-wise.
  525. Args:
  526. x (numpy.ndarray): Input array.
  527. y (numpy.ndarray): Input array.
  528. Returns:
  529. numpy.ndarray, element-wise comparison of x and y.
  530. """
  531. return np.equal(x, y)
  532. def not_equal(x, y):
  533. """
  534. Gets (x != y) element-wise.
  535. Args:
  536. x (numpy.ndarray): Input array.
  537. y (numpy.ndarray): Input array.
  538. Returns:
  539. numpy.ndarray, element-wise comparison of x and y.
  540. """
  541. return np.not_equal(x, y)
  542. def greater(x, y):
  543. """
  544. Get the truth value of (x > y) element-wise.
  545. Args:
  546. x (numpy.ndarray): Input array.
  547. y (numpy.ndarray): Input array.
  548. Returns:
  549. numpy.ndarray, element-wise comparison of x and y.
  550. """
  551. return np.greater(x, y)
  552. def less(x, y):
  553. """
  554. Get the truth value of (x < y) element-wise.
  555. Args:
  556. x (numpy.ndarray): Input array.
  557. y (numpy.ndarray): Input array.
  558. Returns:
  559. Array, element-wise comparison of x and y.
  560. """
  561. return np.less(x, y)
  562. def logical_not(x):
  563. """
  564. Gets the truth value of NOT x element-wise.
  565. Args:
  566. x (numpy.ndarray): Input array.
  567. Returns:
  568. bool, have the same shape as x of the NOT operation on elements of x.
  569. """
  570. return np.logical_not(x)
  571. def sqrt(x):
  572. """
  573. Gets the non-negative square-root of an numpy.ndarray, element-wise.
  574. Args:
  575. x (numpy.ndarray): Input array.
  576. Returns:
  577. numpy.ndarray, has the same shape as x, containing the positive square-root of each
  578. element in x.
  579. """
  580. return np.sqrt(x)
  581. def power(x, y):
  582. """
  583. First array elements raised to powers from second numpy.ndarray, element-wise.
  584. Args:
  585. x (numpy.ndarray): The bases array.
  586. y (numpy.ndarray): The exponents array.
  587. Returns:
  588. numpy.ndarray, the bases in x raised to the exponents in y.
  589. """
  590. return np.power(x, y)
  591. def exp(x):
  592. """
  593. Gets the exponential of all elements in the input array.
  594. Args:
  595. x (numpy.ndarray): Input array.
  596. Returns:
  597. numpy.ndarray, element-wise exponential of x.
  598. """
  599. return np.exp(x)
  600. def maximum(x, y):
  601. """
  602. Gets the max of x and y element-wise.
  603. If x > y, return x. Otherwise, return y.
  604. Args:
  605. x (numpy.ndarray): First input array.
  606. y (numpy.ndarray): Second input array ave the same type as x.
  607. Returns:
  608. numpy.ndarray, has the same type as x.
  609. """
  610. return np.maximum(x, y)
  611. def minimum(x, y):
  612. """
  613. Gets the min of x and y element-wise.
  614. If x < y, return x. Otherwise, return y.
  615. Args:
  616. x (numpy.ndarray): First input array.
  617. y (numpy.ndarray): Second input array have the same type as x.
  618. Returns:
  619. numpy.ndarray, has the same type as x.
  620. """
  621. return np.minimum(x, y)