Browse Source

Graph.unique_name fixed + test case

tags/v0.9
Meinrad Recheis 6 years ago
parent
commit
40af0c5f8a
4 changed files with 106 additions and 85 deletions
  1. +47
    -40
      src/TensorFlowNET.Core/Graphs/Graph.cs
  2. +24
    -24
      test/TensorFlowNET.UnitTest/ControlDependenciesTest.cs
  3. +27
    -20
      test/TensorFlowNET.UnitTest/CreateOpFromTfOperationTest.cs
  4. +8
    -1
      test/TensorFlowNET.UnitTest/PythonTest.cs

+ 47
- 40
src/TensorFlowNET.Core/Graphs/Graph.cs View File

@@ -73,8 +73,8 @@ namespace Tensorflow
return var._as_graph_element(); return var._as_graph_element();


return null; return null;
}
}
private ITensorOrOperation _as_graph_element_locked(object obj, bool allow_tensor = true, bool allow_operation = true) private ITensorOrOperation _as_graph_element_locked(object obj, bool allow_tensor = true, bool allow_operation = true)
{ {
string types_str = ""; string types_str = "";
@@ -99,7 +99,7 @@ namespace Tensorflow
// If obj appears to be a name... // If obj appears to be a name...
if (obj is string name) if (obj is string name)
{ {
if(name.Contains(":") && allow_tensor)
if (name.Contains(":") && allow_tensor)
{ {
string op_name = name.Split(':')[0]; string op_name = name.Split(':')[0];
int out_n = int.Parse(name.Split(':')[1]); int out_n = int.Parse(name.Split(':')[1]);
@@ -107,7 +107,7 @@ namespace Tensorflow
if (_nodes_by_name.ContainsKey(op_name)) if (_nodes_by_name.ContainsKey(op_name))
return _nodes_by_name[op_name].outputs[out_n]; return _nodes_by_name[op_name].outputs[out_n];
} }
else if(!name.Contains(":") & allow_operation)
else if (!name.Contains(":") & allow_operation)
{ {
if (!_nodes_by_name.ContainsKey(name)) if (!_nodes_by_name.ContainsKey(name))
throw new KeyError($"The name {name} refers to an Operation not in the graph."); throw new KeyError($"The name {name} refers to an Operation not in the graph.");
@@ -166,8 +166,8 @@ namespace Tensorflow
throw new RuntimeError("Graph is finalized and cannot be modified."); throw new RuntimeError("Graph is finalized and cannot be modified.");
} }


public unsafe Operation create_op(string op_type, Tensor[] inputs, TF_DataType[] dtypes,
TF_DataType[] input_types = null, string name = null,
public unsafe Operation create_op(string op_type, Tensor[] inputs, TF_DataType[] dtypes,
TF_DataType[] input_types = null, string name = null,
Dictionary<string, AttrValue> attrs = null, OpDef op_def = null) Dictionary<string, AttrValue> attrs = null, OpDef op_def = null)
{ {
if (inputs == null) if (inputs == null)
@@ -188,7 +188,7 @@ namespace Tensorflow
var input_ops = inputs.Select(x => x.op).ToArray(); var input_ops = inputs.Select(x => x.op).ToArray();
var control_inputs = _control_dependencies_for_inputs(input_ops); var control_inputs = _control_dependencies_for_inputs(input_ops);


var op = new Operation(node_def,
var op = new Operation(node_def,
this, this,
inputs: inputs, inputs: inputs,
output_types: dtypes, output_types: dtypes,
@@ -259,54 +259,61 @@ namespace Tensorflow
_name_stack = new_stack; _name_stack = new_stack;


return String.IsNullOrEmpty(new_stack) ? "" : new_stack + "/"; return String.IsNullOrEmpty(new_stack) ? "" : new_stack + "/";
}

}
/// <summary>
/// Return a unique operation name for `name`.
///
/// Note: You rarely need to call `unique_name()` directly.Most of
/// the time you just need to create `with g.name_scope()` blocks to
/// generate structured names.
///
/// `unique_name` is used to generate structured names, separated by
/// `"/"`, to help identify operations when debugging a graph.
/// Operation names are displayed in error messages reported by the
/// TensorFlow runtime, and in various visualization tools such as
/// TensorBoard.
///
/// If `mark_as_used` is set to `True`, which is the default, a new
/// unique name is created and marked as in use.If it's set to `False`,
/// the unique name is returned without actually being marked as used.
/// This is useful when the caller simply wants to know what the name
/// to be created will be.
/// </summary>
/// <param name="name">The name for an operation.</param>
/// <param name="mark_as_used"> Whether to mark this name as being used.</param>
/// <returns>A string to be passed to `create_op()` that will be used
/// to name the operation being created.</returns>
public string unique_name(string name, bool mark_as_used = true) public string unique_name(string name, bool mark_as_used = true)
{ {
if (!String.IsNullOrEmpty(_name_stack)) if (!String.IsNullOrEmpty(_name_stack))
{
name = _name_stack + "/" + name; name = _name_stack + "/" + name;
}

// For the sake of checking for names in use, we treat names as case
// insensitive (e.g. foo = Foo).
var name_key = name.ToLower(); var name_key = name.ToLower();
int i = 0; int i = 0;
if (_names_in_use.ContainsKey(name_key)) if (_names_in_use.ContainsKey(name_key))
{
foreach (var item in _names_in_use)
{
if (item.Key == name_key)
{
i = _names_in_use[name_key];
break;
}
i++;
}
}

i = _names_in_use[name_key];
// Increment the number for "name_key".
if (mark_as_used) if (mark_as_used)
if (_names_in_use.ContainsKey(name_key))
_names_in_use[name_key]++;
else
_names_in_use[name_key] = i + 1;
_names_in_use[name_key] = i + 1;
if (i > 0) if (i > 0)
{ {
var base_name_key = name_key;

// Make sure the composed name key is not already used. // Make sure the composed name key is not already used.
if (_names_in_use.ContainsKey(name_key))
var base_name_key = name_key;
while (_names_in_use.ContainsKey(name_key))
{ {
name_key = $"{base_name_key}_{i}"; name_key = $"{base_name_key}_{i}";
i += 1; i += 1;
} }

// Mark the composed name_key as used in case someone wants
// to call unique_name("name_1").
if (mark_as_used) if (mark_as_used)
_names_in_use[name_key] = 1; _names_in_use[name_key] = 1;


name = $"{name}_{i - 1}";
// Return the new name with the original capitalization of the given name.
name = $"{name}_{i-1}";
} }

return name; return name;
} }


@@ -375,8 +382,8 @@ namespace Tensorflow
public void prevent_fetching(Operation op) public void prevent_fetching(Operation op)
{ {
_unfetchable_ops.Add(op); _unfetchable_ops.Add(op);
}
}
public void Dispose() public void Dispose()
{ {
c_api.TF_DeleteGraph(_handle); c_api.TF_DeleteGraph(_handle);
@@ -387,8 +394,8 @@ namespace Tensorflow
} }


public void __exit__() public void __exit__()
{
{
} }


public static implicit operator IntPtr(Graph graph) public static implicit operator IntPtr(Graph graph)


+ 24
- 24
test/TensorFlowNET.UnitTest/ControlDependenciesTest.cs View File

@@ -157,8 +157,8 @@ namespace TensorFlowNET.UnitTest
}); });
}); });
}); });
AssertItemsEqual(new[] { a_1.op, a_2.op, a_3.op, a_4.op }, b_1.op.control_inputs);
AssertItemsEqual(b_1.op.control_inputs, b_2.op.control_inputs);
assertItemsEqual(new[] { a_1.op, a_2.op, a_3.op, a_4.op }, b_1.op.control_inputs);
assertItemsEqual(b_1.op.control_inputs, b_2.op.control_inputs);
} }
[TestMethod] [TestMethod]
@@ -200,12 +200,12 @@ namespace TensorFlowNET.UnitTest
b_none2 = constant_op.constant(12.0); b_none2 = constant_op.constant(12.0);
}); });
}); });
AssertItemsEqual(new[] { a_3.op, a_4.op }, b_3_4.op.control_inputs);
AssertItemsEqual(new[] { a_3.op }, b_3.op.control_inputs);
AssertItemsEqual(new object[0], b_none.op.control_inputs);
AssertItemsEqual(new[] { a_1.op, a_2.op }, b_1_2.op.control_inputs);
AssertItemsEqual(new[] { a_1.op }, b_1.op.control_inputs);
AssertItemsEqual(new object[0], b_none2.op.control_inputs);
assertItemsEqual(new[] { a_3.op, a_4.op }, b_3_4.op.control_inputs);
assertItemsEqual(new[] { a_3.op }, b_3.op.control_inputs);
assertItemsEqual(new object[0], b_none.op.control_inputs);
assertItemsEqual(new[] { a_1.op, a_2.op }, b_1_2.op.control_inputs);
assertItemsEqual(new[] { a_1.op }, b_1.op.control_inputs);
assertItemsEqual(new object[0], b_none2.op.control_inputs);
} }
[TestMethod] [TestMethod]
@@ -256,25 +256,25 @@ namespace TensorFlowNET.UnitTest
}); });
}); });
AssertItemsEqual(new[] {a_1.op}, b_1.op.control_inputs);
AssertItemsEqual(new[] {a_1.op, a_2.op}, b_2.op.control_inputs);
AssertItemsEqual(new[] { a_1.op, a_2.op}, b_3.op.control_inputs);
AssertItemsEqual(new[] {a_1.op, a_2.op}, b_4.op.control_inputs);
assertItemsEqual(new[] {a_1.op}, b_1.op.control_inputs);
assertItemsEqual(new[] {a_1.op, a_2.op}, b_2.op.control_inputs);
assertItemsEqual(new[] { a_1.op, a_2.op}, b_3.op.control_inputs);
assertItemsEqual(new[] {a_1.op, a_2.op}, b_4.op.control_inputs);
AssertItemsEqual(new object[0], c_1.op.control_inputs);
AssertItemsEqual(new[] {a_2.op}, c_2.op.control_inputs);
AssertItemsEqual(new[] {a_2.op, a_3.op}, c_3.op.control_inputs);
AssertItemsEqual(new[] {a_2.op, a_3.op, a_4.op}, c_4.op.control_inputs);
assertItemsEqual(new object[0], c_1.op.control_inputs);
assertItemsEqual(new[] {a_2.op}, c_2.op.control_inputs);
assertItemsEqual(new[] {a_2.op, a_3.op}, c_3.op.control_inputs);
assertItemsEqual(new[] {a_2.op, a_3.op, a_4.op}, c_4.op.control_inputs);
AssertItemsEqual(new object[0], d_1.op.control_inputs);
AssertItemsEqual(new object[0], d_2.op.control_inputs);
AssertItemsEqual(new object[0], d_3.op.control_inputs);
AssertItemsEqual(new object[0], d_4.op.control_inputs);
assertItemsEqual(new object[0], d_1.op.control_inputs);
assertItemsEqual(new object[0], d_2.op.control_inputs);
assertItemsEqual(new object[0], d_3.op.control_inputs);
assertItemsEqual(new object[0], d_4.op.control_inputs);
AssertItemsEqual(new[] {a_1.op}, e_1.op.control_inputs);
AssertItemsEqual(new[] {a_2.op}, e_2.op.control_inputs);
AssertItemsEqual(new[] {a_3.op}, e_3.op.control_inputs);
AssertItemsEqual(new[] {a_4.op}, e_4.op.control_inputs);
assertItemsEqual(new[] {a_1.op}, e_1.op.control_inputs);
assertItemsEqual(new[] {a_2.op}, e_2.op.control_inputs);
assertItemsEqual(new[] {a_3.op}, e_3.op.control_inputs);
assertItemsEqual(new[] {a_4.op}, e_4.op.control_inputs);
} }
[Ignore("Don't know how to create an operation with two outputs")] [Ignore("Don't know how to create an operation with two outputs")]


