| @@ -0,0 +1,22 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| using Tensorflow.Eager; | |||||
| using static Tensorflow.tensorflow; | |||||
| namespace Tensorflow.Gradients | |||||
| { | |||||
| public partial class Tape | |||||
| { | |||||
| public Tensor[] CallBackwardFunction(BackwardFunction backward_function, | |||||
| List<long> unneeded_gradients, | |||||
| List<Tensor> output_gradients) | |||||
| { | |||||
| var grads = new Tensor[output_gradients.Count]; | |||||
| var result = backward_function(output_gradients.ToArray(), | |||||
| unneeded_gradients.ToArray()); | |||||
| return result; | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,249 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Linq; | |||||
| using System.Text; | |||||
| using Tensorflow.Util; | |||||
| using static Tensorflow.tensorflow; | |||||
| namespace Tensorflow.Gradients | |||||
| { | |||||
| public partial class Tape | |||||
| { | |||||
| int kMinAggregateCount = 4; | |||||
| int kMinAggregateBytes = 128 * 1024 * 1024; | |||||
| public Tensor[] ComputeGradient(long[] target_tensor_ids, | |||||
| long[] source_tensor_ids, | |||||
| UnorderedMap<long, TapeTensor> sources_that_are_targets, | |||||
| Tensor[] output_gradients) | |||||
| { | |||||
| var result = new List<Tensor>(source_tensor_ids.Length); | |||||
| var sources_set = new UnorderedSet<long>(source_tensor_ids); | |||||
| var gradients_size = new UnorderedMap<long, long>(); | |||||
| var state = PrepareBackprop( | |||||
| target_tensor_ids, tensor_tape_, op_tape_, sources_set, persistent_); | |||||
| var op_stack = InitialStack(state.op_tape, state.op_missing_tensor); | |||||
| var gradients = InitialGradients(target_tensor_ids, sources_that_are_targets, | |||||
| output_gradients, | |||||
| tensor_tape_, | |||||
| state.op_tape); | |||||
| while (op_stack.Count > 0) | |||||
| { | |||||
| var op = op_stack.Dequeue(); | |||||
| if (!state.op_tape.find(op, out var trace)) | |||||
| continue; | |||||
| state.op_tape.erase(op); | |||||
| var out_gradients = new List<Tensor>(trace.output_tensor_info.Length); | |||||
| var unneeded_gradients = new List<long>(); | |||||
| for (int i = 0; i < trace.input_tensor_id.Length; i++) | |||||
| { | |||||
| var in_tensor_id = trace.input_tensor_id[i]; | |||||
| if (!tensor_tape_.find(in_tensor_id) && | |||||
| !sources_set.find(in_tensor_id)) | |||||
| unneeded_gradients.Add(i); | |||||
| } | |||||
| bool any_gradient_nonzero = false; | |||||
| var zero_indices = new List<int>(); | |||||
| for (int i = 0; i < trace.output_tensor_info.Length; ++i) | |||||
| { | |||||
| var id = trace.output_tensor_info[i].GetID(); | |||||
| if (!gradients.find(id, out var grad_it)) | |||||
| { | |||||
| throw new NotImplementedException("FunctionsAcceptingNoneForIndicesMap"); | |||||
| } | |||||
| else | |||||
| { | |||||
| any_gradient_nonzero = true; | |||||
| var new_gradients = grad_it.Count == 1 ? | |||||
| grad_it[0] : | |||||
| gen_math_ops.add_n(grad_it.ToArray()); // vspace.AggregateGradients | |||||
| if (!sources_set.find(id)) | |||||
| gradients.Remove(id); | |||||
| else | |||||
| { | |||||
| grad_it.Clear(); | |||||
| grad_it.Add(new_gradients); | |||||
| // vspace.MarkAsResult(new_gradients); | |||||
| } | |||||
| out_gradients.Add(new_gradients); | |||||
| } | |||||
| } | |||||
| Tensor[] in_gradients; | |||||
| if (any_gradient_nonzero) | |||||
| { | |||||
| foreach (var i in zero_indices) | |||||
| out_gradients[i] = trace.output_tensor_info[i].ZerosLike(); | |||||
| in_gradients = CallBackwardFunction(trace.backward_function, | |||||
| unneeded_gradients, | |||||
| out_gradients); | |||||
| if (in_gradients.Count() != trace.input_tensor_id.Count()) | |||||
| throw new RuntimeError($"Recorded operation '{trace.op_type}' returned too few gradients. Expected {trace.input_tensor_id.Length} but received {in_gradients.Count()}"); | |||||
| if (!persistent_) | |||||
| { | |||||
| // trace.backward_function_deleter(trace.backward_function); | |||||
| } | |||||
| } | |||||
| else | |||||
| { | |||||
| throw new NotImplementedException(""); | |||||
| } | |||||
| for (int i = 0; i < in_gradients.Length; ++i) | |||||
| { | |||||
| var id = trace.input_tensor_id[i]; | |||||
| if (in_gradients[i] != null) | |||||
| { | |||||
| var unaggregated_grads = gradients[id]; | |||||
| unaggregated_grads.Add(in_gradients[i]); | |||||
| if(unaggregated_grads.Count > kMinAggregateCount) | |||||
| { | |||||
| if(!gradients_size.ContainsKey(id)) | |||||
| { | |||||
| } | |||||
| else | |||||
| { | |||||
| } | |||||
| throw new NotImplementedException(""); | |||||
| } | |||||
| } | |||||
| if (!state.tensor_usage_counts.find(id)) | |||||
| continue; | |||||
| state.tensor_usage_counts[id]--; | |||||
| if (state.tensor_usage_counts[id] > 0) | |||||
| continue; | |||||
| if (!tensor_tape_.find(id, out var tape_it)) | |||||
| { | |||||
| if (gradients.find(id, out var grad_it)) | |||||
| { | |||||
| // foreach (var g in grad_it) | |||||
| // DeleteGradient(g); | |||||
| gradients.erase(id); | |||||
| } | |||||
| continue; | |||||
| } | |||||
| var op_id = tape_it; | |||||
| if (op_id == -1) | |||||
| continue; | |||||
| if(state.op_missing_tensor.find(op_id, out var missing_it)) | |||||
| { | |||||
| state.op_missing_tensor[op_id]--; | |||||
| if (state.op_missing_tensor[op_id] == 0) | |||||
| op_stack.Enqueue(op_id); | |||||
| } | |||||
| } | |||||
| } | |||||
| if (state.op_tape.Count > 0) | |||||
| throw new RuntimeError("Invalid tape state."); | |||||
| var used_gradient_ids = new List<long>(source_tensor_ids.Length); | |||||
| foreach (var id in source_tensor_ids) | |||||
| { | |||||
| if (!gradients.find(id, out var grad_it)) | |||||
| result.Add(null); | |||||
| else | |||||
| { | |||||
| if(grad_it.Count > 1) | |||||
| { | |||||
| var grad = gen_math_ops.add_n(grad_it.ToArray()); | |||||
| grad_it.Clear(); | |||||
| grad_it.Add(grad); | |||||
| } | |||||
| result.Add(grad_it[0]); | |||||
| used_gradient_ids.Add(id); | |||||
| } | |||||
| } | |||||
| /*foreach(var grad_pair in gradients) | |||||
| { | |||||
| if(!used_gradient_ids.Contains(grad_pair.Key)) | |||||
| { | |||||
| foreach(var g in grad_pair.Value) | |||||
| { | |||||
| vspace.DeleteGradient(g); | |||||
| } | |||||
| } | |||||
| }*/ | |||||
| return result.ToArray(); | |||||
| } | |||||
| UnorderedMapEnumerable<long, List<Tensor>> InitialGradients(long[] target_tensor_ids, | |||||
| UnorderedMap<long, TapeTensor> sources_that_are_targets, | |||||
| Tensor[] output_gradients, | |||||
| TensorTape tensor_tape, | |||||
| OpTape<BackwardFunction, TapeTensor> op_tape) | |||||
| { | |||||
| var result = new UnorderedMapEnumerable<long, List<Tensor>>(); | |||||
| for (int i = 0; i < target_tensor_ids.Length; ++i) | |||||
| { | |||||
| var id = target_tensor_ids[i]; | |||||
| if (output_gradients.Length == 0 || output_gradients[i] == null) | |||||
| { | |||||
| if (tensor_tape.find(id, out var tensor_id) && tensor_id != -1) | |||||
| { | |||||
| if (!op_tape.find(tensor_tape[id], out var op_it)) | |||||
| throw new RuntimeError("Internal state of the gradient tape is invalid: " + | |||||
| "failed to find operation producing a tensor"); | |||||
| bool found = false; | |||||
| for (int j = 0; j < op_it.output_tensor_info.Length; ++j) | |||||
| { | |||||
| if (op_it.output_tensor_info[j].GetID() == id) | |||||
| { | |||||
| found = true; | |||||
| var ones = op_it.output_tensor_info[j].OnesLike(); | |||||
| result[id].Add(ones); | |||||
| break; | |||||
| } | |||||
| } | |||||
| if (!found) | |||||
| { | |||||
| throw new ValueError("Internal state of the gradient tape is invalid: " + | |||||
| "none of operations outputs match expected tensor"); | |||||
| } | |||||
| } | |||||
| else | |||||
| { | |||||
| if (sources_that_are_targets.find(id, out var source_tensor)) | |||||
| result[id].Add(source_tensor.OnesLike()); | |||||
| } | |||||
| } | |||||
| else | |||||
| { | |||||
| result[id].Add(output_gradients[i]); | |||||
| } | |||||
| } | |||||
| return result; | |||||
| } | |||||
| Queue<long> InitialStack(OpTape<BackwardFunction, TapeTensor> op_tape, | |||||
| UnorderedMap<long, long> op_missing_tensor) | |||||
| { | |||||
| var result = new Queue<long>(); | |||||
| foreach(var op_entry in op_tape) | |||||
| { | |||||
| if (!op_missing_tensor.find(op_entry.Key)) | |||||
| result.Enqueue(op_entry.Key); | |||||
| } | |||||
| return result; | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,72 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Linq; | |||||
| using System.Text; | |||||
| using Tensorflow.Util; | |||||
| using static Tensorflow.tensorflow; | |||||
| namespace Tensorflow.Gradients | |||||
| { | |||||
| public partial class Tape | |||||
| { | |||||
| public BackpropInitialState PrepareBackprop(long[] target, | |||||
| TensorTape tensor_tape, | |||||
| OpTape<BackwardFunction, TapeTensor> op_tape, | |||||
| UnorderedSet<long> sources_set, | |||||
| bool persistent_tape) | |||||
| { | |||||
| BackpropInitialState result = new BackpropInitialState(); | |||||
| var tensor_stack = new Queue<long>(target); | |||||
| while (tensor_stack.Count > 0) | |||||
| { | |||||
| var tensor_id = tensor_stack.Dequeue(); | |||||
| if (!tensor_tape.find(tensor_id, out var op_id)) | |||||
| continue; | |||||
| if (op_id == -1 || | |||||
| !op_tape.find(op_id, out var op_it) || | |||||
| result.op_tape.find(op_id, out var result_op_it)) | |||||
| continue; | |||||
| result.op_tape.emplace(op_id, op_it); | |||||
| foreach (var it in op_it.input_tensor_id) | |||||
| { | |||||
| if(result.tensor_usage_counts.find(it)) | |||||
| result.tensor_usage_counts[it]++; | |||||
| else | |||||
| { | |||||
| result.tensor_usage_counts[it] = 1; | |||||
| if (tensor_tape.find(it)) | |||||
| tensor_stack.Enqueue(it); | |||||
| } | |||||
| } | |||||
| if (!persistent_tape) | |||||
| op_tape.Remove(op_id); | |||||
| } | |||||
| foreach (var pair in result.tensor_usage_counts) | |||||
| { | |||||
| if (tensor_tape.find(pair.Key, out var it) && it != -1) | |||||
| result.op_missing_tensor[it] += 1; | |||||
| } | |||||
| if (!persistent_tape) | |||||
| { | |||||
| // Call destructors for all unneeded gradient functions and | |||||
| // clear the op_tape. We can clear the tape because ownership of | |||||
| // backward functions that will be used for gradient computation | |||||
| // has been transferred to `result`. | |||||
| /*for (const auto&op_pair : *op_tape) { | |||||
| op_pair.second.backward_function_deleter( | |||||
| op_pair.second.backward_function); | |||||
| }*/ | |||||
| op_tape.Clear(); | |||||
| } | |||||
| return result; | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,51 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| using Tensorflow.Util; | |||||
| using static Tensorflow.tensorflow; | |||||
| namespace Tensorflow.Gradients | |||||
| { | |||||
| public partial class Tape | |||||
| { | |||||
| long next_op_id_ = 0; | |||||
| UnorderedMap<long, long> tensor_usage_; | |||||
| public void RecordOperation(string op_type, | |||||
| Tensor[] input_tensors, | |||||
| TapeTensor[] output_tensors, | |||||
| long[] input_tensor_id, | |||||
| TF_DataType[] input_dtypes, | |||||
| Func<BackwardFunction> backward_function_getter) | |||||
| { | |||||
| if (!ShouldRecord(input_tensor_id, input_dtypes)) | |||||
| { | |||||
| return; | |||||
| } | |||||
| long op_id = next_op_id_++; | |||||
| var ids = new List<long>(input_tensor_id.Length); | |||||
| foreach (var i in input_tensor_id) | |||||
| { | |||||
| tensor_usage_[i]++; | |||||
| ids.Add(i); | |||||
| } | |||||
| var tensors = new List<TapeTensor>(output_tensors.Length); | |||||
| foreach (var o in output_tensors) | |||||
| { | |||||
| tensor_tape_[o.GetID()] = op_id; | |||||
| tensor_usage_[o.GetID()] = 1; | |||||
| tensors.Add(o); | |||||
| } | |||||
| op_tape_[op_id] = new OpTapeEntry<BackwardFunction, TapeTensor> | |||||
| { | |||||
| op_type = op_type, | |||||
| output_tensor_info = tensors.ToArray(), | |||||
| input_tensor_id = ids.ToArray(), | |||||
| backward_function = backward_function_getter() | |||||
| }; | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -1,50 +1,112 @@ | |||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Linq; | |||||
| using System.Text; | using System.Text; | ||||
| using Tensorflow.Util; | using Tensorflow.Util; | ||||
| using static Tensorflow.Binding; | |||||
| using static Tensorflow.tensorflow; | |||||
| namespace Tensorflow.Gradients | namespace Tensorflow.Gradients | ||||
| { | { | ||||
| public class Tape : ITape | |||||
| public partial class Tape : ITape | |||||
| { | { | ||||
| int nesting_id; | |||||
| static int tape_nesting_id_counter = 0; | |||||
| bool persistent_; | |||||
| bool watch_accessed_variables; | |||||
| TensorTape tensor_tape_; | |||||
| OpTape<BackwardFunction, TapeTensor> op_tape_; | |||||
| /// <summary> | |||||
| /// A deque-backed stack, whose element references are not invalidated by | |||||
| /// pushes and pops at the back. | |||||
| /// </summary> | |||||
| Stack<AccumulatorCallState> call_state_; | |||||
| public Tape(bool persistent, bool watch_accessed_variables) | public Tape(bool persistent, bool watch_accessed_variables) | ||||
| { | { | ||||
| this.persistent_ = persistent; | |||||
| this.watch_accessed_variables = watch_accessed_variables; | |||||
| } | |||||
| tensor_tape_ = new TensorTape(); | |||||
| op_tape_ = new OpTape<BackwardFunction, TapeTensor>(); | |||||
| tensor_usage_ = new UnorderedMap<long, long>(); | |||||
| public Tensor[] ComputeGradient(long[] target_tensor_ids, long[] source_tensor_ids, UnorderedMap<long, TapeTensor> sources_that_are_targets, Tensor[] output_gradients) | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| nesting_id = ++tape_nesting_id_counter; | |||||
| tf.GetTapeSet().Add(this); | |||||
| } | } | ||||
| public void PopTape(ITape tape) | |||||
| /// <summary> | |||||
| /// Marks this tensor to be watched by the given tape. | |||||
| /// </summary> | |||||
| /// <param name="x"></param> | |||||
| public void Watch(long tensor_id) | |||||
| { | { | ||||
| throw new NotImplementedException(); | |||||
| if (!CouldBackprop()) | |||||
| return; | |||||
| tensor_tape_.emplace(tensor_id, -1); | |||||
| } | } | ||||
| public void RecordOperation(string op_type, Tensor[] input_tensors, TapeTensor[] output_tensors, long[] input_tensor_id, TF_DataType[] input_dtypes, Func<tensorflow.BackwardFunction> backward_function_getter) | |||||
| public bool ShouldRecord(long[] tensor_ids, TF_DataType[] dtypes) | |||||
| { | { | ||||
| throw new NotImplementedException(); | |||||
| for (int i = 0; i < tensor_ids.Length; ++i) | |||||
| { | |||||
| if (tensor_tape_.find(tensor_ids[i])) | |||||
| if (IsDtypeTrainable(dtypes[i])) | |||||
| return true; | |||||
| } | |||||
| return false; | |||||
| } | } | ||||
| public bool ShouldRecord(long[] tensor_ids, TF_DataType[] dtypes) | |||||
| /// <summary> | |||||
| /// Pops the given tape in the stack. | |||||
| /// </summary> | |||||
| /// <param name="tape"></param> | |||||
| public void PopTape(ITape tape) | |||||
| { | { | ||||
| throw new NotImplementedException(); | |||||
| tf.GetTapeSet().Remove(tape); | |||||
| } | } | ||||
| public void VariableAccessed(ResourceVariable variable) | public void VariableAccessed(ResourceVariable variable) | ||||
| { | { | ||||
| throw new NotImplementedException(); | |||||
| Watch(variable.Handle.Id); | |||||
| } | } | ||||
| public void Watch(long tensor_id) | |||||
| public ResourceVariable[] WatchedVariables() | |||||
| { | { | ||||
| throw new NotImplementedException(); | |||||
| return null; | |||||
| } | } | ||||
| public ResourceVariable[] WatchedVariables() | |||||
| public bool IsDtypeTrainable(TF_DataType dtype) | |||||
| { | { | ||||
| throw new NotImplementedException(); | |||||
| switch (dtype) | |||||
| { | |||||
| case TF_DataType.TF_HALF: | |||||
| case TF_DataType.TF_BFLOAT16: | |||||
| case TF_DataType.TF_FLOAT: | |||||
| case TF_DataType.TF_DOUBLE: | |||||
| case TF_DataType.TF_COMPLEX64: | |||||
| case TF_DataType.TF_COMPLEX128: | |||||
| case TF_DataType.TF_RESOURCE: | |||||
| case TF_DataType.TF_VARIANT: | |||||
| return true; | |||||
| default: | |||||
| return false; | |||||
| } | |||||
| } | } | ||||
| bool CouldForwardprop() | |||||
| => HasAccumulator(); | |||||
| bool CouldBackprop() | |||||
| => HasGradientTape(); | |||||
| bool HasAccumulator() | |||||
| //return !GetAccumulatorSet()->empty(); | |||||
| => false; | |||||
| bool HasGradientTape() | |||||
| => tf.GetTapeSet().Count > 0; | |||||
| } | } | ||||
| } | } | ||||