From e02d6d042d8a8302f5e350ca19da3eaa539702c9 Mon Sep 17 00:00:00 2001 From: Arnav Das Date: Sat, 8 Jun 2019 01:05:51 +0530 Subject: [PATCH 1/2] ongoing tf.keras.backend.cs --- src/TensorFlowNET.Core/Keras/BackendBase.cs | 73 ++++++++++++++++ .../Keras/GraphLearningPhase.cs | 8 ++ .../Keras/ImageDataFormat.cs | 8 ++ .../Keras/Utils/base_layer_utils.cs | 6 +- src/TensorFlowNET.Core/Keras/backend.cs | 85 +++++++++++++++++-- src/TensorFlowNET.Core/Keras/defaultdict.cs | 22 +++++ src/TensorFlowNET.Core/Python.cs | 63 ++++++++++++++ 7 files changed, 255 insertions(+), 10 deletions(-) create mode 100644 src/TensorFlowNET.Core/Keras/BackendBase.cs create mode 100644 src/TensorFlowNET.Core/Keras/GraphLearningPhase.cs create mode 100644 src/TensorFlowNET.Core/Keras/ImageDataFormat.cs create mode 100644 src/TensorFlowNET.Core/Keras/defaultdict.cs diff --git a/src/TensorFlowNET.Core/Keras/BackendBase.cs b/src/TensorFlowNET.Core/Keras/BackendBase.cs new file mode 100644 index 00000000..f3624485 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/BackendBase.cs @@ -0,0 +1,73 @@ +using NumSharp; +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow; + +namespace Tensorflow.Keras +{ + public abstract class BackendBase + { + TF_DataType _FLOATX = dtypes.float32; + float _EPSILON = 1e-7f; + ImageDataFormat _IMAGE_DATA_FORMAT = ImageDataFormat.channels_last; + + + public float epsilon() => _EPSILON; + + public void set_epsilon(float e) => _EPSILON = e; + + public TF_DataType floatx() => _FLOATX; + + public void set_floatx(TF_DataType floatx) => _FLOATX = floatx; + + public NDArray cast_to_floatx(NDArray x) => np.array(x, dtype: _FLOATX.as_numpy_datatype()); + + public ImageDataFormat image_data_format() => _IMAGE_DATA_FORMAT; + + public void set_image_data_format(ImageDataFormat data_format) => _IMAGE_DATA_FORMAT = data_format; + + public ImageDataFormat normalize_data_format(object value = null) + { + if (value == null) + value = _IMAGE_DATA_FORMAT; + if (value.GetType() == typeof(ImageDataFormat)) + return (ImageDataFormat)value; + else if (value.GetType() == typeof(string)) + { + ImageDataFormat dataFormat; + if(Enum.TryParse((string)value, true, out dataFormat)) + { + if (Enum.IsDefined(typeof(ImageDataFormat), dataFormat) | dataFormat.ToString().Contains(",")) + return dataFormat; + } + } + throw new Exception("The `data_format` argument must be one of \"channels_first\", \"channels_last\". Received: " + value.ToString()); + } + + //Legacy Methods + + public void set_image_dim_ordering(ImageDimOrder dim_ordering) + { + if (dim_ordering == ImageDimOrder.th) + _IMAGE_DATA_FORMAT = ImageDataFormat.channels_first; + else if (dim_ordering == ImageDimOrder.tf) + _IMAGE_DATA_FORMAT = ImageDataFormat.channels_last; + else + throw new Exception("Unknown dim_ordering:"+ dim_ordering); + } + + public ImageDimOrder image_dim_ordering() + { + if (_IMAGE_DATA_FORMAT == ImageDataFormat.channels_first) + return ImageDimOrder.th; + else + return ImageDimOrder.tf; + } + } + public enum ImageDimOrder + { + tf, + th + } +} diff --git a/src/TensorFlowNET.Core/Keras/GraphLearningPhase.cs b/src/TensorFlowNET.Core/Keras/GraphLearningPhase.cs new file mode 100644 index 00000000..6f833e06 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/GraphLearningPhase.cs @@ -0,0 +1,8 @@ +namespace Tensorflow.Keras +{ + public enum GraphLearningPhase + { + train_mode = 1, + test_mode = 0 + } +} diff --git a/src/TensorFlowNET.Core/Keras/ImageDataFormat.cs b/src/TensorFlowNET.Core/Keras/ImageDataFormat.cs new file mode 100644 index 00000000..d32849fe --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/ImageDataFormat.cs @@ -0,0 +1,8 @@ +namespace Tensorflow.Keras +{ + public enum ImageDataFormat + { + channels_last, + channels_first + } +} diff --git a/src/TensorFlowNET.Core/Keras/Utils/base_layer_utils.cs b/src/TensorFlowNET.Core/Keras/Utils/base_layer_utils.cs index 22c2cfc5..5244e8e9 100644 --- a/src/TensorFlowNET.Core/Keras/Utils/base_layer_utils.cs +++ b/src/TensorFlowNET.Core/Keras/Utils/base_layer_utils.cs @@ -95,14 +95,14 @@ namespace Tensorflow.Keras.Utils { var graph = ops.get_default_graph(); Dictionary<(string, string), int> name_uid_map = null; - if (backend.PER_GRAPH_LAYER_NAME_UIDS.ContainsKey(graph.graph_key)) + if (backend.PER_GRAPH_LAYER_NAME_UIDS.ContainsKey(graph)) { - name_uid_map = backend.PER_GRAPH_LAYER_NAME_UIDS[graph.graph_key]; + name_uid_map = backend.PER_GRAPH_LAYER_NAME_UIDS[graph]; } else { name_uid_map = new Dictionary<(string, string), int>(); - backend.PER_GRAPH_LAYER_NAME_UIDS[graph.graph_key] = name_uid_map; + backend.PER_GRAPH_LAYER_NAME_UIDS[graph] = name_uid_map; } return name_uid_map; diff --git a/src/TensorFlowNET.Core/Keras/backend.cs b/src/TensorFlowNET.Core/Keras/backend.cs index 4213957f..e6c64f90 100644 --- a/src/TensorFlowNET.Core/Keras/backend.cs +++ b/src/TensorFlowNET.Core/Keras/backend.cs @@ -1,31 +1,51 @@ using System; using System.Collections.Generic; using System.Text; +using System.Runtime.CompilerServices; +using static Tensorflow.Python; namespace Tensorflow.Keras { - public class backend + public class backend : BackendBase { + /* ---------------------------------------- KERAS BACKEND NATIVE OBJECTS ---------------------------------------- */ + public static Func py_sum = sum; + public static Func py_all = all; + //Func py_any = any; + //Func> py_slice = slice; + + public static Session _SESSION = Tensorflow.tf.defaultSession; + public static Graph _GRAPH = null; + public static Dictionary _GRAPH_LEARNING_PHASES; + //Dictionary> PER_GRAPH_LAYER_NAME_UIDS; + public static bool _MANUAL_VAR_INIT = false; + public static List _LOCAL_DEVICES = null; + /* -------------------------------------- KERAS BACKEND NATIVE OBJECTS END -------------------------------------- */ + /// /// A global dictionary mapping graph objects to an index of counters used /// for various layer names in each graph. /// Allows to give unique autogenerated names to layers, in a graph-specific way. /// - public static Dictionary> PER_GRAPH_LAYER_NAME_UIDS = new Dictionary>(); + public static Dictionary> PER_GRAPH_LAYER_NAME_UIDS = new Dictionary>(); public static Dictionary _GRAPH_VARIABLES = new Dictionary(); + public static Dictionary _GRAPH_TF_OPTIMIZERS = new Dictionary(); + + public static _DummyEagerGraph _DUMMY_EAGER_GRAPH = new _DummyEagerGraph(); + public static void track_variable(RefVariable v) { var graph = v.graph; _GRAPH_VARIABLES[graph.graph_key] = v; } - public static Tensor placeholder(int[] shape = null, - int ndim = -1, - TF_DataType dtype = TF_DataType.DtInvalid, - bool sparse = false, + public static Tensor placeholder(int[] shape = null, + int ndim = -1, + TF_DataType dtype = TF_DataType.DtInvalid, + bool sparse = false, string name = null) { - if(sparse) + if (sparse) { throw new NotImplementedException("placeholder sparse is true"); } @@ -39,5 +59,56 @@ namespace Tensorflow.Keras { return ops.get_default_graph(); } + + public static int get_uid(string prefix, string @namespace = "") + { + var graph = tf.get_default_graph(); + if (!PER_GRAPH_LAYER_NAME_UIDS.ContainsKey(graph)) + PER_GRAPH_LAYER_NAME_UIDS.Add(graph, new defaultdict<(string, string), int>()); + PER_GRAPH_LAYER_NAME_UIDS[graph][(@namespace, prefix)] += 1; + + return PER_GRAPH_LAYER_NAME_UIDS[graph][(@namespace, prefix)]; + } + public static int get_uid((string, string) name) + { + var graph = tf.get_default_graph(); + if (!PER_GRAPH_LAYER_NAME_UIDS.ContainsKey(graph)) + PER_GRAPH_LAYER_NAME_UIDS.Add(graph, new defaultdict<(string, string), int>()); + PER_GRAPH_LAYER_NAME_UIDS[graph][(name)] += 1; + + return PER_GRAPH_LAYER_NAME_UIDS[graph][name]; + } + public static void reset_uids() => PER_GRAPH_LAYER_NAME_UIDS = new Dictionary>(); + public static void clear_session() + { + ops.reset_default_graph(); + reset_uids(); + _SESSION = null; + var phase = tf.placeholder_with_default(false, new int[] { }, name: "keras_learning_phase"); + _GRAPH_LEARNING_PHASES = new Dictionary(); + _GRAPH_LEARNING_PHASES[tf.get_default_graph()] = 0; + } + public static void manual_variable_initialization(bool value) + { + _MANUAL_VAR_INIT = value; + } + public static GraphLearningPhase learning_phase() + { + var graph = tf.get_default_graph(); + if (_GRAPH_LEARNING_PHASES.ContainsKey(graph)) + { + var phase = tf.placeholder_with_default(false, shape: new int[] { }, name: "keras_learning_phase"); + _GRAPH_LEARNING_PHASES[graph] = 0; + } + return _GRAPH_LEARNING_PHASES[graph]; + } + public static void set_learning_phase(bool value) + { + _GRAPH_LEARNING_PHASES[tf.get_default_graph()] = (GraphLearningPhase)((value) ? 1 : 0); + } + + + public class _DummyEagerGraph + { } } } diff --git a/src/TensorFlowNET.Core/Keras/defaultdict.cs b/src/TensorFlowNET.Core/Keras/defaultdict.cs new file mode 100644 index 00000000..849abd00 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/defaultdict.cs @@ -0,0 +1,22 @@ +using System.Collections.Generic; + +namespace System.Collections.Generic +{ + public class defaultdict : Dictionary where TValue : new() + { + public new TValue this[TKey key] + { + get + { + TValue val; + if(!TryGetValue(key, out val)) + { + val = default(TValue); + Add(key, val); + } + return val; + } + set { base[key] = value; } + } + } +} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Python.cs b/src/TensorFlowNET.Core/Python.cs index 42d2063a..595fdd02 100644 --- a/src/TensorFlowNET.Core/Python.cs +++ b/src/TensorFlowNET.Core/Python.cs @@ -184,6 +184,69 @@ namespace Tensorflow return dictionary; } + + public static bool all(IEnumerable enumerable) + { + foreach (var e1 in enumerable) + { + if (!Convert.ToBoolean(e1)) + return false; + } + return true; + } + + public static bool any(IEnumerable enumerable) + { + foreach (var e1 in enumerable) + { + if (Convert.ToBoolean(e1)) + return true; + } + return false; + } + + public static double sum(IEnumerable enumerable) + { + var typedef = new Type[] { typeof(double), typeof(int), typeof(float) }; + var sum = 0.0d; + foreach (var e1 in enumerable) + { + if (!typedef.Contains(e1.GetType())) + throw new Exception("Numeric array expected"); + sum += (double)e1; + } + return sum; + } + + public static double sum(Dictionary values) + { + return sum(values.Keys); + } + + public static IEnumerable slice(double start, double end, double step = 1) + { + for (double i = start; i < end; i += step) + yield return i; + } + + public static IEnumerable slice(float start, float end, float step = 1) + { + for (float i = start; i < end; i += step) + yield return i; + } + + public static IEnumerable slice(int start, int end, int step = 1) + { + for (int i = start; i < end; i += step) + yield return i; + } + + public static IEnumerable slice(int range) + { + for (int i = 0; i < range; i++) + yield return i; + } + public static bool hasattr(object obj, string key) { var __type__ = (obj).GetType(); From bc9f57541d66d4b6f73fa6a5b46335da688743e5 Mon Sep 17 00:00:00 2001 From: Arnav Das Date: Sat, 8 Jun 2019 01:06:47 +0530 Subject: [PATCH 2/2] Ongoing WeakKeyDictionary --- src/TensorFlowNET.Core/WeakKeyDicionary.cs | 424 ++++++++++++++++++ .../TensorFlowNET.UnitTest/PythonBaseTests.cs | 18 + 2 files changed, 442 insertions(+) create mode 100644 src/TensorFlowNET.Core/WeakKeyDicionary.cs diff --git a/src/TensorFlowNET.Core/WeakKeyDicionary.cs b/src/TensorFlowNET.Core/WeakKeyDicionary.cs new file mode 100644 index 00000000..98df4f30 --- /dev/null +++ b/src/TensorFlowNET.Core/WeakKeyDicionary.cs @@ -0,0 +1,424 @@ +using System; +using System.Collections; +using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; +using System.Text; +using System.Runtime.InteropServices; + +namespace Tensorflow +{ + public class WeakKeyDictionary : IDictionary + { + + private Dictionary _internalDictionary; + private object _internalObject = new object(); + private bool _finalized; + + public WeakKeyDictionary() + { + _internalDictionary = new Dictionary(new WeakComparer()); + } + + public WeakKeyDictionary(int capacity) + { + _internalDictionary = new Dictionary(capacity, new WeakComparer()); + } + + public WeakKeyDictionary(IEqualityComparer comparer) + { + _internalDictionary = new Dictionary(new WeakComparer(comparer)); + } + + public WeakKeyDictionary(int capacity, IEqualityComparer comparer) + { + _internalDictionary = new Dictionary(capacity, new WeakComparer(comparer)); + } + + // FXCop: this is not empty; we need to mark this so we know if a key + // still has an active dictionary at its finalization. + [SuppressMessage("Microsoft.Performance", "CA1821:RemoveEmptyFinalizers")] + ~WeakKeyDictionary() + { + _finalized = true; + } + + public ICollection Keys + { + get + { + List list = new List(); + lock (_internalObject) + { + foreach (WeakKey key in _internalDictionary.Keys) + { + object TKey = key.Target; + if (TKey != null) + { + list.Add((TKey)TKey); + } + } + } + return list; + } + } + + public ICollection Values + { + get { + lock (_internalObject) { + return _internalDictionary.Values; + } + } + } + + public int Count + { + get + { + // Ensure a fairly accurate count. + ScavangeLostKeys(); + lock (_internalObject) + { + return _internalDictionary.Count; + } + } + } + + public bool IsReadOnly + { + get { + return false; + } + } + + [SuppressMessage("Microsoft.Usage", "CA1806:DoNotIgnoreMethodResults", Justification = "LostKeyFinder's purpose is to get garbage collected as soon as posible")] + public TValue this[TKey key] + { + get { + lock (_internalObject) { + return _internalDictionary[new WeakKey(key)]; + } + } + set + { + WeakKey Tkey = new WeakKey(key); + lock (_internalObject) + { + //_internalDictionary[Tkey] = value; + _internalDictionary.Add(Tkey, value); + } + // This looks a bit weird but the purpose of the lost key finder is to execute + // code in some future garbage collection phase so we immediately create some garbage. + new LostKeyFinder(this, Tkey); + } + } + + + + + + public bool TryGetValue(TKey key, out TValue value) + { + WeakKey tkey = new WeakKey(key); + lock (_internalObject) + { + return _internalDictionary.TryGetValue(tkey, out value); + } + } + + + [SuppressMessage("Microsoft.Usage", "CA1806:DoNotIgnoreMethodResults", Justification = "LostKeyFinder's purpose is to get garbage collected as soon as posible")] + public void Add(TKey key, TValue value) + { + WeakKey tkey = new WeakKey(key); + lock (_internalObject) + { + _internalDictionary.Add(tkey, value); + } + // This looks a bit weird but the purpose of the lost key finder is to execute + // code in some future garbage collection phase so we immediately create some garbage. + new LostKeyFinder(this, tkey); + + } + + public bool ContainsKey(TKey key) + { + return _internalDictionary.ContainsKey(new WeakKey(key)); + } + + public bool Remove(TKey key) + { + lock (_internalObject) + { + return _internalDictionary.Remove(new WeakKey(key)); + } + } + + public void Add(KeyValuePair item) + { + Add(item.Key, item.Value); + } + + public void Clear() + { + lock (_internalObject) + { + _internalDictionary.Clear(); + } + } + + public bool Contains(KeyValuePair item) + { + TValue value; + bool result; + lock (_internalObject) + { + result = _internalDictionary.TryGetValue(new WeakKey(item.Key), out value); + } + if (result) + { + return value.Equals(item.Value); + } + else + { + return false; + } + } + + public void CopyTo(KeyValuePair[] array, int arrayIndex) + { + lock (_internalObject) + { + foreach (KeyValuePair item in _internalDictionary) + { + KeyValuePair kv = new KeyValuePair((TKey)item.Key.Target, item.Value); + array[arrayIndex] = kv; + arrayIndex++; + } + } + } + + public bool Remove(KeyValuePair item) + { + WeakKey key = new WeakKey(item.Key); + lock (_internalObject) + { + return _internalDictionary.Remove(key); + } + } + + + + + + public IEnumerator> GetEnumerator() + { + List lostKeys = null; + lock (_internalObject) + { + foreach (KeyValuePair item in _internalDictionary) + { + object TKey = item.Key.Target; + if (TKey != null) + { + yield return new KeyValuePair((TKey)TKey, item.Value); + } + else + { + if (lostKeys == null) + { + lostKeys = new List(); + } + lostKeys.Add(item.Key); + } + } + } + // Recover any lost keys. + if (lostKeys != null) + { + lock (_internalObject) + { + foreach (WeakKey key in lostKeys) + { + _internalDictionary.Remove(key); + } + } + } + } + + + + + IEnumerator IEnumerable.GetEnumerator() + { + return GetEnumerator(); + } + + + + private void ScavangeLostKeys() + { + List lostKeys = null; + lock (_internalObject) + { + foreach (WeakKey key in _internalDictionary.Keys) + { + if (!key.IsAlive) + { + if (lostKeys == null) + { + lostKeys = new List(); + } + lostKeys.Add(key); + } + } + } + if (lostKeys != null) + { + lock (_internalObject) + { + foreach (WeakKey key in lostKeys) + { + _internalDictionary.Remove(key); + } + } + } + } + + IEnumerator> IEnumerable>.GetEnumerator() + { + return this.GetEnumerator(); + } + + private class WeakKey : WeakReference + { + private int _hashCode; + // private GCHandle _gcHandle; + + public WeakKey(TKey key) + : base(key, true) + { + _hashCode = key.GetHashCode(); + // Keep the key alive until it is explicitly collected + // _gcHandle = GCHandle.Alloc(this); + } + + internal void Release() + { + // _gcHandle.Free(); + } + + public override int GetHashCode() + { + return _hashCode; + } + + public override bool Equals(object obj) + { + if (obj == null) + { + return false; + } + if (obj.GetHashCode() != _hashCode) + { + return false; + } + if (obj != this && (!IsAlive || !obj.Equals(Target))) + { + return false; + } + return true; + } + } + + private class WeakComparer : IEqualityComparer + { + + private IEqualityComparer _comparer; + public WeakComparer() + { + } + + public WeakComparer(IEqualityComparer comparer) + { + _comparer = comparer; + } + + public bool Equals(WeakKey x, WeakKey y) + { + if (x.GetHashCode() != y.GetHashCode()) + { + return false; + } + if (object.ReferenceEquals(x, y)) + { + return true; + } + object ref1 = x.Target; + if (ref1 == null) + { + return false; + } + object ref2 = y.Target; + if (ref2 == null) + { + return false; + } + + if (_comparer != null) + { + return _comparer.Equals((TKey)ref1, (TKey)ref2); + } + else + { + return ref1.Equals(ref2); + } + } + + public int GetHashCode(WeakKey obj) + { + return obj.GetHashCode(); + } + } + + private class LostKeyFinder + { + WeakKeyDictionary _dictionary; + WeakKey _key; + + public LostKeyFinder(WeakKeyDictionary dictionary, WeakKey key) + { + _dictionary = dictionary; + _key = key; + } + + ~LostKeyFinder() + { + if (_dictionary._finalized || _key == null) + { + if (_key != null) + { + _key.Release(); + _key = null; + } + return; + } + // if (!_key.IsAlive) { + if (_key.Target == null) + { + lock (_dictionary._internalObject) + { + _dictionary._internalDictionary.Remove(_key); + } + _key.Release(); + _key = null; + } + else if (_dictionary._internalDictionary.ContainsKey(_key)) + { + GC.ReRegisterForFinalize(this); + } + } + } + } +} + \ No newline at end of file diff --git a/test/TensorFlowNET.UnitTest/PythonBaseTests.cs b/test/TensorFlowNET.UnitTest/PythonBaseTests.cs index 765a71c2..c5010923 100644 --- a/test/TensorFlowNET.UnitTest/PythonBaseTests.cs +++ b/test/TensorFlowNET.UnitTest/PythonBaseTests.cs @@ -10,6 +10,24 @@ namespace TensorFlowNET.UnitTest [TestClass] public class PythonBaseTests : PythonTest { + [Ignore] + [TestMethod] + public void weakKeyDictionary_test() + { + var weakKeyDict = new WeakKeyDictionary(); + for (int i = 0; i < 5; i++) + { + var c = (char)((int)'a' + i); + weakKeyDict[i] = c; + //Assert.AreEqual(weakKeyDict.Count, (int)(i + 1)); + var v = (weakKeyDict.Count == i + 1); + Assert.IsTrue(v); + } + //Assert.AreEqual(weakKeyDict.Count, 0); + var b = (weakKeyDict.Count == 0); + Assert.IsTrue(b); + } + [TestMethod] public void hasattr_getattr() {