diff --git a/src/TensorFlowNET.Core/Keras/Engine/IOptimizer.cs b/src/TensorFlowNET.Core/Keras/Engine/IOptimizer.cs
index 5458a536..1f989391 100644
--- a/src/TensorFlowNET.Core/Keras/Engine/IOptimizer.cs
+++ b/src/TensorFlowNET.Core/Keras/Engine/IOptimizer.cs
@@ -10,5 +10,13 @@ public interface IOptimizer
void apply_gradients(IEnumerable<(Tensor, IVariableV1)> grads_and_vars,
string name = null,
bool experimental_aggregate_gradients = true);
+
+ void apply_gradients((Tensor, ResourceVariable) grads_and_vars,
+ string name = null,
+ bool experimental_aggregate_gradients = true);
+ void apply_gradients(IEnumerable<(Tensor, ResourceVariable)> grads_and_vars,
+ string name = null,
+ bool experimental_aggregate_gradients = true);
+
IVariableV1 add_slot(IVariableV1 var, string slot_name, IInitializer initializer = null);
}
diff --git a/src/TensorFlowNET.Keras/Optimizers/OptimizerV2.cs b/src/TensorFlowNET.Keras/Optimizers/OptimizerV2.cs
index 44c163bc..1e4dbe08 100644
--- a/src/TensorFlowNET.Keras/Optimizers/OptimizerV2.cs
+++ b/src/TensorFlowNET.Keras/Optimizers/OptimizerV2.cs
@@ -78,6 +78,42 @@ namespace Tensorflow.Keras.Optimizers
});
}
+ public void apply_gradients((Tensor, ResourceVariable) grads_and_vars,
+ string name = null,
+ bool experimental_aggregate_gradients = true)
+ => apply_gradients(new[] { grads_and_vars },
+ name: name,
+ experimental_aggregate_gradients: experimental_aggregate_gradients);
+
+ ///
+ /// Apply gradients to variables.
+ ///
+ ///
+ ///
+ ///
+ public void apply_gradients(IEnumerable<(Tensor, ResourceVariable)> grads_and_vars,
+ string name = null,
+ bool experimental_aggregate_gradients = true)
+ {
+ var var_list = grads_and_vars.Select(x => x.Item2).ToArray();
+ tf_with(ops.name_scope(_name), delegate
+ {
+ ops.init_scope();
+ _create_all_weights(var_list);
+ if (grads_and_vars == null || grads_and_vars.Count() == 0)
+ return control_flow_ops.no_op();
+
+ var apply_state = _prepare(var_list);
+ // if(experimental_aggregate_gradients)
+ {
+ // var reduced_grads = _aggregate_gradients(grads_and_vars);
+ _distributed_apply(grads_and_vars.Select(x => (x.Item1, (IVariableV1)x.Item2)), name, apply_state);
+ }
+
+ return null;
+ });
+ }
+
void apply_grad_to_update_var(IVariableV1 var, Tensor grad, Dictionary> apply_state)
{
_resource_apply_dense(var, grad, apply_state);
diff --git a/test/TensorFlowNET.Graph.UnitTest/ComplexTest.cs b/test/TensorFlowNET.Graph.UnitTest/ComplexTest.cs
index a57ec929..abb44eee 100644
--- a/test/TensorFlowNET.Graph.UnitTest/ComplexTest.cs
+++ b/test/TensorFlowNET.Graph.UnitTest/ComplexTest.cs
@@ -5,8 +5,7 @@ using System.Collections.Generic;
using System.Linq;
using Tensorflow;
using static Tensorflow.Binding;
-using Buffer = Tensorflow.Buffer;
-using TensorFlowNET.Keras.UnitTest;
+using Tensorflow.Keras.UnitTest;
namespace TensorFlowNET.UnitTest.Basics
{
diff --git a/test/TensorFlowNET.Graph.UnitTest/SignalTest.cs b/test/TensorFlowNET.Graph.UnitTest/SignalTest.cs
index 01014a10..cc09b101 100644
--- a/test/TensorFlowNET.Graph.UnitTest/SignalTest.cs
+++ b/test/TensorFlowNET.Graph.UnitTest/SignalTest.cs
@@ -5,8 +5,7 @@ using System.Collections.Generic;
using System.Linq;
using Tensorflow;
using static Tensorflow.Binding;
-using Buffer = Tensorflow.Buffer;
-using TensorFlowNET.Keras.UnitTest;
+using Tensorflow.Keras.UnitTest;
namespace TensorFlowNET.UnitTest.Basics
{