| @@ -35,6 +35,7 @@ namespace Tensorflow.Contexts | |||
| public string ScopeName { get; set; } = ""; | |||
| bool initialized = false; | |||
| ContextSwitchStack context_switches; | |||
| public FunctionCallOptions FunctionCallOptions { get; } | |||
| public SafeContextHandle Handle { get; } | |||
| @@ -44,6 +45,7 @@ namespace Tensorflow.Contexts | |||
| status.Check(true); | |||
| context_switches = new ContextSwitchStack(defaultExecutionMode == EAGER_MODE, false); | |||
| initialized = true; | |||
| FunctionCallOptions = new FunctionCallOptions(); | |||
| } | |||
| /// <summary> | |||
| @@ -0,0 +1,20 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using Google.Protobuf; | |||
| using Google.Protobuf.Collections; | |||
| namespace Tensorflow.Contexts | |||
| { | |||
| public class FunctionCallOptions | |||
| { | |||
| public string config_proto_serialized() | |||
| { | |||
| var config = new ConfigProto | |||
| { | |||
| AllowSoftPlacement = true, | |||
| }; | |||
| return config.ToByteString().ToStringUtf8(); | |||
| } | |||
| } | |||
| } | |||
| @@ -371,7 +371,7 @@ namespace Tensorflow.Eager | |||
| switch (type) | |||
| { | |||
| case TF_AttrType.TF_ATTR_STRING: | |||
| c_api.TFE_OpSetAttrString(op, key, value.ToString(), (uint)value.ToString().Length); | |||
| c_api.TFE_OpSetAttrString(op, key, value.ToString(), (ulong)value.ToString().Length); | |||
| break; | |||
| case TF_AttrType.TF_ATTR_TYPE: | |||
| c_api.TFE_OpSetAttrType(op, key, (TF_DataType)value); | |||
| @@ -241,7 +241,7 @@ namespace Tensorflow | |||
| /// <param name="value">const void*</param> | |||
| /// <param name="length">size_t</param> | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern void TFE_OpSetAttrString(SafeOpHandle op, string attr_name, string value, uint length); | |||
| public static extern void TFE_OpSetAttrString(SafeOpHandle op, string attr_name, string value, ulong length); | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern void TFE_OpSetAttrTypeList(SafeOpHandle op, string attr_name, TF_DataType[] values, int num_values); | |||
| @@ -14,6 +14,7 @@ namespace Tensorflow.Functions | |||
| { | |||
| IntPtr _handle; | |||
| FuncGraph func_graph; | |||
| public Tensor[] CapturedInputs => func_graph.external_captures(); | |||
| public string Name | |||
| { | |||
| @@ -38,6 +39,8 @@ namespace Tensorflow.Functions | |||
| public ConcreteFunction(FuncGraph graph, Dictionary<string, string> attrs) | |||
| { | |||
| func_graph = graph; | |||
| ToGraph(graph.Inputs, graph.Outputs); | |||
| } | |||
| public ConcreteFunction(Func<Tensor, Tensor> func, TF_DataType dtype) | |||
| @@ -124,6 +127,21 @@ namespace Tensorflow.Functions | |||
| return flat_outputs; | |||
| } | |||
| public Tensor[] CallFlat(Tensor[] args, Tensor[] captured_inputs) | |||
| { | |||
| var new_args = new List<Tensor>(); | |||
| new_args.AddRange(args); | |||
| new_args.AddRange(captured_inputs); | |||
| args = new_args.ToArray(); | |||
| var attrs = new object[] | |||
| { | |||
| "executor_type", "", | |||
| "config_proto", tf.Context.FunctionCallOptions.config_proto_serialized() | |||
| }; | |||
| return tf.Runner.Execute(tf.Context, func_graph.FuncName, 1, args, attrs); | |||
| } | |||
| ForwardBackwardCall SelectForwardAndBackwardFunctions(Tensors args, int possible_gradient_type, bool executing_eagerly) | |||
| { | |||
| var functions = new FirstOrderTapeGradientFunctions(func_graph, false); | |||
| @@ -48,7 +48,7 @@ namespace Tensorflow.Functions | |||
| getBackwardFunction: () => backward_function); | |||
| } | |||
| (BackwardFunction, Tensors) _wrap_backward_function(FuncGraph forward_graph, ConcreteFunction backward, Tensors flat_outputs) | |||
| (BackwardFunction, Tensors) _wrap_backward_function(FuncGraph forward_graph, ConcreteFunction backward, Tensors outputs) | |||
| { | |||
| BackwardFunction _backward_function_wrapper = (output_grads, unneeded_gradients) => | |||
| { | |||
| @@ -61,10 +61,11 @@ namespace Tensorflow.Functions | |||
| processed_args.add(arg); | |||
| input_index += 1; | |||
| } | |||
| return output_grads;// backward.Invoke(processed_args.ToArray()); | |||
| return backward.CallFlat(processed_args.ToArray(), outputs); | |||
| }; | |||
| return (_backward_function_wrapper, flat_outputs); | |||
| return (_backward_function_wrapper, outputs); | |||
| } | |||
| protected (EagerDefinedFunction, FuncGraph, ConcreteFunction, List<int>, int) | |||
| @@ -82,7 +83,7 @@ namespace Tensorflow.Functions | |||
| } | |||
| var gradients_wrt_outputs = new List<Tensor>(); | |||
| var backwards_graph = new FuncGraph($"{_BACKWARD_PREFIX}{_func_graph.FuncName}_{ops.uid()}"); | |||
| var backwards_graph = new FuncGraph($"{_BACKWARD_PREFIX}_{ops.uid()}"); | |||
| foreach (var output in trainable_outputs) | |||
| gradients_wrt_outputs.Add(tf.placeholder(output.dtype, output.shape)); | |||
| var gradients_wrt_inputs = gradients_util._GradientsHelper(trainable_outputs.ToArray(), | |||
| @@ -90,16 +91,15 @@ namespace Tensorflow.Functions | |||
| grad_ys: gradients_wrt_outputs.ToArray(), | |||
| src_graph: _func_graph); | |||
| tf.Context.restore_mode(); | |||
| var forward_function_name = $"{_FORWARD_PREFIX}{_func_graph.FuncName}_{ops.uid()}"; | |||
| var forward_function_name = $"{_FORWARD_PREFIX}_{ops.uid()}"; | |||
| var backward_function_attr = new Dictionary<string, string>(); | |||
| backward_function_attr[FORWARD_FUNCTION_ATTRIBUTE_NAME] = forward_function_name; | |||
| gradients_wrt_outputs.append(backwards_graph.internal_captures()); | |||
| backwards_graph.Inputs = gradients_wrt_outputs; | |||
| backwards_graph.Outputs = gradients_wrt_inputs; | |||
| var backward_function = new ConcreteFunction(backwards_graph, backward_function_attr); | |||
| var forward_function_attr = new Dictionary<string, string>(); | |||
| forward_function_attr[BACKWARD_FUNCTION_ATTRIBUTE_NAME] = backward_function.Name; | |||
| var forward_function = new EagerDefinedFunction(forward_function_name, _func_graph, | |||
| @@ -49,14 +49,14 @@ namespace Tensorflow | |||
| RegisterGradientFunction(m.GetCustomAttribute<RegisterGradient>().Name, | |||
| (oper, out_grads) => | |||
| { | |||
| tf.Logger.Debug($"Caculate Gradient: {m.Name}"); | |||
| tf.Logger.Debug($"Caculate Gradient: {oper.name} {m.Name}"); | |||
| var results = g.InvokeMember(m.Name, | |||
| BindingFlags.InvokeMethod, | |||
| null, | |||
| null, | |||
| args: new object[] { oper, out_grads }) as Tensor[]; | |||
| foreach (var result in results.Where(x => x != null)) | |||
| tf.Logger.Debug($"{result.TensorShape}"); | |||
| tf.Logger.Debug($"Gradient: {result.name} {result.TensorShape}"); | |||
| return results; | |||
| } | |||
| ); | |||
| @@ -26,7 +26,9 @@ namespace Tensorflow.Graphs | |||
| public Tensors Outputs { get; set; } | |||
| public Dictionary<string, string> Attrs { get; set; } | |||
| Dictionary<long, (Tensor, Tensor)> _captures = new Dictionary<long, (Tensor, Tensor)>(); | |||
| // new Dictionary<long, (Tensor, Tensor)> _captures = new Dictionary<long, (Tensor, Tensor)>(); | |||
| // public new Tensor[] external_captures => _captures.Values.Select(x => x.Item1).ToArray(); | |||
| /// <summary> | |||
| /// Construct a new FuncGraph. | |||
| /// </summary> | |||
| @@ -129,7 +131,7 @@ namespace Tensorflow.Graphs | |||
| Tensor _capture_helper(Tensor tensor, string name, TensorShape shape = null) | |||
| { | |||
| Tensor placeholder = null; | |||
| if (!_captures.ContainsKey(tensor.Id)) | |||
| if (!_captures.Contains(tensor.Id)) | |||
| { | |||
| placeholder = _create_substitute_placeholder(tensor, | |||
| name: name, | |||
| @@ -139,7 +141,7 @@ namespace Tensorflow.Graphs | |||
| } | |||
| else | |||
| { | |||
| placeholder = _captures[tensor.Id].Item1; | |||
| placeholder = (((Tensor, Tensor))_captures[tensor.Id]).Item2; | |||
| } | |||
| BackwardFunction _backward_function_wrapper = (output_grads, unneeded_gradients) => | |||
| @@ -557,16 +557,16 @@ namespace Tensorflow | |||
| public Tensor[] external_captures() | |||
| { | |||
| Tensor[] captures = new Tensor[this._captures.Count]; | |||
| ICollection inner = this._captures.Keys; // c[0] | |||
| Tensor[] captures = new Tensor[_captures.Count]; | |||
| ICollection inner = _captures.Keys; // c[0] | |||
| inner.CopyTo(captures, 0); | |||
| return captures; | |||
| } | |||
| public Tensor[] internal_captures() | |||
| { | |||
| Tensor[] captures = new Tensor[this._captures.Count]; | |||
| ICollection inner = this._captures.Values; // c[1] | |||
| Tensor[] captures = new Tensor[_captures.Count]; | |||
| ICollection inner = _captures.Values; // c[1] | |||
| inner.CopyTo(captures, 0); | |||
| return captures; | |||
| } | |||
| @@ -340,7 +340,7 @@ namespace Tensorflow.Keras.Engine | |||
| tf.Logger.Debug($"{depth}: {node.Layer}: {node.Layer.Name}"); | |||
| var outputs = node.Layer.Apply(layer_inputs, is_training: training); | |||
| foreach (var output in outputs.Where(x => x != null)) | |||
| tf.Logger.Debug($"{output.TensorShape}"); | |||
| tf.Logger.Debug($"{depth}: {node.Layer}: {node.Layer.Name} {output.TensorShape}"); | |||
| // Update tensor_dict for next input | |||
| foreach (var (x_id, y) in zip(node.FlatOutputIds, outputs)) | |||
| tensor_dict[x_id] = new Queue<Tensor>(Enumerable.Range(0, tensor_usage_count[x_id]).Select(x => y)); | |||