Browse Source

fix add_collections

tags/v0.1.0-Tensor
haiping008 6 years ago
parent
commit
bcb803dfe7
10 changed files with 114 additions and 20 deletions
  1. +13
    -4
      src/TensorFlowNET.Core/Graphs/Graph.cs
  2. +9
    -2
      src/TensorFlowNET.Core/Operations/OpDefLibrary.cs
  3. +16
    -0
      src/TensorFlowNET.Core/Tensors/tensor_util.cs
  4. +24
    -4
      src/TensorFlowNET.Core/Train/Optimizer.cs
  5. +22
    -0
      src/TensorFlowNET.Core/Variables/RefVariable.cs
  6. +2
    -0
      src/TensorFlowNET.Core/Variables/gen_state_ops.py.cs
  7. +16
    -4
      src/TensorFlowNET.Core/Variables/variable_scope.py.cs
  8. +1
    -1
      src/TensorFlowNET.Core/ops.name_scope.cs
  9. +9
    -3
      src/TensorFlowNET.Core/ops.py.cs
  10. +2
    -2
      test/TensorFlowNET.Examples/LinearRegression.cs

+ 13
- 4
src/TensorFlowNET.Core/Graphs/Graph.cs View File

@@ -91,9 +91,18 @@ namespace Tensorflow
throw new Exception($"Can not convert a {typeof(T).Name} into a {types_str}.");
}

public void add_to_collection(string name, object value)
public void add_to_collection<T>(string name, T value)
{
_collections[name] = value;
if (_collections.ContainsKey(name))
(_collections[name] as List<T>).Add(value);
else
_collections[name] = new List<T> { value };
}

public void add_to_collections<T>(List<string> names, T value)
{
foreach (string name in names)
add_to_collection(name, value);
}

public unsafe Operation create_op(string op_type, List<Tensor> inputs, TF_DataType[] dtypes,
@@ -236,9 +245,9 @@ namespace Tensorflow
return _nodes_by_name.Values.Select(x => x).ToArray();
}

public Dictionary<string, object> get_collection(string name)
public object get_collection(string name)
{
return _collections;
return _collections.ContainsKey(name) ? _collections[name] : null;
}

public void Dispose()


+ 9
- 2
src/TensorFlowNET.Core/Operations/OpDefLibrary.cs View File

@@ -20,7 +20,7 @@ namespace Tensorflow
name = op_type_name;
}

string scope = g.unique_name(name) + "/";
string scope = new ops.name_scope(name);

var default_type_attr_map = new Dictionary<string, object>();
foreach (var attr_def in op_def.Attr)
@@ -88,6 +88,9 @@ namespace Tensorflow
switch (attr_def.Type)
{
case "string":
attr_value.S = Google.Protobuf.ByteString.CopyFromUtf8((string)value);
break;
case "type":
attr_value.Type = _MakeType((TF_DataType)value, attr_def);
break;
@@ -95,8 +98,12 @@ namespace Tensorflow
attr_value.B = (bool)value;
break;
case "shape":
attr_value.Shape = new TensorShapeProto();
attr_value.Shape = value == null ?
attr_def.DefaultValue.Shape :
tensor_util.as_shape((long[])value);
break;
default:
throw new InvalidDataException($"attr_def.Type {attr_def.Type}");
}

attr_protos[key] = attr_value;


+ 16
- 0
src/TensorFlowNET.Core/Tensors/tensor_util.cs View File

@@ -73,6 +73,22 @@ namespace Tensorflow
return nd;
}

public static TensorShapeProto as_shape(long[] dims)
{
TensorShapeProto shape = new TensorShapeProto();

for (int i = 0; i < dims.Length; i++)
{
var dim = new TensorShapeProto.Types.Dim();
dim.Size = dims[i];
dim.Name = $"dim_{i}";

shape.Dim.Add(dim);
}

return shape;
}

public static TensorShape as_shape(this IShape shape, int[] dims)
{
return new TensorShape(dims);


+ 24
- 4
src/TensorFlowNET.Core/Train/Optimizer.cs View File

@@ -30,9 +30,14 @@ namespace Tensorflow
/// </summary>
/// <param name="loss"></param>
/// <returns></returns>
public Optimizer minimize(Tensor loss, GateGradientType gate_gradients = GateGradientType.GATE_OP)
public Optimizer minimize(Tensor loss,
GateGradientType gate_gradients = GateGradientType.GATE_OP,
bool colocate_gradients_with_ops = false)
{
compute_gradients(loss, gate_gradients);
compute_gradients(loss,
gate_gradients: gate_gradients,
colocate_gradients_with_ops: colocate_gradients_with_ops);

return this;
}

@@ -41,7 +46,10 @@ namespace Tensorflow
/// </summary>
/// <param name="loss"></param>
/// <param name="gate_gradients"></param>
public List<KeyValuePair<object, object>> compute_gradients(Tensor loss, GateGradientType gate_gradients = GateGradientType.GATE_OP)
public List<KeyValuePair<object, object>> compute_gradients(Tensor loss,
List<RefVariable> var_list = null,
GateGradientType gate_gradients = GateGradientType.GATE_OP,
bool colocate_gradients_with_ops = false)
{
int num_towers = 1;
if(distribute_lib.get_loss_reduction() == VariableAggregationType.MEAN)
@@ -49,7 +57,19 @@ namespace Tensorflow
}

var var_list = variables.trainable_variables();
var tmp = variables.trainable_variables();
switch (tmp)
{
case List<RefVariable> values:
var_list = values;
break;
}

foreach(var v in var_list)
{

}

return null;
}
}


