Browse Source

using and no IEnumerable

tags/v0.110.4-Transformer-Model
Alexander Novikov 2 years ago
parent
commit
9d71cad96e
3 changed files with 21 additions and 21 deletions
  1. +2
    -2
      test/TensorFlowNET.Graph.UnitTest/ControlFlowTest/WhileContextTestCase.cs
  2. +9
    -7
      test/TensorFlowNET.Graph.UnitTest/GradientTest/GradientTest.cs
  3. +10
    -12
      test/TensorFlowNET.Graph.UnitTest/PythonTest.cs

+ 2
- 2
test/TensorFlowNET.Graph.UnitTest/ControlFlowTest/WhileContextTestCase.cs View File

@@ -24,13 +24,13 @@ namespace TensorFlowNET.UnitTest.ControlFlowTest
private void _testWhileContextHelper(int maximum_iterations) private void _testWhileContextHelper(int maximum_iterations)
{ {
// TODO: implement missing code dependencies // TODO: implement missing code dependencies
var sess = this.cached_session();
using var sess = this.cached_session();
var i = constant_op.constant(0, name: "i"); var i = constant_op.constant(0, name: "i");
var c = new Func<Tensor, Tensor>(x => gen_math_ops.less(x, ops.convert_to_tensor(10), name: "c")); var c = new Func<Tensor, Tensor>(x => gen_math_ops.less(x, ops.convert_to_tensor(10), name: "c"));
var b = new Func<Tensor, Tensor>(x => math_ops.add(x, 1, name: "c")); var b = new Func<Tensor, Tensor>(x => math_ops.add(x, 1, name: "c"));
//control_flow_ops.while_loop( //control_flow_ops.while_loop(
// c, b, i , maximum_iterations: tf.constant(maximum_iterations)); // c, b, i , maximum_iterations: tf.constant(maximum_iterations));
foreach (Operation op in sess.Single().graph.get_operations())
foreach (Operation op in sess.graph.get_operations())
{ {
var control_flow_context = op._get_control_flow_context(); var control_flow_context = op._get_control_flow_context();
/*if (control_flow_context != null) /*if (control_flow_context != null)


+ 9
- 7
test/TensorFlowNET.Graph.UnitTest/GradientTest/GradientTest.cs View File

@@ -394,13 +394,15 @@ namespace TensorFlowNET.UnitTest.Gradient
// Test that we differentiate both 'x' and 'y' correctly when x is a // Test that we differentiate both 'x' and 'y' correctly when x is a
// predecessor of y. // predecessor of y.


var sess = self.cached_session().Single();
var x = tf.constant(1.0);
var y = x * 2.0;
var z = y * 3.0;
var grads = tf.gradients(z, new[] { x, y });
self.assertTrue(all(grads.Select(x => x != null)));
self.assertEqual(6.0, grads[0].eval());
using (self.cached_session())
{
var x = tf.constant(1.0);
var y = x * 2.0;
var z = y * 3.0;
var grads = tf.gradients(z, new[] { x, y });
self.assertTrue(all(grads.Select(x => x != null)));
self.assertEqual(6.0, grads[0].eval());
}
} }


[Ignore("TODO")] [Ignore("TODO")]


+ 10
- 12
test/TensorFlowNET.Graph.UnitTest/PythonTest.cs View File

@@ -221,7 +221,7 @@ namespace TensorFlowNET.UnitTest
} }


///Returns a TensorFlow Session for use in executing tests. ///Returns a TensorFlow Session for use in executing tests.
public IEnumerable<Session> cached_session(
public Session cached_session(
Graph graph = null, object config = null, bool use_gpu = false, bool force_gpu = false) Graph graph = null, object config = null, bool use_gpu = false, bool force_gpu = false)
{ {
// This method behaves differently than self.session(): for performance reasons // This method behaves differently than self.session(): for performance reasons
@@ -267,9 +267,8 @@ namespace TensorFlowNET.UnitTest
{ {
var sess = self._get_cached_session( var sess = self._get_cached_session(
graph, config, force_gpu, crash_if_inconsistent_args: true); graph, config, force_gpu, crash_if_inconsistent_args: true);
var cached = self._constrain_devices_and_set_default(sess, use_gpu, force_gpu);
return cached;
using var cached = self._constrain_devices_and_set_default(sess, use_gpu, force_gpu);
return cached;
} }
} }


@@ -318,13 +317,12 @@ namespace TensorFlowNET.UnitTest
return s.as_default(); return s.as_default();
} }


private IEnumerable<Session> _constrain_devices_and_set_default(Session sess, bool use_gpu, bool force_gpu)
private Session _constrain_devices_and_set_default(Session sess, bool use_gpu, bool force_gpu)
{ {
// Set the session and its graph to global default and constrain devices.""" // Set the session and its graph to global default and constrain devices."""
// if context.executing_eagerly():
// yield None
// else:
{
if (tf.executing_eagerly())
return null;
else {
sess.graph.as_default(); sess.graph.as_default();
sess.as_default(); sess.as_default();
{ {
@@ -340,13 +338,13 @@ namespace TensorFlowNET.UnitTest
using (sess.graph.device(gpu_name)) { using (sess.graph.device(gpu_name)) {
yield return sess; yield return sess;
}*/ }*/
yield return sess;
return sess;
} }
else if (use_gpu) else if (use_gpu)
yield return sess;
return sess;
else else
using (sess.graph.device("/device:CPU:0")) using (sess.graph.device("/device:CPU:0"))
yield return sess;
return sess;
} }
} }


Loading…
Cancel
Save