You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Variable.py 3.6 kB

4 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. from __future__ import absolute_import
  2. import numpy as np
  3. from .Node import Op
  4. from .. import ndarray
  5. from .. import stream
  6. def Variable(name, value=None, initializer=None, trainable=True, dtype=np.float32, ctx=None):
  7. """
  8. Defined a variable.
  9. Trainable: Parameter
  10. Not Trainable: Constant
  11. """
  12. placeholder_node = placeholder_op(
  13. name, value, initializer, trainable, dtype, ctx)
  14. return placeholder_node
  15. class PlaceholderOp(Op):
  16. def __init__(self, name, value=None, initializer=None, trainable=True, dtype=np.float32, ctx=None):
  17. super().__init__(PlaceholderOp, [], ctx)
  18. self.name = name
  19. self.is_embed = False
  20. self.shape = None
  21. if value is None and initializer is None:
  22. trainable = False
  23. elif value is not None:
  24. assert initializer is None, 'Value already specified, initializer should be None.'
  25. assert isinstance(value, (np.ndarray, ndarray.NDArray)),\
  26. 'Value data type %s not valid.' % str(type(value))
  27. self.shape = value.shape
  28. else:
  29. assert initializer is not None, 'Value not specified, initializer should not be None.'
  30. self.shape = initializer.shape
  31. self.tensor_value = value
  32. self.initializer = initializer
  33. self.trainable = trainable
  34. self.dtype = dtype
  35. def compute(self, input_vals, output_val, stream_handle=None):
  36. assert self.shape, "placeholder %s values provided by feed_dict" % self.name
  37. def gradient(self, output_grad):
  38. return None
  39. def infer_shape(self, input_shapes):
  40. assert self.shape, "placeholder %s shape provided by feed_shape" % self.name
  41. return self.shape
  42. def forward_hook(self, config):
  43. pass
  44. def backward_hook(self, config):
  45. if self.ctx is None:
  46. self.ctx = config.context
  47. if (config.node_strategy.get(self, config.comm_mode) == 'PS' or (config.node_strategy.get(self, config.comm_mode) == "Hybrid" and self.is_embed)) and self.trainable:
  48. self.ctx = ndarray.cpu(0)
  49. if config.cstable_policy is not None and self.is_embed:
  50. self.event = stream.CSEvent(config.ps_comm, self.id)
  51. else:
  52. self.event = stream.PSEvent(config.ps_comm, self.id)
  53. else:
  54. if self.initializer:
  55. self.initializer(self, config.seed,
  56. config.np_rand, config.comp_stream)
  57. self.initializer = None
  58. elif self.tensor_value is not None:
  59. value = self.tensor_value
  60. assert isinstance(value, (np.ndarray, ndarray.NDArray)), \
  61. 'Parameters should be initialized as numpy.ndarray or ndarray.NDArray .'
  62. if isinstance(value, np.ndarray):
  63. value = ndarray.array(value, self.ctx)
  64. elif value.ctx != self.ctx:
  65. new_value = ndarray.empty(value.shape, self.ctx)
  66. value.copyto(new_value)
  67. value = new_value
  68. self.tensor_value = value
  69. self.on_gpu = ndarray.is_gpu_ctx(self.ctx)
  70. self.on_cpu = not self.on_gpu
  71. def placeholder_op(name, value=None, initializer=None, trainable=True, dtype=np.float32, ctx=None):
  72. """Node of variable placeholder.
  73. Parameters:
  74. ----
  75. None
  76. Returns:
  77. ----
  78. A new Node instance created by Op.
  79. """
  80. return PlaceholderOp(name, value, initializer, trainable, dtype, ctx)