Browse Source

more cond test cases: testCondTrue and testCondFalse

tags/v0.9
Meinrad Recheis 6 years ago
parent
commit
c997c729a0
3 changed files with 78 additions and 58 deletions
  1. +6
    -0
      src/TensorFlowNET.Core/Sessions/Session.cs
  2. +54
    -16
      test/TensorFlowNET.UnitTest/PythonTest.cs
  3. +18
    -42
      test/TensorFlowNET.UnitTest/control_flow_ops_test/CondTestCases.cs

+ 6
- 0
src/TensorFlowNET.Core/Sessions/Session.cs View File

@@ -38,6 +38,12 @@ namespace Tensorflow
Status.Check(true);
}

public Session as_default()
{
tf.defaultSession = this;
return this;
}

public static Session LoadFromSavedModel(string path)
{
var graph = c_api.TF_NewGraph();


+ 54
- 16
test/TensorFlowNET.UnitTest/PythonTest.cs View File

@@ -132,23 +132,66 @@ namespace TensorFlowNET.UnitTest
}
/// <summary>
/// Evaluates tensors and returns numpy values.
/// Evaluates tensors and returns a dictionary of {name:result, ...}.
/// <param name="tensors">A Tensor or a nested list/tuple of Tensors.</param>
/// </summary>
/// <returns> tensors numpy values.</returns>
public object evaluate(params Tensor[] tensors)
public Dictionary<string, NDArray> evaluate(params Tensor[] tensors)
{
var results = new Dictionary<string, NDArray>();
// if context.executing_eagerly():
// return self._eval_helper(tensors)
// else:
{
var sess = ops.get_default_session();
if (sess == None)
with(self.session(), s => sess = s);
return sess.run(tensors);
if (sess == null)
sess = self.session();
with<Session>(sess, s =>
{
foreach (var t in tensors)
results[t.name] = t.eval();
});
return results;
}
}
public NDArray evaluate(Tensor tensor)
{
NDArray result = null;
// if context.executing_eagerly():
// return self._eval_helper(tensors)
// else:
{
var sess = ops.get_default_session();
if (sess == null)
sess = self.session();
with<Session>(sess, s =>
{
result = tensor.eval();
});
return result;
}
}
public object eval_scalar(Tensor tensor)
{
NDArray result = null;
// if context.executing_eagerly():
// return self._eval_helper(tensors)
// else:
{
var sess = ops.get_default_session();
if (sess == null)
sess = self.session();
with<Session>(sess, s =>
{
result = tensor.eval();
});
return result.Array.GetValue(0);
}
}
//Returns a TensorFlow Session for use in executing tests.
public Session session(Graph graph = null, object config = null, bool use_gpu = false, bool force_gpu = false)
{
@@ -188,16 +231,11 @@ namespace TensorFlowNET.UnitTest
//if (context.executing_eagerly())
// yield None
//else
{
with<Session>(self._create_session(graph, config, force_gpu), sess =>
{
with(self._constrain_devices_and_set_default(sess, use_gpu, force_gpu), (x) =>
{
s = sess;
});
});
}
return s;
//{
s = self._create_session(graph, config, force_gpu);
self._constrain_devices_and_set_default(s, use_gpu, force_gpu);
//}
return s.as_default();
}
private IPython _constrain_devices_and_set_default(Session sess, bool useGpu, bool forceGpu)


+ 18
- 42
test/TensorFlowNET.UnitTest/control_flow_ops_test/CondTestCases.cs View File

@@ -13,53 +13,29 @@ namespace TensorFlowNET.UnitTest.control_flow_ops_test
[TestMethod]
public void testCondTrue()
{
var x = tf.constant(2);
var y = tf.constant(5);
var z = control_flow_ops.cond(tf.less(x, y), () => tf.multiply(x, tf.constant(17)),
() => tf.add(y, tf.constant(23)));
self.assertEquals(self.evaluate(z), 34);
with(tf.Graph().as_default(), g =>
{
var x = tf.constant(2);
var y = tf.constant(5);
var z = control_flow_ops.cond(tf.less(x, y), () => tf.multiply(x, tf.constant(17)),
() => tf.add(y, tf.constant(23)));
//tf.train.export_meta_graph(@"D:\dev\tensorboard\logdir\sharp.meta", as_text: false);
self.assertEquals(eval_scalar(z), 34);
});
}
[Ignore("Todo")]
[Ignore("This Test Fails due to missing edges in the graph!")]
[TestMethod]
public void testCondFalse()
{
// def testCondFalse(self):
// x = constant_op.constant(2)
// y = constant_op.constant(1)
// z = control_flow_ops.cond(
// math_ops.less(
// x,
// y), lambda: math_ops.multiply(x, 17), lambda: math_ops.add(y, 23))
// self.assertEquals(self.evaluate(z), 24)
}
[Ignore("Todo")]
[TestMethod]
public void testCondTrueLegacy()
{
// def testCondTrueLegacy(self):
// x = constant_op.constant(2)
// y = constant_op.constant(5)
// z = control_flow_ops.cond(
// math_ops.less(x, y),
// fn1=lambda: math_ops.multiply(x, 17),
// fn2=lambda: math_ops.add(y, 23))
// self.assertEquals(self.evaluate(z), 34)
}
[Ignore("Todo")]
[TestMethod]
public void testCondFalseLegacy()
{
// def testCondFalseLegacy(self):
// x = constant_op.constant(2)
// y = constant_op.constant(1)
// z = control_flow_ops.cond(
// math_ops.less(x, y),
// fn1=lambda: math_ops.multiply(x, 17),
// fn2=lambda: math_ops.add(y, 23))
// self.assertEquals(self.evaluate(z), 24)
with(tf.Graph().as_default(), g =>
{
var x = tf.constant(2);
var y = tf.constant(1);
var z = control_flow_ops.cond(tf.less(x, y), () => tf.multiply(x, tf.constant(17)),
() => tf.add(y, tf.constant(23)));
self.assertEquals(eval_scalar(z), 24);
});
}
[Ignore("Todo")]


Loading…
Cancel
Save