From a7f95991a65b6ae91332671ded7b10e9a968a3b7 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sat, 16 Jan 2021 11:48:54 -0600 Subject: [PATCH] Fix RunInAutoMode. --- .../Contexts/Context.AutoMode.cs | 20 ++++++++++++------- .../Variables/BaseResourceVariable.cs | 2 +- .../Variables/ResourceVariable.Implicit.cs | 2 +- 3 files changed, 15 insertions(+), 9 deletions(-) diff --git a/src/TensorFlowNET.Core/Contexts/Context.AutoMode.cs b/src/TensorFlowNET.Core/Contexts/Context.AutoMode.cs index a42b79f0..7db178b3 100644 --- a/src/TensorFlowNET.Core/Contexts/Context.AutoMode.cs +++ b/src/TensorFlowNET.Core/Contexts/Context.AutoMode.cs @@ -32,19 +32,25 @@ namespace Tensorflow.Contexts { if (tf.Context.has_graph_arg(args)) { - return graphAction(); + if (executing_eagerly()) + { + graph_mode(); + var result = graphAction(); + restore_mode(); + return result; + } + else + { + return graphAction(); + } } else { - try + if (tf.Context.executing_eagerly()) { return eagerAction(); } - catch (InvalidArgumentError ex) - { - throw ex; - } - catch (Exception ex) + else { return graphAction(); } diff --git a/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs b/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs index a504c61b..4a30e060 100644 --- a/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs +++ b/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs @@ -162,7 +162,7 @@ namespace Tensorflow /// read the value only after some condition is true. /// /// - Tensor read_value() + protected Tensor read_value() => tf_with(ops.name_scope("Read"), delegate { var value = _read_variable_op(); diff --git a/src/TensorFlowNET.Core/Variables/ResourceVariable.Implicit.cs b/src/TensorFlowNET.Core/Variables/ResourceVariable.Implicit.cs index aa2815c7..29771c06 100644 --- a/src/TensorFlowNET.Core/Variables/ResourceVariable.Implicit.cs +++ b/src/TensorFlowNET.Core/Variables/ResourceVariable.Implicit.cs @@ -36,7 +36,7 @@ namespace Tensorflow if (as_ref) return handle; else - return AsTensor(); + return GraphElement ?? read_value(); } } }