From bcb803dfe7ce64e79306ea047113726afaa2ce91 Mon Sep 17 00:00:00 2001 From: haiping008 Date: Thu, 10 Jan 2019 11:54:00 -0600 Subject: [PATCH] fix add_collections --- src/TensorFlowNET.Core/Graphs/Graph.cs | 17 ++++++++--- .../Operations/OpDefLibrary.cs | 11 ++++++-- src/TensorFlowNET.Core/Tensors/tensor_util.cs | 16 +++++++++++ src/TensorFlowNET.Core/Train/Optimizer.cs | 28 ++++++++++++++++--- .../Variables/RefVariable.cs | 22 +++++++++++++++ .../Variables/gen_state_ops.py.cs | 2 ++ .../Variables/variable_scope.py.cs | 20 ++++++++++--- src/TensorFlowNET.Core/ops.name_scope.cs | 2 +- src/TensorFlowNET.Core/ops.py.cs | 12 ++++++-- .../LinearRegression.cs | 4 +-- 10 files changed, 114 insertions(+), 20 deletions(-) diff --git a/src/TensorFlowNET.Core/Graphs/Graph.cs b/src/TensorFlowNET.Core/Graphs/Graph.cs index 82516ee9..e0956b09 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.cs @@ -91,9 +91,18 @@ namespace Tensorflow throw new Exception($"Can not convert a {typeof(T).Name} into a {types_str}."); } - public void add_to_collection(string name, object value) + public void add_to_collection(string name, T value) { - _collections[name] = value; + if (_collections.ContainsKey(name)) + (_collections[name] as List).Add(value); + else + _collections[name] = new List { value }; + } + + public void add_to_collections(List names, T value) + { + foreach (string name in names) + add_to_collection(name, value); } public unsafe Operation create_op(string op_type, List inputs, TF_DataType[] dtypes, @@ -236,9 +245,9 @@ namespace Tensorflow return _nodes_by_name.Values.Select(x => x).ToArray(); } - public Dictionary get_collection(string name) + public object get_collection(string name) { - return _collections; + return _collections.ContainsKey(name) ? _collections[name] : null; } public void Dispose() diff --git a/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs b/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs index 121d4163..26bb6374 100644 --- a/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs +++ b/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs @@ -20,7 +20,7 @@ namespace Tensorflow name = op_type_name; } - string scope = g.unique_name(name) + "/"; + string scope = new ops.name_scope(name); var default_type_attr_map = new Dictionary(); foreach (var attr_def in op_def.Attr) @@ -88,6 +88,9 @@ namespace Tensorflow switch (attr_def.Type) { + case "string": + attr_value.S = Google.Protobuf.ByteString.CopyFromUtf8((string)value); + break; case "type": attr_value.Type = _MakeType((TF_DataType)value, attr_def); break; @@ -95,8 +98,12 @@ namespace Tensorflow attr_value.B = (bool)value; break; case "shape": - attr_value.Shape = new TensorShapeProto(); + attr_value.Shape = value == null ? + attr_def.DefaultValue.Shape : + tensor_util.as_shape((long[])value); break; + default: + throw new InvalidDataException($"attr_def.Type {attr_def.Type}"); } attr_protos[key] = attr_value; diff --git a/src/TensorFlowNET.Core/Tensors/tensor_util.cs b/src/TensorFlowNET.Core/Tensors/tensor_util.cs index d1fe6b97..99a01b8b 100644 --- a/src/TensorFlowNET.Core/Tensors/tensor_util.cs +++ b/src/TensorFlowNET.Core/Tensors/tensor_util.cs @@ -73,6 +73,22 @@ namespace Tensorflow return nd; } + public static TensorShapeProto as_shape(long[] dims) + { + TensorShapeProto shape = new TensorShapeProto(); + + for (int i = 0; i < dims.Length; i++) + { + var dim = new TensorShapeProto.Types.Dim(); + dim.Size = dims[i]; + dim.Name = $"dim_{i}"; + + shape.Dim.Add(dim); + } + + return shape; + } + public static TensorShape as_shape(this IShape shape, int[] dims) { return new TensorShape(dims); diff --git a/src/TensorFlowNET.Core/Train/Optimizer.cs b/src/TensorFlowNET.Core/Train/Optimizer.cs index e006eec9..6b054d99 100644 --- a/src/TensorFlowNET.Core/Train/Optimizer.cs +++ b/src/TensorFlowNET.Core/Train/Optimizer.cs @@ -30,9 +30,14 @@ namespace Tensorflow /// /// /// - public Optimizer minimize(Tensor loss, GateGradientType gate_gradients = GateGradientType.GATE_OP) + public Optimizer minimize(Tensor loss, + GateGradientType gate_gradients = GateGradientType.GATE_OP, + bool colocate_gradients_with_ops = false) { - compute_gradients(loss, gate_gradients); + compute_gradients(loss, + gate_gradients: gate_gradients, + colocate_gradients_with_ops: colocate_gradients_with_ops); + return this; } @@ -41,7 +46,10 @@ namespace Tensorflow /// /// /// - public List> compute_gradients(Tensor loss, GateGradientType gate_gradients = GateGradientType.GATE_OP) + public List> compute_gradients(Tensor loss, + List var_list = null, + GateGradientType gate_gradients = GateGradientType.GATE_OP, + bool colocate_gradients_with_ops = false) { int num_towers = 1; if(distribute_lib.get_loss_reduction() == VariableAggregationType.MEAN) @@ -49,7 +57,19 @@ namespace Tensorflow } - var var_list = variables.trainable_variables(); + var tmp = variables.trainable_variables(); + switch (tmp) + { + case List values: + var_list = values; + break; + } + + foreach(var v in var_list) + { + + } + return null; } } diff --git a/src/TensorFlowNET.Core/Variables/RefVariable.cs b/src/TensorFlowNET.Core/Variables/RefVariable.cs index af861bbf..5b2171cf 100644 --- a/src/TensorFlowNET.Core/Variables/RefVariable.cs +++ b/src/TensorFlowNET.Core/Variables/RefVariable.cs @@ -64,6 +64,28 @@ namespace Tensorflow var shape = _initial_value.shape; dtype = _initial_value.dtype; _variable = gen_state_ops.variable_v2(shape, dtype, name); + + // Manually overrides the variable's shape with the initial value's. + if (validate_shape) + { + var initial_value_shape = _initial_value.shape; + } + + // If 'initial_value' makes use of other variables, make sure we don't + // have an issue if these other variables aren't initialized first by + // using their initialized_value() method. + + ops.add_to_collections(collections, this); + } + + public static implicit operator _VariableScopeStore(RefVariable variable) + { + return null; + } + + public static implicit operator RefVariable(_VariableScopeStore store) + { + return null; } } } diff --git a/src/TensorFlowNET.Core/Variables/gen_state_ops.py.cs b/src/TensorFlowNET.Core/Variables/gen_state_ops.py.cs index 8843c475..4ea475e2 100644 --- a/src/TensorFlowNET.Core/Variables/gen_state_ops.py.cs +++ b/src/TensorFlowNET.Core/Variables/gen_state_ops.py.cs @@ -23,6 +23,8 @@ namespace Tensorflow var keywords = new Dictionary(); keywords.Add("dtype", dtype); keywords.Add("shape", shape); + keywords.Add("container", container); + keywords.Add("shared_name", shared_name); var _op = _op_def_lib._apply_op_helper("VariableV2", name: name, keywords: keywords); diff --git a/src/TensorFlowNET.Core/Variables/variable_scope.py.cs b/src/TensorFlowNET.Core/Variables/variable_scope.py.cs index 2c05581a..1c64d591 100644 --- a/src/TensorFlowNET.Core/Variables/variable_scope.py.cs +++ b/src/TensorFlowNET.Core/Variables/variable_scope.py.cs @@ -39,18 +39,30 @@ namespace Tensorflow public static _VariableScopeStore get_variable_scope_store() { + _VariableScopeStore ret = null; var scope_store = ops.get_collection(_VARSCOPESTORE_KEY); if (scope_store == null) { - scope_store = new _VariableScopeStore(); - ops.add_to_collection(_VARSCOPESTORE_KEY, scope_store); + ret = new _VariableScopeStore(); + ops.add_to_collection(_VARSCOPESTORE_KEY, ret); } else { - // scope_store = scope_store[0]; + switch (scope_store) + { + case List values: + ret = values[0]; + break; + case List<_VariableScopeStore> values: + ret = values[0]; + break; + default: + throw new InvalidOperationException("get_variable_scope_store"); + } + } - return scope_store; + return ret; } public static bool _get_trainable_value(VariableSynchronization synchronization, bool? trainable = null) diff --git a/src/TensorFlowNET.Core/ops.name_scope.cs b/src/TensorFlowNET.Core/ops.name_scope.cs index fef5c283..8d00c416 100644 --- a/src/TensorFlowNET.Core/ops.name_scope.cs +++ b/src/TensorFlowNET.Core/ops.name_scope.cs @@ -14,7 +14,7 @@ namespace Tensorflow public Context _ctx; public string _name_scope; - public name_scope(string name, string default_name, List values) + public name_scope(string name, string default_name = "", List values = null) { _name = name; _default_name = default_name; diff --git a/src/TensorFlowNET.Core/ops.py.cs b/src/TensorFlowNET.Core/ops.py.cs index 4a1a6bcc..92312c8f 100644 --- a/src/TensorFlowNET.Core/ops.py.cs +++ b/src/TensorFlowNET.Core/ops.py.cs @@ -12,15 +12,21 @@ namespace Tensorflow { public partial class ops { - public static void add_to_collection(string name, object value) + public static void add_to_collection(string name, T value) { var graph = tf.get_default_graph(); graph.add_to_collection(name, value); } - public static _VariableScopeStore get_collection(string key) + public static void add_to_collections(List names, T value) { - return null;// get_default_graph().get_collection(key); + var graph = tf.get_default_graph(); + graph.add_to_collections(names, value); + } + + public static object get_collection(string key) + { + return get_default_graph().get_collection(key); } public static Graph get_default_graph() diff --git a/test/TensorFlowNET.Examples/LinearRegression.cs b/test/TensorFlowNET.Examples/LinearRegression.cs index 2dc2d297..849db828 100644 --- a/test/TensorFlowNET.Examples/LinearRegression.cs +++ b/test/TensorFlowNET.Examples/LinearRegression.cs @@ -27,12 +27,12 @@ namespace TensorFlowNET.Examples var train_Y = np.array(1.7, 2.76, 2.09, 3.19, 1.694, 1.573, 3.366, 2.596, 2.53, 1.221, 2.827, 3.465, 1.65, 2.904, 2.42, 2.94, 1.3); var n_samples = train_X.shape[0]; - + // tf Graph Input var X = tf.placeholder(tf.float64); var Y = tf.placeholder(tf.float64); - // Set model weights + // Set model weights var W = tf.Variable(rng.randn(), name: "weight"); var b = tf.Variable(rng.randn(), name: "bias");