| @@ -24,7 +24,7 @@ namespace Tensorflow | |||
| public GFile gfile = new GFile(); | |||
| public Tensor read_file(string filename, string name = null) => gen_io_ops.read_file(filename, name); | |||
| public void import_graph_def(GraphDef graph_def, | |||
| public ITensorOrOperation[] import_graph_def(GraphDef graph_def, | |||
| Dictionary<string, Tensor> input_map = null, | |||
| string[] return_elements = null, | |||
| string name = null, | |||
| @@ -95,6 +95,9 @@ namespace Tensorflow | |||
| throw new NotImplementedException("len() not implemented for type: " + a.GetType()); | |||
| } | |||
| public static float min(float a, float b) | |||
| => Math.Min(a, b); | |||
| public static T[] list<T>(IEnumerable<T> list) | |||
| => list.ToArray(); | |||
| @@ -54,6 +54,7 @@ namespace Tensorflow | |||
| input_map = _ConvertInputMapValues(name, input_map); | |||
| }); | |||
| TF_ImportGraphDefResults results = null; | |||
| var bytes = graph_def.ToByteString().ToArray(); | |||
| using (var buffer = c_api_util.tf_buffer(bytes)) | |||
| using (var scoped_options = c_api_util.ScopedTFImportGraphDefOptions()) | |||
| @@ -61,9 +62,8 @@ namespace Tensorflow | |||
| { | |||
| _PopulateTFImportGraphDefOptions(scoped_options, prefix, input_map, return_elements); | |||
| // need to create a class ImportGraphDefWithResults with IDisposal | |||
| var results = c_api.TF_GraphImportGraphDefWithResults(graph, buffer, scoped_options, status); | |||
| results = c_api.TF_GraphImportGraphDefWithResults(graph, buffer, scoped_options, status); | |||
| status.Check(true); | |||
| c_api.TF_DeleteImportGraphDefResults(results); | |||
| } | |||
| _ProcessNewOps(graph); | |||
| @@ -71,7 +71,34 @@ namespace Tensorflow | |||
| if (return_elements == null) | |||
| return null; | |||
| else | |||
| throw new NotImplementedException("import_graph_def return_elements"); | |||
| return _GatherReturnElements(return_elements, graph, results); | |||
| } | |||
| private static ITensorOrOperation[] _GatherReturnElements(string[] requested_return_elements, | |||
| Graph graph, | |||
| TF_ImportGraphDefResults results) | |||
| { | |||
| var return_outputs = results.return_tensors; | |||
| var return_opers = results.return_opers; | |||
| var combined_return_elements = new List<ITensorOrOperation>(); | |||
| int outputs_idx = 0; | |||
| int opers_idx = 0; | |||
| foreach(var name in requested_return_elements) | |||
| { | |||
| if (name.Contains(":")) | |||
| { | |||
| combined_return_elements.append(graph.get_tensor_by_tf_output(return_outputs[outputs_idx])); | |||
| outputs_idx += 1; | |||
| } | |||
| else | |||
| { | |||
| throw new NotImplementedException("_GatherReturnElements"); | |||
| // combined_return_elements.append(graph._get_operation_by_tf_operation(return_opers[opers_idx])); | |||
| } | |||
| } | |||
| return combined_return_elements.ToArray(); | |||
| } | |||
| private static void _ProcessNewOps(Graph graph) | |||
| @@ -100,8 +127,29 @@ namespace Tensorflow | |||
| foreach (var name in return_elements) | |||
| { | |||
| throw new NotImplementedException("_PopulateTFImportGraphDefOptions"); | |||
| if(name.Contains(":")) | |||
| { | |||
| var (op_name, index) = _ParseTensorName(name); | |||
| c_api.TF_ImportGraphDefOptionsAddReturnOutput(options, op_name, index); | |||
| } | |||
| else | |||
| { | |||
| c_api.TF_ImportGraphDefOptionsAddReturnOperation(options, name); | |||
| } | |||
| } | |||
| // c_api.TF_ImportGraphDefOptionsSetValidateColocationConstraints(options, validate_colocation_constraints); | |||
| } | |||
| private static (string, int) _ParseTensorName(string tensor_name) | |||
| { | |||
| var components = tensor_name.Split(':'); | |||
| if (components.Length == 2) | |||
| return (components[0], int.Parse(components[1])); | |||
| else if (components.Length == 1) | |||
| return (components[0], 0); | |||
| else | |||
| throw new ValueError($"Cannot convert {tensor_name} to a tensor name."); | |||
| } | |||
| public static Dictionary<string, Tensor> _ConvertInputMapValues(string name, Dictionary<string, Tensor> input_map) | |||
| @@ -494,6 +494,12 @@ namespace Tensorflow | |||
| c_api.TF_DeleteGraph(handle); | |||
| } | |||
| public Tensor get_tensor_by_tf_output(TF_Output tf_output) | |||
| { | |||
| var op = _get_operation_by_tf_operation(tf_output.oper); | |||
| return op.outputs[tf_output.index]; | |||
| } | |||
| /// <summary> | |||
| /// Returns the <see cref="Tensor"/> with the given <paramref name="name"/>. | |||
| /// This method may be called concurrently from multiple threads. | |||
| @@ -3,13 +3,62 @@ using System.Runtime.InteropServices; | |||
| namespace Tensorflow | |||
| { | |||
| [StructLayout(LayoutKind.Sequential)] | |||
| public struct TF_ImportGraphDefResults | |||
| public class TF_ImportGraphDefResults : DisposableObject | |||
| { | |||
| public IntPtr return_tensors; | |||
| public IntPtr return_nodes; | |||
| /*public IntPtr return_nodes; | |||
| public IntPtr missing_unused_key_names; | |||
| public IntPtr missing_unused_key_indexes; | |||
| public IntPtr missing_unused_key_names_data; | |||
| public IntPtr missing_unused_key_names_data;*/ | |||
| public TF_ImportGraphDefResults(IntPtr handle) | |||
| { | |||
| _handle = handle; | |||
| } | |||
| public TF_Output[] return_tensors | |||
| { | |||
| get | |||
| { | |||
| IntPtr return_output_handle = IntPtr.Zero; | |||
| int num_outputs = -1; | |||
| c_api.TF_ImportGraphDefResultsReturnOutputs(_handle, ref num_outputs, ref return_output_handle); | |||
| TF_Output[] return_outputs = new TF_Output[num_outputs]; | |||
| unsafe | |||
| { | |||
| var tf_output_ptr = (TF_Output*)return_output_handle; | |||
| for (int i = 0; i < num_outputs; i++) | |||
| return_outputs[i] = *(tf_output_ptr + i); | |||
| return return_outputs; | |||
| } | |||
| } | |||
| } | |||
| public TF_Operation[] return_opers | |||
| { | |||
| get | |||
| { | |||
| return new TF_Operation[0]; | |||
| /*TF_Operation return_output_handle = new TF_Operation(); | |||
| int num_outputs = -1; | |||
| c_api.TF_ImportGraphDefResultsReturnOperations(_handle, ref num_outputs, ref return_output_handle); | |||
| TF_Operation[] return_outputs = new TF_Operation[num_outputs]; | |||
| unsafe | |||
| { | |||
| var tf_output_ptr = (TF_Operation*)return_output_handle; | |||
| for (int i = 0; i < num_outputs; i++) | |||
| return_outputs[i] = *(tf_output_ptr + i); | |||
| return return_outputs; | |||
| }*/ | |||
| } | |||
| } | |||
| public static implicit operator TF_ImportGraphDefResults(IntPtr handle) | |||
| => new TF_ImportGraphDefResults(handle); | |||
| public static implicit operator IntPtr(TF_ImportGraphDefResults results) | |||
| => results._handle; | |||
| protected override void DisposeUnmanagedResources(IntPtr handle) | |||
| => c_api.TF_DeleteImportGraphDefResults(handle); | |||
| } | |||
| } | |||
| @@ -65,9 +65,7 @@ namespace Tensorflow | |||
| } | |||
| public static implicit operator IntPtr(Status status) | |||
| { | |||
| return status._handle; | |||
| } | |||
| => status._handle; | |||
| protected override void DisposeUnmanagedResources(IntPtr handle) | |||
| => TF_DeleteStatus(handle); | |||