| @@ -2,6 +2,7 @@ | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using Google.Protobuf; | |||
| using Protobuf.Text; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow.Contexts | |||
| @@ -12,18 +12,36 @@ namespace Tensorflow.Eager | |||
| return HasGradientTape(); | |||
| } | |||
| private bool ShouldRecord(Tensor[] inputs) | |||
| public int TFE_TapeSetPossibleGradientTypes(Tensor[] tensors) | |||
| { | |||
| bool should_record = false; | |||
| foreach (var tape in tf.GetTapeSet()) | |||
| var tape_set = tf.GetTapeSet(); | |||
| var input_ids = MakeTensorIDList(tensors); | |||
| var input_dtypes = MakeTensorDtypeList(tensors); | |||
| bool some_tape_watching = false; | |||
| if (tape_set is not null && tape_set.Count > 0) | |||
| { | |||
| if (tape.ShouldRecord(inputs)) | |||
| foreach (var tape in tape_set) | |||
| { | |||
| should_record = true; | |||
| break; | |||
| if (tape.ShouldRecord(input_ids, input_dtypes)) | |||
| { | |||
| if (tape.Persistent || some_tape_watching) | |||
| { | |||
| return gradients_util.POSSIBLE_GRADIENT_TYPES_HIGHER_ORDER; | |||
| } | |||
| some_tape_watching = true; | |||
| } | |||
| } | |||
| } | |||
| return should_record; | |||
| // skip the forward_accumulators. | |||
| if (some_tape_watching) | |||
| { | |||
| return gradients_util.POSSIBLE_GRADIENT_TYPES_FIRST_ORDER; | |||
| } | |||
| else | |||
| { | |||
| return gradients_util.POSSIBLE_GRADIENT_TYPES_NONE; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -13,7 +13,17 @@ namespace Tensorflow.Eager | |||
| Tensor[] results, | |||
| BackwardFunction backwardFunction = null) | |||
| { | |||
| bool should_record = ShouldRecord(inputs); | |||
| var input_ids = MakeTensorIDList(inputs); | |||
| var input_dtypes = MakeTensorDtypeList(inputs); | |||
| bool should_record = false; | |||
| foreach (var tape in tf.GetTapeSet()) | |||
| { | |||
| if (tape.ShouldRecord(input_ids, input_dtypes)) | |||
| { | |||
| should_record = true; | |||
| break; | |||
| } | |||
| } | |||
| if (!should_record) | |||
| { | |||
| @@ -59,7 +69,7 @@ namespace Tensorflow.Eager | |||
| op_inputs = inputs;*/ | |||
| backwardFunction = backwardFunction ?? GetGradientFunction(op_name, inputs, attrs, results); | |||
| TapeSetRecordOperation(op_name, inputs, results, backwardFunction); | |||
| TapeSetRecordOperation(op_name, inputs, results, input_ids, input_dtypes, backwardFunction); | |||
| return true; | |||
| } | |||
| @@ -129,10 +139,5 @@ namespace Tensorflow.Eager | |||
| { | |||
| return HasGradientTape(); | |||
| } | |||
| TF_DataType[] MakeTensorDtypeList(Tensor[] tensors) | |||
| { | |||
| return tensors.Select(x => x.dtype).ToArray(); | |||
| } | |||
| } | |||
| } | |||
| @@ -1,6 +1,8 @@ | |||
| using System; | |||
| using OneOf.Types; | |||
| using System; | |||
| using Tensorflow.Gradients; | |||
| using Tensorflow.Util; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow.Eager | |||
| { | |||
| @@ -9,40 +11,183 @@ namespace Tensorflow.Eager | |||
| /// </summary> | |||
| public partial class EagerRunner | |||
| { | |||
| /// <summary> | |||
| /// | |||
| /// </summary> | |||
| /// <param name="tape"></param> | |||
| /// <param name="target"></param> | |||
| /// <param name="sources"></param> | |||
| /// <param name="output_gradients"></param> | |||
| /// <param name="unconnected_gradients">determines the value returned if the target and | |||
| /// sources are unconnected.When 'none' the value returned is None wheras when | |||
| /// 'zero' a zero tensor in the same shape as the sources is returned.</param> | |||
| /// <returns></returns> | |||
| /// <exception cref="RuntimeError"></exception> | |||
| public Tensor[] TFE_TapeGradient(ITape tape, | |||
| Tensor[] target, | |||
| Tensor[] sources, | |||
| Tensor[] output_gradients) | |||
| List<Tensor> output_gradients, | |||
| Tensor[] sources_raw, | |||
| string unconnected_gradients) | |||
| { | |||
| var target_vec = target; | |||
| var sources_vec = sources; | |||
| var sources_set = sources_vec; | |||
| if (!tape.Persistent) | |||
| { | |||
| var tape_set = tf.GetTapeSet(); | |||
| if (tape_set.Contains(tape)) | |||
| { | |||
| throw new RuntimeError("gradient() cannot be invoked within the " + | |||
| "GradientTape context (i.e., while operations are being " + | |||
| "recorded). Either move the call to gradient() to be " + | |||
| "outside the 'with tf.GradientTape' block, or " + | |||
| "use a persistent tape: " + | |||
| "'with tf.GradientTape(persistent=true)'"); | |||
| } | |||
| } | |||
| var target_vec = MakeTensorIDList(target); | |||
| var sources_vec = MakeTensorIDList(sources); | |||
| HashSet<long> sources_set = new HashSet<long>(sources_vec); | |||
| var source_tensors_that_are_targets = new UnorderedMap<long, TapeTensor>(); | |||
| int len = target.Length; | |||
| for(int i = 0; i < len; i++) | |||
| { | |||
| var target_id = target_vec[i]; | |||
| if (sources_set.Contains(target_id)) | |||
| { | |||
| var tensor = target[i]; | |||
| source_tensors_that_are_targets[target_id] = TapeTensorFromTensor(tensor); | |||
| } | |||
| } | |||
| List<Tensor> outgrad_vec = new(); | |||
| if(output_gradients is not null) | |||
| { | |||
| outgrad_vec = output_gradients.ToList(); | |||
| } | |||
| var result = tape.ComputeGradient(target_vec, sources_vec, source_tensors_that_are_targets, outgrad_vec, false); | |||
| var seq_array = target; | |||
| var source_tensors_that_are_targets = new UnorderedMap<Tensor, TapeTensor>(); | |||
| for (int i = 0; i < target.Length; ++i) | |||
| bool unconnected_gradients_zero = unconnected_gradients == "zero"; | |||
| Tensor[] sources_obj = null; | |||
| if (unconnected_gradients_zero) | |||
| { | |||
| source_tensors_that_are_targets.Add(target_vec[i], new TapeTensor(seq_array[i])); | |||
| sources_obj = MakeTensorList(sources_raw); | |||
| } | |||
| if (output_gradients != null) | |||
| if (result.Length > 0) | |||
| { | |||
| throw new NotImplementedException(""); | |||
| for(int i = 0; i < result.Length; i++) | |||
| { | |||
| if (result[i] is null && unconnected_gradients_zero) | |||
| { | |||
| var dtype = sources_obj[i].dtype; | |||
| result[i] = new TapeTensor(sources_vec[i], dtype, sources_obj[i]).ZerosLike(); | |||
| } | |||
| } | |||
| } | |||
| else | |||
| return result; | |||
| } | |||
| Tensor[] MakeTensorList(IEnumerable<Tensor> tensors) | |||
| { | |||
| return tensors.ToArray(); | |||
| } | |||
| long[] MakeTensorIDList(Tensor[] tensors) | |||
| { | |||
| int len = tensors.Length; | |||
| long[] ids = new long[len]; | |||
| for(int i = 0; i < len; i++) | |||
| { | |||
| var tensor = tensors[i]; | |||
| ids[i] = tensor.Id; | |||
| } | |||
| return ids; | |||
| } | |||
| TF_DataType[] MakeTensorDtypeList(Tensor[] tensors) | |||
| { | |||
| int len = tensors.Length; | |||
| TF_DataType[] dtypes = new TF_DataType[len]; | |||
| for (int i = 0; i < len; i++) | |||
| { | |||
| output_gradients = new Tensor[0]; | |||
| var tensor = tensors[i]; | |||
| dtypes[i] = tensor.dtype; | |||
| } | |||
| return dtypes; | |||
| } | |||
| var outgrad_vec = MakeTensorList(output_gradients); | |||
| TapeTensor TapeTensorFromTensor(Tensor tensor) | |||
| { | |||
| long id = tensor.Id; | |||
| var dtype = tensor.dtype; | |||
| if (tensor is EagerTensor) | |||
| { | |||
| var handle = tensor.EagerTensorHandle; | |||
| if (DTypeNeedsHandleData(dtype)) | |||
| { | |||
| return new TapeTensor(id, c_api.TFE_TensorHandleDataType(handle), tensor); | |||
| } | |||
| Status status = new(); | |||
| int num_dims = c_api.TFE_TensorHandleNumDims(handle, status); | |||
| long[] dims = new long[num_dims]; | |||
| for(int i = 0; i < num_dims; i++) | |||
| { | |||
| dims[i] = c_api.TFE_TensorHandleDim(handle, i, status); | |||
| } | |||
| Shape tensor_shape = new(dims); | |||
| if(status.Code != TF_Code.TF_OK) | |||
| { | |||
| return new TapeTensor(id, TF_DataType.DtInvalid, Shape.Null); | |||
| } | |||
| else | |||
| { | |||
| return new TapeTensor(id, dtype, tensor_shape); | |||
| } | |||
| } | |||
| var shape_tuple = tensor.shape.dims; | |||
| if(ListContainNone(shape_tuple) || DTypeNeedsHandleData(dtype)) | |||
| { | |||
| return new TapeTensor(id, dtype, tensor); | |||
| } | |||
| long[] l = new long[shape_tuple.Length]; | |||
| for(int i = 0; i < shape_tuple.Length; i++) | |||
| { | |||
| if (shape_tuple[i] < 0) | |||
| { | |||
| l[i] = 0; | |||
| } | |||
| else | |||
| { | |||
| l[i] = shape_tuple[i]; | |||
| } | |||
| } | |||
| return new TapeTensor(id, dtype, new Shape(l)); | |||
| } | |||
| return tape.ComputeGradient(target_vec, sources_vec, source_tensors_that_are_targets, outgrad_vec); | |||
| bool DTypeNeedsHandleData(TF_DataType dtype) | |||
| { | |||
| return dtype == dtypes.variant || dtype == dtypes.resource; | |||
| } | |||
| Tensor[] MakeTensorList(Tensor[] tensors) | |||
| bool ListContainNone(long[] list) | |||
| { | |||
| return tensors; | |||
| int len = list.Length; | |||
| if(len == 0) | |||
| { | |||
| return true; | |||
| } | |||
| for(int i = 0; i < len; i++) | |||
| { | |||
| if (list[i] == -1) | |||
| { | |||
| return true; | |||
| } | |||
| } | |||
| return false; | |||
| } | |||
| } | |||
| } | |||
| @@ -7,8 +7,9 @@ namespace Tensorflow.Eager | |||
| public partial class EagerRunner | |||
| { | |||
| void TapeSetRecordBackprop(string op_type, | |||
| Tensor[] input_tensors, | |||
| TapeTensor[] output_tensors, | |||
| TapeTensor[] output_info, | |||
| long[] input_ids, | |||
| TF_DataType[] input_detyps, | |||
| BackwardFunction backward_function) | |||
| { | |||
| if (!CouldBackprop()) | |||
| @@ -18,7 +19,7 @@ namespace Tensorflow.Eager | |||
| foreach (var tape in tf.GetTapeSet()) | |||
| { | |||
| tape.RecordOperation(op_type, input_tensors, output_tensors, backward_function); | |||
| tape.RecordOperation(op_type, output_info, input_ids, input_detyps, backward_function); | |||
| } | |||
| } | |||
| } | |||
| @@ -10,18 +10,28 @@ namespace Tensorflow.Eager | |||
| public bool TapeSetRecordOperation(string op_type, | |||
| Tensor[] input_tensors, | |||
| Tensor[] output_tensors, | |||
| long[] input_ids, | |||
| TF_DataType[] input_dtypes, | |||
| BackwardFunction backward_function) | |||
| { | |||
| var output_info = output_tensors.Select(x => new TapeTensor(x)).ToArray(); | |||
| var output_info = output_tensors.Select(t => TapeTensorFromTensor(t)).ToArray(); | |||
| if (!TapeSetRecordForwardprop(op_type, input_tensors, output_info, | |||
| backward_function)) | |||
| return false; | |||
| TapeSetRecordBackprop(op_type, input_tensors, output_info, | |||
| TapeSetRecordBackprop(op_type, output_info, input_ids, input_dtypes, | |||
| backward_function); | |||
| return true; | |||
| } | |||
| public void TFE_TapeSetRecordOperation(string op_type, Tensor[] output_tensors, | |||
| Tensor[] input_tensors, BackwardFunction backward_function) | |||
| { | |||
| var input_ids = MakeTensorIDList(input_tensors); | |||
| var input_dtypes = MakeTensorDtypeList(input_tensors); | |||
| TapeSetRecordOperation(op_type, input_tensors, output_tensors, input_ids, input_dtypes, | |||
| backward_function); | |||
| } | |||
| } | |||
| } | |||
| @@ -29,7 +29,14 @@ namespace Tensorflow.Eager | |||
| Tensor[] TFE_TapeGradient(ITape tape, | |||
| Tensor[] target, | |||
| Tensor[] sources, | |||
| Tensor[] output_gradients); | |||
| List<Tensor> output_gradients, | |||
| Tensor[] sources_raw, | |||
| string unconnected_gradients); | |||
| void TFE_TapeSetRecordOperation(string op_type, Tensor[] output_tensors, | |||
| Tensor[] input_tensors, BackwardFunction backward_function); | |||
| int TFE_TapeSetPossibleGradientTypes(Tensor[] tensors); | |||
| bool RecordGradient(string op_name, | |||
| Tensor[] inputs, | |||
| @@ -18,12 +18,13 @@ namespace Tensorflow.Functions | |||
| public class ConcreteFunction: Trackable | |||
| { | |||
| protected IEnumerable<Tensor> _captured_inputs; | |||
| internal FuncGraph func_graph; | |||
| protected DelayedRewriteGradientFunctions _delayed_rewrite_functions; | |||
| protected Dictionary<string, AttrValue> _attrs; | |||
| protected FunctionSpec _function_spec; | |||
| protected FunctionSpec _pre_initialized_function_spec = null; | |||
| protected EagerDefinedFunction _inference_function; | |||
| protected Dictionary<string, TapeGradientFunctions> _tape_functions_cache = new(); | |||
| internal FuncGraph func_graph; | |||
| internal ForwardBackwardCall forward_backward; | |||
| public Tensor[] Inputs => func_graph.Inputs; | |||
| public Tensor[] CapturedInputs => func_graph.external_captures; | |||
| @@ -156,6 +157,17 @@ namespace Tensorflow.Functions | |||
| { | |||
| var executing_eagerly = tf.Context.executing_eagerly(); | |||
| var default_graph = ops.get_default_graph(); | |||
| // TODO(Rinne): deal with `default_graph.building_function` | |||
| var tempvv = func_graph.Variables; | |||
| if(tf.GetTapeSet().Count > 0 || default_graph is FuncGraph) | |||
| { | |||
| foreach(var v in this.func_graph.Variables) | |||
| { | |||
| resource_variable_ops.variable_accessed(v); | |||
| } | |||
| } | |||
| var tensor_inputs = new Tensors(); | |||
| foreach (var (i, arg) in enumerate(args)) | |||
| { | |||
| @@ -223,11 +235,16 @@ namespace Tensorflow.Functions | |||
| { | |||
| input_tangents = new TangentInfo(); | |||
| } | |||
| if(possible_gradient_type == gradients_util.POSSIBLE_GRADIENT_TYPES_FIRST_ORDER || tf.Runner.MustRecordGradient()) | |||
| if(possible_gradient_type == gradients_util.POSSIBLE_GRADIENT_TYPES_FIRST_ORDER) | |||
| { | |||
| if(input_tangents.Indices is not null || executing_eagerly) | |||
| { | |||
| var functions = new FirstOrderTapeGradientFunctions(func_graph, false); | |||
| string cache_key = "first_order"; | |||
| if(!_tape_functions_cache.TryGetValue(cache_key, out var functions)) | |||
| { | |||
| functions = new FirstOrderTapeGradientFunctions(func_graph, false); | |||
| _tape_functions_cache[cache_key] = functions; | |||
| } | |||
| return new ForwardBackwardCall(functions, args, tape_watching: true); | |||
| } | |||
| else | |||
| @@ -241,7 +258,7 @@ namespace Tensorflow.Functions | |||
| } | |||
| // TODO(Rinne): add arg "input_tagents" for ForwardBackwardCall. | |||
| return new ForwardBackwardCall(_delayed_rewrite_functions, args, tape_watching: tf.Runner.MustRecordGradient()); | |||
| return new ForwardBackwardCall(_delayed_rewrite_functions, args, tape_watching: false); | |||
| } | |||
| internal void set_variables(IEnumerable<IVariableV1> variables) | |||
| @@ -124,17 +124,16 @@ namespace Tensorflow.Functions | |||
| // TODO(Rinne): Add arg `CancellationManager`. | |||
| // TODO(Rinne): Check the arg length. | |||
| var function_call_options = tf.Context.FunctionCallOptions; | |||
| string config; | |||
| if (function_call_options.config_proto_serialized().Length == 0) | |||
| { | |||
| config = function_utils.get_disabled_rewriter_config().ToString(); | |||
| } | |||
| else | |||
| { | |||
| config = function_call_options.config_proto_serialized().ToString(); | |||
| } | |||
| string config = ""; // TODO(Rinne): revise it. The following code should work but not, for unclear reasons. | |||
| config = ""; // TODO(Rinne): revise it. | |||
| //if (function_call_options.config_proto_serialized().Length == 0) | |||
| //{ | |||
| // config = function_utils.get_disabled_rewriter_config().ToStringUtf8(); | |||
| //} | |||
| //else | |||
| //{ | |||
| // config = function_call_options.config_proto_serialized().ToStringUtf8(); | |||
| //} | |||
| string executor_type = function_call_options.ExecutorType ?? ""; | |||
| var executing_eagerly = tf.Context.executing_eagerly(); | |||
| @@ -14,12 +14,11 @@ namespace Tensorflow.Functions | |||
| } | |||
| public override EagerDefinedFunction ForwardAndBackwardFunctions(Tensors inference_args) | |||
| public override (EagerDefinedFunction, FuncGraph, ConcreteFunction, List<int>, int) | |||
| ForwardAndBackwardFunctions(Tensors inference_args) | |||
| { | |||
| var outputs = _func_graph.Outputs; | |||
| (_forward_function, _forward_graph, _backward_function, _forwardprop_output_indices, _num_forwardprop_outputs) | |||
| = BuildFunctionsForOutputs(outputs, inference_args); | |||
| return _forward_function; | |||
| var outputs = _func_graph.Outputs.Take(_num_inference_outputs).ToArray(); | |||
| return BuildFunctionsForOutputs(outputs, inference_args); | |||
| } | |||
| } | |||
| } | |||
| @@ -14,7 +14,6 @@ namespace Tensorflow | |||
| protected ConcreteFunction _concrete_variable_creation_fn; | |||
| protected bool _autograph; | |||
| protected TracingCompiler _variable_creation_fn; | |||
| protected bool _has_initialized; | |||
| public string Name { get; set; } | |||
| public Function(Func<Tensor[], Tensor[]> csharp_function, | |||
| string name, bool auto_graph = true) | |||
| @@ -22,7 +21,6 @@ namespace Tensorflow | |||
| _csharp_function = csharp_function; | |||
| Name = name; | |||
| _autograph = auto_graph; | |||
| _has_initialized = false; | |||
| } | |||
| public virtual Tensors Apply(Tensors inputs) | |||
| @@ -38,10 +36,11 @@ namespace Tensorflow | |||
| protected virtual Tensors _call(Tensors inputs) | |||
| { | |||
| if (!_has_initialized) | |||
| if(_variable_creation_fn is not null) | |||
| { | |||
| _initialize(inputs); | |||
| return _variable_creation_fn.Apply(inputs); | |||
| } | |||
| _initialize(inputs); | |||
| return _concrete_variable_creation_fn.CallFlat(inputs, | |||
| _concrete_variable_creation_fn.CapturedInputs); | |||
| @@ -63,7 +62,6 @@ namespace Tensorflow | |||
| _variable_creation_fn = _compiler(_csharp_function); | |||
| _variable_creation_fn._name = this.Name; | |||
| _concrete_variable_creation_fn = _variable_creation_fn._get_concrete_function_internal_garbage_collected(args); | |||
| _has_initialized = true; | |||
| } | |||
| } | |||
| } | |||
| @@ -24,23 +24,40 @@ namespace Tensorflow.Functions | |||
| protected string _INFERENCE_PREFIX = "__inference_"; | |||
| protected FuncGraph _func_graph; | |||
| protected EagerDefinedFunction _forward_function; | |||
| protected EagerDefinedFunction _forward; | |||
| protected FuncGraph _forward_graph; | |||
| protected List<int> _forwardprop_input_indices; | |||
| protected List<int> _forwardprop_output_indices; | |||
| protected int _num_forwardprop_outputs; | |||
| protected ConcreteFunction _backward_function; | |||
| protected int _num_inference_outputs; | |||
| protected int _num_outputs; | |||
| protected int _num_trainable_inference_outputs; | |||
| protected ConcreteFunction _backward; | |||
| BackwardFunction _backward_function_wrapper; | |||
| public TapeGradientFunctions(FuncGraph func_graph, | |||
| bool need_gradients_for_jvps) | |||
| { | |||
| _func_graph = func_graph; | |||
| _forward_graph = null; | |||
| _forward = null; | |||
| _backward = null; | |||
| _num_outputs = func_graph.Outputs.Length; | |||
| _forwardprop_output_indices = null; | |||
| _num_forwardprop_outputs = 0; | |||
| _num_inference_outputs = func_graph.Outputs.Length; | |||
| _num_trainable_inference_outputs = func_graph.Outputs.Where(t => backprop_util.IsTrainable(t)).Count(); | |||
| } | |||
| public virtual EagerDefinedFunction Forward(Tensors inference_args, Tensors input_tangents = null) | |||
| { | |||
| // TODO(Rinne): add input_tangents arg. | |||
| return ForwardAndBackwardFunctions(inference_args); | |||
| if(_forward is null) | |||
| { | |||
| (_forward, _forward_graph, _backward, _forwardprop_output_indices, _num_forwardprop_outputs) | |||
| = ForwardAndBackwardFunctions(inference_args); | |||
| } | |||
| return _forward; | |||
| } | |||
| /// <summary> | |||
| @@ -51,9 +68,13 @@ namespace Tensorflow.Functions | |||
| public virtual void Record(Tensors flat_outputs, Tensors inference_args) | |||
| { | |||
| // TODO(Rinne): add arg `input_tagents`. | |||
| var (backward_function, to_record) = _wrap_backward_function(_forward_graph, _backward_function, flat_outputs); | |||
| tf.Runner.RecordGradient(_forward_function.Name, inference_args, new object[0], to_record, | |||
| getBackwardFunction: backward_function); | |||
| var (backward_function, to_record) = _wrap_backward_function(_forward_graph, _backward, flat_outputs); | |||
| if(_forwardprop_output_indices is not null && _forwardprop_output_indices.Count > 0) | |||
| { | |||
| // TODO(Rinne): implement it. | |||
| throw new NotImplementedException(); | |||
| } | |||
| tf.Runner.TFE_TapeSetRecordOperation(_forward.Signature.Name, to_record, inference_args, backward_function); | |||
| } | |||
| /// <summary> | |||
| @@ -65,66 +86,95 @@ namespace Tensorflow.Functions | |||
| /// <returns></returns> | |||
| (BackwardFunction, Tensors) _wrap_backward_function(FuncGraph forward_graph, ConcreteFunction backward, Tensors outputs) | |||
| { | |||
| var capture_mapping = zip(forward_graph.Outputs.Select(t => ops.tensor_id(t)), outputs) | |||
| .ToDictionary(x => x.Item1, x => x.Item2); | |||
| var captured_inputs = backward.CapturedInputs; | |||
| var remapped_captures = captured_inputs.Select(c => | |||
| { | |||
| if (capture_mapping.TryGetValue(ops.tensor_id(c), out var value)) | |||
| { | |||
| return value; | |||
| } | |||
| else | |||
| { | |||
| return c; | |||
| } | |||
| }).ToArray(); | |||
| if(remapped_captures.Where(t => t is not EagerTensor).Any(t => t.graph == forward_graph)) | |||
| { | |||
| var incorrect_mapping = remapped_captures.Where(t => t is not EagerTensor && t.graph != forward_graph); | |||
| throw new RuntimeError($"Failed to map all backward graph captures to " + | |||
| $"the forward graph. Incorrectly mapped: {string.Join(", ", incorrect_mapping)}"); | |||
| } | |||
| Dictionary<int, Tensor> variant_zeros_like = new Dictionary<int, Tensor>(); | |||
| var backward_function_inputs = backward.Inputs.Length - backward.CapturedInputs.Length; | |||
| var recorded_outputs = new Tensors(); | |||
| var trainable_recorded_outputs = 0; | |||
| foreach (var (output_index, output) in enumerate(outputs)) | |||
| int trainable_recorded_outputs = 0; | |||
| var skip_positions = new HashSet<int>(); | |||
| var relevant_outputs = outputs; | |||
| foreach (var (output_index, output) in enumerate(relevant_outputs)) | |||
| { | |||
| if (trainable_recorded_outputs < backward_function_inputs) | |||
| recorded_outputs.Add(output); | |||
| if (gradients_util.IsTrainable(output)) | |||
| trainable_recorded_outputs += 1; | |||
| if (backprop_util.IsTrainable(output)) | |||
| trainable_recorded_outputs++; | |||
| else | |||
| skip_positions.Add(output_index); | |||
| if (output.dtype == dtypes.variant) | |||
| variant_zeros_like[output_index] = default_gradient.zeros_like(output); | |||
| } | |||
| if(_backward_function_wrapper == null) | |||
| _backward_function_wrapper = (args, unneeded_gradients) => | |||
| { | |||
| var capture_mapping = new Dictionary<long, Tensor>(); | |||
| foreach (var (i, output) in enumerate(outputs)) | |||
| capture_mapping[forward_graph.Outputs[i].Id] = output; | |||
| var remapped_captures = new Tensors(); | |||
| foreach (var capture in backward.CapturedInputs) | |||
| { | |||
| if (capture_mapping.ContainsKey(capture.Id)) | |||
| remapped_captures.Add(capture_mapping[capture.Id]); | |||
| } | |||
| var skip_positions = new List<int>(); | |||
| foreach (var (output_index, output) in enumerate(outputs)) | |||
| if(backward.Outputs is null || backward.Outputs.Length == 0) | |||
| { | |||
| if (!gradients_util.IsTrainable(output)) | |||
| skip_positions.Add(output_index); | |||
| return backward.FlatStructuredOutputs; | |||
| } | |||
| _backward_function_wrapper = (args, unneeded_gradients) => | |||
| var processed_args = new Tensors(); | |||
| int input_index = 0; | |||
| foreach (var (output_index, arg) in enumerate(args)) | |||
| { | |||
| var processed_args = new Tensors(); | |||
| var input_index = 0; | |||
| foreach (var (output_index, arg) in enumerate(args)) | |||
| if (skip_positions.Contains(output_index)) | |||
| continue; | |||
| if (arg is null) | |||
| { | |||
| var input_placeholder = backward.Inputs[input_index]; | |||
| Tensor variant_arg; | |||
| if (input_placeholder.dtype == dtypes.variant) | |||
| { | |||
| variant_arg = variant_zeros_like[output_index]; | |||
| } | |||
| else | |||
| { | |||
| var (shape, type) = default_gradient.shape_and_dtype(input_placeholder); | |||
| variant_arg = array_ops.zeros(shape, type); | |||
| } | |||
| processed_args.Add(variant_arg); | |||
| } | |||
| else | |||
| { | |||
| if (skip_positions.Contains(output_index)) | |||
| continue; | |||
| if (arg == null) | |||
| throw new NotImplementedException(""); | |||
| processed_args.Add(arg); | |||
| input_index += 1; | |||
| if (input_index >= backward_function_inputs) | |||
| break; | |||
| } | |||
| input_index++; | |||
| if (input_index >= backward_function_inputs) | |||
| break; | |||
| } | |||
| tf.Logger.Debug($"Invoke backward function: {backward.Name}"); | |||
| var gradients = backward.CallFlat(processed_args, remapped_captures); | |||
| tf.Logger.Debug($"Invoke backward function: {backward.Name}"); | |||
| var gradients = backward.CallFlat(processed_args, remapped_captures); | |||
| foreach (var unneeded_gradient_index in unneeded_gradients) | |||
| { | |||
| var index = Convert.ToInt32(unneeded_gradient_index); | |||
| if (gradients.Length <= index) | |||
| gradients.Insert(index, null); | |||
| } | |||
| foreach (var unneeded_gradient_index in unneeded_gradients) | |||
| { | |||
| var index = Convert.ToInt32(unneeded_gradient_index); | |||
| if (gradients.Length <= index) | |||
| gradients.Insert(index, null); | |||
| } | |||
| return gradients; | |||
| }; | |||
| } | |||
| return gradients; | |||
| }; | |||
| return (_backward_function_wrapper, recorded_outputs); | |||
| } | |||
| @@ -143,7 +193,7 @@ namespace Tensorflow.Functions | |||
| } | |||
| } | |||
| var backwards_graph = new FuncGraph(_func_graph.Name); | |||
| var backwards_graph = new FuncGraph(monomorphic_function_utils._backward_name(_func_graph.Name)); | |||
| backwards_graph.as_default(); | |||
| var gradients_wrt_outputs = new List<Tensor>(); | |||
| foreach (var output in trainable_outputs) | |||
| @@ -153,6 +203,7 @@ namespace Tensorflow.Functions | |||
| gradients_wrt_outputs.Add(gradient_placeholder); | |||
| handle_data_util.copy_handle_data(output, gradient_placeholder); | |||
| } | |||
| // TODO(Rinne): with ops.device(None) | |||
| var gradients_wrt_inputs = gradients_util._GradientsHelper(trainable_outputs.ToArray(), | |||
| _func_graph.Inputs, | |||
| grad_ys: gradients_wrt_outputs.ToArray(), | |||
| @@ -175,7 +226,8 @@ namespace Tensorflow.Functions | |||
| backwards_graph.Inputs = gradients_wrt_outputs.Concat(backwards_graph.internal_captures).ToArray(); | |||
| backwards_graph.Outputs.AddRange(gradients_wrt_inputs.Where(x => x is not null)); | |||
| var (forward_function, backward_function) = monomorphic_function_utils._create_forward_backward_with_graph(null, _func_graph, backwards_graph); | |||
| var (wrapped_forward_function, wrapped_backward_function) = | |||
| monomorphic_function_utils._create_forward_backward_with_graph(null, _func_graph, backwards_graph); | |||
| //var forward_function_name = $"{_FORWARD_PREFIX}_{_func_graph.FuncName}_{ops.uid()}"; | |||
| //var backward_function_attr = new Dictionary<string, string>(); | |||
| //backward_function_attr[FORWARD_FUNCTION_ATTRIBUTE_NAME] = forward_function_name; | |||
| @@ -189,10 +241,11 @@ namespace Tensorflow.Functions | |||
| // _func_graph.Inputs, _func_graph.Outputs, | |||
| // monomorphic_function_utils._parse_func_attrs(forward_function_attr)); | |||
| return (forward_function, _func_graph, backward_function, null, 0); | |||
| return (wrapped_forward_function, _func_graph, wrapped_backward_function, null, 0); | |||
| } | |||
| public virtual EagerDefinedFunction ForwardAndBackwardFunctions(Tensors inference_args) | |||
| public virtual (EagerDefinedFunction, FuncGraph, ConcreteFunction, List<int>, int) | |||
| ForwardAndBackwardFunctions(Tensors inference_args) | |||
| { | |||
| throw new NotImplementedException(""); | |||
| } | |||
| @@ -73,12 +73,12 @@ namespace Tensorflow.Functions | |||
| private static string male_cache_key(Tensor[] inputs) | |||
| { | |||
| string res = ""; | |||
| foreach (var input in inputs) | |||
| { | |||
| res += $"{input.name}_{input.Id}"; | |||
| } | |||
| return res; | |||
| //string res = ""; | |||
| //foreach (var input in inputs) | |||
| //{ | |||
| // res += $"{input.name}_{input.Id}"; | |||
| //} | |||
| return inputs.Length.ToString(); | |||
| } | |||
| } | |||
| } | |||
| @@ -153,7 +153,7 @@ namespace Tensorflow.Functions | |||
| foreach(var tape in tf.GetTapeSet()) | |||
| { | |||
| tape.RecordOperation(_inference_function.Signature.Name, to_record, | |||
| inference_args.Select(t => new TapeTensor(t)).ToArray(), backward_function); | |||
| inference_args, backward_function); | |||
| } | |||
| } | |||
| @@ -9,7 +9,7 @@ namespace Tensorflow.Gradients | |||
| /// Map from tensor to how many references still exist for this tensor in | |||
| /// the tape. | |||
| /// </summary> | |||
| public UnorderedMap<Tensor, long> tensor_usage_counts { get; set; } | |||
| public UnorderedMap<long, long> tensor_usage_counts { get; set; } | |||
| /// <summary> | |||
| /// Maps from op ID to how many output tensors of this op still need to have | |||
| /// their gradients computed. | |||
| @@ -19,7 +19,7 @@ namespace Tensorflow.Gradients | |||
| public BackpropInitialState() | |||
| { | |||
| op_tape = new OpTape(); | |||
| tensor_usage_counts = new UnorderedMap<Tensor, long>(); | |||
| tensor_usage_counts = new UnorderedMap<long, long>(); | |||
| op_missing_tensor = new UnorderedMap<long, long>(); | |||
| } | |||
| } | |||
| @@ -67,40 +67,59 @@ namespace Tensorflow.Gradients | |||
| /// <param name="target"></param> | |||
| /// <param name="source"></param> | |||
| /// <returns></returns> | |||
| public Tensor gradient(Tensor target, Tensor source) | |||
| public Tensor gradient(Tensor target, Tensor source, List<Tensor> output_gradients = null, | |||
| string unconnected_gradients = null) | |||
| { | |||
| if(_tape is null) | |||
| { | |||
| throw new RuntimeError("A non-persistent GradientTape can only be used to " + | |||
| "compute one set of gradients (or jacobians)."); | |||
| } | |||
| ITape tape = stop_recording(); | |||
| var results = tf.Runner.TFE_TapeGradient(tape, | |||
| new[] { target }, | |||
| new[] { source }, | |||
| null); | |||
| output_gradients, | |||
| new[] { source }, | |||
| unconnected_gradients); | |||
| return results[0]; | |||
| } | |||
| public Tensor gradient(Tensor target, ResourceVariable source) | |||
| public Tensor gradient(Tensor target, ResourceVariable source, List<Tensor> output_gradients = null, | |||
| string unconnected_gradients = null) | |||
| { | |||
| var results = gradient(target, new List<IVariableV1> { source }); | |||
| var results = gradient(target, new List<IVariableV1> { source }, output_gradients, unconnected_gradients); | |||
| return results[0]; | |||
| } | |||
| public (Tensor, Tensor) gradient(Tensor target, (ResourceVariable, ResourceVariable) sources) | |||
| public (Tensor, Tensor) gradient(Tensor target, (ResourceVariable, ResourceVariable) sources, List<Tensor> output_gradients = null, | |||
| string unconnected_gradients = null) | |||
| { | |||
| var results = gradient(target, new List<IVariableV1> { sources.Item1, sources.Item2 }); | |||
| var results = gradient(target, new List<IVariableV1> { sources.Item1, sources.Item2 }, output_gradients, unconnected_gradients); | |||
| return (results[0], results[1]); | |||
| } | |||
| public Tensor[] gradient(Tensor target, IEnumerable<IVariableV1> sources) | |||
| public Tensor[] gradient(Tensor target, IEnumerable<IVariableV1> sources, List<Tensor> output_gradients = null, | |||
| string unconnected_gradients = null) | |||
| { | |||
| if (_tape is null) | |||
| { | |||
| throw new RuntimeError("A non-persistent GradientTape can only be used to " + | |||
| "compute one set of gradients (or jacobians)."); | |||
| } | |||
| var tape = stop_recording(); | |||
| var results = tf.Runner.TFE_TapeGradient(tape, | |||
| new[] { target }, | |||
| sources.Select(x => x.Handle).ToArray(), | |||
| null); | |||
| output_gradients, | |||
| sources.Select(x => x.Handle).ToArray(), | |||
| unconnected_gradients); | |||
| if (!tape.Persistent) | |||
| { | |||
| @@ -6,24 +6,31 @@ namespace Tensorflow.Gradients | |||
| public interface ITape | |||
| { | |||
| void SetTapeId(int id); | |||
| bool ShouldRecord(Tensor[] tensors); | |||
| bool ShouldRecord(long[] tensor_ids, TF_DataType[] tensor_dtypes); | |||
| void StartRecord(); | |||
| void StopRecord(); | |||
| bool Persistent { get; } | |||
| void RecordOperation(string op_type, | |||
| Tensor[] input_tensors, | |||
| TapeTensor[] output_tensors, | |||
| long[] input_tensor_id, | |||
| TF_DataType[] input_dtypes, | |||
| BackwardFunction backward_function); | |||
| void VariableAccessed(ResourceVariable variable); | |||
| void RecordOperation(string op_type, | |||
| Tensor[] outputs, | |||
| Tensor[] inputs, | |||
| BackwardFunction backward_function); | |||
| void VariableAccessed(IVariableV1 variable); | |||
| void Watch(Tensor x); | |||
| ResourceVariable[] WatchedVariables(); | |||
| IVariableV1[] WatchedVariables(); | |||
| Tensor[] ComputeGradient(Tensor[] target_tensor_ids, | |||
| Tensor[] source_tensor_ids, | |||
| UnorderedMap<Tensor, TapeTensor> sources_that_are_targets, | |||
| Tensor[] output_gradients); | |||
| Tensor[] ComputeGradient(long[] target_tensor_ids, | |||
| long[] source_tensor_ids, | |||
| UnorderedMap<long, TapeTensor> sources_that_are_targets, | |||
| List<Tensor> output_gradients, | |||
| bool build_default_zeros_grads); | |||
| } | |||
| } | |||
| @@ -9,9 +9,9 @@ namespace Tensorflow.Gradients | |||
| { | |||
| public string op_type { get; set; } | |||
| public TapeTensor[] output_tensor_info { get; set; } | |||
| public Tensor[] input_tensor_id { get; set; } | |||
| public long[] input_tensor_id { get; set; } | |||
| public BackwardFunction backward_function { get; set; } | |||
| public override string ToString() | |||
| => $"{op_type}, inputs: {string.Join(",", input_tensor_id.Select(x => x.Id))}"; | |||
| => $"{op_type}, inputs: {string.Join(",", input_tensor_id)}"; | |||
| } | |||
| } | |||
| @@ -2,235 +2,246 @@ | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| using Tensorflow.Util; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow.Gradients | |||
| { | |||
| public partial class Tape | |||
| { | |||
| // int kMinAggregateCount = 4; | |||
| // int kMinAggregateBytes = 128 * 1024 * 1024; | |||
| static readonly int kMinAggregateCount = 4; | |||
| static readonly int kMinAggregateBytes = 128 * 1024 * 1024; | |||
| private static UnorderedMap<string, UnorderedSet<int>> _functionsAcceptingNoneForIndicesMap; | |||
| public Tensor[] ComputeGradient(Tensor[] target_tensor_ids, | |||
| Tensor[] source_tensor_ids, | |||
| UnorderedMap<Tensor, TapeTensor> sources_that_are_targets, | |||
| Tensor[] output_gradients) | |||
| static Tape() | |||
| { | |||
| var sources_set = new UnorderedSet<Tensor>(source_tensor_ids); | |||
| // var gradients_size = new UnorderedMap<Tensor, long>(); | |||
| var functionsAcceptingNoneForIndicesMap = FunctionsAcceptingNoneForIndicesMap(); | |||
| var state = PrepareBackprop( | |||
| target_tensor_ids, tensor_tape_, op_tape_, sources_set, _persistent); | |||
| var op_stack = InitialStack(state.op_tape, state.op_missing_tensor); | |||
| var gradients = InitialGradients(target_tensor_ids, sources_that_are_targets, | |||
| output_gradients, | |||
| tensor_tape_, | |||
| state.op_tape); | |||
| _functionsAcceptingNoneForIndicesMap = new(); | |||
| _functionsAcceptingNoneForIndicesMap.Add("SoftmaxCrossEntropyWithLogits", new UnorderedSet<int>(new[] { 1 })); | |||
| _functionsAcceptingNoneForIndicesMap.Add("SparseSoftmaxCrossEntropyWithLogits", new UnorderedSet<int>(new[] { 1 })); | |||
| _functionsAcceptingNoneForIndicesMap.Add("FusedBatchNorm", new UnorderedSet<int>(new[] { 1, 2, 3, 4 })); | |||
| } | |||
| while (!op_stack.empty()) | |||
| public Tensor[] ComputeGradient(long[] target_tensor_ids, | |||
| long[] source_tensor_ids, | |||
| UnorderedMap<long, TapeTensor> sources_that_are_targets, | |||
| List<Tensor> output_gradients, | |||
| bool build_default_zeros_grads) | |||
| { | |||
| UnorderedSet<long> sources_set = new(source_tensor_ids); | |||
| BackpropInitialState state = PrepareBackprop(target_tensor_ids, tensor_tape_, op_tape_, sources_set, Persistent); | |||
| var op_stack = InitialStack(state.op_tape, state.op_missing_tensor); | |||
| var gradients = InitialGradients(target_tensor_ids, sources_that_are_targets, output_gradients, tensor_tape_, state.op_tape); | |||
| UnorderedMap<long, long> gradients_size = new(); | |||
| while(op_stack.Count > 0) | |||
| { | |||
| var op = op_stack.Dequeue(); | |||
| if (!state.op_tape.find(op, out var trace)) | |||
| long op = op_stack.Dequeue(); | |||
| if(!state.op_tape.TryGetValue(op, out var op_it)) | |||
| { | |||
| continue; | |||
| // Console.WriteLine($"ComputeGradient: {state.op_tape[op].op_type}"); | |||
| } | |||
| var trace = op_it; | |||
| state.op_tape.erase(op); | |||
| var out_gradients = new List<Tensor>(trace.output_tensor_info.Length); | |||
| var unneeded_gradients = new List<long>(); | |||
| for (int i = 0; i < trace.input_tensor_id.Length; i++) | |||
| List<Tensor> out_gradients = new(); | |||
| List<long> unneeded_gradients = new(); | |||
| for(int i = 0, end = trace.input_tensor_id.Length; i < end; i++) | |||
| { | |||
| var in_tensor_id = trace.input_tensor_id[i]; | |||
| if (!tensor_tape_.find(in_tensor_id) && | |||
| !sources_set.find(in_tensor_id)) | |||
| long in_tensor_id = trace.input_tensor_id[i]; | |||
| if(!tensor_tape_.find(in_tensor_id) && !sources_set.find(in_tensor_id)) | |||
| { | |||
| unneeded_gradients.Add(i); | |||
| } | |||
| } | |||
| bool any_gradient_nonzero = false; | |||
| var zero_indices = new List<int>(); | |||
| for (int i = 0; i < trace.output_tensor_info.Length; ++i) | |||
| List<int> zero_indices = new(); | |||
| for(int i = 0, end = trace.output_tensor_info.Length; i < end; i++) | |||
| { | |||
| var id = trace.output_tensor_info[i].GetTensor(); | |||
| if (!gradients.find(id, out var grad_it)) | |||
| long id = trace.output_tensor_info[i].GetID(); | |||
| if(!gradients.TryGetValue(id, out var grad_it)) | |||
| { | |||
| if (functionsAcceptingNoneForIndicesMap.find(trace.op_type, out var func_name_it) && | |||
| func_name_it.find(i)) | |||
| out_gradients.Add(null); | |||
| if (build_default_zeros_grads) | |||
| { | |||
| out_gradients.Add(null); | |||
| } | |||
| else | |||
| { | |||
| out_gradients.Add(null); | |||
| zero_indices.Add(i); | |||
| if(!_functionsAcceptingNoneForIndicesMap.TryGetValue(trace.op_type, out var func_name_it) || | |||
| !func_name_it.find(i)) | |||
| { | |||
| zero_indices.Add(i); | |||
| } | |||
| } | |||
| } | |||
| else | |||
| { | |||
| any_gradient_nonzero = true; | |||
| var new_gradients = grad_it.Count == 1 ? | |||
| grad_it[0] : | |||
| gen_math_ops.add_n(grad_it.ToArray()); // vspace.AggregateGradients | |||
| Tensor new_gradients; | |||
| if (grad_it.Count == 1) | |||
| { | |||
| new_gradients = grad_it[0]; | |||
| } | |||
| else | |||
| { | |||
| new_gradients = AggregateGradients(grad_it); | |||
| } | |||
| if (!sources_set.find(id)) | |||
| { | |||
| gradients.Remove(id); | |||
| } | |||
| else | |||
| { | |||
| // grad_it.Clear(); | |||
| // grad_it.Add(new_gradients); | |||
| // vspace.MarkAsResult(new_gradients); | |||
| grad_it.Clear(); | |||
| grad_it.Add(new_gradients); | |||
| // MarkAsResult | |||
| } | |||
| out_gradients.Add(new_gradients); | |||
| } | |||
| } | |||
| Tensor[] in_gradients; | |||
| Tensor[] in_gradients = new Tensor[0]; | |||
| if (any_gradient_nonzero) | |||
| { | |||
| // foreach (var i in zero_indices) | |||
| // out_gradients[i] = trace.output_tensor_info[i].ZerosLike(); | |||
| in_gradients = trace.backward_function(out_gradients.ToArray(), unneeded_gradients.ToArray()); | |||
| if (in_gradients.Length != trace.input_tensor_id.Length && in_gradients.Length + unneeded_gradients.Count != trace.input_tensor_id.Length) | |||
| throw new RuntimeError($"Recorded operation '{trace.op_type}' returned too few gradients. Expected {trace.input_tensor_id.Length} but received {in_gradients.Count()}"); | |||
| if (!_persistent) | |||
| foreach(var i in zero_indices) | |||
| { | |||
| // trace.backward_function_deleter(trace.backward_function); | |||
| trace.backward_function = null; | |||
| out_gradients[i] = trace.output_tensor_info[i].ZerosLike(); | |||
| } | |||
| in_gradients = CallBackwardFunction(trace.backward_function, unneeded_gradients, out_gradients); | |||
| } | |||
| else | |||
| { | |||
| in_gradients = new Tensor[trace.input_tensor_id.Length]; | |||
| out_gradients.Clear(); | |||
| } | |||
| bool skip_unneeded_id = trace.input_tensor_id.Length > in_gradients.Length; | |||
| for (int i = 0, k = 0; i < in_gradients.Length && k < trace.input_tensor_id.Count(); ++i, ++k) | |||
| for(int i = 0, end = in_gradients.Length; i < end; i++) | |||
| { | |||
| if (skip_unneeded_id && unneeded_gradients.Contains(k)) ++k; | |||
| var id = trace.input_tensor_id[k]; | |||
| if (in_gradients[i] != null) | |||
| long id = trace.input_tensor_id[i]; | |||
| if (in_gradients[i] is not null) | |||
| { | |||
| var unaggregated_grads = gradients[id]; | |||
| var unaggregated_grads = gradients.SetDefault(id, new List<Tensor>()); | |||
| unaggregated_grads.Add(in_gradients[i]); | |||
| /*if (unaggregated_grads.Count > kMinAggregateCount) | |||
| if(unaggregated_grads.Count > kMinAggregateCount) | |||
| { | |||
| if (!gradients_size.find(id, out var size)) | |||
| if(!gradients_size.TryGetValue(id, out var size)) | |||
| { | |||
| size = (long)unaggregated_grads[0].size; | |||
| size = NumElements(unaggregated_grads[0]); | |||
| gradients_size.emplace(id, size); | |||
| } | |||
| if (unaggregated_grads.Count * size * 4 > kMinAggregateBytes) | |||
| if(unaggregated_grads.Count * size * 4 > kMinAggregateBytes) | |||
| { | |||
| throw new NotImplementedException(""); | |||
| Tensor grad = AggregateGradients(unaggregated_grads); | |||
| unaggregated_grads.Clear(); | |||
| unaggregated_grads.Add(grad); | |||
| } | |||
| }*/ | |||
| } | |||
| } | |||
| if (!state.tensor_usage_counts.find(id)) | |||
| if(!state.tensor_usage_counts.find(id)) | |||
| { | |||
| continue; | |||
| } | |||
| state.tensor_usage_counts[id]--; | |||
| if (state.tensor_usage_counts[id] > 0) | |||
| if(state.tensor_usage_counts[id] > 0) | |||
| { | |||
| continue; | |||
| if (!tensor_tape_.find(id, out var tape_it)) | |||
| } | |||
| if (!tensor_tape_.TryGetValue(id, out var tape_it)) | |||
| { | |||
| if (gradients.find(id, out var grad_it)) | |||
| if (gradients.find(id)) | |||
| { | |||
| // foreach (var g in grad_it) | |||
| // DeleteGradient(g); | |||
| gradients.erase(id); | |||
| } | |||
| continue; | |||
| } | |||
| var op_id = tape_it; | |||
| if (op_id == -1) | |||
| long op_id = tape_it; | |||
| if(op_id == -1) | |||
| { | |||
| continue; | |||
| if (state.op_missing_tensor.find(op_id, out var missing_it)) | |||
| } | |||
| if(state.op_missing_tensor.find(op_id)) | |||
| { | |||
| state.op_missing_tensor[op_id]--; | |||
| if (state.op_missing_tensor[op_id] == 0) | |||
| if(state.op_missing_tensor[op_id] == 0) | |||
| { | |||
| op_stack.Enqueue(op_id); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| if (state.op_tape.Count > 0) | |||
| if(state.op_tape.Count > 0) | |||
| { | |||
| throw new RuntimeError("Invalid tape state."); | |||
| var result = new Tensor[source_tensor_ids.Length]; | |||
| var j = 0; | |||
| foreach (var id in source_tensor_ids) | |||
| } | |||
| Tensor[] result = new Tensor[source_tensor_ids.Length]; | |||
| for(int i = 0; i < source_tensor_ids.Length; i++) | |||
| { | |||
| if (gradients.find(id, out var grad_it)) | |||
| long tensor_id = source_tensor_ids[i]; | |||
| if(!gradients.TryGetValue(tensor_id, out var grad_it)) | |||
| { | |||
| if (grad_it.Count > 1) | |||
| result[j] = gen_math_ops.add_n(grad_it.ToArray()); | |||
| else | |||
| result[j] = grad_it[0]; | |||
| result[i] = null; | |||
| } | |||
| else | |||
| { | |||
| if(grad_it.Count > 1) | |||
| { | |||
| Tensor grad = AggregateGradients(grad_it); | |||
| grad_it.Clear(); | |||
| grad_it.Add(grad); | |||
| } | |||
| result[i] = grad_it[0]; | |||
| } | |||
| j++; | |||
| } | |||
| return result; | |||
| } | |||
| UnorderedMap<string, UnorderedSet<int>> FunctionsAcceptingNoneForIndicesMap() | |||
| { | |||
| var m = new UnorderedMap<string, UnorderedSet<int>>(); | |||
| m.Add("SoftmaxCrossEntropyWithLogits", new UnorderedSet<int>(new[] { 1 })); | |||
| m.Add("SparseSoftmaxCrossEntropyWithLogits", new UnorderedSet<int>(new[] { 1 })); | |||
| m.Add("FusedBatchNorm", new UnorderedSet<int>(new[] { 1, 2, 3, 4 })); | |||
| return m; | |||
| return _functionsAcceptingNoneForIndicesMap; | |||
| } | |||
| UnorderedMapEnumerable<Tensor, List<Tensor>> InitialGradients(Tensor[] target_tensor_ids, | |||
| UnorderedMap<Tensor, TapeTensor> sources_that_are_targets, | |||
| Tensor[] output_gradients, | |||
| UnorderedMap<long, List<Tensor>> InitialGradients(long[] target_tensor_ids, | |||
| UnorderedMap<long, TapeTensor> sources_that_are_targets, | |||
| List<Tensor> output_gradients, | |||
| TensorTape tensor_tape, | |||
| OpTape op_tape) | |||
| { | |||
| var result = new UnorderedMapEnumerable<Tensor, List<Tensor>>(); | |||
| for (int i = 0; i < target_tensor_ids.Length; ++i) | |||
| var result = new UnorderedMap<long, List<Tensor>>(); | |||
| for(int i = 0, end = target_tensor_ids.Length; i < end; i++) | |||
| { | |||
| var id = target_tensor_ids[i]; | |||
| if (output_gradients.Length == 0 || output_gradients[i] == null) | |||
| long id = target_tensor_ids[i]; | |||
| if( output_gradients is null ||output_gradients.Count == 0 || output_gradients[i] is null) | |||
| { | |||
| if (tensor_tape.find(id, out var tensor_id) && tensor_id != null) | |||
| if(tensor_tape.TryGetValue(id, out var tensor_it) && tensor_it != -1) | |||
| { | |||
| if (!op_tape.find(tensor_tape[id], out var op_it)) | |||
| if(!op_tape.TryGetValue(tensor_it, out var op_it)) | |||
| { | |||
| throw new RuntimeError("Internal state of the gradient tape is invalid: " + | |||
| "failed to find operation producing a tensor"); | |||
| "failed to find operation producing a tensor."); | |||
| } | |||
| bool found = false; | |||
| for (int j = 0; j < op_it.output_tensor_info.Length; ++j) | |||
| for(int j = 0; j < op_it.output_tensor_info.Length; j++) | |||
| { | |||
| if (op_it.output_tensor_info[j].GetTensor() == id) | |||
| if (op_it.output_tensor_info[j].GetID() == id) | |||
| { | |||
| found = true; | |||
| var ones = op_it.output_tensor_info[j].OnesLike(); | |||
| result[id].Add(ones); | |||
| Tensor ones_like = BuildOnesLike(op_it.output_tensor_info[j]); | |||
| result.SetDefault(id, new List<Tensor>()).Add(ones_like); | |||
| break; | |||
| } | |||
| } | |||
| if (!found) | |||
| { | |||
| throw new ValueError("Internal state of the gradient tape is invalid: " + | |||
| "none of operations outputs match expected tensor"); | |||
| throw new RuntimeError("Internal state of the gradient tape is invalid: " + | |||
| "none of operations outputs match expected tensor."); | |||
| } | |||
| } | |||
| else | |||
| { | |||
| if (sources_that_are_targets.find(id, out var source_tensor)) | |||
| result[id].Add(source_tensor.OnesLike()); | |||
| if(sources_that_are_targets.TryGetValue(id, out var source_tensor)) | |||
| { | |||
| Tensor ones_like = BuildOnesLike(source_tensor); | |||
| result.SetDefault(id, new List<Tensor>()).Add(ones_like); | |||
| } | |||
| } | |||
| } | |||
| else | |||
| { | |||
| result[id].Add(output_gradients[i]); | |||
| result.SetDefault(id, new List<Tensor>()).Add(output_gradients[i]); | |||
| } | |||
| } | |||
| @@ -248,5 +259,26 @@ namespace Tensorflow.Gradients | |||
| } | |||
| return result; | |||
| } | |||
| Tensor BuildOnesLike(TapeTensor t) | |||
| { | |||
| return t.OnesLike(); | |||
| } | |||
| Tensor AggregateGradients(List<Tensor> gradient_tensors) | |||
| { | |||
| if(gradient_tensors.Count == 0) | |||
| { | |||
| return gradient_tensors[0]; | |||
| } | |||
| return tf.add_n(gradient_tensors.ToArray()); | |||
| } | |||
| void DeleteGradient(Tensor gradient) | |||
| { | |||
| // Do not do anything here. Because GC will collect it when it has no reference. | |||
| } | |||
| long NumElements(Tensor tensor) => 1; | |||
| } | |||
| } | |||
| @@ -5,63 +5,62 @@ namespace Tensorflow.Gradients | |||
| { | |||
| public partial class Tape | |||
| { | |||
| public BackpropInitialState PrepareBackprop(Tensor[] target, | |||
| public BackpropInitialState PrepareBackprop(long[] target, | |||
| TensorTape tensor_tape, | |||
| OpTape op_tape, | |||
| UnorderedSet<Tensor> sources_set, | |||
| UnorderedSet<long> sources_set, | |||
| bool persistent_tape) | |||
| { | |||
| Stack<long> tensor_stack = new Stack<long>(); | |||
| foreach(var t in target) | |||
| { | |||
| tensor_stack.Push(t); | |||
| } | |||
| BackpropInitialState result = new BackpropInitialState(); | |||
| var tensor_stack = new Queue<Tensor>(target); | |||
| while (tensor_stack.Count > 0) | |||
| while(tensor_stack.Count > 0) | |||
| { | |||
| var tensor_id = tensor_stack.Dequeue(); | |||
| if (!tensor_tape.find(tensor_id, out var op_id)) | |||
| long tensor_id = tensor_stack.Pop(); | |||
| if(!tensor_tape.TryGetValue(tensor_id, out var op_id)) | |||
| { | |||
| continue; | |||
| if (op_id == -1 || | |||
| !op_tape.find(op_id, out var op_it) || | |||
| result.op_tape.find(op_id, out var result_op_it)) | |||
| } | |||
| if(op_id == -1 || !op_tape.TryGetValue(op_id, out var op_it) | |||
| || result.op_tape.find(op_id)) | |||
| { | |||
| continue; | |||
| } | |||
| result.op_tape.emplace(op_id, op_it); | |||
| foreach (var it in op_it.input_tensor_id) | |||
| foreach(var it in op_it.input_tensor_id) | |||
| { | |||
| if (result.tensor_usage_counts.find(it)) | |||
| if(result.tensor_usage_counts.find(it)) | |||
| { | |||
| result.tensor_usage_counts[it]++; | |||
| } | |||
| else | |||
| { | |||
| result.tensor_usage_counts[it] = 1; | |||
| if (tensor_tape.find(it)) | |||
| tensor_stack.Enqueue(it); | |||
| { | |||
| tensor_stack.Push(it); | |||
| } | |||
| } | |||
| } | |||
| if (!persistent_tape) | |||
| op_tape.Remove(op_id); | |||
| { | |||
| op_tape.erase(op_id); | |||
| } | |||
| } | |||
| foreach (var pair in result.tensor_usage_counts) | |||
| foreach(var pair in result.tensor_usage_counts) | |||
| { | |||
| if (tensor_tape.find(pair.Key, out var it) && it != -1) | |||
| result.op_missing_tensor[it] += 1; | |||
| if(tensor_tape.TryGetValue(pair.Key, out var it) && it != -1) | |||
| { | |||
| result.op_missing_tensor[it]++; | |||
| } | |||
| } | |||
| if (!persistent_tape) | |||
| { | |||
| // Call destructors for all unneeded gradient functions and | |||
| // clear the op_tape. We can clear the tape because ownership of | |||
| // backward functions that will be used for gradient computation | |||
| // has been transferred to `result`. | |||
| /*for (const auto&op_pair : *op_tape) { | |||
| op_pair.second.backward_function_deleter( | |||
| op_pair.second.backward_function); | |||
| }*/ | |||
| op_tape.Clear(); | |||
| } | |||
| return result; | |||
| } | |||
| } | |||
| @@ -8,34 +8,45 @@ namespace Tensorflow.Gradients | |||
| public partial class Tape | |||
| { | |||
| long next_op_id_ = 0; | |||
| UnorderedMap<Tensor, long> tensor_usage_; | |||
| UnorderedMap<long, long> tensor_usage_; | |||
| public void RecordOperation(string op_type, | |||
| Tensor[] input_tensors, | |||
| TapeTensor[] output_tensors, | |||
| long[] input_tensor_id, | |||
| TF_DataType[] input_dtypes, | |||
| BackwardFunction backward_function) | |||
| { | |||
| if (!ShouldRecord(input_tensors)) | |||
| if (!ShouldRecord(input_tensor_id, input_dtypes)) | |||
| return; | |||
| var op_id = next_op_id_++; | |||
| foreach (var i in input_tensors) | |||
| foreach (var i in input_tensor_id) | |||
| { | |||
| tensor_usage_[i]++; | |||
| } | |||
| long op_id = next_op_id_++; | |||
| foreach (var o in output_tensors) | |||
| { | |||
| tf.Logger.Debug($"RecordOperation: tensor_tape_[{o.GetID()}] = {op_id}"); | |||
| tensor_tape_[o.GetTensor()] = op_id; | |||
| tensor_usage_[o.GetTensor()] = 1; | |||
| tensor_tape_[o.GetID()] = op_id; | |||
| tensor_usage_[o.GetID()] = 1; | |||
| } | |||
| op_tape_[op_id] = new OpTapeEntry | |||
| { | |||
| op_type = op_type, | |||
| output_tensor_info = output_tensors, | |||
| input_tensor_id = input_tensors, | |||
| output_tensor_info = output_tensors.ToArray(), | |||
| input_tensor_id = input_tensor_id.ToArray(), | |||
| backward_function = backward_function | |||
| }; | |||
| } | |||
| public void RecordOperation(string op_type, | |||
| Tensor[] outputs, | |||
| Tensor[] inputs, | |||
| BackwardFunction backward_function) | |||
| { | |||
| tf.Runner.TFE_TapeSetRecordOperation(op_type, outputs, inputs, backward_function); | |||
| } | |||
| } | |||
| } | |||
| @@ -1,5 +1,6 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Diagnostics; | |||
| using System.Linq; | |||
| using Tensorflow.Util; | |||
| using static Tensorflow.Binding; | |||
| @@ -29,7 +30,7 @@ namespace Tensorflow.Gradients | |||
| _created_eagerly = tf.Context.executing_eagerly(); | |||
| tensor_tape_ = new TensorTape(); | |||
| op_tape_ = new OpTape(); | |||
| tensor_usage_ = new UnorderedMap<Tensor, long>(); | |||
| tensor_usage_ = new UnorderedMap<long, long>(); | |||
| if(_created_eagerly) | |||
| tf.Context.start_step(); | |||
| // nesting_id = ++tape_nesting_id_counter; | |||
| @@ -42,29 +43,28 @@ namespace Tensorflow.Gradients | |||
| public void Watch(Tensor x) | |||
| { | |||
| tf.Logger.Debug($"Watch tensor id={x.Id}, name={x.name}"); | |||
| tensor_tape_.emplace(x, -1); | |||
| tensor_tape_.emplace(x.Id, -1); | |||
| } | |||
| public bool ShouldRecord(Tensor[] tensors) | |||
| public bool ShouldRecord(long[] tensor_ids, TF_DataType[] tensor_dtypes) | |||
| { | |||
| var dtypes = tensors.Select(x => x.dtype).ToArray(); | |||
| for (int i = 0; i < tensors.Length; ++i) | |||
| Debug.Assert(tensor_ids.Length == tensor_dtypes.Length); | |||
| for (int i = 0; i < tensor_ids.Length; ++i) | |||
| { | |||
| if (tensor_tape_.find(tensors[i])) | |||
| if (tensor_tape_.find(tensor_ids[i]) && IsDtypeTrainable(tensor_dtypes[i])) | |||
| { | |||
| if (IsDtypeTrainable(dtypes[i])) | |||
| return true; | |||
| return true; | |||
| } | |||
| } | |||
| return false; | |||
| } | |||
| public void VariableAccessed(ResourceVariable variable) | |||
| public void VariableAccessed(IVariableV1 variable) | |||
| { | |||
| Watch(variable.Handle); | |||
| } | |||
| public ResourceVariable[] WatchedVariables() | |||
| public IVariableV1[] WatchedVariables() | |||
| { | |||
| return null; | |||
| } | |||
| @@ -1,27 +1,63 @@ | |||
| using static Tensorflow.Binding; | |||
| using OneOf; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow.Gradients | |||
| { | |||
| public class TapeTensor | |||
| { | |||
| Tensor tensor; | |||
| long id => tensor.Id; | |||
| TF_DataType dtype => tensor.dtype; | |||
| Shape shape => tensor.shape; | |||
| internal Tensor tensor; | |||
| internal long id; | |||
| internal TF_DataType dtype; | |||
| internal OneOf<Shape, Tensor> shape; | |||
| public TapeTensor(long id, TF_DataType dtype, Shape shape) | |||
| { | |||
| this.id = id; | |||
| this.dtype = dtype; | |||
| this.shape = shape; | |||
| } | |||
| public TapeTensor(long id, TF_DataType dtype, Tensor shape) | |||
| { | |||
| this.id = id; | |||
| this.dtype = dtype; | |||
| this.shape = shape; | |||
| } | |||
| public TapeTensor(Tensor tensor) | |||
| { | |||
| this.id = tensor.Id; | |||
| this.dtype = tensor.dtype; | |||
| this.shape = tensor.shape; | |||
| this.tensor = tensor; | |||
| } | |||
| public long GetID() => tensor.Id; | |||
| public Tensor GetTensor() => tensor; | |||
| public long GetID() => id; | |||
| public Tensor ZerosLike() | |||
| => tf.zeros(shape: shape, dtype: dtype); | |||
| { | |||
| if(dtype == dtypes.resource) | |||
| { | |||
| return null; | |||
| } | |||
| if(shape.Index == 1) | |||
| { | |||
| return tf.zeros_like(shape.AsT1); | |||
| } | |||
| return tf.zeros(shape.AsT0, dtype); | |||
| } | |||
| public Tensor OnesLike() | |||
| => tf.ones(shape: shape, dtype: dtype); | |||
| { | |||
| if (shape.Index == 1) | |||
| { | |||
| return tf.ones_like(shape.AsT1); | |||
| } | |||
| return tf.ones(shape.AsT0, dtype); | |||
| } | |||
| //public Tensor OnesLike() | |||
| // => tf.ones(shape: shape, dtype: dtype); | |||
| public override string ToString() | |||
| => $"{id}, {shape}, {dtype.as_numpy_name()}"; | |||
| @@ -7,7 +7,7 @@ namespace Tensorflow.Gradients | |||
| /// produced this tensor. A value of -1 means that the tensor was directly | |||
| /// watched and not the result of any operation in the tape. | |||
| /// </summary> | |||
| public class TensorTape : UnorderedMap<Tensor, long> | |||
| public class TensorTape : UnorderedMap<long, long> | |||
| { | |||
| } | |||
| @@ -704,32 +704,7 @@ namespace Tensorflow | |||
| public static int PossibleTapeGradientTypes(Tensor[] tensors) | |||
| { | |||
| var tape_set = tf.GetTapeSet(); | |||
| bool some_tape_watching = false; | |||
| if(tape_set is not null && tape_set.Count > 0) | |||
| { | |||
| foreach(var tape in tape_set) | |||
| { | |||
| if (tape.ShouldRecord(tensors)) | |||
| { | |||
| if(tape.Persistent || some_tape_watching) | |||
| { | |||
| return POSSIBLE_GRADIENT_TYPES_HIGHER_ORDER; | |||
| } | |||
| some_tape_watching = true; | |||
| } | |||
| } | |||
| } | |||
| // skip the forward_accumulators. | |||
| if (some_tape_watching) | |||
| { | |||
| return POSSIBLE_GRADIENT_TYPES_FIRST_ORDER; | |||
| } | |||
| else | |||
| { | |||
| return POSSIBLE_GRADIENT_TYPES_NONE; | |||
| } | |||
| return tf.Runner.TFE_TapeSetPossibleGradientTypes(tensors); | |||
| } | |||
| /// <summary> | |||
| @@ -215,6 +215,16 @@ public class FuncGraph : Graph, IDisposable | |||
| return tensor; | |||
| } | |||
| public void watch_variable(IVariableV1 v) | |||
| { | |||
| if (_resource_tensor_inputs.Contains(v.Handle)) | |||
| { | |||
| return; | |||
| } | |||
| _watched_variables.Add(new WeakReference<IVariableV1>(v)); | |||
| //this = this.outer_graph; | |||
| } | |||
| Tensor capture_eager_tensor(Tensor tensor, string name) | |||
| { | |||
| Tensor graph_const = null; | |||
| @@ -4,10 +4,10 @@ public interface IOptimizer | |||
| { | |||
| Tensor[] aggregate_gradients(IEnumerable<(Tensor, IVariableV1)> grads_and_vars); | |||
| Tensor[] clip_gradients(Tensor[] grads); | |||
| void apply_gradients((Tensor, ResourceVariable) grads_and_vars, | |||
| void apply_gradients((Tensor, IVariableV1) grads_and_vars, | |||
| string name = null, | |||
| bool experimental_aggregate_gradients = true); | |||
| void apply_gradients(IEnumerable<(Tensor, ResourceVariable)> grads_and_vars, | |||
| void apply_gradients(IEnumerable<(Tensor, IVariableV1)> grads_and_vars, | |||
| string name = null, | |||
| bool experimental_aggregate_gradients = true); | |||
| } | |||
| @@ -208,9 +208,9 @@ namespace Tensorflow | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern int TF_OperationOutputListLength(IntPtr oper, string arg_name, SafeStatusHandle status); | |||
| //[DllImport(TensorFlowLibName)] | |||
| //public static extern IntPtr GetHandleShapeAndType(SafeGraphHandle c_graph, TF_Output output); | |||
| //[DllImport(TensorFlowLibName)] | |||
| //public static extern void SetHandleShapeAndType(SafeGraphHandle c_graph, TF_Output output, byte[] data); | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern IntPtr TFC_GetHandleShapeAndType(SafeGraphHandle c_graph, TF_Output output); | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern void TFC_SetHandleShapeAndType(SafeGraphHandle c_graph, TF_Output output, byte[] data, long proto_len, SafeStatusHandle status); | |||
| } | |||
| } | |||
| @@ -39,7 +39,7 @@ namespace Tensorflow | |||
| if (config is null) | |||
| { | |||
| config = function_utils.get_disabled_rewriter_config().ToString(); | |||
| config = function_utils.get_disabled_rewriter_config().ToStringUtf8(); | |||
| } | |||
| if (executor_type is null) | |||
| @@ -49,6 +49,8 @@ namespace Tensorflow | |||
| if (executing_eagerly) | |||
| { | |||
| // TODO(Rinne): implement it. | |||
| throw new NotImplementedException(); | |||
| } | |||
| @@ -17,6 +17,7 @@ | |||
| using System; | |||
| using System.Linq; | |||
| using Tensorflow.Contexts; | |||
| using Tensorflow.Eager; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow | |||
| @@ -210,7 +211,51 @@ namespace Tensorflow | |||
| /// <param name="name">A name for the operation (optional).</param> | |||
| /// <returns>A `Tensor`. Has the same type as `value`.</returns> | |||
| public static Tensor fill<T>(Tensor dims, T value, string name = null) | |||
| => tf.Context.ExecuteOp("Fill", name, new ExecuteOpArgs(dims, value)); | |||
| { | |||
| var ctx = tf.Context; | |||
| if (ctx.executing_eagerly()) | |||
| { | |||
| try | |||
| { | |||
| var _result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo("Fill", name, dims, value)); | |||
| return _result[0]; | |||
| } | |||
| catch (Exception) | |||
| { | |||
| } | |||
| try | |||
| { | |||
| return fill_eager_fallback(dims, value as Tensor, name, ctx); | |||
| } | |||
| catch (Exception) | |||
| { | |||
| } | |||
| } | |||
| Dictionary<string, object> attrs = new Dictionary<string, object>(); | |||
| attrs["dims"] = dims; | |||
| attrs["value"] = value; | |||
| var result = tf.OpDefLib._apply_op_helper("Fill", name, attrs); | |||
| if (execute.must_record_gradient()) | |||
| { | |||
| throw new NotImplementedException(); | |||
| } | |||
| return result.output; | |||
| } | |||
| public static Tensor fill_eager_fallback(Tensor dims, Tensor value, string name, Context ctx) | |||
| { | |||
| object[] attrs = new object[] { "T", dims.dtype.as_datatype_enum(), "index_type", dims.dtype.as_datatype_enum() }; | |||
| var _result = execute.executes("Fill", 1, new Tensor[] { dims, value }, attrs, ctx, name); | |||
| if (execute.must_record_gradient()) | |||
| { | |||
| throw new NotImplementedException(); | |||
| } | |||
| return _result[0]; | |||
| } | |||
| //=> tf.Context.ExecuteOp("Fill", name, new ExecuteOpArgs(dims, value)); | |||
| /// <summary> | |||
| /// Return the reduction indices for computing gradients of s0 op s1 with broadcast. | |||
| @@ -49,8 +49,10 @@ namespace Tensorflow.Operations | |||
| target_t.HandleData = handle_data; | |||
| return; | |||
| } | |||
| // TODO(Rinne): enable it. (currently the internal c api cannot be invoked.) | |||
| //c_api.SetHandleShapeAndType(target_t.graph.c_graph, target_t._as_tf_output(), handle_data.ToByteArray()); | |||
| Status status = new(); | |||
| var proto = handle_data.ToByteArray(); | |||
| c_api.TFC_SetHandleShapeAndType(target_t.graph.c_graph, target_t._as_tf_output(), proto, proto.Length, status); | |||
| status.Check(true); | |||
| } | |||
| public static HandleData get_resource_handle_data(Tensor graph_op) => ops.get_resource_handle_data(graph_op); | |||
| @@ -25,6 +25,7 @@ using static Tensorflow.Binding; | |||
| using Tensorflow.Operations; | |||
| using System.Buffers; | |||
| using Tensorflow.Eager; | |||
| using Tensorflow.Graphs; | |||
| namespace Tensorflow | |||
| { | |||
| @@ -302,5 +303,18 @@ namespace Tensorflow | |||
| // return handle_data_util.get_resource_handle_data(handle); | |||
| //} | |||
| } | |||
| public static void variable_accessed(IVariableV1 variable) | |||
| { | |||
| if (ops.get_default_graph() is FuncGraph func_graph) | |||
| { | |||
| func_graph.watch_variable(variable); | |||
| } | |||
| if (variable.Trainable) | |||
| { | |||
| foreach (var tape in tf.GetTapeSet()) | |||
| tape.VariableAccessed(variable); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -110,7 +110,7 @@ https://tensorflownet.readthedocs.io</Description> | |||
| <PackageReference Include="MethodBoundaryAspect.Fody" Version="2.0.148" /> | |||
| <PackageReference Include="Newtonsoft.Json" Version="13.0.2" /> | |||
| <PackageReference Include="OneOf" Version="3.0.223" /> | |||
| <PackageReference Include="Protobuf.Text" Version="0.6.2" /> | |||
| <PackageReference Include="Protobuf.Text" Version="0.7.0" /> | |||
| <PackageReference Include="Serilog.Sinks.Console" Version="4.1.0" /> | |||
| </ItemGroup> | |||
| @@ -30,7 +30,7 @@ namespace Tensorflow | |||
| { | |||
| public virtual IntPtr TensorDataPointer => _handle == null ? IntPtr.Zero : TF_TensorData(_handle); | |||
| public Tensor() | |||
| protected Tensor() | |||
| { | |||
| } | |||
| @@ -108,6 +108,7 @@ namespace Tensorflow | |||
| protected unsafe void InitTensor(Shape shape, TF_DataType dtype) | |||
| { | |||
| _handle = TF_NewTensor(shape, dtype, null); | |||
| _id = ops.uid(); | |||
| } | |||
| protected unsafe void InitTensor(Shape shape, byte[] bytes, TF_DataType dtype) | |||
| @@ -116,6 +117,7 @@ namespace Tensorflow | |||
| _handle = StringTensor(new byte[][] { bytes }, Shape.Scalar); | |||
| else | |||
| _handle = TF_NewTensor(bytes, shape, dtype); | |||
| _id = ops.uid(); | |||
| } | |||
| protected unsafe void InitTensor(Array array, Shape? shape = null) | |||
| @@ -166,6 +168,8 @@ namespace Tensorflow | |||
| string[] val => StringTensor(val, shape), | |||
| _ => throw new NotImplementedException("") | |||
| }; | |||
| _id = ops.uid(); | |||
| } | |||
| unsafe SafeTensorHandle InitTensor<T>(T[] array, Shape shape, TF_DataType dtype) where T : unmanaged | |||
| @@ -462,6 +462,7 @@ namespace Tensorflow.Training.Saving.SavedModel | |||
| { | |||
| IEnumerable<ConcreteFunction> _concrete_functions; | |||
| FunctionSpec _function_spec; | |||
| public IEnumerable<ConcreteFunction> ConcreteFunctions => _concrete_functions; | |||
| public RestoredFunction(Func<Tensor[], Tensor[]> function, string name, FunctionSpec function_spec, | |||
| IEnumerable<ConcreteFunction> concrete_functions): base(function, name, auto_graph: false) | |||
| { | |||
| @@ -25,6 +25,19 @@ namespace Tensorflow.Util | |||
| } | |||
| } | |||
| public Tv SetDefault(Tk key, Tv default_value) | |||
| { | |||
| if(TryGetValue(key, out var res)) | |||
| { | |||
| return res; | |||
| } | |||
| else | |||
| { | |||
| base[key] = default_value; | |||
| return base[key]; | |||
| } | |||
| } | |||
| public void push_back(Tk key, Tv value) | |||
| => this[key] = value; | |||
| @@ -9,6 +9,7 @@ using System.Diagnostics; | |||
| using Tensorflow.Checkpoint; | |||
| using Tensorflow.Training.Saving.SavedModel; | |||
| using OneOf; | |||
| using Tensorflow.Graphs; | |||
| namespace Tensorflow | |||
| { | |||
| @@ -193,6 +194,10 @@ namespace Tensorflow | |||
| /// </summary> | |||
| void variable_accessed(BaseResourceVariable variable) | |||
| { | |||
| if(ops.get_default_graph() is FuncGraph func_graph) | |||
| { | |||
| func_graph.watch_variable(variable as IVariableV1); | |||
| } | |||
| if (variable.Trainable) | |||
| { | |||
| foreach (var tape in tf.GetTapeSet()) | |||
| @@ -575,12 +575,8 @@ namespace Tensorflow | |||
| public static HandleData get_resource_handle_data(Tensor graph_op) | |||
| { | |||
| throw new NotImplementedException(); | |||
| // This implementation hasn't been checked for some reasons. | |||
| // If it throws an exception in the future, please check it. | |||
| //var handle_data = c_api.GetHandleShapeAndType(graph_op.graph.c_graph, graph_op._as_tf_output()); | |||
| //return HandleData.Parser.ParseFrom(tf.compat.as_bytes(c_api.StringPiece(handle_data))); | |||
| var handle_data = c_api.TFC_GetHandleShapeAndType(graph_op.graph.c_graph, graph_op._as_tf_output()); | |||
| return HandleData.Parser.ParseFrom(tf.compat.as_bytes(c_api.StringPiece(handle_data))); | |||
| } | |||
| public static void dismantle_graph(Graph graph) | |||
| @@ -27,6 +27,7 @@ using Tensorflow.Keras.Utils; | |||
| using Tensorflow.NumPy; | |||
| using Tensorflow.Train; | |||
| using Tensorflow.Training; | |||
| using Tensorflow.Training.Saving.SavedModel; | |||
| using Tensorflow.Util; | |||
| using static Tensorflow.Binding; | |||
| @@ -50,7 +51,17 @@ namespace Tensorflow.Keras.Engine | |||
| /// the layer's weights. | |||
| /// </summary> | |||
| protected bool built; | |||
| public bool Built => built; | |||
| public bool Built | |||
| { | |||
| get | |||
| { | |||
| return built; | |||
| } | |||
| internal set | |||
| { | |||
| built = value; | |||
| } | |||
| } | |||
| public bool Trainable => args.Trainable; | |||
| public TF_DataType DType => args.DType; | |||
| public bool AutoCast => args.Autocast; | |||
| @@ -179,6 +190,11 @@ namespace Tensorflow.Keras.Engine | |||
| } | |||
| protected List<ILayer> _self_tracked_trackables; | |||
| /// <summary> | |||
| /// If this value is set, the behavior of layer call will be changed to directly calling this function. | |||
| /// </summary> | |||
| public Func<Tensors, Tensors>? ReplacedCall { get; set; } = null; | |||
| public Layer(LayerArgs args) | |||
| { | |||
| Initialize(args); | |||
| @@ -259,6 +275,10 @@ namespace Tensorflow.Keras.Engine | |||
| /// <returns></returns> | |||
| protected virtual Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | |||
| { | |||
| if(ReplacedCall is not null) | |||
| { | |||
| return ReplacedCall(inputs); | |||
| } | |||
| return inputs; | |||
| } | |||
| @@ -35,10 +35,6 @@ namespace Tensorflow.Keras.Engine | |||
| { | |||
| (x, y) = data_handler.DataAdapter.Expand1d(x, y); | |||
| using var tape = tf.GradientTape(); | |||
| //foreach (var variable in TrainableVariables) | |||
| //{ | |||
| // tape.watch(variable.Handle); | |||
| //} | |||
| var y_pred = Apply(x, training: true); | |||
| var loss = compiled_loss.Call(y, y_pred); | |||
| @@ -70,7 +66,7 @@ namespace Tensorflow.Keras.Engine | |||
| gradients = optimizer.aggregate_gradients(zip(gradients, trainable_variables)); | |||
| gradients = optimizer.clip_gradients(gradients); | |||
| optimizer.apply_gradients(zip(gradients, trainable_variables.Select(x => x as ResourceVariable)), | |||
| optimizer.apply_gradients(zip(gradients, trainable_variables), | |||
| experimental_aggregate_gradients: false); | |||
| } | |||
| } | |||
| @@ -42,7 +42,7 @@ namespace Tensorflow.Keras.Optimizers | |||
| _set_hyper("decay", args.InitialDecay); | |||
| } | |||
| public void apply_gradients((Tensor, ResourceVariable) grads_and_vars, | |||
| public void apply_gradients((Tensor, IVariableV1) grads_and_vars, | |||
| string name = null, | |||
| bool experimental_aggregate_gradients = true) | |||
| => apply_gradients(new[] { grads_and_vars }, | |||
| @@ -55,7 +55,7 @@ namespace Tensorflow.Keras.Optimizers | |||
| /// <param name="grads_and_vars"></param> | |||
| /// <param name="name"></param> | |||
| /// <param name="experimental_aggregate_gradients"></param> | |||
| public void apply_gradients(IEnumerable<(Tensor, ResourceVariable)> grads_and_vars, | |||
| public void apply_gradients(IEnumerable<(Tensor, IVariableV1)> grads_and_vars, | |||
| string name = null, | |||
| bool experimental_aggregate_gradients = true) | |||
| { | |||
| @@ -78,7 +78,7 @@ namespace Tensorflow.Keras.Optimizers | |||
| }); | |||
| } | |||
| void apply_grad_to_update_var(ResourceVariable var, Tensor grad, Dictionary<DeviceDType, Dictionary<string, Tensor>> apply_state) | |||
| void apply_grad_to_update_var(IVariableV1 var, Tensor grad, Dictionary<DeviceDType, Dictionary<string, Tensor>> apply_state) | |||
| { | |||
| _resource_apply_dense(var, grad, apply_state); | |||
| // if var.constraint is not None: | |||
| @@ -93,7 +93,7 @@ namespace Tensorflow.Keras.Optimizers | |||
| throw new NotImplementedException("_resource_apply_dense"); | |||
| } | |||
| void _distributed_apply(IEnumerable<(Tensor, ResourceVariable)> grads_and_vars, | |||
| void _distributed_apply(IEnumerable<(Tensor, IVariableV1)> grads_and_vars, | |||
| string name, | |||
| Dictionary<DeviceDType, Dictionary<string, Tensor>> _apply_state) | |||
| { | |||
| @@ -255,6 +255,25 @@ namespace Tensorflow.Keras.Saving | |||
| /// <param name="layers"></param> | |||
| private void _finalize_saved_model_layers(List<Layer> layers) | |||
| { | |||
| foreach(var layer in layers) | |||
| { | |||
| layer.Built = true; | |||
| var keras_attr = _get_keras_attr(layer); | |||
| if(keras_attr is not Trackable trackable) | |||
| { | |||
| continue; | |||
| } | |||
| if (trackable.CustomizedFields.TryGetValue("call_and_return_conditional_losses", out var layer_call)) | |||
| { | |||
| Debug.Assert(layer_call is RestoredFunction); | |||
| var concrete_functions = ((RestoredFunction)layer_call).ConcreteFunctions; | |||
| if (concrete_functions is not null && concrete_functions.Count() > 0) | |||
| { | |||
| layer.ReplacedCall = use_wrapped_call(layer, ((RestoredFunction)layer_call).Apply); | |||
| } | |||
| } | |||
| } | |||
| foreach(var layer in layers) | |||
| { | |||
| // TODO(Rinne): deal with `RevivedNetwork`. | |||
| @@ -265,6 +284,12 @@ namespace Tensorflow.Keras.Saving | |||
| } | |||
| } | |||
| private Func<Tensors, Tensors> use_wrapped_call(Layer layer, Func<Tensors, Tensors> call) | |||
| { | |||
| // TODO(Rinne): revise it. | |||
| return call; | |||
| } | |||
| private void _restore_layer_unconditional_losses(Layer layer) | |||
| { | |||
| // TODO(Rinne): implement it. | |||
| @@ -85,16 +85,16 @@ namespace Tensorflow.Keras.Saving.SavedModel | |||
| return _config; | |||
| } | |||
| protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | |||
| { | |||
| if(SerializedAttributes is null || !SerializedAttributes.TryGetValue("__call__", out var func) || func is not Function) | |||
| { | |||
| return base.Call(inputs, state, training); | |||
| } | |||
| else | |||
| { | |||
| return (func as Function).Apply(inputs); | |||
| } | |||
| } | |||
| //protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | |||
| //{ | |||
| // if(SerializedAttributes is null || !SerializedAttributes.TryGetValue("__call__", out var func) || func is not Function) | |||
| // { | |||
| // return base.Call(inputs, state, training); | |||
| // } | |||
| // else | |||
| // { | |||
| // return (func as Function).Apply(inputs); | |||
| // } | |||
| //} | |||
| } | |||
| } | |||
| @@ -223,7 +223,7 @@ namespace Tensorflow.Keras.Saving.SavedModel | |||
| //base(checkpointable_objects.Concat(new string[] { "non_trainable_variables", "layers", "metrics", "layer_regularization_losses", "layer_metrics" }), | |||
| // functions.Concat(new string[] { "call_and_return_conditional_losses", "activity_regularizer_fn" }) | |||
| base(checkpointable_objects.Concat(new string[] { "non_trainable_variables", "layers"}), | |||
| functions.Concat(new string[] { })) | |||
| functions.Concat(new string[] { "call_and_return_conditional_losses", "activity_regularizer_fn" })) | |||
| { | |||
| } | |||
| @@ -64,23 +64,19 @@ public class SequentialModelLoad | |||
| var model = tf.keras.models.load_model(@"Assets/python_func_model"); | |||
| model.summary(); | |||
| var x = tf.random.uniform((8, 784), -1, 1); | |||
| var y = model.Apply(x); | |||
| Console.WriteLine(y); | |||
| model.compile(tf.keras.optimizers.Adam(), tf.keras.losses.SparseCategoricalCrossentropy(), new string[] { "accuracy" }); | |||
| //model.compile(tf.keras.optimizers.Adam(), tf.keras.losses.SparseCategoricalCrossentropy(), new string[] { "accuracy" }); | |||
| //var data_loader = new MnistModelLoader(); | |||
| //var num_epochs = 1; | |||
| //var batch_size = 8; | |||
| var data_loader = new MnistModelLoader(); | |||
| var num_epochs = 1; | |||
| var batch_size = 8; | |||
| //var dataset = data_loader.LoadAsync(new ModelLoadSetting | |||
| //{ | |||
| // TrainDir = "mnist", | |||
| // OneHot = false, | |||
| // ValidationSize = 58000, | |||
| //}).Result; | |||
| var dataset = data_loader.LoadAsync(new ModelLoadSetting | |||
| { | |||
| TrainDir = "mnist", | |||
| OneHot = false, | |||
| ValidationSize = 55000, | |||
| }).Result; | |||
| //model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs); | |||
| model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs); | |||
| } | |||
| } | |||