| @@ -0,0 +1,27 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow.Eager | |||||
| { | |||||
| public class Tape | |||||
| { | |||||
| public static bool IsDtypeTrainable(DataType dtype) | |||||
| { | |||||
| switch (dtype) | |||||
| { | |||||
| case DataType.DtHalf: | |||||
| case DataType.DtBfloat16: | |||||
| case DataType.DtFloat: | |||||
| case DataType.DtDouble: | |||||
| case DataType.DtComplex64: | |||||
| case DataType.DtComplex128: | |||||
| case DataType.DtResource: | |||||
| case DataType.DtVariant: | |||||
| return true; | |||||
| default: | |||||
| return false; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -12,7 +12,19 @@ namespace Tensorflow.Eager | |||||
| { | { | ||||
| public static void RecordGradient(string op_name, Tensor[] inputs, Dictionary<string, object> attrs, Tensor[] results, string name = "") | public static void RecordGradient(string op_name, Tensor[] inputs, Dictionary<string, object> attrs, Tensor[] results, string name = "") | ||||
| { | { | ||||
| var input_ids = inputs.Select(x => x.Id).ToArray(); | |||||
| var input_dtypes = inputs.Select(x => x.dtype).ToArray(); | |||||
| bool should_record = false; | |||||
| foreach (var input_dtype in input_dtypes) | |||||
| { | |||||
| if (Tape.IsDtypeTrainable(input_dtype.as_datatype_enum())) | |||||
| { | |||||
| should_record = true; | |||||
| break; | |||||
| } | |||||
| } | |||||
| if (!should_record) return; | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -1,5 +1,6 @@ | |||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Linq; | |||||
| using System.Runtime.InteropServices; | using System.Runtime.InteropServices; | ||||
| using System.Text; | using System.Text; | ||||
| @@ -126,11 +127,11 @@ namespace Tensorflow | |||||
| Graph._add_op(this); | Graph._add_op(this); | ||||
| } | } | ||||
| public object get_attr(string name) | |||||
| public object get_attr<T>(string name) | |||||
| { | { | ||||
| AttrValue x = null; | AttrValue x = null; | ||||
| var fields = new string[] { "s", "i", "f", "b", "type", "shape", "tensor", "func" }; | |||||
| var fields = new string[] { "s", "i", "f", "b", "Type", "Shape", "Tensor", "func" }; | |||||
| using (var buf = new Buffer()) | using (var buf = new Buffer()) | ||||
| { | { | ||||
| @@ -141,12 +142,21 @@ namespace Tensorflow | |||||
| switch (name) | switch (name) | ||||
| { | { | ||||
| case "T": | |||||
| case "dtype": | case "dtype": | ||||
| return x.Type; | return x.Type; | ||||
| case "shape": | case "shape": | ||||
| return x.Shape; | return x.Shape; | ||||
| default: | default: | ||||
| throw new NotImplementedException($"{name}"); | |||||
| switch (typeof(T).Name) | |||||
| { | |||||
| case "Boolean": | |||||
| return x.B; | |||||
| case "String": | |||||
| return x.S; | |||||
| default: | |||||
| throw new NotImplementedException($"Unsupported field type in {x.ToString()}"); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -21,12 +21,13 @@ namespace Tensorflow | |||||
| var _op = _op_def_lib._apply_op_helper("Placeholder", keywords: keywords); | var _op = _op_def_lib._apply_op_helper("Placeholder", keywords: keywords); | ||||
| var _result = _op.outputs; | var _result = _op.outputs; | ||||
| var _inputs_flat = _op.inputs; | var _inputs_flat = _op.inputs; | ||||
| var _attrs = new Dictionary<string, object>(); | |||||
| _attrs["dtype"] = _op.get_attr("dtype"); | |||||
| _attrs["shape"] = _op.get_attr("shape"); | |||||
| var _attrs = new Dictionary<string, object>(); | |||||
| _attrs["dtype"] = _op.get_attr<DataType>("dtype"); | |||||
| _attrs["shape"] = _op.get_attr<int[]>("shape"); | |||||
| _execute.record_gradient("Placeholder", _inputs_flat, _attrs, _result, name); | _execute.record_gradient("Placeholder", _inputs_flat, _attrs, _result, name); | ||||
| return new Tensor(_op, 0, dtype); | return new Tensor(_op, 0, dtype); | ||||
| } | } | ||||
| @@ -16,6 +16,9 @@ namespace Tensorflow | |||||
| { | { | ||||
| private readonly IntPtr _handle; | private readonly IntPtr _handle; | ||||
| private int _id; | |||||
| public int Id => _id; | |||||
| public Graph Graph => op.Graph; | public Graph Graph => op.Graph; | ||||
| public Operation op { get; } | public Operation op { get; } | ||||
| @@ -2,12 +2,14 @@ | |||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Text; | using System.Text; | ||||
| using Tensorflow.Eager; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| public class gen_state_ops | public class gen_state_ops | ||||
| { | { | ||||
| public static OpDefLibrary _op_def_lib = new OpDefLibrary(); | public static OpDefLibrary _op_def_lib = new OpDefLibrary(); | ||||
| public static Execute _execute = new Execute(); | |||||
| /// <summary> | /// <summary> | ||||
| /// Holds state in the form of a tensor that persists across steps. | /// Holds state in the form of a tensor that persists across steps. | ||||
| @@ -32,6 +34,14 @@ namespace Tensorflow | |||||
| var _result = _op.outputs; | var _result = _op.outputs; | ||||
| var _inputs_flat = _op.inputs; | var _inputs_flat = _op.inputs; | ||||
| var _attrs = new Dictionary<string, object>(); | |||||
| _attrs["dtype"] = _op.get_attr<DataType>("dtype"); | |||||
| _attrs["shape"] = _op.get_attr<int[]>("shape"); | |||||
| _attrs["container"] = _op.get_attr<string>("container"); | |||||
| _attrs["shared_name"] = _op.get_attr<string>("shared_name"); | |||||
| _execute.record_gradient("Placeholder", _inputs_flat, _attrs, _result, name); | |||||
| return new Tensor(_op, 0, dtype); | return new Tensor(_op, 0, dtype); | ||||
| } | } | ||||
| @@ -56,9 +66,17 @@ namespace Tensorflow | |||||
| var _op = _op_def_lib._apply_op_helper("Assign", name: name, keywords: keywords); | var _op = _op_def_lib._apply_op_helper("Assign", name: name, keywords: keywords); | ||||
| var _result = _op.outputs[0]; | |||||
| var _result = _op.outputs; | |||||
| var _inputs_flat = _op.inputs; | var _inputs_flat = _op.inputs; | ||||
| return _result; | |||||
| var _attrs = new Dictionary<string, object>(); | |||||
| _attrs["T"] = _op.get_attr<DataType>("T"); | |||||
| _attrs["validate_shape"] = _op.get_attr<bool>("validate_shape"); | |||||
| _attrs["use_locking"] = _op.get_attr<bool>("use_locking"); | |||||
| _execute.record_gradient("Placeholder", _inputs_flat, _attrs, _result, name); | |||||
| return _result[0]; | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||