| @@ -0,0 +1,40 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| using System.Text; | |||
| namespace Tensorflow | |||
| { | |||
| public class DefaultGraphStack | |||
| { | |||
| List<StackModel> stack = new List<StackModel>(); | |||
| public void set_controller(Graph @default) | |||
| { | |||
| if (!stack.Exists(x => x.Graph == @default)) | |||
| stack.Add(new StackModel { Graph = @default, IsDefault = true }); | |||
| foreach (var s in stack) | |||
| s.IsDefault = s.Graph == @default; | |||
| } | |||
| public Graph get_controller() | |||
| { | |||
| if (stack.Count == 0) | |||
| stack.Add(new StackModel { Graph = tf.Graph(), IsDefault = true }); | |||
| return stack.First(x => x.IsDefault).Graph; | |||
| } | |||
| public void reset() | |||
| { | |||
| stack.Clear(); | |||
| } | |||
| } | |||
| public class StackModel | |||
| { | |||
| public Graph Graph { get; set; } | |||
| public bool IsDefault { get; set; } | |||
| } | |||
| } | |||
| @@ -87,7 +87,7 @@ namespace Tensorflow | |||
| private Dictionary<string, object> _collections = new Dictionary<string, object>(); | |||
| public bool building_function; | |||
| public Graph() | |||
| { | |||
| _handle = c_api.TF_NewGraph(); | |||
| @@ -113,7 +113,14 @@ namespace Tensorflow | |||
| return _as_graph_element_locked(obj, allow_tensor, allow_operation); | |||
| } | |||
| public Graph as_default() => ops.set_default_graph(this); | |||
| /// <summary> | |||
| /// Returns a context manager that makes this `Graph` the default graph. | |||
| /// </summary> | |||
| /// <returns></returns> | |||
| public Graph as_default() | |||
| { | |||
| return ops.set_default_graph(this); | |||
| } | |||
| private Tensor _as_graph_element(object obj) | |||
| { | |||
| @@ -172,13 +172,12 @@ namespace Tensorflow.Layers | |||
| } | |||
| else | |||
| { | |||
| with(tf.variable_scope(scope, default_name: _base_name), | |||
| captured_scope => | |||
| { | |||
| _scope = captured_scope; | |||
| }); | |||
| with(tf.variable_scope(scope, default_name: _base_name), captured_scope => | |||
| { | |||
| // convert variable_scope to VariableScope | |||
| _scope = captured_scope; | |||
| }); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -26,6 +26,7 @@ namespace Tensorflow | |||
| private bool? _reuse; | |||
| bool _in_graph_mode; | |||
| protected Graph _graph; | |||
| bool _building_function; | |||
| public variable_scope(string name, | |||
| string default_name = "", | |||
| @@ -70,6 +71,17 @@ namespace Tensorflow | |||
| public void __enter__() | |||
| { | |||
| // If the default graph is building a function, then we should not replace it | |||
| // with the cached graph. | |||
| if (ops.get_default_graph().building_function) | |||
| _building_function = true; | |||
| else | |||
| _building_function = false; | |||
| if (_in_graph_mode && !_building_function) | |||
| { | |||
| _graph.as_default(); | |||
| } | |||
| _scope = _enter_scope_uncached(); | |||
| } | |||
| @@ -54,12 +54,13 @@ namespace Tensorflow | |||
| public void Dispose() | |||
| { | |||
| var g = get_default_graph(); | |||
| // Console.WriteLine($"name_scope: {g._name_stack} -> {old_stack}"); | |||
| g._name_stack = old_stack; | |||
| // Console.WriteLine($"name_scope: {g._name_stack} -> {old_stack}"); | |||
| } | |||
| public void __exit__() | |||
| { | |||
| } | |||
| /// <summary> | |||
| @@ -50,7 +50,8 @@ namespace Tensorflow | |||
| return get_default_graph().get_collection_ref(key); | |||
| } | |||
| private static Graph default_graph; | |||
| public static DefaultGraphStack default_graph_stack = new DefaultGraphStack(); | |||
| /// <summary> | |||
| /// Returns the default graph for the current thread. | |||
| /// | |||
| @@ -68,15 +69,13 @@ namespace Tensorflow | |||
| { | |||
| //TODO: original source indicates there should be a _default_graph_stack! | |||
| //return _default_graph_stack.get_default() | |||
| if (default_graph == null) | |||
| default_graph = tf.Graph(); | |||
| return default_graph; | |||
| return default_graph_stack.get_controller(); | |||
| } | |||
| public static Graph set_default_graph(Graph graph) | |||
| { | |||
| //TODO: original source does not have a 'set_default_graph' and indicates there should be a _default_graph_stack! | |||
| default_graph = graph; | |||
| return default_graph; | |||
| default_graph_stack.set_controller(graph); | |||
| return default_graph_stack.get_controller(); | |||
| } | |||
| /// <summary> | |||
| @@ -96,10 +95,7 @@ namespace Tensorflow | |||
| // throw new InvalidOperationException("Do not use tf.reset_default_graph() to clear " + | |||
| // "nested graphs. If you need a cleared graph, " + | |||
| // "exit the nesting and create a new graph."); | |||
| //_default_graph_stack.reset(); | |||
| if (default_graph!=null) | |||
| default_graph.Dispose(); | |||
| default_graph = tf.Graph(); | |||
| default_graph_stack.reset(); | |||
| } | |||
| public static Graph _get_graph_from_inputs(params Tensor[] op_input_list) | |||
| @@ -195,7 +195,7 @@ namespace TensorFlowNET.Examples | |||
| return graph; | |||
| } | |||
| private bool RunWithImportedGraph(Session sess, Graph graph) | |||
| private bool Train(Session sess, Graph graph) | |||
| { | |||
| var stopwatch = Stopwatch.StartNew(); | |||
| @@ -274,8 +274,7 @@ namespace TensorFlowNET.Examples | |||
| { | |||
| var graph = IsImportingGraph ? ImportGraph() : BuildGraph(); | |||
| return with(tf.Session(graph), sess | |||
| => RunWithImportedGraph(sess, graph)); | |||
| return with(tf.Session(graph), sess => Train(sess, graph)); | |||
| } | |||
| public bool Predict() | |||