| @@ -0,0 +1,73 @@ | |||||
| using NumSharp; | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| using Tensorflow; | |||||
| namespace Tensorflow.Keras | |||||
| { | |||||
| public abstract class BackendBase | |||||
| { | |||||
| TF_DataType _FLOATX = dtypes.float32; | |||||
| float _EPSILON = 1e-7f; | |||||
| ImageDataFormat _IMAGE_DATA_FORMAT = ImageDataFormat.channels_last; | |||||
| public float epsilon() => _EPSILON; | |||||
| public void set_epsilon(float e) => _EPSILON = e; | |||||
| public TF_DataType floatx() => _FLOATX; | |||||
| public void set_floatx(TF_DataType floatx) => _FLOATX = floatx; | |||||
| public NDArray cast_to_floatx(NDArray x) => np.array(x, dtype: _FLOATX.as_numpy_datatype()); | |||||
| public ImageDataFormat image_data_format() => _IMAGE_DATA_FORMAT; | |||||
| public void set_image_data_format(ImageDataFormat data_format) => _IMAGE_DATA_FORMAT = data_format; | |||||
| public ImageDataFormat normalize_data_format(object value = null) | |||||
| { | |||||
| if (value == null) | |||||
| value = _IMAGE_DATA_FORMAT; | |||||
| if (value.GetType() == typeof(ImageDataFormat)) | |||||
| return (ImageDataFormat)value; | |||||
| else if (value.GetType() == typeof(string)) | |||||
| { | |||||
| ImageDataFormat dataFormat; | |||||
| if(Enum.TryParse((string)value, true, out dataFormat)) | |||||
| { | |||||
| if (Enum.IsDefined(typeof(ImageDataFormat), dataFormat) | dataFormat.ToString().Contains(",")) | |||||
| return dataFormat; | |||||
| } | |||||
| } | |||||
| throw new Exception("The `data_format` argument must be one of \"channels_first\", \"channels_last\". Received: " + value.ToString()); | |||||
| } | |||||
| //Legacy Methods | |||||
| public void set_image_dim_ordering(ImageDimOrder dim_ordering) | |||||
| { | |||||
| if (dim_ordering == ImageDimOrder.th) | |||||
| _IMAGE_DATA_FORMAT = ImageDataFormat.channels_first; | |||||
| else if (dim_ordering == ImageDimOrder.tf) | |||||
| _IMAGE_DATA_FORMAT = ImageDataFormat.channels_last; | |||||
| else | |||||
| throw new Exception("Unknown dim_ordering:"+ dim_ordering); | |||||
| } | |||||
| public ImageDimOrder image_dim_ordering() | |||||
| { | |||||
| if (_IMAGE_DATA_FORMAT == ImageDataFormat.channels_first) | |||||
| return ImageDimOrder.th; | |||||
| else | |||||
| return ImageDimOrder.tf; | |||||
| } | |||||
| } | |||||
| public enum ImageDimOrder | |||||
| { | |||||
| tf, | |||||
| th | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,8 @@ | |||||
| namespace Tensorflow.Keras | |||||
| { | |||||
| public enum GraphLearningPhase | |||||
| { | |||||
| train_mode = 1, | |||||
| test_mode = 0 | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,8 @@ | |||||
| namespace Tensorflow.Keras | |||||
| { | |||||
| public enum ImageDataFormat | |||||
| { | |||||
| channels_last, | |||||
| channels_first | |||||
| } | |||||
| } | |||||
| @@ -95,14 +95,14 @@ namespace Tensorflow.Keras.Utils | |||||
| { | { | ||||
| var graph = ops.get_default_graph(); | var graph = ops.get_default_graph(); | ||||
| Dictionary<(string, string), int> name_uid_map = null; | Dictionary<(string, string), int> name_uid_map = null; | ||||
| if (backend.PER_GRAPH_LAYER_NAME_UIDS.ContainsKey(graph.graph_key)) | |||||
| if (backend.PER_GRAPH_LAYER_NAME_UIDS.ContainsKey(graph)) | |||||
| { | { | ||||
| name_uid_map = backend.PER_GRAPH_LAYER_NAME_UIDS[graph.graph_key]; | |||||
| name_uid_map = backend.PER_GRAPH_LAYER_NAME_UIDS[graph]; | |||||
| } | } | ||||
| else | else | ||||
| { | { | ||||
| name_uid_map = new Dictionary<(string, string), int>(); | name_uid_map = new Dictionary<(string, string), int>(); | ||||
| backend.PER_GRAPH_LAYER_NAME_UIDS[graph.graph_key] = name_uid_map; | |||||
| backend.PER_GRAPH_LAYER_NAME_UIDS[graph] = name_uid_map; | |||||
| } | } | ||||
| return name_uid_map; | return name_uid_map; | ||||
| @@ -1,31 +1,51 @@ | |||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Text; | using System.Text; | ||||
| using System.Runtime.CompilerServices; | |||||
| using static Tensorflow.Python; | |||||
| namespace Tensorflow.Keras | namespace Tensorflow.Keras | ||||
| { | { | ||||
| public class backend | |||||
| public class backend : BackendBase | |||||
| { | { | ||||
| /* ---------------------------------------- KERAS BACKEND NATIVE OBJECTS ---------------------------------------- */ | |||||
| public static Func<Array, double> py_sum = sum; | |||||
| public static Func<Array, bool> py_all = all; | |||||
| //Func<Array, bool> py_any = any; | |||||
| //Func<double, double, double, IEnumerable<double>> py_slice = slice; | |||||
| public static Session _SESSION = Tensorflow.tf.defaultSession; | |||||
| public static Graph _GRAPH = null; | |||||
| public static Dictionary<Graph, GraphLearningPhase> _GRAPH_LEARNING_PHASES; | |||||
| //Dictionary<Graph, Dictionary<string, int>> PER_GRAPH_LAYER_NAME_UIDS; | |||||
| public static bool _MANUAL_VAR_INIT = false; | |||||
| public static List<string> _LOCAL_DEVICES = null; | |||||
| /* -------------------------------------- KERAS BACKEND NATIVE OBJECTS END -------------------------------------- */ | |||||
| /// <summary> | /// <summary> | ||||
| /// A global dictionary mapping graph objects to an index of counters used | /// A global dictionary mapping graph objects to an index of counters used | ||||
| /// for various layer names in each graph. | /// for various layer names in each graph. | ||||
| /// Allows to give unique autogenerated names to layers, in a graph-specific way. | /// Allows to give unique autogenerated names to layers, in a graph-specific way. | ||||
| /// </summary> | /// </summary> | ||||
| public static Dictionary<string, Dictionary<(string, string), int>> PER_GRAPH_LAYER_NAME_UIDS = new Dictionary<string, Dictionary<(string, string), int>>(); | |||||
| public static Dictionary<Graph, Dictionary<(string, string), int>> PER_GRAPH_LAYER_NAME_UIDS = new Dictionary<Graph, Dictionary<(string, string), int>>(); | |||||
| public static Dictionary<string, RefVariable> _GRAPH_VARIABLES = new Dictionary<string, RefVariable>(); | public static Dictionary<string, RefVariable> _GRAPH_VARIABLES = new Dictionary<string, RefVariable>(); | ||||
| public static Dictionary<string, Optimizer> _GRAPH_TF_OPTIMIZERS = new Dictionary<string, Optimizer>(); | |||||
| public static _DummyEagerGraph _DUMMY_EAGER_GRAPH = new _DummyEagerGraph(); | |||||
| public static void track_variable(RefVariable v) | public static void track_variable(RefVariable v) | ||||
| { | { | ||||
| var graph = v.graph; | var graph = v.graph; | ||||
| _GRAPH_VARIABLES[graph.graph_key] = v; | _GRAPH_VARIABLES[graph.graph_key] = v; | ||||
| } | } | ||||
| public static Tensor placeholder(int[] shape = null, | |||||
| int ndim = -1, | |||||
| TF_DataType dtype = TF_DataType.DtInvalid, | |||||
| bool sparse = false, | |||||
| public static Tensor placeholder(int[] shape = null, | |||||
| int ndim = -1, | |||||
| TF_DataType dtype = TF_DataType.DtInvalid, | |||||
| bool sparse = false, | |||||
| string name = null) | string name = null) | ||||
| { | { | ||||
| if(sparse) | |||||
| if (sparse) | |||||
| { | { | ||||
| throw new NotImplementedException("placeholder sparse is true"); | throw new NotImplementedException("placeholder sparse is true"); | ||||
| } | } | ||||
| @@ -39,5 +59,56 @@ namespace Tensorflow.Keras | |||||
| { | { | ||||
| return ops.get_default_graph(); | return ops.get_default_graph(); | ||||
| } | } | ||||
| public static int get_uid(string prefix, string @namespace = "") | |||||
| { | |||||
| var graph = tf.get_default_graph(); | |||||
| if (!PER_GRAPH_LAYER_NAME_UIDS.ContainsKey(graph)) | |||||
| PER_GRAPH_LAYER_NAME_UIDS.Add(graph, new defaultdict<(string, string), int>()); | |||||
| PER_GRAPH_LAYER_NAME_UIDS[graph][(@namespace, prefix)] += 1; | |||||
| return PER_GRAPH_LAYER_NAME_UIDS[graph][(@namespace, prefix)]; | |||||
| } | |||||
| public static int get_uid((string, string) name) | |||||
| { | |||||
| var graph = tf.get_default_graph(); | |||||
| if (!PER_GRAPH_LAYER_NAME_UIDS.ContainsKey(graph)) | |||||
| PER_GRAPH_LAYER_NAME_UIDS.Add(graph, new defaultdict<(string, string), int>()); | |||||
| PER_GRAPH_LAYER_NAME_UIDS[graph][(name)] += 1; | |||||
| return PER_GRAPH_LAYER_NAME_UIDS[graph][name]; | |||||
| } | |||||
| public static void reset_uids() => PER_GRAPH_LAYER_NAME_UIDS = new Dictionary<Graph, Dictionary<(string, string), int>>(); | |||||
| public static void clear_session() | |||||
| { | |||||
| ops.reset_default_graph(); | |||||
| reset_uids(); | |||||
| _SESSION = null; | |||||
| var phase = tf.placeholder_with_default(false, new int[] { }, name: "keras_learning_phase"); | |||||
| _GRAPH_LEARNING_PHASES = new Dictionary<Graph, GraphLearningPhase>(); | |||||
| _GRAPH_LEARNING_PHASES[tf.get_default_graph()] = 0; | |||||
| } | |||||
| public static void manual_variable_initialization(bool value) | |||||
| { | |||||
| _MANUAL_VAR_INIT = value; | |||||
| } | |||||
| public static GraphLearningPhase learning_phase() | |||||
| { | |||||
| var graph = tf.get_default_graph(); | |||||
| if (_GRAPH_LEARNING_PHASES.ContainsKey(graph)) | |||||
| { | |||||
| var phase = tf.placeholder_with_default(false, shape: new int[] { }, name: "keras_learning_phase"); | |||||
| _GRAPH_LEARNING_PHASES[graph] = 0; | |||||
| } | |||||
| return _GRAPH_LEARNING_PHASES[graph]; | |||||
| } | |||||
| public static void set_learning_phase(bool value) | |||||
| { | |||||
| _GRAPH_LEARNING_PHASES[tf.get_default_graph()] = (GraphLearningPhase)((value) ? 1 : 0); | |||||
| } | |||||
| public class _DummyEagerGraph | |||||
| { } | |||||
| } | } | ||||
| } | } | ||||
| @@ -0,0 +1,22 @@ | |||||
| using System.Collections.Generic; | |||||
| namespace System.Collections.Generic | |||||
| { | |||||
| public class defaultdict<TKey, TValue> : Dictionary<TKey, TValue> where TValue : new() | |||||
| { | |||||
| public new TValue this[TKey key] | |||||
| { | |||||
| get | |||||
| { | |||||
| TValue val; | |||||
| if(!TryGetValue(key, out val)) | |||||
| { | |||||
| val = default(TValue); | |||||
| Add(key, val); | |||||
| } | |||||
| return val; | |||||
| } | |||||
| set { base[key] = value; } | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -184,6 +184,69 @@ namespace Tensorflow | |||||
| return dictionary; | return dictionary; | ||||
| } | } | ||||
| public static bool all(IEnumerable enumerable) | |||||
| { | |||||
| foreach (var e1 in enumerable) | |||||
| { | |||||
| if (!Convert.ToBoolean(e1)) | |||||
| return false; | |||||
| } | |||||
| return true; | |||||
| } | |||||
| public static bool any(IEnumerable enumerable) | |||||
| { | |||||
| foreach (var e1 in enumerable) | |||||
| { | |||||
| if (Convert.ToBoolean(e1)) | |||||
| return true; | |||||
| } | |||||
| return false; | |||||
| } | |||||
| public static double sum(IEnumerable enumerable) | |||||
| { | |||||
| var typedef = new Type[] { typeof(double), typeof(int), typeof(float) }; | |||||
| var sum = 0.0d; | |||||
| foreach (var e1 in enumerable) | |||||
| { | |||||
| if (!typedef.Contains(e1.GetType())) | |||||
| throw new Exception("Numeric array expected"); | |||||
| sum += (double)e1; | |||||
| } | |||||
| return sum; | |||||
| } | |||||
| public static double sum<TKey, TValue>(Dictionary<TKey, TValue> values) | |||||
| { | |||||
| return sum(values.Keys); | |||||
| } | |||||
| public static IEnumerable<double> slice(double start, double end, double step = 1) | |||||
| { | |||||
| for (double i = start; i < end; i += step) | |||||
| yield return i; | |||||
| } | |||||
| public static IEnumerable<float> slice(float start, float end, float step = 1) | |||||
| { | |||||
| for (float i = start; i < end; i += step) | |||||
| yield return i; | |||||
| } | |||||
| public static IEnumerable<int> slice(int start, int end, int step = 1) | |||||
| { | |||||
| for (int i = start; i < end; i += step) | |||||
| yield return i; | |||||
| } | |||||
| public static IEnumerable<int> slice(int range) | |||||
| { | |||||
| for (int i = 0; i < range; i++) | |||||
| yield return i; | |||||
| } | |||||
| public static bool hasattr(object obj, string key) | public static bool hasattr(object obj, string key) | ||||
| { | { | ||||
| var __type__ = (obj).GetType(); | var __type__ = (obj).GetType(); | ||||
| @@ -0,0 +1,424 @@ | |||||
| using System; | |||||
| using System.Collections; | |||||
| using System.Collections.Generic; | |||||
| using System.Diagnostics.CodeAnalysis; | |||||
| using System.Text; | |||||
| using System.Runtime.InteropServices; | |||||
| namespace Tensorflow | |||||
| { | |||||
| public class WeakKeyDictionary<TKey, TValue> : IDictionary<TKey, TValue> | |||||
| { | |||||
| private Dictionary<WeakKey, TValue> _internalDictionary; | |||||
| private object _internalObject = new object(); | |||||
| private bool _finalized; | |||||
| public WeakKeyDictionary() | |||||
| { | |||||
| _internalDictionary = new Dictionary<WeakKey, TValue>(new WeakComparer()); | |||||
| } | |||||
| public WeakKeyDictionary(int capacity) | |||||
| { | |||||
| _internalDictionary = new Dictionary<WeakKey, TValue>(capacity, new WeakComparer()); | |||||
| } | |||||
| public WeakKeyDictionary(IEqualityComparer<TKey> comparer) | |||||
| { | |||||
| _internalDictionary = new Dictionary<WeakKey, TValue>(new WeakComparer(comparer)); | |||||
| } | |||||
| public WeakKeyDictionary(int capacity, IEqualityComparer<TKey> comparer) | |||||
| { | |||||
| _internalDictionary = new Dictionary<WeakKey, TValue>(capacity, new WeakComparer(comparer)); | |||||
| } | |||||
| // FXCop: this is not empty; we need to mark this so we know if a key | |||||
| // still has an active dictionary at its finalization. | |||||
| [SuppressMessage("Microsoft.Performance", "CA1821:RemoveEmptyFinalizers")] | |||||
| ~WeakKeyDictionary() | |||||
| { | |||||
| _finalized = true; | |||||
| } | |||||
| public ICollection<TKey> Keys | |||||
| { | |||||
| get | |||||
| { | |||||
| List<TKey> list = new List<TKey>(); | |||||
| lock (_internalObject) | |||||
| { | |||||
| foreach (WeakKey key in _internalDictionary.Keys) | |||||
| { | |||||
| object TKey = key.Target; | |||||
| if (TKey != null) | |||||
| { | |||||
| list.Add((TKey)TKey); | |||||
| } | |||||
| } | |||||
| } | |||||
| return list; | |||||
| } | |||||
| } | |||||
| public ICollection<TValue> Values | |||||
| { | |||||
| get { | |||||
| lock (_internalObject) { | |||||
| return _internalDictionary.Values; | |||||
| } | |||||
| } | |||||
| } | |||||
| public int Count | |||||
| { | |||||
| get | |||||
| { | |||||
| // Ensure a fairly accurate count. | |||||
| ScavangeLostKeys(); | |||||
| lock (_internalObject) | |||||
| { | |||||
| return _internalDictionary.Count; | |||||
| } | |||||
| } | |||||
| } | |||||
| public bool IsReadOnly | |||||
| { | |||||
| get { | |||||
| return false; | |||||
| } | |||||
| } | |||||
| [SuppressMessage("Microsoft.Usage", "CA1806:DoNotIgnoreMethodResults", Justification = "LostKeyFinder's purpose is to get garbage collected as soon as posible")] | |||||
| public TValue this[TKey key] | |||||
| { | |||||
| get { | |||||
| lock (_internalObject) { | |||||
| return _internalDictionary[new WeakKey(key)]; | |||||
| } | |||||
| } | |||||
| set | |||||
| { | |||||
| WeakKey Tkey = new WeakKey(key); | |||||
| lock (_internalObject) | |||||
| { | |||||
| //_internalDictionary[Tkey] = value; | |||||
| _internalDictionary.Add(Tkey, value); | |||||
| } | |||||
| // This looks a bit weird but the purpose of the lost key finder is to execute | |||||
| // code in some future garbage collection phase so we immediately create some garbage. | |||||
| new LostKeyFinder(this, Tkey); | |||||
| } | |||||
| } | |||||
| public bool TryGetValue(TKey key, out TValue value) | |||||
| { | |||||
| WeakKey tkey = new WeakKey(key); | |||||
| lock (_internalObject) | |||||
| { | |||||
| return _internalDictionary.TryGetValue(tkey, out value); | |||||
| } | |||||
| } | |||||
| [SuppressMessage("Microsoft.Usage", "CA1806:DoNotIgnoreMethodResults", Justification = "LostKeyFinder's purpose is to get garbage collected as soon as posible")] | |||||
| public void Add(TKey key, TValue value) | |||||
| { | |||||
| WeakKey tkey = new WeakKey(key); | |||||
| lock (_internalObject) | |||||
| { | |||||
| _internalDictionary.Add(tkey, value); | |||||
| } | |||||
| // This looks a bit weird but the purpose of the lost key finder is to execute | |||||
| // code in some future garbage collection phase so we immediately create some garbage. | |||||
| new LostKeyFinder(this, tkey); | |||||
| } | |||||
| public bool ContainsKey(TKey key) | |||||
| { | |||||
| return _internalDictionary.ContainsKey(new WeakKey(key)); | |||||
| } | |||||
| public bool Remove(TKey key) | |||||
| { | |||||
| lock (_internalObject) | |||||
| { | |||||
| return _internalDictionary.Remove(new WeakKey(key)); | |||||
| } | |||||
| } | |||||
| public void Add(KeyValuePair<TKey, TValue> item) | |||||
| { | |||||
| Add(item.Key, item.Value); | |||||
| } | |||||
| public void Clear() | |||||
| { | |||||
| lock (_internalObject) | |||||
| { | |||||
| _internalDictionary.Clear(); | |||||
| } | |||||
| } | |||||
| public bool Contains(KeyValuePair<TKey, TValue> item) | |||||
| { | |||||
| TValue value; | |||||
| bool result; | |||||
| lock (_internalObject) | |||||
| { | |||||
| result = _internalDictionary.TryGetValue(new WeakKey(item.Key), out value); | |||||
| } | |||||
| if (result) | |||||
| { | |||||
| return value.Equals(item.Value); | |||||
| } | |||||
| else | |||||
| { | |||||
| return false; | |||||
| } | |||||
| } | |||||
| public void CopyTo(KeyValuePair<TKey, TValue>[] array, int arrayIndex) | |||||
| { | |||||
| lock (_internalObject) | |||||
| { | |||||
| foreach (KeyValuePair<WeakKey, TValue> item in _internalDictionary) | |||||
| { | |||||
| KeyValuePair<TKey, TValue> kv = new KeyValuePair<TKey, TValue>((TKey)item.Key.Target, item.Value); | |||||
| array[arrayIndex] = kv; | |||||
| arrayIndex++; | |||||
| } | |||||
| } | |||||
| } | |||||
| public bool Remove(KeyValuePair<TKey, TValue> item) | |||||
| { | |||||
| WeakKey key = new WeakKey(item.Key); | |||||
| lock (_internalObject) | |||||
| { | |||||
| return _internalDictionary.Remove(key); | |||||
| } | |||||
| } | |||||
| public IEnumerator<KeyValuePair<TKey, TValue>> GetEnumerator() | |||||
| { | |||||
| List<WeakKey> lostKeys = null; | |||||
| lock (_internalObject) | |||||
| { | |||||
| foreach (KeyValuePair<WeakKey, TValue> item in _internalDictionary) | |||||
| { | |||||
| object TKey = item.Key.Target; | |||||
| if (TKey != null) | |||||
| { | |||||
| yield return new KeyValuePair<TKey, TValue>((TKey)TKey, item.Value); | |||||
| } | |||||
| else | |||||
| { | |||||
| if (lostKeys == null) | |||||
| { | |||||
| lostKeys = new List<WeakKey>(); | |||||
| } | |||||
| lostKeys.Add(item.Key); | |||||
| } | |||||
| } | |||||
| } | |||||
| // Recover any lost keys. | |||||
| if (lostKeys != null) | |||||
| { | |||||
| lock (_internalObject) | |||||
| { | |||||
| foreach (WeakKey key in lostKeys) | |||||
| { | |||||
| _internalDictionary.Remove(key); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| IEnumerator IEnumerable.GetEnumerator() | |||||
| { | |||||
| return GetEnumerator(); | |||||
| } | |||||
| private void ScavangeLostKeys() | |||||
| { | |||||
| List<WeakKey> lostKeys = null; | |||||
| lock (_internalObject) | |||||
| { | |||||
| foreach (WeakKey key in _internalDictionary.Keys) | |||||
| { | |||||
| if (!key.IsAlive) | |||||
| { | |||||
| if (lostKeys == null) | |||||
| { | |||||
| lostKeys = new List<WeakKey>(); | |||||
| } | |||||
| lostKeys.Add(key); | |||||
| } | |||||
| } | |||||
| } | |||||
| if (lostKeys != null) | |||||
| { | |||||
| lock (_internalObject) | |||||
| { | |||||
| foreach (WeakKey key in lostKeys) | |||||
| { | |||||
| _internalDictionary.Remove(key); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| IEnumerator<KeyValuePair<TKey, TValue>> IEnumerable<KeyValuePair<TKey, TValue>>.GetEnumerator() | |||||
| { | |||||
| return this.GetEnumerator(); | |||||
| } | |||||
| private class WeakKey : WeakReference | |||||
| { | |||||
| private int _hashCode; | |||||
| // private GCHandle _gcHandle; | |||||
| public WeakKey(TKey key) | |||||
| : base(key, true) | |||||
| { | |||||
| _hashCode = key.GetHashCode(); | |||||
| // Keep the key alive until it is explicitly collected | |||||
| // _gcHandle = GCHandle.Alloc(this); | |||||
| } | |||||
| internal void Release() | |||||
| { | |||||
| // _gcHandle.Free(); | |||||
| } | |||||
| public override int GetHashCode() | |||||
| { | |||||
| return _hashCode; | |||||
| } | |||||
| public override bool Equals(object obj) | |||||
| { | |||||
| if (obj == null) | |||||
| { | |||||
| return false; | |||||
| } | |||||
| if (obj.GetHashCode() != _hashCode) | |||||
| { | |||||
| return false; | |||||
| } | |||||
| if (obj != this && (!IsAlive || !obj.Equals(Target))) | |||||
| { | |||||
| return false; | |||||
| } | |||||
| return true; | |||||
| } | |||||
| } | |||||
| private class WeakComparer : IEqualityComparer<WeakKey> | |||||
| { | |||||
| private IEqualityComparer<TKey> _comparer; | |||||
| public WeakComparer() | |||||
| { | |||||
| } | |||||
| public WeakComparer(IEqualityComparer<TKey> comparer) | |||||
| { | |||||
| _comparer = comparer; | |||||
| } | |||||
| public bool Equals(WeakKey x, WeakKey y) | |||||
| { | |||||
| if (x.GetHashCode() != y.GetHashCode()) | |||||
| { | |||||
| return false; | |||||
| } | |||||
| if (object.ReferenceEquals(x, y)) | |||||
| { | |||||
| return true; | |||||
| } | |||||
| object ref1 = x.Target; | |||||
| if (ref1 == null) | |||||
| { | |||||
| return false; | |||||
| } | |||||
| object ref2 = y.Target; | |||||
| if (ref2 == null) | |||||
| { | |||||
| return false; | |||||
| } | |||||
| if (_comparer != null) | |||||
| { | |||||
| return _comparer.Equals((TKey)ref1, (TKey)ref2); | |||||
| } | |||||
| else | |||||
| { | |||||
| return ref1.Equals(ref2); | |||||
| } | |||||
| } | |||||
| public int GetHashCode(WeakKey obj) | |||||
| { | |||||
| return obj.GetHashCode(); | |||||
| } | |||||
| } | |||||
| private class LostKeyFinder | |||||
| { | |||||
| WeakKeyDictionary<TKey, TValue> _dictionary; | |||||
| WeakKey _key; | |||||
| public LostKeyFinder(WeakKeyDictionary<TKey, TValue> dictionary, WeakKey key) | |||||
| { | |||||
| _dictionary = dictionary; | |||||
| _key = key; | |||||
| } | |||||
| ~LostKeyFinder() | |||||
| { | |||||
| if (_dictionary._finalized || _key == null) | |||||
| { | |||||
| if (_key != null) | |||||
| { | |||||
| _key.Release(); | |||||
| _key = null; | |||||
| } | |||||
| return; | |||||
| } | |||||
| // if (!_key.IsAlive) { | |||||
| if (_key.Target == null) | |||||
| { | |||||
| lock (_dictionary._internalObject) | |||||
| { | |||||
| _dictionary._internalDictionary.Remove(_key); | |||||
| } | |||||
| _key.Release(); | |||||
| _key = null; | |||||
| } | |||||
| else if (_dictionary._internalDictionary.ContainsKey(_key)) | |||||
| { | |||||
| GC.ReRegisterForFinalize(this); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -10,6 +10,24 @@ namespace TensorFlowNET.UnitTest | |||||
| [TestClass] | [TestClass] | ||||
| public class PythonBaseTests : PythonTest | public class PythonBaseTests : PythonTest | ||||
| { | { | ||||
| [Ignore] | |||||
| [TestMethod] | |||||
| public void weakKeyDictionary_test() | |||||
| { | |||||
| var weakKeyDict = new WeakKeyDictionary<int, char>(); | |||||
| for (int i = 0; i < 5; i++) | |||||
| { | |||||
| var c = (char)((int)'a' + i); | |||||
| weakKeyDict[i] = c; | |||||
| //Assert.AreEqual(weakKeyDict.Count, (int)(i + 1)); | |||||
| var v = (weakKeyDict.Count == i + 1); | |||||
| Assert.IsTrue(v); | |||||
| } | |||||
| //Assert.AreEqual(weakKeyDict.Count, 0); | |||||
| var b = (weakKeyDict.Count == 0); | |||||
| Assert.IsTrue(b); | |||||
| } | |||||
| [TestMethod] | [TestMethod] | ||||
| public void hasattr_getattr() | public void hasattr_getattr() | ||||
| { | { | ||||