| @@ -0,0 +1,40 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Linq; | |||||
| using System.Text; | |||||
| namespace Tensorflow | |||||
| { | |||||
| public class DefaultGraphStack | |||||
| { | |||||
| List<StackModel> stack = new List<StackModel>(); | |||||
| public void set_controller(Graph @default) | |||||
| { | |||||
| if (!stack.Exists(x => x.Graph == @default)) | |||||
| stack.Add(new StackModel { Graph = @default, IsDefault = true }); | |||||
| foreach (var s in stack) | |||||
| s.IsDefault = s.Graph == @default; | |||||
| } | |||||
| public Graph get_controller() | |||||
| { | |||||
| if (stack.Count == 0) | |||||
| stack.Add(new StackModel { Graph = tf.Graph(), IsDefault = true }); | |||||
| return stack.First(x => x.IsDefault).Graph; | |||||
| } | |||||
| public void reset() | |||||
| { | |||||
| stack.Clear(); | |||||
| } | |||||
| } | |||||
| public class StackModel | |||||
| { | |||||
| public Graph Graph { get; set; } | |||||
| public bool IsDefault { get; set; } | |||||
| } | |||||
| } | |||||
| @@ -87,7 +87,7 @@ namespace Tensorflow | |||||
| private Dictionary<string, object> _collections = new Dictionary<string, object>(); | private Dictionary<string, object> _collections = new Dictionary<string, object>(); | ||||
| public bool building_function; | public bool building_function; | ||||
| public Graph() | public Graph() | ||||
| { | { | ||||
| _handle = c_api.TF_NewGraph(); | _handle = c_api.TF_NewGraph(); | ||||
| @@ -113,7 +113,14 @@ namespace Tensorflow | |||||
| return _as_graph_element_locked(obj, allow_tensor, allow_operation); | return _as_graph_element_locked(obj, allow_tensor, allow_operation); | ||||
| } | } | ||||
| public Graph as_default() => ops.set_default_graph(this); | |||||
| /// <summary> | |||||
| /// Returns a context manager that makes this `Graph` the default graph. | |||||
| /// </summary> | |||||
| /// <returns></returns> | |||||
| public Graph as_default() | |||||
| { | |||||
| return ops.set_default_graph(this); | |||||
| } | |||||
| private Tensor _as_graph_element(object obj) | private Tensor _as_graph_element(object obj) | ||||
| { | { | ||||
| @@ -172,13 +172,12 @@ namespace Tensorflow.Layers | |||||
| } | } | ||||
| else | else | ||||
| { | { | ||||
| with(tf.variable_scope(scope, default_name: _base_name), | |||||
| captured_scope => | |||||
| { | |||||
| _scope = captured_scope; | |||||
| }); | |||||
| with(tf.variable_scope(scope, default_name: _base_name), captured_scope => | |||||
| { | |||||
| // convert variable_scope to VariableScope | |||||
| _scope = captured_scope; | |||||
| }); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -26,6 +26,7 @@ namespace Tensorflow | |||||
| private bool? _reuse; | private bool? _reuse; | ||||
| bool _in_graph_mode; | bool _in_graph_mode; | ||||
| protected Graph _graph; | protected Graph _graph; | ||||
| bool _building_function; | |||||
| public variable_scope(string name, | public variable_scope(string name, | ||||
| string default_name = "", | string default_name = "", | ||||
| @@ -70,6 +71,17 @@ namespace Tensorflow | |||||
| public void __enter__() | public void __enter__() | ||||
| { | { | ||||
| // If the default graph is building a function, then we should not replace it | |||||
| // with the cached graph. | |||||
| if (ops.get_default_graph().building_function) | |||||
| _building_function = true; | |||||
| else | |||||
| _building_function = false; | |||||
| if (_in_graph_mode && !_building_function) | |||||
| { | |||||
| _graph.as_default(); | |||||
| } | |||||
| _scope = _enter_scope_uncached(); | _scope = _enter_scope_uncached(); | ||||
| } | } | ||||
| @@ -54,12 +54,13 @@ namespace Tensorflow | |||||
| public void Dispose() | public void Dispose() | ||||
| { | { | ||||
| var g = get_default_graph(); | var g = get_default_graph(); | ||||
| // Console.WriteLine($"name_scope: {g._name_stack} -> {old_stack}"); | |||||
| g._name_stack = old_stack; | g._name_stack = old_stack; | ||||
| // Console.WriteLine($"name_scope: {g._name_stack} -> {old_stack}"); | |||||
| } | } | ||||
| public void __exit__() | public void __exit__() | ||||
| { | { | ||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -50,7 +50,8 @@ namespace Tensorflow | |||||
| return get_default_graph().get_collection_ref(key); | return get_default_graph().get_collection_ref(key); | ||||
| } | } | ||||
| private static Graph default_graph; | |||||
| public static DefaultGraphStack default_graph_stack = new DefaultGraphStack(); | |||||
| /// <summary> | /// <summary> | ||||
| /// Returns the default graph for the current thread. | /// Returns the default graph for the current thread. | ||||
| /// | /// | ||||
| @@ -68,15 +69,13 @@ namespace Tensorflow | |||||
| { | { | ||||
| //TODO: original source indicates there should be a _default_graph_stack! | //TODO: original source indicates there should be a _default_graph_stack! | ||||
| //return _default_graph_stack.get_default() | //return _default_graph_stack.get_default() | ||||
| if (default_graph == null) | |||||
| default_graph = tf.Graph(); | |||||
| return default_graph; | |||||
| return default_graph_stack.get_controller(); | |||||
| } | } | ||||
| public static Graph set_default_graph(Graph graph) | public static Graph set_default_graph(Graph graph) | ||||
| { | { | ||||
| //TODO: original source does not have a 'set_default_graph' and indicates there should be a _default_graph_stack! | //TODO: original source does not have a 'set_default_graph' and indicates there should be a _default_graph_stack! | ||||
| default_graph = graph; | |||||
| return default_graph; | |||||
| default_graph_stack.set_controller(graph); | |||||
| return default_graph_stack.get_controller(); | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -96,10 +95,7 @@ namespace Tensorflow | |||||
| // throw new InvalidOperationException("Do not use tf.reset_default_graph() to clear " + | // throw new InvalidOperationException("Do not use tf.reset_default_graph() to clear " + | ||||
| // "nested graphs. If you need a cleared graph, " + | // "nested graphs. If you need a cleared graph, " + | ||||
| // "exit the nesting and create a new graph."); | // "exit the nesting and create a new graph."); | ||||
| //_default_graph_stack.reset(); | |||||
| if (default_graph!=null) | |||||
| default_graph.Dispose(); | |||||
| default_graph = tf.Graph(); | |||||
| default_graph_stack.reset(); | |||||
| } | } | ||||
| public static Graph _get_graph_from_inputs(params Tensor[] op_input_list) | public static Graph _get_graph_from_inputs(params Tensor[] op_input_list) | ||||
| @@ -195,7 +195,7 @@ namespace TensorFlowNET.Examples | |||||
| return graph; | return graph; | ||||
| } | } | ||||
| private bool RunWithImportedGraph(Session sess, Graph graph) | |||||
| private bool Train(Session sess, Graph graph) | |||||
| { | { | ||||
| var stopwatch = Stopwatch.StartNew(); | var stopwatch = Stopwatch.StartNew(); | ||||
| @@ -274,8 +274,7 @@ namespace TensorFlowNET.Examples | |||||
| { | { | ||||
| var graph = IsImportingGraph ? ImportGraph() : BuildGraph(); | var graph = IsImportingGraph ? ImportGraph() : BuildGraph(); | ||||
| return with(tf.Session(graph), sess | |||||
| => RunWithImportedGraph(sess, graph)); | |||||
| return with(tf.Session(graph), sess => Train(sess, graph)); | |||||
| } | } | ||||
| public bool Predict() | public bool Predict() | ||||