| @@ -0,0 +1,27 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow.Eager | |||
| { | |||
| public class Tape | |||
| { | |||
| public static bool IsDtypeTrainable(DataType dtype) | |||
| { | |||
| switch (dtype) | |||
| { | |||
| case DataType.DtHalf: | |||
| case DataType.DtBfloat16: | |||
| case DataType.DtFloat: | |||
| case DataType.DtDouble: | |||
| case DataType.DtComplex64: | |||
| case DataType.DtComplex128: | |||
| case DataType.DtResource: | |||
| case DataType.DtVariant: | |||
| return true; | |||
| default: | |||
| return false; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -12,7 +12,19 @@ namespace Tensorflow.Eager | |||
| { | |||
| public static void RecordGradient(string op_name, Tensor[] inputs, Dictionary<string, object> attrs, Tensor[] results, string name = "") | |||
| { | |||
| var input_ids = inputs.Select(x => x.Id).ToArray(); | |||
| var input_dtypes = inputs.Select(x => x.dtype).ToArray(); | |||
| bool should_record = false; | |||
| foreach (var input_dtype in input_dtypes) | |||
| { | |||
| if (Tape.IsDtypeTrainable(input_dtype.as_datatype_enum())) | |||
| { | |||
| should_record = true; | |||
| break; | |||
| } | |||
| } | |||
| if (!should_record) return; | |||
| } | |||
| } | |||
| } | |||
| @@ -1,5 +1,6 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| using System.Runtime.InteropServices; | |||
| using System.Text; | |||
| @@ -126,11 +127,11 @@ namespace Tensorflow | |||
| Graph._add_op(this); | |||
| } | |||
| public object get_attr(string name) | |||
| public object get_attr<T>(string name) | |||
| { | |||
| AttrValue x = null; | |||
| var fields = new string[] { "s", "i", "f", "b", "type", "shape", "tensor", "func" }; | |||
| var fields = new string[] { "s", "i", "f", "b", "Type", "Shape", "Tensor", "func" }; | |||
| using (var buf = new Buffer()) | |||
| { | |||
| @@ -141,12 +142,21 @@ namespace Tensorflow | |||
| switch (name) | |||
| { | |||
| case "T": | |||
| case "dtype": | |||
| return x.Type; | |||
| case "shape": | |||
| return x.Shape; | |||
| default: | |||
| throw new NotImplementedException($"{name}"); | |||
| switch (typeof(T).Name) | |||
| { | |||
| case "Boolean": | |||
| return x.B; | |||
| case "String": | |||
| return x.S; | |||
| default: | |||
| throw new NotImplementedException($"Unsupported field type in {x.ToString()}"); | |||
| } | |||
| } | |||
| } | |||
| @@ -21,12 +21,13 @@ namespace Tensorflow | |||
| var _op = _op_def_lib._apply_op_helper("Placeholder", keywords: keywords); | |||
| var _result = _op.outputs; | |||
| var _inputs_flat = _op.inputs; | |||
| var _attrs = new Dictionary<string, object>(); | |||
| _attrs["dtype"] = _op.get_attr("dtype"); | |||
| _attrs["shape"] = _op.get_attr("shape"); | |||
| var _attrs = new Dictionary<string, object>(); | |||
| _attrs["dtype"] = _op.get_attr<DataType>("dtype"); | |||
| _attrs["shape"] = _op.get_attr<int[]>("shape"); | |||
| _execute.record_gradient("Placeholder", _inputs_flat, _attrs, _result, name); | |||
| return new Tensor(_op, 0, dtype); | |||
| } | |||
| @@ -16,6 +16,9 @@ namespace Tensorflow | |||
| { | |||
| private readonly IntPtr _handle; | |||
| private int _id; | |||
| public int Id => _id; | |||
| public Graph Graph => op.Graph; | |||
| public Operation op { get; } | |||
| @@ -2,12 +2,14 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using Tensorflow.Eager; | |||
| namespace Tensorflow | |||
| { | |||
| public class gen_state_ops | |||
| { | |||
| public static OpDefLibrary _op_def_lib = new OpDefLibrary(); | |||
| public static Execute _execute = new Execute(); | |||
| /// <summary> | |||
| /// Holds state in the form of a tensor that persists across steps. | |||
| @@ -32,6 +34,14 @@ namespace Tensorflow | |||
| var _result = _op.outputs; | |||
| var _inputs_flat = _op.inputs; | |||
| var _attrs = new Dictionary<string, object>(); | |||
| _attrs["dtype"] = _op.get_attr<DataType>("dtype"); | |||
| _attrs["shape"] = _op.get_attr<int[]>("shape"); | |||
| _attrs["container"] = _op.get_attr<string>("container"); | |||
| _attrs["shared_name"] = _op.get_attr<string>("shared_name"); | |||
| _execute.record_gradient("Placeholder", _inputs_flat, _attrs, _result, name); | |||
| return new Tensor(_op, 0, dtype); | |||
| } | |||
| @@ -56,9 +66,17 @@ namespace Tensorflow | |||
| var _op = _op_def_lib._apply_op_helper("Assign", name: name, keywords: keywords); | |||
| var _result = _op.outputs[0]; | |||
| var _result = _op.outputs; | |||
| var _inputs_flat = _op.inputs; | |||
| return _result; | |||
| var _attrs = new Dictionary<string, object>(); | |||
| _attrs["T"] = _op.get_attr<DataType>("T"); | |||
| _attrs["validate_shape"] = _op.get_attr<bool>("validate_shape"); | |||
| _attrs["use_locking"] = _op.get_attr<bool>("use_locking"); | |||
| _execute.record_gradient("Placeholder", _inputs_flat, _attrs, _result, name); | |||
| return _result[0]; | |||
| } | |||
| } | |||
| } | |||