From 1623bb87a7ee8da5eb1978ed8b89ae8604b10d57 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sat, 28 Nov 2020 18:00:52 -0600 Subject: [PATCH] fix _assign_new_value for BatchNormaliztion. --- src/TensorFlowNET.Core/Gradients/TapeTensor.cs | 3 +++ src/TensorFlowNET.Keras/Layers/BatchNormalization.cs | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/src/TensorFlowNET.Core/Gradients/TapeTensor.cs b/src/TensorFlowNET.Core/Gradients/TapeTensor.cs index 92c4e39f..be030321 100644 --- a/src/TensorFlowNET.Core/Gradients/TapeTensor.cs +++ b/src/TensorFlowNET.Core/Gradients/TapeTensor.cs @@ -22,5 +22,8 @@ namespace Tensorflow.Gradients public Tensor OnesLike() => tf.ones(shape: shape, dtype: dtype); + + public override string ToString() + => $"{id}, {shape}, {dtype.as_numpy_name()}"; } } diff --git a/src/TensorFlowNET.Keras/Layers/BatchNormalization.cs b/src/TensorFlowNET.Keras/Layers/BatchNormalization.cs index a160e496..18bd5c55 100644 --- a/src/TensorFlowNET.Keras/Layers/BatchNormalization.cs +++ b/src/TensorFlowNET.Keras/Layers/BatchNormalization.cs @@ -196,7 +196,7 @@ namespace Tensorflow.Keras.Layers _assign_moving_average(moving_variance, variance, momentum_tensor); if (use_fused_avg_updates) - _assign_new_value(moving_variance, mean); + _assign_new_value(moving_variance, variance); else _assign_moving_average(moving_variance, variance, momentum_tensor);