+ 27
- 20
test/TensorFlowNET.UnitTest/CreateOpFromTfOperationTest.cs View File

@@ -33,28 +33,35 @@ namespace TensorFlowNET.UnitTest
Assert.AreEqual("myop", op.name); Assert.AreEqual("myop", op.name);
Assert.AreEqual("Identity", op.type); Assert.AreEqual("Identity", op.type);
Assert.AreEqual(1, len(op.outputs)); Assert.AreEqual(1, len(op.outputs));
AssertItemsEqual(new []{2, 3}, op.outputs[0].shape);
assertItemsEqual(new []{2, 3}, op.outputs[0].shape);
}); });
} }
/*def testUniqueName(self):
g = ops.Graph()
with g.as_default():
c_op = ops._create_c_op(g, ops._NodeDef("IntOutput", "myop"), [], [])
c_op2 = ops._create_c_op(g, ops._NodeDef("IntOutput", "myop_1"), [], [])
op = g._create_op_from_tf_operation(c_op)
op2 = g._create_op_from_tf_operation(c_op2)
# Create ops with same names as op1 and op2. We expect the new names to be
# uniquified.
op3 = test_ops.int_output(name="myop").op
op4 = test_ops.int_output(name="myop_1").op
self.assertEqual(op.name, "myop")
self.assertEqual(op2.name, "myop_1")
self.assertEqual(op3.name, "myop_2")
self.assertEqual(op4.name, "myop_1_1")
[TestMethod]
public void TestUniqueName()
{
var graph = tf.Graph().as_default();
with<Graph>(graph, g =>
{
//var (c_op,op_desc) = ops._create_c_op(g, ops._NodeDef("Const", "myop"), new Tensor[0], new Operation[0]);
//var (c_op2, op_desc1) = ops._create_c_op(g, ops._NodeDef("Const", "myop_1"), new Tensor[0], new Operation[0]);
//var op = g._create_op_from_tf_operation(c_op);
//var op2 = g._create_op_from_tf_operation(c_op2);
var op = constant_op.constant(0, name:"myop").op;
var op2 = constant_op.constant(0, name: "myop_1").op;
// Create ops with same names as op1 and op2. We expect the new names to be
// uniquified.
var op3 = constant_op.constant(0, name: "myop").op;
var op4 = constant_op.constant(0, name: "myop_1").op;
self.assertEqual(op.name, "myop");
self.assertEqual(op2.name, "myop_1");
self.assertEqual(op3.name, "myop_2");
self.assertEqual(op4.name, "myop_1_1");
});
}
/*
@test_util.run_v1_only("b/120545219") @test_util.run_v1_only("b/120545219")
def testCond(self): def testCond(self):
g = ops.Graph() g = ops.Graph()
@@ -164,5 +171,5 @@ namespace TensorFlowNET.UnitTest
*/ */
}
}
} }

+ 8
- 1
test/TensorFlowNET.UnitTest/PythonTest.cs View File

@@ -13,7 +13,7 @@ namespace TensorFlowNET.UnitTest
/// </summary> /// </summary>
public class PythonTest : Python public class PythonTest : Python
{ {
public void AssertItemsEqual(ICollection expected, ICollection given)
public void assertItemsEqual(ICollection expected, ICollection given)
{ {
Assert.IsNotNull(expected); Assert.IsNotNull(expected);
Assert.IsNotNull(given); Assert.IsNotNull(given);
@@ -23,5 +23,12 @@ namespace TensorFlowNET.UnitTest
for(int i=0; i<e.Length; i++) for(int i=0; i<e.Length; i++)
Assert.AreEqual(e[i], g[i], $"Items differ at index {i}, expected {e[i]} but got {g[i]}"); Assert.AreEqual(e[i], g[i], $"Items differ at index {i}, expected {e[i]} but got {g[i]}");
} }
public void assertEqual(object given, object expected)
{
Assert.AreEqual(expected, given);
}
protected PythonTest self { get => this; }
} }
} }

Loading…
Cancel
Save