You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

BaseResourceVariable.cs 5.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153
  1. using NumSharp;
  2. using System;
  3. using System.Collections.Generic;
  4. using System.Text;
  5. using Tensorflow.Eager;
  6. using Tensorflow.Gradients;
  7. using static Tensorflow.Binding;
  8. namespace Tensorflow
  9. {
  10. public class BaseResourceVariable : DisposableObject, IVariableV1
  11. {
  12. protected string _name;
  13. public virtual string Name => _handle_name;
  14. protected TF_DataType _dtype;
  15. public TF_DataType dtype => _dtype;
  16. protected string _handle_name;
  17. protected string handle_name => _handle_name;
  18. protected string _unique_id;
  19. public string unique_id => _unique_id;
  20. protected bool _in_graph_mode;
  21. protected bool _trainable;
  22. public bool trainable => _trainable;
  23. protected Tensor _initial_value;
  24. public Tensor initial_value => _initial_value;
  25. protected Tensor _parent_op;
  26. public Tensor parent_op => _parent_op;
  27. /// <summary>
  28. /// Tensor handle
  29. /// </summary>
  30. protected Tensor handle;
  31. public Tensor Handle => handle;
  32. protected Tensor _graph_element;
  33. public Tensor GraphElement => _graph_element;
  34. protected TensorShape _shape;
  35. public TensorShape shape => _shape;
  36. protected Operation initializer_op;
  37. public Operation Initializer => initializer_op;
  38. public Operation Op => handle.op;
  39. public Graph Graph => handle.graph;
  40. public BaseResourceVariable()
  41. {
  42. }
  43. public BaseResourceVariable(IntPtr handle, IntPtr tensor)
  44. {
  45. _handle = handle;
  46. }
  47. public void __init__(bool trainable = true,
  48. Tensor handle = null,
  49. string name = null,
  50. string unique_id = null,
  51. string handle_name = null)
  52. {
  53. _trainable = trainable;
  54. _handle_name = handle_name + ":0";
  55. _unique_id = unique_id;
  56. this.handle = handle;
  57. _name = name;
  58. // handle_deleter
  59. }
  60. public ITensorOrOperation assign<T>(T value, bool use_locking = false, string name = null, bool read_value = true)
  61. {
  62. var value_tensor = ops.convert_to_tensor(value, dtype: dtype);
  63. var assign_op = gen_resource_variable_ops.assign_variable_op(
  64. handle, value_tensor, name: name);
  65. if (read_value)
  66. return gen_resource_variable_ops.read_variable_op(handle, dtype);
  67. // return _lazy_read(assign_op, value_tensor);
  68. return assign_op;
  69. }
  70. public Tensor value() => _read_variable_op();
  71. protected Tensor _read_variable_op()
  72. {
  73. variable_accessed(this);
  74. var result = gen_resource_variable_ops.read_variable_op(handle, _dtype);
  75. // _maybe_set_handle_data(_dtype, _handle, result);
  76. return result;
  77. }
  78. BaseResourceVariable _lazy_read(Operation op, Tensor value)
  79. {
  80. variable_accessed(this);
  81. return new _UnreadVariable(handle, _dtype, _shape, _in_graph_mode, _unique_id);
  82. }
  83. /// <summary>
  84. /// Records that `variable` was accessed for the tape and FuncGraph.
  85. /// </summary>
  86. void variable_accessed(BaseResourceVariable variable)
  87. {
  88. if (variable.trainable)
  89. {
  90. foreach (var tape in tf.GetTapeSet())
  91. tape.VariableAccessed(variable as ResourceVariable);
  92. }
  93. }
  94. /// <summary>
  95. /// Constructs an op which reads the value of this variable.
  96. ///
  97. /// Should be used when there are multiple reads, or when it is desirable to
  98. /// read the value only after some condition is true.
  99. /// </summary>
  100. /// <returns></returns>
  101. Tensor read_value()
  102. => tf_with(ops.name_scope("Read"), delegate
  103. {
  104. var value = _read_variable_op();
  105. return array_ops.identity(value);
  106. });
  107. public ITensorOrOperation assign_add<T>(T delta, bool use_locking = false, string name = null, bool read_value = true)
  108. {
  109. var assign_add_op = gen_resource_variable_ops.assign_add_variable_op(Handle,
  110. ops.convert_to_tensor(delta, dtype: dtype), name: name);
  111. if (read_value)
  112. return gen_resource_variable_ops.read_variable_op(handle, dtype);
  113. // return _lazy_read(assign_add_op);
  114. return assign_add_op;
  115. }
  116. public override string ToString()
  117. {
  118. if (tf.context.executing_eagerly())
  119. return $"tf.Variable: '{Name}' shape={string.Join(",", shape)}, dtype={dtype.as_numpy_name()}, numpy={tensor_util.to_numpy_string(read_value())}";
  120. else
  121. return $"tf.Variable: '{Name}' shape={string.Join(",", shape)}, dtype={dtype.as_numpy_name()}";
  122. }
  123. public NDArray numpy() => read_value().numpy();
  124. protected override void DisposeUnmanagedResources(IntPtr handle)
  125. {
  126. }
  127. public Tensor AsTensor() => _graph_element;
  128. }
  129. }