| @@ -22,6 +22,7 @@ namespace Tensorflow | |||||
| private Dictionary<string, int> _names_in_use; | private Dictionary<string, int> _names_in_use; | ||||
| public int _version; | public int _version; | ||||
| private int _next_id_counter; | private int _next_id_counter; | ||||
| private List<String> _unfetchable_ops = new List<string>(); | |||||
| public Graph(IntPtr graph) | public Graph(IntPtr graph) | ||||
| { | { | ||||
| @@ -111,6 +112,20 @@ namespace Tensorflow | |||||
| return ++_next_id_counter; | return ++_next_id_counter; | ||||
| } | } | ||||
| public bool is_fetchable<T>(T tensor_or_op) | |||||
| { | |||||
| if (tensor_or_op is Tensor) | |||||
| { | |||||
| return !_unfetchable_ops.Contains((tensor_or_op as Tensor).name); ; | |||||
| } | |||||
| else if (tensor_or_op is Operation) | |||||
| { | |||||
| return !_unfetchable_ops.Contains((tensor_or_op as Operation).name); | |||||
| } | |||||
| return false; | |||||
| } | |||||
| public string unique_name(string name) | public string unique_name(string name) | ||||
| { | { | ||||
| var name_key = name.ToLower(); | var name_key = name.ToLower(); | ||||
| @@ -20,5 +20,10 @@ namespace Tensorflow | |||||
| _unique_fetches.Add(fetch); | _unique_fetches.Add(fetch); | ||||
| } | } | ||||
| } | } | ||||
| public List<Object> unique_fetches() | |||||
| { | |||||
| return _unique_fetches; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -10,10 +10,35 @@ namespace Tensorflow | |||||
| public class _FetchHandler | public class _FetchHandler | ||||
| { | { | ||||
| private _ElementFetchMapper _fetch_mapper; | private _ElementFetchMapper _fetch_mapper; | ||||
| private List<object> _fetches = new List<object>(); | |||||
| private List<bool> _ops = new List<bool>(); | |||||
| private List<object> _final_fetches = new List<object>(); | |||||
| public _FetchHandler(Graph graph, Tensor fetches, object feeds = null, object feed_handles = null) | public _FetchHandler(Graph graph, Tensor fetches, object feeds = null, object feed_handles = null) | ||||
| { | { | ||||
| _fetch_mapper = new _FetchMapper().for_fetch(fetches); | _fetch_mapper = new _FetchMapper().for_fetch(fetches); | ||||
| foreach(var fetch in _fetch_mapper.unique_fetches()) | |||||
| { | |||||
| switch (fetch) | |||||
| { | |||||
| case Tensor val: | |||||
| _assert_fetchable(graph, val.op); | |||||
| _fetches.Add(fetch); | |||||
| _ops.Add(false); | |||||
| break; | |||||
| } | |||||
| } | |||||
| _final_fetches = _fetches; | |||||
| } | |||||
| private void _assert_fetchable(Graph graph, Operation op) | |||||
| { | |||||
| if (!graph.is_fetchable(op)) | |||||
| { | |||||
| throw new Exception($"Operation {op.name} has been marked as not fetchable."); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||