| @@ -31,6 +31,50 @@ namespace Tensorflow | |||||
| _names_in_use = new Dictionary<string, int>(); | _names_in_use = new Dictionary<string, int>(); | ||||
| } | } | ||||
| public T as_graph_element<T>(T obj, bool allow_tensor = true, bool allow_operation = true) | |||||
| { | |||||
| return _as_graph_element_locked(obj, allow_tensor, allow_operation); | |||||
| } | |||||
| private Func<object> _as_graph_element(object obj) | |||||
| { | |||||
| return null; | |||||
| } | |||||
| private T _as_graph_element_locked<T>(T obj, bool allow_tensor = true, bool allow_operation = true) | |||||
| { | |||||
| string types_str = ""; | |||||
| if (allow_tensor && allow_operation) | |||||
| { | |||||
| types_str = "Tensor or Operation"; | |||||
| } | |||||
| else if (allow_tensor) | |||||
| { | |||||
| types_str = "Tensor"; | |||||
| } | |||||
| else if (allow_operation) | |||||
| { | |||||
| types_str = "Operation"; | |||||
| } | |||||
| var temp_obj = _as_graph_element(obj); | |||||
| if(obj is Tensor && allow_tensor) | |||||
| { | |||||
| if ((obj as Tensor).graph.Equals(this)) | |||||
| { | |||||
| return obj; | |||||
| } | |||||
| else | |||||
| { | |||||
| throw new Exception($"Tensor {obj} is not an element of this graph."); | |||||
| } | |||||
| } | |||||
| throw new Exception($"Can not convert a {typeof(T).Name} into a {types_str}."); | |||||
| } | |||||
| public unsafe Operation create_op(string op_type, List<Tensor> inputs, TF_DataType[] dtypes, | public unsafe Operation create_op(string op_type, List<Tensor> inputs, TF_DataType[] dtypes, | ||||
| TF_DataType[] input_types = null, string name = "", | TF_DataType[] input_types = null, string name = "", | ||||
| Dictionary<string, AttrValue> attrs = null, OpDef op_def = null) | Dictionary<string, AttrValue> attrs = null, OpDef op_def = null) | ||||
| @@ -8,6 +8,7 @@ namespace Tensorflow | |||||
| public class Operation | public class Operation | ||||
| { | { | ||||
| private Graph _graph; | private Graph _graph; | ||||
| public Graph graph => _graph; | |||||
| public IntPtr _c_op; | public IntPtr _c_op; | ||||
| public int _id => _id_value; | public int _id => _id_value; | ||||
| private int _id_value; | private int _id_value; | ||||
| @@ -0,0 +1,24 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow | |||||
| { | |||||
| /// <summary> | |||||
| /// Fetch mapper for singleton tensors and ops. | |||||
| /// </summary> | |||||
| public class _ElementFetchMapper : _FetchMapper | |||||
| { | |||||
| private List<Object> _unique_fetches = new List<object>(); | |||||
| private Action _contraction_fn; | |||||
| public _ElementFetchMapper(List<Tensor> fetches, Action contraction_fn) | |||||
| { | |||||
| foreach(var tensor in fetches) | |||||
| { | |||||
| var fetch = ops.get_default_graph().as_graph_element(tensor, allow_tensor: true, allow_operation: true); | |||||
| _unique_fetches.Add(fetch); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -9,9 +9,11 @@ namespace Tensorflow | |||||
| /// </summary> | /// </summary> | ||||
| public class _FetchHandler | public class _FetchHandler | ||||
| { | { | ||||
| public _FetchHandler(Graph graph, Tensor fetches) | |||||
| { | |||||
| private _ElementFetchMapper _fetch_mapper; | |||||
| public _FetchHandler(Graph graph, Tensor fetches, object feeds = null, object feed_handles = null) | |||||
| { | |||||
| _fetch_mapper = new _FetchMapper().for_fetch(fetches); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -0,0 +1,16 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow | |||||
| { | |||||
| public class _FetchMapper | |||||
| { | |||||
| public _ElementFetchMapper for_fetch(Tensor fetch) | |||||
| { | |||||
| var fetches = new List<Tensor> { fetch }; | |||||
| return new _ElementFetchMapper(fetches, null); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -13,6 +13,8 @@ namespace Tensorflow | |||||
| private DataType _dtype; | private DataType _dtype; | ||||
| public DataType dtype => _dtype; | public DataType dtype => _dtype; | ||||
| public Graph graph => _op.graph; | |||||
| public string name; | public string name; | ||||
| public Tensor(Operation op, int value_index, DataType dtype) | public Tensor(Operation op, int value_index, DataType dtype) | ||||