You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

RefVariable.cs 3.9 kB

6 years ago
6 years ago
6 years ago
6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Text;
  4. namespace Tensorflow
  5. {
  6. public class RefVariable : VariableV1
  7. {
  8. public bool _in_graph_mode = true;
  9. public Tensor _initial_value;
  10. public string _graph_key;
  11. public bool _trainable;
  12. public Tensor _variable;
  13. public Tensor _snapshot;
  14. private Operation _initializer_op;
  15. public Operation initializer => _initializer_op;
  16. public Operation op => _initializer_op;
  17. public RefVariable(object initial_value,
  18. bool trainable = true,
  19. List<string> collections = null,
  20. bool validate_shape = true,
  21. string caching_device = "",
  22. string name = "",
  23. TF_DataType dtype = TF_DataType.DtInvalid) :
  24. base(initial_value, trainable, collections, validate_shape, caching_device, name, dtype)
  25. {
  26. _init_from_args(initial_value, trainable, collections, validate_shape, caching_device, name, dtype);
  27. }
  28. private void _init_from_args(object initial_value,
  29. bool trainable = true,
  30. List<string> collections = null,
  31. bool validate_shape = true,
  32. string caching_device = "",
  33. string name = "",
  34. TF_DataType dtype = TF_DataType.DtInvalid)
  35. {
  36. if (initial_value is null)
  37. throw new ValueError("initial_value must be specified.");
  38. var init_from_fn = false;
  39. if(collections == null)
  40. {
  41. collections = new List<string> { ops.GraphKeys.GLOBAL_VARIABLES };
  42. }
  43. // Store the graph key so optimizers know how to only retrieve variables from
  44. // this graph.
  45. _graph_key = ops.get_default_graph()._graph_key;
  46. _trainable = trainable;
  47. if (!collections.Contains(ops.GraphKeys.TRAINABLE_VARIABLES))
  48. collections.Add(ops.GraphKeys.TRAINABLE_VARIABLES);
  49. ops.init_scope();
  50. var values = init_from_fn ? new List<object>() : new List<object> { initial_value };
  51. using (var namescope = new ops.name_scope<object>(name, "Variable", values))
  52. {
  53. name = namescope;
  54. if (init_from_fn)
  55. {
  56. }
  57. else
  58. {
  59. _initial_value = ops.convert_to_tensor(initial_value, name: "initial_value");
  60. }
  61. var shape = _initial_value.shape;
  62. dtype = _initial_value.dtype;
  63. _variable = gen_state_ops.variable_v2(shape, dtype, name);
  64. // Manually overrides the variable's shape with the initial value's.
  65. if (validate_shape)
  66. {
  67. var initial_value_shape = _initial_value.shape;
  68. }
  69. // If 'initial_value' makes use of other variables, make sure we don't
  70. // have an issue if these other variables aren't initialized first by
  71. // using their initialized_value() method.
  72. _initializer_op = gen_state_ops.assign(_variable, _initial_value, validate_shape).op;
  73. if (!String.IsNullOrEmpty(caching_device))
  74. {
  75. }
  76. else
  77. {
  78. _snapshot = gen_array_ops.identity(_variable, name = "read");
  79. }
  80. ops.add_to_collections(collections, this);
  81. }
  82. }
  83. public Tensor _ref()
  84. {
  85. return _variable;
  86. }
  87. public static implicit operator _VariableScopeStore(RefVariable variable)
  88. {
  89. return null;
  90. }
  91. public static implicit operator RefVariable(_VariableScopeStore store)
  92. {
  93. return null;
  94. }
  95. }
  96. }

tensorflow框架的.NET版本,提供了丰富的特性和API,可以借此很方便地在.NET平台下搭建深度学习训练与推理流程。