diff --git a/src/TensorFlowNET.Core/Contexts/FunctionCallOptions.cs b/src/TensorFlowNET.Core/Contexts/FunctionCallOptions.cs
index 2fcf9dce..71312d11 100644
--- a/src/TensorFlowNET.Core/Contexts/FunctionCallOptions.cs
+++ b/src/TensorFlowNET.Core/Contexts/FunctionCallOptions.cs
@@ -2,6 +2,7 @@
using System.Collections.Generic;
using System.Text;
using Google.Protobuf;
+using Protobuf.Text;
using static Tensorflow.Binding;
namespace Tensorflow.Contexts
diff --git a/src/TensorFlowNET.Core/Eager/EagerRunner.MustRecordGradient.cs b/src/TensorFlowNET.Core/Eager/EagerRunner.MustRecordGradient.cs
index c4bce84f..33382703 100644
--- a/src/TensorFlowNET.Core/Eager/EagerRunner.MustRecordGradient.cs
+++ b/src/TensorFlowNET.Core/Eager/EagerRunner.MustRecordGradient.cs
@@ -12,18 +12,36 @@ namespace Tensorflow.Eager
return HasGradientTape();
}
- private bool ShouldRecord(Tensor[] inputs)
+ public int TFE_TapeSetPossibleGradientTypes(Tensor[] tensors)
{
- bool should_record = false;
- foreach (var tape in tf.GetTapeSet())
+ var tape_set = tf.GetTapeSet();
+ var input_ids = MakeTensorIDList(tensors);
+ var input_dtypes = MakeTensorDtypeList(tensors);
+ bool some_tape_watching = false;
+ if (tape_set is not null && tape_set.Count > 0)
{
- if (tape.ShouldRecord(inputs))
+ foreach (var tape in tape_set)
{
- should_record = true;
- break;
+ if (tape.ShouldRecord(input_ids, input_dtypes))
+ {
+ if (tape.Persistent || some_tape_watching)
+ {
+ return gradients_util.POSSIBLE_GRADIENT_TYPES_HIGHER_ORDER;
+ }
+ some_tape_watching = true;
+ }
}
}
- return should_record;
+ // skip the forward_accumulators.
+
+ if (some_tape_watching)
+ {
+ return gradients_util.POSSIBLE_GRADIENT_TYPES_FIRST_ORDER;
+ }
+ else
+ {
+ return gradients_util.POSSIBLE_GRADIENT_TYPES_NONE;
+ }
}
}
}
diff --git a/src/TensorFlowNET.Core/Eager/EagerRunner.RecordGradient.cs b/src/TensorFlowNET.Core/Eager/EagerRunner.RecordGradient.cs
index cfcea55a..59d5fd03 100644
--- a/src/TensorFlowNET.Core/Eager/EagerRunner.RecordGradient.cs
+++ b/src/TensorFlowNET.Core/Eager/EagerRunner.RecordGradient.cs
@@ -13,7 +13,17 @@ namespace Tensorflow.Eager
Tensor[] results,
BackwardFunction backwardFunction = null)
{
- bool should_record = ShouldRecord(inputs);
+ var input_ids = MakeTensorIDList(inputs);
+ var input_dtypes = MakeTensorDtypeList(inputs);
+ bool should_record = false;
+ foreach (var tape in tf.GetTapeSet())
+ {
+ if (tape.ShouldRecord(input_ids, input_dtypes))
+ {
+ should_record = true;
+ break;
+ }
+ }
if (!should_record)
{
@@ -59,7 +69,7 @@ namespace Tensorflow.Eager
op_inputs = inputs;*/
backwardFunction = backwardFunction ?? GetGradientFunction(op_name, inputs, attrs, results);
- TapeSetRecordOperation(op_name, inputs, results, backwardFunction);
+ TapeSetRecordOperation(op_name, inputs, results, input_ids, input_dtypes, backwardFunction);
return true;
}
@@ -129,10 +139,5 @@ namespace Tensorflow.Eager
{
return HasGradientTape();
}
-
- TF_DataType[] MakeTensorDtypeList(Tensor[] tensors)
- {
- return tensors.Select(x => x.dtype).ToArray();
- }
}
}
diff --git a/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_TapeGradient.cs b/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_TapeGradient.cs
index c96d09e5..1f7b3ae6 100644
--- a/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_TapeGradient.cs
+++ b/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_TapeGradient.cs
@@ -1,6 +1,8 @@
-using System;
+using OneOf.Types;
+using System;
using Tensorflow.Gradients;
using Tensorflow.Util;
+using static Tensorflow.Binding;
namespace Tensorflow.Eager
{
@@ -9,40 +11,183 @@ namespace Tensorflow.Eager
///
public partial class EagerRunner
{
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ /// determines the value returned if the target and
+ /// sources are unconnected.When 'none' the value returned is None wheras when
+ /// 'zero' a zero tensor in the same shape as the sources is returned.
+ ///
+ ///
public Tensor[] TFE_TapeGradient(ITape tape,
Tensor[] target,
Tensor[] sources,
- Tensor[] output_gradients)
+ List output_gradients,
+ Tensor[] sources_raw,
+ string unconnected_gradients)
{
- var target_vec = target;
- var sources_vec = sources;
- var sources_set = sources_vec;
+ if (!tape.Persistent)
+ {
+ var tape_set = tf.GetTapeSet();
+ if (tape_set.Contains(tape))
+ {
+ throw new RuntimeError("gradient() cannot be invoked within the " +
+ "GradientTape context (i.e., while operations are being " +
+ "recorded). Either move the call to gradient() to be " +
+ "outside the 'with tf.GradientTape' block, or " +
+ "use a persistent tape: " +
+ "'with tf.GradientTape(persistent=true)'");
+ }
+ }
+
+ var target_vec = MakeTensorIDList(target);
+ var sources_vec = MakeTensorIDList(sources);
+ HashSet sources_set = new HashSet(sources_vec);
+ var source_tensors_that_are_targets = new UnorderedMap();
+
+ int len = target.Length;
+ for(int i = 0; i < len; i++)
+ {
+ var target_id = target_vec[i];
+ if (sources_set.Contains(target_id))
+ {
+ var tensor = target[i];
+ source_tensors_that_are_targets[target_id] = TapeTensorFromTensor(tensor);
+ }
+ }
+
+ List outgrad_vec = new();
+ if(output_gradients is not null)
+ {
+ outgrad_vec = output_gradients.ToList();
+ }
+ var result = tape.ComputeGradient(target_vec, sources_vec, source_tensors_that_are_targets, outgrad_vec, false);
- var seq_array = target;
- var source_tensors_that_are_targets = new UnorderedMap();
- for (int i = 0; i < target.Length; ++i)
+ bool unconnected_gradients_zero = unconnected_gradients == "zero";
+ Tensor[] sources_obj = null;
+ if (unconnected_gradients_zero)
{
- source_tensors_that_are_targets.Add(target_vec[i], new TapeTensor(seq_array[i]));
+ sources_obj = MakeTensorList(sources_raw);
}
- if (output_gradients != null)
+ if (result.Length > 0)
{
- throw new NotImplementedException("");
+ for(int i = 0; i < result.Length; i++)
+ {
+ if (result[i] is null && unconnected_gradients_zero)
+ {
+ var dtype = sources_obj[i].dtype;
+ result[i] = new TapeTensor(sources_vec[i], dtype, sources_obj[i]).ZerosLike();
+ }
+ }
}
- else
+ return result;
+ }
+
+ Tensor[] MakeTensorList(IEnumerable tensors)
+ {
+ return tensors.ToArray();
+ }
+
+ long[] MakeTensorIDList(Tensor[] tensors)
+ {
+ int len = tensors.Length;
+ long[] ids = new long[len];
+ for(int i = 0; i < len; i++)
+ {
+ var tensor = tensors[i];
+ ids[i] = tensor.Id;
+ }
+ return ids;
+ }
+
+ TF_DataType[] MakeTensorDtypeList(Tensor[] tensors)
+ {
+ int len = tensors.Length;
+ TF_DataType[] dtypes = new TF_DataType[len];
+ for (int i = 0; i < len; i++)
{
- output_gradients = new Tensor[0];
+ var tensor = tensors[i];
+ dtypes[i] = tensor.dtype;
}
+ return dtypes;
+ }
- var outgrad_vec = MakeTensorList(output_gradients);
+ TapeTensor TapeTensorFromTensor(Tensor tensor)
+ {
+ long id = tensor.Id;
+ var dtype = tensor.dtype;
+ if (tensor is EagerTensor)
+ {
+ var handle = tensor.EagerTensorHandle;
+ if (DTypeNeedsHandleData(dtype))
+ {
+ return new TapeTensor(id, c_api.TFE_TensorHandleDataType(handle), tensor);
+ }
+
+ Status status = new();
+ int num_dims = c_api.TFE_TensorHandleNumDims(handle, status);
+ long[] dims = new long[num_dims];
+ for(int i = 0; i < num_dims; i++)
+ {
+ dims[i] = c_api.TFE_TensorHandleDim(handle, i, status);
+ }
+ Shape tensor_shape = new(dims);
+
+ if(status.Code != TF_Code.TF_OK)
+ {
+ return new TapeTensor(id, TF_DataType.DtInvalid, Shape.Null);
+ }
+ else
+ {
+ return new TapeTensor(id, dtype, tensor_shape);
+ }
+ }
+ var shape_tuple = tensor.shape.dims;
+ if(ListContainNone(shape_tuple) || DTypeNeedsHandleData(dtype))
+ {
+ return new TapeTensor(id, dtype, tensor);
+ }
+ long[] l = new long[shape_tuple.Length];
+ for(int i = 0; i < shape_tuple.Length; i++)
+ {
+ if (shape_tuple[i] < 0)
+ {
+ l[i] = 0;
+ }
+ else
+ {
+ l[i] = shape_tuple[i];
+ }
+ }
+ return new TapeTensor(id, dtype, new Shape(l));
+ }
- return tape.ComputeGradient(target_vec, sources_vec, source_tensors_that_are_targets, outgrad_vec);
+ bool DTypeNeedsHandleData(TF_DataType dtype)
+ {
+ return dtype == dtypes.variant || dtype == dtypes.resource;
}
- Tensor[] MakeTensorList(Tensor[] tensors)
+ bool ListContainNone(long[] list)
{
- return tensors;
+ int len = list.Length;
+ if(len == 0)
+ {
+ return true;
+ }
+ for(int i = 0; i < len; i++)
+ {
+ if (list[i] == -1)
+ {
+ return true;
+ }
+ }
+ return false;
}
}
}
diff --git a/src/TensorFlowNET.Core/Eager/EagerRunner.TapeSetRecordBackprop.cs b/src/TensorFlowNET.Core/Eager/EagerRunner.TapeSetRecordBackprop.cs
index e8751aed..9bcc8fe2 100644
--- a/src/TensorFlowNET.Core/Eager/EagerRunner.TapeSetRecordBackprop.cs
+++ b/src/TensorFlowNET.Core/Eager/EagerRunner.TapeSetRecordBackprop.cs
@@ -7,8 +7,9 @@ namespace Tensorflow.Eager
public partial class EagerRunner
{
void TapeSetRecordBackprop(string op_type,
- Tensor[] input_tensors,
- TapeTensor[] output_tensors,
+ TapeTensor[] output_info,
+ long[] input_ids,
+ TF_DataType[] input_detyps,
BackwardFunction backward_function)
{
if (!CouldBackprop())
@@ -18,7 +19,7 @@ namespace Tensorflow.Eager
foreach (var tape in tf.GetTapeSet())
{
- tape.RecordOperation(op_type, input_tensors, output_tensors, backward_function);
+ tape.RecordOperation(op_type, output_info, input_ids, input_detyps, backward_function);
}
}
}
diff --git a/src/TensorFlowNET.Core/Eager/EagerRunner.TapeSetRecordOperation.cs b/src/TensorFlowNET.Core/Eager/EagerRunner.TapeSetRecordOperation.cs
index 42e1cff9..3987e7a3 100644
--- a/src/TensorFlowNET.Core/Eager/EagerRunner.TapeSetRecordOperation.cs
+++ b/src/TensorFlowNET.Core/Eager/EagerRunner.TapeSetRecordOperation.cs
@@ -10,18 +10,28 @@ namespace Tensorflow.Eager
public bool TapeSetRecordOperation(string op_type,
Tensor[] input_tensors,
Tensor[] output_tensors,
+ long[] input_ids,
+ TF_DataType[] input_dtypes,
BackwardFunction backward_function)
{
- var output_info = output_tensors.Select(x => new TapeTensor(x)).ToArray();
-
+ var output_info = output_tensors.Select(t => TapeTensorFromTensor(t)).ToArray();
if (!TapeSetRecordForwardprop(op_type, input_tensors, output_info,
backward_function))
return false;
- TapeSetRecordBackprop(op_type, input_tensors, output_info,
+ TapeSetRecordBackprop(op_type, output_info, input_ids, input_dtypes,
backward_function);
return true;
}
+
+ public void TFE_TapeSetRecordOperation(string op_type, Tensor[] output_tensors,
+ Tensor[] input_tensors, BackwardFunction backward_function)
+ {
+ var input_ids = MakeTensorIDList(input_tensors);
+ var input_dtypes = MakeTensorDtypeList(input_tensors);
+ TapeSetRecordOperation(op_type, input_tensors, output_tensors, input_ids, input_dtypes,
+ backward_function);
+ }
}
}
diff --git a/src/TensorFlowNET.Core/Eager/IEagerRunner.cs b/src/TensorFlowNET.Core/Eager/IEagerRunner.cs
index 7baf4cd7..21a33669 100644
--- a/src/TensorFlowNET.Core/Eager/IEagerRunner.cs
+++ b/src/TensorFlowNET.Core/Eager/IEagerRunner.cs
@@ -29,7 +29,14 @@ namespace Tensorflow.Eager
Tensor[] TFE_TapeGradient(ITape tape,
Tensor[] target,
Tensor[] sources,
- Tensor[] output_gradients);
+ List output_gradients,
+ Tensor[] sources_raw,
+ string unconnected_gradients);
+
+ void TFE_TapeSetRecordOperation(string op_type, Tensor[] output_tensors,
+ Tensor[] input_tensors, BackwardFunction backward_function);
+
+ int TFE_TapeSetPossibleGradientTypes(Tensor[] tensors);
bool RecordGradient(string op_name,
Tensor[] inputs,
diff --git a/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs b/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs
index fbebd4d6..5c2d3a8d 100644
--- a/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs
+++ b/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs
@@ -18,12 +18,13 @@ namespace Tensorflow.Functions
public class ConcreteFunction: Trackable
{
protected IEnumerable _captured_inputs;
- internal FuncGraph func_graph;
protected DelayedRewriteGradientFunctions _delayed_rewrite_functions;
protected Dictionary _attrs;
protected FunctionSpec _function_spec;
protected FunctionSpec _pre_initialized_function_spec = null;
protected EagerDefinedFunction _inference_function;
+ protected Dictionary _tape_functions_cache = new();
+ internal FuncGraph func_graph;
internal ForwardBackwardCall forward_backward;
public Tensor[] Inputs => func_graph.Inputs;
public Tensor[] CapturedInputs => func_graph.external_captures;
@@ -156,6 +157,17 @@ namespace Tensorflow.Functions
{
var executing_eagerly = tf.Context.executing_eagerly();
var default_graph = ops.get_default_graph();
+ // TODO(Rinne): deal with `default_graph.building_function`
+
+ var tempvv = func_graph.Variables;
+ if(tf.GetTapeSet().Count > 0 || default_graph is FuncGraph)
+ {
+ foreach(var v in this.func_graph.Variables)
+ {
+ resource_variable_ops.variable_accessed(v);
+ }
+ }
+
var tensor_inputs = new Tensors();
foreach (var (i, arg) in enumerate(args))
{
@@ -223,11 +235,16 @@ namespace Tensorflow.Functions
{
input_tangents = new TangentInfo();
}
- if(possible_gradient_type == gradients_util.POSSIBLE_GRADIENT_TYPES_FIRST_ORDER || tf.Runner.MustRecordGradient())
+ if(possible_gradient_type == gradients_util.POSSIBLE_GRADIENT_TYPES_FIRST_ORDER)
{
if(input_tangents.Indices is not null || executing_eagerly)
{
- var functions = new FirstOrderTapeGradientFunctions(func_graph, false);
+ string cache_key = "first_order";
+ if(!_tape_functions_cache.TryGetValue(cache_key, out var functions))
+ {
+ functions = new FirstOrderTapeGradientFunctions(func_graph, false);
+ _tape_functions_cache[cache_key] = functions;
+ }
return new ForwardBackwardCall(functions, args, tape_watching: true);
}
else
@@ -241,7 +258,7 @@ namespace Tensorflow.Functions
}
// TODO(Rinne): add arg "input_tagents" for ForwardBackwardCall.
- return new ForwardBackwardCall(_delayed_rewrite_functions, args, tape_watching: tf.Runner.MustRecordGradient());
+ return new ForwardBackwardCall(_delayed_rewrite_functions, args, tape_watching: false);
}
internal void set_variables(IEnumerable variables)
diff --git a/src/TensorFlowNET.Core/Functions/EagerDefinedFunction.cs b/src/TensorFlowNET.Core/Functions/EagerDefinedFunction.cs
index c2f8e016..cc38683d 100644
--- a/src/TensorFlowNET.Core/Functions/EagerDefinedFunction.cs
+++ b/src/TensorFlowNET.Core/Functions/EagerDefinedFunction.cs
@@ -124,17 +124,16 @@ namespace Tensorflow.Functions
// TODO(Rinne): Add arg `CancellationManager`.
// TODO(Rinne): Check the arg length.
var function_call_options = tf.Context.FunctionCallOptions;
- string config;
- if (function_call_options.config_proto_serialized().Length == 0)
- {
- config = function_utils.get_disabled_rewriter_config().ToString();
- }
- else
- {
- config = function_call_options.config_proto_serialized().ToString();
- }
+ string config = ""; // TODO(Rinne): revise it. The following code should work but not, for unclear reasons.
- config = ""; // TODO(Rinne): revise it.
+ //if (function_call_options.config_proto_serialized().Length == 0)
+ //{
+ // config = function_utils.get_disabled_rewriter_config().ToStringUtf8();
+ //}
+ //else
+ //{
+ // config = function_call_options.config_proto_serialized().ToStringUtf8();
+ //}
string executor_type = function_call_options.ExecutorType ?? "";
var executing_eagerly = tf.Context.executing_eagerly();
diff --git a/src/TensorFlowNET.Core/Functions/FirstOrderTapeGradientFunctions.cs b/src/TensorFlowNET.Core/Functions/FirstOrderTapeGradientFunctions.cs
index c0e69dba..bfb0defc 100644
--- a/src/TensorFlowNET.Core/Functions/FirstOrderTapeGradientFunctions.cs
+++ b/src/TensorFlowNET.Core/Functions/FirstOrderTapeGradientFunctions.cs
@@ -14,12 +14,11 @@ namespace Tensorflow.Functions
}
- public override EagerDefinedFunction ForwardAndBackwardFunctions(Tensors inference_args)
+ public override (EagerDefinedFunction, FuncGraph, ConcreteFunction, List, int)
+ ForwardAndBackwardFunctions(Tensors inference_args)
{
- var outputs = _func_graph.Outputs;
- (_forward_function, _forward_graph, _backward_function, _forwardprop_output_indices, _num_forwardprop_outputs)
- = BuildFunctionsForOutputs(outputs, inference_args);
- return _forward_function;
+ var outputs = _func_graph.Outputs.Take(_num_inference_outputs).ToArray();
+ return BuildFunctionsForOutputs(outputs, inference_args);
}
}
}
diff --git a/src/TensorFlowNET.Core/Functions/Function.cs b/src/TensorFlowNET.Core/Functions/Function.cs
index a53df14c..ea1b3eec 100644
--- a/src/TensorFlowNET.Core/Functions/Function.cs
+++ b/src/TensorFlowNET.Core/Functions/Function.cs
@@ -14,7 +14,6 @@ namespace Tensorflow
protected ConcreteFunction _concrete_variable_creation_fn;
protected bool _autograph;
protected TracingCompiler _variable_creation_fn;
- protected bool _has_initialized;
public string Name { get; set; }
public Function(Func csharp_function,
string name, bool auto_graph = true)
@@ -22,7 +21,6 @@ namespace Tensorflow
_csharp_function = csharp_function;
Name = name;
_autograph = auto_graph;
- _has_initialized = false;
}
public virtual Tensors Apply(Tensors inputs)
@@ -38,10 +36,11 @@ namespace Tensorflow
protected virtual Tensors _call(Tensors inputs)
{
- if (!_has_initialized)
+ if(_variable_creation_fn is not null)
{
- _initialize(inputs);
+ return _variable_creation_fn.Apply(inputs);
}
+ _initialize(inputs);
return _concrete_variable_creation_fn.CallFlat(inputs,
_concrete_variable_creation_fn.CapturedInputs);
@@ -63,7 +62,6 @@ namespace Tensorflow
_variable_creation_fn = _compiler(_csharp_function);
_variable_creation_fn._name = this.Name;
_concrete_variable_creation_fn = _variable_creation_fn._get_concrete_function_internal_garbage_collected(args);
- _has_initialized = true;
}
}
}
diff --git a/src/TensorFlowNET.Core/Functions/TapeGradientFunctions.cs b/src/TensorFlowNET.Core/Functions/TapeGradientFunctions.cs
index 638aeaf1..3895226e 100644
--- a/src/TensorFlowNET.Core/Functions/TapeGradientFunctions.cs
+++ b/src/TensorFlowNET.Core/Functions/TapeGradientFunctions.cs
@@ -24,23 +24,40 @@ namespace Tensorflow.Functions
protected string _INFERENCE_PREFIX = "__inference_";
protected FuncGraph _func_graph;
- protected EagerDefinedFunction _forward_function;
+ protected EagerDefinedFunction _forward;
protected FuncGraph _forward_graph;
+ protected List _forwardprop_input_indices;
protected List _forwardprop_output_indices;
protected int _num_forwardprop_outputs;
- protected ConcreteFunction _backward_function;
+ protected int _num_inference_outputs;
+ protected int _num_outputs;
+ protected int _num_trainable_inference_outputs;
+ protected ConcreteFunction _backward;
BackwardFunction _backward_function_wrapper;
public TapeGradientFunctions(FuncGraph func_graph,
bool need_gradients_for_jvps)
{
_func_graph = func_graph;
+ _forward_graph = null;
+ _forward = null;
+ _backward = null;
+ _num_outputs = func_graph.Outputs.Length;
+ _forwardprop_output_indices = null;
+ _num_forwardprop_outputs = 0;
+ _num_inference_outputs = func_graph.Outputs.Length;
+ _num_trainable_inference_outputs = func_graph.Outputs.Where(t => backprop_util.IsTrainable(t)).Count();
}
public virtual EagerDefinedFunction Forward(Tensors inference_args, Tensors input_tangents = null)
{
// TODO(Rinne): add input_tangents arg.
- return ForwardAndBackwardFunctions(inference_args);
+ if(_forward is null)
+ {
+ (_forward, _forward_graph, _backward, _forwardprop_output_indices, _num_forwardprop_outputs)
+ = ForwardAndBackwardFunctions(inference_args);
+ }
+ return _forward;
}
///
@@ -51,9 +68,13 @@ namespace Tensorflow.Functions
public virtual void Record(Tensors flat_outputs, Tensors inference_args)
{
// TODO(Rinne): add arg `input_tagents`.
- var (backward_function, to_record) = _wrap_backward_function(_forward_graph, _backward_function, flat_outputs);
- tf.Runner.RecordGradient(_forward_function.Name, inference_args, new object[0], to_record,
- getBackwardFunction: backward_function);
+ var (backward_function, to_record) = _wrap_backward_function(_forward_graph, _backward, flat_outputs);
+ if(_forwardprop_output_indices is not null && _forwardprop_output_indices.Count > 0)
+ {
+ // TODO(Rinne): implement it.
+ throw new NotImplementedException();
+ }
+ tf.Runner.TFE_TapeSetRecordOperation(_forward.Signature.Name, to_record, inference_args, backward_function);
}
///
@@ -65,66 +86,95 @@ namespace Tensorflow.Functions
///
(BackwardFunction, Tensors) _wrap_backward_function(FuncGraph forward_graph, ConcreteFunction backward, Tensors outputs)
{
+ var capture_mapping = zip(forward_graph.Outputs.Select(t => ops.tensor_id(t)), outputs)
+ .ToDictionary(x => x.Item1, x => x.Item2);
+ var captured_inputs = backward.CapturedInputs;
+ var remapped_captures = captured_inputs.Select(c =>
+ {
+ if (capture_mapping.TryGetValue(ops.tensor_id(c), out var value))
+ {
+ return value;
+ }
+ else
+ {
+ return c;
+ }
+ }).ToArray();
+ if(remapped_captures.Where(t => t is not EagerTensor).Any(t => t.graph == forward_graph))
+ {
+ var incorrect_mapping = remapped_captures.Where(t => t is not EagerTensor && t.graph != forward_graph);
+ throw new RuntimeError($"Failed to map all backward graph captures to " +
+ $"the forward graph. Incorrectly mapped: {string.Join(", ", incorrect_mapping)}");
+ }
+
+ Dictionary variant_zeros_like = new Dictionary();
var backward_function_inputs = backward.Inputs.Length - backward.CapturedInputs.Length;
var recorded_outputs = new Tensors();
- var trainable_recorded_outputs = 0;
- foreach (var (output_index, output) in enumerate(outputs))
+ int trainable_recorded_outputs = 0;
+ var skip_positions = new HashSet();
+ var relevant_outputs = outputs;
+ foreach (var (output_index, output) in enumerate(relevant_outputs))
{
if (trainable_recorded_outputs < backward_function_inputs)
recorded_outputs.Add(output);
- if (gradients_util.IsTrainable(output))
- trainable_recorded_outputs += 1;
+ if (backprop_util.IsTrainable(output))
+ trainable_recorded_outputs++;
+ else
+ skip_positions.Add(output_index);
+ if (output.dtype == dtypes.variant)
+ variant_zeros_like[output_index] = default_gradient.zeros_like(output);
}
- if(_backward_function_wrapper == null)
+ _backward_function_wrapper = (args, unneeded_gradients) =>
{
- var capture_mapping = new Dictionary();
- foreach (var (i, output) in enumerate(outputs))
- capture_mapping[forward_graph.Outputs[i].Id] = output;
-
- var remapped_captures = new Tensors();
- foreach (var capture in backward.CapturedInputs)
- {
- if (capture_mapping.ContainsKey(capture.Id))
- remapped_captures.Add(capture_mapping[capture.Id]);
- }
-
- var skip_positions = new List();
- foreach (var (output_index, output) in enumerate(outputs))
+ if(backward.Outputs is null || backward.Outputs.Length == 0)
{
- if (!gradients_util.IsTrainable(output))
- skip_positions.Add(output_index);
+ return backward.FlatStructuredOutputs;
}
- _backward_function_wrapper = (args, unneeded_gradients) =>
+ var processed_args = new Tensors();
+ int input_index = 0;
+ foreach (var (output_index, arg) in enumerate(args))
{
- var processed_args = new Tensors();
- var input_index = 0;
- foreach (var (output_index, arg) in enumerate(args))
+ if (skip_positions.Contains(output_index))
+ continue;
+ if (arg is null)
+ {
+ var input_placeholder = backward.Inputs[input_index];
+ Tensor variant_arg;
+ if (input_placeholder.dtype == dtypes.variant)
+ {
+ variant_arg = variant_zeros_like[output_index];
+ }
+ else
+ {
+ var (shape, type) = default_gradient.shape_and_dtype(input_placeholder);
+
+ variant_arg = array_ops.zeros(shape, type);
+ }
+ processed_args.Add(variant_arg);
+ }
+ else
{
- if (skip_positions.Contains(output_index))
- continue;
- if (arg == null)
- throw new NotImplementedException("");
processed_args.Add(arg);
- input_index += 1;
- if (input_index >= backward_function_inputs)
- break;
}
+ input_index++;
+ if (input_index >= backward_function_inputs)
+ break;
+ }
- tf.Logger.Debug($"Invoke backward function: {backward.Name}");
- var gradients = backward.CallFlat(processed_args, remapped_captures);
+ tf.Logger.Debug($"Invoke backward function: {backward.Name}");
+ var gradients = backward.CallFlat(processed_args, remapped_captures);
- foreach (var unneeded_gradient_index in unneeded_gradients)
- {
- var index = Convert.ToInt32(unneeded_gradient_index);
- if (gradients.Length <= index)
- gradients.Insert(index, null);
- }
+ foreach (var unneeded_gradient_index in unneeded_gradients)
+ {
+ var index = Convert.ToInt32(unneeded_gradient_index);
+ if (gradients.Length <= index)
+ gradients.Insert(index, null);
+ }
- return gradients;
- };
- }
+ return gradients;
+ };
return (_backward_function_wrapper, recorded_outputs);
}
@@ -143,7 +193,7 @@ namespace Tensorflow.Functions
}
}
- var backwards_graph = new FuncGraph(_func_graph.Name);
+ var backwards_graph = new FuncGraph(monomorphic_function_utils._backward_name(_func_graph.Name));
backwards_graph.as_default();
var gradients_wrt_outputs = new List();
foreach (var output in trainable_outputs)
@@ -153,6 +203,7 @@ namespace Tensorflow.Functions
gradients_wrt_outputs.Add(gradient_placeholder);
handle_data_util.copy_handle_data(output, gradient_placeholder);
}
+ // TODO(Rinne): with ops.device(None)
var gradients_wrt_inputs = gradients_util._GradientsHelper(trainable_outputs.ToArray(),
_func_graph.Inputs,
grad_ys: gradients_wrt_outputs.ToArray(),
@@ -175,7 +226,8 @@ namespace Tensorflow.Functions
backwards_graph.Inputs = gradients_wrt_outputs.Concat(backwards_graph.internal_captures).ToArray();
backwards_graph.Outputs.AddRange(gradients_wrt_inputs.Where(x => x is not null));
- var (forward_function, backward_function) = monomorphic_function_utils._create_forward_backward_with_graph(null, _func_graph, backwards_graph);
+ var (wrapped_forward_function, wrapped_backward_function) =
+ monomorphic_function_utils._create_forward_backward_with_graph(null, _func_graph, backwards_graph);
//var forward_function_name = $"{_FORWARD_PREFIX}_{_func_graph.FuncName}_{ops.uid()}";
//var backward_function_attr = new Dictionary();
//backward_function_attr[FORWARD_FUNCTION_ATTRIBUTE_NAME] = forward_function_name;
@@ -189,10 +241,11 @@ namespace Tensorflow.Functions
// _func_graph.Inputs, _func_graph.Outputs,
// monomorphic_function_utils._parse_func_attrs(forward_function_attr));
- return (forward_function, _func_graph, backward_function, null, 0);
+ return (wrapped_forward_function, _func_graph, wrapped_backward_function, null, 0);
}
- public virtual EagerDefinedFunction ForwardAndBackwardFunctions(Tensors inference_args)
+ public virtual (EagerDefinedFunction, FuncGraph, ConcreteFunction, List, int)
+ ForwardAndBackwardFunctions(Tensors inference_args)
{
throw new NotImplementedException("");
}
diff --git a/src/TensorFlowNET.Core/Functions/TracingCompiler.cs b/src/TensorFlowNET.Core/Functions/TracingCompiler.cs
index 8a844671..fb109595 100644
--- a/src/TensorFlowNET.Core/Functions/TracingCompiler.cs
+++ b/src/TensorFlowNET.Core/Functions/TracingCompiler.cs
@@ -73,12 +73,12 @@ namespace Tensorflow.Functions
private static string male_cache_key(Tensor[] inputs)
{
- string res = "";
- foreach (var input in inputs)
- {
- res += $"{input.name}_{input.Id}";
- }
- return res;
+ //string res = "";
+ //foreach (var input in inputs)
+ //{
+ // res += $"{input.name}_{input.Id}";
+ //}
+ return inputs.Length.ToString();
}
}
}
diff --git a/src/TensorFlowNET.Core/Functions/monomorphic_function.cs b/src/TensorFlowNET.Core/Functions/monomorphic_function.cs
index acf00597..7cb5c705 100644
--- a/src/TensorFlowNET.Core/Functions/monomorphic_function.cs
+++ b/src/TensorFlowNET.Core/Functions/monomorphic_function.cs
@@ -153,7 +153,7 @@ namespace Tensorflow.Functions
foreach(var tape in tf.GetTapeSet())
{
tape.RecordOperation(_inference_function.Signature.Name, to_record,
- inference_args.Select(t => new TapeTensor(t)).ToArray(), backward_function);
+ inference_args, backward_function);
}
}
diff --git a/src/TensorFlowNET.Core/Gradients/BackpropInitialState.cs b/src/TensorFlowNET.Core/Gradients/BackpropInitialState.cs
index eee98a61..743ed0d8 100644
--- a/src/TensorFlowNET.Core/Gradients/BackpropInitialState.cs
+++ b/src/TensorFlowNET.Core/Gradients/BackpropInitialState.cs
@@ -9,7 +9,7 @@ namespace Tensorflow.Gradients
/// Map from tensor to how many references still exist for this tensor in
/// the tape.
///
- public UnorderedMap tensor_usage_counts { get; set; }
+ public UnorderedMap tensor_usage_counts { get; set; }
///
/// Maps from op ID to how many output tensors of this op still need to have
/// their gradients computed.
@@ -19,7 +19,7 @@ namespace Tensorflow.Gradients
public BackpropInitialState()
{
op_tape = new OpTape();
- tensor_usage_counts = new UnorderedMap();
+ tensor_usage_counts = new UnorderedMap();
op_missing_tensor = new UnorderedMap();
}
}
diff --git a/src/TensorFlowNET.Core/Gradients/GradientTape.cs b/src/TensorFlowNET.Core/Gradients/GradientTape.cs
index 31517e58..b5fd373e 100644
--- a/src/TensorFlowNET.Core/Gradients/GradientTape.cs
+++ b/src/TensorFlowNET.Core/Gradients/GradientTape.cs
@@ -67,40 +67,59 @@ namespace Tensorflow.Gradients
///
///
///
- public Tensor gradient(Tensor target, Tensor source)
+ public Tensor gradient(Tensor target, Tensor source, List output_gradients = null,
+ string unconnected_gradients = null)
{
+ if(_tape is null)
+ {
+ throw new RuntimeError("A non-persistent GradientTape can only be used to " +
+ "compute one set of gradients (or jacobians).");
+ }
+
ITape tape = stop_recording();
var results = tf.Runner.TFE_TapeGradient(tape,
new[] { target },
new[] { source },
- null);
+ output_gradients,
+ new[] { source },
+ unconnected_gradients);
return results[0];
}
- public Tensor gradient(Tensor target, ResourceVariable source)
+ public Tensor gradient(Tensor target, ResourceVariable source, List output_gradients = null,
+ string unconnected_gradients = null)
{
- var results = gradient(target, new List { source });
+ var results = gradient(target, new List { source }, output_gradients, unconnected_gradients);
return results[0];
}
- public (Tensor, Tensor) gradient(Tensor target, (ResourceVariable, ResourceVariable) sources)
+ public (Tensor, Tensor) gradient(Tensor target, (ResourceVariable, ResourceVariable) sources, List output_gradients = null,
+ string unconnected_gradients = null)
{
- var results = gradient(target, new List { sources.Item1, sources.Item2 });
+ var results = gradient(target, new List { sources.Item1, sources.Item2 }, output_gradients, unconnected_gradients);
return (results[0], results[1]);
}
- public Tensor[] gradient(Tensor target, IEnumerable sources)
+ public Tensor[] gradient(Tensor target, IEnumerable sources, List output_gradients = null,
+ string unconnected_gradients = null)
{
+ if (_tape is null)
+ {
+ throw new RuntimeError("A non-persistent GradientTape can only be used to " +
+ "compute one set of gradients (or jacobians).");
+ }
var tape = stop_recording();
var results = tf.Runner.TFE_TapeGradient(tape,
new[] { target },
sources.Select(x => x.Handle).ToArray(),
- null);
+ output_gradients,
+ sources.Select(x => x.Handle).ToArray(),
+ unconnected_gradients);
if (!tape.Persistent)
{
diff --git a/src/TensorFlowNET.Core/Gradients/ITape.cs b/src/TensorFlowNET.Core/Gradients/ITape.cs
index dbd085ea..07594dab 100644
--- a/src/TensorFlowNET.Core/Gradients/ITape.cs
+++ b/src/TensorFlowNET.Core/Gradients/ITape.cs
@@ -6,24 +6,31 @@ namespace Tensorflow.Gradients
public interface ITape
{
void SetTapeId(int id);
- bool ShouldRecord(Tensor[] tensors);
+ bool ShouldRecord(long[] tensor_ids, TF_DataType[] tensor_dtypes);
void StartRecord();
void StopRecord();
bool Persistent { get; }
void RecordOperation(string op_type,
- Tensor[] input_tensors,
TapeTensor[] output_tensors,
+ long[] input_tensor_id,
+ TF_DataType[] input_dtypes,
BackwardFunction backward_function);
- void VariableAccessed(ResourceVariable variable);
+ void RecordOperation(string op_type,
+ Tensor[] outputs,
+ Tensor[] inputs,
+ BackwardFunction backward_function);
+
+ void VariableAccessed(IVariableV1 variable);
void Watch(Tensor x);
- ResourceVariable[] WatchedVariables();
+ IVariableV1[] WatchedVariables();
- Tensor[] ComputeGradient(Tensor[] target_tensor_ids,
- Tensor[] source_tensor_ids,
- UnorderedMap sources_that_are_targets,
- Tensor[] output_gradients);
+ Tensor[] ComputeGradient(long[] target_tensor_ids,
+ long[] source_tensor_ids,
+ UnorderedMap sources_that_are_targets,
+ List output_gradients,
+ bool build_default_zeros_grads);
}
}
diff --git a/src/TensorFlowNET.Core/Gradients/OpTapeEntry.cs b/src/TensorFlowNET.Core/Gradients/OpTapeEntry.cs
index 537369dd..7665fa01 100644
--- a/src/TensorFlowNET.Core/Gradients/OpTapeEntry.cs
+++ b/src/TensorFlowNET.Core/Gradients/OpTapeEntry.cs
@@ -9,9 +9,9 @@ namespace Tensorflow.Gradients
{
public string op_type { get; set; }
public TapeTensor[] output_tensor_info { get; set; }
- public Tensor[] input_tensor_id { get; set; }
+ public long[] input_tensor_id { get; set; }
public BackwardFunction backward_function { get; set; }
public override string ToString()
- => $"{op_type}, inputs: {string.Join(",", input_tensor_id.Select(x => x.Id))}";
+ => $"{op_type}, inputs: {string.Join(",", input_tensor_id)}";
}
}
diff --git a/src/TensorFlowNET.Core/Gradients/Tape.ComputeGradient.cs b/src/TensorFlowNET.Core/Gradients/Tape.ComputeGradient.cs
index 73c9e501..8a4a41f6 100644
--- a/src/TensorFlowNET.Core/Gradients/Tape.ComputeGradient.cs
+++ b/src/TensorFlowNET.Core/Gradients/Tape.ComputeGradient.cs
@@ -2,235 +2,246 @@
using System.Collections.Generic;
using System.Linq;
using Tensorflow.Util;
+using static Tensorflow.Binding;
namespace Tensorflow.Gradients
{
public partial class Tape
{
- // int kMinAggregateCount = 4;
- // int kMinAggregateBytes = 128 * 1024 * 1024;
+ static readonly int kMinAggregateCount = 4;
+ static readonly int kMinAggregateBytes = 128 * 1024 * 1024;
+ private static UnorderedMap> _functionsAcceptingNoneForIndicesMap;
- public Tensor[] ComputeGradient(Tensor[] target_tensor_ids,
- Tensor[] source_tensor_ids,
- UnorderedMap sources_that_are_targets,
- Tensor[] output_gradients)
+ static Tape()
{
- var sources_set = new UnorderedSet(source_tensor_ids);
- // var gradients_size = new UnorderedMap();
- var functionsAcceptingNoneForIndicesMap = FunctionsAcceptingNoneForIndicesMap();
- var state = PrepareBackprop(
- target_tensor_ids, tensor_tape_, op_tape_, sources_set, _persistent);
- var op_stack = InitialStack(state.op_tape, state.op_missing_tensor);
- var gradients = InitialGradients(target_tensor_ids, sources_that_are_targets,
- output_gradients,
- tensor_tape_,
- state.op_tape);
+ _functionsAcceptingNoneForIndicesMap = new();
+ _functionsAcceptingNoneForIndicesMap.Add("SoftmaxCrossEntropyWithLogits", new UnorderedSet(new[] { 1 }));
+ _functionsAcceptingNoneForIndicesMap.Add("SparseSoftmaxCrossEntropyWithLogits", new UnorderedSet(new[] { 1 }));
+ _functionsAcceptingNoneForIndicesMap.Add("FusedBatchNorm", new UnorderedSet(new[] { 1, 2, 3, 4 }));
+ }
- while (!op_stack.empty())
+ public Tensor[] ComputeGradient(long[] target_tensor_ids,
+ long[] source_tensor_ids,
+ UnorderedMap sources_that_are_targets,
+ List output_gradients,
+ bool build_default_zeros_grads)
+ {
+ UnorderedSet sources_set = new(source_tensor_ids);
+ BackpropInitialState state = PrepareBackprop(target_tensor_ids, tensor_tape_, op_tape_, sources_set, Persistent);
+ var op_stack = InitialStack(state.op_tape, state.op_missing_tensor);
+ var gradients = InitialGradients(target_tensor_ids, sources_that_are_targets, output_gradients, tensor_tape_, state.op_tape);
+ UnorderedMap gradients_size = new();
+ while(op_stack.Count > 0)
{
- var op = op_stack.Dequeue();
- if (!state.op_tape.find(op, out var trace))
+ long op = op_stack.Dequeue();
+ if(!state.op_tape.TryGetValue(op, out var op_it))
+ {
continue;
-
- // Console.WriteLine($"ComputeGradient: {state.op_tape[op].op_type}");
+ }
+ var trace = op_it;
state.op_tape.erase(op);
-
- var out_gradients = new List(trace.output_tensor_info.Length);
- var unneeded_gradients = new List();
- for (int i = 0; i < trace.input_tensor_id.Length; i++)
+ List out_gradients = new();
+ List unneeded_gradients = new();
+ for(int i = 0, end = trace.input_tensor_id.Length; i < end; i++)
{
- var in_tensor_id = trace.input_tensor_id[i];
- if (!tensor_tape_.find(in_tensor_id) &&
- !sources_set.find(in_tensor_id))
+ long in_tensor_id = trace.input_tensor_id[i];
+ if(!tensor_tape_.find(in_tensor_id) && !sources_set.find(in_tensor_id))
+ {
unneeded_gradients.Add(i);
+ }
}
bool any_gradient_nonzero = false;
- var zero_indices = new List();
- for (int i = 0; i < trace.output_tensor_info.Length; ++i)
+ List zero_indices = new();
+ for(int i = 0, end = trace.output_tensor_info.Length; i < end; i++)
{
- var id = trace.output_tensor_info[i].GetTensor();
- if (!gradients.find(id, out var grad_it))
+ long id = trace.output_tensor_info[i].GetID();
+ if(!gradients.TryGetValue(id, out var grad_it))
{
- if (functionsAcceptingNoneForIndicesMap.find(trace.op_type, out var func_name_it) &&
- func_name_it.find(i))
+ out_gradients.Add(null);
+ if (build_default_zeros_grads)
{
- out_gradients.Add(null);
- }
- else
- {
- out_gradients.Add(null);
- zero_indices.Add(i);
+ if(!_functionsAcceptingNoneForIndicesMap.TryGetValue(trace.op_type, out var func_name_it) ||
+ !func_name_it.find(i))
+ {
+ zero_indices.Add(i);
+ }
}
}
else
{
any_gradient_nonzero = true;
- var new_gradients = grad_it.Count == 1 ?
- grad_it[0] :
- gen_math_ops.add_n(grad_it.ToArray()); // vspace.AggregateGradients
-
+ Tensor new_gradients;
+ if (grad_it.Count == 1)
+ {
+ new_gradients = grad_it[0];
+ }
+ else
+ {
+ new_gradients = AggregateGradients(grad_it);
+ }
if (!sources_set.find(id))
+ {
gradients.Remove(id);
+ }
else
{
- // grad_it.Clear();
- // grad_it.Add(new_gradients);
- // vspace.MarkAsResult(new_gradients);
+ grad_it.Clear();
+ grad_it.Add(new_gradients);
+ // MarkAsResult
}
out_gradients.Add(new_gradients);
}
}
- Tensor[] in_gradients;
+ Tensor[] in_gradients = new Tensor[0];
if (any_gradient_nonzero)
{
- // foreach (var i in zero_indices)
- // out_gradients[i] = trace.output_tensor_info[i].ZerosLike();
-
- in_gradients = trace.backward_function(out_gradients.ToArray(), unneeded_gradients.ToArray());
-
- if (in_gradients.Length != trace.input_tensor_id.Length && in_gradients.Length + unneeded_gradients.Count != trace.input_tensor_id.Length)
- throw new RuntimeError($"Recorded operation '{trace.op_type}' returned too few gradients. Expected {trace.input_tensor_id.Length} but received {in_gradients.Count()}");
- if (!_persistent)
+ foreach(var i in zero_indices)
{
- // trace.backward_function_deleter(trace.backward_function);
- trace.backward_function = null;
+ out_gradients[i] = trace.output_tensor_info[i].ZerosLike();
}
+ in_gradients = CallBackwardFunction(trace.backward_function, unneeded_gradients, out_gradients);
}
else
{
- in_gradients = new Tensor[trace.input_tensor_id.Length];
+ out_gradients.Clear();
}
-
- bool skip_unneeded_id = trace.input_tensor_id.Length > in_gradients.Length;
- for (int i = 0, k = 0; i < in_gradients.Length && k < trace.input_tensor_id.Count(); ++i, ++k)
+
+ for(int i = 0, end = in_gradients.Length; i < end; i++)
{
- if (skip_unneeded_id && unneeded_gradients.Contains(k)) ++k;
- var id = trace.input_tensor_id[k];
- if (in_gradients[i] != null)
+ long id = trace.input_tensor_id[i];
+ if (in_gradients[i] is not null)
{
- var unaggregated_grads = gradients[id];
+ var unaggregated_grads = gradients.SetDefault(id, new List());
unaggregated_grads.Add(in_gradients[i]);
- /*if (unaggregated_grads.Count > kMinAggregateCount)
+ if(unaggregated_grads.Count > kMinAggregateCount)
{
- if (!gradients_size.find(id, out var size))
+ if(!gradients_size.TryGetValue(id, out var size))
{
- size = (long)unaggregated_grads[0].size;
+ size = NumElements(unaggregated_grads[0]);
gradients_size.emplace(id, size);
}
-
- if (unaggregated_grads.Count * size * 4 > kMinAggregateBytes)
+ if(unaggregated_grads.Count * size * 4 > kMinAggregateBytes)
{
- throw new NotImplementedException("");
+ Tensor grad = AggregateGradients(unaggregated_grads);
+ unaggregated_grads.Clear();
+ unaggregated_grads.Add(grad);
}
- }*/
+ }
}
-
- if (!state.tensor_usage_counts.find(id))
+ if(!state.tensor_usage_counts.find(id))
+ {
continue;
-
+ }
state.tensor_usage_counts[id]--;
- if (state.tensor_usage_counts[id] > 0)
+ if(state.tensor_usage_counts[id] > 0)
+ {
continue;
-
- if (!tensor_tape_.find(id, out var tape_it))
+ }
+ if (!tensor_tape_.TryGetValue(id, out var tape_it))
{
- if (gradients.find(id, out var grad_it))
+ if (gradients.find(id))
{
- // foreach (var g in grad_it)
- // DeleteGradient(g);
gradients.erase(id);
}
continue;
}
-
- var op_id = tape_it;
- if (op_id == -1)
+ long op_id = tape_it;
+ if(op_id == -1)
+ {
continue;
-
- if (state.op_missing_tensor.find(op_id, out var missing_it))
+ }
+ if(state.op_missing_tensor.find(op_id))
{
state.op_missing_tensor[op_id]--;
- if (state.op_missing_tensor[op_id] == 0)
+ if(state.op_missing_tensor[op_id] == 0)
+ {
op_stack.Enqueue(op_id);
+ }
}
}
}
- if (state.op_tape.Count > 0)
+ if(state.op_tape.Count > 0)
+ {
throw new RuntimeError("Invalid tape state.");
-
- var result = new Tensor[source_tensor_ids.Length];
- var j = 0;
- foreach (var id in source_tensor_ids)
+ }
+ Tensor[] result = new Tensor[source_tensor_ids.Length];
+ for(int i = 0; i < source_tensor_ids.Length; i++)
{
- if (gradients.find(id, out var grad_it))
+ long tensor_id = source_tensor_ids[i];
+ if(!gradients.TryGetValue(tensor_id, out var grad_it))
{
- if (grad_it.Count > 1)
- result[j] = gen_math_ops.add_n(grad_it.ToArray());
- else
- result[j] = grad_it[0];
+ result[i] = null;
+ }
+ else
+ {
+ if(grad_it.Count > 1)
+ {
+ Tensor grad = AggregateGradients(grad_it);
+ grad_it.Clear();
+ grad_it.Add(grad);
+ }
+ result[i] = grad_it[0];
}
- j++;
}
-
return result;
}
UnorderedMap> FunctionsAcceptingNoneForIndicesMap()
{
- var m = new UnorderedMap>();
- m.Add("SoftmaxCrossEntropyWithLogits", new UnorderedSet(new[] { 1 }));
- m.Add("SparseSoftmaxCrossEntropyWithLogits", new UnorderedSet(new[] { 1 }));
- m.Add("FusedBatchNorm", new UnorderedSet(new[] { 1, 2, 3, 4 }));
- return m;
+ return _functionsAcceptingNoneForIndicesMap;
}
- UnorderedMapEnumerable> InitialGradients(Tensor[] target_tensor_ids,
- UnorderedMap sources_that_are_targets,
- Tensor[] output_gradients,
+ UnorderedMap> InitialGradients(long[] target_tensor_ids,
+ UnorderedMap sources_that_are_targets,
+ List output_gradients,
TensorTape tensor_tape,
OpTape op_tape)
{
- var result = new UnorderedMapEnumerable>();
- for (int i = 0; i < target_tensor_ids.Length; ++i)
+ var result = new UnorderedMap>();
+ for(int i = 0, end = target_tensor_ids.Length; i < end; i++)
{
- var id = target_tensor_ids[i];
- if (output_gradients.Length == 0 || output_gradients[i] == null)
+ long id = target_tensor_ids[i];
+ if( output_gradients is null ||output_gradients.Count == 0 || output_gradients[i] is null)
{
- if (tensor_tape.find(id, out var tensor_id) && tensor_id != null)
+ if(tensor_tape.TryGetValue(id, out var tensor_it) && tensor_it != -1)
{
- if (!op_tape.find(tensor_tape[id], out var op_it))
+ if(!op_tape.TryGetValue(tensor_it, out var op_it))
+ {
throw new RuntimeError("Internal state of the gradient tape is invalid: " +
- "failed to find operation producing a tensor");
+ "failed to find operation producing a tensor.");
+ }
bool found = false;
- for (int j = 0; j < op_it.output_tensor_info.Length; ++j)
+ for(int j = 0; j < op_it.output_tensor_info.Length; j++)
{
- if (op_it.output_tensor_info[j].GetTensor() == id)
+ if (op_it.output_tensor_info[j].GetID() == id)
{
found = true;
- var ones = op_it.output_tensor_info[j].OnesLike();
- result[id].Add(ones);
+ Tensor ones_like = BuildOnesLike(op_it.output_tensor_info[j]);
+ result.SetDefault(id, new List()).Add(ones_like);
break;
}
}
-
if (!found)
{
- throw new ValueError("Internal state of the gradient tape is invalid: " +
- "none of operations outputs match expected tensor");
+ throw new RuntimeError("Internal state of the gradient tape is invalid: " +
+ "none of operations outputs match expected tensor.");
}
}
else
{
- if (sources_that_are_targets.find(id, out var source_tensor))
- result[id].Add(source_tensor.OnesLike());
+ if(sources_that_are_targets.TryGetValue(id, out var source_tensor))
+ {
+ Tensor ones_like = BuildOnesLike(source_tensor);
+ result.SetDefault(id, new List()).Add(ones_like);
+ }
}
}
else
{
- result[id].Add(output_gradients[i]);
+ result.SetDefault(id, new List()).Add(output_gradients[i]);
}
}
@@ -248,5 +259,26 @@ namespace Tensorflow.Gradients
}
return result;
}
+
+ Tensor BuildOnesLike(TapeTensor t)
+ {
+ return t.OnesLike();
+ }
+
+ Tensor AggregateGradients(List gradient_tensors)
+ {
+ if(gradient_tensors.Count == 0)
+ {
+ return gradient_tensors[0];
+ }
+ return tf.add_n(gradient_tensors.ToArray());
+ }
+
+ void DeleteGradient(Tensor gradient)
+ {
+ // Do not do anything here. Because GC will collect it when it has no reference.
+ }
+
+ long NumElements(Tensor tensor) => 1;
}
}
diff --git a/src/TensorFlowNET.Core/Gradients/Tape.PrepareBackprop.cs b/src/TensorFlowNET.Core/Gradients/Tape.PrepareBackprop.cs
index 2ab4ddbb..f8f356e7 100644
--- a/src/TensorFlowNET.Core/Gradients/Tape.PrepareBackprop.cs
+++ b/src/TensorFlowNET.Core/Gradients/Tape.PrepareBackprop.cs
@@ -5,63 +5,62 @@ namespace Tensorflow.Gradients
{
public partial class Tape
{
- public BackpropInitialState PrepareBackprop(Tensor[] target,
+ public BackpropInitialState PrepareBackprop(long[] target,
TensorTape tensor_tape,
OpTape op_tape,
- UnorderedSet sources_set,
+ UnorderedSet sources_set,
bool persistent_tape)
{
+ Stack tensor_stack = new Stack();
+ foreach(var t in target)
+ {
+ tensor_stack.Push(t);
+ }
BackpropInitialState result = new BackpropInitialState();
- var tensor_stack = new Queue(target);
- while (tensor_stack.Count > 0)
+ while(tensor_stack.Count > 0)
{
- var tensor_id = tensor_stack.Dequeue();
-
- if (!tensor_tape.find(tensor_id, out var op_id))
+ long tensor_id = tensor_stack.Pop();
+ if(!tensor_tape.TryGetValue(tensor_id, out var op_id))
+ {
continue;
-
- if (op_id == -1 ||
- !op_tape.find(op_id, out var op_it) ||
- result.op_tape.find(op_id, out var result_op_it))
+ }
+ if(op_id == -1 || !op_tape.TryGetValue(op_id, out var op_it)
+ || result.op_tape.find(op_id))
+ {
continue;
-
+ }
result.op_tape.emplace(op_id, op_it);
-
- foreach (var it in op_it.input_tensor_id)
+ foreach(var it in op_it.input_tensor_id)
{
- if (result.tensor_usage_counts.find(it))
+ if(result.tensor_usage_counts.find(it))
+ {
result.tensor_usage_counts[it]++;
+ }
else
{
result.tensor_usage_counts[it] = 1;
if (tensor_tape.find(it))
- tensor_stack.Enqueue(it);
+ {
+ tensor_stack.Push(it);
+ }
}
}
-
if (!persistent_tape)
- op_tape.Remove(op_id);
+ {
+ op_tape.erase(op_id);
+ }
}
-
- foreach (var pair in result.tensor_usage_counts)
+ foreach(var pair in result.tensor_usage_counts)
{
- if (tensor_tape.find(pair.Key, out var it) && it != -1)
- result.op_missing_tensor[it] += 1;
+ if(tensor_tape.TryGetValue(pair.Key, out var it) && it != -1)
+ {
+ result.op_missing_tensor[it]++;
+ }
}
-
if (!persistent_tape)
{
- // Call destructors for all unneeded gradient functions and
- // clear the op_tape. We can clear the tape because ownership of
- // backward functions that will be used for gradient computation
- // has been transferred to `result`.
- /*for (const auto&op_pair : *op_tape) {
- op_pair.second.backward_function_deleter(
- op_pair.second.backward_function);
- }*/
op_tape.Clear();
}
-
return result;
}
}
diff --git a/src/TensorFlowNET.Core/Gradients/Tape.RecordOperation.cs b/src/TensorFlowNET.Core/Gradients/Tape.RecordOperation.cs
index a692f1f4..708b9121 100644
--- a/src/TensorFlowNET.Core/Gradients/Tape.RecordOperation.cs
+++ b/src/TensorFlowNET.Core/Gradients/Tape.RecordOperation.cs
@@ -8,34 +8,45 @@ namespace Tensorflow.Gradients
public partial class Tape
{
long next_op_id_ = 0;
- UnorderedMap tensor_usage_;
+ UnorderedMap tensor_usage_;
public void RecordOperation(string op_type,
- Tensor[] input_tensors,
TapeTensor[] output_tensors,
+ long[] input_tensor_id,
+ TF_DataType[] input_dtypes,
BackwardFunction backward_function)
{
- if (!ShouldRecord(input_tensors))
+ if (!ShouldRecord(input_tensor_id, input_dtypes))
return;
- var op_id = next_op_id_++;
- foreach (var i in input_tensors)
+ foreach (var i in input_tensor_id)
+ {
tensor_usage_[i]++;
-
+ }
+ long op_id = next_op_id_++;
+
foreach (var o in output_tensors)
{
tf.Logger.Debug($"RecordOperation: tensor_tape_[{o.GetID()}] = {op_id}");
- tensor_tape_[o.GetTensor()] = op_id;
- tensor_usage_[o.GetTensor()] = 1;
+ tensor_tape_[o.GetID()] = op_id;
+ tensor_usage_[o.GetID()] = 1;
}
op_tape_[op_id] = new OpTapeEntry
{
op_type = op_type,
- output_tensor_info = output_tensors,
- input_tensor_id = input_tensors,
+ output_tensor_info = output_tensors.ToArray(),
+ input_tensor_id = input_tensor_id.ToArray(),
backward_function = backward_function
};
}
+
+ public void RecordOperation(string op_type,
+ Tensor[] outputs,
+ Tensor[] inputs,
+ BackwardFunction backward_function)
+ {
+ tf.Runner.TFE_TapeSetRecordOperation(op_type, outputs, inputs, backward_function);
+ }
}
}
diff --git a/src/TensorFlowNET.Core/Gradients/Tape.cs b/src/TensorFlowNET.Core/Gradients/Tape.cs
index 15caf81b..648666bb 100644
--- a/src/TensorFlowNET.Core/Gradients/Tape.cs
+++ b/src/TensorFlowNET.Core/Gradients/Tape.cs
@@ -1,5 +1,6 @@
using System;
using System.Collections.Generic;
+using System.Diagnostics;
using System.Linq;
using Tensorflow.Util;
using static Tensorflow.Binding;
@@ -29,7 +30,7 @@ namespace Tensorflow.Gradients
_created_eagerly = tf.Context.executing_eagerly();
tensor_tape_ = new TensorTape();
op_tape_ = new OpTape();
- tensor_usage_ = new UnorderedMap();
+ tensor_usage_ = new UnorderedMap();
if(_created_eagerly)
tf.Context.start_step();
// nesting_id = ++tape_nesting_id_counter;
@@ -42,29 +43,28 @@ namespace Tensorflow.Gradients
public void Watch(Tensor x)
{
tf.Logger.Debug($"Watch tensor id={x.Id}, name={x.name}");
- tensor_tape_.emplace(x, -1);
+ tensor_tape_.emplace(x.Id, -1);
}
- public bool ShouldRecord(Tensor[] tensors)
+ public bool ShouldRecord(long[] tensor_ids, TF_DataType[] tensor_dtypes)
{
- var dtypes = tensors.Select(x => x.dtype).ToArray();
- for (int i = 0; i < tensors.Length; ++i)
+ Debug.Assert(tensor_ids.Length == tensor_dtypes.Length);
+ for (int i = 0; i < tensor_ids.Length; ++i)
{
- if (tensor_tape_.find(tensors[i]))
+ if (tensor_tape_.find(tensor_ids[i]) && IsDtypeTrainable(tensor_dtypes[i]))
{
- if (IsDtypeTrainable(dtypes[i]))
- return true;
+ return true;
}
}
return false;
}
- public void VariableAccessed(ResourceVariable variable)
+ public void VariableAccessed(IVariableV1 variable)
{
Watch(variable.Handle);
}
- public ResourceVariable[] WatchedVariables()
+ public IVariableV1[] WatchedVariables()
{
return null;
}
diff --git a/src/TensorFlowNET.Core/Gradients/TapeTensor.cs b/src/TensorFlowNET.Core/Gradients/TapeTensor.cs
index 210794d8..3ad19768 100644
--- a/src/TensorFlowNET.Core/Gradients/TapeTensor.cs
+++ b/src/TensorFlowNET.Core/Gradients/TapeTensor.cs
@@ -1,27 +1,63 @@
-using static Tensorflow.Binding;
+using OneOf;
+using static Tensorflow.Binding;
namespace Tensorflow.Gradients
{
public class TapeTensor
{
- Tensor tensor;
- long id => tensor.Id;
- TF_DataType dtype => tensor.dtype;
- Shape shape => tensor.shape;
+ internal Tensor tensor;
+ internal long id;
+ internal TF_DataType dtype;
+ internal OneOf shape;
+
+ public TapeTensor(long id, TF_DataType dtype, Shape shape)
+ {
+ this.id = id;
+ this.dtype = dtype;
+ this.shape = shape;
+ }
+
+ public TapeTensor(long id, TF_DataType dtype, Tensor shape)
+ {
+ this.id = id;
+ this.dtype = dtype;
+ this.shape = shape;
+ }
public TapeTensor(Tensor tensor)
{
+ this.id = tensor.Id;
+ this.dtype = tensor.dtype;
+ this.shape = tensor.shape;
this.tensor = tensor;
}
- public long GetID() => tensor.Id;
- public Tensor GetTensor() => tensor;
+ public long GetID() => id;
public Tensor ZerosLike()
- => tf.zeros(shape: shape, dtype: dtype);
+ {
+ if(dtype == dtypes.resource)
+ {
+ return null;
+ }
+ if(shape.Index == 1)
+ {
+ return tf.zeros_like(shape.AsT1);
+ }
+ return tf.zeros(shape.AsT0, dtype);
+ }
public Tensor OnesLike()
- => tf.ones(shape: shape, dtype: dtype);
+ {
+ if (shape.Index == 1)
+ {
+ return tf.ones_like(shape.AsT1);
+ }
+ return tf.ones(shape.AsT0, dtype);
+ }
+
+ //public Tensor OnesLike()
+ // => tf.ones(shape: shape, dtype: dtype);
public override string ToString()
=> $"{id}, {shape}, {dtype.as_numpy_name()}";
diff --git a/src/TensorFlowNET.Core/Gradients/TensorTape.cs b/src/TensorFlowNET.Core/Gradients/TensorTape.cs
index b9424f91..3f069082 100644
--- a/src/TensorFlowNET.Core/Gradients/TensorTape.cs
+++ b/src/TensorFlowNET.Core/Gradients/TensorTape.cs
@@ -7,7 +7,7 @@ namespace Tensorflow.Gradients
/// produced this tensor. A value of -1 means that the tensor was directly
/// watched and not the result of any operation in the tape.
///
- public class TensorTape : UnorderedMap
+ public class TensorTape : UnorderedMap
{
}
diff --git a/src/TensorFlowNET.Core/Gradients/gradients_util.cs b/src/TensorFlowNET.Core/Gradients/gradients_util.cs
index 10166911..71d3d9ca 100644
--- a/src/TensorFlowNET.Core/Gradients/gradients_util.cs
+++ b/src/TensorFlowNET.Core/Gradients/gradients_util.cs
@@ -704,32 +704,7 @@ namespace Tensorflow
public static int PossibleTapeGradientTypes(Tensor[] tensors)
{
- var tape_set = tf.GetTapeSet();
- bool some_tape_watching = false;
- if(tape_set is not null && tape_set.Count > 0)
- {
- foreach(var tape in tape_set)
- {
- if (tape.ShouldRecord(tensors))
- {
- if(tape.Persistent || some_tape_watching)
- {
- return POSSIBLE_GRADIENT_TYPES_HIGHER_ORDER;
- }
- some_tape_watching = true;
- }
- }
- }
- // skip the forward_accumulators.
-
- if (some_tape_watching)
- {
- return POSSIBLE_GRADIENT_TYPES_FIRST_ORDER;
- }
- else
- {
- return POSSIBLE_GRADIENT_TYPES_NONE;
- }
+ return tf.Runner.TFE_TapeSetPossibleGradientTypes(tensors);
}
///
diff --git a/src/TensorFlowNET.Core/Graphs/FuncGraph.cs b/src/TensorFlowNET.Core/Graphs/FuncGraph.cs
index 9ef0b95b..ea415969 100644
--- a/src/TensorFlowNET.Core/Graphs/FuncGraph.cs
+++ b/src/TensorFlowNET.Core/Graphs/FuncGraph.cs
@@ -215,6 +215,16 @@ public class FuncGraph : Graph, IDisposable
return tensor;
}
+ public void watch_variable(IVariableV1 v)
+ {
+ if (_resource_tensor_inputs.Contains(v.Handle))
+ {
+ return;
+ }
+ _watched_variables.Add(new WeakReference(v));
+ //this = this.outer_graph;
+ }
+
Tensor capture_eager_tensor(Tensor tensor, string name)
{
Tensor graph_const = null;
diff --git a/src/TensorFlowNET.Core/Keras/Engine/IOptimizer.cs b/src/TensorFlowNET.Core/Keras/Engine/IOptimizer.cs
index 68d6d059..58e7ef8c 100644
--- a/src/TensorFlowNET.Core/Keras/Engine/IOptimizer.cs
+++ b/src/TensorFlowNET.Core/Keras/Engine/IOptimizer.cs
@@ -4,10 +4,10 @@ public interface IOptimizer
{
Tensor[] aggregate_gradients(IEnumerable<(Tensor, IVariableV1)> grads_and_vars);
Tensor[] clip_gradients(Tensor[] grads);
- void apply_gradients((Tensor, ResourceVariable) grads_and_vars,
+ void apply_gradients((Tensor, IVariableV1) grads_and_vars,
string name = null,
bool experimental_aggregate_gradients = true);
- void apply_gradients(IEnumerable<(Tensor, ResourceVariable)> grads_and_vars,
+ void apply_gradients(IEnumerable<(Tensor, IVariableV1)> grads_and_vars,
string name = null,
bool experimental_aggregate_gradients = true);
}
diff --git a/src/TensorFlowNET.Core/Operations/c_api.ops.cs b/src/TensorFlowNET.Core/Operations/c_api.ops.cs
index 43dc8cd4..e5f55631 100644
--- a/src/TensorFlowNET.Core/Operations/c_api.ops.cs
+++ b/src/TensorFlowNET.Core/Operations/c_api.ops.cs
@@ -208,9 +208,9 @@ namespace Tensorflow
[DllImport(TensorFlowLibName)]
public static extern int TF_OperationOutputListLength(IntPtr oper, string arg_name, SafeStatusHandle status);
- //[DllImport(TensorFlowLibName)]
- //public static extern IntPtr GetHandleShapeAndType(SafeGraphHandle c_graph, TF_Output output);
- //[DllImport(TensorFlowLibName)]
- //public static extern void SetHandleShapeAndType(SafeGraphHandle c_graph, TF_Output output, byte[] data);
+ [DllImport(TensorFlowLibName)]
+ public static extern IntPtr TFC_GetHandleShapeAndType(SafeGraphHandle c_graph, TF_Output output);
+ [DllImport(TensorFlowLibName)]
+ public static extern void TFC_SetHandleShapeAndType(SafeGraphHandle c_graph, TF_Output output, byte[] data, long proto_len, SafeStatusHandle status);
}
}
diff --git a/src/TensorFlowNET.Core/Operations/functional_ops.cs b/src/TensorFlowNET.Core/Operations/functional_ops.cs
index 9c2e85d1..10547921 100644
--- a/src/TensorFlowNET.Core/Operations/functional_ops.cs
+++ b/src/TensorFlowNET.Core/Operations/functional_ops.cs
@@ -39,7 +39,7 @@ namespace Tensorflow
if (config is null)
{
- config = function_utils.get_disabled_rewriter_config().ToString();
+ config = function_utils.get_disabled_rewriter_config().ToStringUtf8();
}
if (executor_type is null)
@@ -49,6 +49,8 @@ namespace Tensorflow
if (executing_eagerly)
{
+ // TODO(Rinne): implement it.
+
throw new NotImplementedException();
}
diff --git a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs
index 93a54af0..1dc6504a 100644
--- a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs
+++ b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs
@@ -17,6 +17,7 @@
using System;
using System.Linq;
using Tensorflow.Contexts;
+using Tensorflow.Eager;
using static Tensorflow.Binding;
namespace Tensorflow
@@ -210,7 +211,51 @@ namespace Tensorflow
/// A name for the operation (optional).
/// A `Tensor`. Has the same type as `value`.
public static Tensor fill(Tensor dims, T value, string name = null)
- => tf.Context.ExecuteOp("Fill", name, new ExecuteOpArgs(dims, value));
+ {
+ var ctx = tf.Context;
+ if (ctx.executing_eagerly())
+ {
+ try
+ {
+ var _result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo("Fill", name, dims, value));
+ return _result[0];
+ }
+ catch (Exception)
+ {
+
+ }
+ try
+ {
+ return fill_eager_fallback(dims, value as Tensor, name, ctx);
+ }
+ catch (Exception)
+ {
+
+ }
+ }
+ Dictionary attrs = new Dictionary();
+ attrs["dims"] = dims;
+ attrs["value"] = value;
+ var result = tf.OpDefLib._apply_op_helper("Fill", name, attrs);
+ if (execute.must_record_gradient())
+ {
+ throw new NotImplementedException();
+ }
+ return result.output;
+ }
+
+ public static Tensor fill_eager_fallback(Tensor dims, Tensor value, string name, Context ctx)
+ {
+ object[] attrs = new object[] { "T", dims.dtype.as_datatype_enum(), "index_type", dims.dtype.as_datatype_enum() };
+ var _result = execute.executes("Fill", 1, new Tensor[] { dims, value }, attrs, ctx, name);
+
+ if (execute.must_record_gradient())
+ {
+ throw new NotImplementedException();
+ }
+ return _result[0];
+ }
+ //=> tf.Context.ExecuteOp("Fill", name, new ExecuteOpArgs(dims, value));
///
/// Return the reduction indices for computing gradients of s0 op s1 with broadcast.
diff --git a/src/TensorFlowNET.Core/Operations/handle_data_util.cs b/src/TensorFlowNET.Core/Operations/handle_data_util.cs
index 66daa5c0..a01efc52 100644
--- a/src/TensorFlowNET.Core/Operations/handle_data_util.cs
+++ b/src/TensorFlowNET.Core/Operations/handle_data_util.cs
@@ -49,8 +49,10 @@ namespace Tensorflow.Operations
target_t.HandleData = handle_data;
return;
}
- // TODO(Rinne): enable it. (currently the internal c api cannot be invoked.)
- //c_api.SetHandleShapeAndType(target_t.graph.c_graph, target_t._as_tf_output(), handle_data.ToByteArray());
+ Status status = new();
+ var proto = handle_data.ToByteArray();
+ c_api.TFC_SetHandleShapeAndType(target_t.graph.c_graph, target_t._as_tf_output(), proto, proto.Length, status);
+ status.Check(true);
}
public static HandleData get_resource_handle_data(Tensor graph_op) => ops.get_resource_handle_data(graph_op);
diff --git a/src/TensorFlowNET.Core/Operations/resource_variable_ops.cs b/src/TensorFlowNET.Core/Operations/resource_variable_ops.cs
index 3e39338b..c06e822d 100644
--- a/src/TensorFlowNET.Core/Operations/resource_variable_ops.cs
+++ b/src/TensorFlowNET.Core/Operations/resource_variable_ops.cs
@@ -25,6 +25,7 @@ using static Tensorflow.Binding;
using Tensorflow.Operations;
using System.Buffers;
using Tensorflow.Eager;
+using Tensorflow.Graphs;
namespace Tensorflow
{
@@ -302,5 +303,18 @@ namespace Tensorflow
// return handle_data_util.get_resource_handle_data(handle);
//}
}
+
+ public static void variable_accessed(IVariableV1 variable)
+ {
+ if (ops.get_default_graph() is FuncGraph func_graph)
+ {
+ func_graph.watch_variable(variable);
+ }
+ if (variable.Trainable)
+ {
+ foreach (var tape in tf.GetTapeSet())
+ tape.VariableAccessed(variable);
+ }
+ }
}
}
diff --git a/src/TensorFlowNET.Core/Tensorflow.Binding.csproj b/src/TensorFlowNET.Core/Tensorflow.Binding.csproj
index 4898cca0..935e5545 100644
--- a/src/TensorFlowNET.Core/Tensorflow.Binding.csproj
+++ b/src/TensorFlowNET.Core/Tensorflow.Binding.csproj
@@ -110,7 +110,7 @@ https://tensorflownet.readthedocs.io
-
+
diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs
index fff3cde5..498ffda7 100644
--- a/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs
+++ b/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs
@@ -30,7 +30,7 @@ namespace Tensorflow
{
public virtual IntPtr TensorDataPointer => _handle == null ? IntPtr.Zero : TF_TensorData(_handle);
- public Tensor()
+ protected Tensor()
{
}
@@ -108,6 +108,7 @@ namespace Tensorflow
protected unsafe void InitTensor(Shape shape, TF_DataType dtype)
{
_handle = TF_NewTensor(shape, dtype, null);
+ _id = ops.uid();
}
protected unsafe void InitTensor(Shape shape, byte[] bytes, TF_DataType dtype)
@@ -116,6 +117,7 @@ namespace Tensorflow
_handle = StringTensor(new byte[][] { bytes }, Shape.Scalar);
else
_handle = TF_NewTensor(bytes, shape, dtype);
+ _id = ops.uid();
}
protected unsafe void InitTensor(Array array, Shape? shape = null)
@@ -166,6 +168,8 @@ namespace Tensorflow
string[] val => StringTensor(val, shape),
_ => throw new NotImplementedException("")
};
+
+ _id = ops.uid();
}
unsafe SafeTensorHandle InitTensor(T[] array, Shape shape, TF_DataType dtype) where T : unmanaged
diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/function_deserialization.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/function_deserialization.cs
index 69dd2c10..d6986af3 100644
--- a/src/TensorFlowNET.Core/Training/Saving/SavedModel/function_deserialization.cs
+++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/function_deserialization.cs
@@ -462,6 +462,7 @@ namespace Tensorflow.Training.Saving.SavedModel
{
IEnumerable _concrete_functions;
FunctionSpec _function_spec;
+ public IEnumerable ConcreteFunctions => _concrete_functions;
public RestoredFunction(Func function, string name, FunctionSpec function_spec,
IEnumerable concrete_functions): base(function, name, auto_graph: false)
{
diff --git a/src/TensorFlowNET.Core/Util/UnorderedMap.cs b/src/TensorFlowNET.Core/Util/UnorderedMap.cs
index fa2b91fe..219a3c14 100644
--- a/src/TensorFlowNET.Core/Util/UnorderedMap.cs
+++ b/src/TensorFlowNET.Core/Util/UnorderedMap.cs
@@ -25,6 +25,19 @@ namespace Tensorflow.Util
}
}
+ public Tv SetDefault(Tk key, Tv default_value)
+ {
+ if(TryGetValue(key, out var res))
+ {
+ return res;
+ }
+ else
+ {
+ base[key] = default_value;
+ return base[key];
+ }
+ }
+
public void push_back(Tk key, Tv value)
=> this[key] = value;
diff --git a/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs b/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs
index faaa0274..74ce4e8a 100644
--- a/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs
+++ b/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs
@@ -9,6 +9,7 @@ using System.Diagnostics;
using Tensorflow.Checkpoint;
using Tensorflow.Training.Saving.SavedModel;
using OneOf;
+using Tensorflow.Graphs;
namespace Tensorflow
{
@@ -193,6 +194,10 @@ namespace Tensorflow
///
void variable_accessed(BaseResourceVariable variable)
{
+ if(ops.get_default_graph() is FuncGraph func_graph)
+ {
+ func_graph.watch_variable(variable as IVariableV1);
+ }
if (variable.Trainable)
{
foreach (var tape in tf.GetTapeSet())
diff --git a/src/TensorFlowNET.Core/ops.cs b/src/TensorFlowNET.Core/ops.cs
index 7aadb206..c261f3ce 100644
--- a/src/TensorFlowNET.Core/ops.cs
+++ b/src/TensorFlowNET.Core/ops.cs
@@ -575,12 +575,8 @@ namespace Tensorflow
public static HandleData get_resource_handle_data(Tensor graph_op)
{
- throw new NotImplementedException();
- // This implementation hasn't been checked for some reasons.
- // If it throws an exception in the future, please check it.
-
- //var handle_data = c_api.GetHandleShapeAndType(graph_op.graph.c_graph, graph_op._as_tf_output());
- //return HandleData.Parser.ParseFrom(tf.compat.as_bytes(c_api.StringPiece(handle_data)));
+ var handle_data = c_api.TFC_GetHandleShapeAndType(graph_op.graph.c_graph, graph_op._as_tf_output());
+ return HandleData.Parser.ParseFrom(tf.compat.as_bytes(c_api.StringPiece(handle_data)));
}
public static void dismantle_graph(Graph graph)
diff --git a/src/TensorFlowNET.Keras/Engine/Layer.cs b/src/TensorFlowNET.Keras/Engine/Layer.cs
index 0a06df2c..79c955b6 100644
--- a/src/TensorFlowNET.Keras/Engine/Layer.cs
+++ b/src/TensorFlowNET.Keras/Engine/Layer.cs
@@ -27,6 +27,7 @@ using Tensorflow.Keras.Utils;
using Tensorflow.NumPy;
using Tensorflow.Train;
using Tensorflow.Training;
+using Tensorflow.Training.Saving.SavedModel;
using Tensorflow.Util;
using static Tensorflow.Binding;
@@ -50,7 +51,17 @@ namespace Tensorflow.Keras.Engine
/// the layer's weights.
///
protected bool built;
- public bool Built => built;
+ public bool Built
+ {
+ get
+ {
+ return built;
+ }
+ internal set
+ {
+ built = value;
+ }
+ }
public bool Trainable => args.Trainable;
public TF_DataType DType => args.DType;
public bool AutoCast => args.Autocast;
@@ -179,6 +190,11 @@ namespace Tensorflow.Keras.Engine
}
protected List _self_tracked_trackables;
+ ///
+ /// If this value is set, the behavior of layer call will be changed to directly calling this function.
+ ///
+ public Func? ReplacedCall { get; set; } = null;
+
public Layer(LayerArgs args)
{
Initialize(args);
@@ -259,6 +275,10 @@ namespace Tensorflow.Keras.Engine
///
protected virtual Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
{
+ if(ReplacedCall is not null)
+ {
+ return ReplacedCall(inputs);
+ }
return inputs;
}
diff --git a/src/TensorFlowNET.Keras/Engine/Model.Train.cs b/src/TensorFlowNET.Keras/Engine/Model.Train.cs
index 5cf34250..905ea453 100644
--- a/src/TensorFlowNET.Keras/Engine/Model.Train.cs
+++ b/src/TensorFlowNET.Keras/Engine/Model.Train.cs
@@ -35,10 +35,6 @@ namespace Tensorflow.Keras.Engine
{
(x, y) = data_handler.DataAdapter.Expand1d(x, y);
using var tape = tf.GradientTape();
- //foreach (var variable in TrainableVariables)
- //{
- // tape.watch(variable.Handle);
- //}
var y_pred = Apply(x, training: true);
var loss = compiled_loss.Call(y, y_pred);
@@ -70,7 +66,7 @@ namespace Tensorflow.Keras.Engine
gradients = optimizer.aggregate_gradients(zip(gradients, trainable_variables));
gradients = optimizer.clip_gradients(gradients);
- optimizer.apply_gradients(zip(gradients, trainable_variables.Select(x => x as ResourceVariable)),
+ optimizer.apply_gradients(zip(gradients, trainable_variables),
experimental_aggregate_gradients: false);
}
}
diff --git a/src/TensorFlowNET.Keras/Optimizers/OptimizerV2.cs b/src/TensorFlowNET.Keras/Optimizers/OptimizerV2.cs
index dcd7535f..e49d757a 100644
--- a/src/TensorFlowNET.Keras/Optimizers/OptimizerV2.cs
+++ b/src/TensorFlowNET.Keras/Optimizers/OptimizerV2.cs
@@ -42,7 +42,7 @@ namespace Tensorflow.Keras.Optimizers
_set_hyper("decay", args.InitialDecay);
}
- public void apply_gradients((Tensor, ResourceVariable) grads_and_vars,
+ public void apply_gradients((Tensor, IVariableV1) grads_and_vars,
string name = null,
bool experimental_aggregate_gradients = true)
=> apply_gradients(new[] { grads_and_vars },
@@ -55,7 +55,7 @@ namespace Tensorflow.Keras.Optimizers
///
///
///
- public void apply_gradients(IEnumerable<(Tensor, ResourceVariable)> grads_and_vars,
+ public void apply_gradients(IEnumerable<(Tensor, IVariableV1)> grads_and_vars,
string name = null,
bool experimental_aggregate_gradients = true)
{
@@ -78,7 +78,7 @@ namespace Tensorflow.Keras.Optimizers
});
}
- void apply_grad_to_update_var(ResourceVariable var, Tensor grad, Dictionary> apply_state)
+ void apply_grad_to_update_var(IVariableV1 var, Tensor grad, Dictionary> apply_state)
{
_resource_apply_dense(var, grad, apply_state);
// if var.constraint is not None:
@@ -93,7 +93,7 @@ namespace Tensorflow.Keras.Optimizers
throw new NotImplementedException("_resource_apply_dense");
}
- void _distributed_apply(IEnumerable<(Tensor, ResourceVariable)> grads_and_vars,
+ void _distributed_apply(IEnumerable<(Tensor, IVariableV1)> grads_and_vars,
string name,
Dictionary> _apply_state)
{
diff --git a/src/TensorFlowNET.Keras/Saving/KerasObjectLoader.cs b/src/TensorFlowNET.Keras/Saving/KerasObjectLoader.cs
index aed6769a..9cdd3b50 100644
--- a/src/TensorFlowNET.Keras/Saving/KerasObjectLoader.cs
+++ b/src/TensorFlowNET.Keras/Saving/KerasObjectLoader.cs
@@ -255,6 +255,25 @@ namespace Tensorflow.Keras.Saving
///
private void _finalize_saved_model_layers(List layers)
{
+ foreach(var layer in layers)
+ {
+ layer.Built = true;
+ var keras_attr = _get_keras_attr(layer);
+ if(keras_attr is not Trackable trackable)
+ {
+ continue;
+ }
+ if (trackable.CustomizedFields.TryGetValue("call_and_return_conditional_losses", out var layer_call))
+ {
+ Debug.Assert(layer_call is RestoredFunction);
+ var concrete_functions = ((RestoredFunction)layer_call).ConcreteFunctions;
+ if (concrete_functions is not null && concrete_functions.Count() > 0)
+ {
+ layer.ReplacedCall = use_wrapped_call(layer, ((RestoredFunction)layer_call).Apply);
+ }
+ }
+ }
+
foreach(var layer in layers)
{
// TODO(Rinne): deal with `RevivedNetwork`.
@@ -265,6 +284,12 @@ namespace Tensorflow.Keras.Saving
}
}
+ private Func use_wrapped_call(Layer layer, Func call)
+ {
+ // TODO(Rinne): revise it.
+ return call;
+ }
+
private void _restore_layer_unconditional_losses(Layer layer)
{
// TODO(Rinne): implement it.
diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/RevivedLayer.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/RevivedLayer.cs
index 4df6613f..bca84a86 100644
--- a/src/TensorFlowNET.Keras/Saving/SavedModel/RevivedLayer.cs
+++ b/src/TensorFlowNET.Keras/Saving/SavedModel/RevivedLayer.cs
@@ -85,16 +85,16 @@ namespace Tensorflow.Keras.Saving.SavedModel
return _config;
}
- protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
- {
- if(SerializedAttributes is null || !SerializedAttributes.TryGetValue("__call__", out var func) || func is not Function)
- {
- return base.Call(inputs, state, training);
- }
- else
- {
- return (func as Function).Apply(inputs);
- }
- }
+ //protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
+ //{
+ // if(SerializedAttributes is null || !SerializedAttributes.TryGetValue("__call__", out var func) || func is not Function)
+ // {
+ // return base.Call(inputs, state, training);
+ // }
+ // else
+ // {
+ // return (func as Function).Apply(inputs);
+ // }
+ //}
}
}
diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/serialized_attributes.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/serialized_attributes.cs
index 9d611efe..0ec5d1a8 100644
--- a/src/TensorFlowNET.Keras/Saving/SavedModel/serialized_attributes.cs
+++ b/src/TensorFlowNET.Keras/Saving/SavedModel/serialized_attributes.cs
@@ -223,7 +223,7 @@ namespace Tensorflow.Keras.Saving.SavedModel
//base(checkpointable_objects.Concat(new string[] { "non_trainable_variables", "layers", "metrics", "layer_regularization_losses", "layer_metrics" }),
// functions.Concat(new string[] { "call_and_return_conditional_losses", "activity_regularizer_fn" })
base(checkpointable_objects.Concat(new string[] { "non_trainable_variables", "layers"}),
- functions.Concat(new string[] { }))
+ functions.Concat(new string[] { "call_and_return_conditional_losses", "activity_regularizer_fn" }))
{
}
diff --git a/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelLoad.cs b/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelLoad.cs
index cb230605..51962830 100644
--- a/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelLoad.cs
+++ b/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelLoad.cs
@@ -64,23 +64,19 @@ public class SequentialModelLoad
var model = tf.keras.models.load_model(@"Assets/python_func_model");
model.summary();
- var x = tf.random.uniform((8, 784), -1, 1);
- var y = model.Apply(x);
- Console.WriteLine(y);
+ model.compile(tf.keras.optimizers.Adam(), tf.keras.losses.SparseCategoricalCrossentropy(), new string[] { "accuracy" });
- //model.compile(tf.keras.optimizers.Adam(), tf.keras.losses.SparseCategoricalCrossentropy(), new string[] { "accuracy" });
-
- //var data_loader = new MnistModelLoader();
- //var num_epochs = 1;
- //var batch_size = 8;
+ var data_loader = new MnistModelLoader();
+ var num_epochs = 1;
+ var batch_size = 8;
- //var dataset = data_loader.LoadAsync(new ModelLoadSetting
- //{
- // TrainDir = "mnist",
- // OneHot = false,
- // ValidationSize = 58000,
- //}).Result;
+ var dataset = data_loader.LoadAsync(new ModelLoadSetting
+ {
+ TrainDir = "mnist",
+ OneHot = false,
+ ValidationSize = 55000,
+ }).Result;
- //model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs);
+ model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs);
}
}