You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

BaseResourceVariable.cs 15 kB

5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379
  1. using Tensorflow.NumPy;
  2. using System;
  3. using Tensorflow.Eager;
  4. using Tensorflow.Variables;
  5. using Tensorflow.Train;
  6. using static Tensorflow.Binding;
  7. using System.Collections.Generic;
  8. using System.Diagnostics;
  9. using Tensorflow.Checkpoint;
  10. using Tensorflow.Training.Saving.SavedModel;
  11. using OneOf;
  12. using Tensorflow.Graphs;
  13. namespace Tensorflow
  14. {
  15. public class BaseResourceVariable : DisposableTrackableObject
  16. {
  17. protected string _name;
  18. public virtual string Name => _handle_name;
  19. public virtual string SharedName
  20. {
  21. get
  22. {
  23. // TODO(Rinne): optimize the implementation with refactor of variable.
  24. return _handle_name.Substring(0, _handle_name.IndexOf(':') + 1);
  25. }
  26. }
  27. protected TF_DataType _dtype;
  28. public TF_DataType dtype => _dtype;
  29. protected string _handle_name;
  30. public string handle_name
  31. {
  32. get { return _handle_name; }
  33. set { _handle_name = value; }
  34. }
  35. protected string _unique_id;
  36. public string UniqueId => _unique_id;
  37. protected bool _in_graph_mode;
  38. internal bool InGraphMode => _in_graph_mode;
  39. protected bool _trainable;
  40. public bool Trainable => _trainable;
  41. protected Tensor _initial_value;
  42. public Operation initializer => initializer_op;
  43. protected Tensor _parent_op;
  44. public Tensor parent_op => _parent_op;
  45. /// <summary>
  46. /// Tensor handle
  47. /// </summary>
  48. protected Tensor handle;
  49. public Tensor Handle => handle;
  50. protected Tensor _graph_element;
  51. public Tensor GraphElement => _graph_element;
  52. protected Shape _shape;
  53. public Shape shape => _shape;
  54. protected Operation initializer_op;
  55. public Operation Initializer => initializer_op;
  56. public Operation Op => handle.op;
  57. public Graph Graph => handle.graph;
  58. public string Device => handle.Device;
  59. EagerResourceDeleter eager_resource_deleter;
  60. public VariableAggregation Aggregation { get; protected set; } = VariableAggregation.None;
  61. public BaseResourceVariable()
  62. {
  63. }
  64. public void __init__(bool trainable = true,
  65. Shape shape = null,
  66. TF_DataType dtype = TF_DataType.DtInvalid,
  67. Tensor handle = null,
  68. string name = null,
  69. string unique_id = null,
  70. string handle_name = null)
  71. {
  72. _trainable = trainable;
  73. _handle_name = handle_name + ":0";
  74. _unique_id = unique_id;
  75. this.handle = handle;
  76. _name = name;
  77. if(shape is not null)
  78. {
  79. _shape = shape;
  80. }
  81. if(dtype != TF_DataType.DtInvalid)
  82. {
  83. _dtype = dtype;
  84. }
  85. // After the handle has been created, set up a way to clean it up when
  86. // executing eagerly. We'll hold the only reference to the deleter, so that
  87. // when this object is garbage collected the deleter will be too. This
  88. // means ResourceVariables can be part of reference cycles without those
  89. // cycles being uncollectable.
  90. if (handle is EagerTensor)
  91. {
  92. _handle = handle.EagerTensorHandle.DangerousGetHandle();
  93. // eager_resource_deleter = new EagerResourceDeleter(handle, handle.Device);
  94. }
  95. else if(handle is null)
  96. {
  97. // TODO: fix this dangerous change.
  98. _handle = IntPtr.Zero;
  99. }
  100. else
  101. {
  102. _handle = handle.Handle == null ? IntPtr.Zero : handle.Handle.DangerousGetHandle();
  103. }
  104. #if TRACK_TENSOR_LIFE
  105. print($"Created Resource 0x{_handle.ToString("x16")} {_name}");
  106. #endif
  107. }
  108. public Tensor assign<T>(T value, bool use_locking = false, string name = null, bool read_value = true)
  109. {
  110. if (value.GetType() == typeof(Tensor))
  111. {
  112. var assign = gen_state_ops.assign(handle, value, use_locking: use_locking, name: name);
  113. if (read_value)
  114. return assign;
  115. return assign.op;
  116. }
  117. var value_tensor = ops.convert_to_tensor(value, dtype: dtype);
  118. var assign_op = gen_resource_variable_ops.assign_variable_op(
  119. handle, value_tensor, name: name);
  120. if (read_value)
  121. return gen_resource_variable_ops.read_variable_op(handle, dtype);
  122. if (assign_op == null)
  123. return null;
  124. return assign_op;
  125. }
  126. public void StridedSliceAssign(Tensor value, ParsedSliceArgs slice)
  127. {
  128. _strided_slice_assign(slice.PackedBegin, slice.PackedEnd, slice.PackedStrides, value);
  129. }
  130. void _strided_slice_assign(Tensor begin, Tensor end, Tensor strides, Tensor value, string name = null,
  131. int begin_mask = 0, int end_mask = 0, int ellipsis_mask = 0, int new_axis_mask = 0, int shrink_axis_mask = 0)
  132. {
  133. var op = gen_array_ops.resource_strided_slice_assign(handle, begin, end, strides, value,
  134. begin_mask: begin_mask,
  135. end_mask: end_mask,
  136. ellipsis_mask: ellipsis_mask,
  137. new_axis_mask: new_axis_mask,
  138. shrink_axis_mask: shrink_axis_mask);
  139. }
  140. public IVariableV1 assign_lazy_load(Tensor value, string name = null)
  141. {
  142. var value_tensor = ops.convert_to_tensor(value, dtype: dtype);
  143. var assign_op = gen_resource_variable_ops.assign_variable_op(
  144. handle, value_tensor, name: name);
  145. var variable = _lazy_read(assign_op, value_tensor);
  146. return variable;
  147. }
  148. public Tensor value()
  149. => GraphElement ?? _read_variable_op();
  150. protected Tensor _read_variable_op(bool no_copy = false)
  151. {
  152. variable_accessed(this);
  153. Tensor read_and_set_handle(bool no_copy)
  154. {
  155. if (no_copy)
  156. {
  157. gen_resource_variable_ops.disable_copy_on_read(handle);
  158. }
  159. var result = gen_resource_variable_ops.read_variable_op(handle, _dtype);
  160. resource_variable_ops._maybe_set_handle_data(_dtype, handle, result);
  161. return result;
  162. }
  163. // TODO(Rinne): deal with caching device.
  164. var result = read_and_set_handle(no_copy);
  165. if (!tf.Context.executing_eagerly())
  166. {
  167. tf.Runner.TFE_TapeSetRecordOperation("ReadVariableOp", new Tensor[] { result }, new Tensor[] { handle },
  168. backward_function: (x, _) => x);
  169. }
  170. // have to set shape when converting to substituent placeholder
  171. if (result.shape.ndim == -1)
  172. {
  173. c_api.TF_GraphSetTensorShape(result.graph,
  174. result._as_tf_output(),
  175. shape.dims,
  176. shape.ndim,
  177. tf.Status);
  178. tf.Status.Check(true);
  179. }
  180. return result;
  181. }
  182. IVariableV1 _lazy_read(Operation op, Tensor value)
  183. {
  184. variable_accessed(this);
  185. return new _UnreadVariable(handle, _dtype, _shape, _in_graph_mode, _unique_id);
  186. }
  187. /// <summary>
  188. /// Records that `variable` was accessed for the tape and FuncGraph.
  189. /// </summary>
  190. void variable_accessed(BaseResourceVariable variable)
  191. {
  192. if(ops.get_default_graph() is FuncGraph func_graph)
  193. {
  194. func_graph.watch_variable(variable as IVariableV1);
  195. }
  196. if (variable.Trainable)
  197. {
  198. foreach (var tape in tf.GetTapeSet())
  199. tape.VariableAccessed(variable as ResourceVariable);
  200. }
  201. }
  202. /// <summary>
  203. /// Constructs an op which reads the value of this variable.
  204. ///
  205. /// Should be used when there are multiple reads, or when it is desirable to
  206. /// read the value only after some condition is true.
  207. /// </summary>
  208. /// <returns></returns>
  209. protected Tensor read_value()
  210. {
  211. var value = tf_with(ops.name_scope("Read"), delegate
  212. {
  213. return _read_variable_op();
  214. });
  215. return array_ops.identity(value);
  216. }
  217. public Tensor assign_add<T>(T delta, bool use_locking = false, string name = null, bool read_value = true)
  218. {
  219. var assign_add_op = gen_resource_variable_ops.assign_add_variable_op(Handle,
  220. ops.convert_to_tensor(delta, dtype: dtype), name: name);
  221. if (read_value)
  222. return gen_resource_variable_ops.read_variable_op(handle, dtype);
  223. // return _lazy_read(assign_add_op);
  224. return assign_add_op;
  225. }
  226. public Tensor assign_sub<T>(T delta, bool use_locking = false, string name = null, bool read_value = true)
  227. {
  228. var assign_sub_op = gen_resource_variable_ops.assign_sub_variable_op(Handle,
  229. ops.convert_to_tensor(delta, dtype: dtype), name: name);
  230. if (read_value)
  231. return gen_resource_variable_ops.read_variable_op(handle, dtype);
  232. // return _lazy_read(assign_add_op);
  233. return assign_sub_op;
  234. }
  235. public IVariableV1 assign_sub_lazy_load(Tensor delta, string name = null)
  236. {
  237. var assign_sub_op = gen_resource_variable_ops.assign_sub_variable_op(Handle,
  238. ops.convert_to_tensor(delta, dtype: dtype), name: name);
  239. return _lazy_read(assign_sub_op, delta);
  240. }
  241. public override string ToString()
  242. {
  243. if (tf.Context.executing_eagerly())
  244. return $"tf.Variable: '{Name}' shape={string.Join(",", shape)}, dtype={dtype.as_numpy_name()}, numpy={read_value().numpy()}";
  245. else
  246. return $"tf.Variable: '{Name}' shape={string.Join(",", shape)}, dtype={dtype.as_numpy_name()}";
  247. }
  248. public NDArray numpy() => read_value().numpy();
  249. protected override void DisposeUnmanagedResources(IntPtr handle)
  250. {
  251. #if TRACK_TENSOR_LIFE
  252. print($"Deleted Resource 0x{handle.ToString("x16")} {_name}");
  253. #endif
  254. }
  255. public Tensor AsTensor(TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool as_ref = false)
  256. {
  257. if (as_ref)
  258. return read_value().op.inputs[0];
  259. else
  260. return value();
  261. }
  262. public override (IDictionary<Trackable, Trackable>, IDictionary<Tensor, Tensor>) map_resources(SaveOptions save_options)
  263. {
  264. BaseResourceVariable new_variable;
  265. if (save_options.experimental_variable_policy.save_variable_devices())
  266. {
  267. Debug.Assert(this is ResourceVariable);
  268. new_variable = tf_with(ops.device(this.Device), _ =>
  269. {
  270. return resource_variable_ops.copy_to_graph_uninitialized((ResourceVariable)this);
  271. });
  272. }
  273. else
  274. {
  275. new_variable = resource_variable_ops.copy_to_graph_uninitialized((ResourceVariable)this);
  276. }
  277. Dictionary<Trackable, Trackable> obj_map = new();
  278. Dictionary<Tensor, Tensor> resource_map = new();
  279. obj_map[this] = new_variable;
  280. resource_map[this.handle] = new_variable.handle;
  281. return (obj_map, resource_map);
  282. }
  283. /// <summary>
  284. /// Writes additional information of the variable into the SavedObject proto.
  285. /// ubclasses of ResourceVariables could choose to override this method to
  286. /// customize extra information to provide when saving a SavedModel.
  287. /// </summary>
  288. /// <param name="proto"></param>
  289. /// <param name="options"></param>
  290. public virtual void write_object_proto(SavedObject proto, SaveOptions options)
  291. {
  292. resource_variable_ops.write_object_proto_for_resource_variable(this, proto, options);
  293. }
  294. public override IDictionary<string, Func<string, OneOf<BaseResourceVariable, MySaveableObject>>> gather_saveables_for_checkpoint()
  295. {
  296. var res = new Dictionary<string, Func<string, OneOf<BaseResourceVariable, MySaveableObject>>>();
  297. res[Trackable.Constants.VARIABLE_VALUE_KEY] = x => this;
  298. return res;
  299. }
  300. public Tensor is_initialized(string name = null)
  301. {
  302. return gen_resource_variable_ops.var_is_initialized_op(this.handle, name);
  303. }
  304. public Tensor read_value_no_copy()
  305. {
  306. Tensor value = null;
  307. tf_with(ops.name_scope("Read"), _ =>
  308. {
  309. // TODO: `no_copy = true`.
  310. value = _read_variable_op();
  311. });
  312. return array_ops.identity(value);
  313. }
  314. //public static Tensor operator +(BaseResourceVariable x, int y) => x.value() + y;
  315. //public static Tensor operator +(BaseResourceVariable x, float y) => x.value() + y;
  316. //public static Tensor operator +(BaseResourceVariable x, double y) => x.value() + y;
  317. //public static Tensor operator +(BaseResourceVariable x, BaseResourceVariable y) => x.value() + y.value();
  318. //public static Tensor operator -(BaseResourceVariable x, int y) => x.value() - y;
  319. //public static Tensor operator -(BaseResourceVariable x, float y) => x.value() - y;
  320. //public static Tensor operator -(BaseResourceVariable x, double y) => x.value() - y;
  321. //public static Tensor operator -(BaseResourceVariable x, Tensor y) => x.value() - y;
  322. //public static Tensor operator -(BaseResourceVariable x, BaseResourceVariable y) => x.value() - y.value();
  323. //public static Tensor operator *(BaseResourceVariable x, BaseResourceVariable y) => x.value() * y.value();
  324. //public static Tensor operator *(BaseResourceVariable x, Tensor y) => x.value() * y;
  325. //public static Tensor operator *(BaseResourceVariable x, NDArray y) => x.value() * y;
  326. //public static Tensor operator <(BaseResourceVariable x, Tensor y) => x.value() < y;
  327. //public static Tensor operator >(BaseResourceVariable x, Tensor y) => x.value() > y;
  328. }
  329. }