diff --git a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs index c850abc2..28115584 100644 --- a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs @@ -268,6 +268,16 @@ namespace Tensorflow public static Tensor rank(Tensor input, string name = null) { + if (tf.context.executing_eagerly()) + { + var results = tf.Runner.TFE_FastPathExecute(tf.context, tf.context.device_name, + "Rank", name, + null, + input); + + return results[0]; + } + var _op = tf._op_def_lib._apply_op_helper("Rank", name: name, args: new { input }); return _op.outputs[0]; diff --git a/src/TensorFlowNET.Core/Operations/math_ops.cs b/src/TensorFlowNET.Core/Operations/math_ops.cs index 0c167691..a9b597be 100644 --- a/src/TensorFlowNET.Core/Operations/math_ops.cs +++ b/src/TensorFlowNET.Core/Operations/math_ops.cs @@ -567,7 +567,7 @@ namespace Tensorflow } else { - if(x is Tensor) + if(x.rank > -1) return constant_op.constant(np.arange(x.rank)); var rank = array_ops.rank(x); diff --git a/src/TensorFlowNET.Core/Training/AdamOptimizer.cs b/src/TensorFlowNET.Core/Training/AdamOptimizer.cs index 1210af3b..23cc951f 100644 --- a/src/TensorFlowNET.Core/Training/AdamOptimizer.cs +++ b/src/TensorFlowNET.Core/Training/AdamOptimizer.cs @@ -109,7 +109,7 @@ namespace Tensorflow.Train return control_flow_ops.group(new[] { var_update, m_t, v_t }); } - protected override void _create_slots(RefVariable[] var_list) + protected override void _create_slots(ResourceVariable[] var_list) { var first_var = var_list.OrderBy(x => x.Name).First(); _create_non_slot_variable(initial_value: _beta1, name: "beta1_power", colocate_with: first_var); diff --git a/src/TensorFlowNET.Core/Training/Optimizer.cs b/src/TensorFlowNET.Core/Training/Optimizer.cs index 1a77858e..e3c4ca03 100644 --- a/src/TensorFlowNET.Core/Training/Optimizer.cs +++ b/src/TensorFlowNET.Core/Training/Optimizer.cs @@ -107,7 +107,7 @@ namespace Tensorflow /// public Operation minimize(Tensor loss, RefVariable global_step = null, - List var_list=null, + List var_list=null, GateGradientType gate_gradients = GateGradientType.GATE_OP, int? aggregation_method=null, bool colocate_gradients_with_ops = false, string name=null, Tensor grad_loss=null) @@ -142,17 +142,17 @@ namespace Tensorflow /// /// An `Operation` that applies the specified gradients. If `global_step` /// was not None, that operation also increments `global_step`. - public Operation apply_gradients(Tuple[] grads_and_vars, RefVariable global_step = null, string name = null) + public Operation apply_gradients(Tuple[] grads_and_vars, RefVariable global_step = null, string name = null) { // No DistributionStrategy case. - var converted_grads_and_vars = new List<(Tensor, RefVariable, _OptimizableVariable)>(); + var converted_grads_and_vars = new List<(Tensor, ResourceVariable, _OptimizableVariable)>(); foreach (var (g, v) in grads_and_vars) { if(g != null) { // Convert the grad to Tensor or IndexedSlices if necessary. var gR = ops.convert_to_tensor_or_indexed_slices(g); - var p = _get_processor(v); + var p = optimizer._get_processor(v); converted_grads_and_vars.Add((gR, v, p)); } } @@ -230,7 +230,7 @@ namespace Tensorflow /// silently ignored). /// /// - protected virtual void _create_slots(RefVariable[] var_list) + protected virtual void _create_slots(ResourceVariable[] var_list) { } @@ -276,6 +276,12 @@ namespace Tensorflow return control_flow_ops.group(update_ops, name_scope); } + public virtual Operation _apply_dense(Tensor grad, ResourceVariable var) + { + var alpha = math_ops.cast(LearningRateTensor, var.dtype.as_base_dtype()); + return gen_training_ops.resource_apply_gradient_descent(var.Handle, alpha, grad, use_locking: _use_locking).op; + } + public virtual Operation _apply_dense(Tensor grad, RefVariable var) { var alpha = math_ops.cast(LearningRateTensor, var.dtype.as_base_dtype()); @@ -298,6 +304,16 @@ namespace Tensorflow return _apply_sparse(gradient_no_duplicate_indices, var); } + public virtual Operation _apply_sparse_duplicate_indices(IndexedSlices grad, ResourceVariable var) + { + var (summed_values, unique_indices) = _deduplicate_indexed_slices(values: grad.values, indices: grad.indices); + var gradient_no_duplicate_indices = new IndexedSlices( + indices: unique_indices, + values: summed_values, + dense_shape: grad.dense_shape); + return _apply_sparse(gradient_no_duplicate_indices, var); + } + public virtual Operation _apply_sparse(IndexedSlices grad, RefVariable var) { throw new NotImplementedException("_apply_sparse"); @@ -344,18 +360,6 @@ namespace Tensorflow return non_slot; } - private _OptimizableVariable _get_processor(RefVariable v) - { - if(v is RefVariable) - { - return new _RefVariableProcessor(v); - } - else - { - throw new NotImplementedException("_get_processor"); - } - } - /// /// Compute gradients of `loss` for the variables in `var_list`. /// @@ -365,8 +369,8 @@ namespace Tensorflow /// A list of (gradient, variable) pairs. Variable is always present, but /// gradient can be `None`. /// - public Tuple[] compute_gradients(Tensor loss, - List var_list = null, + public Tuple[] compute_gradients(Tensor loss, + List var_list = null, int? aggregation_method = null, GateGradientType gate_gradients = GateGradientType.GATE_OP, bool colocate_gradients_with_ops = false, @@ -374,26 +378,28 @@ namespace Tensorflow { // Scale loss if using a "mean" loss reduction and multiple replicas. loss = _scale_loss(loss); -#pragma warning disable CS0219 // Variable is assigned but its value is never used - int num_towers = 1; -#pragma warning restore CS0219 // Variable is assigned but its value is never used if(var_list == null) { - var vars = ops.get_collection(tf.GraphKeys.TRAINABLE_RESOURCE_VARIABLES); + var vars = ops.get_collection(tf.GraphKeys.TRAINABLE_RESOURCE_VARIABLES); var tmp = variables.trainable_variables(); switch (tmp) { - case List values: + case List values: + var_list = values.Concat(vars).ToList(); + break; + /*case List values: var_list = values.Concat(vars).ToList(); break; case List values: var_list = values.Select(x => x as RefVariable).Concat(vars).ToList(); - break; + break;*/ + default: + throw new NotImplementedException(""); } } - var_list = var_list.Concat(ops.get_collection(tf.GraphKeys._STREAMING_MODEL_PORTS)).ToList(); + var_list = var_list.Concat(ops.get_collection(tf.GraphKeys._STREAMING_MODEL_PORTS)).ToList(); var processors = var_list.Select(v => optimizer._get_processor(v)).ToList(); var var_refs = processors.Select(x => x.target()).ToArray(); @@ -406,7 +412,7 @@ namespace Tensorflow grads = control_flow_ops.tuple(grads); var grads_and_vars = zip(grads, var_list) - .Select(x => new Tuple(x.Item1, x.Item2)) + .Select(x => new Tuple(x.Item1, x.Item2)) .ToArray(); return grads_and_vars; diff --git a/src/TensorFlowNET.Core/Training/gen_training_ops.cs b/src/TensorFlowNET.Core/Training/gen_training_ops.cs index c744733f..7de977f4 100644 --- a/src/TensorFlowNET.Core/Training/gen_training_ops.cs +++ b/src/TensorFlowNET.Core/Training/gen_training_ops.cs @@ -59,7 +59,7 @@ namespace Tensorflow return _op.outputs[0]; } - public static Operation resource_apply_gradient_descent(EagerTensor var, EagerTensor alpha, EagerTensor delta, bool use_locking = false, string name = null) + public static Operation resource_apply_gradient_descent(Tensor var, Tensor alpha, Tensor delta, bool use_locking = false, string name = null) { if (tf.context.executing_eagerly()) { @@ -79,7 +79,7 @@ namespace Tensorflow use_locking }); - return _op.outputs[0]; + return _op; } } } diff --git a/src/TensorFlowNET.Core/Training/optimizer.py.cs b/src/TensorFlowNET.Core/Training/optimizer.py.cs index 9f48d161..115af574 100644 --- a/src/TensorFlowNET.Core/Training/optimizer.py.cs +++ b/src/TensorFlowNET.Core/Training/optimizer.py.cs @@ -24,6 +24,11 @@ namespace Tensorflow { return new _RefVariableProcessor(v); } + + public static _OptimizableVariable _get_processor(ResourceVariable v) + { + return new _DenseResourceVariableProcessor(v); + } } public class _RefVariableProcessor : _OptimizableVariable @@ -56,4 +61,35 @@ namespace Tensorflow return update_op; } } + + public class _DenseResourceVariableProcessor : _OptimizableVariable + { + private ResourceVariable _v; + + public _DenseResourceVariableProcessor(ResourceVariable v) + { + _v = v; + } + + public Tensor target() + { + return _v.Handle; + } + + public Operation update_op(Optimizer optimizer, Tensor g) + { + Operation update_op = null; + + if (g.Tag == null) + { + update_op = optimizer._apply_dense(g, _v); + } + else if (g.Tag is IndexedSlices) + { + return optimizer._apply_sparse_duplicate_indices(g, _v); + } + + return update_op; + } + } } diff --git a/src/TensorFlowNET.Core/WeakKeyDicionary.cs b/src/TensorFlowNET.Core/WeakKeyDicionary.cs deleted file mode 100644 index c6504282..00000000 --- a/src/TensorFlowNET.Core/WeakKeyDicionary.cs +++ /dev/null @@ -1,438 +0,0 @@ -/***************************************************************************** - Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -******************************************************************************/ - -using System; -using System.Collections; -using System.Collections.Generic; -using System.Diagnostics.CodeAnalysis; - -namespace Tensorflow -{ - public class WeakKeyDictionary : IDictionary - { - - private Dictionary _internalDictionary; - private object _internalObject = new object(); - private bool _finalized; - - public WeakKeyDictionary() - { - _internalDictionary = new Dictionary(new WeakComparer()); - } - - public WeakKeyDictionary(int capacity) - { - _internalDictionary = new Dictionary(capacity, new WeakComparer()); - } - - public WeakKeyDictionary(IEqualityComparer comparer) - { - _internalDictionary = new Dictionary(new WeakComparer(comparer)); - } - - public WeakKeyDictionary(int capacity, IEqualityComparer comparer) - { - _internalDictionary = new Dictionary(capacity, new WeakComparer(comparer)); - } - - // FXCop: this is not empty; we need to mark this so we know if a key - // still has an active dictionary at its finalization. - [SuppressMessage("Microsoft.Performance", "CA1821:RemoveEmptyFinalizers")] - ~WeakKeyDictionary() - { - _finalized = true; - } - - public ICollection Keys - { - get - { - List list = new List(); - lock (_internalObject) - { - foreach (WeakKey key in _internalDictionary.Keys) - { - object TKey = key.Target; - if (TKey != null) - { - list.Add((TKey)TKey); - } - } - } - return list; - } - } - - public ICollection Values - { - get { - lock (_internalObject) { - return _internalDictionary.Values; - } - } - } - - public int Count - { - get - { - // Ensure a fairly accurate count. - ScavangeLostKeys(); - lock (_internalObject) - { - return _internalDictionary.Count; - } - } - } - - public bool IsReadOnly - { - get { - return false; - } - } - - [SuppressMessage("Microsoft.Usage", "CA1806:DoNotIgnoreMethodResults", Justification = "LostKeyFinder's purpose is to get garbage collected as soon as posible")] - public TValue this[TKey key] - { - get { - lock (_internalObject) { - return _internalDictionary[new WeakKey(key)]; - } - } - set - { - WeakKey Tkey = new WeakKey(key); - lock (_internalObject) - { - //_internalDictionary[Tkey] = value; - _internalDictionary.Add(Tkey, value); - } - // This looks a bit weird but the purpose of the lost key finder is to execute - // code in some future garbage collection phase so we immediately create some garbage. - new LostKeyFinder(this, Tkey); - } - } - - - - - - public bool TryGetValue(TKey key, out TValue value) - { - WeakKey tkey = new WeakKey(key); - lock (_internalObject) - { - return _internalDictionary.TryGetValue(tkey, out value); - } - } - - - [SuppressMessage("Microsoft.Usage", "CA1806:DoNotIgnoreMethodResults", Justification = "LostKeyFinder's purpose is to get garbage collected as soon as posible")] - public void Add(TKey key, TValue value) - { - WeakKey tkey = new WeakKey(key); - lock (_internalObject) - { - _internalDictionary.Add(tkey, value); - } - // This looks a bit weird but the purpose of the lost key finder is to execute - // code in some future garbage collection phase so we immediately create some garbage. - new LostKeyFinder(this, tkey); - - } - - public bool ContainsKey(TKey key) - { - return _internalDictionary.ContainsKey(new WeakKey(key)); - } - - public bool Remove(TKey key) - { - lock (_internalObject) - { - return _internalDictionary.Remove(new WeakKey(key)); - } - } - - public void Add(KeyValuePair item) - { - Add(item.Key, item.Value); - } - - public void Clear() - { - lock (_internalObject) - { - _internalDictionary.Clear(); - } - } - - public bool Contains(KeyValuePair item) - { - TValue value; - bool result; - lock (_internalObject) - { - result = _internalDictionary.TryGetValue(new WeakKey(item.Key), out value); - } - if (result) - { - return value.Equals(item.Value); - } - else - { - return false; - } - } - - public void CopyTo(KeyValuePair[] array, int arrayIndex) - { - lock (_internalObject) - { - foreach (KeyValuePair item in _internalDictionary) - { - KeyValuePair kv = new KeyValuePair((TKey)item.Key.Target, item.Value); - array[arrayIndex] = kv; - arrayIndex++; - } - } - } - - public bool Remove(KeyValuePair item) - { - WeakKey key = new WeakKey(item.Key); - lock (_internalObject) - { - return _internalDictionary.Remove(key); - } - } - - - - - - public IEnumerator> GetEnumerator() - { - List lostKeys = null; - lock (_internalObject) - { - foreach (KeyValuePair item in _internalDictionary) - { - object TKey = item.Key.Target; - if (TKey != null) - { - yield return new KeyValuePair((TKey)TKey, item.Value); - } - else - { - if (lostKeys == null) - { - lostKeys = new List(); - } - lostKeys.Add(item.Key); - } - } - } - // Recover any lost keys. - if (lostKeys != null) - { - lock (_internalObject) - { - foreach (WeakKey key in lostKeys) - { - _internalDictionary.Remove(key); - } - } - } - } - - - - - IEnumerator IEnumerable.GetEnumerator() - { - return GetEnumerator(); - } - - - - private void ScavangeLostKeys() - { - List lostKeys = null; - lock (_internalObject) - { - foreach (WeakKey key in _internalDictionary.Keys) - { - if (!key.IsAlive) - { - if (lostKeys == null) - { - lostKeys = new List(); - } - lostKeys.Add(key); - } - } - } - if (lostKeys != null) - { - lock (_internalObject) - { - foreach (WeakKey key in lostKeys) - { - _internalDictionary.Remove(key); - } - } - } - } - - IEnumerator> IEnumerable>.GetEnumerator() - { - return this.GetEnumerator(); - } - - private class WeakKey : WeakReference - { - private int _hashCode; - // private GCHandle _gcHandle; - - public WeakKey(TKey key) - : base(key, true) - { - _hashCode = key.GetHashCode(); - // Keep the key alive until it is explicitly collected - // _gcHandle = GCHandle.Alloc(this); - } - - internal void Release() - { - // _gcHandle.Free(); - } - - public override int GetHashCode() - { - return _hashCode; - } - - public override bool Equals(object obj) - { - if (obj == null) - { - return false; - } - if (obj.GetHashCode() != _hashCode) - { - return false; - } - if (obj != this && (!IsAlive || !obj.Equals(Target))) - { - return false; - } - return true; - } - } - - private class WeakComparer : IEqualityComparer - { - - private IEqualityComparer _comparer; - public WeakComparer() - { - } - - public WeakComparer(IEqualityComparer comparer) - { - _comparer = comparer; - } - - public bool Equals(WeakKey x, WeakKey y) - { - if (x.GetHashCode() != y.GetHashCode()) - { - return false; - } - if (object.ReferenceEquals(x, y)) - { - return true; - } - object ref1 = x.Target; - if (ref1 == null) - { - return false; - } - object ref2 = y.Target; - if (ref2 == null) - { - return false; - } - - if (_comparer != null) - { - return _comparer.Equals((TKey)ref1, (TKey)ref2); - } - else - { - return ref1.Equals(ref2); - } - } - - public int GetHashCode(WeakKey obj) - { - return obj.GetHashCode(); - } - } - - private class LostKeyFinder - { - WeakKeyDictionary _dictionary; - WeakKey _key; - - public LostKeyFinder(WeakKeyDictionary dictionary, WeakKey key) - { - _dictionary = dictionary; - _key = key; - } - - ~LostKeyFinder() - { - if (_dictionary._finalized || _key == null) - { - if (_key != null) - { - _key.Release(); - _key = null; - } - return; - } - // if (!_key.IsAlive) { - if (_key.Target == null) - { - lock (_dictionary._internalObject) - { - _dictionary._internalDictionary.Remove(_key); - } - _key.Release(); - _key = null; - } - else if (_dictionary._internalDictionary.ContainsKey(_key)) - { - GC.ReRegisterForFinalize(this); - } - } - } - } -} - \ No newline at end of file diff --git a/test/TensorFlowNET.UnitTest/Basics/VariableTest.cs b/test/TensorFlowNET.UnitTest/Basics/VariableTest.cs index 41d4e6f1..a2e7d995 100644 --- a/test/TensorFlowNET.UnitTest/Basics/VariableTest.cs +++ b/test/TensorFlowNET.UnitTest/Basics/VariableTest.cs @@ -2,13 +2,13 @@ using Microsoft.VisualStudio.TestTools.UnitTesting; using NumSharp; using System.Linq; -using Tensorflow; +using Tensorflow.UnitTest; using static Tensorflow.Binding; namespace TensorFlowNET.UnitTest.Basics { [TestClass] - public class VariableTest + public class VariableTest : EagerModeTestBase { [TestMethod] public void NewVariable() diff --git a/test/TensorFlowNET.UnitTest/EagerModeTestBase.cs b/test/TensorFlowNET.UnitTest/EagerModeTestBase.cs new file mode 100644 index 00000000..752e093a --- /dev/null +++ b/test/TensorFlowNET.UnitTest/EagerModeTestBase.cs @@ -0,0 +1,23 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System; +using System.Collections.Generic; +using System.Text; +using TensorFlowNET.UnitTest; +using static Tensorflow.Binding; + +namespace Tensorflow.UnitTest +{ + public class EagerModeTestBase : PythonTest + { + [TestInitialize] + public void TestInit() + { + tf.enable_eager_execution(); + } + + [TestCleanup] + public void TestClean() + { + } + } +}