| @@ -23,14 +23,14 @@ namespace Tensorflow | |||
| { | |||
| public VariableV1[] global_variables(string scope = null) | |||
| { | |||
| return (ops.get_collection<VariableV1>(tf.GraphKeys.GLOBAL_VARIABLES, scope)) | |||
| return (ops.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope) as List<VariableV1>) | |||
| .ToArray(); | |||
| } | |||
| public Operation global_variables_initializer() | |||
| { | |||
| var g = variables.global_variables(); | |||
| return variables.variables_initializer(g?.ToArray()); | |||
| return variables.variables_initializer(g.ToArray()); | |||
| } | |||
| /// <summary> | |||
| @@ -54,9 +54,9 @@ namespace Tensorflow | |||
| { | |||
| var scope = Tensorflow.variable_scope.get_variable_scope(); | |||
| var store = Tensorflow.variable_scope._get_default_variable_store(); | |||
| return scope.get_variable(store, | |||
| name, | |||
| shape: shape, | |||
| return scope.get_variable(store, | |||
| name, | |||
| shape: shape, | |||
| dtype: dtype, | |||
| use_resource: use_resource, | |||
| validate_shape: validate_shape, | |||
| @@ -10,11 +10,23 @@ namespace Tensorflow | |||
| { | |||
| public static class functools | |||
| { | |||
| public static Func<Tin, Tout> partial<Tin, Tout>(Func<Tin, Tout> func, Tin arg) | |||
| => (arg0) => func(arg0); | |||
| public static PartialFunc<Tin, Tout> partial<Tin, Tout>(Func<Tin, Tout> func, Tin arg) | |||
| => new PartialFunc<Tin, Tout> | |||
| { | |||
| args = arg, | |||
| invoke = func | |||
| }; | |||
| public static Func<Tin1, Tin2, Tout> partial<Tin1, Tin2, Tout>(Func<Tin1, Tin2, Tout> func, (Tin1, Tin2) args) | |||
| => (arg1, arg2) => func(arg1, arg2); | |||
| => (arg1, arg2) => func(args.Item1, args.Item2); | |||
| } | |||
| public class PartialFunc<Tin, Tout> | |||
| { | |||
| public Tin args { get; set; } | |||
| public object[] keywords { get; set; } | |||
| public Func<Tin, Tout> invoke { get; set; } | |||
| } | |||
| } | |||
| } | |||
| @@ -46,9 +46,9 @@ namespace Tensorflow | |||
| if (!string.IsNullOrEmpty(unbound_inputs_col_name)) | |||
| { | |||
| foreach(var col in meta_graph_def.CollectionDef) | |||
| foreach (var col in meta_graph_def.CollectionDef) | |||
| { | |||
| if(col.Key == unbound_inputs_col_name) | |||
| if (col.Key == unbound_inputs_col_name) | |||
| { | |||
| throw new NotImplementedException("import_scoped_meta_graph_with_return_elements"); | |||
| } | |||
| @@ -78,7 +78,7 @@ namespace Tensorflow | |||
| // Restores all the other collections. | |||
| var variable_objects = new Dictionary<ByteString, VariableV1>(); | |||
| foreach(var col in meta_graph_def.CollectionDef.OrderBy(x => x.Key)) | |||
| foreach (var col in meta_graph_def.CollectionDef.OrderBy(x => x.Key)) | |||
| { | |||
| // Don't add unbound_inputs to the new graph. | |||
| if (col.Key == unbound_inputs_col_name) | |||
| @@ -87,7 +87,7 @@ namespace Tensorflow | |||
| switch (col.Value.KindCase) | |||
| { | |||
| case KindOneofCase.NodeList: | |||
| foreach(var value in col.Value.NodeList.Value) | |||
| foreach (var value in col.Value.NodeList.Value) | |||
| { | |||
| var col_op = graph.as_graph_element(ops.prepend_name_scope(value, scope_to_prepend_to_names)); | |||
| graph.add_to_collection(col.Key, col_op); | |||
| @@ -115,7 +115,7 @@ namespace Tensorflow | |||
| } | |||
| else | |||
| { | |||
| foreach(var value in col.Value.BytesList.Value) | |||
| foreach (var value in col.Value.BytesList.Value) | |||
| { | |||
| switch (col.Key) | |||
| { | |||
| @@ -139,7 +139,7 @@ namespace Tensorflow | |||
| } | |||
| } | |||
| } | |||
| break; | |||
| default: | |||
| throw new NotImplementedException("import_scoped_meta_graph_with_return_elements"); | |||
| @@ -173,8 +173,8 @@ namespace Tensorflow | |||
| string unbound_inputs_col_name = "unbound_inputs", | |||
| bool clear_devices = false, | |||
| SaverDef saver_def = null, | |||
| bool clear_extraneous_savers= false, | |||
| bool strip_default_attrs= false, | |||
| bool clear_extraneous_savers = false, | |||
| bool strip_default_attrs = false, | |||
| byte[] meta_info_def = null) | |||
| { | |||
| var graph = ops.get_default_graph(); | |||
| @@ -236,12 +236,12 @@ namespace Tensorflow | |||
| meta_graph_def.GraphDef = graph_def; | |||
| // Fills in meta_info_def.stripped_op_list using the ops from graph_def. | |||
| if (meta_graph_def.MetaInfoDef.StrippedOpList == null || | |||
| if (meta_graph_def.MetaInfoDef.StrippedOpList == null || | |||
| meta_graph_def.MetaInfoDef.StrippedOpList.Op.Count == 0) | |||
| meta_graph_def.MetaInfoDef.StrippedOpList = stripped_op_list_for_graph(meta_graph_def.GraphDef); | |||
| var clist = graph.get_all_collection_keys(); | |||
| foreach(var ctype in clist) | |||
| foreach (var ctype in clist) | |||
| { | |||
| if (clear_extraneous_savers) | |||
| { | |||
| @@ -256,30 +256,34 @@ namespace Tensorflow | |||
| return meta_graph_def; | |||
| } | |||
| private static void add_collection_def(MetaGraphDef meta_graph_def, | |||
| string key, | |||
| private static void add_collection_def(MetaGraphDef meta_graph_def, | |||
| string key, | |||
| Graph graph = null, | |||
| string export_scope = "") | |||
| { | |||
| if (!meta_graph_def.CollectionDef.ContainsKey(key)) | |||
| meta_graph_def.CollectionDef[key] = new CollectionDef(); | |||
| var col_def = meta_graph_def.CollectionDef[key]; | |||
| col_def.NodeList = new Types.NodeList(); | |||
| col_def.BytesList = new Types.BytesList(); | |||
| foreach (object value in graph.get_collection(key)) | |||
| switch (graph.get_collection(key)) | |||
| { | |||
| switch (value) | |||
| { | |||
| case RefVariable x: | |||
| case List<RefVariable> collection_list: | |||
| col_def.BytesList = new Types.BytesList(); | |||
| foreach (var x in collection_list) | |||
| { | |||
| var proto = x.to_proto(export_scope); | |||
| col_def.BytesList.Value.Add(proto.ToByteString()); | |||
| break; | |||
| case ITensorOrOperation x2: | |||
| col_def.NodeList.Value.Add(ops.strip_name_scope(x2.name, export_scope)); | |||
| break; | |||
| default: | |||
| break; | |||
| } | |||
| } | |||
| break; | |||
| case List<object> collection_list: | |||
| col_def.NodeList = new Types.NodeList(); | |||
| foreach (var x in collection_list) | |||
| if (x is ITensorOrOperation x2) | |||
| col_def.NodeList.Value.Add(ops.strip_name_scope(x2.name, export_scope)); | |||
| break; | |||
| case List<Operation> collection_list: | |||
| break; | |||
| } | |||
| } | |||
| @@ -77,7 +77,7 @@ namespace Tensorflow | |||
| /// <remarks>https://www.tensorflow.org/guide/graphs <br></br>https://www.tensorflow.org/api_docs/python/tf/Graph</remarks> | |||
| public partial class Graph : DisposableObject | |||
| #if !SERIALIZABLE | |||
| ,IEnumerable<Operation> | |||
| , IEnumerable<Operation> | |||
| #endif | |||
| { | |||
| private Dictionary<int, ITensorOrOperation> _nodes_by_id; | |||
| @@ -100,15 +100,13 @@ namespace Tensorflow | |||
| /// </summary> | |||
| private bool _finalized = false; | |||
| /// <summary> | |||
| /// Arbitrary collections of objects inside the graph. | |||
| /// TODO: Access might be slow (-> O(n)) depending on size. | |||
| /// Arbitrary collections of objects. | |||
| /// </summary> | |||
| private readonly ICollection<(string name, string scope, object item)> _collections = new List<(string name, string scope, object item)>(); | |||
| private Dictionary<string, object> _collections = new Dictionary<string, object>(); | |||
| public bool building_function; | |||
| public bool building_function; | |||
| public Graph() | |||
| { | |||
| _handle = c_api.TF_NewGraph(); | |||
| @@ -230,14 +228,16 @@ namespace Tensorflow | |||
| throw new Exception($"Can not convert a {obj.GetType().Name} into a {types_str}."); | |||
| } | |||
| public void add_to_collection(string name, object value) | |||
| public void add_to_collection<T>(string name, T value) | |||
| { | |||
| _check_not_finalized(); | |||
| _collections.Add((name, null, value)); | |||
| if (_collections.ContainsKey(name)) | |||
| (_collections[name] as List<T>).Add(value); | |||
| else | |||
| _collections[name] = new List<T> { value }; | |||
| } | |||
| public void add_to_collections(List<string> names, object value) | |||
| public void add_to_collections<T>(List<string> names, T value) | |||
| { | |||
| foreach (string name in names) | |||
| add_to_collection(name, value); | |||
| @@ -278,6 +278,12 @@ namespace Tensorflow | |||
| _create_op_helper(op, true); | |||
| /*Console.Write($"create_op: {op_type} '{node_def.Name}'"); | |||
| Console.Write($", inputs: {(inputs.Length == 0 ? "empty" : String.Join(", ", inputs.Select(x => x.name)))}"); | |||
| Console.Write($", control_inputs: {(control_inputs.Length == 0 ? "empty" : String.Join(", ", control_inputs.Select(x => x.name)))}"); | |||
| Console.Write($", outputs: {(op.outputs.Length == 0 ? "empty" : String.Join(", ", op.outputs.Select(x => x.name)))}"); | |||
| Console.WriteLine();*/ | |||
| return op; | |||
| } | |||
| @@ -394,7 +400,7 @@ namespace Tensorflow | |||
| _names_in_use[name_key] = 1; | |||
| // Return the new name with the original capitalization of the given name. | |||
| name = $"{name}_{i-1}"; | |||
| name = $"{name}_{i - 1}"; | |||
| } | |||
| return name; | |||
| } | |||
| @@ -407,8 +413,8 @@ namespace Tensorflow | |||
| TF_Output[] return_outputs = new TF_Output[num_return_outputs]; | |||
| unsafe | |||
| { | |||
| var tf_output_ptr = (TF_Output*) return_output_handle; | |||
| for (int i = 0; i < num_return_outputs; i++) | |||
| var tf_output_ptr = (TF_Output*)return_output_handle; | |||
| for (int i = 0; i < num_return_outputs; i++) | |||
| return_outputs[i] = *(tf_output_ptr + i); | |||
| return return_outputs; | |||
| } | |||
| @@ -416,34 +422,46 @@ namespace Tensorflow | |||
| public string[] get_all_collection_keys() | |||
| { | |||
| return (from c in _collections where !c.name.StartsWith("__") select c.name).ToArray(); | |||
| return _collections.Keys.Where(x => !x.StartsWith("__")).ToArray(); | |||
| } | |||
| public List<object> get_collection(string name, string scope = null) | |||
| public object get_collection(string name, string scope = null) | |||
| { | |||
| return get_collection<object>(name, scope); | |||
| } | |||
| return _collections.ContainsKey(name) ? _collections[name] : null; | |||
| } | |||
| public List<T> get_collection<T>(string name, string scope = null) | |||
| { | |||
| return (from c in _collections | |||
| where c.name == name && | |||
| (scope == null || c.scope == scope) && | |||
| implementationOf<T>(c.item) | |||
| select (T)(c.item)).ToList(); | |||
| } | |||
| private static bool implementationOf<T>(object item) | |||
| { | |||
| return (item.GetType() == typeof(T) || item.GetType().IsSubclassOf(typeof(T))); | |||
| } | |||
| { | |||
| List<T> t = default; | |||
| var collection = _collections.ContainsKey(name) ? _collections[name] : new List<T>(); | |||
| switch (collection) | |||
| { | |||
| case List<VariableV1> list: | |||
| t = list.Select(x => (T)(object)x).ToList(); | |||
| break; | |||
| case List<ResourceVariable> list: | |||
| t = list.Select(x => (T)(object)x).ToList(); | |||
| break; | |||
| case List<RefVariable> list: | |||
| t = list.Select(x => (T)(object)x).ToList(); | |||
| break; | |||
| case List<Tensor> list: | |||
| t = list.Select(x => (T)(object)x).ToList(); | |||
| break; | |||
| case List<Operation> list: | |||
| t = list.Select(x => (T)(object)x).ToList(); | |||
| break; | |||
| default: | |||
| throw new NotImplementedException($"get_collection<{typeof(T).FullName}>"); | |||
| } | |||
| return t; | |||
| } | |||
| public List<T> get_collection_ref<T>(string name) | |||
| { | |||
| return get_collection<T>(name); | |||
| if (!_collections.ContainsKey(name)) | |||
| _collections[name] = new List<T>(); | |||
| return _collections[name] as List<T>; | |||
| } | |||
| public void prevent_feeding(Tensor tensor) | |||
| @@ -497,7 +515,7 @@ namespace Tensorflow | |||
| string debugString = string.Empty; | |||
| public override string ToString() | |||
| { | |||
| return $"{graph_key}, ({_handle})"; | |||
| return $"{graph_key}, ({_handle})"; | |||
| /*if (string.IsNullOrEmpty(debugString)) | |||
| { | |||
| int len = 0; | |||
| @@ -514,7 +532,7 @@ namespace Tensorflow | |||
| IEnumerator<Operation> IEnumerable<Operation>.GetEnumerator() | |||
| => GetEnumerable().GetEnumerator(); | |||
| IEnumerator IEnumerable.GetEnumerator() | |||
| IEnumerator IEnumerable.GetEnumerator() | |||
| => throw new NotImplementedException(); | |||
| #endif | |||
| @@ -16,7 +16,7 @@ namespace Tensorflow.Train | |||
| // Create in proper graph and base name_scope. | |||
| var g = graph.as_default(); | |||
| g.name_scope(null); | |||
| var v = tf.get_variable(tf.GraphKeys.GLOBAL_STEP, new TensorShape(), dtype: dtypes.int64, | |||
| var v = tf.get_variable(tf.GraphKeys.GLOBAL_STEP, new int[0], dtype: dtypes.int64, | |||
| initializer: tf.zeros_initializer, | |||
| trainable: false, | |||
| aggregation: VariableAggregation.OnlyFirstReplica, | |||
| @@ -43,7 +43,7 @@ namespace Tensorflow | |||
| protected Graph _graph; | |||
| bool _building_function; | |||
| public variable_scope(string name, | |||
| public variable_scope(string name, | |||
| string default_name = "", | |||
| Tensor[] values = null, | |||
| bool? reuse = null, | |||
| @@ -113,7 +113,7 @@ namespace Tensorflow | |||
| { | |||
| // Reenter the current name scope | |||
| string name_scope = ops.get_name_scope(); | |||
| if(!string.IsNullOrEmpty(name_scope)) | |||
| if (!string.IsNullOrEmpty(name_scope)) | |||
| // Hack to reenter | |||
| name_scope += "/"; | |||
| current_name_scope = ops.name_scope(name_scope); | |||
| @@ -128,8 +128,8 @@ namespace Tensorflow | |||
| string current_name_scope_name = current_name_scope; | |||
| _current_name_scope = current_name_scope; | |||
| string old_name_scope = _scope == null ? current_name_scope_name : _scope.original_name_scope; | |||
| if(_scope == null) | |||
| if (_scope == null) | |||
| pure_variable_scope = new PureVariableScope(_name, old_name_scope: old_name_scope); | |||
| else | |||
| pure_variable_scope = new PureVariableScope(_scope, old_name_scope: old_name_scope); | |||
| @@ -179,7 +179,7 @@ namespace Tensorflow | |||
| TF_DataType dtype = TF_DataType.DtInvalid, | |||
| int[] shape = null, | |||
| bool validate_shape = false, | |||
| bool ? use_resource = null, | |||
| bool? use_resource = null, | |||
| VariableSynchronization synchronization = VariableSynchronization.Auto, | |||
| VariableAggregation aggregation = VariableAggregation.None) | |||
| { | |||
| @@ -189,7 +189,7 @@ namespace Tensorflow | |||
| use_resource = get_variable_scope().use_resource; | |||
| } | |||
| if(!use_resource.HasValue) | |||
| if (!use_resource.HasValue) | |||
| use_resource = _DEFAULT_USE_RESOURCE; | |||
| if (use_resource.Value) | |||
| @@ -204,7 +204,7 @@ namespace Tensorflow | |||
| } | |||
| else | |||
| { | |||
| return new RefVariable(initial_value, | |||
| return new RefVariable(initial_value, | |||
| trainable: trainable.Value, | |||
| validate_shape: validate_shape, | |||
| collections: collections, | |||
| @@ -215,13 +215,13 @@ namespace Tensorflow | |||
| public static _VariableStore _get_default_variable_store() | |||
| { | |||
| var store = ops.get_collection<_VariableStore>(_VARSTORE_KEY).FirstOrDefault(); | |||
| if (store == null) | |||
| { | |||
| store = new _VariableStore(); | |||
| ops.add_to_collection(_VARSTORE_KEY, store); | |||
| } | |||
| return store; | |||
| var store = ops.get_collection(_VARSTORE_KEY); | |||
| if (store != null) | |||
| return (store as List<_VariableStore>)[0]; | |||
| var store1 = new _VariableStore(); | |||
| ops.add_to_collection(_VARSTORE_KEY, store1); | |||
| return store1; | |||
| } | |||
| public static VariableScope get_variable_scope() | |||
| @@ -231,15 +231,30 @@ namespace Tensorflow | |||
| public static _VariableScopeStore get_variable_scope_store() | |||
| { | |||
| var scope_store = ops.get_collection<_VariableScopeStore>(_VARSCOPESTORE_KEY).FirstOrDefault(); | |||
| if (scope_store == null) | |||
| scope_store = ops.get_collection<RefVariable>(_VARSCOPESTORE_KEY).FirstOrDefault(); | |||
| _VariableScopeStore ret = null; | |||
| var scope_store = ops.get_collection(_VARSCOPESTORE_KEY); | |||
| if (scope_store == null) | |||
| { | |||
| scope_store = new _VariableScopeStore(); | |||
| ops.add_to_collection(_VARSCOPESTORE_KEY, scope_store); | |||
| ret = new _VariableScopeStore(); | |||
| ops.add_to_collection(_VARSCOPESTORE_KEY, ret); | |||
| } | |||
| else | |||
| { | |||
| switch (scope_store) | |||
| { | |||
| case List<RefVariable> values: | |||
| ret = values[0]; | |||
| break; | |||
| case List<_VariableScopeStore> values: | |||
| ret = values[0]; | |||
| break; | |||
| default: | |||
| throw new InvalidOperationException("get_variable_scope_store"); | |||
| } | |||
| } | |||
| return scope_store; | |||
| return ret; | |||
| } | |||
| public static bool _get_trainable_value(VariableSynchronization synchronization, bool? trainable = true) | |||
| @@ -256,7 +271,7 @@ namespace Tensorflow | |||
| { | |||
| trainable = true; | |||
| } | |||
| return trainable.Value; | |||
| } | |||
| @@ -279,7 +294,7 @@ namespace Tensorflow | |||
| } | |||
| // TODO for Switch/Case | |||
| public static RefVariable get_variable(string embeddingMatrix, IInitializer initializer, bool use_resource, | |||
| public static RefVariable get_variable(string embeddingMatrix, IInitializer initializer, bool use_resource, | |||
| TensorShape shape = null, | |||
| TF_DataType dtype = TF_DataType.DtInvalid, | |||
| bool trainable = false, | |||
| @@ -290,12 +305,12 @@ namespace Tensorflow | |||
| public void __init__() | |||
| { | |||
| } | |||
| public void __del__() | |||
| { | |||
| } | |||
| } | |||
| } | |||
| @@ -63,7 +63,7 @@ namespace Tensorflow | |||
| /// list contains the values in the order under which they were | |||
| /// collected. | |||
| /// </returns> | |||
| public static List<object> get_collection(string key, string scope = null) | |||
| public static object get_collection(string key, string scope = null) | |||
| { | |||
| return get_default_graph().get_collection(key, scope); | |||
| } | |||