You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

BaseResourceVariable.cs 8.0 kB

5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227
  1. using NumSharp;
  2. using System;
  3. using static Tensorflow.Binding;
  4. namespace Tensorflow
  5. {
  6. public class BaseResourceVariable : DisposableObject
  7. {
  8. protected string _name;
  9. public virtual string Name => _handle_name;
  10. protected TF_DataType _dtype;
  11. public TF_DataType dtype => _dtype;
  12. protected string _handle_name;
  13. protected string handle_name => _handle_name;
  14. protected string _unique_id;
  15. public string UniqueId => _unique_id;
  16. protected bool _in_graph_mode;
  17. protected bool _trainable;
  18. public bool trainable => _trainable;
  19. protected Tensor _initial_value;
  20. public Tensor initial_value => _initial_value;
  21. public Operation initializer => initializer_op;
  22. protected Tensor _parent_op;
  23. public Tensor parent_op => _parent_op;
  24. /// <summary>
  25. /// Tensor handle
  26. /// </summary>
  27. protected Tensor handle;
  28. public Tensor Handle => handle;
  29. protected Tensor _graph_element;
  30. public Tensor GraphElement => _graph_element;
  31. protected TensorShape _shape;
  32. public TensorShape shape => _shape;
  33. protected Operation initializer_op;
  34. public Operation Initializer => initializer_op;
  35. public Operation Op => handle.op;
  36. public Graph Graph => handle.graph;
  37. public string Device => "";
  38. public BaseResourceVariable()
  39. {
  40. }
  41. public BaseResourceVariable(IntPtr handle, IntPtr tensor)
  42. {
  43. _handle = handle;
  44. }
  45. public void __init__(bool trainable = true,
  46. Tensor handle = null,
  47. string name = null,
  48. string unique_id = null,
  49. string handle_name = null)
  50. {
  51. _trainable = trainable;
  52. _handle_name = handle_name + ":0";
  53. _unique_id = unique_id;
  54. this.handle = handle;
  55. _name = name;
  56. // handle_deleter
  57. }
  58. public Tensor assign<T>(T value, bool use_locking = false, string name = null, bool read_value = true)
  59. {
  60. if (value.GetType() == typeof(Tensor))
  61. {
  62. var assign = gen_state_ops.assign(handle, value, use_locking: use_locking, name: name);
  63. if (read_value)
  64. return assign;
  65. return assign.op;
  66. }
  67. var value_tensor = ops.convert_to_tensor(value, dtype: dtype);
  68. var assign_op = gen_resource_variable_ops.assign_variable_op(
  69. handle, value_tensor, name: name);
  70. if (read_value)
  71. return gen_resource_variable_ops.read_variable_op(handle, dtype);
  72. return assign_op;
  73. }
  74. public void StridedSliceAssign(Tensor value, ParsedSliceArgs slice)
  75. {
  76. _strided_slice_assign(slice.PackedBegin, slice.PackedEnd, slice.PackedStrides, value);
  77. }
  78. void _strided_slice_assign(Tensor begin, Tensor end, Tensor strides, Tensor value, string name = null,
  79. int begin_mask = 0, int end_mask = 0, int ellipsis_mask = 0, int new_axis_mask = 0, int shrink_axis_mask = 0)
  80. {
  81. var op = gen_array_ops.resource_strided_slice_assign(handle, begin, end, strides, value,
  82. begin_mask: begin_mask,
  83. end_mask: end_mask,
  84. ellipsis_mask: ellipsis_mask,
  85. new_axis_mask: new_axis_mask,
  86. shrink_axis_mask: shrink_axis_mask);
  87. }
  88. public IVariableV1 assign_lazy_load(Tensor value, string name = null)
  89. {
  90. var value_tensor = ops.convert_to_tensor(value, dtype: dtype);
  91. var assign_op = gen_resource_variable_ops.assign_variable_op(
  92. handle, value_tensor, name: name);
  93. var variable = _lazy_read(assign_op, value_tensor);
  94. return variable;
  95. }
  96. public Tensor value()
  97. => GraphElement ?? _read_variable_op();
  98. protected Tensor _read_variable_op()
  99. {
  100. variable_accessed(this);
  101. var result = gen_resource_variable_ops.read_variable_op(handle, _dtype);
  102. // _maybe_set_handle_data(_dtype, _handle, result);
  103. // have to set shape when converting to substituent placeholder
  104. if (result.TensorShape.ndim == -1)
  105. {
  106. c_api.TF_GraphSetTensorShape(result.graph,
  107. result._as_tf_output(),
  108. shape.as_list_long(),
  109. shape.ndim,
  110. tf.Status.Handle);
  111. tf.Status.Check(true);
  112. }
  113. return result;
  114. }
  115. IVariableV1 _lazy_read(Operation op, Tensor value)
  116. {
  117. variable_accessed(this);
  118. return new _UnreadVariable(handle, _dtype, _shape, _in_graph_mode, _unique_id);
  119. }
  120. /// <summary>
  121. /// Records that `variable` was accessed for the tape and FuncGraph.
  122. /// </summary>
  123. void variable_accessed(BaseResourceVariable variable)
  124. {
  125. if (variable.trainable)
  126. {
  127. foreach (var tape in tf.GetTapeSet())
  128. tape.VariableAccessed(variable as ResourceVariable);
  129. }
  130. }
  131. /// <summary>
  132. /// Constructs an op which reads the value of this variable.
  133. ///
  134. /// Should be used when there are multiple reads, or when it is desirable to
  135. /// read the value only after some condition is true.
  136. /// </summary>
  137. /// <returns></returns>
  138. protected Tensor read_value()
  139. {
  140. var value = tf_with(ops.name_scope("Read"), delegate
  141. {
  142. return _read_variable_op();
  143. });
  144. return array_ops.identity(value);
  145. }
  146. public Tensor assign_add<T>(T delta, bool use_locking = false, string name = null, bool read_value = true)
  147. {
  148. var assign_add_op = gen_resource_variable_ops.assign_add_variable_op(Handle,
  149. ops.convert_to_tensor(delta, dtype: dtype), name: name);
  150. if (read_value)
  151. return gen_resource_variable_ops.read_variable_op(handle, dtype);
  152. // return _lazy_read(assign_add_op);
  153. return assign_add_op;
  154. }
  155. public Tensor assign_sub<T>(T delta, bool use_locking = false, string name = null, bool read_value = true)
  156. {
  157. var assign_sub_op = gen_resource_variable_ops.assign_sub_variable_op(Handle,
  158. ops.convert_to_tensor(delta, dtype: dtype), name: name);
  159. if (read_value)
  160. return gen_resource_variable_ops.read_variable_op(handle, dtype);
  161. // return _lazy_read(assign_add_op);
  162. return assign_sub_op;
  163. }
  164. public IVariableV1 assign_sub_lazy_load(Tensor delta, string name = null)
  165. {
  166. var assign_sub_op = gen_resource_variable_ops.assign_sub_variable_op(Handle,
  167. ops.convert_to_tensor(delta, dtype: dtype), name: name);
  168. return _lazy_read(assign_sub_op, delta);
  169. }
  170. public override string ToString()
  171. {
  172. if (tf.Context.executing_eagerly())
  173. return $"tf.Variable: '{Name}' shape={string.Join(",", shape)}, dtype={dtype.as_numpy_name()}, numpy={tensor_util.to_numpy_string(read_value())}";
  174. else
  175. return $"tf.Variable: '{Name}' shape={string.Join(",", shape)}, dtype={dtype.as_numpy_name()}";
  176. }
  177. public NDArray numpy() => read_value().numpy();
  178. protected override void DisposeUnmanagedResources(IntPtr handle)
  179. {
  180. }
  181. public Tensor AsTensor(TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool as_ref = false)
  182. {
  183. if (as_ref)
  184. return read_value().op.inputs[0];
  185. else
  186. return value();
  187. }
  188. }
  189. }