From 8b5ef9c9a16cebeeb3ae66ad82478d77c12b6ca4 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Wed, 19 Feb 2020 05:55:11 -0600 Subject: [PATCH] AddV2 --- src/TensorFlowNET.Core/Gradients/nn_grad.cs | 16 +++++++++++++--- .../Operations/OpDefLibrary.cs | 4 ++++ .../Operations/Operation.Implicit.cs | 2 +- .../Operations/Operation.Input.cs | 3 +++ .../Operations/Operation.Output.cs | 3 +++ src/TensorFlowNET.Core/Operations/Operation.cs | 7 +++++++ src/TensorFlowNET.Core/Operations/array_ops.cs | 2 +- .../Operations/gen_math_ops.cs | 3 ++- src/TensorFlowNET.Core/Operations/math_ops.cs | 6 ++++++ src/TensorFlowNET.Core/Operations/nn_impl.py.cs | 2 +- src/TensorFlowNET.Core/TensorFlow.Binding.csproj | 3 ++- src/TensorFlowNET.Core/Tensors/Tensor.cs | 3 +++ src/TensorFlowNET.Core/ops.name_scope.cs | 4 ++-- .../Hub/MnistModelLoaderTest.cs | 2 +- 14 files changed, 49 insertions(+), 11 deletions(-) diff --git a/src/TensorFlowNET.Core/Gradients/nn_grad.cs b/src/TensorFlowNET.Core/Gradients/nn_grad.cs index 967b3c21..e4502ad8 100644 --- a/src/TensorFlowNET.Core/Gradients/nn_grad.cs +++ b/src/TensorFlowNET.Core/Gradients/nn_grad.cs @@ -170,6 +170,14 @@ namespace Tensorflow.Gradients public static Tensor[] _FusedBatchNormGrad(Operation op, Tensor[] grads) => _BaseFusedBatchNormGrad(op, 0, grads); + [RegisterGradient("FusedBatchNormV2")] + public static Tensor[] _FusedBatchNormV2Grad(Operation op, Tensor[] grads) + => _BaseFusedBatchNormGrad(op, 1, grads); + + [RegisterGradient("FusedBatchNormV3")] + public static Tensor[] _FusedBatchNormV3Grad(Operation op, Tensor[] grads) + => _BaseFusedBatchNormGrad(op, 2, grads); + /// /// Return the gradients for the 3 inputs of BatchNorm. /// @@ -190,8 +198,10 @@ namespace Tensorflow.Gradients switch (version) { case 2: - throw new NotImplementedException(""); + grad_fun = gen_nn_ops.fused_batch_norm_grad_v3; + break; case 1: + // grad_fun = gen_nn_ops.fused_batch_norm_grad_v2; throw new NotImplementedException(""); default: grad_fun = gen_nn_ops.fused_batch_norm_grad; @@ -225,8 +235,8 @@ namespace Tensorflow.Gradients YBackprop = grad_y, X = x, Scale = scale, - ReserveSpace1 = op.outputs[3], - ReserveSpace2 = op.outputs[4], + ReserveSpace1 = pop_mean, + ReserveSpace2 = pop_var, ReserveSpace3 = version == 2 ? op.outputs[5] : null, Epsilon = epsilon, DataFormat = data_format, diff --git a/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs b/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs index 5700ccdd..e842fcb4 100644 --- a/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs +++ b/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs @@ -65,6 +65,10 @@ namespace Tensorflow var base_types = new List(); var types = new List(); +#if DEBUG + if (op_type_name == "FusedBatchNormGradV3") + ; +#endif // Perform input type inference foreach (var input_arg in op_def.InputArg) { diff --git a/src/TensorFlowNET.Core/Operations/Operation.Implicit.cs b/src/TensorFlowNET.Core/Operations/Operation.Implicit.cs index 8de412c8..9cadac0c 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.Implicit.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.Implicit.cs @@ -32,7 +32,7 @@ namespace Tensorflow public override string ToString() { - return _handle == IntPtr.Zero ? "tf.Operation Undefined" : $"tf.Operation '{name}' type={OpType}"; + return _handle == IntPtr.Zero ? "tf.Operation Undefined" : $""; } public override bool Equals(object obj) diff --git a/src/TensorFlowNET.Core/Operations/Operation.Input.cs b/src/TensorFlowNET.Core/Operations/Operation.Input.cs index fdf92504..bbd11b0e 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.Input.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.Input.cs @@ -80,6 +80,9 @@ namespace Tensorflow /// reasons, or to ensure that the side effects of an op are observed /// in the correct order. /// +#if SERIALIZABLE + [JsonIgnore] +#endif public Operation[] control_inputs { get diff --git a/src/TensorFlowNET.Core/Operations/Operation.Output.cs b/src/TensorFlowNET.Core/Operations/Operation.Output.cs index abe8e9c1..3a70c107 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.Output.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.Output.cs @@ -45,6 +45,9 @@ namespace Tensorflow } private Tensor[] _outputs; +#if SERIALIZABLE + [JsonIgnore] +#endif public Tensor[] outputs => _outputs; #if SERIALIZABLE [JsonIgnore] diff --git a/src/TensorFlowNET.Core/Operations/Operation.cs b/src/TensorFlowNET.Core/Operations/Operation.cs index 831e6ca5..a0acf4bb 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.cs @@ -74,6 +74,9 @@ namespace Tensorflow public TF_DataType dtype => TF_DataType.DtInvalid; public string name => _handle == IntPtr.Zero ? null : c_api.StringPiece(c_api.TF_OperationName(_handle)); public string OpType => _handle == IntPtr.Zero ? null : c_api.StringPiece(c_api.TF_OperationOpType(_handle)); +#if SERIALIZABLE + [JsonIgnore] +#endif public string Device => _handle == IntPtr.Zero ? null : c_api.StringPiece(c_api.TF_OperationDevice(_handle)); #if SERIALIZABLE [JsonIgnore] @@ -152,6 +155,10 @@ namespace Tensorflow { _graph = g; +#if DEBUG + if (node_def.Name == "define_second_stage_train/gradients/define_loss/conv_lobj_branch/batch_normalization/cond/FusedBatchNormV3_1_grad/FusedBatchNormGradV3") + ; +#endif // Build the list of control inputs. var control_input_ops = new List(); if (control_inputs != null) diff --git a/src/TensorFlowNET.Core/Operations/array_ops.cs b/src/TensorFlowNET.Core/Operations/array_ops.cs index f9f2f58f..5374d72b 100644 --- a/src/TensorFlowNET.Core/Operations/array_ops.cs +++ b/src/TensorFlowNET.Core/Operations/array_ops.cs @@ -427,7 +427,7 @@ namespace Tensorflow if (!tf.context.executing_eagerly()) { var input_tensor = ops.convert_to_tensor(input); - var input_shape = tensor_util.to_shape(input_tensor.shape); + var input_shape = input_tensor.TensorShape; if (optimize && input_tensor.NDims > -1 && input_shape.is_fully_defined()) { var nd = np.array(input_tensor.shape).astype(out_type.as_numpy_dtype()); diff --git a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs index 5cf240e8..7621d1b2 100644 --- a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs @@ -141,7 +141,8 @@ namespace Tensorflow public static Tensor add(Tx x, Ty y, string name = null) { - var _op = _op_def_lib._apply_op_helper("Add", name, args: new { x, y }); + // forward_compatible(2019, 6, 25): + var _op = _op_def_lib._apply_op_helper("AddV2", name, args: new { x, y }); return _op.output; } diff --git a/src/TensorFlowNET.Core/Operations/math_ops.cs b/src/TensorFlowNET.Core/Operations/math_ops.cs index bb8d7134..bc904028 100644 --- a/src/TensorFlowNET.Core/Operations/math_ops.cs +++ b/src/TensorFlowNET.Core/Operations/math_ops.cs @@ -355,6 +355,9 @@ namespace Tensorflow return _may_reduce_to_scalar(keepdims, axis, all); } + public static Tensor realdiv(Tensor x, Tensor y, string name = null) + => gen_math_ops.real_div(x, y, name: name); + /// /// Computes log(sum(exp(elements across dimensions of a tensor))). /// Reduces `input_tensor` along the dimensions given in `axis`. @@ -561,6 +564,9 @@ namespace Tensorflow public static Tensor rsqrt(Tensor x, string name = null) => gen_math_ops.rsqrt(x, name: name); + public static Tensor pow(Tx x, Ty y, string name = null) + => gen_math_ops.pow(x, y, name: name); + public static Tensor range(object start, object limit = null, object delta = null, TF_DataType dtype = TF_DataType.DtInvalid, string name = "range") { if(limit == null) diff --git a/src/TensorFlowNET.Core/Operations/nn_impl.py.cs b/src/TensorFlowNET.Core/Operations/nn_impl.py.cs index 42103b00..a6c9e221 100644 --- a/src/TensorFlowNET.Core/Operations/nn_impl.py.cs +++ b/src/TensorFlowNET.Core/Operations/nn_impl.py.cs @@ -117,7 +117,7 @@ namespace Tensorflow var min_epsilon = 1.001e-5f; epsilon = epsilon > min_epsilon ? epsilon : min_epsilon; - var results = gen_nn_ops.fused_batch_norm(x, + var results = gen_nn_ops.fused_batch_norm_v3(x, scale_tensor, offset_tensor, mean, diff --git a/src/TensorFlowNET.Core/TensorFlow.Binding.csproj b/src/TensorFlowNET.Core/TensorFlow.Binding.csproj index 09e2674a..369c6c81 100644 --- a/src/TensorFlowNET.Core/TensorFlow.Binding.csproj +++ b/src/TensorFlowNET.Core/TensorFlow.Binding.csproj @@ -33,7 +33,7 @@ https://tensorflownet.readthedocs.io true - TRACE;DEBUG;SERIALIZABLE_ + TRACE;DEBUG;SERIALIZABLE @@ -62,6 +62,7 @@ https://tensorflownet.readthedocs.io + diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.cs b/src/TensorFlowNET.Core/Tensors/Tensor.cs index efac802d..2ec02232 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.cs @@ -115,6 +115,9 @@ namespace Tensorflow /// /// The name of the device on which this tensor will be produced, or null. /// +#if SERIALIZABLE + [JsonIgnore] +#endif public string Device => op.Device; #if SERIALIZABLE [JsonIgnore] diff --git a/src/TensorFlowNET.Core/ops.name_scope.cs b/src/TensorFlowNET.Core/ops.name_scope.cs index 80397667..55e9cf61 100644 --- a/src/TensorFlowNET.Core/ops.name_scope.cs +++ b/src/TensorFlowNET.Core/ops.name_scope.cs @@ -68,7 +68,7 @@ namespace Tensorflow var g = get_default_graph(); g._name_stack = old_stack; } - + public void __exit__() { } @@ -82,7 +82,7 @@ namespace Tensorflow { } - + /// /// __enter__() /// diff --git a/test/TensorFlowNET.UnitTest/Hub/MnistModelLoaderTest.cs b/test/TensorFlowNET.UnitTest/Hub/MnistModelLoaderTest.cs index 26dfd3b6..b1c90b32 100644 --- a/test/TensorFlowNET.UnitTest/Hub/MnistModelLoaderTest.cs +++ b/test/TensorFlowNET.UnitTest/Hub/MnistModelLoaderTest.cs @@ -2,7 +2,7 @@ using Microsoft.VisualStudio.TestTools.UnitTesting; using System.Threading.Tasks; using Tensorflow.Hub; -namespace UnitTest +namespace TensorFlowNET.UnitTest { [TestClass] public class MnistModelLoaderTest