From df913078b9260df2b46cb81bf34a1108b6af0bfe Mon Sep 17 00:00:00 2001 From: Haiping Chen Date: Thu, 27 Apr 2023 06:40:26 -0500 Subject: [PATCH] Fix namespace compile issue. --- .../Keras/Engine/IOptimizer.cs | 8 +++++ .../Optimizers/OptimizerV2.cs | 36 +++++++++++++++++++ .../ComplexTest.cs | 3 +- .../SignalTest.cs | 3 +- 4 files changed, 46 insertions(+), 4 deletions(-) diff --git a/src/TensorFlowNET.Core/Keras/Engine/IOptimizer.cs b/src/TensorFlowNET.Core/Keras/Engine/IOptimizer.cs index 5458a536..1f989391 100644 --- a/src/TensorFlowNET.Core/Keras/Engine/IOptimizer.cs +++ b/src/TensorFlowNET.Core/Keras/Engine/IOptimizer.cs @@ -10,5 +10,13 @@ public interface IOptimizer void apply_gradients(IEnumerable<(Tensor, IVariableV1)> grads_and_vars, string name = null, bool experimental_aggregate_gradients = true); + + void apply_gradients((Tensor, ResourceVariable) grads_and_vars, + string name = null, + bool experimental_aggregate_gradients = true); + void apply_gradients(IEnumerable<(Tensor, ResourceVariable)> grads_and_vars, + string name = null, + bool experimental_aggregate_gradients = true); + IVariableV1 add_slot(IVariableV1 var, string slot_name, IInitializer initializer = null); } diff --git a/src/TensorFlowNET.Keras/Optimizers/OptimizerV2.cs b/src/TensorFlowNET.Keras/Optimizers/OptimizerV2.cs index 44c163bc..1e4dbe08 100644 --- a/src/TensorFlowNET.Keras/Optimizers/OptimizerV2.cs +++ b/src/TensorFlowNET.Keras/Optimizers/OptimizerV2.cs @@ -78,6 +78,42 @@ namespace Tensorflow.Keras.Optimizers }); } + public void apply_gradients((Tensor, ResourceVariable) grads_and_vars, + string name = null, + bool experimental_aggregate_gradients = true) + => apply_gradients(new[] { grads_and_vars }, + name: name, + experimental_aggregate_gradients: experimental_aggregate_gradients); + + /// + /// Apply gradients to variables. + /// + /// + /// + /// + public void apply_gradients(IEnumerable<(Tensor, ResourceVariable)> grads_and_vars, + string name = null, + bool experimental_aggregate_gradients = true) + { + var var_list = grads_and_vars.Select(x => x.Item2).ToArray(); + tf_with(ops.name_scope(_name), delegate + { + ops.init_scope(); + _create_all_weights(var_list); + if (grads_and_vars == null || grads_and_vars.Count() == 0) + return control_flow_ops.no_op(); + + var apply_state = _prepare(var_list); + // if(experimental_aggregate_gradients) + { + // var reduced_grads = _aggregate_gradients(grads_and_vars); + _distributed_apply(grads_and_vars.Select(x => (x.Item1, (IVariableV1)x.Item2)), name, apply_state); + } + + return null; + }); + } + void apply_grad_to_update_var(IVariableV1 var, Tensor grad, Dictionary> apply_state) { _resource_apply_dense(var, grad, apply_state); diff --git a/test/TensorFlowNET.Graph.UnitTest/ComplexTest.cs b/test/TensorFlowNET.Graph.UnitTest/ComplexTest.cs index a57ec929..abb44eee 100644 --- a/test/TensorFlowNET.Graph.UnitTest/ComplexTest.cs +++ b/test/TensorFlowNET.Graph.UnitTest/ComplexTest.cs @@ -5,8 +5,7 @@ using System.Collections.Generic; using System.Linq; using Tensorflow; using static Tensorflow.Binding; -using Buffer = Tensorflow.Buffer; -using TensorFlowNET.Keras.UnitTest; +using Tensorflow.Keras.UnitTest; namespace TensorFlowNET.UnitTest.Basics { diff --git a/test/TensorFlowNET.Graph.UnitTest/SignalTest.cs b/test/TensorFlowNET.Graph.UnitTest/SignalTest.cs index 01014a10..cc09b101 100644 --- a/test/TensorFlowNET.Graph.UnitTest/SignalTest.cs +++ b/test/TensorFlowNET.Graph.UnitTest/SignalTest.cs @@ -5,8 +5,7 @@ using System.Collections.Generic; using System.Linq; using Tensorflow; using static Tensorflow.Binding; -using Buffer = Tensorflow.Buffer; -using TensorFlowNET.Keras.UnitTest; +using Tensorflow.Keras.UnitTest; namespace TensorFlowNET.UnitTest.Basics {