From b968fd79ab156bfca62f434c7fb936e2ed512455 Mon Sep 17 00:00:00 2001 From: dogvane Date: Mon, 10 Jul 2023 00:41:23 +0800 Subject: [PATCH] add avg_pool_grad function --- src/TensorFlowNET.Core/Gradients/nn_grad.cs | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/src/TensorFlowNET.Core/Gradients/nn_grad.cs b/src/TensorFlowNET.Core/Gradients/nn_grad.cs index a1ac97a9..3a6efd54 100644 --- a/src/TensorFlowNET.Core/Gradients/nn_grad.cs +++ b/src/TensorFlowNET.Core/Gradients/nn_grad.cs @@ -365,6 +365,23 @@ namespace Tensorflow.Gradients }; } + [RegisterGradient("AvgPool")] + public static Tensor[] _AvgPoolGrad(Operation op, Tensor[] grads) + { + Tensor grad = grads[0]; + + return new Tensor[] + { + gen_nn_ops.avg_pool_grad( + array_ops.shape(op.inputs[0]), + grad, + op.get_attr_list("ksize"), + op.get_attr_list("strides"), + op.get_attr("padding").ToString(), + op.get_attr("data_format").ToString()) + }; + } + /// /// Return the gradients for TopK. ///