| @@ -14,49 +14,61 @@ | |||
| limitations under the License. | |||
| ******************************************************************************/ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow | |||
| { | |||
| /// <summary> | |||
| /// Serves as a stack for determining current default graph. | |||
| /// </summary> | |||
| public class DefaultGraphStack | |||
| { | |||
| List<StackModel> stack = new List<StackModel>(); | |||
| private readonly List<StackModel> _stack = new List<StackModel>(); | |||
| public void set_controller(Graph @default) | |||
| { | |||
| if (!stack.Exists(x => x.Graph == @default)) | |||
| stack.Add(new StackModel { Graph = @default, IsDefault = true }); | |||
| if (!_stack.Exists(x => x.Graph == @default)) | |||
| _stack.Add(new StackModel {Graph = @default, IsDefault = true}); | |||
| foreach (var s in stack) | |||
| foreach (var s in _stack) | |||
| s.IsDefault = s.Graph == @default; | |||
| } | |||
| public Graph get_controller() | |||
| { | |||
| if (stack.Count(x => x.IsDefault) == 0) | |||
| stack.Add(new StackModel { Graph = tf.Graph(), IsDefault = true }); | |||
| if (_stack.Count(x => x.IsDefault) == 0) | |||
| _stack.Add(new StackModel {Graph = tf.Graph(), IsDefault = true}); | |||
| for (var i = _stack.Count - 1; i >= 0; i--) | |||
| { | |||
| var x = _stack[i]; | |||
| if (x.IsDefault) | |||
| return x.Graph; | |||
| } | |||
| return stack.Last(x => x.IsDefault).Graph; | |||
| throw new TensorflowException("Unable to find a default graph"); | |||
| } | |||
| public bool remove(Graph g) | |||
| { | |||
| var sm = stack.FirstOrDefault(x => x.Graph == g); | |||
| if (sm == null) return false; | |||
| return stack.Remove(sm); | |||
| if (_stack.Count == 0) | |||
| return false; | |||
| var sm = _stack.Find(model => model.Graph == g); | |||
| return sm != null && _stack.Remove(sm); | |||
| } | |||
| public void reset() | |||
| { | |||
| stack.Clear(); | |||
| _stack.Clear(); | |||
| } | |||
| } | |||
| public class StackModel | |||
| { | |||
| public Graph Graph { get; set; } | |||
| public bool IsDefault { get; set; } | |||
| private class StackModel | |||
| { | |||
| public Graph Graph { get; set; } | |||
| public bool IsDefault { get; set; } | |||
| } | |||
| } | |||
| } | |||
| } | |||