From ecbda0cdb98faae2b80eb4e4339898f708a01e2f Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Wed, 28 Aug 2019 19:05:48 -0500 Subject: [PATCH] add RefVariable.read_value() --- src/TensorFlowNET.Core/APIs/tf.train.cs | 2 +- .../Train/ExponentialMovingAverage.cs | 22 +++++++++++++++--- .../Variables/RefVariable.cs | 23 +++++++++++++++++++ .../Variables/gen_state_ops.py.cs | 6 +++++ src/TensorFlowNET.Core/Variables/state_ops.cs | 8 +++++++ .../ImageProcessing/YOLO/Main.cs | 7 +++++- test/TensorFlowNET.UnitTest/ImageTest.cs | 3 +++ test/TensorFlowNET.UnitTest/NameScopeTest.cs | 4 +--- 8 files changed, 67 insertions(+), 8 deletions(-) diff --git a/src/TensorFlowNET.Core/APIs/tf.train.cs b/src/TensorFlowNET.Core/APIs/tf.train.cs index 6d8a9935..a943308b 100644 --- a/src/TensorFlowNET.Core/APIs/tf.train.cs +++ b/src/TensorFlowNET.Core/APIs/tf.train.cs @@ -31,7 +31,7 @@ namespace Tensorflow public Optimizer AdamOptimizer(float learning_rate, string name = "Adam") => new AdamOptimizer(learning_rate, name: name); - public object ExponentialMovingAverage(float decay) + public ExponentialMovingAverage ExponentialMovingAverage(float decay) => new ExponentialMovingAverage(decay); public Saver Saver(VariableV1[] var_list = null) => new Saver(var_list: var_list); diff --git a/src/TensorFlowNET.Core/Train/ExponentialMovingAverage.cs b/src/TensorFlowNET.Core/Train/ExponentialMovingAverage.cs index 81e491ac..e129edce 100644 --- a/src/TensorFlowNET.Core/Train/ExponentialMovingAverage.cs +++ b/src/TensorFlowNET.Core/Train/ExponentialMovingAverage.cs @@ -1,6 +1,8 @@ using System; using System.Collections.Generic; +using System.Linq; using System.Text; +using static Tensorflow.Binding; namespace Tensorflow.Train { @@ -11,6 +13,7 @@ namespace Tensorflow.Train bool _zero_debias; string _name; public string name => _name; + List _averages; public ExponentialMovingAverage(float decay, int? num_updates = null, bool zero_debias = false, string name = "ExponentialMovingAverage") @@ -19,6 +22,7 @@ namespace Tensorflow.Train _num_updates = num_updates; _zero_debias = zero_debias; _name = name; + _averages = new List(); } /// @@ -26,11 +30,23 @@ namespace Tensorflow.Train /// /// /// - public Operation apply(VariableV1[] var_list = null) + public Operation apply(RefVariable[] var_list = null) { - throw new NotImplementedException(""); - } + if (var_list == null) + var_list = variables.trainable_variables() as RefVariable[]; + foreach(var var in var_list) + { + if (!_averages.Contains(var)) + { + ops.init_scope(); + var slot = new SlotCreator(); + var.initialized_value(); + // var avg = slot.create_zeros_slot + } + } + throw new NotImplementedException(""); + } } } diff --git a/src/TensorFlowNET.Core/Variables/RefVariable.cs b/src/TensorFlowNET.Core/Variables/RefVariable.cs index eb0f7316..e0e3e0f7 100644 --- a/src/TensorFlowNET.Core/Variables/RefVariable.cs +++ b/src/TensorFlowNET.Core/Variables/RefVariable.cs @@ -308,5 +308,28 @@ namespace Tensorflow { throw new NotImplementedException(); } + + /// + /// Returns the value of this variable, read in the current context. + /// + /// + private ITensorOrOperation read_value() + { + return array_ops.identity(_variable, name: "read"); + } + + public Tensor is_variable_initialized(RefVariable variable) + { + return state_ops.is_variable_initialized(variable); + } + + public Tensor initialized_value() + { + ops.init_scope(); + throw new NotImplementedException(""); + /*return control_flow_ops.cond(is_variable_initialized(this), + read_value, + () => initial_value);*/ + } } } diff --git a/src/TensorFlowNET.Core/Variables/gen_state_ops.py.cs b/src/TensorFlowNET.Core/Variables/gen_state_ops.py.cs index af34a2ba..5c8744b6 100644 --- a/src/TensorFlowNET.Core/Variables/gen_state_ops.py.cs +++ b/src/TensorFlowNET.Core/Variables/gen_state_ops.py.cs @@ -14,6 +14,7 @@ limitations under the License. ******************************************************************************/ +using System; using System.Collections.Generic; using Tensorflow.Eager; @@ -145,5 +146,10 @@ namespace Tensorflow var _op = _op_def_lib._apply_op_helper("ScatterAdd", name: name, args: new { @ref, indices, updates, use_locking }); return _op.outputs[0]; } + + public static Tensor is_variable_initialized(RefVariable @ref, string name = null) + { + throw new NotImplementedException(""); + } } } diff --git a/src/TensorFlowNET.Core/Variables/state_ops.cs b/src/TensorFlowNET.Core/Variables/state_ops.cs index 502c3c1e..8f478f2d 100644 --- a/src/TensorFlowNET.Core/Variables/state_ops.cs +++ b/src/TensorFlowNET.Core/Variables/state_ops.cs @@ -106,5 +106,13 @@ namespace Tensorflow throw new NotImplementedException("scatter_add"); } + + public static Tensor is_variable_initialized(RefVariable @ref, string name = null) + { + if (@ref.dtype.is_ref_dtype()) + return gen_state_ops.is_variable_initialized(@ref: @ref, name: name); + throw new NotImplementedException(""); + //return @ref.is_initialized(name: name); + } } } diff --git a/test/TensorFlowNET.Examples/ImageProcessing/YOLO/Main.cs b/test/TensorFlowNET.Examples/ImageProcessing/YOLO/Main.cs index 57770866..935b9914 100644 --- a/test/TensorFlowNET.Examples/ImageProcessing/YOLO/Main.cs +++ b/test/TensorFlowNET.Examples/ImageProcessing/YOLO/Main.cs @@ -91,7 +91,12 @@ namespace TensorFlowNET.Examples.ImageProcessing.YOLO tf_with(tf.name_scope("define_loss"), scope => { - model = new YOLOv3(cfg, input_data, trainable); + // model = new YOLOv3(cfg, input_data, trainable); + }); + + tf_with(tf.name_scope("define_weight_decay"), scope => + { + var moving_ave = tf.train.ExponentialMovingAverage(moving_ave_decay).apply((RefVariable[])tf.trainable_variables()); }); return graph; diff --git a/test/TensorFlowNET.UnitTest/ImageTest.cs b/test/TensorFlowNET.UnitTest/ImageTest.cs index 4b6d5922..e4f8a835 100644 --- a/test/TensorFlowNET.UnitTest/ImageTest.cs +++ b/test/TensorFlowNET.UnitTest/ImageTest.cs @@ -8,6 +8,9 @@ using static Tensorflow.Binding; namespace TensorFlowNET.UnitTest { + /// + /// Find more examples in https://www.programcreek.com/python/example/90444/tensorflow.read_file + /// [TestClass] public class ImageTest { diff --git a/test/TensorFlowNET.UnitTest/NameScopeTest.cs b/test/TensorFlowNET.UnitTest/NameScopeTest.cs index 3d763b38..5c307b9f 100644 --- a/test/TensorFlowNET.UnitTest/NameScopeTest.cs +++ b/test/TensorFlowNET.UnitTest/NameScopeTest.cs @@ -69,9 +69,7 @@ namespace TensorFlowNET.UnitTest Assert.AreEqual("scope1", g._name_stack); var const3 = tf.constant(2.0); Assert.AreEqual("scope1/Const_1:0", const3.name); - } - - ; + }; g.Dispose();