| @@ -35,6 +35,7 @@ namespace Tensorflow.Contexts | |||||
| public string ScopeName { get; set; } = ""; | public string ScopeName { get; set; } = ""; | ||||
| bool initialized = false; | bool initialized = false; | ||||
| ContextSwitchStack context_switches; | ContextSwitchStack context_switches; | ||||
| public FunctionCallOptions FunctionCallOptions { get; } | |||||
| public SafeContextHandle Handle { get; } | public SafeContextHandle Handle { get; } | ||||
| @@ -44,6 +45,7 @@ namespace Tensorflow.Contexts | |||||
| status.Check(true); | status.Check(true); | ||||
| context_switches = new ContextSwitchStack(defaultExecutionMode == EAGER_MODE, false); | context_switches = new ContextSwitchStack(defaultExecutionMode == EAGER_MODE, false); | ||||
| initialized = true; | initialized = true; | ||||
| FunctionCallOptions = new FunctionCallOptions(); | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -0,0 +1,20 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| using Google.Protobuf; | |||||
| using Google.Protobuf.Collections; | |||||
| namespace Tensorflow.Contexts | |||||
| { | |||||
| public class FunctionCallOptions | |||||
| { | |||||
| public string config_proto_serialized() | |||||
| { | |||||
| var config = new ConfigProto | |||||
| { | |||||
| AllowSoftPlacement = true, | |||||
| }; | |||||
| return config.ToByteString().ToStringUtf8(); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -371,7 +371,7 @@ namespace Tensorflow.Eager | |||||
| switch (type) | switch (type) | ||||
| { | { | ||||
| case TF_AttrType.TF_ATTR_STRING: | case TF_AttrType.TF_ATTR_STRING: | ||||
| c_api.TFE_OpSetAttrString(op, key, value.ToString(), (uint)value.ToString().Length); | |||||
| c_api.TFE_OpSetAttrString(op, key, value.ToString(), (ulong)value.ToString().Length); | |||||
| break; | break; | ||||
| case TF_AttrType.TF_ATTR_TYPE: | case TF_AttrType.TF_ATTR_TYPE: | ||||
| c_api.TFE_OpSetAttrType(op, key, (TF_DataType)value); | c_api.TFE_OpSetAttrType(op, key, (TF_DataType)value); | ||||
| @@ -241,7 +241,7 @@ namespace Tensorflow | |||||
| /// <param name="value">const void*</param> | /// <param name="value">const void*</param> | ||||
| /// <param name="length">size_t</param> | /// <param name="length">size_t</param> | ||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern void TFE_OpSetAttrString(SafeOpHandle op, string attr_name, string value, uint length); | |||||
| public static extern void TFE_OpSetAttrString(SafeOpHandle op, string attr_name, string value, ulong length); | |||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern void TFE_OpSetAttrTypeList(SafeOpHandle op, string attr_name, TF_DataType[] values, int num_values); | public static extern void TFE_OpSetAttrTypeList(SafeOpHandle op, string attr_name, TF_DataType[] values, int num_values); | ||||
| @@ -14,6 +14,7 @@ namespace Tensorflow.Functions | |||||
| { | { | ||||
| IntPtr _handle; | IntPtr _handle; | ||||
| FuncGraph func_graph; | FuncGraph func_graph; | ||||
| public Tensor[] CapturedInputs => func_graph.external_captures(); | |||||
| public string Name | public string Name | ||||
| { | { | ||||
| @@ -38,6 +39,8 @@ namespace Tensorflow.Functions | |||||
| public ConcreteFunction(FuncGraph graph, Dictionary<string, string> attrs) | public ConcreteFunction(FuncGraph graph, Dictionary<string, string> attrs) | ||||
| { | { | ||||
| func_graph = graph; | func_graph = graph; | ||||
| ToGraph(graph.Inputs, graph.Outputs); | |||||
| } | } | ||||
| public ConcreteFunction(Func<Tensor, Tensor> func, TF_DataType dtype) | public ConcreteFunction(Func<Tensor, Tensor> func, TF_DataType dtype) | ||||
| @@ -124,6 +127,21 @@ namespace Tensorflow.Functions | |||||
| return flat_outputs; | return flat_outputs; | ||||
| } | } | ||||
| public Tensor[] CallFlat(Tensor[] args, Tensor[] captured_inputs) | |||||
| { | |||||
| var new_args = new List<Tensor>(); | |||||
| new_args.AddRange(args); | |||||
| new_args.AddRange(captured_inputs); | |||||
| args = new_args.ToArray(); | |||||
| var attrs = new object[] | |||||
| { | |||||
| "executor_type", "", | |||||
| "config_proto", tf.Context.FunctionCallOptions.config_proto_serialized() | |||||
| }; | |||||
| return tf.Runner.Execute(tf.Context, func_graph.FuncName, 1, args, attrs); | |||||
| } | |||||
| ForwardBackwardCall SelectForwardAndBackwardFunctions(Tensors args, int possible_gradient_type, bool executing_eagerly) | ForwardBackwardCall SelectForwardAndBackwardFunctions(Tensors args, int possible_gradient_type, bool executing_eagerly) | ||||
| { | { | ||||
| var functions = new FirstOrderTapeGradientFunctions(func_graph, false); | var functions = new FirstOrderTapeGradientFunctions(func_graph, false); | ||||
| @@ -48,7 +48,7 @@ namespace Tensorflow.Functions | |||||
| getBackwardFunction: () => backward_function); | getBackwardFunction: () => backward_function); | ||||
| } | } | ||||
| (BackwardFunction, Tensors) _wrap_backward_function(FuncGraph forward_graph, ConcreteFunction backward, Tensors flat_outputs) | |||||
| (BackwardFunction, Tensors) _wrap_backward_function(FuncGraph forward_graph, ConcreteFunction backward, Tensors outputs) | |||||
| { | { | ||||
| BackwardFunction _backward_function_wrapper = (output_grads, unneeded_gradients) => | BackwardFunction _backward_function_wrapper = (output_grads, unneeded_gradients) => | ||||
| { | { | ||||
| @@ -61,10 +61,11 @@ namespace Tensorflow.Functions | |||||
| processed_args.add(arg); | processed_args.add(arg); | ||||
| input_index += 1; | input_index += 1; | ||||
| } | } | ||||
| return output_grads;// backward.Invoke(processed_args.ToArray()); | |||||
| return backward.CallFlat(processed_args.ToArray(), outputs); | |||||
| }; | }; | ||||
| return (_backward_function_wrapper, flat_outputs); | |||||
| return (_backward_function_wrapper, outputs); | |||||
| } | } | ||||
| protected (EagerDefinedFunction, FuncGraph, ConcreteFunction, List<int>, int) | protected (EagerDefinedFunction, FuncGraph, ConcreteFunction, List<int>, int) | ||||
| @@ -82,7 +83,7 @@ namespace Tensorflow.Functions | |||||
| } | } | ||||
| var gradients_wrt_outputs = new List<Tensor>(); | var gradients_wrt_outputs = new List<Tensor>(); | ||||
| var backwards_graph = new FuncGraph($"{_BACKWARD_PREFIX}{_func_graph.FuncName}_{ops.uid()}"); | |||||
| var backwards_graph = new FuncGraph($"{_BACKWARD_PREFIX}_{ops.uid()}"); | |||||
| foreach (var output in trainable_outputs) | foreach (var output in trainable_outputs) | ||||
| gradients_wrt_outputs.Add(tf.placeholder(output.dtype, output.shape)); | gradients_wrt_outputs.Add(tf.placeholder(output.dtype, output.shape)); | ||||
| var gradients_wrt_inputs = gradients_util._GradientsHelper(trainable_outputs.ToArray(), | var gradients_wrt_inputs = gradients_util._GradientsHelper(trainable_outputs.ToArray(), | ||||
| @@ -90,16 +91,15 @@ namespace Tensorflow.Functions | |||||
| grad_ys: gradients_wrt_outputs.ToArray(), | grad_ys: gradients_wrt_outputs.ToArray(), | ||||
| src_graph: _func_graph); | src_graph: _func_graph); | ||||
| tf.Context.restore_mode(); | |||||
| var forward_function_name = $"{_FORWARD_PREFIX}{_func_graph.FuncName}_{ops.uid()}"; | |||||
| var forward_function_name = $"{_FORWARD_PREFIX}_{ops.uid()}"; | |||||
| var backward_function_attr = new Dictionary<string, string>(); | var backward_function_attr = new Dictionary<string, string>(); | ||||
| backward_function_attr[FORWARD_FUNCTION_ATTRIBUTE_NAME] = forward_function_name; | backward_function_attr[FORWARD_FUNCTION_ATTRIBUTE_NAME] = forward_function_name; | ||||
| gradients_wrt_outputs.append(backwards_graph.internal_captures()); | |||||
| backwards_graph.Inputs = gradients_wrt_outputs; | backwards_graph.Inputs = gradients_wrt_outputs; | ||||
| backwards_graph.Outputs = gradients_wrt_inputs; | backwards_graph.Outputs = gradients_wrt_inputs; | ||||
| var backward_function = new ConcreteFunction(backwards_graph, backward_function_attr); | var backward_function = new ConcreteFunction(backwards_graph, backward_function_attr); | ||||
| var forward_function_attr = new Dictionary<string, string>(); | var forward_function_attr = new Dictionary<string, string>(); | ||||
| forward_function_attr[BACKWARD_FUNCTION_ATTRIBUTE_NAME] = backward_function.Name; | forward_function_attr[BACKWARD_FUNCTION_ATTRIBUTE_NAME] = backward_function.Name; | ||||
| var forward_function = new EagerDefinedFunction(forward_function_name, _func_graph, | var forward_function = new EagerDefinedFunction(forward_function_name, _func_graph, | ||||
| @@ -49,14 +49,14 @@ namespace Tensorflow | |||||
| RegisterGradientFunction(m.GetCustomAttribute<RegisterGradient>().Name, | RegisterGradientFunction(m.GetCustomAttribute<RegisterGradient>().Name, | ||||
| (oper, out_grads) => | (oper, out_grads) => | ||||
| { | { | ||||
| tf.Logger.Debug($"Caculate Gradient: {m.Name}"); | |||||
| tf.Logger.Debug($"Caculate Gradient: {oper.name} {m.Name}"); | |||||
| var results = g.InvokeMember(m.Name, | var results = g.InvokeMember(m.Name, | ||||
| BindingFlags.InvokeMethod, | BindingFlags.InvokeMethod, | ||||
| null, | null, | ||||
| null, | null, | ||||
| args: new object[] { oper, out_grads }) as Tensor[]; | args: new object[] { oper, out_grads }) as Tensor[]; | ||||
| foreach (var result in results.Where(x => x != null)) | foreach (var result in results.Where(x => x != null)) | ||||
| tf.Logger.Debug($"{result.TensorShape}"); | |||||
| tf.Logger.Debug($"Gradient: {result.name} {result.TensorShape}"); | |||||
| return results; | return results; | ||||
| } | } | ||||
| ); | ); | ||||
| @@ -26,7 +26,9 @@ namespace Tensorflow.Graphs | |||||
| public Tensors Outputs { get; set; } | public Tensors Outputs { get; set; } | ||||
| public Dictionary<string, string> Attrs { get; set; } | public Dictionary<string, string> Attrs { get; set; } | ||||
| Dictionary<long, (Tensor, Tensor)> _captures = new Dictionary<long, (Tensor, Tensor)>(); | |||||
| // new Dictionary<long, (Tensor, Tensor)> _captures = new Dictionary<long, (Tensor, Tensor)>(); | |||||
| // public new Tensor[] external_captures => _captures.Values.Select(x => x.Item1).ToArray(); | |||||
| /// <summary> | /// <summary> | ||||
| /// Construct a new FuncGraph. | /// Construct a new FuncGraph. | ||||
| /// </summary> | /// </summary> | ||||
| @@ -129,7 +131,7 @@ namespace Tensorflow.Graphs | |||||
| Tensor _capture_helper(Tensor tensor, string name, TensorShape shape = null) | Tensor _capture_helper(Tensor tensor, string name, TensorShape shape = null) | ||||
| { | { | ||||
| Tensor placeholder = null; | Tensor placeholder = null; | ||||
| if (!_captures.ContainsKey(tensor.Id)) | |||||
| if (!_captures.Contains(tensor.Id)) | |||||
| { | { | ||||
| placeholder = _create_substitute_placeholder(tensor, | placeholder = _create_substitute_placeholder(tensor, | ||||
| name: name, | name: name, | ||||
| @@ -139,7 +141,7 @@ namespace Tensorflow.Graphs | |||||
| } | } | ||||
| else | else | ||||
| { | { | ||||
| placeholder = _captures[tensor.Id].Item1; | |||||
| placeholder = (((Tensor, Tensor))_captures[tensor.Id]).Item2; | |||||
| } | } | ||||
| BackwardFunction _backward_function_wrapper = (output_grads, unneeded_gradients) => | BackwardFunction _backward_function_wrapper = (output_grads, unneeded_gradients) => | ||||
| @@ -557,16 +557,16 @@ namespace Tensorflow | |||||
| public Tensor[] external_captures() | public Tensor[] external_captures() | ||||
| { | { | ||||
| Tensor[] captures = new Tensor[this._captures.Count]; | |||||
| ICollection inner = this._captures.Keys; // c[0] | |||||
| Tensor[] captures = new Tensor[_captures.Count]; | |||||
| ICollection inner = _captures.Keys; // c[0] | |||||
| inner.CopyTo(captures, 0); | inner.CopyTo(captures, 0); | ||||
| return captures; | return captures; | ||||
| } | } | ||||
| public Tensor[] internal_captures() | public Tensor[] internal_captures() | ||||
| { | { | ||||
| Tensor[] captures = new Tensor[this._captures.Count]; | |||||
| ICollection inner = this._captures.Values; // c[1] | |||||
| Tensor[] captures = new Tensor[_captures.Count]; | |||||
| ICollection inner = _captures.Values; // c[1] | |||||
| inner.CopyTo(captures, 0); | inner.CopyTo(captures, 0); | ||||
| return captures; | return captures; | ||||
| } | } | ||||
| @@ -340,7 +340,7 @@ namespace Tensorflow.Keras.Engine | |||||
| tf.Logger.Debug($"{depth}: {node.Layer}: {node.Layer.Name}"); | tf.Logger.Debug($"{depth}: {node.Layer}: {node.Layer.Name}"); | ||||
| var outputs = node.Layer.Apply(layer_inputs, is_training: training); | var outputs = node.Layer.Apply(layer_inputs, is_training: training); | ||||
| foreach (var output in outputs.Where(x => x != null)) | foreach (var output in outputs.Where(x => x != null)) | ||||
| tf.Logger.Debug($"{output.TensorShape}"); | |||||
| tf.Logger.Debug($"{depth}: {node.Layer}: {node.Layer.Name} {output.TensorShape}"); | |||||
| // Update tensor_dict for next input | // Update tensor_dict for next input | ||||
| foreach (var (x_id, y) in zip(node.FlatOutputIds, outputs)) | foreach (var (x_id, y) in zip(node.FlatOutputIds, outputs)) | ||||
| tensor_dict[x_id] = new Queue<Tensor>(Enumerable.Range(0, tensor_usage_count[x_id]).Select(x => y)); | tensor_dict[x_id] = new Queue<Tensor>(Enumerable.Range(0, tensor_usage_count[x_id]).Select(x => y)); | ||||