You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

RefVariable.cs 6.1 kB

6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Linq;
  4. using System.Text;
  5. namespace Tensorflow
  6. {
  7. public partial class RefVariable : VariableV1
  8. {
  9. public bool _in_graph_mode = true;
  10. public Tensor _initial_value;
  11. public string _graph_key;
  12. public bool _trainable;
  13. public Tensor _variable;
  14. public Tensor _snapshot;
  15. private Operation _initializer_op;
  16. public Operation initializer => _initializer_op;
  17. public Operation op => _variable.op;
  18. public Graph graph => _variable.Graph;
  19. public TF_DataType dtype => _variable.dtype;
  20. public TensorShape shape => tensor_util.to_shape(_variable.shape);
  21. public string name => _variable.name;
  22. public RefVariable(object initial_value,
  23. bool trainable = true,
  24. List<string> collections = null,
  25. bool validate_shape = true,
  26. string caching_device = "",
  27. string name = "",
  28. TF_DataType dtype = TF_DataType.DtInvalid) :
  29. base(initial_value, trainable, collections, validate_shape, caching_device, name, dtype)
  30. {
  31. _in_graph_mode = true;
  32. _init_from_args(initial_value, trainable, collections, validate_shape, caching_device, name, dtype);
  33. }
  34. private void _init_from_args(object initial_value,
  35. bool trainable = true,
  36. List<string> collections = null,
  37. bool validate_shape = true,
  38. string caching_device = "",
  39. string name = "",
  40. TF_DataType dtype = TF_DataType.DtInvalid)
  41. {
  42. if (initial_value is null)
  43. throw new ValueError("initial_value must be specified.");
  44. var init_from_fn = false;
  45. if(collections == null)
  46. {
  47. collections = new List<string> { ops.GraphKeys.GLOBAL_VARIABLES };
  48. }
  49. // Store the graph key so optimizers know how to only retrieve variables from
  50. // this graph.
  51. _graph_key = ops.get_default_graph()._graph_key;
  52. _trainable = trainable;
  53. if (!collections.Contains(ops.GraphKeys.TRAINABLE_VARIABLES))
  54. collections.Add(ops.GraphKeys.TRAINABLE_VARIABLES);
  55. ops.init_scope();
  56. var values = init_from_fn ? new List<object>() : new List<object> { initial_value };
  57. using (var namescope = new ops.name_scope<object>(name, "Variable", values))
  58. {
  59. name = namescope;
  60. if (init_from_fn)
  61. {
  62. }
  63. // Or get the initial value from a Tensor or Python object.
  64. else
  65. {
  66. _initial_value = ops.convert_to_tensor(initial_value, name: "initial_value");
  67. var shape = _initial_value.shape;
  68. dtype = _initial_value.dtype;
  69. _variable = gen_state_ops.variable_v2(shape, dtype.as_base_dtype(), name);
  70. }
  71. // Manually overrides the variable's shape with the initial value's.
  72. if (validate_shape)
  73. {
  74. var initial_value_shape = _initial_value.shape;
  75. }
  76. // If 'initial_value' makes use of other variables, make sure we don't
  77. // have an issue if these other variables aren't initialized first by
  78. // using their initialized_value() method.
  79. var _initial_value2 = _try_guard_against_uninitialized_dependencies(_initial_value);
  80. _initializer_op = gen_state_ops.assign(_variable, _initial_value2, validate_shape).op;
  81. if (!String.IsNullOrEmpty(caching_device))
  82. {
  83. }
  84. else
  85. {
  86. ops.colocate_with(_initializer_op);
  87. _snapshot = gen_array_ops.identity(_variable, name = "read");
  88. }
  89. ops.add_to_collections(collections, this);
  90. }
  91. }
  92. public Tensor _ref()
  93. {
  94. return _variable;
  95. }
  96. public Tensor _AsTensor()
  97. {
  98. return _snapshot;
  99. }
  100. /// <summary>
  101. /// Attempt to guard against dependencies on uninitialized variables.
  102. /// </summary>
  103. /// <param name="initial_value"></param>
  104. private Tensor _try_guard_against_uninitialized_dependencies(Tensor initial_value)
  105. {
  106. return _safe_initial_value_from_tensor(initial_value, new Dictionary<string, Operation>());
  107. }
  108. /// <summary>
  109. /// Replace dependencies on variables with their initialized values.
  110. /// </summary>
  111. /// <param name="tensor">A `Tensor`. The tensor to replace.</param>
  112. /// <param name="op_cache">A dict mapping operation names to `Operation`s.</param>
  113. /// <returns>A `Tensor` compatible with `tensor`.</returns>
  114. private Tensor _safe_initial_value_from_tensor(Tensor tensor, Dictionary<string, Operation> op_cache)
  115. {
  116. var op = tensor.op;
  117. var new_op = op_cache.ContainsKey(op.Name) ? op_cache[op.Name] : null;
  118. if(new_op == null)
  119. {
  120. new_op = _safe_initial_value_from_op(op, op_cache);
  121. op_cache[op.Name] = new_op;
  122. }
  123. return new_op.outputs[tensor.value_index];
  124. }
  125. private Operation _safe_initial_value_from_op(Operation op, Dictionary<string, Operation> op_cache)
  126. {
  127. var op_type = op.node_def.Op;
  128. switch (op_type)
  129. {
  130. case "IsVariableInitialized":
  131. case "VarIsInitializedOp":
  132. case "ReadVariableOp":
  133. return op;
  134. case "Variable":
  135. case "VariableV2":
  136. case "VarHandleOp":
  137. break;
  138. }
  139. // Recursively build initializer expressions for inputs.
  140. return op;
  141. }
  142. }
  143. }

tensorflow框架的.NET版本,提供了丰富的特性和API,可以借此很方便地在.NET平台下搭建深度学习训练与推理流程。