From ccc52dd98d0960b1f1fe0c0e3d6395f8faa84620 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sat, 7 Nov 2020 06:49:15 -0600 Subject: [PATCH] fixed #632 --- .../Keras/Engine/Model.Train.cs | 3 +-- .../Keras/Optimizers/OptimizerV2.cs | 5 ++++- src/TensorFlowNET.Core/Variables/state_ops.cs | 15 +++++++++++++++ 3 files changed, 20 insertions(+), 3 deletions(-) diff --git a/src/TensorFlowNET.Core/Keras/Engine/Model.Train.cs b/src/TensorFlowNET.Core/Keras/Engine/Model.Train.cs index db00fabe..7bc9e9c6 100644 --- a/src/TensorFlowNET.Core/Keras/Engine/Model.Train.cs +++ b/src/TensorFlowNET.Core/Keras/Engine/Model.Train.cs @@ -16,7 +16,7 @@ namespace Tensorflow.Keras.Engine var data = iterator.next(); var outputs = train_step(data[0], data[1]); tf_with(ops.control_dependencies(new object[0]), ctl => _train_counter.assign_add(1)); - throw new NotImplementedException(""); + return null; } /// @@ -38,7 +38,6 @@ namespace Tensorflow.Keras.Engine // The _minimize call does a few extra steps unnecessary in most cases, // such as loss scaling and gradient clipping. _minimize(tape, optimizer, loss, trainable_variables); - compiled_metrics.update_state(y, y_pred); return new[] { ("loss", loss) }; } diff --git a/src/TensorFlowNET.Core/Keras/Optimizers/OptimizerV2.cs b/src/TensorFlowNET.Core/Keras/Optimizers/OptimizerV2.cs index 81b9f59a..fd29ea92 100644 --- a/src/TensorFlowNET.Core/Keras/Optimizers/OptimizerV2.cs +++ b/src/TensorFlowNET.Core/Keras/Optimizers/OptimizerV2.cs @@ -71,7 +71,7 @@ namespace Tensorflow.Keras.Optimizers return control_flow_ops.no_op(); var apply_state = _prepare(var_list); - if(experimental_aggregate_gradients) + // if(experimental_aggregate_gradients) { // var reduced_grads = _aggregate_gradients(grads_and_vars); _distributed_apply(grads_and_vars, name, apply_state); @@ -84,6 +84,9 @@ namespace Tensorflow.Keras.Optimizers void apply_grad_to_update_var(ResourceVariable var, Tensor grad, Dictionary> apply_state) { _resource_apply_dense(var, grad, apply_state); + // if var.constraint is not None: + // with ops.control_dependencies([update_op]): + // return var.assign(var.constraint(var)) } protected virtual Operation _resource_apply_dense(IVariableV1 var, diff --git a/src/TensorFlowNET.Core/Variables/state_ops.cs b/src/TensorFlowNET.Core/Variables/state_ops.cs index 3152686b..014a010b 100644 --- a/src/TensorFlowNET.Core/Variables/state_ops.cs +++ b/src/TensorFlowNET.Core/Variables/state_ops.cs @@ -67,6 +67,21 @@ namespace Tensorflow name: name); } + public static Tensor assign(IVariableV1 @ref, object value, + bool validate_shape = true, + bool use_locking = true, + string name = null) + { + if (@ref.dtype.is_ref_dtype()) + return gen_state_ops.assign(@ref, + value, + validate_shape: validate_shape, + use_locking: use_locking, + name: name); + else + return @ref.assign(value, name: name); + } + public static Tensor assign_sub(IVariableV1 @ref, Tensor value, bool use_locking = false,