Browse Source

add RefVariable.read_value()

tags/v0.12
Oceania2018 6 years ago
parent
commit
ecbda0cdb9
8 changed files with 67 additions and 8 deletions
  1. +1
    -1
      src/TensorFlowNET.Core/APIs/tf.train.cs
  2. +19
    -3
      src/TensorFlowNET.Core/Train/ExponentialMovingAverage.cs
  3. +23
    -0
      src/TensorFlowNET.Core/Variables/RefVariable.cs
  4. +6
    -0
      src/TensorFlowNET.Core/Variables/gen_state_ops.py.cs
  5. +8
    -0
      src/TensorFlowNET.Core/Variables/state_ops.cs
  6. +6
    -1
      test/TensorFlowNET.Examples/ImageProcessing/YOLO/Main.cs
  7. +3
    -0
      test/TensorFlowNET.UnitTest/ImageTest.cs
  8. +1
    -3
      test/TensorFlowNET.UnitTest/NameScopeTest.cs

+ 1
- 1
src/TensorFlowNET.Core/APIs/tf.train.cs View File

@@ -31,7 +31,7 @@ namespace Tensorflow
public Optimizer AdamOptimizer(float learning_rate, string name = "Adam")
=> new AdamOptimizer(learning_rate, name: name);

public object ExponentialMovingAverage(float decay)
public ExponentialMovingAverage ExponentialMovingAverage(float decay)
=> new ExponentialMovingAverage(decay);

public Saver Saver(VariableV1[] var_list = null) => new Saver(var_list: var_list);


+ 19
- 3
src/TensorFlowNET.Core/Train/ExponentialMovingAverage.cs View File

@@ -1,6 +1,8 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using static Tensorflow.Binding;

namespace Tensorflow.Train
{
@@ -11,6 +13,7 @@ namespace Tensorflow.Train
bool _zero_debias;
string _name;
public string name => _name;
List<VariableV1> _averages;

public ExponentialMovingAverage(float decay, int? num_updates = null, bool zero_debias = false,
string name = "ExponentialMovingAverage")
@@ -19,6 +22,7 @@ namespace Tensorflow.Train
_num_updates = num_updates;
_zero_debias = zero_debias;
_name = name;
_averages = new List<VariableV1>();
}

/// <summary>
@@ -26,11 +30,23 @@ namespace Tensorflow.Train
/// </summary>
/// <param name="var_list"></param>
/// <returns></returns>
public Operation apply(VariableV1[] var_list = null)
public Operation apply(RefVariable[] var_list = null)
{
throw new NotImplementedException("");
}
if (var_list == null)
var_list = variables.trainable_variables() as RefVariable[];

foreach(var var in var_list)
{
if (!_averages.Contains(var))
{
ops.init_scope();
var slot = new SlotCreator();
var.initialized_value();
// var avg = slot.create_zeros_slot
}
}

throw new NotImplementedException("");
}
}
}

+ 23
- 0
src/TensorFlowNET.Core/Variables/RefVariable.cs View File

@@ -308,5 +308,28 @@ namespace Tensorflow
{
throw new NotImplementedException();
}

/// <summary>
/// Returns the value of this variable, read in the current context.
/// </summary>
/// <returns></returns>
private ITensorOrOperation read_value()
{
return array_ops.identity(_variable, name: "read");
}

public Tensor is_variable_initialized(RefVariable variable)
{
return state_ops.is_variable_initialized(variable);
}

public Tensor initialized_value()
{
ops.init_scope();
throw new NotImplementedException("");
/*return control_flow_ops.cond(is_variable_initialized(this),
read_value,
() => initial_value);*/
}
}
}

+ 6
- 0
src/TensorFlowNET.Core/Variables/gen_state_ops.py.cs View File

@@ -14,6 +14,7 @@
limitations under the License.
******************************************************************************/
using System;
using System.Collections.Generic;
using Tensorflow.Eager;
@@ -145,5 +146,10 @@ namespace Tensorflow
var _op = _op_def_lib._apply_op_helper("ScatterAdd", name: name, args: new { @ref, indices, updates, use_locking });
return _op.outputs[0];
}
public static Tensor is_variable_initialized(RefVariable @ref, string name = null)
{
throw new NotImplementedException("");
}
}
}

+ 8
- 0
src/TensorFlowNET.Core/Variables/state_ops.cs View File

@@ -106,5 +106,13 @@ namespace Tensorflow
throw new NotImplementedException("scatter_add");
}
public static Tensor is_variable_initialized(RefVariable @ref, string name = null)
{
if (@ref.dtype.is_ref_dtype())
return gen_state_ops.is_variable_initialized(@ref: @ref, name: name);
throw new NotImplementedException("");
//return @ref.is_initialized(name: name);
}
}
}

+ 6
- 1
test/TensorFlowNET.Examples/ImageProcessing/YOLO/Main.cs View File

@@ -91,7 +91,12 @@ namespace TensorFlowNET.Examples.ImageProcessing.YOLO

tf_with(tf.name_scope("define_loss"), scope =>
{
model = new YOLOv3(cfg, input_data, trainable);
// model = new YOLOv3(cfg, input_data, trainable);
});

tf_with(tf.name_scope("define_weight_decay"), scope =>
{
var moving_ave = tf.train.ExponentialMovingAverage(moving_ave_decay).apply((RefVariable[])tf.trainable_variables());
});

return graph;


+ 3
- 0
test/TensorFlowNET.UnitTest/ImageTest.cs View File

@@ -8,6 +8,9 @@ using static Tensorflow.Binding;

namespace TensorFlowNET.UnitTest
{
/// <summary>
/// Find more examples in https://www.programcreek.com/python/example/90444/tensorflow.read_file
/// </summary>
[TestClass]
public class ImageTest
{


+ 1
- 3
test/TensorFlowNET.UnitTest/NameScopeTest.cs View File

@@ -69,9 +69,7 @@ namespace TensorFlowNET.UnitTest
Assert.AreEqual("scope1", g._name_stack);
var const3 = tf.constant(2.0);
Assert.AreEqual("scope1/Const_1:0", const3.name);
}

;
};

g.Dispose();



Loading…
Cancel
Save