+ 22
- 0
src/TensorFlowNET.Core/Variables/RefVariable.cs View File

@@ -64,6 +64,28 @@ namespace Tensorflow
var shape = _initial_value.shape;
dtype = _initial_value.dtype;
_variable = gen_state_ops.variable_v2(shape, dtype, name);

// Manually overrides the variable's shape with the initial value's.
if (validate_shape)
{
var initial_value_shape = _initial_value.shape;
}

// If 'initial_value' makes use of other variables, make sure we don't
// have an issue if these other variables aren't initialized first by
// using their initialized_value() method.

ops.add_to_collections(collections, this);
}

public static implicit operator _VariableScopeStore(RefVariable variable)
{
return null;
}

public static implicit operator RefVariable(_VariableScopeStore store)
{
return null;
}
}
}

+ 2
- 0
src/TensorFlowNET.Core/Variables/gen_state_ops.py.cs View File

@@ -23,6 +23,8 @@ namespace Tensorflow
var keywords = new Dictionary<string, object>();
keywords.Add("dtype", dtype);
keywords.Add("shape", shape);
keywords.Add("container", container);
keywords.Add("shared_name", shared_name);

var _op = _op_def_lib._apply_op_helper("VariableV2", name: name, keywords: keywords);



+ 16
- 4
src/TensorFlowNET.Core/Variables/variable_scope.py.cs View File

@@ -39,18 +39,30 @@ namespace Tensorflow

public static _VariableScopeStore get_variable_scope_store()
{
_VariableScopeStore ret = null;
var scope_store = ops.get_collection(_VARSCOPESTORE_KEY);
if (scope_store == null)
{
scope_store = new _VariableScopeStore();
ops.add_to_collection(_VARSCOPESTORE_KEY, scope_store);
ret = new _VariableScopeStore();
ops.add_to_collection(_VARSCOPESTORE_KEY, ret);
}
else
{
// scope_store = scope_store[0];
switch (scope_store)
{
case List<RefVariable> values:
ret = values[0];
break;
case List<_VariableScopeStore> values:
ret = values[0];
break;
default:
throw new InvalidOperationException("get_variable_scope_store");
}
}

return scope_store;
return ret;
}

public static bool _get_trainable_value(VariableSynchronization synchronization, bool? trainable = null)


+ 1
- 1
src/TensorFlowNET.Core/ops.name_scope.cs View File

@@ -14,7 +14,7 @@ namespace Tensorflow
public Context _ctx;
public string _name_scope;

public name_scope(string name, string default_name, List<object> values)
public name_scope(string name, string default_name = "", List<object> values = null)
{
_name = name;
_default_name = default_name;


+ 9
- 3
src/TensorFlowNET.Core/ops.py.cs View File

@@ -12,15 +12,21 @@ namespace Tensorflow
{
public partial class ops
{
public static void add_to_collection(string name, object value)
public static void add_to_collection<T>(string name, T value)
{
var graph = tf.get_default_graph();
graph.add_to_collection(name, value);
}

public static _VariableScopeStore get_collection(string key)
public static void add_to_collections<T>(List<string> names, T value)
{
return null;// get_default_graph().get_collection(key);
var graph = tf.get_default_graph();
graph.add_to_collections(names, value);
}

public static object get_collection(string key)
{
return get_default_graph().get_collection(key);
}

public static Graph get_default_graph()


+ 2
- 2
test/TensorFlowNET.Examples/LinearRegression.cs View File

@@ -27,12 +27,12 @@ namespace TensorFlowNET.Examples
var train_Y = np.array(1.7, 2.76, 2.09, 3.19, 1.694, 1.573, 3.366, 2.596, 2.53, 1.221,
2.827, 3.465, 1.65, 2.904, 2.42, 2.94, 1.3);
var n_samples = train_X.shape[0];
// tf Graph Input
var X = tf.placeholder(tf.float64);
var Y = tf.placeholder(tf.float64);

// Set model weights
// Set model weights
var W = tf.Variable(rng.randn<double>(), name: "weight");
var b = tf.Variable(rng.randn<double>(), name: "bias");



Loading…
Cancel
Save