diff --git a/src/TensorFlowNET.Core/Keras/Layers/Conv.cs b/src/TensorFlowNET.Core/Keras/Layers/Conv.cs index bc80ff72..efa79e00 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/Conv.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/Conv.cs @@ -101,8 +101,8 @@ namespace Tensorflow.Keras.Layers throw new NotImplementedException("call channels_first"); } else - { - outputs = nn_ops.bias_add(outputs, bias, data_format: "NHWC"); + { + outputs = nn_ops.bias_add(outputs, bias._AsTensor(), data_format: "NHWC"); } } diff --git a/src/TensorFlowNET.Core/Operations/Operation.cs b/src/TensorFlowNET.Core/Operations/Operation.cs index 4f3880d1..07dab399 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.cs @@ -10,22 +10,22 @@ using System.Text; namespace Tensorflow { - /// - /// Represents a graph node that performs computation on tensors. - /// - /// An `Operation` is a node in a TensorFlow `Graph` that takes zero or - /// more `Tensor` objects as input, and produces zero or more `Tensor` - /// objects as output. Objects of type `Operation` are created by - /// calling an op constructor(such as `tf.matmul`) - /// or `tf.Graph.create_op`. - /// - /// For example `c = tf.matmul(a, b)` creates an `Operation` of type - /// "MatMul" that takes tensors `a` and `b` as input, and produces `c` - /// as output. - /// - /// After the graph has been launched in a session, an `Operation` can - /// be executed by passing it to - /// `tf.Session.run`. + /// + /// Represents a graph node that performs computation on tensors. + /// + /// An `Operation` is a node in a TensorFlow `Graph` that takes zero or + /// more `Tensor` objects as input, and produces zero or more `Tensor` + /// objects as output. Objects of type `Operation` are created by + /// calling an op constructor(such as `tf.matmul`) + /// or `tf.Graph.create_op`. + /// + /// For example `c = tf.matmul(a, b)` creates an `Operation` of type + /// "MatMul" that takes tensors `a` and `b` as input, and produces `c` + /// as output. + /// + /// After the graph has been launched in a session, an `Operation` can + /// be executed by passing it to + /// `tf.Session.run`. /// `op.run()` is a shortcut for calling `tf.get_default_session().run(op)`. /// public partial class Operation : ITensorOrOperation @@ -271,47 +271,49 @@ namespace Tensorflow return base.Equals(obj); } - /// - /// Update the input to this operation at the given index. - /// - /// NOTE: This is for TF internal use only.Please don't use it. - /// - /// the index of the input to update. - /// the Tensor to be used as the input at the given index. - public void _update_input(int index, Tensor tensor) - { - _assert_same_graph(tensor); - - var input = _tf_input(index); + /// + /// Update the input to this operation at the given index. + /// + /// NOTE: This is for TF internal use only.Please don't use it. + /// + /// the index of the input to update. + /// the Tensor to be used as the input at the given index. + public void _update_input(int index, Tensor tensor) + { + _assert_same_graph(tensor); + + var input = _tf_input(index); var output = tensor._as_tf_output(); // Reset cached inputs. - _inputs = null;// new InputList(new Tensor[] { tensor }); // is this right? original code: self._inputs_val=None - // TODO: implement below code dependencies - c_api.TF_UpdateEdge(graph, output, input, status); - } - - private void _assert_same_graph(Tensor tensor) - { - //TODO: implement - } - - /// - /// Create and return a new TF_Output for output_idx'th output of this op. - /// - public TF_Output _tf_output(int output_idx) - { - var tf_output = new TF_Output(op, output_idx); - return tf_output; - } - - /// - /// Create and return a new TF_Input for input_idx'th input of this op. - /// - public TF_Input _tf_input(int input_idx) - { - var tf_input = new TF_Input(op, input_idx); - return tf_input; - } - } -} + _inputs = null; + // after the c_api call next time _inputs is accessed + // the updated inputs are reloaded from the c_api + c_api.TF_UpdateEdge(_graph, output, input, status); + //var updated_inputs = inputs; + } + + private void _assert_same_graph(Tensor tensor) + { + //TODO: implement + } + + /// + /// Create and return a new TF_Output for output_idx'th output of this op. + /// + public TF_Output _tf_output(int output_idx) + { + var tf_output = new TF_Output(op, output_idx); + return tf_output; + } + + /// + /// Create and return a new TF_Input for input_idx'th input of this op. + /// + public TF_Input _tf_input(int input_idx) + { + var tf_input = new TF_Input(op, input_idx); + return tf_input; + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs b/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs index fad7a1e1..dbb7a96e 100644 --- a/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs +++ b/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs @@ -290,33 +290,11 @@ namespace Tensorflow { // TODO: here a chunk of original code is missing /* - if fn1 is not None: - if true_fn is not None: - raise TypeError("cond(): true_fn and fn1 may not be set simultaneously.") - true_fn = fn1 - elif true_fn is None: - raise TypeError("cond(): true_fn argument required") - if fn2 is not None: - if false_fn is not None: - raise TypeError("cond(): false_fn and fn2 may not be set simultaneously.") - false_fn = fn2 - elif false_fn is None: - raise TypeError("cond(): false_fn argument required") - - if not callable(true_fn): - raise TypeError("true_fn must be callable.") - if not callable(false_fn): - raise TypeError("false_fn must be callable.") - with ops.name_scope(name, "cond", [pred]): if context.executing_eagerly(): if pred: return _UnpackIfSingleton(true_fn()) return _UnpackIfSingleton(false_fn()) - - # Add the Switch to the graph. - if isinstance(pred, bool): - raise TypeError("pred must not be a Python bool") */ // Add the Switch to the graph. diff --git a/src/TensorFlowNET.Core/Operations/nn_ops.cs b/src/TensorFlowNET.Core/Operations/nn_ops.cs index e54caf66..a5a2815c 100644 --- a/src/TensorFlowNET.Core/Operations/nn_ops.cs +++ b/src/TensorFlowNET.Core/Operations/nn_ops.cs @@ -30,7 +30,7 @@ namespace Tensorflow /// /// public static Tensor bias_add(Tensor value, - RefVariable bias, + Tensor bias, string data_format = null, string name = null) { diff --git a/src/TensorFlowNET.Core/Sessions/Session.cs b/src/TensorFlowNET.Core/Sessions/Session.cs index e43dd913..a610f4e7 100644 --- a/src/TensorFlowNET.Core/Sessions/Session.cs +++ b/src/TensorFlowNET.Core/Sessions/Session.cs @@ -38,6 +38,12 @@ namespace Tensorflow Status.Check(true); } + public Session as_default() + { + tf.defaultSession = this; + return this; + } + public static Session LoadFromSavedModel(string path) { var graph = c_api.TF_NewGraph(); diff --git a/test/TensorFlowNET.UnitTest/PythonTest.cs b/test/TensorFlowNET.UnitTest/PythonTest.cs index 97d49932..b888f883 100644 --- a/test/TensorFlowNET.UnitTest/PythonTest.cs +++ b/test/TensorFlowNET.UnitTest/PythonTest.cs @@ -132,24 +132,44 @@ namespace TensorFlowNET.UnitTest } /// - /// Evaluates tensors and returns numpy values. - /// A Tensor or a nested list/tuple of Tensors. + /// This function is used in many original tensorflow unit tests to evaluate tensors + /// in a test session with special settings (for instance constant folding off) + /// /// - /// tensors numpy values. - [Obsolete("Why do we need this function? we already have Tensor.eval().")] - public object evaluate(params Tensor[] tensors) + public T evaluate(Tensor tensor) { + var results = new Dictionary(); // if context.executing_eagerly(): // return self._eval_helper(tensors) // else: { var sess = ops.get_default_session(); - if (sess == None) - with(self.session(), s => sess = s); - return sess.run(tensors); + if (sess == null) + sess = self.session(); + T t_result = (T)(object)null; + with(sess, s => + { + var ndarray=tensor.eval(); + if (typeof(T) == typeof(double)) + { + double d = ndarray; + t_result = (T)(object)d; + } + else if (typeof(T) == typeof(int)) + { + int d = ndarray; + t_result = (T) (object) d; + } + else + { + t_result = (T)(object)ndarray; + } + }); + return t_result; } } + //Returns a TensorFlow Session for use in executing tests. public Session session(Graph graph = null, object config = null, bool use_gpu = false, bool force_gpu = false) { @@ -189,16 +209,11 @@ namespace TensorFlowNET.UnitTest //if (context.executing_eagerly()) // yield None //else - { - with(self._create_session(graph, config, force_gpu), sess => - { - with(self._constrain_devices_and_set_default(sess, use_gpu, force_gpu), (x) => - { - s = sess; - }); - }); - } - return s; + //{ + s = self._create_session(graph, config, force_gpu); + self._constrain_devices_and_set_default(s, use_gpu, force_gpu); + //} + return s.as_default(); } private IPython _constrain_devices_and_set_default(Session sess, bool useGpu, bool forceGpu) diff --git a/test/TensorFlowNET.UnitTest/control_flow_ops_test/CondTestCases.cs b/test/TensorFlowNET.UnitTest/control_flow_ops_test/CondTestCases.cs index 01a774a5..dd364149 100644 --- a/test/TensorFlowNET.UnitTest/control_flow_ops_test/CondTestCases.cs +++ b/test/TensorFlowNET.UnitTest/control_flow_ops_test/CondTestCases.cs @@ -91,74 +91,7 @@ namespace TensorFlowNET.UnitTest.control_flow_ops_test }); } - [Ignore("Todo")] - [TestMethod] - public void testCondTrueLegacy() - { - // def testCondTrueLegacy(self): - // x = constant_op.constant(2) - // y = constant_op.constant(5) - // z = control_flow_ops.cond( - // math_ops.less(x, y), - // fn1=lambda: math_ops.multiply(x, 17), - // fn2=lambda: math_ops.add(y, 23)) - // self.assertEquals(self.evaluate(z), 34) - } - - [Ignore("Todo")] - [TestMethod] - public void testCondFalseLegacy() - { - // def testCondFalseLegacy(self): - // x = constant_op.constant(2) - // y = constant_op.constant(1) - // z = control_flow_ops.cond( - // math_ops.less(x, y), - // fn1=lambda: math_ops.multiply(x, 17), - // fn2=lambda: math_ops.add(y, 23)) - // self.assertEquals(self.evaluate(z), 24) - } - - [Ignore("Todo")] - [TestMethod] - public void testCondMissingArg1() - { - // def testCondMissingArg1(self): - // x = constant_op.constant(1) - // with self.assertRaises(TypeError): - // control_flow_ops.cond(True, false_fn=lambda: x) - - } - - [Ignore("Todo")] - [TestMethod] - public void testCondMissingArg2() - { - // def testCondMissingArg2(self): - // x = constant_op.constant(1) - // with self.assertRaises(TypeError): - // control_flow_ops.cond(True, lambda: x) - } - - [Ignore("Todo")] - [TestMethod] - public void testCondDuplicateArg1() - { - // def testCondDuplicateArg1(self): - // x = constant_op.constant(1) - // with self.assertRaises(TypeError): - // control_flow_ops.cond(True, lambda: x, lambda: x, fn1=lambda: x) - } - - [Ignore("Todo")] - [TestMethod] - public void testCondDuplicateArg2() - { - // def testCondDuplicateArg2(self): - // x = constant_op.constant(1) - // with self.assertRaises(TypeError): - // control_flow_ops.cond(True, lambda: x, lambda: x, fn2=lambda: x) - } + // NOTE: all other test python test cases of this class are either not needed due to strong typing or dest a deprecated api } } diff --git a/test/TensorFlowNET.UnitTest/gradients_test/GradientsTest.cs b/test/TensorFlowNET.UnitTest/gradients_test/GradientsTest.cs new file mode 100644 index 00000000..2d308042 --- /dev/null +++ b/test/TensorFlowNET.UnitTest/gradients_test/GradientsTest.cs @@ -0,0 +1,507 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using Tensorflow; + +namespace TensorFlowNET.UnitTest.gradients_test +{ + [TestClass] + public class GradientsTest : PythonTest + { + + //[Ignore("TODO")] + [TestMethod] + public void testGradients() + { + with(tf.Graph().as_default(), g => + { + var inp = tf.constant(1.0, shape: new[]{32, 100}, name:"in"); + var w = tf.constant(1.0, shape: new[] { 100, 10}, name:"w"); + var b = tf.constant(1.0, shape: new[] { 10}, name:"b"); + var xw = math_ops.matmul(inp, w, name: "xw"); + var h = nn_ops.bias_add(xw, b, name: "h"); + var w_grad = gradients_impl.gradients(new []{h}, new[] { w})[0]; + self.assertEquals("MatMul", w_grad.op.type); + // TODO: Operation._original_op + //self.assertEquals(w_grad.op._original_op, xw.op); + self.assertTrue((bool)w_grad.op.get_attr("transpose_a")); + self.assertFalse((bool)w_grad.op.get_attr("transpose_b")); + }); + + } + + [Ignore("TODO")] + [TestMethod] + public void testUnusedOutput() + { + //def testUnusedOutput(self): + // with ops.Graph().as_default(): + // w = constant(1.0, shape=[2, 2]) + // x = constant(1.0, shape=[2, 2]) + // wx = math_ops.matmul(w, x) + // split_wx = array_ops.split(value=wx, num_or_size_splits=2, axis=0) + // c = math_ops.reduce_sum(split_wx[1]) + // gw = gradients.gradients(c, [w])[0] + // self.assertEquals("MatMul", gw.op.type) + } + + [Ignore("TODO")] + [TestMethod] + public void testColocateGradients() + { + + //def testColocateGradients(self): + // with ops.Graph().as_default() as g: + // w = constant(1.0, shape=[1, 1]) + // x = constant(1.0, shape=[1, 2]) + // with g.device("/device:GPU:0"): + // wx = math_ops.matmul(w, x) + // gw = gradients.gradients(wx, [w], colocate_gradients_with_ops=True)[0] + // self.assertEqual(gw.op.colocation_groups(), wx.op.colocation_groups()) + } + + [Ignore("TODO")] + [TestMethod] + public void testColocateGradientsWithAggregation() + { + //def testColocateGradientsWithAggregation(self): + // with ops.Graph().as_default() as g: + // with g.device("/device:GPU:1"): + // w = constant(1.0, shape=[1, 1]) + // x = constant(1.0, shape=[1, 2]) + // y = constant(1.0, shape=[1, 2]) + // wx = math_ops.matmul(w, x) + // wy = math_ops.matmul(w, y) + // with g.device("/device:GPU:0"): + // z = wx + wy + + // gw1 = gradients.gradients(z, [w], colocate_gradients_with_ops=True)[0] + // self.assertEqual(gw1.op.colocation_groups(), wx.op.colocation_groups()) + + // gw2 = gradients.gradients(z, [w], colocate_gradients_with_ops=False)[0] + // self.assertTrue(wx.op.colocation_groups() != gw2.op.colocation_groups()) + + } + + [Ignore("TODO")] + [TestMethod] + public void testColocateGradientsWithAggregationInMultipleDevices() + { + //def testColocateGradientsWithAggregationInMultipleDevices(self): + // with ops.Graph().as_default() as g: + // with g.device("/device:GPU:1"): + // w = constant(1.0, shape=[1, 1]) + // x = constant(1.0, shape=[1, 2]) + // y = constant(1.0, shape=[1, 2]) + // with g.device("/task:1"): + // wx = math_ops.matmul(w, x) + // with g.device("/task:2"): + // wy = math_ops.matmul(w, y) + // with g.device("/device:GPU:0"): + // z = wx + wy + + // gw1 = gradients.gradients(z, [w], colocate_gradients_with_ops=True)[0] + // self.assertEqual(gw1.op.colocation_groups(), w.op.colocation_groups()) + + // gw2 = gradients.gradients(z, [w], colocate_gradients_with_ops=False)[0] + // self.assertTrue(w.op.colocation_groups() != gw2.op.colocation_groups()) + } + + + [Ignore("TODO")] + [TestMethod] + public void testColocateGradientsWithGateGradients() + { + + //def testColocateGradientsWithGateGradients(self): + // if not test_util.is_gpu_available(): + // self.skipTest("No GPU available") + // with ops.Graph().as_default() as g: + // with g.device("/device:CPU:0"): + // x = constant(1.0, shape=[1, 1]) + // y = constant(1.0, shape=[1, 1]) + // s = x + y + // with g.device("/device:GPU:0"): + // z = math_ops.reduce_sum(s) + + // gz_x = gradients.gradients(z, [x], colocate_gradients_with_ops=True, + // gate_gradients=True)[0] + // with session.Session(): + // # Make sure the placer doesn't complain. + // self.evaluate(gz_x) + + } + + [Ignore("TODO")] + [TestMethod] + public void testBoundaryStop() + { + //def testBoundaryStop(self): + // # Test that we don't differentiate 'x'. The gradient function for 'x' is + // # set explicitly to None so we will get an exception if the gradient code + // # tries to differentiate 'x'. + // with ops.Graph().as_default(): + // c = constant(1.0) + // x = array_ops.identity(c) + // y = x + 1.0 + // z = y + 1 + // grads = gradients.gradients(z, [x]) + // self.assertTrue(all(x is not None for x in grads)) + + } + + [Ignore("TODO")] + [TestMethod] + public void testBoundaryContinue() + { + //@test_util.run_v1_only("b/120545219") + //def testBoundaryContinue(self): + // # Test that we differentiate both 'x' and 'y' correctly when x is a + // # predecessor of y. + // with self.cached_session(): + // x = constant(1.0) + // y = x * 2.0 + // z = y * 3.0 + // grads = gradients.gradients(z, [x, y]) + // self.assertTrue(all(x is not None for x in grads)) + // self.assertEqual(6.0, grads[0].eval()) + + } + + [Ignore("TODO")] + [TestMethod] + public void testAggregationMethodAccumulateN() + { + + //@test_util.run_v1_only("b/120545219") + //def testAggregationMethodAccumulateN(self): + // with self.cached_session(): + // x = constant(1.0) + // y = x * 2.0 + // z = y + y + y + y + y + y + y + y + y + y + // grads = gradients.gradients( + // z, [x, y], + // aggregation_method=gradients.AggregationMethod. + // EXPERIMENTAL_ACCUMULATE_N) + // self.assertTrue(all(x is not None for x in grads)) + // self.assertEqual(20.0, grads[0].eval()) + // self.assertEqual(10.0, grads[1].eval()) + + } + + [Ignore("TODO")] + [TestMethod] + public void testAggregationMethodAddN() + { + //@test_util.run_v1_only("b/120545219") + //def testAggregationMethodAddN(self): + // with self.cached_session(): + // x = constant(1.0) + // y = x * 2.0 + // z = y + y + y + y + y + y + y + y + y + y + // grads = gradients.gradients( + // z, [x, y], aggregation_method=gradients.AggregationMethod.ADD_N) + // self.assertTrue(all(x is not None for x in grads)) + // self.assertEqual(20.0, grads[0].eval()) + // self.assertEqual(10.0, grads[1].eval()) + + + } + + [Ignore("TODO")] + [TestMethod] + public void testAggregationMethodTree() + { + //@test_util.run_v1_only("b/120545219") + //def testAggregationMethodTree(self): + // with self.cached_session(): + // x = constant(1.0) + // y = x * 2.0 + // z = y + y + y + y + y + y + y + y + y + y + // grads = gradients.gradients( + // z, [x, y], + // aggregation_method=gradients.AggregationMethod.EXPERIMENTAL_TREE) + // self.assertTrue(all(x is not None for x in grads)) + // self.assertEqual(20.0, grads[0].eval()) + // self.assertEqual(10.0, grads[1].eval()) + + } + + [Ignore("TODO")] + [TestMethod] + public void testNoGradientForStringOutputs() + { + + //def testNoGradientForStringOutputs(self): + // with ops.Graph().as_default(): + + // def _TestOpGrad(_, float_grad, string_grad): + // """Gradient function for TestStringOutput.""" + // self.assertEquals(float_grad.dtype, dtypes.float32) + // self.assertFalse(string_grad) + // return float_grad + + // ops.RegisterGradient("TestStringOutput")(_TestOpGrad) + + // c = constant(1.0) + // x, _ = test_ops.test_string_output(c) + // z = x * 2.0 + // w = z * 3.0 + // grads = gradients.gradients(z, [c]) + // self.assertTrue(isinstance(grads[0], ops.Tensor)) + // grads = gradients.gradients(w, [c]) + // self.assertTrue(isinstance(grads[0], ops.Tensor)) + } + + [Ignore("TODO")] + [TestMethod] + public void testSingletonIndexedSlices() + { + + //def testSingletonIndexedSlices(self): + // with ops.Graph().as_default(): + // x = array_ops.placeholder(dtypes.float32) + // y = array_ops.identity(x) + // dy = ops.IndexedSlices( + // array_ops.placeholder(dtypes.float32), + // array_ops.placeholder(dtypes.int32)) + // dx, = gradients.gradients(y, x, grad_ys=dy) + // # The IndexedSlices gradient of tf.identity is the identity map. + // with self.cached_session() as sess: + // vdx, vdy = sess.run( + // [dx, dy], feed_dict={x: [1.0], dy.indices: [0], dy.values: [2.0]}) + // self.assertEqual(vdx, vdy) + } + + [Ignore("TODO")] + [TestMethod] + public void testNonDifferentiableSwitchInWhileLoop() + { + + + //@test_util.run_v1_only("b/120545219") + //def testNonDifferentiableSwitchInWhileLoop(self): + // with ops.Graph().as_default(): + // v = array_ops.placeholder(dtypes.float32, []) + + // def _Step(i, a, ta): + // a += math_ops.cast(v, dtypes.int32) + // return (i + 1, a, ta.write(i, a)) + + // n = 4 + // i, _, ta = control_flow_ops.while_loop( + // lambda i, *_: i < n, + // _Step, [0, 0, tensor_array_ops.TensorArray( + // dtypes.int32, size=n)]) + // target = ta.read(i - 1) + // grad, = gradients.gradients(target, v) + // self.assertIsNone(grad) + + } + + [Ignore("TODO")] + [TestMethod] + public void testVariableReadValueGradient() + { + + //def testVariableReadValueGradient(self): + // with ops.Graph().as_default(): + // init = constant_op.constant(100.0) + // var = variables.Variable(init) + // gradient = gradients.gradients(var.read_value(), var) + // self.assertIsNotNone(gradient) + } + + [Ignore("TODO")] + [TestMethod] + public void testVariableAsGraphElementGradient() + { + //def testVariableAsGraphElementGradient(self): + // with ops.Graph().as_default() as graph: + // init = constant_op.constant(100.0) + // var = variables.Variable(init) + // gradient = gradients.gradients(graph.as_graph_element(var), var) + // self.assertIsNotNone(gradient) + } + + [Ignore("TODO")] + [TestMethod] + public void testVariableRefGradient() + { + + //@test_util.run_v1_only("b/120545219") + //def testVariableRefGradient(self): + // with ops.Graph().as_default(): + // init = constant_op.constant(100.0) + // var = variables.VariableV1(init) + // gradient = gradients.gradients(var._ref(), var) + // self.assertIsNotNone(gradient) + } + + [Ignore("TODO")] + [TestMethod] + public void testDependentYs() + { + //@test_util.run_v1_only("b/120545219") + //def testDependentYs(self): + // with self.cached_session(): + // x = constant_op.constant(3.0) + // y = math_ops.square(x) + // y1 = math_ops.square(y) + // y2 = math_ops.square(y1) + // g = gradients.gradients([y, y2], x) + // self.assertAllClose(17502.0, g[0].eval()) + // g = gradients.gradients(y + y2, x) + // self.assertAllClose(17502.0, g[0].eval()) + // z = array_ops.identity(y) + // z2 = array_ops.identity(y2) + // g = gradients.gradients([z, z2], x) + // self.assertAllClose(17502.0, g[0].eval()) + + } + + [Ignore("TODO")] + [TestMethod] + public void testPartialDerivatives() + { + + //@test_util.run_v1_only("b/120545219") + //def testPartialDerivatives(self): + // with self.cached_session(): + // x = constant_op.constant(1.) + // y = 2 * x + // z = x + y + // totalg = gradients.gradients(z, [x, y]) + // self.assertEqual([3.0, 1.0], [g.eval() for g in totalg]) + // partialg = gradients.gradients(z, [x, y], stop_gradients=[x, y]) + // self.assertEqual([1.0, 1.0], [g.eval() for g in partialg]) + } + + [Ignore("TODO")] + [TestMethod] + public void testStopGradients() + { + + + //@test_util.run_v1_only("b/120545219") + //def testStopGradients(self): + // def _MakeGraph(rng, stop_gradients=()): + // def _FunctionOf(xs, k=3): + // return ops.convert_to_tensor( + // sum(math_ops.matmul(rng.rand(k, k), x) for x in xs) + // + rng.rand(k, k)) + + // a = _FunctionOf([]) + // if "a" in stop_gradients: a = array_ops.stop_gradient(a) + // b = _FunctionOf([a]) + // if "b" in stop_gradients: b = array_ops.stop_gradient(b) + // c = _FunctionOf([a, b]) + // if "c" in stop_gradients: c = array_ops.stop_gradient(c) + // d = _FunctionOf([b, c]) + // if "d" in stop_gradients: d = array_ops.stop_gradient(d) + // return dict(a=a, b=b, c=c, d=d) + + // def _Gradients(ys, xs, **kwargs): + // dydxs = gradients.gradients(ys, xs, **kwargs) + // dydxs = [0. * x if dydx is None else dydx + // for x, dydx in zip(xs, dydxs)] + // return dydxs + // seed = np.random.randint(1000) + // cases = [] + // subsets = [""] + "a b c d ab ac ad bc bd cd abc abd acd bcd abcd".split() + // graph = _MakeGraph(np.random.RandomState(seed)) + // for constants in subsets: + // graph_with_stops = _MakeGraph(np.random.RandomState(seed), constants) + // for variables_ in subsets: + // # compute the gradient when stopped using tf.stop_gradients + // grad1 = _Gradients([graph_with_stops["d"]], + // [graph_with_stops[v] for v in variables_]) + // # compute the gradient when stopped using the stop_gradients kwarg + // grad2 = _Gradients([graph["d"]], + // [graph[v] for v in variables_], + // stop_gradients=[graph[v] for v in constants]) + // cases.append(dict(grad1=grad1, grad2=grad2, + // constants=constants, variables=variables_)) + + // # evaluate all tensors in one call to session.run for speed + // with self.cached_session() as sess: + // results = sess.run([(case["grad1"], case["grad2"]) for case in cases]) + + // for (npgrad1, npgrad2), case in zip(results, cases): + // for a, b in zip(npgrad1, npgrad2): + // np.testing.assert_allclose(a, b) + + } + + [Ignore("TODO")] + [TestMethod] + public void testUnconnectedGradientsNoneUnconnectedGradients() + { + + + //def testUnconnectedGradientsNoneUnconnectedGradients(self): + // with ops.Graph().as_default(): + // x = constant(1.0, shape=[2, 2]) + // y = constant(3.0, shape=[3, 1]) + // grad = gradients.gradients( + // [y], [x], unconnected_gradients="none") + // self.assertIsNone(grad[0]) + } + + [Ignore("TODO")] + [TestMethod] + public void testUnconnectedGradientsZerosUnconnectedGradients() + { + + + //def testUnconnectedGradientsZerosUnconnectedGradients(self): + // with ops.Graph().as_default(): + // x = constant(1.0, shape=[2, 2]) + // y = constant(3.0, shape=[3, 1]) + // grads = gradients.gradients( + // [y], [x], unconnected_gradients="zero") + // with self.cached_session() as sess: + // self.assertAllEqual([[0.0, 0.0], [0.0, 0.0]], self.evaluate(grads)[0]) + } + + [Ignore("TODO")] + [TestMethod] + public void testUnconnectedGradientsZeroConnectedGradients() + { + + + + //def testUnconnectedGradientsZeroConnectedGradients(self): + // with ops.Graph().as_default(): + // x = constant(1.0) + // y = x * 3.0 + // grad = gradients.gradients( + // [y], [x], unconnected_gradients="zero") + // with self.cached_session() as sess: + // self.assertEquals(3.0, self.evaluate(grad)[0]) + } + + [Ignore("TODO")] + [TestMethod] + public void testUnknownUnconnectedGradientsValueGiven() + { + //def testUnknownUnconnectedGradientsValueGiven(self): + // with ops.Graph().as_default(): + // x = constant(1.0) + // y = constant(1.0) + // with self.assertRaisesRegexp( + // ValueError, "Unknown value for unconnected_gradients: 'nonsense'"): + // gradients.gradients([y], [x], unconnected_gradients="nonsense") + + } + + + + /* + + + + */ + } +} diff --git a/test/TensorFlowNET.UnitTest/gradients_test/gradients_test.py b/test/TensorFlowNET.UnitTest/gradients_test/gradients_test.py new file mode 100644 index 00000000..c53afef6 --- /dev/null +++ b/test/TensorFlowNET.UnitTest/gradients_test/gradients_test.py @@ -0,0 +1,1104 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tensorflow.ops.gradients.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import sys +import warnings + +import numpy as np + +from tensorflow.python.client import session +from tensorflow.python.eager import backprop +from tensorflow.python.eager import context +from tensorflow.python.eager import function +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import function as framework_function +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_ops +from tensorflow.python.framework import test_util +from tensorflow.python.framework.constant_op import constant +from tensorflow.python.layers import core as core_layers +from tensorflow.python.ops import array_grad # pylint: disable=unused-import +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_grad # pylint: disable=unused-import +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import custom_gradient +from tensorflow.python.ops import data_flow_grad # pylint: disable=unused-import +from tensorflow.python.ops import data_flow_ops # pylint: disable=unused-import +from tensorflow.python.ops import functional_ops # pylint: disable=unused-import +from tensorflow.python.ops import gradients +from tensorflow.python.ops import gradients_impl +from tensorflow.python.ops import list_ops +from tensorflow.python.ops import math_grad # pylint: disable=unused-import +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_grad # pylint: disable=unused-import +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import state_grad # pylint: disable=unused-import +from tensorflow.python.ops import tensor_array_grad # pylint: disable=unused-import +from tensorflow.python.ops import tensor_array_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables +from tensorflow.python.ops.nn_ops import bias_add +from tensorflow.python.platform import googletest + + +class GradientsTest(test_util.TensorFlowTestCase): + + def testGradients(self): + with ops.Graph().as_default(): + inp = constant(1.0, shape=[32, 100], name="in") + w = constant(1.0, shape=[100, 10], name="w") + b = constant(1.0, shape=[10], name="b") + xw = math_ops.matmul(inp, w, name="xw") + h = bias_add(xw, b, name="h") + w_grad = gradients.gradients(h, w)[0] + self.assertEquals("MatMul", w_grad.op.type) + self.assertEquals(w_grad.op._original_op, xw.op) + self.assertTrue(w_grad.op.get_attr("transpose_a")) + self.assertFalse(w_grad.op.get_attr("transpose_b")) + + def testUnusedOutput(self): + with ops.Graph().as_default(): + w = constant(1.0, shape=[2, 2]) + x = constant(1.0, shape=[2, 2]) + wx = math_ops.matmul(w, x) + split_wx = array_ops.split(value=wx, num_or_size_splits=2, axis=0) + c = math_ops.reduce_sum(split_wx[1]) + gw = gradients.gradients(c, [w])[0] + self.assertEquals("MatMul", gw.op.type) + + def testColocateGradients(self): + with ops.Graph().as_default() as g: + w = constant(1.0, shape=[1, 1]) + x = constant(1.0, shape=[1, 2]) + with g.device("/device:GPU:0"): + wx = math_ops.matmul(w, x) + gw = gradients.gradients(wx, [w], colocate_gradients_with_ops=True)[0] + self.assertEqual(gw.op.colocation_groups(), wx.op.colocation_groups()) + + def testColocateGradientsWithAggregation(self): + with ops.Graph().as_default() as g: + with g.device("/device:GPU:1"): + w = constant(1.0, shape=[1, 1]) + x = constant(1.0, shape=[1, 2]) + y = constant(1.0, shape=[1, 2]) + wx = math_ops.matmul(w, x) + wy = math_ops.matmul(w, y) + with g.device("/device:GPU:0"): + z = wx + wy + + gw1 = gradients.gradients(z, [w], colocate_gradients_with_ops=True)[0] + self.assertEqual(gw1.op.colocation_groups(), wx.op.colocation_groups()) + + gw2 = gradients.gradients(z, [w], colocate_gradients_with_ops=False)[0] + self.assertTrue(wx.op.colocation_groups() != gw2.op.colocation_groups()) + + def testColocateGradientsWithAggregationInMultipleDevices(self): + with ops.Graph().as_default() as g: + with g.device("/device:GPU:1"): + w = constant(1.0, shape=[1, 1]) + x = constant(1.0, shape=[1, 2]) + y = constant(1.0, shape=[1, 2]) + with g.device("/task:1"): + wx = math_ops.matmul(w, x) + with g.device("/task:2"): + wy = math_ops.matmul(w, y) + with g.device("/device:GPU:0"): + z = wx + wy + + gw1 = gradients.gradients(z, [w], colocate_gradients_with_ops=True)[0] + self.assertEqual(gw1.op.colocation_groups(), w.op.colocation_groups()) + + gw2 = gradients.gradients(z, [w], colocate_gradients_with_ops=False)[0] + self.assertTrue(w.op.colocation_groups() != gw2.op.colocation_groups()) + + def testColocateGradientsWithGateGradients(self): + if not test_util.is_gpu_available(): + self.skipTest("No GPU available") + with ops.Graph().as_default() as g: + with g.device("/device:CPU:0"): + x = constant(1.0, shape=[1, 1]) + y = constant(1.0, shape=[1, 1]) + s = x + y + with g.device("/device:GPU:0"): + z = math_ops.reduce_sum(s) + + gz_x = gradients.gradients(z, [x], colocate_gradients_with_ops=True, + gate_gradients=True)[0] + with session.Session(): + # Make sure the placer doesn't complain. + self.evaluate(gz_x) + + def testBoundaryStop(self): + # Test that we don't differentiate 'x'. The gradient function for 'x' is + # set explicitly to None so we will get an exception if the gradient code + # tries to differentiate 'x'. + with ops.Graph().as_default(): + c = constant(1.0) + x = array_ops.identity(c) + y = x + 1.0 + z = y + 1 + grads = gradients.gradients(z, [x]) + self.assertTrue(all(x is not None for x in grads)) + + @test_util.run_v1_only("b/120545219") + def testBoundaryContinue(self): + # Test that we differentiate both 'x' and 'y' correctly when x is a + # predecessor of y. + with self.cached_session(): + x = constant(1.0) + y = x * 2.0 + z = y * 3.0 + grads = gradients.gradients(z, [x, y]) + self.assertTrue(all(x is not None for x in grads)) + self.assertEqual(6.0, grads[0].eval()) + + @test_util.run_v1_only("b/120545219") + def testAggregationMethodAccumulateN(self): + with self.cached_session(): + x = constant(1.0) + y = x * 2.0 + z = y + y + y + y + y + y + y + y + y + y + grads = gradients.gradients( + z, [x, y], + aggregation_method=gradients.AggregationMethod. + EXPERIMENTAL_ACCUMULATE_N) + self.assertTrue(all(x is not None for x in grads)) + self.assertEqual(20.0, grads[0].eval()) + self.assertEqual(10.0, grads[1].eval()) + + @test_util.run_v1_only("b/120545219") + def testAggregationMethodAddN(self): + with self.cached_session(): + x = constant(1.0) + y = x * 2.0 + z = y + y + y + y + y + y + y + y + y + y + grads = gradients.gradients( + z, [x, y], aggregation_method=gradients.AggregationMethod.ADD_N) + self.assertTrue(all(x is not None for x in grads)) + self.assertEqual(20.0, grads[0].eval()) + self.assertEqual(10.0, grads[1].eval()) + + @test_util.run_v1_only("b/120545219") + def testAggregationMethodTree(self): + with self.cached_session(): + x = constant(1.0) + y = x * 2.0 + z = y + y + y + y + y + y + y + y + y + y + grads = gradients.gradients( + z, [x, y], + aggregation_method=gradients.AggregationMethod.EXPERIMENTAL_TREE) + self.assertTrue(all(x is not None for x in grads)) + self.assertEqual(20.0, grads[0].eval()) + self.assertEqual(10.0, grads[1].eval()) + + def testNoGradientForStringOutputs(self): + with ops.Graph().as_default(): + + def _TestOpGrad(_, float_grad, string_grad): + """Gradient function for TestStringOutput.""" + self.assertEquals(float_grad.dtype, dtypes.float32) + self.assertFalse(string_grad) + return float_grad + + ops.RegisterGradient("TestStringOutput")(_TestOpGrad) + + c = constant(1.0) + x, _ = test_ops.test_string_output(c) + z = x * 2.0 + w = z * 3.0 + grads = gradients.gradients(z, [c]) + self.assertTrue(isinstance(grads[0], ops.Tensor)) + grads = gradients.gradients(w, [c]) + self.assertTrue(isinstance(grads[0], ops.Tensor)) + + def testSingletonIndexedSlices(self): + with ops.Graph().as_default(): + x = array_ops.placeholder(dtypes.float32) + y = array_ops.identity(x) + dy = ops.IndexedSlices( + array_ops.placeholder(dtypes.float32), + array_ops.placeholder(dtypes.int32)) + dx, = gradients.gradients(y, x, grad_ys=dy) + # The IndexedSlices gradient of tf.identity is the identity map. + with self.cached_session() as sess: + vdx, vdy = sess.run( + [dx, dy], feed_dict={x: [1.0], dy.indices: [0], dy.values: [2.0]}) + self.assertEqual(vdx, vdy) + + @test_util.run_v1_only("b/120545219") + def testNonDifferentiableSwitchInWhileLoop(self): + with ops.Graph().as_default(): + v = array_ops.placeholder(dtypes.float32, []) + + def _Step(i, a, ta): + a += math_ops.cast(v, dtypes.int32) + return (i + 1, a, ta.write(i, a)) + + n = 4 + i, _, ta = control_flow_ops.while_loop( + lambda i, *_: i < n, + _Step, [0, 0, tensor_array_ops.TensorArray( + dtypes.int32, size=n)]) + target = ta.read(i - 1) + grad, = gradients.gradients(target, v) + self.assertIsNone(grad) + + def testVariableReadValueGradient(self): + with ops.Graph().as_default(): + init = constant_op.constant(100.0) + var = variables.Variable(init) + gradient = gradients.gradients(var.read_value(), var) + self.assertIsNotNone(gradient) + + def testVariableAsGraphElementGradient(self): + with ops.Graph().as_default() as graph: + init = constant_op.constant(100.0) + var = variables.Variable(init) + gradient = gradients.gradients(graph.as_graph_element(var), var) + self.assertIsNotNone(gradient) + + @test_util.run_v1_only("b/120545219") + def testVariableRefGradient(self): + with ops.Graph().as_default(): + init = constant_op.constant(100.0) + var = variables.VariableV1(init) + gradient = gradients.gradients(var._ref(), var) + self.assertIsNotNone(gradient) + + @test_util.run_v1_only("b/120545219") + def testDependentYs(self): + with self.cached_session(): + x = constant_op.constant(3.0) + y = math_ops.square(x) + y1 = math_ops.square(y) + y2 = math_ops.square(y1) + g = gradients.gradients([y, y2], x) + self.assertAllClose(17502.0, g[0].eval()) + g = gradients.gradients(y + y2, x) + self.assertAllClose(17502.0, g[0].eval()) + z = array_ops.identity(y) + z2 = array_ops.identity(y2) + g = gradients.gradients([z, z2], x) + self.assertAllClose(17502.0, g[0].eval()) + + @test_util.run_v1_only("b/120545219") + def testPartialDerivatives(self): + with self.cached_session(): + x = constant_op.constant(1.) + y = 2 * x + z = x + y + totalg = gradients.gradients(z, [x, y]) + self.assertEqual([3.0, 1.0], [g.eval() for g in totalg]) + partialg = gradients.gradients(z, [x, y], stop_gradients=[x, y]) + self.assertEqual([1.0, 1.0], [g.eval() for g in partialg]) + + @test_util.run_v1_only("b/120545219") + def testStopGradients(self): + def _MakeGraph(rng, stop_gradients=()): + def _FunctionOf(xs, k=3): + return ops.convert_to_tensor( + sum(math_ops.matmul(rng.rand(k, k), x) for x in xs) + + rng.rand(k, k)) + + a = _FunctionOf([]) + if "a" in stop_gradients: a = array_ops.stop_gradient(a) + b = _FunctionOf([a]) + if "b" in stop_gradients: b = array_ops.stop_gradient(b) + c = _FunctionOf([a, b]) + if "c" in stop_gradients: c = array_ops.stop_gradient(c) + d = _FunctionOf([b, c]) + if "d" in stop_gradients: d = array_ops.stop_gradient(d) + return dict(a=a, b=b, c=c, d=d) + + def _Gradients(ys, xs, **kwargs): + dydxs = gradients.gradients(ys, xs, **kwargs) + dydxs = [0. * x if dydx is None else dydx + for x, dydx in zip(xs, dydxs)] + return dydxs + + seed = np.random.randint(1000) + cases = [] + subsets = [""] + "a b c d ab ac ad bc bd cd abc abd acd bcd abcd".split() + graph = _MakeGraph(np.random.RandomState(seed)) + for constants in subsets: + graph_with_stops = _MakeGraph(np.random.RandomState(seed), constants) + for variables_ in subsets: + # compute the gradient when stopped using tf.stop_gradients + grad1 = _Gradients([graph_with_stops["d"]], + [graph_with_stops[v] for v in variables_]) + # compute the gradient when stopped using the stop_gradients kwarg + grad2 = _Gradients([graph["d"]], + [graph[v] for v in variables_], + stop_gradients=[graph[v] for v in constants]) + cases.append(dict(grad1=grad1, grad2=grad2, + constants=constants, variables=variables_)) + + # evaluate all tensors in one call to session.run for speed + with self.cached_session() as sess: + results = sess.run([(case["grad1"], case["grad2"]) for case in cases]) + + for (npgrad1, npgrad2), case in zip(results, cases): + for a, b in zip(npgrad1, npgrad2): + np.testing.assert_allclose(a, b) + + def testUnconnectedGradientsNoneUnconnectedGradients(self): + with ops.Graph().as_default(): + x = constant(1.0, shape=[2, 2]) + y = constant(3.0, shape=[3, 1]) + grad = gradients.gradients( + [y], [x], unconnected_gradients="none") + self.assertIsNone(grad[0]) + + def testUnconnectedGradientsZerosUnconnectedGradients(self): + with ops.Graph().as_default(): + x = constant(1.0, shape=[2, 2]) + y = constant(3.0, shape=[3, 1]) + grads = gradients.gradients( + [y], [x], unconnected_gradients="zero") + with self.cached_session() as sess: + self.assertAllEqual([[0.0, 0.0], [0.0, 0.0]], self.evaluate(grads)[0]) + + def testUnconnectedGradientsZeroConnectedGradients(self): + with ops.Graph().as_default(): + x = constant(1.0) + y = x * 3.0 + grad = gradients.gradients( + [y], [x], unconnected_gradients="zero") + with self.cached_session() as sess: + self.assertEquals(3.0, self.evaluate(grad)[0]) + + def testUnknownUnconnectedGradientsValueGiven(self): + with ops.Graph().as_default(): + x = constant(1.0) + y = constant(1.0) + with self.assertRaisesRegexp( + ValueError, "Unknown value for unconnected_gradients: 'nonsense'"): + gradients.gradients([y], [x], unconnected_gradients="nonsense") + + +class FunctionGradientsTest(test_util.TensorFlowTestCase): + + @classmethod + def XSquarePlusB(cls, x, b): + return x * x + b + + @classmethod + def XSquarePlusBGradient(cls, x, b, g): + # Perturb gradients (multiply by 2), so we can test that this was called. + g *= 2.0 + return g * 2.0 * x, g + + @classmethod + def _PythonGradient(cls, op, grad): + # Perturb gradients (multiply by 3), so we can test that this was called. + grad *= 3.0 + return grad * op.inputs[0] * 2.0, grad + + @classmethod + def _GetFunc(cls, **kwargs): + return framework_function.Defun(dtypes.float32, dtypes.float32, ** + kwargs)(cls.XSquarePlusB) + + def _GetFuncGradients(self, f, x_value, b_value): + x = constant_op.constant(x_value, name="x") + b = constant_op.constant(b_value, name="b") + + y = f(x, b) + grads = gradients.gradients(y, [x, b]) + with self.cached_session() as sess: + return sess.run(grads) + + def testFunctionGradientsBasic(self): + g = ops.Graph() + with g.as_default(): + f = self._GetFunc() + # Get gradients (should add SymbolicGradient node for function). + grads = self._GetFuncGradients(f, [2.0], [1.0]) + self.assertAllEqual([4.0], grads[0]) + self.assertAllEqual([1.0], grads[1]) + + def testFunctionGradientsComposition(self): + with ops.Graph().as_default(): + f = self._GetFunc() + x = constant_op.constant([2.0], name="x") + b1 = constant_op.constant([1.0], name="b1") + b2 = constant_op.constant([1.0], name="b2") + + y = f(f(x, b1), b2) + # Build gradient graph (should add SymbolicGradient node for function). + grads = gradients.gradients(y, [x, b1]) + + with self.cached_session() as sess: + self.assertAllEqual([40.0], self.evaluate(grads)[0]) + self.assertAllEqual([10.0], self.evaluate(grads)[1]) + + def testFunctionGradientsWithGradFunc(self): + g = ops.Graph() + with g.as_default(): + grad_func = framework_function.Defun(dtypes.float32, dtypes.float32, + dtypes.float32)( + self.XSquarePlusBGradient) + f = self._GetFunc(grad_func=grad_func) + # Get gradients (should add SymbolicGradient node for function, which + # uses the grad_func above, which multiplies all gradients by 2). + grads = self._GetFuncGradients(f, [2.0], [1.0]) + self.assertAllEqual([4.0 * 2], grads[0]) + self.assertAllEqual([1.0 * 2], grads[1]) + + def testFunctionGradientWithRegistration(self): + g = ops.Graph() + with g.as_default(): + f = self._GetFunc(python_grad_func=self._PythonGradient) + # Get gradients, using the python gradient function. It multiplies the + # gradients by 3. + grads = self._GetFuncGradients(f, [2.0], [1.0]) + self.assertAllEqual([4.0 * 3], grads[0]) + self.assertAllEqual([1.0 * 3], grads[1]) + + def testFunctionGradientWithGradFuncAndRegistration(self): + g = ops.Graph() + with g.as_default(): + grad_func = framework_function.Defun(dtypes.float32, dtypes.float32, + dtypes.float32)( + self.XSquarePlusBGradient) + with self.assertRaisesRegexp(ValueError, "Gradient defined twice"): + f = self._GetFunc( + grad_func=grad_func, python_grad_func=self._PythonGradient) + f.add_to_graph(ops.Graph()) + + def testGradientWrtCaptured(self): + with ops.Graph().as_default(): + x = constant_op.constant(1.0, name="x") + + @function.defun() + def Foo(): + y = math_ops.multiply(x, 2.0, name="y") + g = gradients_impl.gradients(y, x) + return g[0] + + f = Foo() + with self.cached_session() as sess: + self.assertEqual(self.evaluate(f), 2.0) + + def testGradientOfCaptured(self): + with ops.Graph().as_default(): + x = constant_op.constant(1.0, name="x") + y = math_ops.multiply(x, 2.0, name="y") + + @framework_function.Defun() + def Foo(): + g = gradients_impl.gradients(y, x) + return g[0] + + f = Foo() + with self.cached_session() as sess: + self.assertEqual(self.evaluate(f), 2.0) + + def testCapturedResourceVariable(self): + with ops.Graph().as_default(): + var = resource_variable_ops.ResourceVariable(1.0, name="var") + + @function.defun() + def Foo(): + y = math_ops.multiply(var, 2.0, name="y") + g = gradients_impl.gradients(y, var) + return g[0] + + f = Foo() + with self.cached_session() as sess: + self.evaluate(variables.global_variables_initializer()) + self.assertEqual(self.evaluate(f), 2.0) + + def testCapturedNested(self): + with ops.Graph().as_default(): + x1 = constant_op.constant(1.0, name="x1") + x2 = constant_op.constant(2.0, name="x2") + x3 = math_ops.multiply(x1, x2, name="x3") + + @function.defun() + def Outer(): + outer1 = array_ops.identity(x1, name="outer1") + + @function.defun() + def Inner(): + inner1 = array_ops.identity(outer1, name="inner1") + inner2 = array_ops.identity(x2, name="inner2") + inner3 = array_ops.identity(x3, name="inner3") + return gradients_impl.gradients([inner1, inner2, inner3, x1], + [x1, x2]) + + return Inner() + + x1_grad, x2_grad = Outer() + with self.cached_session() as sess: + # 1.0 + None + 2.0 + 1.0 = 4.0 + self.assertEqual(self.evaluate(x1_grad), 4.0) + # None + 1.0 + 1.0 + None = 2.0 + self.assertEqual(self.evaluate(x2_grad), 2.0) + + def testCapturedFromFunction(self): + with ops.Graph().as_default(): + x = constant_op.constant(1.0, name="x") + + @function.defun() + def Outer(): + y = math_ops.multiply(x, 2.0, name="y") + + @function.defun() + def Inner(): + z = math_ops.multiply(y, 3.0, name="z") + g = gradients_impl.gradients(z, y) + return g[0] + + return Inner() + + z_grad = Outer() + with self.cached_session() as sess: + self.assertEqual(self.evaluate(z_grad), 3.0) + + def testCapturedEagerTensors(self): + # Test that we can handle captured eager tensors unrelated to the gradient + # computation (i.e. we need to ignore them). + # TODO(skyewm): make it an error if you try to take the gradient wrt a + # captured EagerTensor + with context.eager_mode(): + c = constant_op.constant(2.0, name="c") + + @function.defun + def Foo(): + x = constant_op.constant(10.0, name="x") + y = math_ops.multiply(x, c, name="y") + z = math_ops.multiply(y, 3.0, name="z") + g = gradients_impl.gradients(z, x) + return g[0] + + self.assertEqual(Foo().numpy(), 6.0) + + +class StopGradientTest(test_util.TensorFlowTestCase): + + def testStopGradient(self): + with ops.Graph().as_default(): + inp = constant(1.0, shape=[100, 32], name="in") + out = array_ops.stop_gradient(inp) + igrad = gradients.gradients(out, inp)[0] + assert igrad is None + + +class PreventGradientTest(test_util.TensorFlowTestCase): + + def testPreventGradient(self): + with ops.Graph().as_default(): + inp = constant(1.0, shape=[100, 32], name="in") + out = array_ops.prevent_gradient(inp) + with self.assertRaisesRegexp(LookupError, "explicitly disabled"): + _ = gradients.gradients(out, inp) + + +class HessianVectorProductTest(test_util.TensorFlowTestCase): + + @test_util.run_v1_only("b/120545219") + def testHessianVectorProduct(self): + # Manually compute the Hessian explicitly for a low-dimensional problem + # and check that HessianVectorProduct matches multiplication by the + # explicit Hessian. + # Specifically, the Hessian of f(x) = x^T A x is + # H = A + A^T. + # We expect HessianVectorProduct(f(x), x, v) to be H v. + m = 4 + rng = np.random.RandomState([1, 2, 3]) + mat_value = rng.randn(m, m).astype("float32") + v_value = rng.randn(m, 1).astype("float32") + x_value = rng.randn(m, 1).astype("float32") + hess_value = mat_value + mat_value.T + hess_v_value = np.dot(hess_value, v_value) + for use_gpu in [False, True]: + with self.cached_session(use_gpu=use_gpu): + mat = constant_op.constant(mat_value) + v = constant_op.constant(v_value) + x = constant_op.constant(x_value) + mat_x = math_ops.matmul(mat, x, name="Ax") + x_mat_x = math_ops.matmul(array_ops.transpose(x), mat_x, name="xAx") + hess_v = gradients_impl._hessian_vector_product(x_mat_x, [x], [v])[0] + hess_v_actual = self.evaluate(hess_v) + self.assertAllClose(hess_v_value, hess_v_actual) + + +class HessianTest(test_util.TensorFlowTestCase): + + @test_util.run_v1_only("b/120545219") + def testHessian1D(self): + # Manually compute the Hessian explicitly for a low-dimensional problem + # and check that `hessian` matches. Specifically, the Hessian of + # f(x) = x^T A x is H = A + A^T. + m = 4 + rng = np.random.RandomState([1, 2, 3]) + mat_value = rng.randn(m, m).astype("float32") + x_value = rng.randn(m).astype("float32") + hess_value = mat_value + mat_value.T + with self.session(use_gpu=True): + mat = constant_op.constant(mat_value) + x = constant_op.constant(x_value) + x_mat_x = math_ops.reduce_sum(x[:, None] * mat * x[None, :]) + hess = gradients.hessians(x_mat_x, x)[0] + hess_actual = self.evaluate(hess) + self.assertAllClose(hess_value, hess_actual) + + @test_util.run_v1_only("b/120545219") + def testHessian1D_multi(self): + # Test the computation of the hessian with respect to multiple tensors + m = 4 + n = 3 + rng = np.random.RandomState([1, 2, 3]) + mat_values = [rng.randn(m, m).astype("float32") for _ in range(n)] + x_values = [rng.randn(m).astype("float32") for _ in range(n)] + hess_values = [mat_value + mat_value.T for mat_value in mat_values] + with self.session(use_gpu=True): + mats = [constant_op.constant(mat_value) for mat_value in mat_values] + xs = [constant_op.constant(x_value) for x_value in x_values] + xs_mats_xs = [ + math_ops.reduce_sum(x[:, None] * mat * x[None, :]) + for x, mat in zip(xs, mats) + ] + hessians = gradients.hessians(xs_mats_xs, xs) + hessians_actual = [hess.eval() for hess in hessians] + for hess_value, hess_actual in zip(hess_values, hessians_actual): + self.assertAllClose(hess_value, hess_actual) + + @test_util.run_v1_only("b/120545219") + def testHessianInvalidDimension(self): + for shape in [(10, 10), None]: + with self.cached_session(use_gpu=True): + x = array_ops.placeholder(dtypes.float32, shape) + # Expect a ValueError because the dimensions are wrong + with self.assertRaises(ValueError): + gradients.hessians(x, x) + + @test_util.run_v1_only("b/120545219") + def testHessian2D_square_matrix(self): + # Manually compute the Hessian explicitly for a low-dimensional problem + # and check that `hessian` matches. Specifically, the Hessian of + # f(x) = 1/2 * x^T * x is H = constant (block identity matrix) + m = 3 + rng = np.random.RandomState([1, 2, 3]) + x_value = rng.randn(m, m).astype("float32") + with self.session(use_gpu=True): + x = constant_op.constant(x_value) + x_square = math_ops.reduce_sum( + math_ops.matmul(array_ops.transpose(x), x) * 0.5 + ) + hess = gradients.hessians(x_square, x)[0] + hess_actual = self.evaluate(hess) + hess_value = np.bmat([ + [elem*np.ones((m, m)) for elem in vec] + for vec in np.eye(m) + ]).astype("float32") + self.assertAllEqual((m, m, m, m), hess_actual.shape) + self.assertAllClose(hess_value, hess_actual.reshape((m * m, m * m))) + + @test_util.run_v1_only("b/120545219") + def testHessian2D_non_square_matrix(self): + m = 3 + n = 4 + rng = np.random.RandomState([1, 2, 3]) + x_value = rng.randn(m, n).astype("float32") + with self.session(use_gpu=True): + x = constant_op.constant(x_value) + x_square = math_ops.reduce_sum( + math_ops.matmul(array_ops.transpose(x), x) * 0.5 + ) + hess = gradients.hessians(x_square, x)[0] + hess_actual = self.evaluate(hess) + hess_value = np.bmat([ + [elem*np.ones((n, n)) for elem in vec] + for vec in np.eye(m) + ]).astype("float32") + self.assertAllEqual((m, n, m, n), hess_actual.shape) + self.assertAllClose(hess_value, hess_actual.reshape((m * n, m * n))) + + +class IndexedSlicesToTensorTest(test_util.TensorFlowTestCase): + + @test_util.run_v1_only("b/120545219") + def testIndexedSlicesToTensor(self): + with self.cached_session(): + np_val = np.random.rand(4, 4, 4, 4).astype(np.float32) + c = constant_op.constant(np_val) + c_sparse = math_ops._as_indexed_slices(c) + self.assertAllEqual(np_val.shape, c_sparse.dense_shape.eval()) + c_dense = math_ops.multiply(c_sparse, 1.0) + self.assertAllClose(np_val, self.evaluate(c_dense)) + + @test_util.run_v1_only("b/120545219") + def testIndexedSlicesToTensorList(self): + with self.cached_session(): + numpy_list = [] + dense_list = [] + sparse_list = [] + for _ in range(3): + np_val = np.random.rand(4, 4, 4, 4).astype(np.float32) + c = constant_op.constant(np_val) + c_sparse = math_ops._as_indexed_slices(c) + numpy_list.append(np_val) + dense_list.append(c) + sparse_list.append(c_sparse) + packed_dense = array_ops.stack(dense_list) + packed_sparse = array_ops.stack(sparse_list) + self.assertAllClose(packed_dense.eval(), self.evaluate(packed_sparse)) + + @test_util.run_v1_only("b/120545219") + def testInt64Indices(self): + with self.cached_session(): + np_val = np.random.rand(4, 4, 4, 4).astype(np.float32) + c = constant_op.constant(np_val) + c_sparse = math_ops._as_indexed_slices(c) + c_sparse = ops.IndexedSlices( + c_sparse.values, + math_ops.cast(c_sparse.indices, dtypes.int64), c_sparse.dense_shape) + self.assertAllEqual(np_val.shape, c_sparse.dense_shape.eval()) + c_dense = math_ops.multiply(c_sparse, 1.0) + self.assertAllClose(np_val, self.evaluate(c_dense)) + + @test_util.run_v1_only("b/120545219") + def testWarnings(self): + # TODO(gunan) Reenable after this issue is fixed: + # https://github.com/google/protobuf/issues/2812 + if sys.version_info >= (3, 5): + self.skipTest("Skipped test for Python 3.5+") + + # Smaller than the threshold: no warning. + c_sparse = ops.IndexedSlices( + array_ops.placeholder(dtypes.float32), + array_ops.placeholder(dtypes.int32), constant([4, 4, 4, 4])) + with warnings.catch_warnings(record=True) as w: + math_ops.multiply(c_sparse, 1.0) + self.assertEqual(0, len(w)) + + # Greater than or equal to the threshold: warning. + c_sparse = ops.IndexedSlices( + array_ops.placeholder(dtypes.float32), + array_ops.placeholder(dtypes.int32), constant([100, 100, 100, 100])) + # "always" filter prevents the warning from being suppressed if it was + # already triggered in a different test. + warnings.simplefilter("always") + with warnings.catch_warnings(record=True) as w: + math_ops.multiply(c_sparse, 1.0) + self.assertEqual(1, len(w)) + self.assertTrue( + "with 100000000 elements. This may consume a large amount of memory." in + str(w[0].message)) + + # Unknown dense shape: warning. + c_sparse = ops.IndexedSlices( + array_ops.placeholder(dtypes.float32), + array_ops.placeholder(dtypes.int32), + array_ops.placeholder(dtypes.int32)) + with warnings.catch_warnings(record=True) as w: + math_ops.multiply(c_sparse, 1.0) + self.assertEqual(1, len(w)) + self.assertTrue( + "of unknown shape. This may consume a large amount of memory." in + str(w[0].message)) + + +class OnlyRealGradientsTest(test_util.TensorFlowTestCase): + + @test_util.run_v1_only("b/120545219") + def testRealOnly(self): + x = constant_op.constant(7+3j, dtype=dtypes.complex64) + y = math_ops.square(x) + with self.assertRaisesRegexp( + TypeError, + r"Gradients of complex tensors must set grad_ys " + r"\(y\.dtype = tf\.complex64\)"): + gradients.gradients(y, x) + + +class ResourceCondTest(test_util.TensorFlowTestCase): + + @test_util.run_v1_only("b/120545219") + def testBasic(self): + gamma = resource_variable_ops.ResourceVariable( + np.random.random((3,)), + dtype="float32", name="gamma") + + inputs = array_ops.ones(shape=(3,), dtype="float32") + + def TestFn(): + output = inputs + gamma + return output + + training = array_ops.placeholder_with_default(True, shape=()) + output = control_flow_ops.cond( + training, TestFn, lambda: inputs) + + loss = output + + grads = gradients.gradients( + loss, [gamma]) + self.assertTrue(None not in grads) + + +class CustomGradientTest(test_util.TensorFlowTestCase): + + def testCustomGradientTrivial(self): + + @custom_gradient.custom_gradient + def MyIdentity(x): + + def Grad(dy): + return [3 * dy] + + return x, Grad + + with ops.Graph().as_default(): + x = constant(3.) + y = MyIdentity(MyIdentity(x)) + dy = gradients.gradients(y, x)[0] + with session.Session(): + self.assertEqual(9., self.evaluate(dy)) + + def testCustomGradient(self): + + @custom_gradient.custom_gradient + def MyMultiply(x1, x2): + result = x1 * x2 + + def Grad(dy): + # Switched the ordering here. + return [dy * x1, dy * x2] + + return result, Grad + + with ops.Graph().as_default(): + x1 = constant(3.) + x2 = constant(5.) + y = MyMultiply(x1, x2) + dy = gradients.gradients(y, [x1, x2]) + with session.Session() as sess: + self.assertAllEqual([3., 5.], self.evaluate(dy)) + + def testCustomGradientErrors(self): + + @custom_gradient.custom_gradient + def F(x): + + def Grad(_): + raise RuntimeError("x") + + return x, Grad + + with ops.Graph().as_default(): + x = constant(1.0) + y = F(x) + with self.assertRaises(RuntimeError): + gradients.gradients(y, x) + + def testCustomGradientWithVariables(self): + + @custom_gradient.custom_gradient + def F(x): + out = core_layers.dense(x, 3, use_bias=False) + + def Grad(out_grad, variables=None): # pylint: disable=redefined-outer-name + self.assertEqual(1, len(variables)) + grads = gradients.gradients(out, [x, variables[0]], grad_ys=out_grad) + return grads[0], [array_ops.ones((4, 3))] + + return out, Grad + + with ops.Graph().as_default(): + x = array_ops.ones((2, 4)) + with variable_scope.variable_scope("f", use_resource=True) as vs: + y = F(x) + all_vars = vs.global_variables() + assert len(all_vars) == 1 + grads = gradients.gradients(y, [x, all_vars[0]]) + for g in grads: + self.assertTrue(g is not None) + with session.Session() as sess: + self.evaluate(variables.global_variables_initializer()) + dw = sess.run(math_ops.reduce_sum(grads[1])) + self.assertEqual(12., dw) + + def testCustomGradientWithVariablesEager(self): + with context.eager_mode(): + layer = core_layers.Dense(4, use_bias=False) + + @custom_gradient.custom_gradient + def F(x): + out = layer(x) + + def Grad(out_grad, variables=None): # pylint: disable=redefined-outer-name + del out_grad + self.assertEqual(1, len(variables)) + return (array_ops.ones((3, 2)), + [array_ops.ones((2, 4))]) + + return out, Grad + + x = array_ops.ones((3, 2)) + 2. + with backprop.GradientTape() as tape: + tape.watch(x) + y = F(x) + w, = layer.variables + dx, dw = tape.gradient(y, [x, w]) + self.assertEqual(6., math_ops.reduce_sum(dx).numpy()) + self.assertEqual(8., math_ops.reduce_sum(dw).numpy()) + + @test_util.run_v1_only("b/120545219") + def testCustomGradientErrorsWithNonResourceVariables(self): + + def F(x, use_resource=False): + with variable_scope.variable_scope("f", use_resource=use_resource): + out = core_layers.dense(x, 4, use_bias=False) + + def Grad(out_grad, variables=None): # pylint: disable=redefined-outer-name + del out_grad + self.assertEqual(1, len(variables)) + return (array_ops.ones((3, 2)), [array_ops.ones((2, 4))]) + + return out, Grad + + @custom_gradient.custom_gradient + def FResource(x): + return F(x, use_resource=True) + + @custom_gradient.custom_gradient + def FNonResource(x): + return F(x, use_resource=False) + + x = array_ops.ones((3, 2)) + 2. + + # Wrapping scope has use_resource=True but inner scope sets to False. Fails. + with variable_scope.variable_scope("vs1", use_resource=True): + with self.assertRaisesWithPredicateMatch(TypeError, + "must be `ResourceVariable`s"): + FNonResource(x) + + # Wrapping scope has use_resource=False but inner scope sets to True. + # Passes. + with variable_scope.variable_scope("vs2", use_resource=False): + FResource(x) + + def testWithNumpyInputs(self): + with context.eager_mode(): + + @custom_gradient.custom_gradient + def F(x): + out = x + + def Grad(_): + return (None, None) + + return out, Grad + + x = np.ones((3, 2), dtype=np.float32) + # Smoke test to ensure numpy inputs are accepted + F(x) + + @test_util.run_v1_only("b/120545219") + def testRVGradientsDynamicCond(self): + with self.cached_session(): + alpha = resource_variable_ops.ResourceVariable( + np.random.random((1,)), + dtype="float32") + + conditional = array_ops.placeholder_with_default(True, shape=()) + output = control_flow_ops.cond( + conditional, lambda: alpha * 2, lambda: alpha * 3) + + g, = gradients_impl.gradients(output, alpha) + self.evaluate(variables.global_variables_initializer()) + self.assertAllEqual(g.eval(), [2.0]) + self.assertAllEqual(g.eval(feed_dict={conditional: False}), [3.0]) + + +class AggregateIndexedSlicesGradientsTest(test_util.TensorFlowTestCase): + + def _assert_indexed_slices_equal(self, left, right): + self.assertAllEqual( + self.evaluate(ops.convert_to_tensor(left)), + self.evaluate(ops.convert_to_tensor(right))) + + def testNoGradients(self): + self.assertIsNone(gradients_impl._AggregateIndexedSlicesGradients([])) + + def testOneGradient(self): + t = math_ops._as_indexed_slices(constant_op.constant( + [[1., 2.], [0, 0], [3., 4.]])) + result = gradients_impl._AggregateIndexedSlicesGradients([t]) + self._assert_indexed_slices_equal(t, result) + + def testMultipleGradients(self): + t0 = math_ops._as_indexed_slices(constant_op.constant( + [[1., 2.], [0, 0], [3., 4.]])) + t1 = math_ops._as_indexed_slices(constant_op.constant( + [[0., 0.], [5, 6], [7., 8.]])) + total = constant_op.constant( + [[1., 2.], [5, 6], [10., 12.]]) + result = gradients_impl._AggregateIndexedSlicesGradients([t0, t1]) + self._assert_indexed_slices_equal(total, result) + + def testMultipleGradientsWithNones(self): + t0 = math_ops._as_indexed_slices(constant_op.constant( + [[1., 2.], [0, 0], [3., 4.]])) + t1 = math_ops._as_indexed_slices(constant_op.constant( + [[0., 0.], [5, 6], [7., 8.]])) + t3 = None + total = constant_op.constant( + [[1., 2.], [5, 6], [10., 12.]]) + result = gradients_impl._AggregateIndexedSlicesGradients([t0, t1, t3]) + self._assert_indexed_slices_equal(total, result) + + def testMixedTensorAndIndexedSlices(self): + t0 = math_ops._as_indexed_slices(constant_op.constant( + [[1., 2.], [0, 0], [3., 4.]])) + t1 = constant_op.constant( + [[0., 0.], [5, 6], [7., 8.]]) + total = constant_op.constant( + [[1., 2.], [5, 6], [10., 12.]]) + result = gradients_impl._AggregateIndexedSlicesGradients([t0, t1]) + self._assert_indexed_slices_equal(total, result) + + +class TensorListGradientsTest(test_util.TensorFlowTestCase): + + def testDefaultGradYs(self): + with ops.Graph().as_default(): + tl = list_ops.empty_tensor_list( + element_dtype=dtypes.float32, + element_shape=ops.convert_to_tensor([], dtype=dtypes.int32)) + a = constant(1.0) + tl = list_ops.tensor_list_push_back(tl, a) + + grad_tl = list_ops.empty_tensor_list( + element_dtype=dtypes.float32, + element_shape=ops.convert_to_tensor([], dtype=dtypes.int32)) + grad_tl = list_ops.tensor_list_push_back(tl, constant(5.0)) + + grad = gradients.gradients(tl, a, grad_ys=grad_tl)[0] + with self.cached_session() as sess: + self.assertEquals(self.evaluate(grad), 5.) + + +if __name__ == "__main__": + googletest.main() diff --git a/test/TensorFlowNET.UnitTest/nn_test/nn_test.py b/test/TensorFlowNET.UnitTest/nn_test/nn_test.py new file mode 100644 index 00000000..82fab741 --- /dev/null +++ b/test/TensorFlowNET.UnitTest/nn_test/nn_test.py @@ -0,0 +1,1243 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for miscellaneous functionality in tensorflow.ops.nn.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math + +from absl.testing import parameterized +import numpy as np +from six.moves import xrange # pylint: disable=redefined-builtin + +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gradient_checker +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_impl +from tensorflow.python.ops import nn_ops +from tensorflow.python.ops import partitioned_variables +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables +import tensorflow.python.ops.nn_grad # pylint: disable=unused-import +from tensorflow.python.ops.nn_impl import _compute_sampled_logits +from tensorflow.python.platform import test as test_lib + + +class ZeroFractionTest(test_lib.TestCase): + + def _ZeroFraction(self, x): + assert x.shape + total_elements = np.prod(x.shape) + nonzeros = np.count_nonzero(x.flatten()) + return 1.0 - nonzeros / total_elements + + @test_util.run_deprecated_v1 + def testZeroFraction(self): + x_shape = [5, 17] + x_np = np.random.randint(0, 2, size=x_shape).astype(np.float32) + y_np = self._ZeroFraction(x_np) + + x_tf = constant_op.constant(x_np) + x_tf.set_shape(x_shape) + y_tf = nn_impl.zero_fraction(x_tf) + y_tf_np = self.evaluate(y_tf) + + eps = 1e-8 + self.assertAllClose(y_tf_np, y_np, eps) + + @test_util.run_deprecated_v1 + def testZeroFractionEmpty(self): + x = np.zeros(0) + y = self.evaluate(nn_impl.zero_fraction(x)) + self.assertTrue(np.isnan(y)) + + @test_util.run_deprecated_v1 + def testZeroFraction2_27Zeros(self): + sparsity = nn_impl.zero_fraction( + array_ops.zeros([int(2**27 * 1.01)], dtype=dtypes.int8)) + self.assertAllClose(1.0, self.evaluate(sparsity)) + + @test_util.run_deprecated_v1 + def testZeroFraction2_27Ones(self): + sparsity = nn_impl.zero_fraction( + array_ops.ones([int(2**27 * 1.01)], dtype=dtypes.int8)) + self.assertAllClose(0.0, self.evaluate(sparsity)) + + @test_util.run_deprecated_v1 + def testUnknownSize(self): + value = array_ops.placeholder(dtype=dtypes.float32) + sparsity = nn_impl.zero_fraction(value) + with self.cached_session() as sess: + self.assertAllClose( + 0.25, + sess.run(sparsity, {value: [[0., 1.], [0.3, 2.]]})) + + +class SoftmaxTest(test_lib.TestCase, parameterized.TestCase): + + def _softmax(self, x): + assert len(x.shape) == 2 + m = x.max(1)[:, np.newaxis] + u = np.exp(x - m) + z = u.sum(1)[:, np.newaxis] + return u / z + + @test_util.run_in_graph_and_eager_modes + def testSoftmax(self): + x_shape = [5, 10] + x_np = np.random.randn(*x_shape).astype(np.float32) + y_np = self._softmax(x_np) + x_tf = constant_op.constant(x_np) + y_tf = nn_ops.softmax_v2(x_tf) + y_tf_last_dim = nn_ops.softmax_v2(x_tf, 1) + y_tf_np = self.evaluate(y_tf) + y_tf_last_dim_np = self.evaluate(y_tf_last_dim) + eps = 1e-3 + self.assertAllClose(y_tf_np, y_np, eps) + self.assertAllClose(y_tf_last_dim_np, y_np, eps) + + def testSoftmaxAxes(self): + arr = np.linspace(0., 1, 12).reshape(3, 4) + x_neg_axis = nn_ops.softmax_v2(arr, axis=-2) + y_pos_axis = nn_ops.softmax_v2(arr, axis=0) + z_gt_axis = nn_ops.softmax_v2(arr, axis=0) + x_neg_axis_tf = self.evaluate(x_neg_axis) + y_pos_axis_tf = self.evaluate(y_pos_axis) + z_gt_axis_tf = self.evaluate(z_gt_axis) + eps = 1e-3 + self.assertAllClose(x_neg_axis_tf, y_pos_axis_tf, eps) + self.assertAllClose(y_pos_axis_tf, z_gt_axis_tf, eps) + + @parameterized.parameters(((5, 10),), ((2, 3, 4),)) + @test_util.run_deprecated_v1 + def testGradient(self, x_shape): + x_np = np.random.randn(*x_shape).astype(np.float64) + with self.cached_session(): + x_tf = constant_op.constant(x_np) + y_tf = nn_ops.softmax_v2(x_tf) + err = gradient_checker.compute_gradient_error(x_tf, x_shape, y_tf, + x_shape) + eps = 2e-8 + self.assertLess(err, eps) + + +class LogPoissonLossTest(test_lib.TestCase): + + def _log_poisson_loss(self, x, z, compute_full_loss=False): + lpl = np.exp(x) - z * x + if compute_full_loss: + stirling_approx = z * np.log(z) - z + 0.5 * np.log(2. * np.pi * z) + lpl += np.ma.masked_array(stirling_approx, mask=(z <= 1)).filled(0.) + return lpl + + @test_util.run_in_graph_and_eager_modes + def testLogPoissonLoss(self): + x_shape = [5, 10] + x_np = np.random.randn(*x_shape).astype(np.float32) + z_np = np.random.randint(0, 5, size=x_shape).astype(np.float32) + y_np = self._log_poisson_loss(x_np, z_np, compute_full_loss=False) + y_np_stirling = self._log_poisson_loss(x_np, z_np, compute_full_loss=True) + y_tf = nn_impl.log_poisson_loss(z_np, x_np, compute_full_loss=False) + y_tf_stirling = nn_impl.log_poisson_loss(z_np, x_np, compute_full_loss=True) + y_tf_np = self.evaluate(y_tf) + y_tf_np_stirling = self.evaluate(y_tf_stirling) + eps = 1e-3 + self.assertAllClose(y_tf_np, y_np, eps) + self.assertAllClose(y_tf_np_stirling, y_np_stirling, eps) + + @test_util.run_deprecated_v1 + def testGradient(self): + x_shape = [5, 10] + x_np = np.random.randn(*x_shape).astype(np.float64) + z_np = np.random.randint(0, 5, size=x_shape).astype(np.float64) + with self.cached_session(): + x_tf = constant_op.constant(x_np) + y_tf = nn_impl.log_poisson_loss(z_np, x_tf, compute_full_loss=False) + y_tf_stirling = nn_impl.log_poisson_loss( + z_np, x_tf, compute_full_loss=True) + err = gradient_checker.compute_gradient_error(x_tf, x_shape, y_tf, + x_shape) + err_stirling = gradient_checker.compute_gradient_error( + x_tf, x_shape, y_tf_stirling, x_shape) + eps = 1e-6 + self.assertLess(err, eps) + self.assertLess(err_stirling, eps) + + +class LogSoftmaxTest(test_lib.TestCase, parameterized.TestCase): + + def _log_softmax(self, x): + assert len(x.shape) == 2 + m = x.max(1)[:, np.newaxis] + u = x - m + return u - np.log(np.sum(np.exp(u), 1, keepdims=True)) + + @test_util.run_in_graph_and_eager_modes + def testLogSoftmax(self): + x_shape = [5, 10] + x_np = np.random.randn(*x_shape).astype(np.float32) + y_np = self._log_softmax(x_np) + x_tf = constant_op.constant(x_np) + y_tf = nn_ops.log_softmax_v2(x_tf) + y_tf_np = self.evaluate(y_tf) + eps = 1e-3 + self.assertAllClose(y_tf_np, y_np, eps) + + def testLogSoftmaxAxes(self): + arr = np.linspace(0., 1, 12).reshape(3, 4) + x_neg_axis = nn_ops.log_softmax_v2(arr, axis=-2) + y_pos_axis = nn_ops.log_softmax_v2(arr, axis=0) + z_gt_axis = nn_ops.log_softmax_v2(arr, axis=0) + x_neg_axis_tf = self.evaluate(x_neg_axis) + y_pos_axis_tf = self.evaluate(y_pos_axis) + z_gt_axis_tf = self.evaluate(z_gt_axis) + eps = 1e-3 + self.assertAllClose(x_neg_axis_tf, y_pos_axis_tf, eps) + self.assertAllClose(y_pos_axis_tf, z_gt_axis_tf, eps) + + @parameterized.parameters(((5, 10),), ((2, 3, 4),)) + @test_util.run_deprecated_v1 + def testGradient(self, x_shape): + x_np = np.random.randn(*x_shape).astype(np.float64) + with self.cached_session(): + x_tf = constant_op.constant(x_np) + y_tf = nn_ops.log_softmax_v2(x_tf) + err = gradient_checker.compute_gradient_error(x_tf, x_shape, y_tf, + x_shape) + eps = 1e-7 + self.assertLess(err, eps) + + +class L2LossTest(test_lib.TestCase): + + @test_util.run_in_graph_and_eager_modes + def testL2Loss(self): + for dtype in [dtypes.float32, dtypes.float64]: + x = constant_op.constant( + [1.0, 0.0, 3.0, 2.0], shape=[2, 2], name="x", dtype=dtype) + l2loss = nn_ops.l2_loss(x) + value = self.evaluate(l2loss) + self.assertAllClose(7.0, value) + + @test_util.run_deprecated_v1 + def testGradient(self): + x_shape = [20, 7, 3] + np.random.seed(1) # Make it reproducible. + x_val = np.random.random_sample(x_shape).astype(np.float64) + with self.cached_session(): + x = constant_op.constant(x_val, name="x") + output = nn_ops.l2_loss(x) + err = gradient_checker.compute_gradient_error(x, x_shape, output, [1]) + print("L2Loss gradient err = %g " % err) + err_tolerance = 1e-10 + self.assertLess(err, err_tolerance) + + +class L2NormalizeTest(test_lib.TestCase): + + def _l2Normalize(self, x, dim): + if isinstance(dim, list): + norm = np.linalg.norm(x, axis=tuple(dim)) + for d in dim: + norm = np.expand_dims(norm, d) + return x / norm + else: + norm = np.apply_along_axis(np.linalg.norm, dim, x) + return x / np.expand_dims(norm, dim) + + @test_util.run_in_graph_and_eager_modes + def testL2Normalize(self): + x_shape = [20, 7, 3] + np.random.seed(1) + x_np = np.random.random_sample(x_shape).astype(np.float32) + for dim in range(len(x_shape)): + y_np = self._l2Normalize(x_np, dim) + x_tf = constant_op.constant(x_np, name="x") + y_tf = nn_impl.l2_normalize_v2(x_tf, dim) + self.assertAllClose(y_np, self.evaluate(y_tf)) + + @test_util.run_in_graph_and_eager_modes + def testL2NormalizeDimArray(self): + x_shape = [20, 7, 3] + np.random.seed(1) + x_np = np.random.random_sample(x_shape).astype(np.float32) + dim = [1, 2] + y_np = self._l2Normalize(x_np, dim) + x_tf = constant_op.constant(x_np, name="x") + y_tf = nn_impl.l2_normalize_v2(x_tf, dim) + self.assertAllClose(y_np, self.evaluate(y_tf)) + + @test_util.run_deprecated_v1 + def testL2NormalizeGradient(self): + x_shape = [20, 7, 3] + np.random.seed(1) + x_np = np.random.random_sample(x_shape).astype(np.float64) + for dim in range(len(x_shape)): + with self.cached_session(): + x_tf = constant_op.constant(x_np, name="x") + y_tf = nn_impl.l2_normalize_v2(x_tf, dim) + err = gradient_checker.compute_gradient_error(x_tf, x_shape, y_tf, + x_shape) + print("L2Normalize gradient err = %g " % err) + self.assertLess(err, 1e-4) + + +class DropoutTest(test_lib.TestCase): + + def testDropout(self): + # Runs dropout with 0-1 tensor 10 times, sum the number of ones and validate + # that it is producing approximately the right number of ones over a large + # number of samples, based on the keep probability. + x_dim = 40 + y_dim = 30 + num_iter = 10 + for keep_prob in [0.1, 0.5, 0.8]: + t = constant_op.constant(1.0, shape=[x_dim, y_dim], dtype=dtypes.float32) + dropout = nn_ops.dropout(t, keep_prob) + final_count = 0 + self.assertEqual([x_dim, y_dim], dropout.get_shape()) + for _ in xrange(0, num_iter): + value = self.evaluate(dropout) + final_count += np.count_nonzero(value) + # Verifies that there are only two values: 0 and 1/keep_prob. + sorted_value = np.unique(np.sort(value)) + self.assertEqual(0, sorted_value[0]) + self.assertAllClose(1 / keep_prob, sorted_value[1]) + + # Check that we are in the 15% error range + expected_count = x_dim * y_dim * keep_prob * num_iter + rel_error = math.fabs(final_count - expected_count) / expected_count + print(rel_error) + self.assertTrue(rel_error < 0.15) + + def testShapedDropout(self): + # Runs dropout with 0-1 tensor 10 times, sum the number of ones and validate + # that it is producing approximately the right number of ones over a large + # number of samples, based on the keep probability. This time with shaped + # noise. + x_dim = 40 * 30 + y_dim = 3 + num_iter = 10 + for keep_prob in [0.1, 0.5, 0.8]: + t = constant_op.constant(1.0, shape=[x_dim, y_dim], dtype=dtypes.float32) + dropout = nn_ops.dropout(t, keep_prob, noise_shape=[x_dim, 1]) + self.assertEqual([x_dim, y_dim], dropout.get_shape()) + final_count = 0 + for _ in xrange(0, num_iter): + value = self.evaluate(dropout) + final_count += np.count_nonzero(value) + # Verifies that there are only two values: 0 and 1/keep_prob. + sorted_value = np.unique(np.sort(value)) + self.assertEqual(0, sorted_value[0]) + self.assertAllClose(1 / keep_prob, sorted_value[1]) + + # Check that we are in the 15% error range + expected_count = x_dim * y_dim * keep_prob * num_iter + rel_error = math.fabs(final_count - expected_count) / expected_count + print(rel_error) + self.assertTrue(rel_error < 0.15) + + def testShapedDropoutCorrelation(self): + # Runs a shaped dropout and tests that the correlations are correct. + x_dim = 40 + y_dim = 30 + num_iter = 10 + for keep_prob in [0.1, 0.5, 0.8]: + t = constant_op.constant(1.0, shape=[x_dim, y_dim], dtype=dtypes.float32) + dropout = nn_ops.dropout(t, keep_prob, noise_shape=[x_dim, 1]) + self.assertEqual([x_dim, y_dim], dropout.get_shape()) + for _ in xrange(0, num_iter): + value = self.evaluate(dropout) + # Verifies that each y column as only one type of activation. + for i in xrange(x_dim): + sorted_value = np.unique(np.sort(value[i, :])) + self.assertEqual(sorted_value.size, 1) + + @test_util.run_deprecated_v1 + def testDropoutPlaceholderKeepProb(self): + # Runs dropout with 0-1 tensor 10 times, sum the number of ones and validate + # that it is producing approximately the right number of ones over a large + # number of samples, based on the keep probability. + x_dim = 40 + y_dim = 30 + num_iter = 10 + for keep_prob in [0.1, 0.5, 0.8]: + with self.cached_session(): + t = constant_op.constant( + 1.0, shape=[x_dim, y_dim], dtype=dtypes.float32) + keep_prob_placeholder = array_ops.placeholder(dtypes.float32) + dropout = nn_ops.dropout(t, keep_prob_placeholder) + final_count = 0 + self.assertEqual([x_dim, y_dim], dropout.get_shape()) + for _ in xrange(0, num_iter): + value = dropout.eval(feed_dict={keep_prob_placeholder: keep_prob}) + final_count += np.count_nonzero(value) + # Verifies that there are only two values: 0 and 1/keep_prob. + sorted_value = np.unique(np.sort(value)) + self.assertEqual(0, sorted_value[0]) + self.assertAllClose(1 / keep_prob, sorted_value[1]) + # Check that we are in the 15% error range + expected_count = x_dim * y_dim * keep_prob * num_iter + rel_error = math.fabs(final_count - expected_count) / expected_count + print(rel_error) + self.assertTrue(rel_error < 0.15) + + @test_util.run_deprecated_v1 + def testShapedDropoutUnknownShape(self): + x_dim = 40 + y_dim = 30 + keep_prob = 0.5 + x = constant_op.constant(1.0, shape=[x_dim, y_dim], dtype=dtypes.float32) + dropout_x = nn_ops.dropout( + x, keep_prob, noise_shape=array_ops.placeholder(dtypes.int32)) + self.assertEqual(x.get_shape(), dropout_x.get_shape()) + + def testPartialShapedDropout(self): + x_dim = 40 * 30 + y_dim = 3 + num_iter = 10 + for keep_prob in [0.1, 0.5, 0.8]: + t = constant_op.constant(1.0, shape=[x_dim, y_dim], dtype=dtypes.float32) + # Set noise_shape=[None, 1] which means [x_dim, 1]. + dropout = nn_ops.dropout(t, keep_prob, noise_shape=[None, 1]) + self.assertEqual([x_dim, y_dim], dropout.get_shape()) + final_count = 0 + for _ in xrange(0, num_iter): + value = self.evaluate(dropout) + final_count += np.count_nonzero(value) + # Verifies that there are only two values: 0 and 1/keep_prob. + sorted_value = np.unique(np.sort(value)) + self.assertEqual(0, sorted_value[0]) + self.assertAllClose(1 / keep_prob, sorted_value[1]) + + # Check that we are in the 15% error range + expected_count = x_dim * y_dim * keep_prob * num_iter + rel_error = math.fabs(final_count - expected_count) / expected_count + print(rel_error) + self.assertTrue(rel_error < 0.15) + + @test_util.run_deprecated_v1 + def testInvalidKeepProb(self): + x_dim = 40 + y_dim = 30 + t = constant_op.constant(1.0, shape=[x_dim, y_dim], dtype=dtypes.float32) + with self.assertRaises(ValueError): + nn_ops.dropout(t, -1.0) + with self.assertRaises(ValueError): + nn_ops.dropout(t, 1.1) + with self.assertRaises(ValueError): + nn_ops.dropout(t, [0.0, 1.0]) + with self.assertRaises(ValueError): + nn_ops.dropout(t, array_ops.placeholder(dtypes.float64)) + with self.assertRaises(ValueError): + nn_ops.dropout(t, array_ops.placeholder(dtypes.float32, shape=[2])) + + @test_util.run_deprecated_v1 + def testInvalidRate(self): + x_dim = 40 + y_dim = 30 + t = constant_op.constant(1.0, shape=[x_dim, y_dim], dtype=dtypes.float32) + with self.assertRaises(ValueError): + nn_ops.dropout_v2(t, -1.0) + with self.assertRaises(ValueError): + nn_ops.dropout_v2(t, 1.1) + with self.assertRaises(ValueError): + nn_ops.dropout_v2(t, [0.0, 1.0]) + + @test_util.run_deprecated_v1 + def testShapedDropoutShapeError(self): + # Runs shaped dropout and verifies an error is thrown on misshapen noise. + x_dim = 40 + y_dim = 30 + keep_prob = 0.5 + t = constant_op.constant(1.0, shape=[x_dim, y_dim], dtype=dtypes.float32) + with self.assertRaises(ValueError): + _ = nn_ops.dropout(t, keep_prob, noise_shape=[x_dim, y_dim + 10]) + with self.assertRaises(ValueError): + _ = nn_ops.dropout(t, keep_prob, noise_shape=[x_dim, y_dim, 5]) + with self.assertRaises(ValueError): + _ = nn_ops.dropout(t, keep_prob, noise_shape=[x_dim + 3]) + with self.assertRaises(ValueError): + _ = nn_ops.dropout(t, keep_prob, noise_shape=[x_dim]) + # test that broadcasting proceeds + _ = nn_ops.dropout(t, keep_prob, noise_shape=[y_dim]) + _ = nn_ops.dropout(t, keep_prob, noise_shape=[1, y_dim]) + _ = nn_ops.dropout(t, keep_prob, noise_shape=[x_dim, 1]) + _ = nn_ops.dropout(t, keep_prob, noise_shape=[1, 1]) + + def testNoDropoutFast(self): + x = array_ops.zeros((5,)) + y = nn_ops.dropout(x, keep_prob=1) + self.assertTrue(x is y) + + y = nn_ops.dropout_v2(x, rate=0) + self.assertTrue(x is y) + + def testDropoutWithIntegerInputs(self): + x = constant_op.constant([1, 1, 1, 1, 1]) + with self.assertRaises(ValueError): + _ = nn_ops.dropout(x, 0.5) + + +class ComputeSampledLogitsTest(test_lib.TestCase): + + def setUp(self): + self._eps = 1e-3 + + def _GenerateTestData(self, num_classes, dim, batch_size, num_true, labels, + sampled, subtract_log_q): + """Randomly generates input/output data for a single test case. + + This function returns numpy constants for use in a test case. + + Args: + num_classes: An int. The number of embedding classes in the test case. + dim: An int. The dimension of the embedding. + batch_size: An int. The batch size. + num_true: An int. The number of target classes per training example. + labels: A list of batch_size * num_true ints. The target classes. + sampled: A list of indices in [0, num_classes). + subtract_log_q: A bool corresponding to the parameter in + _compute_sampled_logits(). + + Returns: + weights: Embedding weights to use as test input. It is a numpy array + of shape [num_classes, dim] + biases: Embedding biases to use as test input. It is a numpy array + of shape [num_classes]. + hidden_acts: Forward activations of the network to use as test input. + It is a numpy array of shape [batch_size, dim]. + sampled_vals: A tuple based on `sampled` to use as test input in the + format returned by a *_candidate_sampler function. + exp_logits: The output logits expected from _compute_sampled_logits(). + It is a numpy array of shape [batch_size, num_true + len(sampled)]. + exp_labels: The output labels expected from _compute_sampled_logits(). + It is a numpy array of shape [batch_size, num_true + len(sampled)]. + """ + weights = np.random.randn(num_classes, dim).astype(np.float32) + biases = np.random.randn(num_classes).astype(np.float32) + hidden_acts = np.random.randn(batch_size, dim).astype(np.float32) + + true_exp = np.full([batch_size, 1], fill_value=0.5, dtype=np.float32) + sampled_exp = np.full([len(sampled)], fill_value=0.5, dtype=np.float32) + sampled_vals = (sampled, true_exp, sampled_exp) + + sampled_w, sampled_b = weights[sampled], biases[sampled] + true_w, true_b = weights[labels], biases[labels] + + true_logits = np.sum( + hidden_acts.reshape((batch_size, 1, dim)) * true_w.reshape( + (batch_size, num_true, dim)), + axis=2) + true_b = true_b.reshape((batch_size, num_true)) + true_logits += true_b + sampled_logits = np.dot(hidden_acts, sampled_w.T) + sampled_b + + if subtract_log_q: + true_logits -= np.log(true_exp) + sampled_logits -= np.log(sampled_exp[np.newaxis, :]) + + exp_logits = np.concatenate([true_logits, sampled_logits], axis=1) + exp_labels = np.hstack((np.ones_like(true_logits) / num_true, + np.zeros_like(sampled_logits))) + + return weights, biases, hidden_acts, sampled_vals, exp_logits, exp_labels + + def _ShardTestEmbeddings(self, weights, biases, num_shards): + """Shards the weights and biases returned by _GenerateTestData. + + Args: + weights: The weights returned by _GenerateTestData. + biases: The biases returned by _GenerateTestData. + num_shards: The number of shards to create. + + Returns: + sharded_weights: A list of size `num_shards` containing all the weights. + sharded_biases: A list of size `num_shards` containing all the biases. + """ + with ops.Graph().as_default() as g: + sharded_weights = variable_scope.get_variable( + "w", + partitioner=partitioned_variables.fixed_size_partitioner(num_shards), + initializer=constant_op.constant(weights)) + sharded_biases = variable_scope.get_variable( + "b", + partitioner=partitioned_variables.fixed_size_partitioner(num_shards), + initializer=constant_op.constant(biases)) + with self.session(graph=g) as sess: + variables.global_variables_initializer().run() + return self.evaluate([list(sharded_weights), list(sharded_biases)]) + + def testShapes(self): + np.random.seed(0) + num_classes = 5 + batch_size = 3 + + for num_true in range(1, 5): + labels = np.random.randint( + low=0, high=num_classes, size=batch_size * num_true) + (weights, biases, hidden_acts, sampled_vals, exp_logits, + exp_labels) = self._GenerateTestData( + num_classes=num_classes, + dim=10, + batch_size=batch_size, + num_true=num_true, + labels=labels, + sampled=[1, 0, 2, 3], + subtract_log_q=False) + logits_tensor, labels_tensor = _compute_sampled_logits( + weights=constant_op.constant(weights), + biases=constant_op.constant(biases), + labels=constant_op.constant( + labels, dtype=dtypes.int64, shape=(batch_size, num_true)), + inputs=constant_op.constant(hidden_acts), + num_sampled=4, + num_classes=num_classes, + num_true=num_true, + sampled_values=sampled_vals, + subtract_log_q=False, + remove_accidental_hits=False, + partition_strategy="div", + name="sampled_logits_basic_num_true_%d" % num_true) + got_logits, got_labels = self.evaluate([logits_tensor, labels_tensor]) + self.assertEqual(exp_logits.shape, got_logits.shape, self._eps) + self.assertEqual(exp_labels.shape, got_labels.shape, self._eps) + + def testBasic(self): + """Without accidental hit removal or subtract_log_q.""" + np.random.seed(0) + num_classes = 5 + batch_size = 3 + + for num_true in range(1, 5): + labels = np.random.randint( + low=0, high=num_classes, size=batch_size * num_true) + (weights, biases, hidden_acts, sampled_vals, exp_logits, + exp_labels) = self._GenerateTestData( + num_classes=num_classes, + dim=10, + batch_size=batch_size, + num_true=num_true, + labels=labels, + sampled=[1, 0, 2, 3], + subtract_log_q=False) + logits_tensor, labels_tensor = _compute_sampled_logits( + weights=constant_op.constant(weights), + biases=constant_op.constant(biases), + labels=constant_op.constant( + labels, dtype=dtypes.int64, shape=(batch_size, num_true)), + inputs=constant_op.constant(hidden_acts), + num_sampled=4, + num_classes=num_classes, + num_true=num_true, + sampled_values=sampled_vals, + subtract_log_q=False, + remove_accidental_hits=False, + partition_strategy="div", + name="sampled_logits_basic_num_true_%d" % num_true) + got_logits, got_labels = self.evaluate([logits_tensor, labels_tensor]) + self.assertAllClose(exp_logits, got_logits, self._eps) + self.assertAllClose(exp_labels, got_labels, self._eps) + + def testAccidentalHitRemoval(self): + """With accidental hit removal, no subtract_log_q.""" + np.random.seed(0) + num_classes = 5 + batch_size = 3 + sampled = [1, 0, 2, 3] + + for num_true in range(1, 5): + labels = np.random.randint( + low=0, high=num_classes, size=batch_size * num_true) + (weights, biases, hidden_acts, sampled_vals, _, + _) = self._GenerateTestData( + num_classes=num_classes, + dim=10, + batch_size=batch_size, + num_true=num_true, + labels=labels, + sampled=sampled, + subtract_log_q=False) + logits_tensor, _ = _compute_sampled_logits( + weights=constant_op.constant(weights), + biases=constant_op.constant(biases), + labels=constant_op.constant( + labels, dtype=dtypes.int64, shape=(batch_size, num_true)), + inputs=constant_op.constant(hidden_acts), + num_sampled=len(sampled), + num_classes=num_classes, + num_true=num_true, + sampled_values=sampled_vals, + subtract_log_q=False, + remove_accidental_hits=True, + partition_strategy="div", + name="sampled_logits_accidental_hit_removal_num_true_%d" % num_true) + # Test that the exponentiated logits of accidental hits are near 0. + # First we need to find the hits in this random test run: + labels_reshape = labels.reshape((batch_size, num_true)) + got_logits = self.evaluate(logits_tensor) + for row in xrange(batch_size): + row_labels = labels_reshape[row, :] + for col in xrange(len(sampled)): + if sampled[col] in row_labels: + # We need to add the num_true_test offset into logits_* + self.assertNear( + np.exp(got_logits[row, col + num_true]), 0., self._eps) + + def testSubtractLogQ(self): + """With subtract_log_q, no accidental hit removal.""" + np.random.seed(0) + num_classes = 5 + batch_size = 3 + + for num_true in range(1, 5): + labels = np.random.randint( + low=0, high=num_classes, size=batch_size * num_true) + (weights, biases, hidden_acts, sampled_vals, exp_logits, + exp_labels) = self._GenerateTestData( + num_classes=num_classes, + dim=10, + batch_size=batch_size, + num_true=num_true, + labels=labels, + sampled=[1, 0, 2, 3], + subtract_log_q=True) + logits_tensor, labels_tensor = _compute_sampled_logits( + weights=constant_op.constant(weights), + biases=constant_op.constant(biases), + labels=constant_op.constant( + labels, dtype=dtypes.int64, shape=(batch_size, num_true)), + inputs=constant_op.constant(hidden_acts), + num_sampled=4, + num_classes=num_classes, + num_true=num_true, + sampled_values=sampled_vals, + subtract_log_q=True, + remove_accidental_hits=False, + partition_strategy="div", + name="sampled_logits_subtract_log_q_num_true_%d" % num_true) + got_logits, got_labels = self.evaluate([logits_tensor, labels_tensor]) + self.assertAllClose(exp_logits, got_logits, self._eps) + self.assertAllClose(exp_labels, got_labels, self._eps) + + def testSharded(self): + """With sharded weights and sharded biases.""" + np.random.seed(0) + num_classes = 5 + batch_size = 3 + + for num_true in range(1, 5): + labels = np.random.randint( + low=0, high=num_classes, size=batch_size * num_true) + (weights, biases, hidden_acts, sampled_vals, exp_logits, + exp_labels) = self._GenerateTestData( + num_classes=num_classes, + dim=10, + batch_size=batch_size, + num_true=num_true, + labels=labels, + sampled=[1, 0, 2, 3], + subtract_log_q=False) + weight_shards, bias_shards = self._ShardTestEmbeddings( + weights, biases, num_shards=3) + logits_tensor, labels_tensor = _compute_sampled_logits( + weights=[constant_op.constant(shard) for shard in weight_shards], + biases=[constant_op.constant(shard) for shard in bias_shards], + labels=constant_op.constant( + labels, dtype=dtypes.int64, shape=(batch_size, num_true)), + inputs=constant_op.constant(hidden_acts), + num_sampled=4, + num_classes=num_classes, + num_true=num_true, + sampled_values=sampled_vals, + subtract_log_q=False, + remove_accidental_hits=False, + partition_strategy="div", + name="sampled_logits_sharded_num_true_%d" % num_true) + got_logits, got_labels = self.evaluate([logits_tensor, labels_tensor]) + self.assertAllClose(exp_logits, got_logits, self._eps) + self.assertAllClose(exp_labels, got_labels, self._eps) + + def testNCELoss(self): + # A simple test to verify the numerics. + + def _SigmoidCrossEntropyWithLogits(logits, targets): + # logits, targets: float arrays of the same shape. + assert logits.shape == targets.shape + pred = 1. / (1. + np.exp(-logits)) + eps = 0.0001 + pred = np.minimum(np.maximum(pred, eps), 1 - eps) + return -targets * np.log(pred) - (1. - targets) * np.log(1. - pred) + + np.random.seed(0) + num_classes = 5 + batch_size = 3 + labels = [0, 1, 2] + (weights, biases, hidden_acts, sampled_vals, exp_logits, + exp_labels) = self._GenerateTestData( + num_classes=num_classes, + dim=10, + batch_size=batch_size, + num_true=1, + labels=labels, + sampled=[1, 0, 2, 3], + subtract_log_q=True) + exp_nce_loss = np.sum( + _SigmoidCrossEntropyWithLogits(exp_logits, exp_labels), 1) + + got_nce_loss = nn_impl.nce_loss_v2( + weights=constant_op.constant(weights), + biases=constant_op.constant(biases), + labels=constant_op.constant(labels, shape=(batch_size, 1)), + inputs=constant_op.constant(hidden_acts), + num_sampled=4, + num_classes=num_classes, + num_true=1, + sampled_values=sampled_vals) + + self.assertAllClose(exp_nce_loss, self.evaluate(got_nce_loss), 1e-4) + + # Test with sharded weights and sharded biases. + weight_shards, bias_shards = self._ShardTestEmbeddings( + weights, biases, num_shards=3) + got_nce_loss = nn_impl.nce_loss_v2( + weights=[constant_op.constant(shard) for shard in weight_shards], + biases=[constant_op.constant(shard) for shard in bias_shards], + labels=constant_op.constant(labels, shape=(batch_size, 1)), + inputs=constant_op.constant(hidden_acts), + num_sampled=4, + num_classes=num_classes, + num_true=1, + sampled_values=sampled_vals) + + self.assertAllClose(exp_nce_loss, self.evaluate(got_nce_loss), 1e-4) + + def testSampledSoftmaxLoss(self): + # A simple test to verify the numerics. + + def _SoftmaxCrossEntropyWithLogits(logits, targets): + # logits, targets: float arrays of the same shape. + assert logits.shape == targets.shape + stable_exp_logits = np.exp( + logits - np.amax(logits, axis=1, keepdims=True)) + pred = stable_exp_logits / np.sum(stable_exp_logits, 1, keepdims=True) + return -np.sum(targets * np.log(pred + 1.0e-20), axis=1) + + np.random.seed(0) + num_classes = 5 + batch_size = 3 + labels = [0, 1, 2] + (weights, biases, hidden_acts, sampled_vals, exp_logits, + exp_labels) = self._GenerateTestData( + num_classes=num_classes, + dim=10, + batch_size=batch_size, + num_true=1, + labels=labels, + sampled=[1, 0, 2, 3], + subtract_log_q=True) + exp_sampled_softmax_loss = _SoftmaxCrossEntropyWithLogits( + exp_logits, exp_labels) + + got_sampled_softmax_loss = nn_impl.sampled_softmax_loss_v2( + weights=constant_op.constant(weights), + biases=constant_op.constant(biases), + labels=constant_op.constant(labels, shape=(batch_size, 1)), + inputs=constant_op.constant(hidden_acts), + num_sampled=4, + num_classes=num_classes, + num_true=1, + sampled_values=sampled_vals, + remove_accidental_hits=False) + + self.assertAllClose(exp_sampled_softmax_loss, + self.evaluate(got_sampled_softmax_loss), 1e-4) + + # Test with sharded weights and sharded biases. + weight_shards, bias_shards = self._ShardTestEmbeddings( + weights, biases, num_shards=3) + got_sampled_softmax_loss = nn_impl.sampled_softmax_loss_v2( + weights=[constant_op.constant(shard) for shard in weight_shards], + biases=[constant_op.constant(shard) for shard in bias_shards], + labels=constant_op.constant(labels, shape=(batch_size, 1)), + inputs=constant_op.constant(hidden_acts), + num_sampled=4, + num_classes=num_classes, + num_true=1, + sampled_values=sampled_vals, + remove_accidental_hits=False) + + self.assertAllClose(exp_sampled_softmax_loss, + self.evaluate(got_sampled_softmax_loss), 1e-4) + + def testSampledSoftmaxLossBf16(self): + # A simple test to verify the numerics for bfloat16. + def _SoftmaxCrossEntropyWithLogits(logits, targets): + # logits, targets: float arrays of the same shape. + assert logits.shape == targets.shape + stable_exp_logits = np.exp( + logits - np.amax(logits, axis=1, keepdims=True)) + pred = stable_exp_logits / np.sum(stable_exp_logits, 1, keepdims=True) + return -np.sum(targets * np.log(pred + 1.0e-20), axis=1) + + np.random.seed(0) + num_classes = 5 + batch_size = 3 + labels = [0, 1, 2] + sampled = [1, 0, 2, 3] + (weights, biases, hidden_acts, _, exp_logits, + exp_labels) = self._GenerateTestData( + num_classes=num_classes, + dim=10, + batch_size=batch_size, + num_true=1, + labels=labels, + sampled=sampled, + subtract_log_q=True) + exp_sampled_softmax_loss = _SoftmaxCrossEntropyWithLogits( + exp_logits, exp_labels) + + true_exp_bf16 = np.full([batch_size, 1], + fill_value=0.5, + dtype=dtypes.bfloat16.as_numpy_dtype) + sampled_exp_bf16 = np.full([len(sampled)], + fill_value=0.5, + dtype=dtypes.bfloat16.as_numpy_dtype) + sampled_vals_bf16 = (sampled, true_exp_bf16, sampled_exp_bf16) + + got_sampled_softmax_loss = math_ops.cast( + nn_impl.sampled_softmax_loss_v2( + weights=constant_op.constant(weights, dtype=dtypes.bfloat16), + biases=constant_op.constant(biases, dtype=dtypes.bfloat16), + labels=constant_op.constant( + labels, shape=(batch_size, 1), dtype=dtypes.bfloat16), + inputs=constant_op.constant(hidden_acts, dtype=dtypes.bfloat16), + num_sampled=4, + num_classes=num_classes, + num_true=1, + sampled_values=sampled_vals_bf16, + remove_accidental_hits=False), dtypes.float32) + + self.assertAllClose(exp_sampled_softmax_loss, + self.evaluate(got_sampled_softmax_loss), 1e-1) + + +class CReluTest(test_lib.TestCase): + + def test(self): + np.random.seed(1) # Make it reproducible. + x = np.random.randn(3, 4).astype(np.float32) + y = np.concatenate([x * (x > 0), -x * (x < 0)], axis=1) + + z = self.evaluate(nn_ops.crelu(constant_op.constant(x))) + self.assertAllClose(y, z, 1e-4) + + +class ReluTest(test_lib.TestCase): + + def test(self): + np.random.seed(1) # Make it reproducible. + x = np.random.randn(3, 4).astype(np.float32) + y = np.maximum(x, 0.0) + + z = self.evaluate(nn_ops.relu(constant_op.constant(x))) + self.assertAllEqual(y, z) + + @test_util.run_deprecated_v1 + def testNaNs(self): + # Test that relu(nan) = nan for various sizes. + for i in range(18): + x = np.zeros(i) + np.nan + with self.cached_session(): + z = nn_ops.relu(constant_op.constant(x)).eval() + self.assertTrue(np.isnan(z).all()) + + +class LeakyReluTest(test_lib.TestCase): + + def testRange(self): + batch_size = 3 + height, width = 4, 4 + np.random.seed(1) # Make it reproducible. + inputs = np.random.uniform(size=(batch_size, height, width, 3)).astype( + np.float32) + inputs = constant_op.constant(inputs) + + outputs = nn_ops.leaky_relu(inputs) + self.assertEquals(inputs.shape, outputs.shape) + + inputs, outputs = self.evaluate([inputs, outputs]) + + self.assertGreaterEqual(outputs.min(), 0.0) + self.assertLessEqual(outputs.max(), 1.0) + self.assertAllClose(inputs, outputs) + + @test_util.run_deprecated_v1 + def testValues(self): + for dtype in [np.int32, np.int64, np.float16, np.float32, np.float64]: + np_values = np.array([-2, -1, 0, 1, 2], dtype=dtype) + outputs = nn_ops.leaky_relu(constant_op.constant(np_values)) + + outputs = self.evaluate(outputs) + + tol = 2e-3 if dtype == np.float16 else 1e-6 + self.assertAllClose( + outputs, [-0.4, -0.2, 0.0, 1.0, 2.0], rtol=tol, atol=tol) + + @test_util.run_deprecated_v1 + def testName(self): + np_values = np.array([-2, -1, 0, 1, 2], dtype=np.float64) + outputs_with_name_set = nn_ops.leaky_relu( + constant_op.constant(np_values), + name='test_relu_op') + self.assertEqual(outputs_with_name_set.name, 'test_relu_op:0') + outputs_without_name_set = nn_ops.leaky_relu( + constant_op.constant(np_values)) + self.assertEqual(outputs_without_name_set.name, 'LeakyRelu:0') + + +class SwishTest(test_lib.TestCase): + + @test_util.run_deprecated_v1 + def testValues(self): + np_values = np.array( + [np.linspace(-10.0, 0.0, 100), + np.linspace(0.0, 10.0, 100)], + dtype=np.float32) + tf_values = constant_op.constant(np_values) + actual_tf_outputs = nn_impl.swish(tf_values) + expected_tf_outputs = tf_values * math_ops.sigmoid(tf_values) + + actual_outputs, expected_outputs = self.evaluate( + [actual_tf_outputs, expected_tf_outputs]) + + self.assertAllClose(actual_outputs, expected_outputs) + + @test_util.run_deprecated_v1 + def testGradients(self): + shape = [5, 3, 4] + sigma = 5 + input_values = np.random.randn(*shape) * sigma + x_tf = constant_op.constant(input_values) + y_tf = nn_impl.swish(x_tf) + with self.cached_session(): + err = gradient_checker.compute_gradient_error(x_tf, shape, y_tf, shape) + self.assertLess(err, 1e-4) + + +class MomentsTest(test_lib.TestCase): + + def doOutputTest(self, + input_shape, + moments_axes, + tol=1e-4, + check_gradients=False): + for mu in [0.0, 1.0, 1e3]: + for sigma in [1.0, 0.1]: + for keep_dims in [True, False]: + input_values = np.random.rand(*input_shape) * sigma + mu + expected_mean = np.mean( + input_values, axis=moments_axes, keepdims=keep_dims) + expected_var = np.var( + input_values, axis=moments_axes, keepdims=keep_dims) + with ops.Graph().as_default() as g: + with self.session(graph=g) as sess: + inputs = constant_op.constant( + input_values, shape=input_shape, dtype=dtypes.float32) + mean, variance = nn_impl.moments_v2( + inputs, moments_axes, keepdims=keep_dims) + + if check_gradients: + err = gradient_checker.compute_gradient_error( + inputs, input_shape, mean, mean.shape.as_list()) + self.assertLess(err, 1e-3) + err = gradient_checker.compute_gradient_error( + inputs, input_shape, variance, variance.shape.as_list()) + self.assertLess(err, 1e-3) + + # Evaluate. + [mean, variance] = self.evaluate([mean, variance]) + # Make sure that there are no NaNs + self.assertFalse(np.isnan(mean).any()) + self.assertFalse(np.isnan(variance).any()) + self.assertAllClose(mean, expected_mean, rtol=tol, atol=tol) + self.assertAllClose(variance, expected_var, rtol=tol, atol=tol) + + def testOutputAndGradient2DInput0(self): + self.doOutputTest((10, 10), (0,), check_gradients=True) + + def testOutputAndGradient2DInput01(self): + self.doOutputTest((10, 10), (0, 1), check_gradients=True) + + def testOutput2DInput0(self): + self.doOutputTest((10, 300), (0,)) + + def testOutput2DInput1(self): + self.doOutputTest((10, 300), (1,)) + + def testOutput2DInput01(self): + self.doOutputTest((10, 300), (0, 1)) + + def testOutput4DInput0(self): + self.doOutputTest((10, 10, 10, 30), (0,)) + + def testOutput4DInput1(self): + self.doOutputTest((10, 10, 10, 30), (1,)) + + def testOutput4DInput3(self): + self.doOutputTest((10, 10, 10, 30), (3,)) + + def testOutput4DInput012(self): + self.doOutputTest((10, 10, 10, 30), (0, 1, 2)) + + def testOutput4DInput123(self): + self.doOutputTest((10, 10, 10, 30), (1, 2, 3)) + + +class DataFormatDimMapTest(test_lib.TestCase): + + def _test(self, x_val, y_val_expected): + x = constant_op.constant(x_val) + y = nn_ops.data_format_dim_map(x) + + y_val = self.evaluate(y) + self.assertAllEqual(y_val, y_val_expected) + + def test(self): + self._test(0, 0) + self._test(1, 2) + self._test(2, 3) + self._test(3, 1) + self._test(-1, 1) + self._test(-2, 3) + self._test(-3, 2) + self._test(-4, 0) + self._test([1, 3], [2, 1]) + self._test([1, 3, -2], [2, 1, 3]) + self._test([1, -3, -2], [2, 2, 3]) + self._test([[1, -3], [1, -1]], [[2, 2], [2, 1]]) + + def testNHWCtoNCHW(self): + x_val = [1, -3, -2] + y_val_expected = [2, 2, 3] + x = constant_op.constant(x_val) + y = nn_ops.data_format_dim_map(x, src_format="NHWC", dst_format="NCHW") + with test_util.use_gpu(): + y_val = self.evaluate(y) + self.assertAllEqual(y_val, y_val_expected) + + def testNHWCtoHWNC(self): + x_val = [-4, -3, -2, -1, 0, 1, 2, 3] + y_val_expected = [2, 0, 1, 3, 2, 0, 1, 3] + x = constant_op.constant(x_val) + y = nn_ops.data_format_dim_map(x, src_format="NHWC", dst_format="HWNC") + with test_util.use_gpu(): + y_val = self.evaluate(y) + self.assertAllEqual(y_val, y_val_expected) + + def testNHWCtoWHCN(self): + x_val = [-4, -3, -2, -1, 0, 1, 2, 3] + y_val_expected = [3, 1, 0, 2, 3, 1, 0, 2] + x = constant_op.constant(x_val) + y = nn_ops.data_format_dim_map(x, src_format="NHWC", dst_format="WHCN") + with test_util.use_gpu(): + y_val = self.evaluate(y) + self.assertAllEqual(y_val, y_val_expected) + + def testArbitraryASCII(self): + x_val = [-4, -3, -2, -1, 0, 1, 2, 3] + y_val_expected = [3, 2, 1, 0, 3, 2, 1, 0] + x = constant_op.constant(x_val) + y = nn_ops.data_format_dim_map(x, src_format="qwer", dst_format="rewq") + with test_util.use_gpu(): + y_val = self.evaluate(y) + self.assertAllEqual(y_val, y_val_expected) + + +class DataFormatVectorPermuteTest(test_lib.TestCase): + + def testNHWCToNCHW(self): + x_val = [7, 4, 9, 3] + x = constant_op.constant(x_val) + y = nn_ops.data_format_vec_permute(x) + with test_util.use_gpu(): + y_val = self.evaluate(y) + self.assertAllEqual(y_val, [7, 3, 4, 9]) + + def testNCHWToNHWC(self): + x_val = [7, 4, 9, 3] + x = constant_op.constant(x_val) + y = nn_ops.data_format_vec_permute(x, src_format="NCHW", dst_format="NHWC") + with test_util.use_gpu(): + y_val = self.evaluate(y) + self.assertAllEqual(y_val, [7, 9, 3, 4]) + + def testNHWCToHWNC(self): + x_val = [7, 4, 9, 3] + x = constant_op.constant(x_val) + y = nn_ops.data_format_vec_permute(x, src_format="NHWC", dst_format="HWNC") + with test_util.use_gpu(): + y_val = self.evaluate(y) + self.assertAllEqual(y_val, [4, 9, 7, 3]) + + def testHWNCToNHWC(self): + x_val = [7, 4, 9, 3] + x = constant_op.constant(x_val) + y = nn_ops.data_format_vec_permute(x, src_format="HWNC", dst_format="NHWC") + with test_util.use_gpu(): + y_val = self.evaluate(y) + self.assertAllEqual(y_val, [9, 7, 4, 3]) + + def testNHWCToNCHW2D(self): + x_val = [[7, 4], [9, 3], [4, 5], [5, 1]] + x = constant_op.constant(x_val) + y = nn_ops.data_format_vec_permute(x) + with test_util.use_gpu(): + y_val = self.evaluate(y) + self.assertAllEqual(y_val, [[7, 4], [5, 1], [9, 3], [4, 5]]) + + def testNHWCToHWNC2D(self): + x_val = [[7, 4], [9, 3], [4, 5], [5, 1]] + x = constant_op.constant(x_val) + y = nn_ops.data_format_vec_permute(x, src_format="NHWC", dst_format="HWNC") + with test_util.use_gpu(): + y_val = self.evaluate(y) + self.assertAllEqual(y_val, [[9, 3], [4, 5], [7, 4], [5, 1]]) + + def testHWNCToNHWC2D(self): + x_val = [[7, 4], [9, 3], [4, 5], [5, 1]] + x = constant_op.constant(x_val) + y = nn_ops.data_format_vec_permute(x, src_format="HWNC", dst_format="NHWC") + with test_util.use_gpu(): + y_val = self.evaluate(y) + self.assertAllEqual(y_val, [[4, 5], [7, 4], [9, 3], [5, 1]]) + + def testNCHWToNHWC2D(self): + x_val = [[7, 4], [9, 3], [4, 5], [5, 1]] + x = constant_op.constant(x_val) + y = nn_ops.data_format_vec_permute(x, src_format="NCHW", dst_format="NHWC") + with test_util.use_gpu(): + y_val = self.evaluate(y) + self.assertAllEqual(y_val, [[7, 4], [4, 5], [5, 1], [9, 3]]) + + +if __name__ == "__main__": + test_lib.main() diff --git a/test/TensorFlowNET.UnitTest/ops_test/ControlDependenciesTest.cs b/test/TensorFlowNET.UnitTest/ops_test/ControlDependenciesTest.cs index 315313c6..ca2665ff 100644 --- a/test/TensorFlowNET.UnitTest/ops_test/ControlDependenciesTest.cs +++ b/test/TensorFlowNET.UnitTest/ops_test/ControlDependenciesTest.cs @@ -139,7 +139,7 @@ namespace TensorFlowNET.UnitTest.ops_test var a_2 = constant_op.constant(3.0); var a_3 = constant_op.constant(4.0); var a_4 = constant_op.constant(5.0); - Operation b_1 = null, b_2 = null; + Tensor b_1 = null, b_2 = null; with(g.control_dependencies(new[] { a_1, a_2, a_3, a_4 }), ctrl => { b_1 = constant_op.constant(6.0); @@ -157,6 +157,12 @@ namespace TensorFlowNET.UnitTest.ops_test }); }); }); + var z=tf.add(a_1, tf.multiply(b_2, b_1)); + with(g.control_dependencies(new[] {z}), ctrl => + { + var z1 = tf.add(a_3, tf.multiply(a_4, a_2)); + }); + //tf.train.export_meta_graph(@"D:\dev\tensorboard\logdir\sharp.meta", as_text: false); assertItemsEqual(b_1.op.control_inputs, new[] { a_1.op, a_2.op, a_3.op, a_4.op }); assertItemsEqual(b_2.op.control_inputs, b_1.op.control_inputs); }