diff --git a/src/TensorFlowNET.Core/Graphs/Graph.cs b/src/TensorFlowNET.Core/Graphs/Graph.cs index f1a33371..ab228c47 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.cs @@ -388,7 +388,7 @@ namespace Tensorflow var handle = return_oper_handle.node + Marshal.SizeOf() * i; return_opers[i] = new Operation(*(IntPtr*)handle); } - + return return_opers; } @@ -439,6 +439,18 @@ namespace Tensorflow c_api.TF_DeleteGraph(_handle); } + /// + /// Returns the with the given . + /// This method may be called concurrently from multiple threads. + /// + /// The name of the `Tensor` to return. + /// If does not correspond to a tensor in this graph. + /// The `Tensor` with the given . + public Tensor get_tensor_by_name(string name) + { + return (Tensor)this.as_graph_element(name, allow_tensor: true, allow_operation: false); + } + public void __enter__() { } diff --git a/src/TensorFlowNET.Core/ops.GraphKeys.cs b/src/TensorFlowNET.Core/ops.GraphKeys.cs index 9e1e72f2..7babae11 100644 --- a/src/TensorFlowNET.Core/ops.GraphKeys.cs +++ b/src/TensorFlowNET.Core/ops.GraphKeys.cs @@ -35,6 +35,8 @@ namespace Tensorflow public static string TRAIN_OP = "train_op"; + public static string GLOBAL_STEP = GLOBAL_STEP = "global_step"; + public static string[] _VARIABLE_COLLECTIONS = new string[] { "variables", "trainable_variables" }; /// /// Key to collect BaseSaverBuilder.SaveableObject instances for checkpointing.