diff --git a/src/TensorFlowNET.Core/Eager/Tape.cs b/src/TensorFlowNET.Core/Eager/Tape.cs new file mode 100644 index 00000000..6d469370 --- /dev/null +++ b/src/TensorFlowNET.Core/Eager/Tape.cs @@ -0,0 +1,27 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Eager +{ + public class Tape + { + public static bool IsDtypeTrainable(DataType dtype) + { + switch (dtype) + { + case DataType.DtHalf: + case DataType.DtBfloat16: + case DataType.DtFloat: + case DataType.DtDouble: + case DataType.DtComplex64: + case DataType.DtComplex128: + case DataType.DtResource: + case DataType.DtVariant: + return true; + default: + return false; + } + } + } +} diff --git a/src/TensorFlowNET.Core/Eager/pywrap_tfe_src.cs b/src/TensorFlowNET.Core/Eager/pywrap_tfe_src.cs index b972e4ec..79dc67a8 100644 --- a/src/TensorFlowNET.Core/Eager/pywrap_tfe_src.cs +++ b/src/TensorFlowNET.Core/Eager/pywrap_tfe_src.cs @@ -12,7 +12,19 @@ namespace Tensorflow.Eager { public static void RecordGradient(string op_name, Tensor[] inputs, Dictionary attrs, Tensor[] results, string name = "") { - + var input_ids = inputs.Select(x => x.Id).ToArray(); + var input_dtypes = inputs.Select(x => x.dtype).ToArray(); + + bool should_record = false; + foreach (var input_dtype in input_dtypes) + { + if (Tape.IsDtypeTrainable(input_dtype.as_datatype_enum())) + { + should_record = true; + break; + } + } + if (!should_record) return; } } } diff --git a/src/TensorFlowNET.Core/Operations/Operation.cs b/src/TensorFlowNET.Core/Operations/Operation.cs index 81743336..a4fd5838 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.cs @@ -1,5 +1,6 @@ using System; using System.Collections.Generic; +using System.Linq; using System.Runtime.InteropServices; using System.Text; @@ -126,11 +127,11 @@ namespace Tensorflow Graph._add_op(this); } - public object get_attr(string name) + public object get_attr(string name) { AttrValue x = null; - var fields = new string[] { "s", "i", "f", "b", "type", "shape", "tensor", "func" }; + var fields = new string[] { "s", "i", "f", "b", "Type", "Shape", "Tensor", "func" }; using (var buf = new Buffer()) { @@ -141,12 +142,21 @@ namespace Tensorflow switch (name) { + case "T": case "dtype": return x.Type; case "shape": return x.Shape; default: - throw new NotImplementedException($"{name}"); + switch (typeof(T).Name) + { + case "Boolean": + return x.B; + case "String": + return x.S; + default: + throw new NotImplementedException($"Unsupported field type in {x.ToString()}"); + } } } diff --git a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs index 80fd60a9..f9896275 100644 --- a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs @@ -21,12 +21,13 @@ namespace Tensorflow var _op = _op_def_lib._apply_op_helper("Placeholder", keywords: keywords); var _result = _op.outputs; var _inputs_flat = _op.inputs; - var _attrs = new Dictionary(); - _attrs["dtype"] = _op.get_attr("dtype"); - _attrs["shape"] = _op.get_attr("shape"); + var _attrs = new Dictionary(); + _attrs["dtype"] = _op.get_attr("dtype"); + _attrs["shape"] = _op.get_attr("shape"); _execute.record_gradient("Placeholder", _inputs_flat, _attrs, _result, name); + return new Tensor(_op, 0, dtype); } diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.cs b/src/TensorFlowNET.Core/Tensors/Tensor.cs index b8ecf011..cf72bb23 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.cs @@ -16,6 +16,9 @@ namespace Tensorflow { private readonly IntPtr _handle; + private int _id; + public int Id => _id; + public Graph Graph => op.Graph; public Operation op { get; } diff --git a/src/TensorFlowNET.Core/Variables/gen_state_ops.py.cs b/src/TensorFlowNET.Core/Variables/gen_state_ops.py.cs index e1585b42..a36edec6 100644 --- a/src/TensorFlowNET.Core/Variables/gen_state_ops.py.cs +++ b/src/TensorFlowNET.Core/Variables/gen_state_ops.py.cs @@ -2,12 +2,14 @@ using System; using System.Collections.Generic; using System.Text; +using Tensorflow.Eager; namespace Tensorflow { public class gen_state_ops { public static OpDefLibrary _op_def_lib = new OpDefLibrary(); + public static Execute _execute = new Execute(); /// /// Holds state in the form of a tensor that persists across steps. @@ -32,6 +34,14 @@ namespace Tensorflow var _result = _op.outputs; var _inputs_flat = _op.inputs; + var _attrs = new Dictionary(); + _attrs["dtype"] = _op.get_attr("dtype"); + _attrs["shape"] = _op.get_attr("shape"); + _attrs["container"] = _op.get_attr("container"); + _attrs["shared_name"] = _op.get_attr("shared_name"); + + _execute.record_gradient("Placeholder", _inputs_flat, _attrs, _result, name); + return new Tensor(_op, 0, dtype); } @@ -56,9 +66,17 @@ namespace Tensorflow var _op = _op_def_lib._apply_op_helper("Assign", name: name, keywords: keywords); - var _result = _op.outputs[0]; + var _result = _op.outputs; var _inputs_flat = _op.inputs; - return _result; + + var _attrs = new Dictionary(); + _attrs["T"] = _op.get_attr("T"); + _attrs["validate_shape"] = _op.get_attr("validate_shape"); + _attrs["use_locking"] = _op.get_attr("use_locking"); + + _execute.record_gradient("Placeholder", _inputs_flat, _attrs, _result, name); + + return _result[0]; } } }