You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tensorflow_backend.py 25 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050
  1. #! /usr/bin/python
  2. # -*- coding: utf-8 -*-
  3. from __future__ import absolute_import, division, print_function
  4. from .tensorflow_nn import nchw_to_nhwc, nhwc_to_nchw
  5. import tensorflow as tf
  6. _dtypeDict = {
  7. 'DType': tf.DType,
  8. 'float16': tf.float16,
  9. 'float32': tf.float32,
  10. 'float64': tf.float64,
  11. 'int8': tf.int8,
  12. 'int16': tf.int16,
  13. 'int32': tf.int32,
  14. 'int64': tf.int64,
  15. 'uint8': tf.uint8,
  16. 'uint16': tf.uint16,
  17. 'uint32': tf.uint32,
  18. 'uint64': tf.uint64
  19. }
  20. DType = tf.DType
  21. float16 = tf.float16
  22. float32 = tf.float32
  23. float64 = tf.float64
  24. int8 = tf.int8
  25. int16 = tf.int16
  26. int32 = tf.int32
  27. int64 = tf.int64
  28. uint8 = tf.uint8
  29. uint16 = tf.uint16
  30. uint32 = tf.uint32
  31. uint64 = tf.uint64
  32. # isinstance input output
  33. # TensorLike = tf_ops._TensorLike
  34. def set_context(**kwargs):
  35. raise Exception("Using TenosrFlow backend,You don't need to set context")
  36. def get_tensor_shape(x):
  37. return x.get_shape().as_list()
  38. # initializers
  39. def zeros(shape, dtype=tf.float32):
  40. """
  41. Creates a tensor with all elements set to zero.
  42. Parameters
  43. ----------
  44. shape : A list of integers
  45. a tuple of integers, or a 1-D Tensor of type int32.
  46. dtype : tensor
  47. The DType of an element in the resulting Tensor
  48. Returns
  49. -------
  50. A Tensor with all elements set to zero.
  51. """
  52. return tf.zeros(shape=shape, dtype=dtype)
  53. def ones(shape, dtype=tf.float32):
  54. """
  55. Creates a tensor with all elements set to ones.
  56. Parameters
  57. ----------
  58. shape : A list of integers
  59. a tuple of integers, or a 1-D Tensor of type int32.
  60. dtype : tensor
  61. The DType of an element in the resulting Tensor
  62. Returns
  63. -------
  64. A Tensor with all elements set to zero.
  65. """
  66. return tf.ones(shape=shape, dtype=dtype)
  67. def constant(value, dtype=tf.float32, shape=None):
  68. """
  69. Creates a constant tensor from a tensor-like object.
  70. Parameters
  71. ----------
  72. value : list
  73. A constant value (or list) of output type dtype.
  74. dtype : tensor
  75. The type of the elements of the resulting tensor.
  76. shape : tuple
  77. Optional dimensions of resulting tensor.
  78. Returns
  79. -------
  80. A Constant Tensor.
  81. """
  82. return tf.constant(value=value, dtype=dtype, shape=shape)
  83. def random_uniform(shape, minval=0, maxval=None, dtype=tf.float32, seed=None):
  84. """
  85. Outputs random values from a uniform distribution.
  86. Parameters
  87. ----------
  88. shape : tuple
  89. A 1-D integer Tensor or Python array. The shape of the output tensor.
  90. minval : int
  91. The lower bound on the range of random values to generate (inclusive). Defaults to 0.
  92. maxval : int
  93. The upper bound on the range of random values to generate (exclusive). Defaults to 1 if dtype is floating point.
  94. dtype : tensor
  95. The type of the output: float16, float32, float64, int32, or int64.
  96. seed : int
  97. Used in combination with tf.random.set_seed to create a reproducible sequence of tensors across multiple calls.
  98. Returns
  99. -------
  100. A tensor of the specified shape filled with random uniform values.
  101. """
  102. outputs = tf.random.uniform(shape=shape, minval=minval, maxval=maxval, dtype=dtype, seed=seed)
  103. return outputs
  104. def random_normal(shape, mean=0.0, stddev=1.0, dtype=tf.dtypes.float32, seed=None):
  105. """
  106. Outputs random values from a normal distribution.
  107. Parameters
  108. ----------
  109. shape : tuple
  110. A 1-D integer Tensor or Python array. The shape of the output tensor.
  111. mean : float
  112. The mean of the normal distribution
  113. stddev : float
  114. The standard deviation of the normal distribution.
  115. dtype : tensor
  116. The type of the output.
  117. seed : A Python integer
  118. Used to create a random seed for the distribution
  119. Returns
  120. -------
  121. A tensor of the specified shape filled with random normal values.
  122. """
  123. outputs = tf.random.normal(shape=shape, mean=mean, stddev=stddev, dtype=dtype, seed=seed)
  124. return outputs
  125. def truncated_normal(shape, mean=0.0, stddev=1.0, dtype=tf.float32, seed=None):
  126. """
  127. Outputs random values from a truncated normal distribution.
  128. Parameters
  129. ----------
  130. shape : tuple
  131. A 1-D integer Tensor or Python array. The shape of the output tensor.
  132. mean : float
  133. The mean of the normal distribution
  134. stddev : float
  135. The standard deviation of the normal distribution.
  136. dtype : tensor
  137. The type of the output.
  138. seed : A Python integer
  139. Used to create a random seed for the distribution
  140. Returns
  141. -------
  142. A tensor of the specified shape filled with random truncated normal values.
  143. """
  144. outputs = tf.random.truncated_normal(shape=shape, mean=mean, stddev=stddev, dtype=dtype, seed=seed)
  145. return outputs
  146. def he_normal(shape, dtype, seed=None):
  147. """
  148. He normal initializer.
  149. Parameters
  150. ----------
  151. seed : A Python integer.
  152. Used to seed the random generator.
  153. shape : tuple
  154. A 1-D integer Tensor or Python array. The shape of the output tensor.
  155. dtype : tensor
  156. The type of the output.
  157. Returns
  158. -------
  159. A tensor of the specified shape filled with he normal values.
  160. """
  161. return tf.initializers.he_normal(seed)(shape=shape, dtype=dtype)
  162. def Variable(initial_value, name, trainable=True):
  163. """
  164. Creates a new variable with value initial_value.
  165. Parameters
  166. ----------
  167. initial_value : tensor
  168. A Tensor, or Python object convertible to a Tensor
  169. name : str
  170. Optional name for the variable. Defaults to 'Variable' and gets uniquified automatically.
  171. Returns
  172. -------
  173. Variable
  174. """
  175. var = tf.Variable(initial_value=initial_value, name=name, trainable=trainable)
  176. return var
  177. class MatMul(object):
  178. def __init__(self):
  179. pass
  180. def __call__(self, a, b):
  181. return tf.matmul(a, b)
  182. def matmul(a, b):
  183. """
  184. Multiplies matrix a by matrix b, producing a * b.
  185. Parameters
  186. ----------
  187. a : tensor
  188. type float16, float32, float64, int32, complex64, complex128 and rank > 1.
  189. b : tensor
  190. with same type and rank as a.
  191. Returns
  192. -------
  193. A Tensor of the same type as a and b
  194. """
  195. outputs = tf.matmul(a, b)
  196. return outputs
  197. def add(value, bias):
  198. """
  199. Returns x + y element-wise.
  200. Parameters
  201. ----------
  202. value : tensor.
  203. Must be one of the following types: bfloat16, half, float32, float64,
  204. uint8, int8, int16, int32, int64, complex64, complex128, string.
  205. bias : tensor
  206. Must have the same type as a
  207. Returns
  208. -------
  209. A Tensor. Has the same type as a.
  210. """
  211. outputs = tf.add(value, bias)
  212. return outputs
  213. def dtypes(dt):
  214. """
  215. Data dtypes.
  216. Parameters
  217. ----------
  218. dt : string
  219. It could be 'uint8', 'uint16', 'uint32', 'uint64', 'int8', 'int16',
  220. 'int32', 'int64', 'float16', 'float32', 'float64', 'DType'.
  221. Returns
  222. -------
  223. Data dtypes
  224. """
  225. if dt not in _dtypeDict.keys():
  226. raise Exception("Unsupported dtype: {}".format(dt))
  227. return _dtypeDict[dt]
  228. class Maximum(object):
  229. def __init__(self):
  230. pass
  231. def __call__(self, x, y):
  232. return tf.maximum(x=x, y=y)
  233. class Minimum(object):
  234. def __init__(self):
  235. pass
  236. def __call__(self, x, y):
  237. return tf.minimum(x=x, y=y)
  238. def minimum(x, y):
  239. """
  240. Returns the min of x and y (i.e. x < y ? x : y) element-wise.
  241. Parameters
  242. ----------
  243. x : tensor.
  244. Must be one of the following types: bfloat16, half, float32, float64, int32, int64.
  245. y : A Tensor.
  246. Must have the same type as x.
  247. Returns
  248. -------
  249. A Tensor. Has the same type as x
  250. """
  251. outputs = tf.minimum(x=x, y=y)
  252. return outputs
  253. class FlattenReshape(object):
  254. def __init__(self):
  255. pass
  256. def __call__(self, inputs):
  257. dim = 1
  258. for d in get_tensor_shape(inputs)[1:]:
  259. dim *= d
  260. return tf.reshape(inputs, [-1, dim])
  261. class Reshape(object):
  262. def __init__(self, shape):
  263. self.shape = shape
  264. def __call__(self, tensor):
  265. return tf.reshape(tensor, self.shape)
  266. def reshape(tensor, shape):
  267. """
  268. Reshapes a tensor.
  269. Parameters
  270. ----------
  271. tensor : tensor
  272. A Tensor.
  273. shape : tensor
  274. Defines the shape of the output tensor.
  275. Returns
  276. -------
  277. A Tensor. Has the same type as tensor
  278. """
  279. return tf.reshape(tensor, shape)
  280. class Concat(object):
  281. def __init__(self, axis):
  282. super(Concat, self).__init__()
  283. self.axis = axis
  284. def __call__(self, values):
  285. return tf.concat(values=values, axis=self.axis)
  286. def concat(values, axis):
  287. """
  288. Concatenates tensors along one dimension.
  289. Parameters
  290. ----------
  291. values : list
  292. A list of Tensor objects or a single Tensor
  293. axis : int
  294. 0-D int32 Tensor. Dimension along which to concatenate
  295. Returns
  296. -------
  297. A Tensor resulting from concatenation of the input tensors.
  298. """
  299. return tf.concat(values, axis)
  300. def convert_to_tensor(value, dtype=None):
  301. """
  302. Converts the given value to a Tensor.
  303. Parameters
  304. ----------
  305. value : object
  306. An object whose type has a registered Tensor conversion function.
  307. dtype : optional
  308. Optional element type for the returned tensor. If missing, the type is inferred from the type of value.
  309. Returns
  310. -------
  311. A Tensor based on value.
  312. """
  313. return tf.convert_to_tensor(value, dtype)
  314. def convert_to_numpy(value):
  315. return value.numpy()
  316. def sqrt(x):
  317. """
  318. Computes square root of x element-wise.
  319. Parameters
  320. ----------
  321. x : tensor
  322. Must be one of the following types: bfloat16, half, float32, float64, complex64, complex128.
  323. Returns
  324. -------
  325. A Tensor. Has the same type as x.
  326. """
  327. return tf.sqrt(x)
  328. class ReduceSum(object):
  329. def __init__(self, axis=None):
  330. self.axis = axis
  331. def __call__(self, input):
  332. return tf.reduce_sum(input, axis=self.axis)
  333. class ReduceMean(object):
  334. def __init__(self, axis):
  335. self.axis = axis
  336. def __call__(self, inputs):
  337. output = tf.reduce_mean(inputs, self.axis)
  338. return output
  339. def reduce_mean(input_tensor, axis=None):
  340. """
  341. Computes the mean of elements across dimensions of a tensor.
  342. Parameters
  343. ----------
  344. input_tensor : tensor
  345. The tensor to reduce. Should have numeric type.
  346. axis : list
  347. The dimensions to reduce. If None (the default), reduces all dimensions.
  348. Must be in the range [-rank(input_tensor), rank(input_tensor)).
  349. name : str
  350. A name for the operation (optional).
  351. Returns
  352. -------
  353. The reduced tensor.
  354. """
  355. return tf.reduce_mean(input_tensor, axis=axis)
  356. class ReduceMax(object):
  357. def __init__(self, axis):
  358. self.axis = axis
  359. def __call__(self, inputs):
  360. output = tf.reduce_max(inputs, self.axis)
  361. return output
  362. def reduce_max(input_tensor, axis=None):
  363. """
  364. Computes the maximum of elements across dimensions of a tensor.
  365. Parameters
  366. ----------
  367. input_tensor : tensor
  368. The tensor to reduce. Should have real numeric type.
  369. axis : int
  370. The dimensions to reduce. If None (the default), reduces all dimensions.
  371. Must be in the range [-rank(input_tensor), rank(input_tensor)).
  372. name : str
  373. A name for the operation (optional).
  374. Returns
  375. -------
  376. The reduced tensor.
  377. """
  378. return tf.reduce_max(input_tensor, axis=axis)
  379. def reduce_min(input_tensor, axis=None):
  380. """
  381. Computes the minimum of elements across dimensions of a tensor.
  382. Parameters
  383. ----------
  384. input_tensor : tensor
  385. The tensor to reduce. Should have real numeric type.
  386. axis : int
  387. The dimensions to reduce. If None (the default), reduces all dimensions.
  388. Must be in the range [-rank(input_tensor), rank(input_tensor)).
  389. name : str
  390. A name for the operation (optional).
  391. Returns
  392. -------
  393. The reduced tensor.
  394. """
  395. return tf.reduce_min(input_tensor, axis=axis)
  396. class Pad(object):
  397. def __init__(self, paddings, mode="REFLECT", constant_values=0):
  398. if mode not in ['CONSTANT', 'REFLECT', 'SYMMETRIC']:
  399. raise Exception("Unsupported mode: {}".format(mode))
  400. self.paddings = paddings
  401. self.mode = mode
  402. self.constant_values = constant_values
  403. def __call__(self, x):
  404. outputs = tf.pad(x, self.paddings, mode=self.mode, constant_values=self.constant_values)
  405. return outputs
  406. def pad(tensor, paddings, mode='CONSTANT', constant_values=0):
  407. """
  408. Pads a tensor.
  409. Parameters
  410. ----------
  411. tensor : tensor
  412. A Tensor.
  413. paddings : tensor
  414. A Tensor of type int32.
  415. mode : str
  416. One of "CONSTANT", "REFLECT", or "SYMMETRIC" (case-insensitive)
  417. constant_values : int
  418. In "CONSTANT" mode, the scalar pad value to use. Must be same type as tensor.
  419. Returns
  420. -------
  421. A Tensor. Has the same type as tensor.
  422. """
  423. if mode not in ['CONSTANT', 'REFLECT', 'SYMMETRIC']:
  424. raise Exception("Unsupported mode: {}".format(mode))
  425. outputs = tf.pad(tensor, paddings, mode=mode, constant_values=constant_values)
  426. return outputs
  427. class Unstack(object):
  428. def __init__(self, axis, num=None):
  429. self.axis = axis
  430. self.num = num
  431. def __call__(self, values):
  432. return tf.unstack(values, num=self.num, axis=self.axis)
  433. class Stack(object):
  434. def __init__(self, axis=0):
  435. self.axis = axis
  436. def __call__(self, values):
  437. return tf.stack(values, axis=self.axis)
  438. def stack(values, axis=0):
  439. """
  440. Stacks a list of rank-R tensors into one rank-(R+1) tensor.
  441. Parameters
  442. ----------
  443. values : list
  444. A list of Tensor objects with the same shape and type.
  445. axis : int
  446. An int. The axis to stack along. Defaults to the first dimension.
  447. Negative values wrap around, so the valid range is [-(R+1), R+1).
  448. Returns
  449. -------
  450. A stacked Tensor with the same type as values.
  451. """
  452. return tf.stack(values, axis=axis)
  453. class Meshgrid(object):
  454. def __init__(self, indexing='xy'):
  455. super(Meshgrid, self).__init__()
  456. self.index = indexing
  457. def __call__(self, inputs):
  458. return tf.meshgrid(inputs)
  459. def meshgrid(*args, **kwargs):
  460. """
  461. Broadcasts parameters for evaluation on an N-D grid.
  462. Parameters
  463. ----------
  464. x : tensor
  465. Tensors with rank 1.
  466. y : tensor
  467. Tensors with rank 1.
  468. Returns
  469. -------
  470. A list of N Tensors with rank N.
  471. """
  472. return tf.meshgrid(*args, **kwargs)
  473. def range(start, limit=None, delta=1, dtype=None):
  474. """
  475. Creates a sequence of numbers.
  476. Parameters
  477. ----------
  478. start : tensor
  479. A 0-D Tensor (scalar). Acts as first entry in the range if limit is not None;
  480. otherwise, acts as range limit and first entry defaults to 0.
  481. limit : tensor
  482. A 0-D Tensor (scalar). Upper limit of sequence, exclusive. If None,
  483. defaults to the value of start while the first entry of the range defaults to 0.
  484. delta : tensor
  485. A 0-D Tensor (scalar). Number that increments start. Defaults to 1.
  486. dtype : type
  487. The type of the elements of the resulting tensor.
  488. Returns
  489. -------
  490. An 1-D Tensor of type dtype.
  491. """
  492. if limit is None:
  493. outputs = tf.range(start, delta=delta, dtype=dtype)
  494. else:
  495. outputs = tf.range(start, limit, delta=delta, dtype=dtype)
  496. return outputs
  497. class ExpandDims(object):
  498. def __init__(self, axis):
  499. self.axis = axis
  500. def __call__(self, input):
  501. return tf.expand_dims(input, axis=self.axis)
  502. def expand_dims(input, axis):
  503. """
  504. Inserts a dimension of 1 into a tensor's shape.
  505. Parameters
  506. ----------
  507. input : tensor
  508. A Tensor.
  509. axis : int
  510. 0-D (scalar). Specifies the dimension index at which to expand the shape of input.
  511. Must be in the range [-rank(input) - 1, rank(input)].
  512. Returns
  513. -------
  514. A Tensor with the same data as input, but its shape has an additional dimension of size 1 added.
  515. """
  516. return tf.expand_dims(input, axis)
  517. class Tile(object):
  518. def __init__(self):
  519. pass
  520. def __call__(self, input, multiples):
  521. return tf.tile(input, multiples)
  522. def tile(input, multiples):
  523. """
  524. Constructs a tensor by tiling a given tensor.
  525. Parameters
  526. ----------
  527. input : tensor
  528. A Tensor. 1-D or higher.
  529. multiples : tensor
  530. Must be one of the following types: int32, int64. 1-D.
  531. Length must be the same as the number of dimensions in input
  532. Returns
  533. -------
  534. A Tensor. Has the same type as input.
  535. """
  536. return tf.tile(input, multiples)
  537. class Cast(object):
  538. def __init__(self, dtype):
  539. self.dtype = dtype
  540. def __call__(self, x):
  541. return tf.cast(x, dtype=self.dtype)
  542. def cast(x, dtype):
  543. """
  544. Casts a tensor to a new type.
  545. Parameters
  546. ----------
  547. x : tensor
  548. A Tensor or SparseTensor or IndexedSlices of numeric type.
  549. It could be uint8, uint16, uint32, uint64, int8, int16, int32, int64, float16, float32, float64.
  550. dtype : dtpye
  551. The destination type. The list of supported dtypes is the same as x
  552. Returns
  553. -------
  554. A Tensor or SparseTensor or IndexedSlices with same shape as x and same type as dtype.
  555. """
  556. return tf.cast(x, dtype=dtype)
  557. class Transpose(object):
  558. def __init__(self, perm, conjugate=False):
  559. self.perm = perm
  560. self.conjugate = conjugate
  561. def __call__(self, a):
  562. return tf.transpose(a, self.perm, self.conjugate)
  563. def transpose(a, perm=None, conjugate=False):
  564. """
  565. Transposes a.
  566. Parameters
  567. ----------
  568. a : tensor
  569. A Tensor.
  570. perm : list / int
  571. A permutation of the dimensions of a.
  572. conjugate : bool
  573. Setting it to True is mathematically equivalent to tf.math.conj(tf.transpose(input)).
  574. Returns
  575. -------
  576. A transposed Tensor.
  577. """
  578. return tf.transpose(a, perm, conjugate)
  579. def gather_nd(params, indices, batch_dims=0):
  580. """
  581. Gather slices from params into a Tensor with shape specified by indices.
  582. Parameters
  583. ----------
  584. params : tensor
  585. The tensor from which to gather values.
  586. indices : tensor
  587. Must be one of the following types: int32, int64. Index tensor.
  588. batch_dims : int
  589. An integer or a scalar 'Tensor'. The number of batch dimensions.
  590. Returns
  591. -------
  592. A Tensor. Has the same type as params.
  593. """
  594. return tf.gather_nd(params, indices, batch_dims)
  595. def clip_by_value(t, clip_value_min, clip_value_max):
  596. """
  597. Clips tensor values to a specified min and max.
  598. Parameters
  599. ----------
  600. t : tensor
  601. A Tensor or IndexedSlices
  602. clip_value_min : tensor
  603. A 0-D (scalar) Tensor, or a Tensor with the same shape as t. The minimum value to clip by
  604. clip_value_max : tensor
  605. A 0-D (scalar) Tensor, or a Tensor with the same shape as t. The minimum value to clip by
  606. Returns
  607. -------
  608. A clipped Tensor or IndexedSlices.
  609. """
  610. return tf.clip_by_value(t, clip_value_min, clip_value_max)
  611. def split(value, num_or_size_splits, axis=0, num=None):
  612. """
  613. Splits a tensor into sub tensors.
  614. Parameters
  615. ----------
  616. value : tensor
  617. The Tensor to split.
  618. num_or_size_splits : list
  619. Either an integer indicating the number of splits along split_dim or a 1-D integer Tensor or
  620. Python list containing the sizes of each output tensor along split_dim.
  621. axis : int
  622. The dimension along which to split. Must be in the range [-rank(value), rank(value)). Defaults to 0.
  623. num : int
  624. used to specify the number of outputs when it cannot be inferred from the shape of size_splits.
  625. Returns
  626. -------
  627. Tensor objects resulting from splitting value.
  628. """
  629. return tf.split(value=value, num_or_size_splits=num_or_size_splits, axis=axis, num=num)
  630. class Floor(object):
  631. def __call__(self, x):
  632. return tf.floor(x)
  633. def floor(x):
  634. return tf.floor(x)
  635. def gather(params, indices):
  636. return tf.gather(params, indices)
  637. def linspace(start, stop, num):
  638. return tf.linspace(start, stop, num)
  639. def slice(inputs, starts, sizes):
  640. return tf.slice(inputs, starts, sizes)
  641. def add_n(inputs):
  642. return tf.add_n(inputs)
  643. class OneHot(object):
  644. def __init__(self, depth, on_value, off_value, axis, dtype):
  645. self.depth = depth
  646. self.on_value = on_value
  647. self.off_value = off_value
  648. self.axis = axis
  649. self.dtype = dtype
  650. def __call__(self, inputs):
  651. outputs = tf.one_hot(
  652. inputs, self.depth, on_value=self.on_value, off_value=self.off_value, axis=self.axis, dtype=self.dtype
  653. )
  654. return outputs
  655. class L2Normalize(object):
  656. def __init__(self, axis=None, epsilon=1e-12):
  657. self.axis = axis
  658. self.epsilon = epsilon
  659. def __call__(self, input, *args, **kwargs):
  660. outputs = tf.math.l2_normalize(input, axis=self.axis, epsilon=self.epsilon)
  661. return outputs
  662. class EmbeddingLookup(object):
  663. def __init__(self, max_norm=None):
  664. self.max_norm = max_norm
  665. def __call__(self, params, ids):
  666. outputs = tf.nn.embedding_lookup(params=params, ids=ids, max_norm=self.max_norm)
  667. return outputs
  668. class NCELoss(object):
  669. def __init__(self, num_true=1, sampled_values=None, remove_accidental_hits=False):
  670. self.num_true = num_true
  671. self.sampled_values = sampled_values
  672. self.remove_accidental_hits = remove_accidental_hits
  673. def __call__(self, weights, biases, labels, inputs, num_sampled, num_classes):
  674. outputs = tf.nn.nce_loss(
  675. weights=weights, biases=biases, inputs=inputs, labels=labels, num_sampled=num_sampled,
  676. num_classes=num_classes
  677. )
  678. return outputs
  679. class NotEqual(object):
  680. def __init__(self):
  681. pass
  682. def __call__(self, x, y):
  683. return tf.not_equal(x, y)
  684. class CountNonzero(object):
  685. def __init__(self, keepdims=None, dtype=int64):
  686. self.keepdims = keepdims
  687. self.dtype = dtype
  688. def __call__(self, input, axis=None):
  689. return tf.math.count_nonzero(input, axis=axis, keepdims=self.keepdims, dtype=self.dtype)
  690. class Resize:
  691. def __init__(self, scale, method, antialias=False, data_format='channels_last', ksize=None):
  692. self.method = method
  693. self.antialias = antialias
  694. self.scale = scale
  695. self.data_format = data_format
  696. def __call__(self, inputs):
  697. if self.data_format == 'channels_first':
  698. inputs = nchw_to_nhwc(inputs)
  699. if len(get_tensor_shape(inputs)) == 4:
  700. output_size = [int(inputs.shape[1] * self.scale[0]), int(inputs.shape[2] * self.scale[1])]
  701. else:
  702. raise ("The inputs shape must be 4-D Tensor.")
  703. outputs = tf.image.resize(inputs, size=output_size, method=self.method, antialias=self.antialias)
  704. if self.data_format == 'channels_first':
  705. outputs = nhwc_to_nchw(outputs)
  706. return outputs
  707. def resize(inputs, output_size, method, antialias):
  708. return tf.image.resize(inputs, size=output_size, method=method, antialias=antialias)
  709. class ZeroPadding1D(object):
  710. def __init__(self, padding):
  711. self.zeropad = tf.keras.layers.ZeroPadding1D(padding=padding)
  712. def __call__(self, inputs):
  713. return self.zeropad(inputs)
  714. class ZeroPadding2D(object):
  715. def __init__(self, padding):
  716. self.zeropad = tf.keras.layers.ZeroPadding2D(padding=padding)
  717. def __call__(self, inputs):
  718. return self.zeropad(inputs)
  719. class ZeroPadding3D(object):
  720. def __init__(self, padding):
  721. self.zeropad = tf.keras.layers.ZeroPadding3D(padding=padding)
  722. def __call__(self, inputs):
  723. return self.zeropad(inputs)
  724. class Sign(object):
  725. def __init__(self):
  726. pass
  727. def __call__(self, x):
  728. return tf.sign(x)
  729. class Ceil(object):
  730. def __call__(self, x):
  731. return tf.math.ceil(x)
  732. def ceil(x):
  733. return tf.math.ceil(x)
  734. def multiply(x, y):
  735. return tf.multiply(x, y)
  736. def divide(x, y):
  737. return tf.divide(x, y)
  738. def identity(x):
  739. return tf.identity(x)
  740. class BatchToSpace(object):
  741. def __init__(self, block_size, crops):
  742. self.bolock_size = block_size
  743. self.crops = crops
  744. def __call__(self, input_x):
  745. return tf.batch_to_space(input=input_x, block_shape=self.bolock_size, crops=self.crops)
  746. class DepthToSpace(object):
  747. def __init__(self, block_size, data_format='NHWC'):
  748. self.block_size = block_size
  749. self.data_format = data_format
  750. def __call__(self, input):
  751. return tf.nn.depth_to_space(input, block_size=self.block_size, data_format=self.data_format)

TensorLayer3.0 是一款兼容多种深度学习框架为计算后端的深度学习库。计划兼容TensorFlow, Pytorch, MindSpore, Paddle.