From a8a0731c4d012fa8e12c233023af5f2f0e19fb2d Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sat, 11 Jul 2020 12:43:58 -0500 Subject: [PATCH] Change RefVariable to IVariableV1. --- src/TensorFlowNET.Core/APIs/tf.compat.v1.cs | 25 +++++++++ src/TensorFlowNET.Core/APIs/tf.nn.cs | 10 ++-- src/TensorFlowNET.Core/APIs/tf.state.cs | 2 +- src/TensorFlowNET.Core/APIs/tf.train.cs | 4 +- src/TensorFlowNET.Core/APIs/tf.variable.cs | 24 -------- src/TensorFlowNET.Core/Graphs/Graph.cs | 2 +- .../Keras/Optimizers/PolynomialDecay.cs | 2 +- src/TensorFlowNET.Core/Layers/Layer.cs | 2 +- src/TensorFlowNET.Core/Operations/math_ops.cs | 6 +- .../Training/AdamOptimizer.cs | 16 +++--- .../Training/ExponentialMovingAverage.cs | 4 +- src/TensorFlowNET.Core/Training/Optimizer.cs | 22 ++++---- .../Training/SlotCreator.cs | 10 ++-- .../Training/TrainingUtil.cs | 4 +- .../Training/gen_training_ops.cs | 2 +- .../Training/moving_averages.cs | 4 +- .../Variables/BaseResourceVariable.cs | 16 ++++-- .../Variables/IVariableV1.cs | 5 +- .../Variables/RefVariable.Implicit.cs | 2 +- .../Variables/RefVariable.cs | 6 +- .../Variables/VariableScope.cs | 4 +- .../Variables/_VariableStore.cs | 55 ++++--------------- .../Variables/gen_state_ops.py.cs | 4 +- src/TensorFlowNET.Core/Variables/state_ops.cs | 6 +- .../Variables/variable_scope.py.cs | 2 +- src/TensorFlowNET.Core/ops.cs | 5 ++ 26 files changed, 111 insertions(+), 133 deletions(-) diff --git a/src/TensorFlowNET.Core/APIs/tf.compat.v1.cs b/src/TensorFlowNET.Core/APIs/tf.compat.v1.cs index 97e3ac6e..63833991 100644 --- a/src/TensorFlowNET.Core/APIs/tf.compat.v1.cs +++ b/src/TensorFlowNET.Core/APIs/tf.compat.v1.cs @@ -15,6 +15,7 @@ ******************************************************************************/ using System; +using System.Collections.Generic; using Tensorflow.Eager; using static Tensorflow.Binding; @@ -26,5 +27,29 @@ namespace Tensorflow { tf.context.default_execution_mode = Context.GRAPH_MODE; } + + public IVariableV1 get_variable(string name, + TensorShape shape = null, + TF_DataType dtype = TF_DataType.DtInvalid, + object initializer = null, // IInitializer or Tensor + bool? trainable = null, + List collections = null, + bool? use_resource = null, + bool validate_shape = true, + VariableSynchronization synchronization = VariableSynchronization.Auto, + VariableAggregation aggregation = VariableAggregation.None) + { + var scope = Tensorflow.variable_scope.get_variable_scope(); + var store = Tensorflow.variable_scope._get_default_variable_store(); + return scope.get_variable(store, + name, + shape: shape, + dtype: dtype, + use_resource: use_resource, + validate_shape: validate_shape, + initializer: initializer, + trainable: trainable, + collections: collections); + } } } diff --git a/src/TensorFlowNET.Core/APIs/tf.nn.cs b/src/TensorFlowNET.Core/APIs/tf.nn.cs index 3f756502..3d470ea3 100644 --- a/src/TensorFlowNET.Core/APIs/tf.nn.cs +++ b/src/TensorFlowNET.Core/APIs/tf.nn.cs @@ -27,13 +27,13 @@ namespace Tensorflow public class nn_internal { - public Tensor conv2d(Tensor input, RefVariable filter, int[] strides, string padding, bool use_cudnn_on_gpu = true, + public Tensor conv2d(Tensor input, IVariableV1 filter, int[] strides, string padding, bool use_cudnn_on_gpu = true, string data_format= "NHWC", int[] dilations= null, string name = null) { var parameters = new Conv2dParams { Input = input, - Filter = filter, + Filter = filter.AsTensor(), Strides = strides, Padding = padding, UseCudnnOnGpu = use_cudnn_on_gpu, @@ -98,7 +98,7 @@ namespace Tensorflow name: name, keep_dims: keep_dims); - public Tensor embedding_lookup(RefVariable @params, + public Tensor embedding_lookup(IVariableV1 @params, Tensor ids, string partition_strategy = "mod", string name = null) => embedding_ops._embedding_lookup_and_transform(@params, @@ -150,12 +150,12 @@ namespace Tensorflow public Tensor[] top_k(Tensor input, int k = 1, bool sorted = true, string name = null) => gen_nn_ops.top_kv2(input, k: k, sorted: sorted, name: name); - public Tensor bias_add(Tensor value, RefVariable bias, string data_format = null, string name = null) + public Tensor bias_add(Tensor value, IVariableV1 bias, string data_format = null, string name = null) { return tf_with(ops.name_scope(name, "BiasAdd", new { value, bias }), scope => { name = scope; - return gen_nn_ops.bias_add(value, bias, data_format: data_format, name: name); + return gen_nn_ops.bias_add(value, bias.AsTensor(), data_format: data_format, name: name); }); } diff --git a/src/TensorFlowNET.Core/APIs/tf.state.cs b/src/TensorFlowNET.Core/APIs/tf.state.cs index cb5cb4f5..d86f88b1 100644 --- a/src/TensorFlowNET.Core/APIs/tf.state.cs +++ b/src/TensorFlowNET.Core/APIs/tf.state.cs @@ -18,7 +18,7 @@ namespace Tensorflow { public partial class tensorflow { - public Tensor assign_add(IVariableV1 @ref, T value, + public ITensorOrOperation assign_add(IVariableV1 @ref, T value, bool use_locking = false, string name = null) => state_ops.assign_add(@ref, value, use_locking: use_locking, name: name); } diff --git a/src/TensorFlowNET.Core/APIs/tf.train.cs b/src/TensorFlowNET.Core/APIs/tf.train.cs index b3819b7b..df10e79e 100644 --- a/src/TensorFlowNET.Core/APIs/tf.train.cs +++ b/src/TensorFlowNET.Core/APIs/tf.train.cs @@ -26,10 +26,10 @@ namespace Tensorflow public class train_internal { - public RefVariable create_global_step(Graph graph) + public IVariableV1 create_global_step(Graph graph) => TrainingUtil.create_global_step(graph); - public RefVariable get_global_step(Graph graph) + public IVariableV1 get_global_step(Graph graph) => TrainingUtil.get_global_step(graph); public Optimizer GradientDescentOptimizer(float learning_rate) diff --git a/src/TensorFlowNET.Core/APIs/tf.variable.cs b/src/TensorFlowNET.Core/APIs/tf.variable.cs index 5ebc305b..c730e805 100644 --- a/src/TensorFlowNET.Core/APIs/tf.variable.cs +++ b/src/TensorFlowNET.Core/APIs/tf.variable.cs @@ -50,30 +50,6 @@ namespace Tensorflow public IVariableV1[] trainable_variables(string scope = null) => (variables.trainable_variables() as List).ToArray(); - public RefVariable get_variable(string name, - TensorShape shape = null, - TF_DataType dtype = TF_DataType.DtInvalid, - object initializer = null, // IInitializer or Tensor - bool? trainable = null, - List collections = null, - bool? use_resource = null, - bool validate_shape = true, - VariableSynchronization synchronization = VariableSynchronization.Auto, - VariableAggregation aggregation = VariableAggregation.None) - { - var scope = Tensorflow.variable_scope.get_variable_scope(); - var store = Tensorflow.variable_scope._get_default_variable_store(); - return scope.get_variable(store, - name, - shape: shape, - dtype: dtype, - use_resource: use_resource, - validate_shape: validate_shape, - initializer: initializer, - trainable: trainable, - collections: collections); - } - public VariableScope get_variable_scope() => Tensorflow.variable_scope.get_variable_scope(); } diff --git a/src/TensorFlowNET.Core/Graphs/Graph.cs b/src/TensorFlowNET.Core/Graphs/Graph.cs index 4de12257..12dc66fb 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.cs @@ -535,7 +535,7 @@ namespace Tensorflow string debugString = string.Empty; public override string ToString() { - return $"{graph_key}, ({_handle})"; + return $"{graph_key}, 0x{_handle.ToString("x16")}"; /*if (string.IsNullOrEmpty(debugString)) { int len = 0; diff --git a/src/TensorFlowNET.Core/Keras/Optimizers/PolynomialDecay.cs b/src/TensorFlowNET.Core/Keras/Optimizers/PolynomialDecay.cs index fe1604cf..0d9c8306 100644 --- a/src/TensorFlowNET.Core/Keras/Optimizers/PolynomialDecay.cs +++ b/src/TensorFlowNET.Core/Keras/Optimizers/PolynomialDecay.cs @@ -34,7 +34,7 @@ namespace Tensorflow.Keras.Optimizers this.name = name; } - public Tensor __call__(RefVariable step) + public Tensor __call__(IVariableV1 step) { return tf_with(ops.name_scope(name ?? "PolynomialDecay"), scope => { diff --git a/src/TensorFlowNET.Core/Layers/Layer.cs b/src/TensorFlowNET.Core/Layers/Layer.cs index 83dc8c99..0d6d2f69 100644 --- a/src/TensorFlowNET.Core/Layers/Layer.cs +++ b/src/TensorFlowNET.Core/Layers/Layer.cs @@ -161,7 +161,7 @@ namespace Tensorflow.Layers initializer: initializer, trainable: trainable, getter: (name1, shape1, dtype1, initializer1, trainable1) => - tf.get_variable(name1, + tf.compat.v1.get_variable(name1, shape: new TensorShape(shape1), dtype: dtype1, initializer: initializer1, diff --git a/src/TensorFlowNET.Core/Operations/math_ops.cs b/src/TensorFlowNET.Core/Operations/math_ops.cs index a9b597be..2c5a557a 100644 --- a/src/TensorFlowNET.Core/Operations/math_ops.cs +++ b/src/TensorFlowNET.Core/Operations/math_ops.cs @@ -68,11 +68,11 @@ namespace Tensorflow return gen_math_ops.add_n(inputs, name: name); } - public static Tensor cast(RefVariable x, TF_DataType dtype = TF_DataType.DtInvalid, string name = null) + public static Tensor cast(IVariableV1 x, TF_DataType dtype = TF_DataType.DtInvalid, string name = null) { var base_type = dtype.as_base_dtype(); if (base_type == x.dtype) - return x; + return x.AsTensor(); return tf_with(ops.name_scope(name, "Cast", new { x }), scope => { @@ -81,7 +81,7 @@ namespace Tensorflow if (t_x.dtype.as_base_dtype() != base_type) t_x = gen_math_ops.cast(t_x, base_type, name: name); - return x; + return x.AsTensor(); }); } diff --git a/src/TensorFlowNET.Core/Training/AdamOptimizer.cs b/src/TensorFlowNET.Core/Training/AdamOptimizer.cs index 23cc951f..5c3672ff 100644 --- a/src/TensorFlowNET.Core/Training/AdamOptimizer.cs +++ b/src/TensorFlowNET.Core/Training/AdamOptimizer.cs @@ -79,7 +79,7 @@ namespace Tensorflow.Train use_locking: _use_locking).op; } - private Operation _apply_sparse_shared(Tensor grad, RefVariable var, Tensor indices, Func scatter_add) + private Operation _apply_sparse_shared(Tensor grad, IVariableV1 var, Tensor indices, Func scatter_add) { var (beta1_power_v, beta2_power_v) = _get_beta_accumulators(); Tensor beta1_power = math_ops.cast(beta1_power_v, var.dtype.as_base_dtype()); @@ -91,7 +91,7 @@ namespace Tensorflow.Train var lr = (lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power)); var m = get_slot(var, "m"); var m_scaled_g_values = grad * (1 - beta1_t); - var m_t = state_ops.assign(m, m * beta1_t, use_locking: _use_locking); + var m_t = state_ops.assign(m.AsTensor(), m.AsTensor() * beta1_t, use_locking: _use_locking); tf_with(ops.control_dependencies(new[] { m_t }), delegate { m_t = scatter_add(m, indices, m_scaled_g_values); @@ -99,7 +99,7 @@ namespace Tensorflow.Train var v = get_slot(var, "v"); var v_scaled_g_values = (grad * grad) * (1 - beta2_t); - var v_t = state_ops.assign(v, v * beta2_t, use_locking: _use_locking); + var v_t = state_ops.assign(v.AsTensor(), v.AsTensor() * beta2_t, use_locking: _use_locking); tf_with(ops.control_dependencies(new[] { v_t }), delegate { v_t = scatter_add(v, indices, v_scaled_g_values); @@ -132,8 +132,8 @@ namespace Tensorflow.Train { var (beta1_power, beta2_power) = _get_beta_accumulators(); ops.colocate_with(beta1_power); - var update_beta1 = beta1_power.assign(beta1_power * _beta1_t, use_locking: _use_locking); - var update_beta2 = beta2_power.assign(beta2_power * _beta2_t, use_locking: _use_locking); + var update_beta1 = beta1_power.assign(beta1_power.AsTensor() * _beta1_t, use_locking: _use_locking); + var update_beta2 = beta2_power.assign(beta2_power.AsTensor() * _beta2_t, use_locking: _use_locking); operations.Add(update_beta1); operations.Add(update_beta2); @@ -142,12 +142,12 @@ namespace Tensorflow.Train return control_flow_ops.group(operations.ToArray(), name: name_scope); } - private (RefVariable, RefVariable) _get_beta_accumulators() + private (IVariableV1, IVariableV1) _get_beta_accumulators() { ops.init_scope(); var graph = ops.get_default_graph(); - return (_get_non_slot_variable("beta1_power", graph: graph) as RefVariable, - _get_non_slot_variable("beta2_power", graph: graph) as RefVariable); + return (_get_non_slot_variable("beta1_power", graph: graph), + _get_non_slot_variable("beta2_power", graph: graph)); } public override void _prepare() diff --git a/src/TensorFlowNET.Core/Training/ExponentialMovingAverage.cs b/src/TensorFlowNET.Core/Training/ExponentialMovingAverage.cs index cc3527c2..bdcc69cc 100644 --- a/src/TensorFlowNET.Core/Training/ExponentialMovingAverage.cs +++ b/src/TensorFlowNET.Core/Training/ExponentialMovingAverage.cs @@ -13,7 +13,7 @@ namespace Tensorflow.Train bool _zero_debias; string _name; public string name => _name; - Dictionary _averages; + Dictionary _averages; public ExponentialMovingAverage(float decay, int? num_updates = null, bool zero_debias = false, string name = "ExponentialMovingAverage") @@ -22,7 +22,7 @@ namespace Tensorflow.Train _num_updates = num_updates; _zero_debias = zero_debias; _name = name; - _averages = new Dictionary(); + _averages = new Dictionary(); } /// diff --git a/src/TensorFlowNET.Core/Training/Optimizer.cs b/src/TensorFlowNET.Core/Training/Optimizer.cs index ebe9690d..e9beba08 100644 --- a/src/TensorFlowNET.Core/Training/Optimizer.cs +++ b/src/TensorFlowNET.Core/Training/Optimizer.cs @@ -43,7 +43,7 @@ namespace Tensorflow protected Tensor _lr_t; public Tensor LearningRateTensor => _lr_t; public bool _use_locking; - public Dictionary> _slots; + public Dictionary> _slots; public Dictionary _non_slot_dict; public Dictionary _deferred_slot_restorations; SlotCreator slot_creator = new SlotCreator(); @@ -57,7 +57,7 @@ namespace Tensorflow _use_locking = use_locking; _lr = learning_rate; // Dictionary of slots. - _slots = new Dictionary>(); + _slots = new Dictionary>(); _non_slot_dict = new Dictionary(); _deferred_slot_restorations = new Dictionary(); } @@ -71,7 +71,7 @@ namespace Tensorflow _use_locking = use_locking; _lr_t = learning_rate; // Dictionary of slots. - _slots = new Dictionary>(); + _slots = new Dictionary>(); _non_slot_dict = new Dictionary(); _deferred_slot_restorations = new Dictionary(); } @@ -207,7 +207,7 @@ namespace Tensorflow { apply_updates = state_ops.assign_add(global_step, ops.convert_to_tensor(1, dtype: global_step.dtype), - name: name); + name: name) as Operation; } }); } @@ -241,7 +241,7 @@ namespace Tensorflow /// /// /// - protected IVariableV1 _create_non_slot_variable(float initial_value, string name, RefVariable colocate_with) + protected IVariableV1 _create_non_slot_variable(float initial_value, string name, IVariableV1 colocate_with) { // Recommendation: Use OptimizerV2 if your optimizer uses non-slot variables. var graph = colocate_with.Graph; @@ -338,7 +338,7 @@ namespace Tensorflow /// /// /// - protected RefVariable get_slot(RefVariable var, string name) + protected IVariableV1 get_slot(IVariableV1 var, string name) { var named_slots = _slots.ContainsKey(name) ? _slots[name] : null; if (named_slots == null) @@ -347,7 +347,7 @@ namespace Tensorflow return named_slots.ContainsKey(_var_key(var)) ? named_slots[_var_key(var)] : null; } - private string _var_key(RefVariable var) + private string _var_key(IVariableV1 var) { return $"{var.Op.graph.graph_key}.{var.Op.name}"; } @@ -438,7 +438,7 @@ namespace Tensorflow /// /// /// - protected RefVariable _zeros_slot(RefVariable var, string slot_name, string op_name) + protected IVariableV1 _zeros_slot(IVariableV1 var, string slot_name, string op_name) { var named_slots = _slot_dict(slot_name); if (!named_slots.ContainsKey(_var_key(var))) @@ -453,18 +453,18 @@ namespace Tensorflow /// /// Restore a newly created slot variable's value. /// - protected void _restore_slot_variable(string slot_name, RefVariable variable, RefVariable slot_variable) + protected void _restore_slot_variable(string slot_name, IVariableV1 variable, IVariableV1 slot_variable) { var variable_key = _var_key(variable); // TODO } - protected Dictionary _slot_dict(string slot_name) + protected Dictionary _slot_dict(string slot_name) { var named_slots = _slots.ContainsKey(slot_name) ? _slots[slot_name] : null; if(named_slots == null) { - named_slots = new Dictionary(); + named_slots = new Dictionary(); _slots[slot_name] = named_slots; } diff --git a/src/TensorFlowNET.Core/Training/SlotCreator.cs b/src/TensorFlowNET.Core/Training/SlotCreator.cs index 3a27158d..408f639c 100644 --- a/src/TensorFlowNET.Core/Training/SlotCreator.cs +++ b/src/TensorFlowNET.Core/Training/SlotCreator.cs @@ -30,7 +30,7 @@ namespace Tensorflow.Train /// /// /// - public RefVariable create_slot(RefVariable primary, Tensor val, string name, bool colocate_with_primary = true) + public IVariableV1 create_slot(RefVariable primary, Tensor val, string name, bool colocate_with_primary = true) { var validate_shape = val.TensorShape.is_fully_defined(); var prefix = primary.Op.name; @@ -48,7 +48,7 @@ namespace Tensorflow.Train /// /// /// - public RefVariable create_zeros_slot(RefVariable primary, string name, TF_DataType dtype = TF_DataType.DtInvalid, bool colocate_with_primary = true) + public IVariableV1 create_zeros_slot(IVariableV1 primary, string name, TF_DataType dtype = TF_DataType.DtInvalid, bool colocate_with_primary = true) { if (dtype == TF_DataType.DtInvalid) dtype = primary.dtype; @@ -70,7 +70,7 @@ namespace Tensorflow.Train /// Creates a slot initialized using an `Initializer`. /// /// - public RefVariable create_slot_with_initializer(RefVariable primary, IInitializer initializer, TensorShape shape, + public IVariableV1 create_slot_with_initializer(IVariableV1 primary, IInitializer initializer, TensorShape shape, TF_DataType dtype, string name, bool colocate_with_primary = true) { var validate_shape = shape.is_fully_defined(); @@ -91,14 +91,14 @@ namespace Tensorflow.Train /// /// /// - private RefVariable _create_slot_var(IVariableV1 primary, object val, string scope, bool validate_shape, + private IVariableV1 _create_slot_var(IVariableV1 primary, object val, string scope, bool validate_shape, TensorShape shape, TF_DataType dtype) { bool use_resource = primary is ResourceVariable; if (resource_variable_ops.is_resource_variable(primary)) use_resource = true; - var slot = tf.get_variable( + var slot = tf.compat.v1.get_variable( scope, initializer: val, trainable: false, diff --git a/src/TensorFlowNET.Core/Training/TrainingUtil.cs b/src/TensorFlowNET.Core/Training/TrainingUtil.cs index 79a1de4b..dbfe916b 100644 --- a/src/TensorFlowNET.Core/Training/TrainingUtil.cs +++ b/src/TensorFlowNET.Core/Training/TrainingUtil.cs @@ -7,7 +7,7 @@ namespace Tensorflow.Train { public class TrainingUtil { - public static RefVariable create_global_step(Graph graph = null) + public static IVariableV1 create_global_step(Graph graph = null) { graph = graph ?? ops.get_default_graph(); if (get_global_step(graph) != null) @@ -16,7 +16,7 @@ namespace Tensorflow.Train // Create in proper graph and base name_scope. var g = graph.as_default(); g.name_scope(null); - var v = tf.get_variable(tf.GraphKeys.GLOBAL_STEP, new int[0], dtype: dtypes.int64, + var v = tf.compat.v1.get_variable(tf.GraphKeys.GLOBAL_STEP, new int[0], dtype: dtypes.int64, initializer: tf.zeros_initializer, trainable: false, aggregation: VariableAggregation.OnlyFirstReplica, diff --git a/src/TensorFlowNET.Core/Training/gen_training_ops.cs b/src/TensorFlowNET.Core/Training/gen_training_ops.cs index 7de977f4..ee637b7d 100644 --- a/src/TensorFlowNET.Core/Training/gen_training_ops.cs +++ b/src/TensorFlowNET.Core/Training/gen_training_ops.cs @@ -23,7 +23,7 @@ namespace Tensorflow { public class gen_training_ops { - public static Tensor apply_adam(RefVariable var, RefVariable m, RefVariable v, Tensor beta1_power, Tensor beta2_power, + public static Tensor apply_adam(IVariableV1 var, IVariableV1 m, IVariableV1 v, Tensor beta1_power, Tensor beta2_power, Tensor lr, Tensor beta1, Tensor beta2, Tensor epsilon, Tensor grad, bool use_locking = false, bool use_nesterov = false, string name = null) { diff --git a/src/TensorFlowNET.Core/Training/moving_averages.cs b/src/TensorFlowNET.Core/Training/moving_averages.cs index de4e7f2e..be91a4b7 100644 --- a/src/TensorFlowNET.Core/Training/moving_averages.cs +++ b/src/TensorFlowNET.Core/Training/moving_averages.cs @@ -16,7 +16,7 @@ namespace Tensorflow.Train /// /// /// - public static Tensor assign_moving_average(RefVariable variable, RefVariable value, Tensor decay, + public static Tensor assign_moving_average(IVariableV1 variable, IVariableV1 value, Tensor decay, bool zero_debias = true, string name = null) { return tf_with(ops.name_scope(name, "AssignMovingAvg", new { variable, value, decay }), scope => @@ -25,7 +25,7 @@ namespace Tensorflow.Train if (decay.dtype != variable.dtype.as_base_dtype()) decay = math_ops.cast(decay, variable.dtype.as_base_dtype()); - return state_ops.assign_sub(variable, (variable - value) * decay, name: scope); + return state_ops.assign_sub(variable, (variable.AsTensor() - value.AsTensor()) * decay, name: scope); }); } } diff --git a/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs b/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs index 822d1f66..9adc5e4f 100644 --- a/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs +++ b/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs @@ -70,14 +70,15 @@ namespace Tensorflow // handle_deleter } - public BaseResourceVariable assign(object value, bool use_locking = false, string name = null, bool read_value = true) + public ITensorOrOperation assign(T value, bool use_locking = false, string name = null, bool read_value = true) { var value_tensor = ops.convert_to_tensor(value, dtype: dtype); var assign_op = gen_resource_variable_ops.assign_variable_op( handle, value_tensor, name: name); if (read_value) - return _lazy_read(assign_op, value_tensor); - return null; + return gen_resource_variable_ops.read_variable_op(handle, dtype); + // return _lazy_read(assign_op, value_tensor); + return assign_op; } public Tensor value() => _read_variable_op(); @@ -122,13 +123,14 @@ namespace Tensorflow return array_ops.identity(value); }); - public Operation assign_add(T delta, bool use_locking = false, string name = null, bool read_value = true) + public ITensorOrOperation assign_add(T delta, bool use_locking = false, string name = null, bool read_value = true) { var assign_add_op = gen_resource_variable_ops.assign_add_variable_op(Handle, ops.convert_to_tensor(delta, dtype: dtype), name: name); - /*if (read_value) - return _lazy_read(assign_add_op);*/ + if (read_value) + return gen_resource_variable_ops.read_variable_op(handle, dtype); + // return _lazy_read(assign_add_op); return assign_add_op; } @@ -145,5 +147,7 @@ namespace Tensorflow protected override void DisposeUnmanagedResources(IntPtr handle) { } + + public Tensor AsTensor() => _graph_element; } } diff --git a/src/TensorFlowNET.Core/Variables/IVariableV1.cs b/src/TensorFlowNET.Core/Variables/IVariableV1.cs index e7389522..68a1b78a 100644 --- a/src/TensorFlowNET.Core/Variables/IVariableV1.cs +++ b/src/TensorFlowNET.Core/Variables/IVariableV1.cs @@ -38,6 +38,9 @@ namespace Tensorflow public Tensor GraphElement { get; } public Graph Graph { get; } public TF_DataType dtype { get; } - public Operation assign_add(T delta, bool use_locking = false, string name = null, bool read_value = true); + public TensorShape shape { get; } + ITensorOrOperation assign_add(T delta, bool use_locking = false, string name = null, bool read_value = true); + ITensorOrOperation assign(T value, bool use_locking = false, string name = null, bool read_value = true); + Tensor AsTensor(); } } diff --git a/src/TensorFlowNET.Core/Variables/RefVariable.Implicit.cs b/src/TensorFlowNET.Core/Variables/RefVariable.Implicit.cs index 864dc8c4..6bc90ae9 100644 --- a/src/TensorFlowNET.Core/Variables/RefVariable.Implicit.cs +++ b/src/TensorFlowNET.Core/Variables/RefVariable.Implicit.cs @@ -14,7 +14,7 @@ public static implicit operator Tensor(RefVariable var) { - return var._AsTensor(); + return var.AsTensor(); } public static implicit operator RefVariable(Tensor var) diff --git a/src/TensorFlowNET.Core/Variables/RefVariable.cs b/src/TensorFlowNET.Core/Variables/RefVariable.cs index 38d124ac..7ed50bcc 100644 --- a/src/TensorFlowNET.Core/Variables/RefVariable.cs +++ b/src/TensorFlowNET.Core/Variables/RefVariable.cs @@ -220,7 +220,7 @@ namespace Tensorflow public Tensor value() => _snapshot; - public Tensor _AsTensor() => _snapshot; + public Tensor AsTensor() => _snapshot; public Tensor _as_graph_element() => _variable; @@ -333,7 +333,7 @@ namespace Tensorflow /// A `Tensor` that will hold the new value of this variable after /// the assignment has completed. /// - public ITensorOrOperation assign(object value, bool use_locking = false, string name = null, bool read_value = true) + public ITensorOrOperation assign(T value, bool use_locking = false, string name = null, bool read_value = true) { var assign = gen_state_ops.assign(_variable, value, use_locking: use_locking, name: name); if (read_value) @@ -416,7 +416,7 @@ namespace Tensorflow // name: A name for the operation(optional). // Returns: // A mutable `Tensor`. Has the same type as `ref`. - public Operation assign_add(T value, bool use_locking = false, string name = null, bool read_value = true) + public ITensorOrOperation assign_add(T value, bool use_locking = false, string name = null, bool read_value = true) { var variable = this; var _op = tf._op_def_lib._apply_op_helper("AssignAdd", name: name, args: new { variable, value, use_locking }); diff --git a/src/TensorFlowNET.Core/Variables/VariableScope.cs b/src/TensorFlowNET.Core/Variables/VariableScope.cs index a95a101c..acfaac95 100644 --- a/src/TensorFlowNET.Core/Variables/VariableScope.cs +++ b/src/TensorFlowNET.Core/Variables/VariableScope.cs @@ -47,7 +47,7 @@ namespace Tensorflow _dtype = dtype; } - public RefVariable get_variable(_VariableStore var_store, + public IVariableV1 get_variable(_VariableStore var_store, string name, TensorShape shape = null, TF_DataType dtype = TF_DataType.DtInvalid, @@ -73,7 +73,7 @@ namespace Tensorflow trainable: trainable, collections: collections, synchronization: synchronization, - aggregation: aggregation) as RefVariable; + aggregation: aggregation); }); } diff --git a/src/TensorFlowNET.Core/Variables/_VariableStore.cs b/src/TensorFlowNET.Core/Variables/_VariableStore.cs index 5b6b6b54..291ad99b 100644 --- a/src/TensorFlowNET.Core/Variables/_VariableStore.cs +++ b/src/TensorFlowNET.Core/Variables/_VariableStore.cs @@ -92,7 +92,7 @@ namespace Tensorflow return _get_single_variable(name: name, shape: shape, dtype: dtype, - initializer: tensor, + init_value: tensor, trainable: trainable, validate_shape: validate_shape, synchronization: synchronization, @@ -116,6 +116,7 @@ namespace Tensorflow TensorShape shape = null, TF_DataType dtype = TF_DataType.DtInvalid, IInitializer initializer = null, + Tensor init_value = null, bool reuse = false, bool? trainable = null, List collections = null, @@ -124,9 +125,9 @@ namespace Tensorflow VariableSynchronization synchronization = VariableSynchronization.Auto, VariableAggregation aggregation = VariableAggregation.None) { - bool initializing_from_value = false; + bool initializing_from_value = init_value != null; if (use_resource == null) - use_resource = false; + use_resource = variable_scope._DEFAULT_USE_RESOURCE; if (_vars.ContainsKey(name)) { @@ -140,7 +141,7 @@ namespace Tensorflow IVariableV1 v = null; // Create the tensor to initialize the variable with default value. - if (initializer == null) + if (initializer == null && init_value == null) { if (dtype.is_floating()) { @@ -154,7 +155,10 @@ namespace Tensorflow { if (initializing_from_value) { - + v = new ResourceVariable(init_value, + name: name, + validate_shape: validate_shape, + trainable: trainable.Value); } else { @@ -166,6 +170,7 @@ namespace Tensorflow trainable: trainable, collections: collections, dtype: variable_dtype, + use_resource: use_resource, validate_shape: validate_shape, synchronization: synchronization, aggregation: aggregation); @@ -176,45 +181,5 @@ namespace Tensorflow return v; } - - private RefVariable _get_single_variable(string name, - TensorShape shape = null, - TF_DataType dtype = TF_DataType.DtInvalid, - Tensor initializer = null, - bool reuse = false, - bool? trainable = null, - bool validate_shape = false, - bool? use_resource = null, - VariableSynchronization synchronization = VariableSynchronization.Auto, - VariableAggregation aggregation = VariableAggregation.None) - { - if (use_resource == null) - use_resource = false; - - if (_vars.ContainsKey(name)) - { - if (!reuse) - { - var var = _vars[name]; - - } - throw new NotImplementedException("_get_single_variable"); - } - - RefVariable v = null; - // Create the variable. - ops.init_scope(); - { - var init_val = initializer; - v = new RefVariable(init_val, - name: name, - validate_shape: validate_shape, - trainable: trainable.Value); - } - - _vars[name] = v; - - return v; - } } } diff --git a/src/TensorFlowNET.Core/Variables/gen_state_ops.py.cs b/src/TensorFlowNET.Core/Variables/gen_state_ops.py.cs index 9a566d70..3440eb76 100644 --- a/src/TensorFlowNET.Core/Variables/gen_state_ops.py.cs +++ b/src/TensorFlowNET.Core/Variables/gen_state_ops.py.cs @@ -110,7 +110,7 @@ namespace Tensorflow return _result[0]; } - public static Tensor assign_sub(RefVariable @ref, + public static Tensor assign_sub(IVariableV1 @ref, Tensor value, bool use_locking = false, string name = null) @@ -129,7 +129,7 @@ namespace Tensorflow /// /// /// - public static Tensor scatter_add(RefVariable @ref, Tensor indices, Tensor updates, bool use_locking = false, string name = null) + public static Tensor scatter_add(IVariableV1 @ref, Tensor indices, Tensor updates, bool use_locking = false, string name = null) { var _op = tf._op_def_lib._apply_op_helper("ScatterAdd", name: name, args: new { @ref, indices, updates, use_locking }); return _op.outputs[0]; diff --git a/src/TensorFlowNET.Core/Variables/state_ops.cs b/src/TensorFlowNET.Core/Variables/state_ops.cs index 4ad626ef..ad621915 100644 --- a/src/TensorFlowNET.Core/Variables/state_ops.cs +++ b/src/TensorFlowNET.Core/Variables/state_ops.cs @@ -78,7 +78,7 @@ namespace Tensorflow name: name); } - public static Tensor assign_sub(RefVariable @ref, + public static Tensor assign_sub(IVariableV1 @ref, Tensor value, bool use_locking = false, string name = null) => gen_state_ops.assign_sub(@ref, @@ -106,13 +106,13 @@ namespace Tensorflow // Returns: // Same as "ref". Returned as a convenience for operations that want // to use the new value after the variable has been updated. - public static Operation assign_add(IVariableV1 @ref, + public static ITensorOrOperation assign_add(IVariableV1 @ref, T value, bool use_locking = false, string name = null) => @ref.assign_add(value, use_locking: use_locking, name: name); - public static Tensor scatter_add(RefVariable @ref, Tensor indices, Tensor updates, bool use_locking = false, string name = null) + public static Tensor scatter_add(IVariableV1 @ref, Tensor indices, Tensor updates, bool use_locking = false, string name = null) { if (@ref.dtype.is_ref_dtype()) return gen_state_ops.scatter_add(@ref, indices, updates, use_locking: use_locking, name: name); diff --git a/src/TensorFlowNET.Core/Variables/variable_scope.py.cs b/src/TensorFlowNET.Core/Variables/variable_scope.py.cs index f538dd02..41f1132d 100644 --- a/src/TensorFlowNET.Core/Variables/variable_scope.py.cs +++ b/src/TensorFlowNET.Core/Variables/variable_scope.py.cs @@ -27,7 +27,7 @@ namespace Tensorflow { public static string _VARSTORE_KEY = "__variable_store"; public static string _VARSCOPESTORE_KEY = "__varscope"; - public static bool _DEFAULT_USE_RESOURCE = false; + public static bool _DEFAULT_USE_RESOURCE = true; private bool _use_resource; public bool UseResource => _use_resource; diff --git a/src/TensorFlowNET.Core/ops.cs b/src/TensorFlowNET.Core/ops.cs index 93e2ba7a..d23d49bb 100644 --- a/src/TensorFlowNET.Core/ops.cs +++ b/src/TensorFlowNET.Core/ops.cs @@ -304,6 +304,11 @@ namespace Tensorflow _colocate_with_for_gradient(tensor.op, null, ignore_existing); } + public static void colocate_with(IVariableV1 variable, bool ignore_existing = false) + { + _colocate_with_for_gradient(variable.AsTensor(), null, ignore_existing); + } + public static void _colocate_with_for_gradient(Operation op, string gradient_uid, bool ignore_existing = false) { var default_graph = get_default_graph();