diff --git a/src/TensorFlowNET.Core/Keras/Layers/Conv.cs b/src/TensorFlowNET.Core/Keras/Layers/Conv.cs
index bc80ff72..efa79e00 100644
--- a/src/TensorFlowNET.Core/Keras/Layers/Conv.cs
+++ b/src/TensorFlowNET.Core/Keras/Layers/Conv.cs
@@ -101,8 +101,8 @@ namespace Tensorflow.Keras.Layers
throw new NotImplementedException("call channels_first");
}
else
- {
- outputs = nn_ops.bias_add(outputs, bias, data_format: "NHWC");
+ {
+ outputs = nn_ops.bias_add(outputs, bias._AsTensor(), data_format: "NHWC");
}
}
diff --git a/src/TensorFlowNET.Core/Operations/nn_ops.cs b/src/TensorFlowNET.Core/Operations/nn_ops.cs
index e54caf66..a5a2815c 100644
--- a/src/TensorFlowNET.Core/Operations/nn_ops.cs
+++ b/src/TensorFlowNET.Core/Operations/nn_ops.cs
@@ -30,7 +30,7 @@ namespace Tensorflow
///
///
public static Tensor bias_add(Tensor value,
- RefVariable bias,
+ Tensor bias,
string data_format = null,
string name = null)
{
diff --git a/test/TensorFlowNET.UnitTest/control_flow_ops_test/CondTestCases.cs b/test/TensorFlowNET.UnitTest/control_flow_ops_test/CondTestCases.cs
index 0b5f0879..ef39d41e 100644
--- a/test/TensorFlowNET.UnitTest/control_flow_ops_test/CondTestCases.cs
+++ b/test/TensorFlowNET.UnitTest/control_flow_ops_test/CondTestCases.cs
@@ -38,46 +38,7 @@ namespace TensorFlowNET.UnitTest.control_flow_ops_test
});
}
- [Ignore("Todo")]
- [TestMethod]
- public void testCondMissingArg1()
- {
- // def testCondMissingArg1(self):
- // x = constant_op.constant(1)
- // with self.assertRaises(TypeError):
- // control_flow_ops.cond(True, false_fn=lambda: x)
-
- }
-
- [Ignore("Todo")]
- [TestMethod]
- public void testCondMissingArg2()
- {
- // def testCondMissingArg2(self):
- // x = constant_op.constant(1)
- // with self.assertRaises(TypeError):
- // control_flow_ops.cond(True, lambda: x)
- }
-
- [Ignore("Todo")]
- [TestMethod]
- public void testCondDuplicateArg1()
- {
- // def testCondDuplicateArg1(self):
- // x = constant_op.constant(1)
- // with self.assertRaises(TypeError):
- // control_flow_ops.cond(True, lambda: x, lambda: x, fn1=lambda: x)
- }
-
- [Ignore("Todo")]
- [TestMethod]
- public void testCondDuplicateArg2()
- {
- // def testCondDuplicateArg2(self):
- // x = constant_op.constant(1)
- // with self.assertRaises(TypeError):
- // control_flow_ops.cond(True, lambda: x, lambda: x, fn2=lambda: x)
- }
+ // NOTE: all other test python test cases of this class are either not needed due to strong typing or dest a deprecated api
}
}
diff --git a/test/TensorFlowNET.UnitTest/gradients_test/GradientsTest.cs b/test/TensorFlowNET.UnitTest/gradients_test/GradientsTest.cs
new file mode 100644
index 00000000..2d308042
--- /dev/null
+++ b/test/TensorFlowNET.UnitTest/gradients_test/GradientsTest.cs
@@ -0,0 +1,507 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+using Tensorflow;
+
+namespace TensorFlowNET.UnitTest.gradients_test
+{
+ [TestClass]
+ public class GradientsTest : PythonTest
+ {
+
+ //[Ignore("TODO")]
+ [TestMethod]
+ public void testGradients()
+ {
+ with(tf.Graph().as_default(), g =>
+ {
+ var inp = tf.constant(1.0, shape: new[]{32, 100}, name:"in");
+ var w = tf.constant(1.0, shape: new[] { 100, 10}, name:"w");
+ var b = tf.constant(1.0, shape: new[] { 10}, name:"b");
+ var xw = math_ops.matmul(inp, w, name: "xw");
+ var h = nn_ops.bias_add(xw, b, name: "h");
+ var w_grad = gradients_impl.gradients(new []{h}, new[] { w})[0];
+ self.assertEquals("MatMul", w_grad.op.type);
+ // TODO: Operation._original_op
+ //self.assertEquals(w_grad.op._original_op, xw.op);
+ self.assertTrue((bool)w_grad.op.get_attr("transpose_a"));
+ self.assertFalse((bool)w_grad.op.get_attr("transpose_b"));
+ });
+
+ }
+
+ [Ignore("TODO")]
+ [TestMethod]
+ public void testUnusedOutput()
+ {
+ //def testUnusedOutput(self):
+ // with ops.Graph().as_default():
+ // w = constant(1.0, shape=[2, 2])
+ // x = constant(1.0, shape=[2, 2])
+ // wx = math_ops.matmul(w, x)
+ // split_wx = array_ops.split(value=wx, num_or_size_splits=2, axis=0)
+ // c = math_ops.reduce_sum(split_wx[1])
+ // gw = gradients.gradients(c, [w])[0]
+ // self.assertEquals("MatMul", gw.op.type)
+ }
+
+ [Ignore("TODO")]
+ [TestMethod]
+ public void testColocateGradients()
+ {
+
+ //def testColocateGradients(self):
+ // with ops.Graph().as_default() as g:
+ // w = constant(1.0, shape=[1, 1])
+ // x = constant(1.0, shape=[1, 2])
+ // with g.device("/device:GPU:0"):
+ // wx = math_ops.matmul(w, x)
+ // gw = gradients.gradients(wx, [w], colocate_gradients_with_ops=True)[0]
+ // self.assertEqual(gw.op.colocation_groups(), wx.op.colocation_groups())
+ }
+
+ [Ignore("TODO")]
+ [TestMethod]
+ public void testColocateGradientsWithAggregation()
+ {
+ //def testColocateGradientsWithAggregation(self):
+ // with ops.Graph().as_default() as g:
+ // with g.device("/device:GPU:1"):
+ // w = constant(1.0, shape=[1, 1])
+ // x = constant(1.0, shape=[1, 2])
+ // y = constant(1.0, shape=[1, 2])
+ // wx = math_ops.matmul(w, x)
+ // wy = math_ops.matmul(w, y)
+ // with g.device("/device:GPU:0"):
+ // z = wx + wy
+
+ // gw1 = gradients.gradients(z, [w], colocate_gradients_with_ops=True)[0]
+ // self.assertEqual(gw1.op.colocation_groups(), wx.op.colocation_groups())
+
+ // gw2 = gradients.gradients(z, [w], colocate_gradients_with_ops=False)[0]
+ // self.assertTrue(wx.op.colocation_groups() != gw2.op.colocation_groups())
+
+ }
+
+ [Ignore("TODO")]
+ [TestMethod]
+ public void testColocateGradientsWithAggregationInMultipleDevices()
+ {
+ //def testColocateGradientsWithAggregationInMultipleDevices(self):
+ // with ops.Graph().as_default() as g:
+ // with g.device("/device:GPU:1"):
+ // w = constant(1.0, shape=[1, 1])
+ // x = constant(1.0, shape=[1, 2])
+ // y = constant(1.0, shape=[1, 2])
+ // with g.device("/task:1"):
+ // wx = math_ops.matmul(w, x)
+ // with g.device("/task:2"):
+ // wy = math_ops.matmul(w, y)
+ // with g.device("/device:GPU:0"):
+ // z = wx + wy
+
+ // gw1 = gradients.gradients(z, [w], colocate_gradients_with_ops=True)[0]
+ // self.assertEqual(gw1.op.colocation_groups(), w.op.colocation_groups())
+
+ // gw2 = gradients.gradients(z, [w], colocate_gradients_with_ops=False)[0]
+ // self.assertTrue(w.op.colocation_groups() != gw2.op.colocation_groups())
+ }
+
+
+ [Ignore("TODO")]
+ [TestMethod]
+ public void testColocateGradientsWithGateGradients()
+ {
+
+ //def testColocateGradientsWithGateGradients(self):
+ // if not test_util.is_gpu_available():
+ // self.skipTest("No GPU available")
+ // with ops.Graph().as_default() as g:
+ // with g.device("/device:CPU:0"):
+ // x = constant(1.0, shape=[1, 1])
+ // y = constant(1.0, shape=[1, 1])
+ // s = x + y
+ // with g.device("/device:GPU:0"):
+ // z = math_ops.reduce_sum(s)
+
+ // gz_x = gradients.gradients(z, [x], colocate_gradients_with_ops=True,
+ // gate_gradients=True)[0]
+ // with session.Session():
+ // # Make sure the placer doesn't complain.
+ // self.evaluate(gz_x)
+
+ }
+
+ [Ignore("TODO")]
+ [TestMethod]
+ public void testBoundaryStop()
+ {
+ //def testBoundaryStop(self):
+ // # Test that we don't differentiate 'x'. The gradient function for 'x' is
+ // # set explicitly to None so we will get an exception if the gradient code
+ // # tries to differentiate 'x'.
+ // with ops.Graph().as_default():
+ // c = constant(1.0)
+ // x = array_ops.identity(c)
+ // y = x + 1.0
+ // z = y + 1
+ // grads = gradients.gradients(z, [x])
+ // self.assertTrue(all(x is not None for x in grads))
+
+ }
+
+ [Ignore("TODO")]
+ [TestMethod]
+ public void testBoundaryContinue()
+ {
+ //@test_util.run_v1_only("b/120545219")
+ //def testBoundaryContinue(self):
+ // # Test that we differentiate both 'x' and 'y' correctly when x is a
+ // # predecessor of y.
+ // with self.cached_session():
+ // x = constant(1.0)
+ // y = x * 2.0
+ // z = y * 3.0
+ // grads = gradients.gradients(z, [x, y])
+ // self.assertTrue(all(x is not None for x in grads))
+ // self.assertEqual(6.0, grads[0].eval())
+
+ }
+
+ [Ignore("TODO")]
+ [TestMethod]
+ public void testAggregationMethodAccumulateN()
+ {
+
+ //@test_util.run_v1_only("b/120545219")
+ //def testAggregationMethodAccumulateN(self):
+ // with self.cached_session():
+ // x = constant(1.0)
+ // y = x * 2.0
+ // z = y + y + y + y + y + y + y + y + y + y
+ // grads = gradients.gradients(
+ // z, [x, y],
+ // aggregation_method=gradients.AggregationMethod.
+ // EXPERIMENTAL_ACCUMULATE_N)
+ // self.assertTrue(all(x is not None for x in grads))
+ // self.assertEqual(20.0, grads[0].eval())
+ // self.assertEqual(10.0, grads[1].eval())
+
+ }
+
+ [Ignore("TODO")]
+ [TestMethod]
+ public void testAggregationMethodAddN()
+ {
+ //@test_util.run_v1_only("b/120545219")
+ //def testAggregationMethodAddN(self):
+ // with self.cached_session():
+ // x = constant(1.0)
+ // y = x * 2.0
+ // z = y + y + y + y + y + y + y + y + y + y
+ // grads = gradients.gradients(
+ // z, [x, y], aggregation_method=gradients.AggregationMethod.ADD_N)
+ // self.assertTrue(all(x is not None for x in grads))
+ // self.assertEqual(20.0, grads[0].eval())
+ // self.assertEqual(10.0, grads[1].eval())
+
+
+ }
+
+ [Ignore("TODO")]
+ [TestMethod]
+ public void testAggregationMethodTree()
+ {
+ //@test_util.run_v1_only("b/120545219")
+ //def testAggregationMethodTree(self):
+ // with self.cached_session():
+ // x = constant(1.0)
+ // y = x * 2.0
+ // z = y + y + y + y + y + y + y + y + y + y
+ // grads = gradients.gradients(
+ // z, [x, y],
+ // aggregation_method=gradients.AggregationMethod.EXPERIMENTAL_TREE)
+ // self.assertTrue(all(x is not None for x in grads))
+ // self.assertEqual(20.0, grads[0].eval())
+ // self.assertEqual(10.0, grads[1].eval())
+
+ }
+
+ [Ignore("TODO")]
+ [TestMethod]
+ public void testNoGradientForStringOutputs()
+ {
+
+ //def testNoGradientForStringOutputs(self):
+ // with ops.Graph().as_default():
+
+ // def _TestOpGrad(_, float_grad, string_grad):
+ // """Gradient function for TestStringOutput."""
+ // self.assertEquals(float_grad.dtype, dtypes.float32)
+ // self.assertFalse(string_grad)
+ // return float_grad
+
+ // ops.RegisterGradient("TestStringOutput")(_TestOpGrad)
+
+ // c = constant(1.0)
+ // x, _ = test_ops.test_string_output(c)
+ // z = x * 2.0
+ // w = z * 3.0
+ // grads = gradients.gradients(z, [c])
+ // self.assertTrue(isinstance(grads[0], ops.Tensor))
+ // grads = gradients.gradients(w, [c])
+ // self.assertTrue(isinstance(grads[0], ops.Tensor))
+ }
+
+ [Ignore("TODO")]
+ [TestMethod]
+ public void testSingletonIndexedSlices()
+ {
+
+ //def testSingletonIndexedSlices(self):
+ // with ops.Graph().as_default():
+ // x = array_ops.placeholder(dtypes.float32)
+ // y = array_ops.identity(x)
+ // dy = ops.IndexedSlices(
+ // array_ops.placeholder(dtypes.float32),
+ // array_ops.placeholder(dtypes.int32))
+ // dx, = gradients.gradients(y, x, grad_ys=dy)
+ // # The IndexedSlices gradient of tf.identity is the identity map.
+ // with self.cached_session() as sess:
+ // vdx, vdy = sess.run(
+ // [dx, dy], feed_dict={x: [1.0], dy.indices: [0], dy.values: [2.0]})
+ // self.assertEqual(vdx, vdy)
+ }
+
+ [Ignore("TODO")]
+ [TestMethod]
+ public void testNonDifferentiableSwitchInWhileLoop()
+ {
+
+
+ //@test_util.run_v1_only("b/120545219")
+ //def testNonDifferentiableSwitchInWhileLoop(self):
+ // with ops.Graph().as_default():
+ // v = array_ops.placeholder(dtypes.float32, [])
+
+ // def _Step(i, a, ta):
+ // a += math_ops.cast(v, dtypes.int32)
+ // return (i + 1, a, ta.write(i, a))
+
+ // n = 4
+ // i, _, ta = control_flow_ops.while_loop(
+ // lambda i, *_: i < n,
+ // _Step, [0, 0, tensor_array_ops.TensorArray(
+ // dtypes.int32, size=n)])
+ // target = ta.read(i - 1)
+ // grad, = gradients.gradients(target, v)
+ // self.assertIsNone(grad)
+
+ }
+
+ [Ignore("TODO")]
+ [TestMethod]
+ public void testVariableReadValueGradient()
+ {
+
+ //def testVariableReadValueGradient(self):
+ // with ops.Graph().as_default():
+ // init = constant_op.constant(100.0)
+ // var = variables.Variable(init)
+ // gradient = gradients.gradients(var.read_value(), var)
+ // self.assertIsNotNone(gradient)
+ }
+
+ [Ignore("TODO")]
+ [TestMethod]
+ public void testVariableAsGraphElementGradient()
+ {
+ //def testVariableAsGraphElementGradient(self):
+ // with ops.Graph().as_default() as graph:
+ // init = constant_op.constant(100.0)
+ // var = variables.Variable(init)
+ // gradient = gradients.gradients(graph.as_graph_element(var), var)
+ // self.assertIsNotNone(gradient)
+ }
+
+ [Ignore("TODO")]
+ [TestMethod]
+ public void testVariableRefGradient()
+ {
+
+ //@test_util.run_v1_only("b/120545219")
+ //def testVariableRefGradient(self):
+ // with ops.Graph().as_default():
+ // init = constant_op.constant(100.0)
+ // var = variables.VariableV1(init)
+ // gradient = gradients.gradients(var._ref(), var)
+ // self.assertIsNotNone(gradient)
+ }
+
+ [Ignore("TODO")]
+ [TestMethod]
+ public void testDependentYs()
+ {
+ //@test_util.run_v1_only("b/120545219")
+ //def testDependentYs(self):
+ // with self.cached_session():
+ // x = constant_op.constant(3.0)
+ // y = math_ops.square(x)
+ // y1 = math_ops.square(y)
+ // y2 = math_ops.square(y1)
+ // g = gradients.gradients([y, y2], x)
+ // self.assertAllClose(17502.0, g[0].eval())
+ // g = gradients.gradients(y + y2, x)
+ // self.assertAllClose(17502.0, g[0].eval())
+ // z = array_ops.identity(y)
+ // z2 = array_ops.identity(y2)
+ // g = gradients.gradients([z, z2], x)
+ // self.assertAllClose(17502.0, g[0].eval())
+
+ }
+
+ [Ignore("TODO")]
+ [TestMethod]
+ public void testPartialDerivatives()
+ {
+
+ //@test_util.run_v1_only("b/120545219")
+ //def testPartialDerivatives(self):
+ // with self.cached_session():
+ // x = constant_op.constant(1.)
+ // y = 2 * x
+ // z = x + y
+ // totalg = gradients.gradients(z, [x, y])
+ // self.assertEqual([3.0, 1.0], [g.eval() for g in totalg])
+ // partialg = gradients.gradients(z, [x, y], stop_gradients=[x, y])
+ // self.assertEqual([1.0, 1.0], [g.eval() for g in partialg])
+ }
+
+ [Ignore("TODO")]
+ [TestMethod]
+ public void testStopGradients()
+ {
+
+
+ //@test_util.run_v1_only("b/120545219")
+ //def testStopGradients(self):
+ // def _MakeGraph(rng, stop_gradients=()):
+ // def _FunctionOf(xs, k=3):
+ // return ops.convert_to_tensor(
+ // sum(math_ops.matmul(rng.rand(k, k), x) for x in xs)
+ // + rng.rand(k, k))
+
+ // a = _FunctionOf([])
+ // if "a" in stop_gradients: a = array_ops.stop_gradient(a)
+ // b = _FunctionOf([a])
+ // if "b" in stop_gradients: b = array_ops.stop_gradient(b)
+ // c = _FunctionOf([a, b])
+ // if "c" in stop_gradients: c = array_ops.stop_gradient(c)
+ // d = _FunctionOf([b, c])
+ // if "d" in stop_gradients: d = array_ops.stop_gradient(d)
+ // return dict(a=a, b=b, c=c, d=d)
+
+ // def _Gradients(ys, xs, **kwargs):
+ // dydxs = gradients.gradients(ys, xs, **kwargs)
+ // dydxs = [0. * x if dydx is None else dydx
+ // for x, dydx in zip(xs, dydxs)]
+ // return dydxs
+ // seed = np.random.randint(1000)
+ // cases = []
+ // subsets = [""] + "a b c d ab ac ad bc bd cd abc abd acd bcd abcd".split()
+ // graph = _MakeGraph(np.random.RandomState(seed))
+ // for constants in subsets:
+ // graph_with_stops = _MakeGraph(np.random.RandomState(seed), constants)
+ // for variables_ in subsets:
+ // # compute the gradient when stopped using tf.stop_gradients
+ // grad1 = _Gradients([graph_with_stops["d"]],
+ // [graph_with_stops[v] for v in variables_])
+ // # compute the gradient when stopped using the stop_gradients kwarg
+ // grad2 = _Gradients([graph["d"]],
+ // [graph[v] for v in variables_],
+ // stop_gradients=[graph[v] for v in constants])
+ // cases.append(dict(grad1=grad1, grad2=grad2,
+ // constants=constants, variables=variables_))
+
+ // # evaluate all tensors in one call to session.run for speed
+ // with self.cached_session() as sess:
+ // results = sess.run([(case["grad1"], case["grad2"]) for case in cases])
+
+ // for (npgrad1, npgrad2), case in zip(results, cases):
+ // for a, b in zip(npgrad1, npgrad2):
+ // np.testing.assert_allclose(a, b)
+
+ }
+
+ [Ignore("TODO")]
+ [TestMethod]
+ public void testUnconnectedGradientsNoneUnconnectedGradients()
+ {
+
+
+ //def testUnconnectedGradientsNoneUnconnectedGradients(self):
+ // with ops.Graph().as_default():
+ // x = constant(1.0, shape=[2, 2])
+ // y = constant(3.0, shape=[3, 1])
+ // grad = gradients.gradients(
+ // [y], [x], unconnected_gradients="none")
+ // self.assertIsNone(grad[0])
+ }
+
+ [Ignore("TODO")]
+ [TestMethod]
+ public void testUnconnectedGradientsZerosUnconnectedGradients()
+ {
+
+
+ //def testUnconnectedGradientsZerosUnconnectedGradients(self):
+ // with ops.Graph().as_default():
+ // x = constant(1.0, shape=[2, 2])
+ // y = constant(3.0, shape=[3, 1])
+ // grads = gradients.gradients(
+ // [y], [x], unconnected_gradients="zero")
+ // with self.cached_session() as sess:
+ // self.assertAllEqual([[0.0, 0.0], [0.0, 0.0]], self.evaluate(grads)[0])
+ }
+
+ [Ignore("TODO")]
+ [TestMethod]
+ public void testUnconnectedGradientsZeroConnectedGradients()
+ {
+
+
+
+ //def testUnconnectedGradientsZeroConnectedGradients(self):
+ // with ops.Graph().as_default():
+ // x = constant(1.0)
+ // y = x * 3.0
+ // grad = gradients.gradients(
+ // [y], [x], unconnected_gradients="zero")
+ // with self.cached_session() as sess:
+ // self.assertEquals(3.0, self.evaluate(grad)[0])
+ }
+
+ [Ignore("TODO")]
+ [TestMethod]
+ public void testUnknownUnconnectedGradientsValueGiven()
+ {
+ //def testUnknownUnconnectedGradientsValueGiven(self):
+ // with ops.Graph().as_default():
+ // x = constant(1.0)
+ // y = constant(1.0)
+ // with self.assertRaisesRegexp(
+ // ValueError, "Unknown value for unconnected_gradients: 'nonsense'"):
+ // gradients.gradients([y], [x], unconnected_gradients="nonsense")
+
+ }
+
+
+
+ /*
+
+
+
+ */
+ }
+}
diff --git a/test/TensorFlowNET.UnitTest/gradients_test/gradients_test.py b/test/TensorFlowNET.UnitTest/gradients_test/gradients_test.py
new file mode 100644
index 00000000..c53afef6
--- /dev/null
+++ b/test/TensorFlowNET.UnitTest/gradients_test/gradients_test.py
@@ -0,0 +1,1104 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for tensorflow.ops.gradients."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import sys
+import warnings
+
+import numpy as np
+
+from tensorflow.python.client import session
+from tensorflow.python.eager import backprop
+from tensorflow.python.eager import context
+from tensorflow.python.eager import function
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import function as framework_function
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_ops
+from tensorflow.python.framework import test_util
+from tensorflow.python.framework.constant_op import constant
+from tensorflow.python.layers import core as core_layers
+from tensorflow.python.ops import array_grad # pylint: disable=unused-import
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_grad # pylint: disable=unused-import
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import custom_gradient
+from tensorflow.python.ops import data_flow_grad # pylint: disable=unused-import
+from tensorflow.python.ops import data_flow_ops # pylint: disable=unused-import
+from tensorflow.python.ops import functional_ops # pylint: disable=unused-import
+from tensorflow.python.ops import gradients
+from tensorflow.python.ops import gradients_impl
+from tensorflow.python.ops import list_ops
+from tensorflow.python.ops import math_grad # pylint: disable=unused-import
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn_grad # pylint: disable=unused-import
+from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.ops import state_grad # pylint: disable=unused-import
+from tensorflow.python.ops import tensor_array_grad # pylint: disable=unused-import
+from tensorflow.python.ops import tensor_array_ops
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops import variables
+from tensorflow.python.ops.nn_ops import bias_add
+from tensorflow.python.platform import googletest
+
+
+class GradientsTest(test_util.TensorFlowTestCase):
+
+ def testGradients(self):
+ with ops.Graph().as_default():
+ inp = constant(1.0, shape=[32, 100], name="in")
+ w = constant(1.0, shape=[100, 10], name="w")
+ b = constant(1.0, shape=[10], name="b")
+ xw = math_ops.matmul(inp, w, name="xw")
+ h = bias_add(xw, b, name="h")
+ w_grad = gradients.gradients(h, w)[0]
+ self.assertEquals("MatMul", w_grad.op.type)
+ self.assertEquals(w_grad.op._original_op, xw.op)
+ self.assertTrue(w_grad.op.get_attr("transpose_a"))
+ self.assertFalse(w_grad.op.get_attr("transpose_b"))
+
+ def testUnusedOutput(self):
+ with ops.Graph().as_default():
+ w = constant(1.0, shape=[2, 2])
+ x = constant(1.0, shape=[2, 2])
+ wx = math_ops.matmul(w, x)
+ split_wx = array_ops.split(value=wx, num_or_size_splits=2, axis=0)
+ c = math_ops.reduce_sum(split_wx[1])
+ gw = gradients.gradients(c, [w])[0]
+ self.assertEquals("MatMul", gw.op.type)
+
+ def testColocateGradients(self):
+ with ops.Graph().as_default() as g:
+ w = constant(1.0, shape=[1, 1])
+ x = constant(1.0, shape=[1, 2])
+ with g.device("/device:GPU:0"):
+ wx = math_ops.matmul(w, x)
+ gw = gradients.gradients(wx, [w], colocate_gradients_with_ops=True)[0]
+ self.assertEqual(gw.op.colocation_groups(), wx.op.colocation_groups())
+
+ def testColocateGradientsWithAggregation(self):
+ with ops.Graph().as_default() as g:
+ with g.device("/device:GPU:1"):
+ w = constant(1.0, shape=[1, 1])
+ x = constant(1.0, shape=[1, 2])
+ y = constant(1.0, shape=[1, 2])
+ wx = math_ops.matmul(w, x)
+ wy = math_ops.matmul(w, y)
+ with g.device("/device:GPU:0"):
+ z = wx + wy
+
+ gw1 = gradients.gradients(z, [w], colocate_gradients_with_ops=True)[0]
+ self.assertEqual(gw1.op.colocation_groups(), wx.op.colocation_groups())
+
+ gw2 = gradients.gradients(z, [w], colocate_gradients_with_ops=False)[0]
+ self.assertTrue(wx.op.colocation_groups() != gw2.op.colocation_groups())
+
+ def testColocateGradientsWithAggregationInMultipleDevices(self):
+ with ops.Graph().as_default() as g:
+ with g.device("/device:GPU:1"):
+ w = constant(1.0, shape=[1, 1])
+ x = constant(1.0, shape=[1, 2])
+ y = constant(1.0, shape=[1, 2])
+ with g.device("/task:1"):
+ wx = math_ops.matmul(w, x)
+ with g.device("/task:2"):
+ wy = math_ops.matmul(w, y)
+ with g.device("/device:GPU:0"):
+ z = wx + wy
+
+ gw1 = gradients.gradients(z, [w], colocate_gradients_with_ops=True)[0]
+ self.assertEqual(gw1.op.colocation_groups(), w.op.colocation_groups())
+
+ gw2 = gradients.gradients(z, [w], colocate_gradients_with_ops=False)[0]
+ self.assertTrue(w.op.colocation_groups() != gw2.op.colocation_groups())
+
+ def testColocateGradientsWithGateGradients(self):
+ if not test_util.is_gpu_available():
+ self.skipTest("No GPU available")
+ with ops.Graph().as_default() as g:
+ with g.device("/device:CPU:0"):
+ x = constant(1.0, shape=[1, 1])
+ y = constant(1.0, shape=[1, 1])
+ s = x + y
+ with g.device("/device:GPU:0"):
+ z = math_ops.reduce_sum(s)
+
+ gz_x = gradients.gradients(z, [x], colocate_gradients_with_ops=True,
+ gate_gradients=True)[0]
+ with session.Session():
+ # Make sure the placer doesn't complain.
+ self.evaluate(gz_x)
+
+ def testBoundaryStop(self):
+ # Test that we don't differentiate 'x'. The gradient function for 'x' is
+ # set explicitly to None so we will get an exception if the gradient code
+ # tries to differentiate 'x'.
+ with ops.Graph().as_default():
+ c = constant(1.0)
+ x = array_ops.identity(c)
+ y = x + 1.0
+ z = y + 1
+ grads = gradients.gradients(z, [x])
+ self.assertTrue(all(x is not None for x in grads))
+
+ @test_util.run_v1_only("b/120545219")
+ def testBoundaryContinue(self):
+ # Test that we differentiate both 'x' and 'y' correctly when x is a
+ # predecessor of y.
+ with self.cached_session():
+ x = constant(1.0)
+ y = x * 2.0
+ z = y * 3.0
+ grads = gradients.gradients(z, [x, y])
+ self.assertTrue(all(x is not None for x in grads))
+ self.assertEqual(6.0, grads[0].eval())
+
+ @test_util.run_v1_only("b/120545219")
+ def testAggregationMethodAccumulateN(self):
+ with self.cached_session():
+ x = constant(1.0)
+ y = x * 2.0
+ z = y + y + y + y + y + y + y + y + y + y
+ grads = gradients.gradients(
+ z, [x, y],
+ aggregation_method=gradients.AggregationMethod.
+ EXPERIMENTAL_ACCUMULATE_N)
+ self.assertTrue(all(x is not None for x in grads))
+ self.assertEqual(20.0, grads[0].eval())
+ self.assertEqual(10.0, grads[1].eval())
+
+ @test_util.run_v1_only("b/120545219")
+ def testAggregationMethodAddN(self):
+ with self.cached_session():
+ x = constant(1.0)
+ y = x * 2.0
+ z = y + y + y + y + y + y + y + y + y + y
+ grads = gradients.gradients(
+ z, [x, y], aggregation_method=gradients.AggregationMethod.ADD_N)
+ self.assertTrue(all(x is not None for x in grads))
+ self.assertEqual(20.0, grads[0].eval())
+ self.assertEqual(10.0, grads[1].eval())
+
+ @test_util.run_v1_only("b/120545219")
+ def testAggregationMethodTree(self):
+ with self.cached_session():
+ x = constant(1.0)
+ y = x * 2.0
+ z = y + y + y + y + y + y + y + y + y + y
+ grads = gradients.gradients(
+ z, [x, y],
+ aggregation_method=gradients.AggregationMethod.EXPERIMENTAL_TREE)
+ self.assertTrue(all(x is not None for x in grads))
+ self.assertEqual(20.0, grads[0].eval())
+ self.assertEqual(10.0, grads[1].eval())
+
+ def testNoGradientForStringOutputs(self):
+ with ops.Graph().as_default():
+
+ def _TestOpGrad(_, float_grad, string_grad):
+ """Gradient function for TestStringOutput."""
+ self.assertEquals(float_grad.dtype, dtypes.float32)
+ self.assertFalse(string_grad)
+ return float_grad
+
+ ops.RegisterGradient("TestStringOutput")(_TestOpGrad)
+
+ c = constant(1.0)
+ x, _ = test_ops.test_string_output(c)
+ z = x * 2.0
+ w = z * 3.0
+ grads = gradients.gradients(z, [c])
+ self.assertTrue(isinstance(grads[0], ops.Tensor))
+ grads = gradients.gradients(w, [c])
+ self.assertTrue(isinstance(grads[0], ops.Tensor))
+
+ def testSingletonIndexedSlices(self):
+ with ops.Graph().as_default():
+ x = array_ops.placeholder(dtypes.float32)
+ y = array_ops.identity(x)
+ dy = ops.IndexedSlices(
+ array_ops.placeholder(dtypes.float32),
+ array_ops.placeholder(dtypes.int32))
+ dx, = gradients.gradients(y, x, grad_ys=dy)
+ # The IndexedSlices gradient of tf.identity is the identity map.
+ with self.cached_session() as sess:
+ vdx, vdy = sess.run(
+ [dx, dy], feed_dict={x: [1.0], dy.indices: [0], dy.values: [2.0]})
+ self.assertEqual(vdx, vdy)
+
+ @test_util.run_v1_only("b/120545219")
+ def testNonDifferentiableSwitchInWhileLoop(self):
+ with ops.Graph().as_default():
+ v = array_ops.placeholder(dtypes.float32, [])
+
+ def _Step(i, a, ta):
+ a += math_ops.cast(v, dtypes.int32)
+ return (i + 1, a, ta.write(i, a))
+
+ n = 4
+ i, _, ta = control_flow_ops.while_loop(
+ lambda i, *_: i < n,
+ _Step, [0, 0, tensor_array_ops.TensorArray(
+ dtypes.int32, size=n)])
+ target = ta.read(i - 1)
+ grad, = gradients.gradients(target, v)
+ self.assertIsNone(grad)
+
+ def testVariableReadValueGradient(self):
+ with ops.Graph().as_default():
+ init = constant_op.constant(100.0)
+ var = variables.Variable(init)
+ gradient = gradients.gradients(var.read_value(), var)
+ self.assertIsNotNone(gradient)
+
+ def testVariableAsGraphElementGradient(self):
+ with ops.Graph().as_default() as graph:
+ init = constant_op.constant(100.0)
+ var = variables.Variable(init)
+ gradient = gradients.gradients(graph.as_graph_element(var), var)
+ self.assertIsNotNone(gradient)
+
+ @test_util.run_v1_only("b/120545219")
+ def testVariableRefGradient(self):
+ with ops.Graph().as_default():
+ init = constant_op.constant(100.0)
+ var = variables.VariableV1(init)
+ gradient = gradients.gradients(var._ref(), var)
+ self.assertIsNotNone(gradient)
+
+ @test_util.run_v1_only("b/120545219")
+ def testDependentYs(self):
+ with self.cached_session():
+ x = constant_op.constant(3.0)
+ y = math_ops.square(x)
+ y1 = math_ops.square(y)
+ y2 = math_ops.square(y1)
+ g = gradients.gradients([y, y2], x)
+ self.assertAllClose(17502.0, g[0].eval())
+ g = gradients.gradients(y + y2, x)
+ self.assertAllClose(17502.0, g[0].eval())
+ z = array_ops.identity(y)
+ z2 = array_ops.identity(y2)
+ g = gradients.gradients([z, z2], x)
+ self.assertAllClose(17502.0, g[0].eval())
+
+ @test_util.run_v1_only("b/120545219")
+ def testPartialDerivatives(self):
+ with self.cached_session():
+ x = constant_op.constant(1.)
+ y = 2 * x
+ z = x + y
+ totalg = gradients.gradients(z, [x, y])
+ self.assertEqual([3.0, 1.0], [g.eval() for g in totalg])
+ partialg = gradients.gradients(z, [x, y], stop_gradients=[x, y])
+ self.assertEqual([1.0, 1.0], [g.eval() for g in partialg])
+
+ @test_util.run_v1_only("b/120545219")
+ def testStopGradients(self):
+ def _MakeGraph(rng, stop_gradients=()):
+ def _FunctionOf(xs, k=3):
+ return ops.convert_to_tensor(
+ sum(math_ops.matmul(rng.rand(k, k), x) for x in xs)
+ + rng.rand(k, k))
+
+ a = _FunctionOf([])
+ if "a" in stop_gradients: a = array_ops.stop_gradient(a)
+ b = _FunctionOf([a])
+ if "b" in stop_gradients: b = array_ops.stop_gradient(b)
+ c = _FunctionOf([a, b])
+ if "c" in stop_gradients: c = array_ops.stop_gradient(c)
+ d = _FunctionOf([b, c])
+ if "d" in stop_gradients: d = array_ops.stop_gradient(d)
+ return dict(a=a, b=b, c=c, d=d)
+
+ def _Gradients(ys, xs, **kwargs):
+ dydxs = gradients.gradients(ys, xs, **kwargs)
+ dydxs = [0. * x if dydx is None else dydx
+ for x, dydx in zip(xs, dydxs)]
+ return dydxs
+
+ seed = np.random.randint(1000)
+ cases = []
+ subsets = [""] + "a b c d ab ac ad bc bd cd abc abd acd bcd abcd".split()
+ graph = _MakeGraph(np.random.RandomState(seed))
+ for constants in subsets:
+ graph_with_stops = _MakeGraph(np.random.RandomState(seed), constants)
+ for variables_ in subsets:
+ # compute the gradient when stopped using tf.stop_gradients
+ grad1 = _Gradients([graph_with_stops["d"]],
+ [graph_with_stops[v] for v in variables_])
+ # compute the gradient when stopped using the stop_gradients kwarg
+ grad2 = _Gradients([graph["d"]],
+ [graph[v] for v in variables_],
+ stop_gradients=[graph[v] for v in constants])
+ cases.append(dict(grad1=grad1, grad2=grad2,
+ constants=constants, variables=variables_))
+
+ # evaluate all tensors in one call to session.run for speed
+ with self.cached_session() as sess:
+ results = sess.run([(case["grad1"], case["grad2"]) for case in cases])
+
+ for (npgrad1, npgrad2), case in zip(results, cases):
+ for a, b in zip(npgrad1, npgrad2):
+ np.testing.assert_allclose(a, b)
+
+ def testUnconnectedGradientsNoneUnconnectedGradients(self):
+ with ops.Graph().as_default():
+ x = constant(1.0, shape=[2, 2])
+ y = constant(3.0, shape=[3, 1])
+ grad = gradients.gradients(
+ [y], [x], unconnected_gradients="none")
+ self.assertIsNone(grad[0])
+
+ def testUnconnectedGradientsZerosUnconnectedGradients(self):
+ with ops.Graph().as_default():
+ x = constant(1.0, shape=[2, 2])
+ y = constant(3.0, shape=[3, 1])
+ grads = gradients.gradients(
+ [y], [x], unconnected_gradients="zero")
+ with self.cached_session() as sess:
+ self.assertAllEqual([[0.0, 0.0], [0.0, 0.0]], self.evaluate(grads)[0])
+
+ def testUnconnectedGradientsZeroConnectedGradients(self):
+ with ops.Graph().as_default():
+ x = constant(1.0)
+ y = x * 3.0
+ grad = gradients.gradients(
+ [y], [x], unconnected_gradients="zero")
+ with self.cached_session() as sess:
+ self.assertEquals(3.0, self.evaluate(grad)[0])
+
+ def testUnknownUnconnectedGradientsValueGiven(self):
+ with ops.Graph().as_default():
+ x = constant(1.0)
+ y = constant(1.0)
+ with self.assertRaisesRegexp(
+ ValueError, "Unknown value for unconnected_gradients: 'nonsense'"):
+ gradients.gradients([y], [x], unconnected_gradients="nonsense")
+
+
+class FunctionGradientsTest(test_util.TensorFlowTestCase):
+
+ @classmethod
+ def XSquarePlusB(cls, x, b):
+ return x * x + b
+
+ @classmethod
+ def XSquarePlusBGradient(cls, x, b, g):
+ # Perturb gradients (multiply by 2), so we can test that this was called.
+ g *= 2.0
+ return g * 2.0 * x, g
+
+ @classmethod
+ def _PythonGradient(cls, op, grad):
+ # Perturb gradients (multiply by 3), so we can test that this was called.
+ grad *= 3.0
+ return grad * op.inputs[0] * 2.0, grad
+
+ @classmethod
+ def _GetFunc(cls, **kwargs):
+ return framework_function.Defun(dtypes.float32, dtypes.float32, **
+ kwargs)(cls.XSquarePlusB)
+
+ def _GetFuncGradients(self, f, x_value, b_value):
+ x = constant_op.constant(x_value, name="x")
+ b = constant_op.constant(b_value, name="b")
+
+ y = f(x, b)
+ grads = gradients.gradients(y, [x, b])
+ with self.cached_session() as sess:
+ return sess.run(grads)
+
+ def testFunctionGradientsBasic(self):
+ g = ops.Graph()
+ with g.as_default():
+ f = self._GetFunc()
+ # Get gradients (should add SymbolicGradient node for function).
+ grads = self._GetFuncGradients(f, [2.0], [1.0])
+ self.assertAllEqual([4.0], grads[0])
+ self.assertAllEqual([1.0], grads[1])
+
+ def testFunctionGradientsComposition(self):
+ with ops.Graph().as_default():
+ f = self._GetFunc()
+ x = constant_op.constant([2.0], name="x")
+ b1 = constant_op.constant([1.0], name="b1")
+ b2 = constant_op.constant([1.0], name="b2")
+
+ y = f(f(x, b1), b2)
+ # Build gradient graph (should add SymbolicGradient node for function).
+ grads = gradients.gradients(y, [x, b1])
+
+ with self.cached_session() as sess:
+ self.assertAllEqual([40.0], self.evaluate(grads)[0])
+ self.assertAllEqual([10.0], self.evaluate(grads)[1])
+
+ def testFunctionGradientsWithGradFunc(self):
+ g = ops.Graph()
+ with g.as_default():
+ grad_func = framework_function.Defun(dtypes.float32, dtypes.float32,
+ dtypes.float32)(
+ self.XSquarePlusBGradient)
+ f = self._GetFunc(grad_func=grad_func)
+ # Get gradients (should add SymbolicGradient node for function, which
+ # uses the grad_func above, which multiplies all gradients by 2).
+ grads = self._GetFuncGradients(f, [2.0], [1.0])
+ self.assertAllEqual([4.0 * 2], grads[0])
+ self.assertAllEqual([1.0 * 2], grads[1])
+
+ def testFunctionGradientWithRegistration(self):
+ g = ops.Graph()
+ with g.as_default():
+ f = self._GetFunc(python_grad_func=self._PythonGradient)
+ # Get gradients, using the python gradient function. It multiplies the
+ # gradients by 3.
+ grads = self._GetFuncGradients(f, [2.0], [1.0])
+ self.assertAllEqual([4.0 * 3], grads[0])
+ self.assertAllEqual([1.0 * 3], grads[1])
+
+ def testFunctionGradientWithGradFuncAndRegistration(self):
+ g = ops.Graph()
+ with g.as_default():
+ grad_func = framework_function.Defun(dtypes.float32, dtypes.float32,
+ dtypes.float32)(
+ self.XSquarePlusBGradient)
+ with self.assertRaisesRegexp(ValueError, "Gradient defined twice"):
+ f = self._GetFunc(
+ grad_func=grad_func, python_grad_func=self._PythonGradient)
+ f.add_to_graph(ops.Graph())
+
+ def testGradientWrtCaptured(self):
+ with ops.Graph().as_default():
+ x = constant_op.constant(1.0, name="x")
+
+ @function.defun()
+ def Foo():
+ y = math_ops.multiply(x, 2.0, name="y")
+ g = gradients_impl.gradients(y, x)
+ return g[0]
+
+ f = Foo()
+ with self.cached_session() as sess:
+ self.assertEqual(self.evaluate(f), 2.0)
+
+ def testGradientOfCaptured(self):
+ with ops.Graph().as_default():
+ x = constant_op.constant(1.0, name="x")
+ y = math_ops.multiply(x, 2.0, name="y")
+
+ @framework_function.Defun()
+ def Foo():
+ g = gradients_impl.gradients(y, x)
+ return g[0]
+
+ f = Foo()
+ with self.cached_session() as sess:
+ self.assertEqual(self.evaluate(f), 2.0)
+
+ def testCapturedResourceVariable(self):
+ with ops.Graph().as_default():
+ var = resource_variable_ops.ResourceVariable(1.0, name="var")
+
+ @function.defun()
+ def Foo():
+ y = math_ops.multiply(var, 2.0, name="y")
+ g = gradients_impl.gradients(y, var)
+ return g[0]
+
+ f = Foo()
+ with self.cached_session() as sess:
+ self.evaluate(variables.global_variables_initializer())
+ self.assertEqual(self.evaluate(f), 2.0)
+
+ def testCapturedNested(self):
+ with ops.Graph().as_default():
+ x1 = constant_op.constant(1.0, name="x1")
+ x2 = constant_op.constant(2.0, name="x2")
+ x3 = math_ops.multiply(x1, x2, name="x3")
+
+ @function.defun()
+ def Outer():
+ outer1 = array_ops.identity(x1, name="outer1")
+
+ @function.defun()
+ def Inner():
+ inner1 = array_ops.identity(outer1, name="inner1")
+ inner2 = array_ops.identity(x2, name="inner2")
+ inner3 = array_ops.identity(x3, name="inner3")
+ return gradients_impl.gradients([inner1, inner2, inner3, x1],
+ [x1, x2])
+
+ return Inner()
+
+ x1_grad, x2_grad = Outer()
+ with self.cached_session() as sess:
+ # 1.0 + None + 2.0 + 1.0 = 4.0
+ self.assertEqual(self.evaluate(x1_grad), 4.0)
+ # None + 1.0 + 1.0 + None = 2.0
+ self.assertEqual(self.evaluate(x2_grad), 2.0)
+
+ def testCapturedFromFunction(self):
+ with ops.Graph().as_default():
+ x = constant_op.constant(1.0, name="x")
+
+ @function.defun()
+ def Outer():
+ y = math_ops.multiply(x, 2.0, name="y")
+
+ @function.defun()
+ def Inner():
+ z = math_ops.multiply(y, 3.0, name="z")
+ g = gradients_impl.gradients(z, y)
+ return g[0]
+
+ return Inner()
+
+ z_grad = Outer()
+ with self.cached_session() as sess:
+ self.assertEqual(self.evaluate(z_grad), 3.0)
+
+ def testCapturedEagerTensors(self):
+ # Test that we can handle captured eager tensors unrelated to the gradient
+ # computation (i.e. we need to ignore them).
+ # TODO(skyewm): make it an error if you try to take the gradient wrt a
+ # captured EagerTensor
+ with context.eager_mode():
+ c = constant_op.constant(2.0, name="c")
+
+ @function.defun
+ def Foo():
+ x = constant_op.constant(10.0, name="x")
+ y = math_ops.multiply(x, c, name="y")
+ z = math_ops.multiply(y, 3.0, name="z")
+ g = gradients_impl.gradients(z, x)
+ return g[0]
+
+ self.assertEqual(Foo().numpy(), 6.0)
+
+
+class StopGradientTest(test_util.TensorFlowTestCase):
+
+ def testStopGradient(self):
+ with ops.Graph().as_default():
+ inp = constant(1.0, shape=[100, 32], name="in")
+ out = array_ops.stop_gradient(inp)
+ igrad = gradients.gradients(out, inp)[0]
+ assert igrad is None
+
+
+class PreventGradientTest(test_util.TensorFlowTestCase):
+
+ def testPreventGradient(self):
+ with ops.Graph().as_default():
+ inp = constant(1.0, shape=[100, 32], name="in")
+ out = array_ops.prevent_gradient(inp)
+ with self.assertRaisesRegexp(LookupError, "explicitly disabled"):
+ _ = gradients.gradients(out, inp)
+
+
+class HessianVectorProductTest(test_util.TensorFlowTestCase):
+
+ @test_util.run_v1_only("b/120545219")
+ def testHessianVectorProduct(self):
+ # Manually compute the Hessian explicitly for a low-dimensional problem
+ # and check that HessianVectorProduct matches multiplication by the
+ # explicit Hessian.
+ # Specifically, the Hessian of f(x) = x^T A x is
+ # H = A + A^T.
+ # We expect HessianVectorProduct(f(x), x, v) to be H v.
+ m = 4
+ rng = np.random.RandomState([1, 2, 3])
+ mat_value = rng.randn(m, m).astype("float32")
+ v_value = rng.randn(m, 1).astype("float32")
+ x_value = rng.randn(m, 1).astype("float32")
+ hess_value = mat_value + mat_value.T
+ hess_v_value = np.dot(hess_value, v_value)
+ for use_gpu in [False, True]:
+ with self.cached_session(use_gpu=use_gpu):
+ mat = constant_op.constant(mat_value)
+ v = constant_op.constant(v_value)
+ x = constant_op.constant(x_value)
+ mat_x = math_ops.matmul(mat, x, name="Ax")
+ x_mat_x = math_ops.matmul(array_ops.transpose(x), mat_x, name="xAx")
+ hess_v = gradients_impl._hessian_vector_product(x_mat_x, [x], [v])[0]
+ hess_v_actual = self.evaluate(hess_v)
+ self.assertAllClose(hess_v_value, hess_v_actual)
+
+
+class HessianTest(test_util.TensorFlowTestCase):
+
+ @test_util.run_v1_only("b/120545219")
+ def testHessian1D(self):
+ # Manually compute the Hessian explicitly for a low-dimensional problem
+ # and check that `hessian` matches. Specifically, the Hessian of
+ # f(x) = x^T A x is H = A + A^T.
+ m = 4
+ rng = np.random.RandomState([1, 2, 3])
+ mat_value = rng.randn(m, m).astype("float32")
+ x_value = rng.randn(m).astype("float32")
+ hess_value = mat_value + mat_value.T
+ with self.session(use_gpu=True):
+ mat = constant_op.constant(mat_value)
+ x = constant_op.constant(x_value)
+ x_mat_x = math_ops.reduce_sum(x[:, None] * mat * x[None, :])
+ hess = gradients.hessians(x_mat_x, x)[0]
+ hess_actual = self.evaluate(hess)
+ self.assertAllClose(hess_value, hess_actual)
+
+ @test_util.run_v1_only("b/120545219")
+ def testHessian1D_multi(self):
+ # Test the computation of the hessian with respect to multiple tensors
+ m = 4
+ n = 3
+ rng = np.random.RandomState([1, 2, 3])
+ mat_values = [rng.randn(m, m).astype("float32") for _ in range(n)]
+ x_values = [rng.randn(m).astype("float32") for _ in range(n)]
+ hess_values = [mat_value + mat_value.T for mat_value in mat_values]
+ with self.session(use_gpu=True):
+ mats = [constant_op.constant(mat_value) for mat_value in mat_values]
+ xs = [constant_op.constant(x_value) for x_value in x_values]
+ xs_mats_xs = [
+ math_ops.reduce_sum(x[:, None] * mat * x[None, :])
+ for x, mat in zip(xs, mats)
+ ]
+ hessians = gradients.hessians(xs_mats_xs, xs)
+ hessians_actual = [hess.eval() for hess in hessians]
+ for hess_value, hess_actual in zip(hess_values, hessians_actual):
+ self.assertAllClose(hess_value, hess_actual)
+
+ @test_util.run_v1_only("b/120545219")
+ def testHessianInvalidDimension(self):
+ for shape in [(10, 10), None]:
+ with self.cached_session(use_gpu=True):
+ x = array_ops.placeholder(dtypes.float32, shape)
+ # Expect a ValueError because the dimensions are wrong
+ with self.assertRaises(ValueError):
+ gradients.hessians(x, x)
+
+ @test_util.run_v1_only("b/120545219")
+ def testHessian2D_square_matrix(self):
+ # Manually compute the Hessian explicitly for a low-dimensional problem
+ # and check that `hessian` matches. Specifically, the Hessian of
+ # f(x) = 1/2 * x^T * x is H = constant (block identity matrix)
+ m = 3
+ rng = np.random.RandomState([1, 2, 3])
+ x_value = rng.randn(m, m).astype("float32")
+ with self.session(use_gpu=True):
+ x = constant_op.constant(x_value)
+ x_square = math_ops.reduce_sum(
+ math_ops.matmul(array_ops.transpose(x), x) * 0.5
+ )
+ hess = gradients.hessians(x_square, x)[0]
+ hess_actual = self.evaluate(hess)
+ hess_value = np.bmat([
+ [elem*np.ones((m, m)) for elem in vec]
+ for vec in np.eye(m)
+ ]).astype("float32")
+ self.assertAllEqual((m, m, m, m), hess_actual.shape)
+ self.assertAllClose(hess_value, hess_actual.reshape((m * m, m * m)))
+
+ @test_util.run_v1_only("b/120545219")
+ def testHessian2D_non_square_matrix(self):
+ m = 3
+ n = 4
+ rng = np.random.RandomState([1, 2, 3])
+ x_value = rng.randn(m, n).astype("float32")
+ with self.session(use_gpu=True):
+ x = constant_op.constant(x_value)
+ x_square = math_ops.reduce_sum(
+ math_ops.matmul(array_ops.transpose(x), x) * 0.5
+ )
+ hess = gradients.hessians(x_square, x)[0]
+ hess_actual = self.evaluate(hess)
+ hess_value = np.bmat([
+ [elem*np.ones((n, n)) for elem in vec]
+ for vec in np.eye(m)
+ ]).astype("float32")
+ self.assertAllEqual((m, n, m, n), hess_actual.shape)
+ self.assertAllClose(hess_value, hess_actual.reshape((m * n, m * n)))
+
+
+class IndexedSlicesToTensorTest(test_util.TensorFlowTestCase):
+
+ @test_util.run_v1_only("b/120545219")
+ def testIndexedSlicesToTensor(self):
+ with self.cached_session():
+ np_val = np.random.rand(4, 4, 4, 4).astype(np.float32)
+ c = constant_op.constant(np_val)
+ c_sparse = math_ops._as_indexed_slices(c)
+ self.assertAllEqual(np_val.shape, c_sparse.dense_shape.eval())
+ c_dense = math_ops.multiply(c_sparse, 1.0)
+ self.assertAllClose(np_val, self.evaluate(c_dense))
+
+ @test_util.run_v1_only("b/120545219")
+ def testIndexedSlicesToTensorList(self):
+ with self.cached_session():
+ numpy_list = []
+ dense_list = []
+ sparse_list = []
+ for _ in range(3):
+ np_val = np.random.rand(4, 4, 4, 4).astype(np.float32)
+ c = constant_op.constant(np_val)
+ c_sparse = math_ops._as_indexed_slices(c)
+ numpy_list.append(np_val)
+ dense_list.append(c)
+ sparse_list.append(c_sparse)
+ packed_dense = array_ops.stack(dense_list)
+ packed_sparse = array_ops.stack(sparse_list)
+ self.assertAllClose(packed_dense.eval(), self.evaluate(packed_sparse))
+
+ @test_util.run_v1_only("b/120545219")
+ def testInt64Indices(self):
+ with self.cached_session():
+ np_val = np.random.rand(4, 4, 4, 4).astype(np.float32)
+ c = constant_op.constant(np_val)
+ c_sparse = math_ops._as_indexed_slices(c)
+ c_sparse = ops.IndexedSlices(
+ c_sparse.values,
+ math_ops.cast(c_sparse.indices, dtypes.int64), c_sparse.dense_shape)
+ self.assertAllEqual(np_val.shape, c_sparse.dense_shape.eval())
+ c_dense = math_ops.multiply(c_sparse, 1.0)
+ self.assertAllClose(np_val, self.evaluate(c_dense))
+
+ @test_util.run_v1_only("b/120545219")
+ def testWarnings(self):
+ # TODO(gunan) Reenable after this issue is fixed:
+ # https://github.com/google/protobuf/issues/2812
+ if sys.version_info >= (3, 5):
+ self.skipTest("Skipped test for Python 3.5+")
+
+ # Smaller than the threshold: no warning.
+ c_sparse = ops.IndexedSlices(
+ array_ops.placeholder(dtypes.float32),
+ array_ops.placeholder(dtypes.int32), constant([4, 4, 4, 4]))
+ with warnings.catch_warnings(record=True) as w:
+ math_ops.multiply(c_sparse, 1.0)
+ self.assertEqual(0, len(w))
+
+ # Greater than or equal to the threshold: warning.
+ c_sparse = ops.IndexedSlices(
+ array_ops.placeholder(dtypes.float32),
+ array_ops.placeholder(dtypes.int32), constant([100, 100, 100, 100]))
+ # "always" filter prevents the warning from being suppressed if it was
+ # already triggered in a different test.
+ warnings.simplefilter("always")
+ with warnings.catch_warnings(record=True) as w:
+ math_ops.multiply(c_sparse, 1.0)
+ self.assertEqual(1, len(w))
+ self.assertTrue(
+ "with 100000000 elements. This may consume a large amount of memory." in
+ str(w[0].message))
+
+ # Unknown dense shape: warning.
+ c_sparse = ops.IndexedSlices(
+ array_ops.placeholder(dtypes.float32),
+ array_ops.placeholder(dtypes.int32),
+ array_ops.placeholder(dtypes.int32))
+ with warnings.catch_warnings(record=True) as w:
+ math_ops.multiply(c_sparse, 1.0)
+ self.assertEqual(1, len(w))
+ self.assertTrue(
+ "of unknown shape. This may consume a large amount of memory." in
+ str(w[0].message))
+
+
+class OnlyRealGradientsTest(test_util.TensorFlowTestCase):
+
+ @test_util.run_v1_only("b/120545219")
+ def testRealOnly(self):
+ x = constant_op.constant(7+3j, dtype=dtypes.complex64)
+ y = math_ops.square(x)
+ with self.assertRaisesRegexp(
+ TypeError,
+ r"Gradients of complex tensors must set grad_ys "
+ r"\(y\.dtype = tf\.complex64\)"):
+ gradients.gradients(y, x)
+
+
+class ResourceCondTest(test_util.TensorFlowTestCase):
+
+ @test_util.run_v1_only("b/120545219")
+ def testBasic(self):
+ gamma = resource_variable_ops.ResourceVariable(
+ np.random.random((3,)),
+ dtype="float32", name="gamma")
+
+ inputs = array_ops.ones(shape=(3,), dtype="float32")
+
+ def TestFn():
+ output = inputs + gamma
+ return output
+
+ training = array_ops.placeholder_with_default(True, shape=())
+ output = control_flow_ops.cond(
+ training, TestFn, lambda: inputs)
+
+ loss = output
+
+ grads = gradients.gradients(
+ loss, [gamma])
+ self.assertTrue(None not in grads)
+
+
+class CustomGradientTest(test_util.TensorFlowTestCase):
+
+ def testCustomGradientTrivial(self):
+
+ @custom_gradient.custom_gradient
+ def MyIdentity(x):
+
+ def Grad(dy):
+ return [3 * dy]
+
+ return x, Grad
+
+ with ops.Graph().as_default():
+ x = constant(3.)
+ y = MyIdentity(MyIdentity(x))
+ dy = gradients.gradients(y, x)[0]
+ with session.Session():
+ self.assertEqual(9., self.evaluate(dy))
+
+ def testCustomGradient(self):
+
+ @custom_gradient.custom_gradient
+ def MyMultiply(x1, x2):
+ result = x1 * x2
+
+ def Grad(dy):
+ # Switched the ordering here.
+ return [dy * x1, dy * x2]
+
+ return result, Grad
+
+ with ops.Graph().as_default():
+ x1 = constant(3.)
+ x2 = constant(5.)
+ y = MyMultiply(x1, x2)
+ dy = gradients.gradients(y, [x1, x2])
+ with session.Session() as sess:
+ self.assertAllEqual([3., 5.], self.evaluate(dy))
+
+ def testCustomGradientErrors(self):
+
+ @custom_gradient.custom_gradient
+ def F(x):
+
+ def Grad(_):
+ raise RuntimeError("x")
+
+ return x, Grad
+
+ with ops.Graph().as_default():
+ x = constant(1.0)
+ y = F(x)
+ with self.assertRaises(RuntimeError):
+ gradients.gradients(y, x)
+
+ def testCustomGradientWithVariables(self):
+
+ @custom_gradient.custom_gradient
+ def F(x):
+ out = core_layers.dense(x, 3, use_bias=False)
+
+ def Grad(out_grad, variables=None): # pylint: disable=redefined-outer-name
+ self.assertEqual(1, len(variables))
+ grads = gradients.gradients(out, [x, variables[0]], grad_ys=out_grad)
+ return grads[0], [array_ops.ones((4, 3))]
+
+ return out, Grad
+
+ with ops.Graph().as_default():
+ x = array_ops.ones((2, 4))
+ with variable_scope.variable_scope("f", use_resource=True) as vs:
+ y = F(x)
+ all_vars = vs.global_variables()
+ assert len(all_vars) == 1
+ grads = gradients.gradients(y, [x, all_vars[0]])
+ for g in grads:
+ self.assertTrue(g is not None)
+ with session.Session() as sess:
+ self.evaluate(variables.global_variables_initializer())
+ dw = sess.run(math_ops.reduce_sum(grads[1]))
+ self.assertEqual(12., dw)
+
+ def testCustomGradientWithVariablesEager(self):
+ with context.eager_mode():
+ layer = core_layers.Dense(4, use_bias=False)
+
+ @custom_gradient.custom_gradient
+ def F(x):
+ out = layer(x)
+
+ def Grad(out_grad, variables=None): # pylint: disable=redefined-outer-name
+ del out_grad
+ self.assertEqual(1, len(variables))
+ return (array_ops.ones((3, 2)),
+ [array_ops.ones((2, 4))])
+
+ return out, Grad
+
+ x = array_ops.ones((3, 2)) + 2.
+ with backprop.GradientTape() as tape:
+ tape.watch(x)
+ y = F(x)
+ w, = layer.variables
+ dx, dw = tape.gradient(y, [x, w])
+ self.assertEqual(6., math_ops.reduce_sum(dx).numpy())
+ self.assertEqual(8., math_ops.reduce_sum(dw).numpy())
+
+ @test_util.run_v1_only("b/120545219")
+ def testCustomGradientErrorsWithNonResourceVariables(self):
+
+ def F(x, use_resource=False):
+ with variable_scope.variable_scope("f", use_resource=use_resource):
+ out = core_layers.dense(x, 4, use_bias=False)
+
+ def Grad(out_grad, variables=None): # pylint: disable=redefined-outer-name
+ del out_grad
+ self.assertEqual(1, len(variables))
+ return (array_ops.ones((3, 2)), [array_ops.ones((2, 4))])
+
+ return out, Grad
+
+ @custom_gradient.custom_gradient
+ def FResource(x):
+ return F(x, use_resource=True)
+
+ @custom_gradient.custom_gradient
+ def FNonResource(x):
+ return F(x, use_resource=False)
+
+ x = array_ops.ones((3, 2)) + 2.
+
+ # Wrapping scope has use_resource=True but inner scope sets to False. Fails.
+ with variable_scope.variable_scope("vs1", use_resource=True):
+ with self.assertRaisesWithPredicateMatch(TypeError,
+ "must be `ResourceVariable`s"):
+ FNonResource(x)
+
+ # Wrapping scope has use_resource=False but inner scope sets to True.
+ # Passes.
+ with variable_scope.variable_scope("vs2", use_resource=False):
+ FResource(x)
+
+ def testWithNumpyInputs(self):
+ with context.eager_mode():
+
+ @custom_gradient.custom_gradient
+ def F(x):
+ out = x
+
+ def Grad(_):
+ return (None, None)
+
+ return out, Grad
+
+ x = np.ones((3, 2), dtype=np.float32)
+ # Smoke test to ensure numpy inputs are accepted
+ F(x)
+
+ @test_util.run_v1_only("b/120545219")
+ def testRVGradientsDynamicCond(self):
+ with self.cached_session():
+ alpha = resource_variable_ops.ResourceVariable(
+ np.random.random((1,)),
+ dtype="float32")
+
+ conditional = array_ops.placeholder_with_default(True, shape=())
+ output = control_flow_ops.cond(
+ conditional, lambda: alpha * 2, lambda: alpha * 3)
+
+ g, = gradients_impl.gradients(output, alpha)
+ self.evaluate(variables.global_variables_initializer())
+ self.assertAllEqual(g.eval(), [2.0])
+ self.assertAllEqual(g.eval(feed_dict={conditional: False}), [3.0])
+
+
+class AggregateIndexedSlicesGradientsTest(test_util.TensorFlowTestCase):
+
+ def _assert_indexed_slices_equal(self, left, right):
+ self.assertAllEqual(
+ self.evaluate(ops.convert_to_tensor(left)),
+ self.evaluate(ops.convert_to_tensor(right)))
+
+ def testNoGradients(self):
+ self.assertIsNone(gradients_impl._AggregateIndexedSlicesGradients([]))
+
+ def testOneGradient(self):
+ t = math_ops._as_indexed_slices(constant_op.constant(
+ [[1., 2.], [0, 0], [3., 4.]]))
+ result = gradients_impl._AggregateIndexedSlicesGradients([t])
+ self._assert_indexed_slices_equal(t, result)
+
+ def testMultipleGradients(self):
+ t0 = math_ops._as_indexed_slices(constant_op.constant(
+ [[1., 2.], [0, 0], [3., 4.]]))
+ t1 = math_ops._as_indexed_slices(constant_op.constant(
+ [[0., 0.], [5, 6], [7., 8.]]))
+ total = constant_op.constant(
+ [[1., 2.], [5, 6], [10., 12.]])
+ result = gradients_impl._AggregateIndexedSlicesGradients([t0, t1])
+ self._assert_indexed_slices_equal(total, result)
+
+ def testMultipleGradientsWithNones(self):
+ t0 = math_ops._as_indexed_slices(constant_op.constant(
+ [[1., 2.], [0, 0], [3., 4.]]))
+ t1 = math_ops._as_indexed_slices(constant_op.constant(
+ [[0., 0.], [5, 6], [7., 8.]]))
+ t3 = None
+ total = constant_op.constant(
+ [[1., 2.], [5, 6], [10., 12.]])
+ result = gradients_impl._AggregateIndexedSlicesGradients([t0, t1, t3])
+ self._assert_indexed_slices_equal(total, result)
+
+ def testMixedTensorAndIndexedSlices(self):
+ t0 = math_ops._as_indexed_slices(constant_op.constant(
+ [[1., 2.], [0, 0], [3., 4.]]))
+ t1 = constant_op.constant(
+ [[0., 0.], [5, 6], [7., 8.]])
+ total = constant_op.constant(
+ [[1., 2.], [5, 6], [10., 12.]])
+ result = gradients_impl._AggregateIndexedSlicesGradients([t0, t1])
+ self._assert_indexed_slices_equal(total, result)
+
+
+class TensorListGradientsTest(test_util.TensorFlowTestCase):
+
+ def testDefaultGradYs(self):
+ with ops.Graph().as_default():
+ tl = list_ops.empty_tensor_list(
+ element_dtype=dtypes.float32,
+ element_shape=ops.convert_to_tensor([], dtype=dtypes.int32))
+ a = constant(1.0)
+ tl = list_ops.tensor_list_push_back(tl, a)
+
+ grad_tl = list_ops.empty_tensor_list(
+ element_dtype=dtypes.float32,
+ element_shape=ops.convert_to_tensor([], dtype=dtypes.int32))
+ grad_tl = list_ops.tensor_list_push_back(tl, constant(5.0))
+
+ grad = gradients.gradients(tl, a, grad_ys=grad_tl)[0]
+ with self.cached_session() as sess:
+ self.assertEquals(self.evaluate(grad), 5.)
+
+
+if __name__ == "__main__":
+ googletest.main()
diff --git a/test/TensorFlowNET.UnitTest/nn_test/nn_test.py b/test/TensorFlowNET.UnitTest/nn_test/nn_test.py
new file mode 100644
index 00000000..82fab741
--- /dev/null
+++ b/test/TensorFlowNET.UnitTest/nn_test/nn_test.py
@@ -0,0 +1,1243 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for miscellaneous functionality in tensorflow.ops.nn."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import math
+
+from absl.testing import parameterized
+import numpy as np
+from six.moves import xrange # pylint: disable=redefined-builtin
+
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gradient_checker
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn_impl
+from tensorflow.python.ops import nn_ops
+from tensorflow.python.ops import partitioned_variables
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops import variables
+import tensorflow.python.ops.nn_grad # pylint: disable=unused-import
+from tensorflow.python.ops.nn_impl import _compute_sampled_logits
+from tensorflow.python.platform import test as test_lib
+
+
+class ZeroFractionTest(test_lib.TestCase):
+
+ def _ZeroFraction(self, x):
+ assert x.shape
+ total_elements = np.prod(x.shape)
+ nonzeros = np.count_nonzero(x.flatten())
+ return 1.0 - nonzeros / total_elements
+
+ @test_util.run_deprecated_v1
+ def testZeroFraction(self):
+ x_shape = [5, 17]
+ x_np = np.random.randint(0, 2, size=x_shape).astype(np.float32)
+ y_np = self._ZeroFraction(x_np)
+
+ x_tf = constant_op.constant(x_np)
+ x_tf.set_shape(x_shape)
+ y_tf = nn_impl.zero_fraction(x_tf)
+ y_tf_np = self.evaluate(y_tf)
+
+ eps = 1e-8
+ self.assertAllClose(y_tf_np, y_np, eps)
+
+ @test_util.run_deprecated_v1
+ def testZeroFractionEmpty(self):
+ x = np.zeros(0)
+ y = self.evaluate(nn_impl.zero_fraction(x))
+ self.assertTrue(np.isnan(y))
+
+ @test_util.run_deprecated_v1
+ def testZeroFraction2_27Zeros(self):
+ sparsity = nn_impl.zero_fraction(
+ array_ops.zeros([int(2**27 * 1.01)], dtype=dtypes.int8))
+ self.assertAllClose(1.0, self.evaluate(sparsity))
+
+ @test_util.run_deprecated_v1
+ def testZeroFraction2_27Ones(self):
+ sparsity = nn_impl.zero_fraction(
+ array_ops.ones([int(2**27 * 1.01)], dtype=dtypes.int8))
+ self.assertAllClose(0.0, self.evaluate(sparsity))
+
+ @test_util.run_deprecated_v1
+ def testUnknownSize(self):
+ value = array_ops.placeholder(dtype=dtypes.float32)
+ sparsity = nn_impl.zero_fraction(value)
+ with self.cached_session() as sess:
+ self.assertAllClose(
+ 0.25,
+ sess.run(sparsity, {value: [[0., 1.], [0.3, 2.]]}))
+
+
+class SoftmaxTest(test_lib.TestCase, parameterized.TestCase):
+
+ def _softmax(self, x):
+ assert len(x.shape) == 2
+ m = x.max(1)[:, np.newaxis]
+ u = np.exp(x - m)
+ z = u.sum(1)[:, np.newaxis]
+ return u / z
+
+ @test_util.run_in_graph_and_eager_modes
+ def testSoftmax(self):
+ x_shape = [5, 10]
+ x_np = np.random.randn(*x_shape).astype(np.float32)
+ y_np = self._softmax(x_np)
+ x_tf = constant_op.constant(x_np)
+ y_tf = nn_ops.softmax_v2(x_tf)
+ y_tf_last_dim = nn_ops.softmax_v2(x_tf, 1)
+ y_tf_np = self.evaluate(y_tf)
+ y_tf_last_dim_np = self.evaluate(y_tf_last_dim)
+ eps = 1e-3
+ self.assertAllClose(y_tf_np, y_np, eps)
+ self.assertAllClose(y_tf_last_dim_np, y_np, eps)
+
+ def testSoftmaxAxes(self):
+ arr = np.linspace(0., 1, 12).reshape(3, 4)
+ x_neg_axis = nn_ops.softmax_v2(arr, axis=-2)
+ y_pos_axis = nn_ops.softmax_v2(arr, axis=0)
+ z_gt_axis = nn_ops.softmax_v2(arr, axis=0)
+ x_neg_axis_tf = self.evaluate(x_neg_axis)
+ y_pos_axis_tf = self.evaluate(y_pos_axis)
+ z_gt_axis_tf = self.evaluate(z_gt_axis)
+ eps = 1e-3
+ self.assertAllClose(x_neg_axis_tf, y_pos_axis_tf, eps)
+ self.assertAllClose(y_pos_axis_tf, z_gt_axis_tf, eps)
+
+ @parameterized.parameters(((5, 10),), ((2, 3, 4),))
+ @test_util.run_deprecated_v1
+ def testGradient(self, x_shape):
+ x_np = np.random.randn(*x_shape).astype(np.float64)
+ with self.cached_session():
+ x_tf = constant_op.constant(x_np)
+ y_tf = nn_ops.softmax_v2(x_tf)
+ err = gradient_checker.compute_gradient_error(x_tf, x_shape, y_tf,
+ x_shape)
+ eps = 2e-8
+ self.assertLess(err, eps)
+
+
+class LogPoissonLossTest(test_lib.TestCase):
+
+ def _log_poisson_loss(self, x, z, compute_full_loss=False):
+ lpl = np.exp(x) - z * x
+ if compute_full_loss:
+ stirling_approx = z * np.log(z) - z + 0.5 * np.log(2. * np.pi * z)
+ lpl += np.ma.masked_array(stirling_approx, mask=(z <= 1)).filled(0.)
+ return lpl
+
+ @test_util.run_in_graph_and_eager_modes
+ def testLogPoissonLoss(self):
+ x_shape = [5, 10]
+ x_np = np.random.randn(*x_shape).astype(np.float32)
+ z_np = np.random.randint(0, 5, size=x_shape).astype(np.float32)
+ y_np = self._log_poisson_loss(x_np, z_np, compute_full_loss=False)
+ y_np_stirling = self._log_poisson_loss(x_np, z_np, compute_full_loss=True)
+ y_tf = nn_impl.log_poisson_loss(z_np, x_np, compute_full_loss=False)
+ y_tf_stirling = nn_impl.log_poisson_loss(z_np, x_np, compute_full_loss=True)
+ y_tf_np = self.evaluate(y_tf)
+ y_tf_np_stirling = self.evaluate(y_tf_stirling)
+ eps = 1e-3
+ self.assertAllClose(y_tf_np, y_np, eps)
+ self.assertAllClose(y_tf_np_stirling, y_np_stirling, eps)
+
+ @test_util.run_deprecated_v1
+ def testGradient(self):
+ x_shape = [5, 10]
+ x_np = np.random.randn(*x_shape).astype(np.float64)
+ z_np = np.random.randint(0, 5, size=x_shape).astype(np.float64)
+ with self.cached_session():
+ x_tf = constant_op.constant(x_np)
+ y_tf = nn_impl.log_poisson_loss(z_np, x_tf, compute_full_loss=False)
+ y_tf_stirling = nn_impl.log_poisson_loss(
+ z_np, x_tf, compute_full_loss=True)
+ err = gradient_checker.compute_gradient_error(x_tf, x_shape, y_tf,
+ x_shape)
+ err_stirling = gradient_checker.compute_gradient_error(
+ x_tf, x_shape, y_tf_stirling, x_shape)
+ eps = 1e-6
+ self.assertLess(err, eps)
+ self.assertLess(err_stirling, eps)
+
+
+class LogSoftmaxTest(test_lib.TestCase, parameterized.TestCase):
+
+ def _log_softmax(self, x):
+ assert len(x.shape) == 2
+ m = x.max(1)[:, np.newaxis]
+ u = x - m
+ return u - np.log(np.sum(np.exp(u), 1, keepdims=True))
+
+ @test_util.run_in_graph_and_eager_modes
+ def testLogSoftmax(self):
+ x_shape = [5, 10]
+ x_np = np.random.randn(*x_shape).astype(np.float32)
+ y_np = self._log_softmax(x_np)
+ x_tf = constant_op.constant(x_np)
+ y_tf = nn_ops.log_softmax_v2(x_tf)
+ y_tf_np = self.evaluate(y_tf)
+ eps = 1e-3
+ self.assertAllClose(y_tf_np, y_np, eps)
+
+ def testLogSoftmaxAxes(self):
+ arr = np.linspace(0., 1, 12).reshape(3, 4)
+ x_neg_axis = nn_ops.log_softmax_v2(arr, axis=-2)
+ y_pos_axis = nn_ops.log_softmax_v2(arr, axis=0)
+ z_gt_axis = nn_ops.log_softmax_v2(arr, axis=0)
+ x_neg_axis_tf = self.evaluate(x_neg_axis)
+ y_pos_axis_tf = self.evaluate(y_pos_axis)
+ z_gt_axis_tf = self.evaluate(z_gt_axis)
+ eps = 1e-3
+ self.assertAllClose(x_neg_axis_tf, y_pos_axis_tf, eps)
+ self.assertAllClose(y_pos_axis_tf, z_gt_axis_tf, eps)
+
+ @parameterized.parameters(((5, 10),), ((2, 3, 4),))
+ @test_util.run_deprecated_v1
+ def testGradient(self, x_shape):
+ x_np = np.random.randn(*x_shape).astype(np.float64)
+ with self.cached_session():
+ x_tf = constant_op.constant(x_np)
+ y_tf = nn_ops.log_softmax_v2(x_tf)
+ err = gradient_checker.compute_gradient_error(x_tf, x_shape, y_tf,
+ x_shape)
+ eps = 1e-7
+ self.assertLess(err, eps)
+
+
+class L2LossTest(test_lib.TestCase):
+
+ @test_util.run_in_graph_and_eager_modes
+ def testL2Loss(self):
+ for dtype in [dtypes.float32, dtypes.float64]:
+ x = constant_op.constant(
+ [1.0, 0.0, 3.0, 2.0], shape=[2, 2], name="x", dtype=dtype)
+ l2loss = nn_ops.l2_loss(x)
+ value = self.evaluate(l2loss)
+ self.assertAllClose(7.0, value)
+
+ @test_util.run_deprecated_v1
+ def testGradient(self):
+ x_shape = [20, 7, 3]
+ np.random.seed(1) # Make it reproducible.
+ x_val = np.random.random_sample(x_shape).astype(np.float64)
+ with self.cached_session():
+ x = constant_op.constant(x_val, name="x")
+ output = nn_ops.l2_loss(x)
+ err = gradient_checker.compute_gradient_error(x, x_shape, output, [1])
+ print("L2Loss gradient err = %g " % err)
+ err_tolerance = 1e-10
+ self.assertLess(err, err_tolerance)
+
+
+class L2NormalizeTest(test_lib.TestCase):
+
+ def _l2Normalize(self, x, dim):
+ if isinstance(dim, list):
+ norm = np.linalg.norm(x, axis=tuple(dim))
+ for d in dim:
+ norm = np.expand_dims(norm, d)
+ return x / norm
+ else:
+ norm = np.apply_along_axis(np.linalg.norm, dim, x)
+ return x / np.expand_dims(norm, dim)
+
+ @test_util.run_in_graph_and_eager_modes
+ def testL2Normalize(self):
+ x_shape = [20, 7, 3]
+ np.random.seed(1)
+ x_np = np.random.random_sample(x_shape).astype(np.float32)
+ for dim in range(len(x_shape)):
+ y_np = self._l2Normalize(x_np, dim)
+ x_tf = constant_op.constant(x_np, name="x")
+ y_tf = nn_impl.l2_normalize_v2(x_tf, dim)
+ self.assertAllClose(y_np, self.evaluate(y_tf))
+
+ @test_util.run_in_graph_and_eager_modes
+ def testL2NormalizeDimArray(self):
+ x_shape = [20, 7, 3]
+ np.random.seed(1)
+ x_np = np.random.random_sample(x_shape).astype(np.float32)
+ dim = [1, 2]
+ y_np = self._l2Normalize(x_np, dim)
+ x_tf = constant_op.constant(x_np, name="x")
+ y_tf = nn_impl.l2_normalize_v2(x_tf, dim)
+ self.assertAllClose(y_np, self.evaluate(y_tf))
+
+ @test_util.run_deprecated_v1
+ def testL2NormalizeGradient(self):
+ x_shape = [20, 7, 3]
+ np.random.seed(1)
+ x_np = np.random.random_sample(x_shape).astype(np.float64)
+ for dim in range(len(x_shape)):
+ with self.cached_session():
+ x_tf = constant_op.constant(x_np, name="x")
+ y_tf = nn_impl.l2_normalize_v2(x_tf, dim)
+ err = gradient_checker.compute_gradient_error(x_tf, x_shape, y_tf,
+ x_shape)
+ print("L2Normalize gradient err = %g " % err)
+ self.assertLess(err, 1e-4)
+
+
+class DropoutTest(test_lib.TestCase):
+
+ def testDropout(self):
+ # Runs dropout with 0-1 tensor 10 times, sum the number of ones and validate
+ # that it is producing approximately the right number of ones over a large
+ # number of samples, based on the keep probability.
+ x_dim = 40
+ y_dim = 30
+ num_iter = 10
+ for keep_prob in [0.1, 0.5, 0.8]:
+ t = constant_op.constant(1.0, shape=[x_dim, y_dim], dtype=dtypes.float32)
+ dropout = nn_ops.dropout(t, keep_prob)
+ final_count = 0
+ self.assertEqual([x_dim, y_dim], dropout.get_shape())
+ for _ in xrange(0, num_iter):
+ value = self.evaluate(dropout)
+ final_count += np.count_nonzero(value)
+ # Verifies that there are only two values: 0 and 1/keep_prob.
+ sorted_value = np.unique(np.sort(value))
+ self.assertEqual(0, sorted_value[0])
+ self.assertAllClose(1 / keep_prob, sorted_value[1])
+
+ # Check that we are in the 15% error range
+ expected_count = x_dim * y_dim * keep_prob * num_iter
+ rel_error = math.fabs(final_count - expected_count) / expected_count
+ print(rel_error)
+ self.assertTrue(rel_error < 0.15)
+
+ def testShapedDropout(self):
+ # Runs dropout with 0-1 tensor 10 times, sum the number of ones and validate
+ # that it is producing approximately the right number of ones over a large
+ # number of samples, based on the keep probability. This time with shaped
+ # noise.
+ x_dim = 40 * 30
+ y_dim = 3
+ num_iter = 10
+ for keep_prob in [0.1, 0.5, 0.8]:
+ t = constant_op.constant(1.0, shape=[x_dim, y_dim], dtype=dtypes.float32)
+ dropout = nn_ops.dropout(t, keep_prob, noise_shape=[x_dim, 1])
+ self.assertEqual([x_dim, y_dim], dropout.get_shape())
+ final_count = 0
+ for _ in xrange(0, num_iter):
+ value = self.evaluate(dropout)
+ final_count += np.count_nonzero(value)
+ # Verifies that there are only two values: 0 and 1/keep_prob.
+ sorted_value = np.unique(np.sort(value))
+ self.assertEqual(0, sorted_value[0])
+ self.assertAllClose(1 / keep_prob, sorted_value[1])
+
+ # Check that we are in the 15% error range
+ expected_count = x_dim * y_dim * keep_prob * num_iter
+ rel_error = math.fabs(final_count - expected_count) / expected_count
+ print(rel_error)
+ self.assertTrue(rel_error < 0.15)
+
+ def testShapedDropoutCorrelation(self):
+ # Runs a shaped dropout and tests that the correlations are correct.
+ x_dim = 40
+ y_dim = 30
+ num_iter = 10
+ for keep_prob in [0.1, 0.5, 0.8]:
+ t = constant_op.constant(1.0, shape=[x_dim, y_dim], dtype=dtypes.float32)
+ dropout = nn_ops.dropout(t, keep_prob, noise_shape=[x_dim, 1])
+ self.assertEqual([x_dim, y_dim], dropout.get_shape())
+ for _ in xrange(0, num_iter):
+ value = self.evaluate(dropout)
+ # Verifies that each y column as only one type of activation.
+ for i in xrange(x_dim):
+ sorted_value = np.unique(np.sort(value[i, :]))
+ self.assertEqual(sorted_value.size, 1)
+
+ @test_util.run_deprecated_v1
+ def testDropoutPlaceholderKeepProb(self):
+ # Runs dropout with 0-1 tensor 10 times, sum the number of ones and validate
+ # that it is producing approximately the right number of ones over a large
+ # number of samples, based on the keep probability.
+ x_dim = 40
+ y_dim = 30
+ num_iter = 10
+ for keep_prob in [0.1, 0.5, 0.8]:
+ with self.cached_session():
+ t = constant_op.constant(
+ 1.0, shape=[x_dim, y_dim], dtype=dtypes.float32)
+ keep_prob_placeholder = array_ops.placeholder(dtypes.float32)
+ dropout = nn_ops.dropout(t, keep_prob_placeholder)
+ final_count = 0
+ self.assertEqual([x_dim, y_dim], dropout.get_shape())
+ for _ in xrange(0, num_iter):
+ value = dropout.eval(feed_dict={keep_prob_placeholder: keep_prob})
+ final_count += np.count_nonzero(value)
+ # Verifies that there are only two values: 0 and 1/keep_prob.
+ sorted_value = np.unique(np.sort(value))
+ self.assertEqual(0, sorted_value[0])
+ self.assertAllClose(1 / keep_prob, sorted_value[1])
+ # Check that we are in the 15% error range
+ expected_count = x_dim * y_dim * keep_prob * num_iter
+ rel_error = math.fabs(final_count - expected_count) / expected_count
+ print(rel_error)
+ self.assertTrue(rel_error < 0.15)
+
+ @test_util.run_deprecated_v1
+ def testShapedDropoutUnknownShape(self):
+ x_dim = 40
+ y_dim = 30
+ keep_prob = 0.5
+ x = constant_op.constant(1.0, shape=[x_dim, y_dim], dtype=dtypes.float32)
+ dropout_x = nn_ops.dropout(
+ x, keep_prob, noise_shape=array_ops.placeholder(dtypes.int32))
+ self.assertEqual(x.get_shape(), dropout_x.get_shape())
+
+ def testPartialShapedDropout(self):
+ x_dim = 40 * 30
+ y_dim = 3
+ num_iter = 10
+ for keep_prob in [0.1, 0.5, 0.8]:
+ t = constant_op.constant(1.0, shape=[x_dim, y_dim], dtype=dtypes.float32)
+ # Set noise_shape=[None, 1] which means [x_dim, 1].
+ dropout = nn_ops.dropout(t, keep_prob, noise_shape=[None, 1])
+ self.assertEqual([x_dim, y_dim], dropout.get_shape())
+ final_count = 0
+ for _ in xrange(0, num_iter):
+ value = self.evaluate(dropout)
+ final_count += np.count_nonzero(value)
+ # Verifies that there are only two values: 0 and 1/keep_prob.
+ sorted_value = np.unique(np.sort(value))
+ self.assertEqual(0, sorted_value[0])
+ self.assertAllClose(1 / keep_prob, sorted_value[1])
+
+ # Check that we are in the 15% error range
+ expected_count = x_dim * y_dim * keep_prob * num_iter
+ rel_error = math.fabs(final_count - expected_count) / expected_count
+ print(rel_error)
+ self.assertTrue(rel_error < 0.15)
+
+ @test_util.run_deprecated_v1
+ def testInvalidKeepProb(self):
+ x_dim = 40
+ y_dim = 30
+ t = constant_op.constant(1.0, shape=[x_dim, y_dim], dtype=dtypes.float32)
+ with self.assertRaises(ValueError):
+ nn_ops.dropout(t, -1.0)
+ with self.assertRaises(ValueError):
+ nn_ops.dropout(t, 1.1)
+ with self.assertRaises(ValueError):
+ nn_ops.dropout(t, [0.0, 1.0])
+ with self.assertRaises(ValueError):
+ nn_ops.dropout(t, array_ops.placeholder(dtypes.float64))
+ with self.assertRaises(ValueError):
+ nn_ops.dropout(t, array_ops.placeholder(dtypes.float32, shape=[2]))
+
+ @test_util.run_deprecated_v1
+ def testInvalidRate(self):
+ x_dim = 40
+ y_dim = 30
+ t = constant_op.constant(1.0, shape=[x_dim, y_dim], dtype=dtypes.float32)
+ with self.assertRaises(ValueError):
+ nn_ops.dropout_v2(t, -1.0)
+ with self.assertRaises(ValueError):
+ nn_ops.dropout_v2(t, 1.1)
+ with self.assertRaises(ValueError):
+ nn_ops.dropout_v2(t, [0.0, 1.0])
+
+ @test_util.run_deprecated_v1
+ def testShapedDropoutShapeError(self):
+ # Runs shaped dropout and verifies an error is thrown on misshapen noise.
+ x_dim = 40
+ y_dim = 30
+ keep_prob = 0.5
+ t = constant_op.constant(1.0, shape=[x_dim, y_dim], dtype=dtypes.float32)
+ with self.assertRaises(ValueError):
+ _ = nn_ops.dropout(t, keep_prob, noise_shape=[x_dim, y_dim + 10])
+ with self.assertRaises(ValueError):
+ _ = nn_ops.dropout(t, keep_prob, noise_shape=[x_dim, y_dim, 5])
+ with self.assertRaises(ValueError):
+ _ = nn_ops.dropout(t, keep_prob, noise_shape=[x_dim + 3])
+ with self.assertRaises(ValueError):
+ _ = nn_ops.dropout(t, keep_prob, noise_shape=[x_dim])
+ # test that broadcasting proceeds
+ _ = nn_ops.dropout(t, keep_prob, noise_shape=[y_dim])
+ _ = nn_ops.dropout(t, keep_prob, noise_shape=[1, y_dim])
+ _ = nn_ops.dropout(t, keep_prob, noise_shape=[x_dim, 1])
+ _ = nn_ops.dropout(t, keep_prob, noise_shape=[1, 1])
+
+ def testNoDropoutFast(self):
+ x = array_ops.zeros((5,))
+ y = nn_ops.dropout(x, keep_prob=1)
+ self.assertTrue(x is y)
+
+ y = nn_ops.dropout_v2(x, rate=0)
+ self.assertTrue(x is y)
+
+ def testDropoutWithIntegerInputs(self):
+ x = constant_op.constant([1, 1, 1, 1, 1])
+ with self.assertRaises(ValueError):
+ _ = nn_ops.dropout(x, 0.5)
+
+
+class ComputeSampledLogitsTest(test_lib.TestCase):
+
+ def setUp(self):
+ self._eps = 1e-3
+
+ def _GenerateTestData(self, num_classes, dim, batch_size, num_true, labels,
+ sampled, subtract_log_q):
+ """Randomly generates input/output data for a single test case.
+
+ This function returns numpy constants for use in a test case.
+
+ Args:
+ num_classes: An int. The number of embedding classes in the test case.
+ dim: An int. The dimension of the embedding.
+ batch_size: An int. The batch size.
+ num_true: An int. The number of target classes per training example.
+ labels: A list of batch_size * num_true ints. The target classes.
+ sampled: A list of indices in [0, num_classes).
+ subtract_log_q: A bool corresponding to the parameter in
+ _compute_sampled_logits().
+
+ Returns:
+ weights: Embedding weights to use as test input. It is a numpy array
+ of shape [num_classes, dim]
+ biases: Embedding biases to use as test input. It is a numpy array
+ of shape [num_classes].
+ hidden_acts: Forward activations of the network to use as test input.
+ It is a numpy array of shape [batch_size, dim].
+ sampled_vals: A tuple based on `sampled` to use as test input in the
+ format returned by a *_candidate_sampler function.
+ exp_logits: The output logits expected from _compute_sampled_logits().
+ It is a numpy array of shape [batch_size, num_true + len(sampled)].
+ exp_labels: The output labels expected from _compute_sampled_logits().
+ It is a numpy array of shape [batch_size, num_true + len(sampled)].
+ """
+ weights = np.random.randn(num_classes, dim).astype(np.float32)
+ biases = np.random.randn(num_classes).astype(np.float32)
+ hidden_acts = np.random.randn(batch_size, dim).astype(np.float32)
+
+ true_exp = np.full([batch_size, 1], fill_value=0.5, dtype=np.float32)
+ sampled_exp = np.full([len(sampled)], fill_value=0.5, dtype=np.float32)
+ sampled_vals = (sampled, true_exp, sampled_exp)
+
+ sampled_w, sampled_b = weights[sampled], biases[sampled]
+ true_w, true_b = weights[labels], biases[labels]
+
+ true_logits = np.sum(
+ hidden_acts.reshape((batch_size, 1, dim)) * true_w.reshape(
+ (batch_size, num_true, dim)),
+ axis=2)
+ true_b = true_b.reshape((batch_size, num_true))
+ true_logits += true_b
+ sampled_logits = np.dot(hidden_acts, sampled_w.T) + sampled_b
+
+ if subtract_log_q:
+ true_logits -= np.log(true_exp)
+ sampled_logits -= np.log(sampled_exp[np.newaxis, :])
+
+ exp_logits = np.concatenate([true_logits, sampled_logits], axis=1)
+ exp_labels = np.hstack((np.ones_like(true_logits) / num_true,
+ np.zeros_like(sampled_logits)))
+
+ return weights, biases, hidden_acts, sampled_vals, exp_logits, exp_labels
+
+ def _ShardTestEmbeddings(self, weights, biases, num_shards):
+ """Shards the weights and biases returned by _GenerateTestData.
+
+ Args:
+ weights: The weights returned by _GenerateTestData.
+ biases: The biases returned by _GenerateTestData.
+ num_shards: The number of shards to create.
+
+ Returns:
+ sharded_weights: A list of size `num_shards` containing all the weights.
+ sharded_biases: A list of size `num_shards` containing all the biases.
+ """
+ with ops.Graph().as_default() as g:
+ sharded_weights = variable_scope.get_variable(
+ "w",
+ partitioner=partitioned_variables.fixed_size_partitioner(num_shards),
+ initializer=constant_op.constant(weights))
+ sharded_biases = variable_scope.get_variable(
+ "b",
+ partitioner=partitioned_variables.fixed_size_partitioner(num_shards),
+ initializer=constant_op.constant(biases))
+ with self.session(graph=g) as sess:
+ variables.global_variables_initializer().run()
+ return self.evaluate([list(sharded_weights), list(sharded_biases)])
+
+ def testShapes(self):
+ np.random.seed(0)
+ num_classes = 5
+ batch_size = 3
+
+ for num_true in range(1, 5):
+ labels = np.random.randint(
+ low=0, high=num_classes, size=batch_size * num_true)
+ (weights, biases, hidden_acts, sampled_vals, exp_logits,
+ exp_labels) = self._GenerateTestData(
+ num_classes=num_classes,
+ dim=10,
+ batch_size=batch_size,
+ num_true=num_true,
+ labels=labels,
+ sampled=[1, 0, 2, 3],
+ subtract_log_q=False)
+ logits_tensor, labels_tensor = _compute_sampled_logits(
+ weights=constant_op.constant(weights),
+ biases=constant_op.constant(biases),
+ labels=constant_op.constant(
+ labels, dtype=dtypes.int64, shape=(batch_size, num_true)),
+ inputs=constant_op.constant(hidden_acts),
+ num_sampled=4,
+ num_classes=num_classes,
+ num_true=num_true,
+ sampled_values=sampled_vals,
+ subtract_log_q=False,
+ remove_accidental_hits=False,
+ partition_strategy="div",
+ name="sampled_logits_basic_num_true_%d" % num_true)
+ got_logits, got_labels = self.evaluate([logits_tensor, labels_tensor])
+ self.assertEqual(exp_logits.shape, got_logits.shape, self._eps)
+ self.assertEqual(exp_labels.shape, got_labels.shape, self._eps)
+
+ def testBasic(self):
+ """Without accidental hit removal or subtract_log_q."""
+ np.random.seed(0)
+ num_classes = 5
+ batch_size = 3
+
+ for num_true in range(1, 5):
+ labels = np.random.randint(
+ low=0, high=num_classes, size=batch_size * num_true)
+ (weights, biases, hidden_acts, sampled_vals, exp_logits,
+ exp_labels) = self._GenerateTestData(
+ num_classes=num_classes,
+ dim=10,
+ batch_size=batch_size,
+ num_true=num_true,
+ labels=labels,
+ sampled=[1, 0, 2, 3],
+ subtract_log_q=False)
+ logits_tensor, labels_tensor = _compute_sampled_logits(
+ weights=constant_op.constant(weights),
+ biases=constant_op.constant(biases),
+ labels=constant_op.constant(
+ labels, dtype=dtypes.int64, shape=(batch_size, num_true)),
+ inputs=constant_op.constant(hidden_acts),
+ num_sampled=4,
+ num_classes=num_classes,
+ num_true=num_true,
+ sampled_values=sampled_vals,
+ subtract_log_q=False,
+ remove_accidental_hits=False,
+ partition_strategy="div",
+ name="sampled_logits_basic_num_true_%d" % num_true)
+ got_logits, got_labels = self.evaluate([logits_tensor, labels_tensor])
+ self.assertAllClose(exp_logits, got_logits, self._eps)
+ self.assertAllClose(exp_labels, got_labels, self._eps)
+
+ def testAccidentalHitRemoval(self):
+ """With accidental hit removal, no subtract_log_q."""
+ np.random.seed(0)
+ num_classes = 5
+ batch_size = 3
+ sampled = [1, 0, 2, 3]
+
+ for num_true in range(1, 5):
+ labels = np.random.randint(
+ low=0, high=num_classes, size=batch_size * num_true)
+ (weights, biases, hidden_acts, sampled_vals, _,
+ _) = self._GenerateTestData(
+ num_classes=num_classes,
+ dim=10,
+ batch_size=batch_size,
+ num_true=num_true,
+ labels=labels,
+ sampled=sampled,
+ subtract_log_q=False)
+ logits_tensor, _ = _compute_sampled_logits(
+ weights=constant_op.constant(weights),
+ biases=constant_op.constant(biases),
+ labels=constant_op.constant(
+ labels, dtype=dtypes.int64, shape=(batch_size, num_true)),
+ inputs=constant_op.constant(hidden_acts),
+ num_sampled=len(sampled),
+ num_classes=num_classes,
+ num_true=num_true,
+ sampled_values=sampled_vals,
+ subtract_log_q=False,
+ remove_accidental_hits=True,
+ partition_strategy="div",
+ name="sampled_logits_accidental_hit_removal_num_true_%d" % num_true)
+ # Test that the exponentiated logits of accidental hits are near 0.
+ # First we need to find the hits in this random test run:
+ labels_reshape = labels.reshape((batch_size, num_true))
+ got_logits = self.evaluate(logits_tensor)
+ for row in xrange(batch_size):
+ row_labels = labels_reshape[row, :]
+ for col in xrange(len(sampled)):
+ if sampled[col] in row_labels:
+ # We need to add the num_true_test offset into logits_*
+ self.assertNear(
+ np.exp(got_logits[row, col + num_true]), 0., self._eps)
+
+ def testSubtractLogQ(self):
+ """With subtract_log_q, no accidental hit removal."""
+ np.random.seed(0)
+ num_classes = 5
+ batch_size = 3
+
+ for num_true in range(1, 5):
+ labels = np.random.randint(
+ low=0, high=num_classes, size=batch_size * num_true)
+ (weights, biases, hidden_acts, sampled_vals, exp_logits,
+ exp_labels) = self._GenerateTestData(
+ num_classes=num_classes,
+ dim=10,
+ batch_size=batch_size,
+ num_true=num_true,
+ labels=labels,
+ sampled=[1, 0, 2, 3],
+ subtract_log_q=True)
+ logits_tensor, labels_tensor = _compute_sampled_logits(
+ weights=constant_op.constant(weights),
+ biases=constant_op.constant(biases),
+ labels=constant_op.constant(
+ labels, dtype=dtypes.int64, shape=(batch_size, num_true)),
+ inputs=constant_op.constant(hidden_acts),
+ num_sampled=4,
+ num_classes=num_classes,
+ num_true=num_true,
+ sampled_values=sampled_vals,
+ subtract_log_q=True,
+ remove_accidental_hits=False,
+ partition_strategy="div",
+ name="sampled_logits_subtract_log_q_num_true_%d" % num_true)
+ got_logits, got_labels = self.evaluate([logits_tensor, labels_tensor])
+ self.assertAllClose(exp_logits, got_logits, self._eps)
+ self.assertAllClose(exp_labels, got_labels, self._eps)
+
+ def testSharded(self):
+ """With sharded weights and sharded biases."""
+ np.random.seed(0)
+ num_classes = 5
+ batch_size = 3
+
+ for num_true in range(1, 5):
+ labels = np.random.randint(
+ low=0, high=num_classes, size=batch_size * num_true)
+ (weights, biases, hidden_acts, sampled_vals, exp_logits,
+ exp_labels) = self._GenerateTestData(
+ num_classes=num_classes,
+ dim=10,
+ batch_size=batch_size,
+ num_true=num_true,
+ labels=labels,
+ sampled=[1, 0, 2, 3],
+ subtract_log_q=False)
+ weight_shards, bias_shards = self._ShardTestEmbeddings(
+ weights, biases, num_shards=3)
+ logits_tensor, labels_tensor = _compute_sampled_logits(
+ weights=[constant_op.constant(shard) for shard in weight_shards],
+ biases=[constant_op.constant(shard) for shard in bias_shards],
+ labels=constant_op.constant(
+ labels, dtype=dtypes.int64, shape=(batch_size, num_true)),
+ inputs=constant_op.constant(hidden_acts),
+ num_sampled=4,
+ num_classes=num_classes,
+ num_true=num_true,
+ sampled_values=sampled_vals,
+ subtract_log_q=False,
+ remove_accidental_hits=False,
+ partition_strategy="div",
+ name="sampled_logits_sharded_num_true_%d" % num_true)
+ got_logits, got_labels = self.evaluate([logits_tensor, labels_tensor])
+ self.assertAllClose(exp_logits, got_logits, self._eps)
+ self.assertAllClose(exp_labels, got_labels, self._eps)
+
+ def testNCELoss(self):
+ # A simple test to verify the numerics.
+
+ def _SigmoidCrossEntropyWithLogits(logits, targets):
+ # logits, targets: float arrays of the same shape.
+ assert logits.shape == targets.shape
+ pred = 1. / (1. + np.exp(-logits))
+ eps = 0.0001
+ pred = np.minimum(np.maximum(pred, eps), 1 - eps)
+ return -targets * np.log(pred) - (1. - targets) * np.log(1. - pred)
+
+ np.random.seed(0)
+ num_classes = 5
+ batch_size = 3
+ labels = [0, 1, 2]
+ (weights, biases, hidden_acts, sampled_vals, exp_logits,
+ exp_labels) = self._GenerateTestData(
+ num_classes=num_classes,
+ dim=10,
+ batch_size=batch_size,
+ num_true=1,
+ labels=labels,
+ sampled=[1, 0, 2, 3],
+ subtract_log_q=True)
+ exp_nce_loss = np.sum(
+ _SigmoidCrossEntropyWithLogits(exp_logits, exp_labels), 1)
+
+ got_nce_loss = nn_impl.nce_loss_v2(
+ weights=constant_op.constant(weights),
+ biases=constant_op.constant(biases),
+ labels=constant_op.constant(labels, shape=(batch_size, 1)),
+ inputs=constant_op.constant(hidden_acts),
+ num_sampled=4,
+ num_classes=num_classes,
+ num_true=1,
+ sampled_values=sampled_vals)
+
+ self.assertAllClose(exp_nce_loss, self.evaluate(got_nce_loss), 1e-4)
+
+ # Test with sharded weights and sharded biases.
+ weight_shards, bias_shards = self._ShardTestEmbeddings(
+ weights, biases, num_shards=3)
+ got_nce_loss = nn_impl.nce_loss_v2(
+ weights=[constant_op.constant(shard) for shard in weight_shards],
+ biases=[constant_op.constant(shard) for shard in bias_shards],
+ labels=constant_op.constant(labels, shape=(batch_size, 1)),
+ inputs=constant_op.constant(hidden_acts),
+ num_sampled=4,
+ num_classes=num_classes,
+ num_true=1,
+ sampled_values=sampled_vals)
+
+ self.assertAllClose(exp_nce_loss, self.evaluate(got_nce_loss), 1e-4)
+
+ def testSampledSoftmaxLoss(self):
+ # A simple test to verify the numerics.
+
+ def _SoftmaxCrossEntropyWithLogits(logits, targets):
+ # logits, targets: float arrays of the same shape.
+ assert logits.shape == targets.shape
+ stable_exp_logits = np.exp(
+ logits - np.amax(logits, axis=1, keepdims=True))
+ pred = stable_exp_logits / np.sum(stable_exp_logits, 1, keepdims=True)
+ return -np.sum(targets * np.log(pred + 1.0e-20), axis=1)
+
+ np.random.seed(0)
+ num_classes = 5
+ batch_size = 3
+ labels = [0, 1, 2]
+ (weights, biases, hidden_acts, sampled_vals, exp_logits,
+ exp_labels) = self._GenerateTestData(
+ num_classes=num_classes,
+ dim=10,
+ batch_size=batch_size,
+ num_true=1,
+ labels=labels,
+ sampled=[1, 0, 2, 3],
+ subtract_log_q=True)
+ exp_sampled_softmax_loss = _SoftmaxCrossEntropyWithLogits(
+ exp_logits, exp_labels)
+
+ got_sampled_softmax_loss = nn_impl.sampled_softmax_loss_v2(
+ weights=constant_op.constant(weights),
+ biases=constant_op.constant(biases),
+ labels=constant_op.constant(labels, shape=(batch_size, 1)),
+ inputs=constant_op.constant(hidden_acts),
+ num_sampled=4,
+ num_classes=num_classes,
+ num_true=1,
+ sampled_values=sampled_vals,
+ remove_accidental_hits=False)
+
+ self.assertAllClose(exp_sampled_softmax_loss,
+ self.evaluate(got_sampled_softmax_loss), 1e-4)
+
+ # Test with sharded weights and sharded biases.
+ weight_shards, bias_shards = self._ShardTestEmbeddings(
+ weights, biases, num_shards=3)
+ got_sampled_softmax_loss = nn_impl.sampled_softmax_loss_v2(
+ weights=[constant_op.constant(shard) for shard in weight_shards],
+ biases=[constant_op.constant(shard) for shard in bias_shards],
+ labels=constant_op.constant(labels, shape=(batch_size, 1)),
+ inputs=constant_op.constant(hidden_acts),
+ num_sampled=4,
+ num_classes=num_classes,
+ num_true=1,
+ sampled_values=sampled_vals,
+ remove_accidental_hits=False)
+
+ self.assertAllClose(exp_sampled_softmax_loss,
+ self.evaluate(got_sampled_softmax_loss), 1e-4)
+
+ def testSampledSoftmaxLossBf16(self):
+ # A simple test to verify the numerics for bfloat16.
+ def _SoftmaxCrossEntropyWithLogits(logits, targets):
+ # logits, targets: float arrays of the same shape.
+ assert logits.shape == targets.shape
+ stable_exp_logits = np.exp(
+ logits - np.amax(logits, axis=1, keepdims=True))
+ pred = stable_exp_logits / np.sum(stable_exp_logits, 1, keepdims=True)
+ return -np.sum(targets * np.log(pred + 1.0e-20), axis=1)
+
+ np.random.seed(0)
+ num_classes = 5
+ batch_size = 3
+ labels = [0, 1, 2]
+ sampled = [1, 0, 2, 3]
+ (weights, biases, hidden_acts, _, exp_logits,
+ exp_labels) = self._GenerateTestData(
+ num_classes=num_classes,
+ dim=10,
+ batch_size=batch_size,
+ num_true=1,
+ labels=labels,
+ sampled=sampled,
+ subtract_log_q=True)
+ exp_sampled_softmax_loss = _SoftmaxCrossEntropyWithLogits(
+ exp_logits, exp_labels)
+
+ true_exp_bf16 = np.full([batch_size, 1],
+ fill_value=0.5,
+ dtype=dtypes.bfloat16.as_numpy_dtype)
+ sampled_exp_bf16 = np.full([len(sampled)],
+ fill_value=0.5,
+ dtype=dtypes.bfloat16.as_numpy_dtype)
+ sampled_vals_bf16 = (sampled, true_exp_bf16, sampled_exp_bf16)
+
+ got_sampled_softmax_loss = math_ops.cast(
+ nn_impl.sampled_softmax_loss_v2(
+ weights=constant_op.constant(weights, dtype=dtypes.bfloat16),
+ biases=constant_op.constant(biases, dtype=dtypes.bfloat16),
+ labels=constant_op.constant(
+ labels, shape=(batch_size, 1), dtype=dtypes.bfloat16),
+ inputs=constant_op.constant(hidden_acts, dtype=dtypes.bfloat16),
+ num_sampled=4,
+ num_classes=num_classes,
+ num_true=1,
+ sampled_values=sampled_vals_bf16,
+ remove_accidental_hits=False), dtypes.float32)
+
+ self.assertAllClose(exp_sampled_softmax_loss,
+ self.evaluate(got_sampled_softmax_loss), 1e-1)
+
+
+class CReluTest(test_lib.TestCase):
+
+ def test(self):
+ np.random.seed(1) # Make it reproducible.
+ x = np.random.randn(3, 4).astype(np.float32)
+ y = np.concatenate([x * (x > 0), -x * (x < 0)], axis=1)
+
+ z = self.evaluate(nn_ops.crelu(constant_op.constant(x)))
+ self.assertAllClose(y, z, 1e-4)
+
+
+class ReluTest(test_lib.TestCase):
+
+ def test(self):
+ np.random.seed(1) # Make it reproducible.
+ x = np.random.randn(3, 4).astype(np.float32)
+ y = np.maximum(x, 0.0)
+
+ z = self.evaluate(nn_ops.relu(constant_op.constant(x)))
+ self.assertAllEqual(y, z)
+
+ @test_util.run_deprecated_v1
+ def testNaNs(self):
+ # Test that relu(nan) = nan for various sizes.
+ for i in range(18):
+ x = np.zeros(i) + np.nan
+ with self.cached_session():
+ z = nn_ops.relu(constant_op.constant(x)).eval()
+ self.assertTrue(np.isnan(z).all())
+
+
+class LeakyReluTest(test_lib.TestCase):
+
+ def testRange(self):
+ batch_size = 3
+ height, width = 4, 4
+ np.random.seed(1) # Make it reproducible.
+ inputs = np.random.uniform(size=(batch_size, height, width, 3)).astype(
+ np.float32)
+ inputs = constant_op.constant(inputs)
+
+ outputs = nn_ops.leaky_relu(inputs)
+ self.assertEquals(inputs.shape, outputs.shape)
+
+ inputs, outputs = self.evaluate([inputs, outputs])
+
+ self.assertGreaterEqual(outputs.min(), 0.0)
+ self.assertLessEqual(outputs.max(), 1.0)
+ self.assertAllClose(inputs, outputs)
+
+ @test_util.run_deprecated_v1
+ def testValues(self):
+ for dtype in [np.int32, np.int64, np.float16, np.float32, np.float64]:
+ np_values = np.array([-2, -1, 0, 1, 2], dtype=dtype)
+ outputs = nn_ops.leaky_relu(constant_op.constant(np_values))
+
+ outputs = self.evaluate(outputs)
+
+ tol = 2e-3 if dtype == np.float16 else 1e-6
+ self.assertAllClose(
+ outputs, [-0.4, -0.2, 0.0, 1.0, 2.0], rtol=tol, atol=tol)
+
+ @test_util.run_deprecated_v1
+ def testName(self):
+ np_values = np.array([-2, -1, 0, 1, 2], dtype=np.float64)
+ outputs_with_name_set = nn_ops.leaky_relu(
+ constant_op.constant(np_values),
+ name='test_relu_op')
+ self.assertEqual(outputs_with_name_set.name, 'test_relu_op:0')
+ outputs_without_name_set = nn_ops.leaky_relu(
+ constant_op.constant(np_values))
+ self.assertEqual(outputs_without_name_set.name, 'LeakyRelu:0')
+
+
+class SwishTest(test_lib.TestCase):
+
+ @test_util.run_deprecated_v1
+ def testValues(self):
+ np_values = np.array(
+ [np.linspace(-10.0, 0.0, 100),
+ np.linspace(0.0, 10.0, 100)],
+ dtype=np.float32)
+ tf_values = constant_op.constant(np_values)
+ actual_tf_outputs = nn_impl.swish(tf_values)
+ expected_tf_outputs = tf_values * math_ops.sigmoid(tf_values)
+
+ actual_outputs, expected_outputs = self.evaluate(
+ [actual_tf_outputs, expected_tf_outputs])
+
+ self.assertAllClose(actual_outputs, expected_outputs)
+
+ @test_util.run_deprecated_v1
+ def testGradients(self):
+ shape = [5, 3, 4]
+ sigma = 5
+ input_values = np.random.randn(*shape) * sigma
+ x_tf = constant_op.constant(input_values)
+ y_tf = nn_impl.swish(x_tf)
+ with self.cached_session():
+ err = gradient_checker.compute_gradient_error(x_tf, shape, y_tf, shape)
+ self.assertLess(err, 1e-4)
+
+
+class MomentsTest(test_lib.TestCase):
+
+ def doOutputTest(self,
+ input_shape,
+ moments_axes,
+ tol=1e-4,
+ check_gradients=False):
+ for mu in [0.0, 1.0, 1e3]:
+ for sigma in [1.0, 0.1]:
+ for keep_dims in [True, False]:
+ input_values = np.random.rand(*input_shape) * sigma + mu
+ expected_mean = np.mean(
+ input_values, axis=moments_axes, keepdims=keep_dims)
+ expected_var = np.var(
+ input_values, axis=moments_axes, keepdims=keep_dims)
+ with ops.Graph().as_default() as g:
+ with self.session(graph=g) as sess:
+ inputs = constant_op.constant(
+ input_values, shape=input_shape, dtype=dtypes.float32)
+ mean, variance = nn_impl.moments_v2(
+ inputs, moments_axes, keepdims=keep_dims)
+
+ if check_gradients:
+ err = gradient_checker.compute_gradient_error(
+ inputs, input_shape, mean, mean.shape.as_list())
+ self.assertLess(err, 1e-3)
+ err = gradient_checker.compute_gradient_error(
+ inputs, input_shape, variance, variance.shape.as_list())
+ self.assertLess(err, 1e-3)
+
+ # Evaluate.
+ [mean, variance] = self.evaluate([mean, variance])
+ # Make sure that there are no NaNs
+ self.assertFalse(np.isnan(mean).any())
+ self.assertFalse(np.isnan(variance).any())
+ self.assertAllClose(mean, expected_mean, rtol=tol, atol=tol)
+ self.assertAllClose(variance, expected_var, rtol=tol, atol=tol)
+
+ def testOutputAndGradient2DInput0(self):
+ self.doOutputTest((10, 10), (0,), check_gradients=True)
+
+ def testOutputAndGradient2DInput01(self):
+ self.doOutputTest((10, 10), (0, 1), check_gradients=True)
+
+ def testOutput2DInput0(self):
+ self.doOutputTest((10, 300), (0,))
+
+ def testOutput2DInput1(self):
+ self.doOutputTest((10, 300), (1,))
+
+ def testOutput2DInput01(self):
+ self.doOutputTest((10, 300), (0, 1))
+
+ def testOutput4DInput0(self):
+ self.doOutputTest((10, 10, 10, 30), (0,))
+
+ def testOutput4DInput1(self):
+ self.doOutputTest((10, 10, 10, 30), (1,))
+
+ def testOutput4DInput3(self):
+ self.doOutputTest((10, 10, 10, 30), (3,))
+
+ def testOutput4DInput012(self):
+ self.doOutputTest((10, 10, 10, 30), (0, 1, 2))
+
+ def testOutput4DInput123(self):
+ self.doOutputTest((10, 10, 10, 30), (1, 2, 3))
+
+
+class DataFormatDimMapTest(test_lib.TestCase):
+
+ def _test(self, x_val, y_val_expected):
+ x = constant_op.constant(x_val)
+ y = nn_ops.data_format_dim_map(x)
+
+ y_val = self.evaluate(y)
+ self.assertAllEqual(y_val, y_val_expected)
+
+ def test(self):
+ self._test(0, 0)
+ self._test(1, 2)
+ self._test(2, 3)
+ self._test(3, 1)
+ self._test(-1, 1)
+ self._test(-2, 3)
+ self._test(-3, 2)
+ self._test(-4, 0)
+ self._test([1, 3], [2, 1])
+ self._test([1, 3, -2], [2, 1, 3])
+ self._test([1, -3, -2], [2, 2, 3])
+ self._test([[1, -3], [1, -1]], [[2, 2], [2, 1]])
+
+ def testNHWCtoNCHW(self):
+ x_val = [1, -3, -2]
+ y_val_expected = [2, 2, 3]
+ x = constant_op.constant(x_val)
+ y = nn_ops.data_format_dim_map(x, src_format="NHWC", dst_format="NCHW")
+ with test_util.use_gpu():
+ y_val = self.evaluate(y)
+ self.assertAllEqual(y_val, y_val_expected)
+
+ def testNHWCtoHWNC(self):
+ x_val = [-4, -3, -2, -1, 0, 1, 2, 3]
+ y_val_expected = [2, 0, 1, 3, 2, 0, 1, 3]
+ x = constant_op.constant(x_val)
+ y = nn_ops.data_format_dim_map(x, src_format="NHWC", dst_format="HWNC")
+ with test_util.use_gpu():
+ y_val = self.evaluate(y)
+ self.assertAllEqual(y_val, y_val_expected)
+
+ def testNHWCtoWHCN(self):
+ x_val = [-4, -3, -2, -1, 0, 1, 2, 3]
+ y_val_expected = [3, 1, 0, 2, 3, 1, 0, 2]
+ x = constant_op.constant(x_val)
+ y = nn_ops.data_format_dim_map(x, src_format="NHWC", dst_format="WHCN")
+ with test_util.use_gpu():
+ y_val = self.evaluate(y)
+ self.assertAllEqual(y_val, y_val_expected)
+
+ def testArbitraryASCII(self):
+ x_val = [-4, -3, -2, -1, 0, 1, 2, 3]
+ y_val_expected = [3, 2, 1, 0, 3, 2, 1, 0]
+ x = constant_op.constant(x_val)
+ y = nn_ops.data_format_dim_map(x, src_format="qwer", dst_format="rewq")
+ with test_util.use_gpu():
+ y_val = self.evaluate(y)
+ self.assertAllEqual(y_val, y_val_expected)
+
+
+class DataFormatVectorPermuteTest(test_lib.TestCase):
+
+ def testNHWCToNCHW(self):
+ x_val = [7, 4, 9, 3]
+ x = constant_op.constant(x_val)
+ y = nn_ops.data_format_vec_permute(x)
+ with test_util.use_gpu():
+ y_val = self.evaluate(y)
+ self.assertAllEqual(y_val, [7, 3, 4, 9])
+
+ def testNCHWToNHWC(self):
+ x_val = [7, 4, 9, 3]
+ x = constant_op.constant(x_val)
+ y = nn_ops.data_format_vec_permute(x, src_format="NCHW", dst_format="NHWC")
+ with test_util.use_gpu():
+ y_val = self.evaluate(y)
+ self.assertAllEqual(y_val, [7, 9, 3, 4])
+
+ def testNHWCToHWNC(self):
+ x_val = [7, 4, 9, 3]
+ x = constant_op.constant(x_val)
+ y = nn_ops.data_format_vec_permute(x, src_format="NHWC", dst_format="HWNC")
+ with test_util.use_gpu():
+ y_val = self.evaluate(y)
+ self.assertAllEqual(y_val, [4, 9, 7, 3])
+
+ def testHWNCToNHWC(self):
+ x_val = [7, 4, 9, 3]
+ x = constant_op.constant(x_val)
+ y = nn_ops.data_format_vec_permute(x, src_format="HWNC", dst_format="NHWC")
+ with test_util.use_gpu():
+ y_val = self.evaluate(y)
+ self.assertAllEqual(y_val, [9, 7, 4, 3])
+
+ def testNHWCToNCHW2D(self):
+ x_val = [[7, 4], [9, 3], [4, 5], [5, 1]]
+ x = constant_op.constant(x_val)
+ y = nn_ops.data_format_vec_permute(x)
+ with test_util.use_gpu():
+ y_val = self.evaluate(y)
+ self.assertAllEqual(y_val, [[7, 4], [5, 1], [9, 3], [4, 5]])
+
+ def testNHWCToHWNC2D(self):
+ x_val = [[7, 4], [9, 3], [4, 5], [5, 1]]
+ x = constant_op.constant(x_val)
+ y = nn_ops.data_format_vec_permute(x, src_format="NHWC", dst_format="HWNC")
+ with test_util.use_gpu():
+ y_val = self.evaluate(y)
+ self.assertAllEqual(y_val, [[9, 3], [4, 5], [7, 4], [5, 1]])
+
+ def testHWNCToNHWC2D(self):
+ x_val = [[7, 4], [9, 3], [4, 5], [5, 1]]
+ x = constant_op.constant(x_val)
+ y = nn_ops.data_format_vec_permute(x, src_format="HWNC", dst_format="NHWC")
+ with test_util.use_gpu():
+ y_val = self.evaluate(y)
+ self.assertAllEqual(y_val, [[4, 5], [7, 4], [9, 3], [5, 1]])
+
+ def testNCHWToNHWC2D(self):
+ x_val = [[7, 4], [9, 3], [4, 5], [5, 1]]
+ x = constant_op.constant(x_val)
+ y = nn_ops.data_format_vec_permute(x, src_format="NCHW", dst_format="NHWC")
+ with test_util.use_gpu():
+ y_val = self.evaluate(y)
+ self.assertAllEqual(y_val, [[7, 4], [4, 5], [5, 1], [9, 3]])
+
+
+if __name__ == "__main__":
+ test_lib.main()