| @@ -268,6 +268,16 @@ namespace Tensorflow | |||
| public static Tensor rank(Tensor input, string name = null) | |||
| { | |||
| if (tf.context.executing_eagerly()) | |||
| { | |||
| var results = tf.Runner.TFE_FastPathExecute(tf.context, tf.context.device_name, | |||
| "Rank", name, | |||
| null, | |||
| input); | |||
| return results[0]; | |||
| } | |||
| var _op = tf._op_def_lib._apply_op_helper("Rank", name: name, args: new { input }); | |||
| return _op.outputs[0]; | |||
| @@ -567,7 +567,7 @@ namespace Tensorflow | |||
| } | |||
| else | |||
| { | |||
| if(x is Tensor) | |||
| if(x.rank > -1) | |||
| return constant_op.constant(np.arange(x.rank)); | |||
| var rank = array_ops.rank(x); | |||
| @@ -109,7 +109,7 @@ namespace Tensorflow.Train | |||
| return control_flow_ops.group(new[] { var_update, m_t, v_t }); | |||
| } | |||
| protected override void _create_slots(RefVariable[] var_list) | |||
| protected override void _create_slots(ResourceVariable[] var_list) | |||
| { | |||
| var first_var = var_list.OrderBy(x => x.Name).First(); | |||
| _create_non_slot_variable(initial_value: _beta1, name: "beta1_power", colocate_with: first_var); | |||
| @@ -107,7 +107,7 @@ namespace Tensorflow | |||
| /// </returns> | |||
| public Operation minimize(Tensor loss, | |||
| RefVariable global_step = null, | |||
| List<RefVariable> var_list=null, | |||
| List<ResourceVariable> var_list=null, | |||
| GateGradientType gate_gradients = GateGradientType.GATE_OP, | |||
| int? aggregation_method=null, | |||
| bool colocate_gradients_with_ops = false, string name=null, Tensor grad_loss=null) | |||
| @@ -142,17 +142,17 @@ namespace Tensorflow | |||
| /// <returns> | |||
| /// An `Operation` that applies the specified gradients. If `global_step` | |||
| /// was not None, that operation also increments `global_step`.</returns> | |||
| public Operation apply_gradients(Tuple<Tensor, RefVariable>[] grads_and_vars, RefVariable global_step = null, string name = null) | |||
| public Operation apply_gradients(Tuple<Tensor, ResourceVariable>[] grads_and_vars, RefVariable global_step = null, string name = null) | |||
| { | |||
| // No DistributionStrategy case. | |||
| var converted_grads_and_vars = new List<(Tensor, RefVariable, _OptimizableVariable)>(); | |||
| var converted_grads_and_vars = new List<(Tensor, ResourceVariable, _OptimizableVariable)>(); | |||
| foreach (var (g, v) in grads_and_vars) | |||
| { | |||
| if(g != null) | |||
| { | |||
| // Convert the grad to Tensor or IndexedSlices if necessary. | |||
| var gR = ops.convert_to_tensor_or_indexed_slices(g); | |||
| var p = _get_processor(v); | |||
| var p = optimizer._get_processor(v); | |||
| converted_grads_and_vars.Add((gR, v, p)); | |||
| } | |||
| } | |||
| @@ -230,7 +230,7 @@ namespace Tensorflow | |||
| /// silently ignored). | |||
| /// </summary> | |||
| /// <param name="var_list"></param> | |||
| protected virtual void _create_slots(RefVariable[] var_list) | |||
| protected virtual void _create_slots(ResourceVariable[] var_list) | |||
| { | |||
| } | |||
| @@ -276,6 +276,12 @@ namespace Tensorflow | |||
| return control_flow_ops.group(update_ops, name_scope); | |||
| } | |||
| public virtual Operation _apply_dense(Tensor grad, ResourceVariable var) | |||
| { | |||
| var alpha = math_ops.cast(LearningRateTensor, var.dtype.as_base_dtype()); | |||
| return gen_training_ops.resource_apply_gradient_descent(var.Handle, alpha, grad, use_locking: _use_locking).op; | |||
| } | |||
| public virtual Operation _apply_dense(Tensor grad, RefVariable var) | |||
| { | |||
| var alpha = math_ops.cast(LearningRateTensor, var.dtype.as_base_dtype()); | |||
| @@ -298,6 +304,16 @@ namespace Tensorflow | |||
| return _apply_sparse(gradient_no_duplicate_indices, var); | |||
| } | |||
| public virtual Operation _apply_sparse_duplicate_indices(IndexedSlices grad, ResourceVariable var) | |||
| { | |||
| var (summed_values, unique_indices) = _deduplicate_indexed_slices(values: grad.values, indices: grad.indices); | |||
| var gradient_no_duplicate_indices = new IndexedSlices( | |||
| indices: unique_indices, | |||
| values: summed_values, | |||
| dense_shape: grad.dense_shape); | |||
| return _apply_sparse(gradient_no_duplicate_indices, var); | |||
| } | |||
| public virtual Operation _apply_sparse(IndexedSlices grad, RefVariable var) | |||
| { | |||
| throw new NotImplementedException("_apply_sparse"); | |||
| @@ -344,18 +360,6 @@ namespace Tensorflow | |||
| return non_slot; | |||
| } | |||
| private _OptimizableVariable _get_processor(RefVariable v) | |||
| { | |||
| if(v is RefVariable) | |||
| { | |||
| return new _RefVariableProcessor(v); | |||
| } | |||
| else | |||
| { | |||
| throw new NotImplementedException("_get_processor"); | |||
| } | |||
| } | |||
| /// <summary> | |||
| /// Compute gradients of `loss` for the variables in `var_list`. | |||
| /// </summary> | |||
| @@ -365,8 +369,8 @@ namespace Tensorflow | |||
| /// A list of (gradient, variable) pairs. Variable is always present, but | |||
| /// gradient can be `None`. | |||
| /// </returns> | |||
| public Tuple<Tensor, RefVariable>[] compute_gradients(Tensor loss, | |||
| List<RefVariable> var_list = null, | |||
| public Tuple<Tensor, ResourceVariable>[] compute_gradients(Tensor loss, | |||
| List<ResourceVariable> var_list = null, | |||
| int? aggregation_method = null, | |||
| GateGradientType gate_gradients = GateGradientType.GATE_OP, | |||
| bool colocate_gradients_with_ops = false, | |||
| @@ -374,26 +378,28 @@ namespace Tensorflow | |||
| { | |||
| // Scale loss if using a "mean" loss reduction and multiple replicas. | |||
| loss = _scale_loss(loss); | |||
| #pragma warning disable CS0219 // Variable is assigned but its value is never used | |||
| int num_towers = 1; | |||
| #pragma warning restore CS0219 // Variable is assigned but its value is never used | |||
| if(var_list == null) | |||
| { | |||
| var vars = ops.get_collection<RefVariable>(tf.GraphKeys.TRAINABLE_RESOURCE_VARIABLES); | |||
| var vars = ops.get_collection<ResourceVariable>(tf.GraphKeys.TRAINABLE_RESOURCE_VARIABLES); | |||
| var tmp = variables.trainable_variables(); | |||
| switch (tmp) | |||
| { | |||
| case List<RefVariable> values: | |||
| case List<ResourceVariable> values: | |||
| var_list = values.Concat(vars).ToList(); | |||
| break; | |||
| /*case List<RefVariable> values: | |||
| var_list = values.Concat(vars).ToList(); | |||
| break; | |||
| case List<IVariableV1> values: | |||
| var_list = values.Select(x => x as RefVariable).Concat(vars).ToList(); | |||
| break; | |||
| break;*/ | |||
| default: | |||
| throw new NotImplementedException(""); | |||
| } | |||
| } | |||
| var_list = var_list.Concat(ops.get_collection<RefVariable>(tf.GraphKeys._STREAMING_MODEL_PORTS)).ToList(); | |||
| var_list = var_list.Concat(ops.get_collection<ResourceVariable>(tf.GraphKeys._STREAMING_MODEL_PORTS)).ToList(); | |||
| var processors = var_list.Select(v => optimizer._get_processor(v)).ToList(); | |||
| var var_refs = processors.Select(x => x.target()).ToArray(); | |||
| @@ -406,7 +412,7 @@ namespace Tensorflow | |||
| grads = control_flow_ops.tuple(grads); | |||
| var grads_and_vars = zip(grads, var_list) | |||
| .Select(x => new Tuple<Tensor, RefVariable>(x.Item1, x.Item2)) | |||
| .Select(x => new Tuple<Tensor, ResourceVariable>(x.Item1, x.Item2)) | |||
| .ToArray(); | |||
| return grads_and_vars; | |||
| @@ -59,7 +59,7 @@ namespace Tensorflow | |||
| return _op.outputs[0]; | |||
| } | |||
| public static Operation resource_apply_gradient_descent(EagerTensor var, EagerTensor alpha, EagerTensor delta, bool use_locking = false, string name = null) | |||
| public static Operation resource_apply_gradient_descent(Tensor var, Tensor alpha, Tensor delta, bool use_locking = false, string name = null) | |||
| { | |||
| if (tf.context.executing_eagerly()) | |||
| { | |||
| @@ -79,7 +79,7 @@ namespace Tensorflow | |||
| use_locking | |||
| }); | |||
| return _op.outputs[0]; | |||
| return _op; | |||
| } | |||
| } | |||
| } | |||
| @@ -24,6 +24,11 @@ namespace Tensorflow | |||
| { | |||
| return new _RefVariableProcessor(v); | |||
| } | |||
| public static _OptimizableVariable _get_processor(ResourceVariable v) | |||
| { | |||
| return new _DenseResourceVariableProcessor(v); | |||
| } | |||
| } | |||
| public class _RefVariableProcessor : _OptimizableVariable | |||
| @@ -56,4 +61,35 @@ namespace Tensorflow | |||
| return update_op; | |||
| } | |||
| } | |||
| public class _DenseResourceVariableProcessor : _OptimizableVariable | |||
| { | |||
| private ResourceVariable _v; | |||
| public _DenseResourceVariableProcessor(ResourceVariable v) | |||
| { | |||
| _v = v; | |||
| } | |||
| public Tensor target() | |||
| { | |||
| return _v.Handle; | |||
| } | |||
| public Operation update_op(Optimizer optimizer, Tensor g) | |||
| { | |||
| Operation update_op = null; | |||
| if (g.Tag == null) | |||
| { | |||
| update_op = optimizer._apply_dense(g, _v); | |||
| } | |||
| else if (g.Tag is IndexedSlices) | |||
| { | |||
| return optimizer._apply_sparse_duplicate_indices(g, _v); | |||
| } | |||
| return update_op; | |||
| } | |||
| } | |||
| } | |||
| @@ -1,438 +0,0 @@ | |||
| /***************************************************************************** | |||
| Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||
| you may not use this file except in compliance with the License. | |||
| You may obtain a copy of the License at | |||
| http://www.apache.org/licenses/LICENSE-2.0 | |||
| Unless required by applicable law or agreed to in writing, software | |||
| distributed under the License is distributed on an "AS IS" BASIS, | |||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| See the License for the specific language governing permissions and | |||
| limitations under the License. | |||
| ******************************************************************************/ | |||
| using System; | |||
| using System.Collections; | |||
| using System.Collections.Generic; | |||
| using System.Diagnostics.CodeAnalysis; | |||
| namespace Tensorflow | |||
| { | |||
| public class WeakKeyDictionary<TKey, TValue> : IDictionary<TKey, TValue> | |||
| { | |||
| private Dictionary<WeakKey, TValue> _internalDictionary; | |||
| private object _internalObject = new object(); | |||
| private bool _finalized; | |||
| public WeakKeyDictionary() | |||
| { | |||
| _internalDictionary = new Dictionary<WeakKey, TValue>(new WeakComparer()); | |||
| } | |||
| public WeakKeyDictionary(int capacity) | |||
| { | |||
| _internalDictionary = new Dictionary<WeakKey, TValue>(capacity, new WeakComparer()); | |||
| } | |||
| public WeakKeyDictionary(IEqualityComparer<TKey> comparer) | |||
| { | |||
| _internalDictionary = new Dictionary<WeakKey, TValue>(new WeakComparer(comparer)); | |||
| } | |||
| public WeakKeyDictionary(int capacity, IEqualityComparer<TKey> comparer) | |||
| { | |||
| _internalDictionary = new Dictionary<WeakKey, TValue>(capacity, new WeakComparer(comparer)); | |||
| } | |||
| // FXCop: this is not empty; we need to mark this so we know if a key | |||
| // still has an active dictionary at its finalization. | |||
| [SuppressMessage("Microsoft.Performance", "CA1821:RemoveEmptyFinalizers")] | |||
| ~WeakKeyDictionary() | |||
| { | |||
| _finalized = true; | |||
| } | |||
| public ICollection<TKey> Keys | |||
| { | |||
| get | |||
| { | |||
| List<TKey> list = new List<TKey>(); | |||
| lock (_internalObject) | |||
| { | |||
| foreach (WeakKey key in _internalDictionary.Keys) | |||
| { | |||
| object TKey = key.Target; | |||
| if (TKey != null) | |||
| { | |||
| list.Add((TKey)TKey); | |||
| } | |||
| } | |||
| } | |||
| return list; | |||
| } | |||
| } | |||
| public ICollection<TValue> Values | |||
| { | |||
| get { | |||
| lock (_internalObject) { | |||
| return _internalDictionary.Values; | |||
| } | |||
| } | |||
| } | |||
| public int Count | |||
| { | |||
| get | |||
| { | |||
| // Ensure a fairly accurate count. | |||
| ScavangeLostKeys(); | |||
| lock (_internalObject) | |||
| { | |||
| return _internalDictionary.Count; | |||
| } | |||
| } | |||
| } | |||
| public bool IsReadOnly | |||
| { | |||
| get { | |||
| return false; | |||
| } | |||
| } | |||
| [SuppressMessage("Microsoft.Usage", "CA1806:DoNotIgnoreMethodResults", Justification = "LostKeyFinder's purpose is to get garbage collected as soon as posible")] | |||
| public TValue this[TKey key] | |||
| { | |||
| get { | |||
| lock (_internalObject) { | |||
| return _internalDictionary[new WeakKey(key)]; | |||
| } | |||
| } | |||
| set | |||
| { | |||
| WeakKey Tkey = new WeakKey(key); | |||
| lock (_internalObject) | |||
| { | |||
| //_internalDictionary[Tkey] = value; | |||
| _internalDictionary.Add(Tkey, value); | |||
| } | |||
| // This looks a bit weird but the purpose of the lost key finder is to execute | |||
| // code in some future garbage collection phase so we immediately create some garbage. | |||
| new LostKeyFinder(this, Tkey); | |||
| } | |||
| } | |||
| public bool TryGetValue(TKey key, out TValue value) | |||
| { | |||
| WeakKey tkey = new WeakKey(key); | |||
| lock (_internalObject) | |||
| { | |||
| return _internalDictionary.TryGetValue(tkey, out value); | |||
| } | |||
| } | |||
| [SuppressMessage("Microsoft.Usage", "CA1806:DoNotIgnoreMethodResults", Justification = "LostKeyFinder's purpose is to get garbage collected as soon as posible")] | |||
| public void Add(TKey key, TValue value) | |||
| { | |||
| WeakKey tkey = new WeakKey(key); | |||
| lock (_internalObject) | |||
| { | |||
| _internalDictionary.Add(tkey, value); | |||
| } | |||
| // This looks a bit weird but the purpose of the lost key finder is to execute | |||
| // code in some future garbage collection phase so we immediately create some garbage. | |||
| new LostKeyFinder(this, tkey); | |||
| } | |||
| public bool ContainsKey(TKey key) | |||
| { | |||
| return _internalDictionary.ContainsKey(new WeakKey(key)); | |||
| } | |||
| public bool Remove(TKey key) | |||
| { | |||
| lock (_internalObject) | |||
| { | |||
| return _internalDictionary.Remove(new WeakKey(key)); | |||
| } | |||
| } | |||
| public void Add(KeyValuePair<TKey, TValue> item) | |||
| { | |||
| Add(item.Key, item.Value); | |||
| } | |||
| public void Clear() | |||
| { | |||
| lock (_internalObject) | |||
| { | |||
| _internalDictionary.Clear(); | |||
| } | |||
| } | |||
| public bool Contains(KeyValuePair<TKey, TValue> item) | |||
| { | |||
| TValue value; | |||
| bool result; | |||
| lock (_internalObject) | |||
| { | |||
| result = _internalDictionary.TryGetValue(new WeakKey(item.Key), out value); | |||
| } | |||
| if (result) | |||
| { | |||
| return value.Equals(item.Value); | |||
| } | |||
| else | |||
| { | |||
| return false; | |||
| } | |||
| } | |||
| public void CopyTo(KeyValuePair<TKey, TValue>[] array, int arrayIndex) | |||
| { | |||
| lock (_internalObject) | |||
| { | |||
| foreach (KeyValuePair<WeakKey, TValue> item in _internalDictionary) | |||
| { | |||
| KeyValuePair<TKey, TValue> kv = new KeyValuePair<TKey, TValue>((TKey)item.Key.Target, item.Value); | |||
| array[arrayIndex] = kv; | |||
| arrayIndex++; | |||
| } | |||
| } | |||
| } | |||
| public bool Remove(KeyValuePair<TKey, TValue> item) | |||
| { | |||
| WeakKey key = new WeakKey(item.Key); | |||
| lock (_internalObject) | |||
| { | |||
| return _internalDictionary.Remove(key); | |||
| } | |||
| } | |||
| public IEnumerator<KeyValuePair<TKey, TValue>> GetEnumerator() | |||
| { | |||
| List<WeakKey> lostKeys = null; | |||
| lock (_internalObject) | |||
| { | |||
| foreach (KeyValuePair<WeakKey, TValue> item in _internalDictionary) | |||
| { | |||
| object TKey = item.Key.Target; | |||
| if (TKey != null) | |||
| { | |||
| yield return new KeyValuePair<TKey, TValue>((TKey)TKey, item.Value); | |||
| } | |||
| else | |||
| { | |||
| if (lostKeys == null) | |||
| { | |||
| lostKeys = new List<WeakKey>(); | |||
| } | |||
| lostKeys.Add(item.Key); | |||
| } | |||
| } | |||
| } | |||
| // Recover any lost keys. | |||
| if (lostKeys != null) | |||
| { | |||
| lock (_internalObject) | |||
| { | |||
| foreach (WeakKey key in lostKeys) | |||
| { | |||
| _internalDictionary.Remove(key); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| IEnumerator IEnumerable.GetEnumerator() | |||
| { | |||
| return GetEnumerator(); | |||
| } | |||
| private void ScavangeLostKeys() | |||
| { | |||
| List<WeakKey> lostKeys = null; | |||
| lock (_internalObject) | |||
| { | |||
| foreach (WeakKey key in _internalDictionary.Keys) | |||
| { | |||
| if (!key.IsAlive) | |||
| { | |||
| if (lostKeys == null) | |||
| { | |||
| lostKeys = new List<WeakKey>(); | |||
| } | |||
| lostKeys.Add(key); | |||
| } | |||
| } | |||
| } | |||
| if (lostKeys != null) | |||
| { | |||
| lock (_internalObject) | |||
| { | |||
| foreach (WeakKey key in lostKeys) | |||
| { | |||
| _internalDictionary.Remove(key); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| IEnumerator<KeyValuePair<TKey, TValue>> IEnumerable<KeyValuePair<TKey, TValue>>.GetEnumerator() | |||
| { | |||
| return this.GetEnumerator(); | |||
| } | |||
| private class WeakKey : WeakReference | |||
| { | |||
| private int _hashCode; | |||
| // private GCHandle _gcHandle; | |||
| public WeakKey(TKey key) | |||
| : base(key, true) | |||
| { | |||
| _hashCode = key.GetHashCode(); | |||
| // Keep the key alive until it is explicitly collected | |||
| // _gcHandle = GCHandle.Alloc(this); | |||
| } | |||
| internal void Release() | |||
| { | |||
| // _gcHandle.Free(); | |||
| } | |||
| public override int GetHashCode() | |||
| { | |||
| return _hashCode; | |||
| } | |||
| public override bool Equals(object obj) | |||
| { | |||
| if (obj == null) | |||
| { | |||
| return false; | |||
| } | |||
| if (obj.GetHashCode() != _hashCode) | |||
| { | |||
| return false; | |||
| } | |||
| if (obj != this && (!IsAlive || !obj.Equals(Target))) | |||
| { | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| } | |||
| private class WeakComparer : IEqualityComparer<WeakKey> | |||
| { | |||
| private IEqualityComparer<TKey> _comparer; | |||
| public WeakComparer() | |||
| { | |||
| } | |||
| public WeakComparer(IEqualityComparer<TKey> comparer) | |||
| { | |||
| _comparer = comparer; | |||
| } | |||
| public bool Equals(WeakKey x, WeakKey y) | |||
| { | |||
| if (x.GetHashCode() != y.GetHashCode()) | |||
| { | |||
| return false; | |||
| } | |||
| if (object.ReferenceEquals(x, y)) | |||
| { | |||
| return true; | |||
| } | |||
| object ref1 = x.Target; | |||
| if (ref1 == null) | |||
| { | |||
| return false; | |||
| } | |||
| object ref2 = y.Target; | |||
| if (ref2 == null) | |||
| { | |||
| return false; | |||
| } | |||
| if (_comparer != null) | |||
| { | |||
| return _comparer.Equals((TKey)ref1, (TKey)ref2); | |||
| } | |||
| else | |||
| { | |||
| return ref1.Equals(ref2); | |||
| } | |||
| } | |||
| public int GetHashCode(WeakKey obj) | |||
| { | |||
| return obj.GetHashCode(); | |||
| } | |||
| } | |||
| private class LostKeyFinder | |||
| { | |||
| WeakKeyDictionary<TKey, TValue> _dictionary; | |||
| WeakKey _key; | |||
| public LostKeyFinder(WeakKeyDictionary<TKey, TValue> dictionary, WeakKey key) | |||
| { | |||
| _dictionary = dictionary; | |||
| _key = key; | |||
| } | |||
| ~LostKeyFinder() | |||
| { | |||
| if (_dictionary._finalized || _key == null) | |||
| { | |||
| if (_key != null) | |||
| { | |||
| _key.Release(); | |||
| _key = null; | |||
| } | |||
| return; | |||
| } | |||
| // if (!_key.IsAlive) { | |||
| if (_key.Target == null) | |||
| { | |||
| lock (_dictionary._internalObject) | |||
| { | |||
| _dictionary._internalDictionary.Remove(_key); | |||
| } | |||
| _key.Release(); | |||
| _key = null; | |||
| } | |||
| else if (_dictionary._internalDictionary.ContainsKey(_key)) | |||
| { | |||
| GC.ReRegisterForFinalize(this); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -2,13 +2,13 @@ | |||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
| using NumSharp; | |||
| using System.Linq; | |||
| using Tensorflow; | |||
| using Tensorflow.UnitTest; | |||
| using static Tensorflow.Binding; | |||
| namespace TensorFlowNET.UnitTest.Basics | |||
| { | |||
| [TestClass] | |||
| public class VariableTest | |||
| public class VariableTest : EagerModeTestBase | |||
| { | |||
| [TestMethod] | |||
| public void NewVariable() | |||
| @@ -0,0 +1,23 @@ | |||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using TensorFlowNET.UnitTest; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow.UnitTest | |||
| { | |||
| public class EagerModeTestBase : PythonTest | |||
| { | |||
| [TestInitialize] | |||
| public void TestInit() | |||
| { | |||
| tf.enable_eager_execution(); | |||
| } | |||
| [TestCleanup] | |||
| public void TestClean() | |||
| { | |||
| } | |||
| } | |||
| } | |||