| @@ -12,15 +12,15 @@ TensorFlow.NET is a member project of [SciSharp](https://github.com/SciSharp) st | |||||
|  |  | ||||
| ### How to use | ### How to use | ||||
| Download the pre-compiled dll [here](tensorflow.so) and place it in the working folder. | |||||
| This is only need for Linux and Mac OS, and already packed for Windows. | |||||
| Install TensorFlow.NET through NuGet. | Install TensorFlow.NET through NuGet. | ||||
| ```sh | ```sh | ||||
| PM> Install-Package TensorFlow.NET | PM> Install-Package TensorFlow.NET | ||||
| ``` | ``` | ||||
| If you are using Linux or Mac OS, please download the pre-compiled dll [here](tensorflow.so) and place it in the working folder. This is only need for Linux and Mac OS, and already packed into NuGet for Windows. | |||||
| Import tensorflow.net. | Import tensorflow.net. | ||||
| ```cs | ```cs | ||||
| using Tensorflow; | using Tensorflow; | ||||
| ``` | ``` | ||||
| @@ -0,0 +1,17 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow.Eager | |||||
| { | |||||
| public class Execute | |||||
| { | |||||
| public void record_gradient(string op_name, Tensor[] inputs, Dictionary<string, object> attrs, Tensor[] results, string name = "") | |||||
| { | |||||
| if (inputs == null) | |||||
| inputs = new Tensor[0]; | |||||
| pywrap_tfe_src.RecordGradient(op_name, inputs, attrs, results, name); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,18 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Linq; | |||||
| using System.Text; | |||||
| namespace Tensorflow.Eager | |||||
| { | |||||
| /// <summary> | |||||
| /// python\eager\pywrap_tfe_src.cc | |||||
| /// </summary> | |||||
| public class pywrap_tfe_src | |||||
| { | |||||
| public static void RecordGradient(string op_name, Tensor[] inputs, Dictionary<string, object> attrs, Tensor[] results, string name = "") | |||||
| { | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -113,7 +113,7 @@ namespace Tensorflow | |||||
| op_def = g.GetOpDef(node_def.Op); | op_def = g.GetOpDef(node_def.Op); | ||||
| _handle = ops._create_c_op(g, node_def, inputs); | _handle = ops._create_c_op(g, node_def, inputs); | ||||
| _outputs = new Tensor[NumOutputs]; | _outputs = new Tensor[NumOutputs]; | ||||
| output_types = new TF_DataType[NumOutputs]; | output_types = new TF_DataType[NumOutputs]; | ||||
| @@ -128,21 +128,26 @@ namespace Tensorflow | |||||
| public object get_attr(string name) | public object get_attr(string name) | ||||
| { | { | ||||
| object ret = null; | |||||
| AttrValue x = null; | |||||
| var fields = new string[] { "s", "i", "f", "b", "type", "shape", "tensor", "func" }; | var fields = new string[] { "s", "i", "f", "b", "type", "shape", "tensor", "func" }; | ||||
| using (var buf = new Buffer()) | |||||
| { | |||||
| c_api.TF_OperationGetAttrValueProto(_handle, name, buf, status); | |||||
| status.Check(true); | |||||
| x = AttrValue.Parser.ParseFrom(buf); | |||||
| } | |||||
| switch (name) | switch (name) | ||||
| { | { | ||||
| case "dtype": | case "dtype": | ||||
| ret = _outputs[0]; | |||||
| break; | |||||
| return x.Type; | |||||
| case "shape": | case "shape": | ||||
| ret = new TensorShapeProto(); | |||||
| break; | |||||
| return x.Shape; | |||||
| default: | |||||
| throw new NotImplementedException($"{name}"); | |||||
| } | } | ||||
| return ret; | |||||
| } | } | ||||
| public TF_AttrMetadata GetAttributeMetadata(string attr_name, Status s) | public TF_AttrMetadata GetAttributeMetadata(string attr_name, Status s) | ||||
| @@ -3,14 +3,16 @@ using System.Collections.Generic; | |||||
| using System.IO; | using System.IO; | ||||
| using System.Text; | using System.Text; | ||||
| using Tensorflow; | using Tensorflow; | ||||
| using Tensorflow.Eager; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| public static class gen_array_ops | public static class gen_array_ops | ||||
| { | { | ||||
| public static OpDefLibrary _op_def_lib = new OpDefLibrary(); | public static OpDefLibrary _op_def_lib = new OpDefLibrary(); | ||||
| public static Execute _execute = new Execute(); | |||||
| public static Tensor placeholder(TF_DataType dtype, TensorShape shape = null) | |||||
| public static Tensor placeholder(TF_DataType dtype, TensorShape shape = null, string name = "") | |||||
| { | { | ||||
| var keywords = new Dictionary<string, object>(); | var keywords = new Dictionary<string, object>(); | ||||
| keywords.Add("dtype", dtype); | keywords.Add("dtype", dtype); | ||||
| @@ -24,6 +26,7 @@ namespace Tensorflow | |||||
| _attrs["dtype"] = _op.get_attr("dtype"); | _attrs["dtype"] = _op.get_attr("dtype"); | ||||
| _attrs["shape"] = _op.get_attr("shape"); | _attrs["shape"] = _op.get_attr("shape"); | ||||
| _execute.record_gradient("Placeholder", _inputs_flat, _attrs, _result, name); | |||||
| return new Tensor(_op, 0, dtype); | return new Tensor(_op, 0, dtype); | ||||
| } | } | ||||