diff --git a/Tensorflow.Common/Extensions/DictionaryExtension.cs b/Tensorflow.Common/Extensions/DictionaryExtension.cs new file mode 100644 index 00000000..7502a3a7 --- /dev/null +++ b/Tensorflow.Common/Extensions/DictionaryExtension.cs @@ -0,0 +1,31 @@ +using System; +using System.Collections.Generic; +using System.Runtime.CompilerServices; +using System.Text; + +namespace Tensorflow.Common.Extensions +{ + public static class DictionaryExtension + { + public static void Deconstruct(this KeyValuePair pair, out T1 first, out T2 second) + { + first = pair.Key; + second = pair.Value; + } + public static void Update(this Dictionary dic, IDictionary other) + { + foreach(var (key, value) in other) + { + dic[key] = value; + } + } + public static T2 GetOrDefault(this Dictionary dic, T1 key, T2 defaultValue) + { + if (dic.ContainsKey(key)) + { + return dic[key]; + } + return defaultValue; + } + } +} diff --git a/src/TensorFlowNET.Core/APIs/tf.gradients.cs b/src/TensorFlowNET.Core/APIs/tf.gradients.cs index d722cb14..492b1034 100644 --- a/src/TensorFlowNET.Core/APIs/tf.gradients.cs +++ b/src/TensorFlowNET.Core/APIs/tf.gradients.cs @@ -21,7 +21,7 @@ namespace Tensorflow { public partial class tensorflow { - GradientTape _tapeSet; + internal GradientTape _tapeSet; /// /// Record operations for automatic differentiation. diff --git a/src/TensorFlowNET.Core/APIs/tf.tensor.cs b/src/TensorFlowNET.Core/APIs/tf.tensor.cs index 91293b3a..35efde06 100644 --- a/src/TensorFlowNET.Core/APIs/tf.tensor.cs +++ b/src/TensorFlowNET.Core/APIs/tf.tensor.cs @@ -14,6 +14,8 @@ limitations under the License. ******************************************************************************/ +using Tensorflow.Operations; + namespace Tensorflow { public partial class tensorflow @@ -79,5 +81,10 @@ namespace Tensorflow num_split: num_split, axis: axis, name: name); + + public Tensor ensure_shape(Tensor x, Shape shape, string name = null) + { + return gen_ops.ensure_shape(x, shape, name); + } } } diff --git a/src/TensorFlowNET.Core/Attributes/c_api.ops.cs b/src/TensorFlowNET.Core/Attributes/c_api.ops.cs index 7d9ff65f..2a22413b 100644 --- a/src/TensorFlowNET.Core/Attributes/c_api.ops.cs +++ b/src/TensorFlowNET.Core/Attributes/c_api.ops.cs @@ -61,7 +61,7 @@ namespace Tensorflow public static extern void TF_SetAttrBool(IntPtr desc, string attr_name, bool value); [DllImport(TensorFlowLibName)] - public static extern void TF_SetAttrValueProto(IntPtr desc, string attr_name, byte[] proto, int proto_len, SafeStatusHandle status); + public static extern void TF_SetAttrValueProto(IntPtr desc, string attr_name, byte[] proto, ulong proto_len, SafeStatusHandle status); /// /// Set `num_dims` to -1 to represent "unknown rank". diff --git a/src/TensorFlowNET.Core/Binding.Util.cs b/src/TensorFlowNET.Core/Binding.Util.cs index 5d9d799d..8df39334 100644 --- a/src/TensorFlowNET.Core/Binding.Util.cs +++ b/src/TensorFlowNET.Core/Binding.Util.cs @@ -22,6 +22,7 @@ using System.ComponentModel; using System.Diagnostics; using System.IO; using System.Linq; +using Tensorflow.Operations; namespace Tensorflow { diff --git a/src/TensorFlowNET.Core/Buffers/Buffer.cs b/src/TensorFlowNET.Core/Buffers/Buffer.cs index 9ec9e22f..330e30ca 100644 --- a/src/TensorFlowNET.Core/Buffers/Buffer.cs +++ b/src/TensorFlowNET.Core/Buffers/Buffer.cs @@ -107,6 +107,12 @@ namespace Tensorflow } } + public void Release() + { + _handle.Dispose(); + _handle = null; + } + public override string ToString() => $"0x{_handle.DangerousGetHandle():x16}"; diff --git a/src/TensorFlowNET.Core/Checkpoint/CheckPointUtils.cs b/src/TensorFlowNET.Core/Checkpoint/CheckPointUtils.cs index 9793798d..490c284b 100644 --- a/src/TensorFlowNET.Core/Checkpoint/CheckPointUtils.cs +++ b/src/TensorFlowNET.Core/Checkpoint/CheckPointUtils.cs @@ -161,7 +161,7 @@ public static class CheckPointUtils internal static IEnumerable _objects_with_attributes(IEnumerable full_list) { - return full_list.TakeWhile(x => + return full_list.Where(x => { var saveables = x.gather_saveables_for_checkpoint(); return saveables is not null && saveables.Count > 0; diff --git a/src/TensorFlowNET.Core/Checkpoint/SaveUtil.cs b/src/TensorFlowNET.Core/Checkpoint/SaveUtil.cs index 4aa2a808..7a5da7e3 100644 --- a/src/TensorFlowNET.Core/Checkpoint/SaveUtil.cs +++ b/src/TensorFlowNET.Core/Checkpoint/SaveUtil.cs @@ -109,6 +109,7 @@ namespace Tensorflow.Checkpoint TrackableObjectGraph.Types.TrackableObject trackable_object = new(); trackable_object.SlotVariables.AddRange(td.slot_variable_proto); trackable_object.Children.AddRange(td.children_proto); + object_graph_proto.Nodes.Add(trackable_object); } return object_graph_proto; } diff --git a/src/TensorFlowNET.Core/Contexts/Context.Config.cs b/src/TensorFlowNET.Core/Contexts/Context.Config.cs index b363b516..0c7bded6 100644 --- a/src/TensorFlowNET.Core/Contexts/Context.Config.cs +++ b/src/TensorFlowNET.Core/Contexts/Context.Config.cs @@ -14,9 +14,11 @@ limitations under the License. ******************************************************************************/ +using Google.Protobuf; using System; using System.Diagnostics; using System.Linq; +using Tensorflow.Common.Extensions; namespace Tensorflow.Contexts { @@ -25,12 +27,93 @@ namespace Tensorflow.Contexts /// public sealed partial class Context { - public ConfigProto Config { get; set; } = new ConfigProto + protected Device.PhysicalDevice[] _physical_devices; + protected Dictionary _physical_device_to_index; + ConfigProto _config; + public ConfigProto Config { - GpuOptions = new GPUOptions + get { + _initialize_physical_devices(); + + var config = new ConfigProto(); + if(_config is not null) + { + config.MergeFrom(_config); + } + config.LogDevicePlacement = _log_device_placement; + + config.DeviceCount["CPU"] = 0; + config.DeviceCount["GPU"] = 0; + foreach(var dev in _physical_devices) + { + if (config.DeviceCount.ContainsKey(dev.DeviceType)) + { + config.DeviceCount[dev.DeviceType] += 1; + } + else + { + config.DeviceCount[dev.DeviceType] = 1; + } + } + + var gpu_options = _compute_gpu_options(); + config.GpuOptions = GPUOptions.Parser.ParseFrom(gpu_options.ToByteArray()); + + return config; + } + set + { + _config = value; + } + } + + protected void _initialize_physical_devices(bool reinitialize = false) + { + if(!reinitialize && _physical_devices is not null) + { + return; + } + var devs = list_physical_devices(); + _physical_devices = devs.Select(d => new Device.PhysicalDevice() + { + DeviceName = d.DeviceName, + DeviceType = d.DeviceType + }).ToArray(); + _physical_device_to_index = _physical_devices.Select((p, i) => new KeyValuePair(p, i)) + .ToDictionary(x => x.Key, x => x.Value); + + _import_config(); + } + + protected void _import_config() + { + if(_config is null) + { + return; + } + if(!_config.DeviceCount.TryGetValue("CPU", out var num_cpus)) + { + num_cpus = 1; + } + if(num_cpus != 1) + { + // TODO(Rinne): implement it. } - }; + + var gpus = _physical_devices.Where(d => d.DeviceType == "GPU"); + if(gpus.Count() == 0) + { + return; + } + + if(!_config.DeviceCount.TryGetValue("GPU", out var gpu_count)) + { + gpu_count = 0; + } + + // TODO(Rinne): implement it. + } ConfigProto MergeConfig() { diff --git a/src/TensorFlowNET.Core/Contexts/Context.cs b/src/TensorFlowNET.Core/Contexts/Context.cs index deb67920..7fec1e5a 100644 --- a/src/TensorFlowNET.Core/Contexts/Context.cs +++ b/src/TensorFlowNET.Core/Contexts/Context.cs @@ -38,7 +38,26 @@ namespace Tensorflow.Contexts public string ScopeName { get; set; } = ""; bool initialized = false; ContextSwitchStack context_switches; - public FunctionCallOptions FunctionCallOptions { get; } + protected FunctionCallOptions _function_call_options; + public FunctionCallOptions FunctionCallOptions + { + get + { + if(_function_call_options is null) + { + var config = Config; + _function_call_options = new FunctionCallOptions() + { + Config = config + }; + } + return _function_call_options; + } + set + { + _function_call_options = value; + } + } SafeContextHandle _handle; @@ -62,7 +81,6 @@ namespace Tensorflow.Contexts if (initialized) return; - Config = MergeConfig(); FunctionCallOptions.Config = Config; var config_str = Config.ToByteArray(); var opts = new ContextOptions(); @@ -167,11 +185,29 @@ namespace Tensorflow.Contexts return c_api.TFE_ContextHasFunction(_handle, name); } + public void add_function(SafeFuncGraphHandle fn) + { + ensure_initialized(); + Status status = new(); + c_api.TFE_ContextAddFunction(_handle, fn, status); + status.Check(true); + } + + public void remove_function(string name) + { + ensure_initialized(); + Status status = new(); + c_api.TFE_ContextRemoveFunction(_handle, name, status); + status.Check(true); + } + public void add_function_def(FunctionDef fdef) { ensure_initialized(); - var fdef_string = fdef.ToString(); - c_api.TFE_ContextAddFunctionDef(_handle, fdef_string, fdef_string.Length); + var fdef_string = fdef.ToByteArray(); + Status status = new Status(); + c_api.TFE_ContextAddFunctionDef(_handle, fdef_string, (ulong)fdef_string.Length, status); + status.Check(true); } public void restore_mode() diff --git a/src/TensorFlowNET.Core/Contexts/FunctionCallOptions.cs b/src/TensorFlowNET.Core/Contexts/FunctionCallOptions.cs index 6b6028f0..2fcf9dce 100644 --- a/src/TensorFlowNET.Core/Contexts/FunctionCallOptions.cs +++ b/src/TensorFlowNET.Core/Contexts/FunctionCallOptions.cs @@ -9,10 +9,11 @@ namespace Tensorflow.Contexts public class FunctionCallOptions { public ConfigProto Config { get; set; } + public string ExecutorType { get; set; } - public string config_proto_serialized() + public ByteString config_proto_serialized() { - return Config.ToByteString().ToStringUtf8(); + return Config.ToByteString(); } } } diff --git a/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_Execute.cs b/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_Execute.cs index aa205d45..3806b3ad 100644 --- a/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_Execute.cs +++ b/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_Execute.cs @@ -17,6 +17,7 @@ using System; using System.Linq; using Tensorflow.Contexts; +using Tensorflow.Functions; using static Tensorflow.Binding; namespace Tensorflow.Eager diff --git a/src/TensorFlowNET.Core/Eager/backprop_util.cs b/src/TensorFlowNET.Core/Eager/backprop_util.cs new file mode 100644 index 00000000..0d726e1d --- /dev/null +++ b/src/TensorFlowNET.Core/Eager/backprop_util.cs @@ -0,0 +1,53 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Operations; + +namespace Tensorflow.Eager +{ + internal static class backprop_util + { + // TODO: add quantized_dtypes (after being supported). + private static HashSet _trainable_dtypes = new HashSet(new TF_DataType[] + { + dtypes.float16, dtypes.float32, dtypes.float64, dtypes.complex64, dtypes.complex128, + dtypes.resource, dtypes.variant, TF_DataType.TF_BFLOAT16 + }); + public static bool IsTrainable(Tensor tensor) + { + var dtype = _DTypeFromTensor(tensor); + return _trainable_dtypes.Contains(dtype); + } + public static bool IsTrainable(TF_DataType dtype) + { + return _trainable_dtypes.Contains(dtype); + } + + private static TF_DataType _DTypeFromTensor(Tensor tensor) + { + var dtype = tensor.dtype; + if(dtype.as_base_dtype() == TF_DataType.TF_VARIANT) + { + CppShapeInferenceResult.Types.HandleData handle_data; + if (tensor is EagerTensor) + { + handle_data = tensor.HandleData; + } + else + { + handle_data = handle_data_util.get_resource_handle_data(tensor); + } + if(handle_data is not null && handle_data.IsSet && handle_data.ShapeAndType is not null && + handle_data.ShapeAndType.Count > 0) + { + var first_type = handle_data.ShapeAndType[0].Dtype; + if(first_type != DataType.DtInvalid && handle_data.ShapeAndType.All(x => x.Dtype == first_type)) + { + return first_type.as_tf_dtype(); + } + } + } + return dtype; + } + } +} diff --git a/src/TensorFlowNET.Core/Eager/c_api.eager.cs b/src/TensorFlowNET.Core/Eager/c_api.eager.cs index e8746c1b..665e537f 100644 --- a/src/TensorFlowNET.Core/Eager/c_api.eager.cs +++ b/src/TensorFlowNET.Core/Eager/c_api.eager.cs @@ -31,7 +31,7 @@ namespace Tensorflow public static extern void TFE_ContextOptionsSetConfig(SafeContextOptionsHandle opts, byte[] proto, ulong proto_len, SafeStatusHandle status); [DllImport(TensorFlowLibName)] - public static extern void TFE_ContextAddFunctionDef(SafeContextHandle ctx, string serialized_function_def, int size); + public static extern void TFE_ContextAddFunctionDef(SafeContextHandle ctx, byte[] serialized_function_def, ulong size, SafeStatusHandle status); [DllImport(TensorFlowLibName)] public static extern void TFE_ContextOptionsSetDevicePlacementPolicy(SafeContextOptionsHandle opts, ContextDevicePlacementPolicy device_policy); @@ -280,7 +280,7 @@ namespace Tensorflow public static extern void TFE_OpSetAttrIntList(SafeEagerOpHandle op, string attr_name, long[] values, int num_values); [DllImport(TensorFlowLibName)] - public static extern void TFE_OpSetAttrValueProto(SafeEagerOpHandle op, string attr_name, IMessage[] proto, int proto_len, SafeStatusHandle status); + public static extern void TFE_OpSetAttrValueProto(IntPtr op, string attr_name, IntPtr proto, ulong proto_len, SafeStatusHandle status); /// /// diff --git a/src/TensorFlowNET.Core/Framework/Models/ScopedTFFunction.cs b/src/TensorFlowNET.Core/Framework/Models/ScopedTFFunction.cs deleted file mode 100644 index bce889b6..00000000 --- a/src/TensorFlowNET.Core/Framework/Models/ScopedTFFunction.cs +++ /dev/null @@ -1,6 +0,0 @@ -namespace Tensorflow.Framework.Models -{ - class ScopedTFFunction - { - } -} diff --git a/src/TensorFlowNET.Core/Framework/ScopedTFFunction.cs b/src/TensorFlowNET.Core/Framework/ScopedTFFunction.cs new file mode 100644 index 00000000..11e920f8 --- /dev/null +++ b/src/TensorFlowNET.Core/Framework/ScopedTFFunction.cs @@ -0,0 +1,22 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Framework +{ + internal class ScopedTFFunction + { + SafeFuncGraphHandle _handle; + string _name; + public ScopedTFFunction(SafeFuncGraphHandle func, string name) + { + _handle = func; + _name = name; + } + + public SafeFuncGraphHandle Get() + { + return _handle; + } + } +} diff --git a/src/TensorFlowNET.Core/Framework/function_def_lib.cs b/src/TensorFlowNET.Core/Framework/function_def_lib.cs index b81cb71b..67f8d324 100644 --- a/src/TensorFlowNET.Core/Framework/function_def_lib.cs +++ b/src/TensorFlowNET.Core/Framework/function_def_lib.cs @@ -4,6 +4,7 @@ using System.Diagnostics; using System.Security.Cryptography; using System.Text; using Tensorflow.Graphs; +using Tensorflow.Common.Extensions; using static Tensorflow.Binding; using static Tensorflow.CppShapeInferenceResult.Types; @@ -64,7 +65,7 @@ namespace Tensorflow.Framework { output_names[ops.tensor_id(func_graph.get_tensor_by_name(tensor_name))] = ret_arg_def.Name; } - // TODO(Rinne): func_graph._output_names = output_names + func_graph._output_names = output_names; func_graph.Exit(); return func_graph; @@ -154,9 +155,17 @@ namespace Tensorflow.Framework foreach(var node_def in fdef.NodeDef) { var graph = default_graph; - // TODO(Rinne): The `Graph` lacks `_functions`, needed to be implemented in the future. - while(graph.OuterGraph is not null) + while (true) { + if(graph is null) + { + break; + } + var f = graph.Functions.GetOrDefault(node_def.Op, null); + if(f is not null && graph.OuterGraph is null) + { + break; + } graph = graph.OuterGraph; } diff --git a/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs b/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs index 402d876e..8524f724 100644 --- a/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs +++ b/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs @@ -4,6 +4,7 @@ using System.Diagnostics; using System.Linq; using Tensorflow.Eager; using Tensorflow.Framework.Models; +using Tensorflow.Gradients; using Tensorflow.Graphs; using Tensorflow.Train; using Tensorflow.Util; @@ -19,7 +20,7 @@ namespace Tensorflow.Functions protected IEnumerable _captured_inputs; internal FuncGraph func_graph; protected DelayedRewriteGradientFunctions _delayed_rewrite_functions; - protected Dictionary _attrs; + protected Dictionary _attrs; protected FunctionSpec _function_spec; protected FunctionSpec _pre_initialized_function_spec = null; protected EagerDefinedFunction _inference_function; @@ -29,22 +30,25 @@ namespace Tensorflow.Functions public string Name => _delayed_rewrite_functions.Forward().Name; - public Tensor[] Outputs; + public Tensor[] Outputs => func_graph.Outputs; public Type ReturnType; public TensorSpec[] OutputStructure; public IEnumerable ArgKeywords { get; set; } public long NumPositionArgs { get; set; } public FunctionDef FunctionDef => _delayed_rewrite_functions.Forward().Definition; + public Tensor[] FlatStructuredOutputs => func_graph.FlatStructuredOutputs; + public IEnumerable Variables => func_graph.Variables; + public IEnumerable TrainableVariables => func_graph.TrainableVariables; public ConcreteFunction(string name) { func_graph = new FuncGraph(name); _captured_inputs = func_graph.external_captures; - _attrs= new Dictionary(); + _attrs= new Dictionary(); _set_infer_function(); } - public ConcreteFunction(FuncGraph graph, Dictionary attrs = null) + public ConcreteFunction(FuncGraph graph, Dictionary attrs = null) { func_graph = graph; _captured_inputs = func_graph.external_captures; @@ -70,7 +74,7 @@ namespace Tensorflow.Functions null); func_graph.Exit(); _captured_inputs = func_graph.external_captures; - _attrs = new Dictionary(); + _attrs = new Dictionary(); _set_infer_function(); } @@ -93,7 +97,7 @@ namespace Tensorflow.Functions null); func_graph.Exit(); _captured_inputs = func_graph.external_captures; - _attrs = new Dictionary(); + _attrs = new Dictionary(); _set_infer_function(); } @@ -160,27 +164,20 @@ namespace Tensorflow.Functions } if (!executing_eagerly) { - + // TODO(Rinne): add the check } - tensor_inputs.AddRange(captured_inputs); + tensor_inputs.AddRange(captured_inputs); args = tensor_inputs.ToArray(); - var possible_gradient_type = tf.Runner.MustRecordGradient() ? 1 : 0; + var possible_gradient_type = gradients_util.PossibleTapeGradientTypes(args); // No tape is watching; skip to running the function. - if (possible_gradient_type == 0 && executing_eagerly) + if (possible_gradient_type == gradients_util.POSSIBLE_GRADIENT_TYPES_NONE && executing_eagerly) { return _build_call_outputs(_inference_function.Call(args)); - //var attrs = new object[] - //{ - // "executor_type", "", - // "config_proto", tf.Context.FunctionCallOptions.config_proto_serialized() - //}; - //return tf.Runner.Execute(tf.Context, func_graph.FuncName, func_graph.Outputs.Length, args, attrs); } - if (forward_backward == null) - forward_backward = SelectForwardAndBackwardFunctions(args, possible_gradient_type, executing_eagerly); + forward_backward = SelectForwardAndBackwardFunctions(args, possible_gradient_type, executing_eagerly); var (forward_function, args_with_tangents) = forward_backward.Forward(); Tensors flat_outputs = null; if (executing_eagerly) @@ -189,8 +186,12 @@ namespace Tensorflow.Functions } else { - // TODO(Rinne): add `default_graph._override_gradient_function`. - flat_outputs = forward_function.Call(args_with_tangents); + tf_with(default_graph._override_gradient_function(new Dictionary>(){ + { "PartitionedCall", _get_gradient_function() }, { "StatefulPartitionedCall", _get_gradient_function() } + }), _ => + { + flat_outputs = forward_function.Call(args_with_tangents); + }); } forward_backward.Record(flat_outputs); return _build_call_outputs(flat_outputs); @@ -215,7 +216,8 @@ namespace Tensorflow.Functions TangentInfo input_tangents; if (executing_eagerly) { - throw new NotImplementedException(); + // TODO(Rinne): check if it needs to be implemented. + input_tangents = new TangentInfo(); } else { @@ -239,7 +241,12 @@ namespace Tensorflow.Functions } // TODO(Rinne): add arg "input_tagents" for ForwardBackwardCall. - return new ForwardBackwardCall(_delayed_rewrite_functions, args, tape_watching: false); + return new ForwardBackwardCall(_delayed_rewrite_functions, args, tape_watching: tf.Runner.MustRecordGradient()); + } + + internal void set_variables(IEnumerable variables) + { + func_graph.Variables = variables; } internal void _set_infer_function() @@ -274,6 +281,11 @@ namespace Tensorflow.Functions }; } + internal Func _get_gradient_function() + { + return _delayed_rewrite_functions._rewrite_forward_and_call_backward; + } + private Tensors _build_call_outputs(Tensors result) { // TODO(Rinne): dwal with `func_graph.structured_outputs` diff --git a/src/TensorFlowNET.Core/Functions/EagerDefinedFunction.cs b/src/TensorFlowNET.Core/Functions/EagerDefinedFunction.cs index 61d3121c..c2f8e016 100644 --- a/src/TensorFlowNET.Core/Functions/EagerDefinedFunction.cs +++ b/src/TensorFlowNET.Core/Functions/EagerDefinedFunction.cs @@ -9,18 +9,27 @@ using Tensorflow.Eager; using Tensorflow.Graphs; using Tensorflow.Operations; using Tensorflow.Util; +using Tensorflow.Common.Extensions; using static Tensorflow.Binding; +using Tensorflow.Framework; +using System.Buffers; +using Tensorflow.Gradients; namespace Tensorflow.Functions { - public class EagerDefinedFunction + public class EagerDefinedFunction: IDisposable { public int _num_outputs; - FuncGraph _func_graph; + FuncGraph _graph; FunctionDef _definition; OpDef _signature; string _name; - Tensor[] _func_graph_outputs; + internal ScopedTFFunction _c_func; + internal Tensor[] _func_graph_outputs; + internal string _grad_func_name; + internal Func csharp_grad_func; + internal EagerDefinedFunction _grad_func; + internal bool _registered_on_context = false; public string Name => _name; public DataType[] OutputTypes { get; protected set; } public Shape[] OutputShapes { get; protected set; } @@ -47,48 +56,93 @@ namespace Tensorflow.Functions return _signature; } } - public EagerDefinedFunction(string name, FuncGraph graph, + public unsafe EagerDefinedFunction(string name, FuncGraph graph, Tensors inputs, Tensors outputs, - Dictionary attrs) + Dictionary attrs) { var input_ops = inputs.Select(x => x.op).ToArray(); var operations = graph.get_operations().Where(x => !input_ops.Contains(x.op)) .Select(x => x as Operation).ToArray(); - var output_names = new string[0]; - _func_graph = new FuncGraph(graph, name, attrs); - _func_graph_outputs = new List(outputs).ToArray(); - _func_graph.ToGraph(operations, inputs, outputs, output_names); + var graph_output_names = graph._output_names; + string[] output_names; + if(graph_output_names is not null && outputs.All(t => graph_output_names.ContainsKey(ops.tensor_id(t)))) + { + output_names = outputs.Select(t => graph_output_names[ops.tensor_id(t)]).ToArray(); + if(output_names.Distinct().Count() != output_names.Length) + { + output_names = new string[0]; + } + } + else + { + output_names = new string[0]; + } + + Status status = new Status(); + var fn = c_api.TF_GraphToFunction(graph.c_graph, + name, + false, + operations.Length, + operations.Length == 0 ? new IntPtr[0] : operations.Select(x => (IntPtr)x).ToArray(), + inputs.Length, + inputs.Select(t => t._as_tf_output()).ToArray(), + outputs.Length, + outputs.Select(t => t._as_tf_output()).ToArray(), + output_names.Length != outputs.Length ? null : output_names, + IntPtr.Zero, // warning: the control output hasbben totally ignored. + null, + status); + status.Check(true); + + _c_func = new ScopedTFFunction(fn, name); + + foreach(var (attr_name, attr_value) in attrs) + { + var serialized = attr_value.ToByteArray(); + c_api.TF_FunctionSetAttrValueProto(fn, attr_name, serialized, serialized.Length, status); + status.Check(true); + } var signature = _get_definition().Signature; _name = signature.Name; - // TODO(Rinne): deal with `fn` + tf_with(ops.init_scope(), s => + { + tf.Context.add_function(fn); + _registered_on_context = true; + }); _num_outputs = signature.OutputArg.Count; OutputTypes = signature.OutputArg.Select(x => x.Type).ToArray(); OutputShapes = outputs.Select(x => x.shape).ToArray(); + _func_graph_outputs = new List(outputs).ToArray(); + csharp_grad_func = null; + _graph = graph; } - public Tensors Call(Tensors args) + public unsafe Tensors Call(Tensors args) { // TODO(Rinne): Add arg `CancellationManager`. // TODO(Rinne): Check the arg length. var function_call_options = tf.Context.FunctionCallOptions; string config; - if (string.IsNullOrEmpty(function_call_options.config_proto_serialized())) + if (function_call_options.config_proto_serialized().Length == 0) { - config = function_utils.get_disabled_rewriter_config(); + config = function_utils.get_disabled_rewriter_config().ToString(); } else { - config = function_call_options.config_proto_serialized(); + config = function_call_options.config_proto_serialized().ToString(); } - // TODO(Rinne): executor_type + + config = ""; // TODO(Rinne): revise it. + + string executor_type = function_call_options.ExecutorType ?? ""; var executing_eagerly = tf.Context.executing_eagerly(); var attrs = new object[] { - "executor_type", "", - "config_proto", tf.Context.FunctionCallOptions.config_proto_serialized() + "executor_type", executor_type, + "config_proto", config }; Tensor[] outputs; @@ -103,9 +157,19 @@ namespace Tensorflow.Functions } else { - tf.GradientTape().stop_recording(); - outputs = functional_ops.partitioned_call(args, this, OutputTypes, - executing_eagerly, config, ""); + if(tf.GetTapeSet().Count == 0) + { + outputs = functional_ops.partitioned_call(args, this, OutputTypes, + executing_eagerly, config, ""); + } + else + { + var tape = tf.GetTapeSet().Peek(); + tape.StopRecord(); + outputs = functional_ops.partitioned_call(args, this, OutputTypes, + executing_eagerly, config, ""); + tape.StartRecord(); + } } foreach(var (i, func_graph_output) in enumerate(_func_graph_outputs)) { @@ -141,7 +205,7 @@ namespace Tensorflow.Functions { g.AddFunction(this); } - foreach(var f in _func_graph.Functions.Values) + foreach(var f in _graph.Functions.Values) { if (!g.IsFunction(f.Name)) { @@ -155,12 +219,15 @@ namespace Tensorflow.Functions { var buffer = c_api_util.tf_buffer(); Status status = new(); - c_api.TF_FunctionToFunctionDef(_func_graph._func_graph_handle, buffer, status); + c_api.TF_FunctionToFunctionDef(_c_func.Get(), buffer, status); status.Check(true); var proto_data = c_api.TF_GetBuffer(buffer); - FunctionDef function_def = new(); - function_def.MergeFrom(proto_data.AsSpan()); - return function_def; + return FunctionDef.Parser.ParseFrom(proto_data.AsSpan()); + } + + public void Dispose() + { + tf.Context.remove_function(Name); } } } diff --git a/src/TensorFlowNET.Core/Functions/FirstOrderTapeGradientFunctions.cs b/src/TensorFlowNET.Core/Functions/FirstOrderTapeGradientFunctions.cs index 3c099927..c0e69dba 100644 --- a/src/TensorFlowNET.Core/Functions/FirstOrderTapeGradientFunctions.cs +++ b/src/TensorFlowNET.Core/Functions/FirstOrderTapeGradientFunctions.cs @@ -17,9 +17,9 @@ namespace Tensorflow.Functions public override EagerDefinedFunction ForwardAndBackwardFunctions(Tensors inference_args) { var outputs = _func_graph.Outputs; - (_forward, _forward_graph, _backward, _forwardprop_output_indices, _num_forwardprop_outputs) + (_forward_function, _forward_graph, _backward_function, _forwardprop_output_indices, _num_forwardprop_outputs) = BuildFunctionsForOutputs(outputs, inference_args); - return _forward; + return _forward_function; } } } diff --git a/src/TensorFlowNET.Core/Functions/Function.cs b/src/TensorFlowNET.Core/Functions/Function.cs index cfea3954..a53df14c 100644 --- a/src/TensorFlowNET.Core/Functions/Function.cs +++ b/src/TensorFlowNET.Core/Functions/Function.cs @@ -10,23 +10,26 @@ namespace Tensorflow private IntPtr _handle; #pragma warning restore CS0169 // The field 'Function._handle' is never used - protected Func _function; + protected Func _csharp_function; protected ConcreteFunction _concrete_variable_creation_fn; - protected bool _auto_graph; + protected bool _autograph; + protected TracingCompiler _variable_creation_fn; + protected bool _has_initialized; public string Name { get; set; } - public Function(Func function, + public Function(Func csharp_function, string name, bool auto_graph = true) { - _function = function; + _csharp_function = csharp_function; Name = name; - _auto_graph = auto_graph; + _autograph = auto_graph; + _has_initialized = false; } public virtual Tensors Apply(Tensors inputs) { if (_run_functions_eagerly()) { - return _function(inputs); + return _csharp_function(inputs); } var result = _call(inputs); @@ -35,20 +38,32 @@ namespace Tensorflow protected virtual Tensors _call(Tensors inputs) { - _initialize(); + if (!_has_initialized) + { + _initialize(inputs); + } return _concrete_variable_creation_fn.CallFlat(inputs, _concrete_variable_creation_fn.CapturedInputs); } + protected TracingCompiler _compiler(Func fn) + { + var name = nameof(fn); + return new TracingCompiler(fn, name, autograph: _autograph); + } + protected virtual bool _run_functions_eagerly() { return false; } - private void _initialize() + private void _initialize(Tensor[] args) { - + _variable_creation_fn = _compiler(_csharp_function); + _variable_creation_fn._name = this.Name; + _concrete_variable_creation_fn = _variable_creation_fn._get_concrete_function_internal_garbage_collected(args); + _has_initialized = true; } } } diff --git a/src/TensorFlowNET.Core/Functions/TapeGradientFunctions.cs b/src/TensorFlowNET.Core/Functions/TapeGradientFunctions.cs index 23889d44..638aeaf1 100644 --- a/src/TensorFlowNET.Core/Functions/TapeGradientFunctions.cs +++ b/src/TensorFlowNET.Core/Functions/TapeGradientFunctions.cs @@ -3,8 +3,10 @@ using System.Collections.Generic; using System.Linq; using System.Text; using Tensorflow.Eager; +using Tensorflow.Gradients; using Tensorflow.Graphs; using Tensorflow.NumPy; +using Tensorflow.Operations; using static Tensorflow.Binding; using static Tensorflow.tensorflow; @@ -22,11 +24,11 @@ namespace Tensorflow.Functions protected string _INFERENCE_PREFIX = "__inference_"; protected FuncGraph _func_graph; - protected EagerDefinedFunction _forward; + protected EagerDefinedFunction _forward_function; protected FuncGraph _forward_graph; protected List _forwardprop_output_indices; protected int _num_forwardprop_outputs; - protected ConcreteFunction _backward; + protected ConcreteFunction _backward_function; BackwardFunction _backward_function_wrapper; public TapeGradientFunctions(FuncGraph func_graph, @@ -49,8 +51,8 @@ namespace Tensorflow.Functions public virtual void Record(Tensors flat_outputs, Tensors inference_args) { // TODO(Rinne): add arg `input_tagents`. - var (backward_function, to_record) = _wrap_backward_function(_forward_graph, _backward, flat_outputs); - tf.Runner.RecordGradient(_forward.Name, inference_args, new object[0], to_record, + var (backward_function, to_record) = _wrap_backward_function(_forward_graph, _backward_function, flat_outputs); + tf.Runner.RecordGradient(_forward_function.Name, inference_args, new object[0], to_record, getBackwardFunction: backward_function); } @@ -134,46 +136,58 @@ namespace Tensorflow.Functions var trainable_indices = new List(); foreach(var (index, output) in enumerate(outputs)) { - if (gradients_util.IsTrainable(output)) + if (backprop_util.IsTrainable(output)) { trainable_outputs.Add(output); trainable_indices.Add(index); } } - var gradients_wrt_outputs = new List(); - var backwards_graph = new FuncGraph($"{_BACKWARD_PREFIX}_{_func_graph.FuncName}_{ops.uid()}"); + var backwards_graph = new FuncGraph(_func_graph.Name); backwards_graph.as_default(); + var gradients_wrt_outputs = new List(); foreach (var output in trainable_outputs) - gradients_wrt_outputs.Add(tf.placeholder(output.dtype, output.shape)); + { + var (gradient_shape, gradient_dtype) = default_gradient.shape_and_dtype(output); + var gradient_placeholder = tf.placeholder(gradient_dtype, gradient_shape); + gradients_wrt_outputs.Add(gradient_placeholder); + handle_data_util.copy_handle_data(output, gradient_placeholder); + } var gradients_wrt_inputs = gradients_util._GradientsHelper(trainable_outputs.ToArray(), - _func_graph.Inputs, - grad_ys: gradients_wrt_outputs.ToArray(), - src_graph: _func_graph); + _func_graph.Inputs, + grad_ys: gradients_wrt_outputs.ToArray(), + src_graph: _func_graph); var captures_from_forward = backwards_graph.external_captures .Where(x => x is not EagerTensor && x is not NDArray && x.graph == _func_graph) .ToArray(); + HashSet existing_outputs = new(_func_graph.Outputs); foreach(var capture in captures_from_forward) { - if (!_func_graph.Outputs.Contains(capture)) + if (!existing_outputs.Contains(capture)) + { + existing_outputs.Add(capture); _func_graph.Outputs.Add(capture); + } } backwards_graph.Exit(); - var forward_function_name = $"{_FORWARD_PREFIX}_{_func_graph.FuncName}_{ops.uid()}"; - var backward_function_attr = new Dictionary(); - backward_function_attr[FORWARD_FUNCTION_ATTRIBUTE_NAME] = forward_function_name; - gradients_wrt_outputs.append(backwards_graph.internal_captures); - backwards_graph.Inputs = gradients_wrt_outputs; - backwards_graph.Outputs = gradients_wrt_inputs; + backwards_graph.Inputs = gradients_wrt_outputs.Concat(backwards_graph.internal_captures).ToArray(); + backwards_graph.Outputs.AddRange(gradients_wrt_inputs.Where(x => x is not null)); + + var (forward_function, backward_function) = monomorphic_function_utils._create_forward_backward_with_graph(null, _func_graph, backwards_graph); + //var forward_function_name = $"{_FORWARD_PREFIX}_{_func_graph.FuncName}_{ops.uid()}"; + //var backward_function_attr = new Dictionary(); + //backward_function_attr[FORWARD_FUNCTION_ATTRIBUTE_NAME] = forward_function_name; - var backward_function = new ConcreteFunction(backwards_graph, backward_function_attr); + //var backward_function = new ConcreteFunction(backwards_graph, + // monomorphic_function_utils._parse_func_attrs(backward_function_attr)); - var forward_function_attr = new Dictionary(); - forward_function_attr[BACKWARD_FUNCTION_ATTRIBUTE_NAME] = backward_function.Name; - var forward_function = new EagerDefinedFunction(forward_function_name, _func_graph, - _func_graph.Inputs, _func_graph.Outputs, forward_function_attr); + //var forward_function_attr = new Dictionary(); + //forward_function_attr[BACKWARD_FUNCTION_ATTRIBUTE_NAME] = backward_function.Name; + //var forward_function = new EagerDefinedFunction(forward_function_name, _func_graph, + // _func_graph.Inputs, _func_graph.Outputs, + // monomorphic_function_utils._parse_func_attrs(forward_function_attr)); return (forward_function, _func_graph, backward_function, null, 0); } diff --git a/src/TensorFlowNET.Core/Functions/TracingCompiler.cs b/src/TensorFlowNET.Core/Functions/TracingCompiler.cs new file mode 100644 index 00000000..8a844671 --- /dev/null +++ b/src/TensorFlowNET.Core/Functions/TracingCompiler.cs @@ -0,0 +1,84 @@ +using System; +using System.Collections.Generic; +using System.Security.Cryptography.X509Certificates; +using System.Text; +using Tensorflow.Graphs; + +namespace Tensorflow.Functions +{ + public class TracingCompiler + { + Func _csharp_function; + //FunctionSpec _function_spec; + internal string _name; + bool _autograph; + Dictionary _function_cache; + Dictionary _function_attributes; + int _tracing_count; + + public TracingCompiler(Func csharp_function, string name, object? input_signatures = null, + Dictionary attributes = null, bool autograph = true, object? autograph_options = null, + bool reduce_retracing = false, bool capture_by_value = false) + { + _csharp_function = csharp_function; + bool pure_function = attributes is not null && attributes.Count > 0 && attributes.ContainsKey(monomorphic_function_utils.IMPLEMENTS_ATTRIBUTE_NAME); + _name = name; + _autograph = autograph; + _function_attributes = attributes ?? new Dictionary(); + _function_cache = new Dictionary(); + _tracing_count = 0; + } + + public Tensor[] Apply(Tensor[] inputs) + { + // TODO(Rinne): add lock here. + var (concrete_function, filtered_flat_args) = _maybe_define_function(inputs); + return concrete_function.CallFlat(filtered_flat_args, concrete_function.CapturedInputs); + } + + internal ConcreteFunction _get_concrete_function_internal_garbage_collected(Tensor[] args) + { + var (concrete_function, _) = _maybe_define_concrete_function(args); + return concrete_function; + } + + private (ConcreteFunction, Tensor[]) _maybe_define_concrete_function(Tensor[] args) + { + return _maybe_define_function(args); + } + + private (ConcreteFunction, Tensor[]) _maybe_define_function(Tensor[] args) + { + var lookup_func_key = male_cache_key(args); + if(_function_cache.TryGetValue(lookup_func_key, out var concrete_function)) + { + return (concrete_function, args); + } + concrete_function = _create_concrete_function(args); + _function_cache[lookup_func_key] = concrete_function; + return (concrete_function, args); + } + + private ConcreteFunction _create_concrete_function(Tensor[] args) + { + _tracing_count++; + + int arglen = args.Length; + var concrete_function = new ConcreteFunction(FuncGraph.func_graph_from_func( + _name, x => _csharp_function(x.Where(y => y is Tensor).Select(y => (Tensor)y).ToArray()), + args, new Dictionary(), autograph: _autograph + ), _function_attributes); + return concrete_function; + } + + private static string male_cache_key(Tensor[] inputs) + { + string res = ""; + foreach (var input in inputs) + { + res += $"{input.name}_{input.Id}"; + } + return res; + } + } +} diff --git a/src/TensorFlowNET.Core/Functions/c_api.function.cs b/src/TensorFlowNET.Core/Functions/c_api.function.cs index 3fbb3868..04d102b5 100644 --- a/src/TensorFlowNET.Core/Functions/c_api.function.cs +++ b/src/TensorFlowNET.Core/Functions/c_api.function.cs @@ -16,6 +16,7 @@ using System; using System.Runtime.InteropServices; +using Tensorflow.Functions; namespace Tensorflow { @@ -54,6 +55,9 @@ namespace Tensorflow public static extern IntPtr TF_FunctionName(SafeFuncGraphHandle func); [DllImport(TensorFlowLibName)] - public static extern void TF_GraphCopyFunction(SafeGraphHandle g, SafeFuncGraphHandle func, IntPtr grad, SafeStatusHandle status); + public static extern void TF_GraphCopyFunction(SafeGraphHandle g, SafeFuncGraphHandle func, SafeFuncGraphHandle grad, SafeStatusHandle status); + + [DllImport(TensorFlowLibName)] + public static extern int TF_GraphGetFunctions(SafeGraphHandle g, IntPtr[] funcs, int max_func, SafeStatusHandle status); } } diff --git a/src/TensorFlowNET.Core/Functions/composite_tensor_utils.cs b/src/TensorFlowNET.Core/Functions/composite_tensor_utils.cs new file mode 100644 index 00000000..7994bef1 --- /dev/null +++ b/src/TensorFlowNET.Core/Functions/composite_tensor_utils.cs @@ -0,0 +1,50 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Framework; +using Tensorflow.Framework.Models; +using Tensorflow.Util; + +namespace Tensorflow.Functions +{ + internal static class composite_tensor_utils + { + public static List flatten_with_variables(object inputs) + { + List flat_inputs = new(); + foreach(var value in nest.flatten(inputs)) + { + if(value is CompositeTensor && !resource_variable_ops.is_resource_variable(value)) + { + throw new NotImplementedException("The composite tensor has not been fully supported."); + } + else + { + flat_inputs.Add(value); + } + } + return flat_inputs; + } + public static List flatten_with_variables_or_variable_specs(object arg) + { + List flat_inputs = new(); + foreach(var value in nest.flatten(arg)) + { + if(value is CompositeTensor && !resource_variable_ops.is_resource_variable(value)) + { + throw new NotImplementedException("The composite tensor has not been fully supported."); + } + // TODO(Rinne): deal with `VariableSpec`. + else if(value is TypeSpec type_spec && value is not TensorSpec) + { + throw new NotImplementedException("The TypeSpec has not been fully supported."); + } + else + { + flat_inputs.Add(value); + } + } + return flat_inputs; + } + } +} diff --git a/src/TensorFlowNET.Core/Functions/function_saved_model_utils.cs b/src/TensorFlowNET.Core/Functions/function_saved_model_utils.cs index e92fa3a1..b3caef96 100644 --- a/src/TensorFlowNET.Core/Functions/function_saved_model_utils.cs +++ b/src/TensorFlowNET.Core/Functions/function_saved_model_utils.cs @@ -34,11 +34,10 @@ namespace Tensorflow.Functions "https://github.com/SciSharp/TensorFlow.NET/issues"); } }); - var bound_variables = inputs.TakeWhile(obj => obj is IVariableV1); + var bound_variables = inputs.Where(obj => obj is IVariableV1).Select(x => (IVariableV1)x); List captured_inputs_list = new(); - // TODO(Rinne): concrete_function.set_variables(bound_variables) - + concrete_function.set_variables(bound_variables); if (bound_inputs is not null) { @@ -54,8 +53,14 @@ namespace Tensorflow.Functions concrete_function.func_graph.replace_capture(bound_input, internal_capture); if(internal_capture.dtype == dtypes.resource) { - // skip the check of variable. - handle_data_util.copy_handle_data(bound_input, internal_capture); + if (resource_variable_ops.is_resource_variable(bound_input)) + { + handle_data_util.copy_handle_data(bound_input.Handle, internal_capture); + } + else + { + handle_data_util.copy_handle_data(bound_input, internal_capture); + } } concrete_function.func_graph.capture(bound_input); } diff --git a/src/TensorFlowNET.Core/Functions/monomorphic_function.cs b/src/TensorFlowNET.Core/Functions/monomorphic_function.cs index a8769438..acf00597 100644 --- a/src/TensorFlowNET.Core/Functions/monomorphic_function.cs +++ b/src/TensorFlowNET.Core/Functions/monomorphic_function.cs @@ -1,20 +1,137 @@ -using System; +using Google.Protobuf; +using System; using System.Collections.Generic; using System.Text; +using Tensorflow.Eager; +using Tensorflow.Framework.Models; +using Tensorflow.Gradients; using Tensorflow.Graphs; +using Tensorflow.Common.Extensions; +using Tensorflow.Operations; +using Tensorflow.Framework; +using static Tensorflow.Binding; +using System.Diagnostics; namespace Tensorflow.Functions { - public class DelayedRewriteGradientFunctions: TapeGradientFunctions + internal static class monomorphic_function_utils + { + internal static string _FORWARD_PREFIX = "__forward_"; + internal static string _BACKWARD_PREFIX = "__backward_"; + internal static string _INFERENCE_PREFIX = "__inference_"; + internal static string IMPLEMENTS_ATTRIBUTE_NAME = "_implements"; + internal static string FORWARD_FUNCTION_ATTRIBUTE_NAME = "forward_function_name"; + internal static string BACKWARD_FUNCTION_ATTRIBUTE_NAME = "backward_function_name"; + public static string _inference_name(string name) + { + return $"{_INFERENCE_PREFIX}{name}_{ops.uid()}"; + } + public static string _forward_name(string name) + { + return $"{_FORWARD_PREFIX}{name}_{ops.uid()}"; + } + public static string _backward_name(string name) + { + return $"{_BACKWARD_PREFIX}{name}_{ops.uid()}"; + } + + public static (EagerDefinedFunction, ConcreteFunction) _create_forward_backward_with_graph(Dictionary attrs, + FuncGraph forward_graph, FuncGraph backwards_graph) + { + string forward_function_name = _forward_name(forward_graph.Name); + Dictionary common_attributes; + if(attrs is null) + { + common_attributes = new Dictionary(); + } + else + { + common_attributes = new Dictionary(attrs); + } + + if (common_attributes.ContainsKey(IMPLEMENTS_ATTRIBUTE_NAME)) + { + common_attributes.Remove(IMPLEMENTS_ATTRIBUTE_NAME); + } + var backward_function_attr = _parse_func_attrs(new Dictionary() + { + {FORWARD_FUNCTION_ATTRIBUTE_NAME, forward_function_name } + }); + backward_function_attr.Update(common_attributes); + var backward_function = new ConcreteFunction(backwards_graph, backward_function_attr); + var forward_function_attr = _parse_func_attrs(new Dictionary() + { + {BACKWARD_FUNCTION_ATTRIBUTE_NAME, backward_function.Name } + }); + forward_function_attr.Update(common_attributes); + var forward_function = new EagerDefinedFunction(forward_function_name, forward_graph, + forward_graph.Inputs, forward_graph.Outputs, forward_function_attr); + return (forward_function, backward_function); + } + + public static Dictionary _parse_func_attrs(Dictionary attributes) + { + Dictionary attrs = new(); + foreach(var item in attributes) + { + var key = item.Key; + var value = item.Value; + if (value is AttrValue attr_value) + { + attrs[key] = attr_value; + } + else if (value is bool b) + { + attrs[key] = new AttrValue() { B = b }; + } + else if (value is int i) + { + attrs[key] = new AttrValue() { I = i }; + } + else if (value is float f) + { + attrs[key] = new AttrValue() { F = f }; + } + else if(value is string s) + { + attrs[key] = new AttrValue() { S = ByteString.CopyFromUtf8(s) }; + } + else if (value is byte[] bytes) + { + attrs[key] = new AttrValue() { S = ByteString.CopyFrom(bytes) }; + } + else + { + throw new ValueError($"Attribute {key} must be bool, int, float, string, or " + + $"AttrValue. Got {value.GetType()}."); + } + } + return attrs; + } + + public static Dictionary _parse_func_attrs(Dictionary attributes) + { + Dictionary attrs = new(); + foreach (var item in attributes) + { + var key = item.Key; + var value = item.Value; + attrs[key] = new AttrValue() { S = ByteString.CopyFromUtf8(value) }; + } + return attrs; + } + } + public class DelayedRewriteGradientFunctions : TapeGradientFunctions { EagerDefinedFunction _inference_function; - Dictionary _attrs; + Dictionary _attrs; int _num_inference_outputs; - public DelayedRewriteGradientFunctions(FuncGraph func_graph, Dictionary attrs) - :base(func_graph, false) + Dictionary _cached_function_pairs = new(); + public DelayedRewriteGradientFunctions(FuncGraph func_graph, Dictionary attrs) + : base(func_graph, false) { - _func_graph= func_graph; - _inference_function = new EagerDefinedFunction(_inference_name(_func_graph.Name), + _func_graph = func_graph; + _inference_function = new EagerDefinedFunction(monomorphic_function_utils._inference_name(_func_graph.Name), _func_graph, _func_graph.Inputs, _func_graph.Outputs, attrs); _attrs = attrs; _num_inference_outputs = _func_graph.Outputs.Length; @@ -22,7 +139,7 @@ namespace Tensorflow.Functions public override EagerDefinedFunction Forward(Tensors inference_args = null, Tensors input_tangents = null) { - if(input_tangents is not null) + if (input_tangents is not null) { throw new InvalidArgumentError($"unexpectedly got forwardprop information in " + $"a class that does not support forwardprop."); @@ -32,23 +149,134 @@ namespace Tensorflow.Functions public override void Record(Tensors flat_outputs, Tensors inference_args) { - // TODO(Rinne): implement it. - throw new NotImplementedException(); - base.Record(flat_outputs, inference_args); + var (backward_function, to_record) = _backward(flat_outputs); + foreach(var tape in tf.GetTapeSet()) + { + tape.RecordOperation(_inference_function.Signature.Name, to_record, + inference_args.Select(t => new TapeTensor(t)).ToArray(), backward_function); + } } - //private (BackwardFunction, Tensors) _backward(Tensors outputs) - //{ - // Tensor[] backward_function(Tensor[] grads, long[] unneeded_gradients) - // { - // var call_op = outputs[0].op; + public (EagerDefinedFunction, ConcreteFunction) forward_backward(int num_doutputs = -2) + { + if(num_doutputs == -2) + { + num_doutputs = _num_inference_outputs; + } + if(_cached_function_pairs.TryGetValue(num_doutputs, out var target)) + { + return target; + } + var (forward, backward) = _construct_forward_backward(num_doutputs); + _cached_function_pairs[num_doutputs] = (forward, backward); + return (forward, backward); - // } - //} + } - private string _inference_name(string name) + private (BackwardFunction, Tensors) _backward(Tensors outputs) { - return $"{_INFERENCE_PREFIX}{name}_{ops.uid()}"; + Tensor[] backward_function(Tensor[] args, long[] unneeded_gradients) + { + var call_op = outputs[0].op; + return _rewrite_forward_and_call_backward(call_op, args); + } + return (backward_function, outputs); + } + + internal Tensor[] _rewrite_forward_and_call_backward(Operation op, params object[] doutputs) + { + var (forward_function, backward_function) = forward_backward(doutputs.Length); + if(backward_function.Outputs is null || backward_function.Outputs.Length == 0) + { + return backward_function.FlatStructuredOutputs; + } + forward_function.AddToGraph(op.graph); + + op._set_func_attr("f", forward_function.Name); + op._set_type_list_attr("Tout", forward_function.OutputTypes); + op._add_outputs(forward_function.OutputTypes.Select(x => x.as_tf_dtype()). + Skip(op.outputs.Length).ToArray(), forward_function.OutputShapes.Skip(op.outputs.Length).ToArray() + ); + for(int i = 0; i < op.outputs.Length; i++) + { + var func_graph_output = forward_function._func_graph_outputs[i]; + handle_data_util.copy_handle_data(func_graph_output, op.outputs[i]); + } + + var capture_mapping = zip(_func_graph.Outputs.Select(t => ops.tensor_id(t)), op.outputs). + ToDictionary(x => x.Item1, x => x.Item2); + var remapped_captures = backward_function.CapturedInputs.Select( + x => capture_mapping.GetOrDefault(ops.tensor_id(x), x) + ); + + List cleaned_doutputs = new(); + foreach(var (doutput, placeholder) in zip(doutputs, _func_graph.Outputs)) + { + if (backprop_util.IsTrainable(placeholder)) + { + if(doutput is IndexedSlices) + { + cleaned_doutputs.Add(ops.convert_to_tensor(doutput)); + } + else if(doutput is null) + { + cleaned_doutputs.Add(default_gradient.zeros_like(placeholder)); + } + else if(doutput is Tensor tensor) + { + cleaned_doutputs.Add(tensor); + } + else + { + throw new ValueError($"Unsupported type {doutput.GetType()} in function _rewrite_forward_and_call_backward"); + } + } + } + + return backward_function.CallFlat(cleaned_doutputs.ToArray(), remapped_captures.ToArray()); + } + + private (EagerDefinedFunction, ConcreteFunction) _construct_forward_backward(int num_doutputs) + { + var trainable_outputs = _func_graph.Outputs.Take(num_doutputs).Where(x => backprop_util.IsTrainable(x)); + + List signature = new(); + foreach(var t in trainable_outputs) + { + var (shape, dtype) = default_gradient.shape_and_dtype(t); + signature.Add(new TensorSpec(shape, dtype)); + } + + Tensor[] _backprop_function(Tensor[] grad_ys) + { + return gradients_util._GradientsHelper(trainable_outputs.ToArray(), _func_graph.Inputs, + grad_ys, src_graph: _func_graph); + } + + _func_graph.as_default(); + FuncGraph backwards_graph = new(monomorphic_function_utils._backward_name(_func_graph.Name)); + FuncGraph.func_graph_from_func(backwards_graph.Name, x => _backprop_function(x.Select(y => + { + Debug.Assert(y is Tensor); + return (Tensor)y; + }).ToArray()), new object[0], new Dictionary(), signature.ToArray(), backwards_graph); + var backwards_graph_captures = backwards_graph.external_captures; + var captures_from_forward = backwards_graph_captures.Where(c => c is not EagerTensor && c.graph == _func_graph); + + HashSet existing_outputs = new HashSet(_func_graph.Outputs); + foreach(var capture in captures_from_forward) + { + if (!existing_outputs.Contains(capture)) + { + existing_outputs.Add(capture); + _func_graph.Outputs.Add(capture); + } + } + + var (forward_function, backward_function) = monomorphic_function_utils._create_forward_backward_with_graph( + _attrs, _func_graph, backwards_graph); + _func_graph.Exit(); + return (forward_function, backward_function); } } } diff --git a/src/TensorFlowNET.Core/Gradients/default_gradient.cs b/src/TensorFlowNET.Core/Gradients/default_gradient.cs new file mode 100644 index 00000000..e6c22e36 --- /dev/null +++ b/src/TensorFlowNET.Core/Gradients/default_gradient.cs @@ -0,0 +1,52 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Gradients +{ + internal static class default_gradient + { + public static (Shape, TF_DataType) shape_and_dtype(Tensor t) + { + if(t.dtype == dtypes.resource) + { + var handle_data = resource_variable_ops.get_eager_safe_handle_data(t); + if(handle_data is null || !handle_data.IsSet || handle_data.ShapeAndType.Count != 1) + { + throw new ValueError($"Internal error: Tried to take gradients (or similar) " + + $"of a variable without handle data:\n{t}"); + } + return (new Shape(handle_data.ShapeAndType[0].Shape), handle_data.ShapeAndType[0].Dtype.as_tf_dtype()); + } + return (t.shape, t.dtype); + } + + public static Tensor zeros_like(Tensor t) + { + if(t.dtype == dtypes.resource) + { + var (shape, dtype) = shape_and_dtype(t); + return array_ops.zeros(shape, dtype); + } + else + { + return array_ops.zeros_like(t); + } + } + + public static TF_DataType get_zeros_dtype(Tensor t) + { + if(t.dtype == dtypes.resource) + { + var handle_data = resource_variable_ops.get_eager_safe_handle_data(t); + if(handle_data is null || !handle_data.IsSet || handle_data.ShapeAndType.Count != 1) + { + throw new ValueError($"Internal error: Tried to take gradients (or similar) " + + $"of a variable without handle data:\n{t}"); + } + return handle_data.ShapeAndType[0].Dtype.as_tf_dtype(); + } + return t.dtype; + } + } +} diff --git a/src/TensorFlowNET.Core/Gradients/gradients_util.cs b/src/TensorFlowNET.Core/Gradients/gradients_util.cs index e6312c0d..10166911 100644 --- a/src/TensorFlowNET.Core/Gradients/gradients_util.cs +++ b/src/TensorFlowNET.Core/Gradients/gradients_util.cs @@ -14,10 +14,15 @@ limitations under the License. ******************************************************************************/ +using Google.Protobuf; using System; using System.Collections.Generic; +using System.Diagnostics; using System.Linq; +using Tensorflow.Functions; +using Tensorflow.Gradients; using Tensorflow.Graphs; +using Tensorflow.Operations; using Tensorflow.Operations.ControlFlows; using static Tensorflow.Binding; @@ -148,7 +153,7 @@ namespace Tensorflow Tensor[] in_grads = null; Func grad_fn = null; var is_partitioned_call = _IsPartitionedCall(op); - var is_func_call = false; + var is_func_call = src_graph.IsFunction(op.type) || is_partitioned_call; var has_out_grads = out_grads.Exists(x => x != null); if (has_out_grads && !stop_ops.Contains(op)) { @@ -162,14 +167,41 @@ namespace Tensorflow { if (is_func_call) { + EagerDefinedFunction func_call = null; if (is_partitioned_call) { - + var func_attr = op.get_attr("f"); + Debug.Assert(func_attr is NameAttrList); + var func_name = ((NameAttrList)func_attr).Name; + func_call = src_graph._get_function(func_name); + if(func_call is null && src_graph.OuterGraph is not null) + { + var graph = src_graph.OuterGraph; + while(graph is not null) + { + func_call = graph._get_function(func_name); + if(func_call is not null) + { + break; + } + if(graph.OuterGraph is not null) + { + graph = graph.OuterGraph; + } + else + { + break; + } + } + } } else { - + func_call = src_graph._get_function(op.type); } + // skip the following codes: + // `func_call = getattr(op, "__defun", func_call)` + grad_fn = func_call.csharp_grad_func; } else { @@ -213,6 +245,8 @@ namespace Tensorflow } else { + in_grads = _MaybeCompile(grad_scope, op, out_grads.Where(x => x != null).Select(x => x[0]).ToArray(), + null, (x, y) => _SymGrad(x, y)); throw new NotImplementedException("lambda: _SymGrad(op, out_grads)"); } _VerifyGeneratedGradients(in_grads, op); @@ -668,6 +702,36 @@ namespace Tensorflow dtypes.resource, dtypes.variant}.Contains(dtype); } + public static int PossibleTapeGradientTypes(Tensor[] tensors) + { + var tape_set = tf.GetTapeSet(); + bool some_tape_watching = false; + if(tape_set is not null && tape_set.Count > 0) + { + foreach(var tape in tape_set) + { + if (tape.ShouldRecord(tensors)) + { + if(tape.Persistent || some_tape_watching) + { + return POSSIBLE_GRADIENT_TYPES_HIGHER_ORDER; + } + some_tape_watching = true; + } + } + } + // skip the forward_accumulators. + + if (some_tape_watching) + { + return POSSIBLE_GRADIENT_TYPES_FIRST_ORDER; + } + else + { + return POSSIBLE_GRADIENT_TYPES_NONE; + } + } + /// /// Return true if op has real gradient. /// @@ -688,7 +752,7 @@ namespace Tensorflow private static Tensor[] _MaybeCompile(string scope, Operation op, Tensor[] out_grads, Action func, Func grad_fn) { - scope = scope.EndsWith("/") ? scope.Substring(0, scope.Length - 1) : scope; + //scope = scope.TrimEnd('/').Replace('/', '_'); return grad_fn(op, out_grads); } @@ -701,5 +765,28 @@ namespace Tensorflow throw new ValueError($"Num gradients {grads.Length} generated for op {op.node_def} do not match num " + $"inputs {op.inputs._inputs.Count()}"); } + + private static Tensor[] _SymGrad(Operation op, Tensor[] out_grads) + { + var f_in = ((Tensor[])op.inputs).Concat(out_grads).ToArray(); + var f_types = ((Tensor[])op.inputs).Select(x => default_gradient.get_zeros_dtype(x)).ToArray(); + NameAttrList f = new(); + if (_IsPartitionedCall(op)) + { + var func_attr = op.get_attr("f"); + Debug.Assert(func_attr is NameAttrList); + f.Name = ((NameAttrList)func_attr).Name; + } + else + { + f.Name = op.type; + } + foreach(var k in op.node_def.Attr.Keys) + { + f.Attr[k] = AttrValue.Parser.ParseFrom(op.node_def.Attr[k].ToByteArray()); + } + var in_grads = gen_functional_ops.symbolic_gradient(f_in, f_types, f); + return in_grads; + } } } diff --git a/src/TensorFlowNET.Core/Gradients/ops.gradient_function_mapping.cs b/src/TensorFlowNET.Core/Gradients/ops.gradient_function_mapping.cs index e5831f25..7d3ea171 100644 --- a/src/TensorFlowNET.Core/Gradients/ops.gradient_function_mapping.cs +++ b/src/TensorFlowNET.Core/Gradients/ops.gradient_function_mapping.cs @@ -98,12 +98,23 @@ namespace Tensorflow { if (op.inputs == null) return null; - RegisterFromAssembly(); + var gradient_function = op._gradient_function; + if(gradient_function is null) + { + RegisterFromAssembly(); + + if (!gradientFunctions.ContainsKey(op.type)) + throw new LookupError($"can't get graident function through get_gradient_function {op.type}"); - if (!gradientFunctions.ContainsKey(op.type)) - throw new LookupError($"can't get graident function through get_gradient_function {op.type}"); + return gradientFunctions[op.type]; + } - return gradientFunctions[op.type]; + Tensor[] wrapped_gradient_function(Operation operation, Tensor[] args) + { + return gradient_function(operation, args); + } + // TODO(Rinne): check if this needs to be registered. + return wrapped_gradient_function; } } } diff --git a/src/TensorFlowNET.Core/Graphs/FuncGraph.cs b/src/TensorFlowNET.Core/Graphs/FuncGraph.cs index 9367414e..9ef0b95b 100644 --- a/src/TensorFlowNET.Core/Graphs/FuncGraph.cs +++ b/src/TensorFlowNET.Core/Graphs/FuncGraph.cs @@ -1,6 +1,15 @@ using Google.Protobuf; +using System; +using System.Buffers; +using System.Diagnostics; +using System.Linq; using Tensorflow.Eager; using Tensorflow.Exceptions; +using Tensorflow.Framework; +using Tensorflow.Framework.Models; +using Tensorflow.Functions; +using Tensorflow.Operations; +using Tensorflow.Util; using static Tensorflow.Binding; namespace Tensorflow.Graphs; @@ -11,12 +20,65 @@ namespace Tensorflow.Graphs; public class FuncGraph : Graph, IDisposable { internal SafeFuncGraphHandle _func_graph_handle; + internal HashSet _resource_tensor_inputs; + internal HashSet> _watched_variables; + internal IEnumerable> _weak_variables; + internal object[] _structured_outputs; + internal Dictionary _output_names; public string FuncName => _graph_key; public Tensors Inputs { get; set; } = new Tensors(); public Tensors Outputs { get; set; } = new Tensors(); + public Tensors FlatStructuredOutputs + { + get + { + List res = new(); + foreach(var obj in _structured_outputs) + { + if(obj is Tensor tensor) + { + res.Add(tensor); + } + else if(obj is IEnumerable tensors) + { + res.AddRange(tensors); + } + else + { + throw new TypeError("The structured outputs member should be tensor or tensors."); + } + } + return res; + } + } public string Name { get; set; } - public Dictionary Attrs { get; set; } + public IEnumerable Variables + { + get + { + return _weak_variables.Select(v => + { + if (v.TryGetTarget(out var target)) + { + return target; + } + else + { + throw new AssertionError("Called a function referencing variables which have been deleted. " + + "This likely means that function-local variables were created and " + + "not referenced elsewhere in the program. This is generally a " + + "mistake; consider storing variables in an object attribute on first call."); + } + }); + } + internal set + { + _weak_variables = value.Select(x => new WeakReference(x)); + } + } + public IEnumerable TrainableVariables => Variables.Where(v => v.Trainable); + public Dictionary Attrs { get; set; } Dictionary _captures = new Dictionary(); @@ -42,9 +104,12 @@ public class FuncGraph : Graph, IDisposable outer_graph = outer_graph.OuterGraph; _graph_key = Name = name; building_function = true; + _weak_variables = new List>(); + _resource_tensor_inputs = new HashSet(); + _watched_variables = new HashSet>(); } - public FuncGraph(SafeGraphHandle handle, string name, Dictionary attrs) : base() + public FuncGraph(SafeGraphHandle handle, string name, Dictionary attrs) : base() { outer_graph = ops.get_default_graph(); while (outer_graph.building_function) @@ -55,6 +120,9 @@ public class FuncGraph : Graph, IDisposable // Will to test if FuncGraph has memory leak // c_api.TF_DeleteGraph(_handle); _handle = handle; + _weak_variables = new List>(); + _resource_tensor_inputs = new HashSet(); + _watched_variables = new HashSet>(); } public void replace_capture(Tensor tensor, Tensor placeholder) @@ -62,14 +130,14 @@ public class FuncGraph : Graph, IDisposable _captures[tensor.Id] = (tensor, placeholder); } - public void ToGraph(Operation[] opers, + public unsafe void ToGraph(Operation[] opers, Tensor[] inputs, Tensor[] outputs, string[] output_names) { var status = new Status(); - if (output_names != null && output_names.Length == 0) + if (output_names is null) { - output_names = null; + output_names = new string[0]; }; _func_graph_handle = c_api.TF_GraphToFunction(_handle, @@ -81,7 +149,7 @@ public class FuncGraph : Graph, IDisposable inputs.Select(x => new TF_Output(x.op, 0)).ToArray(), outputs.Length, outputs.Select(x => new TF_Output(x.op, 0)).ToArray(), - output_names, + output_names.Length != outputs.Length ? null : output_names, IntPtr.Zero, null, status); @@ -211,6 +279,19 @@ public class FuncGraph : Graph, IDisposable Inputs.Add(placeholder); } + Tensor pop_capture(Tensor tensor) + { + if(_captures.TryGetValue(tensor.Id, out var capture)) + { + _captures.Remove(tensor.Id); + return capture.Item2; + } + else + { + return null; + } + } + Tensor _create_substitute_placeholder(Tensor value, string name = null, TF_DataType dtype = TF_DataType.DtInvalid, @@ -234,10 +315,7 @@ public class FuncGraph : Graph, IDisposable foreach (var (_name, attr_value) in enumerate(Attrs)) { - var serialized = new AttrValue - { - S = ByteString.CopyFromUtf8(attr_value) - }.ToByteArray(); + var serialized = attr_value.ToByteArray(); c_api.TF_FunctionSetAttrValueProto(_func_graph_handle, _name, serialized, serialized.Length, tf.Status); tf.Status.Check(true); } @@ -260,4 +338,261 @@ public class FuncGraph : Graph, IDisposable { c_api.TFE_ContextRemoveFunction(tf.Context, _graph_key, tf.Status); } + + public static FuncGraph func_graph_from_func(string name, Func func, + object[] args, Dictionary kwargs, TensorSpec[] signature = null, + FuncGraph func_graph = null, bool autograph = false, object autograph_options = null, + bool add_control_dependencies = true, string[] arg_names = null, + Tensor op_return_value = null, bool capture_by_value = false, + bool acd_record_initial_resource_uses = false) + { + if(func_graph is null) + { + func_graph = new FuncGraph(name); + } + + // TODO(Rinne): deal with control dependencies. + + func_graph.as_default(); + var current_scope = variable_scope.get_variable_scope(); + var default_use_resource = current_scope.use_resource; + current_scope.use_resource = true; + + if(signature is not null) + { + args = signature; + kwargs = new Dictionary(); + } + var func_args = _get_defun_inputs_from_args(args, arg_names); + var func_kwargs = _get_defun_inputs_from_kwargs(kwargs); + + if(func_kwargs is not null && func_kwargs.Count > 0) + { + throw new NotImplementedException("The keyword args has not been supported in `func_graph_from_func`."); + } + + foreach(var arg in nest.flatten(new object[] { func_args, func_kwargs })) + { + if(arg is Tensor tensor && tensor.dtype == dtypes.resource) + { + func_graph._resource_tensor_inputs.Add(tensor); + } + else if (arg is ResourceVariable variable) + { + func_graph._resource_tensor_inputs.Add(variable.Handle); + } + } + + // skip the assignment of `func_graph.structured_input_signature`. + + var flat_func_args = nest.flatten(func_args as object); + var flat_func_kwargs = nest.flatten(func_kwargs as object); + func_graph.Inputs = new Tensors(flat_func_args.concat(flat_func_kwargs) + .Where(x => x is Tensor).Select(x => (Tensor)x)); + + //var func_args_before = nest.pack_sequence_as(func_args, flat_func_args, true); + //var func_kwargs_before = nest.pack_sequence_as(func_kwargs, flat_func_kwargs, true); + + Tensor convert(object x) + { + if (x is null) return null; + Tensor res = null; + if(op_return_value is not null && x is Operation) + { + tf_with(ops.control_dependencies(new object[] { x }), _ => + { + res = array_ops.identity(op_return_value); + }); + } + else if(x is not TensorArray) + { + Debug.Assert(x is Tensor); + res = ops.convert_to_tensor_or_composite(x as Tensor); + } + else + { + throw new NotImplementedException($"The `TensorArray` is not supported here currently."); + } + if (add_control_dependencies) + { + // TODO(Rinne): `x = deps_ctx.mark_as_return(x)`. + } + return res; + } + + if (autograph) + { + throw new NotImplementedException("The autograph of `func_graph_from_func` has not been supported."); + } + + var func_outputs = func(func_args); + func_outputs = variable_utils.convert_variables_to_tensors(func_outputs); + func_outputs = func_outputs.Select(x => convert(x)).ToArray(); + // TODO(Rinne): `check_func_mutation`. + + current_scope.use_resource = default_use_resource; + + var graph_variables = func_graph._watched_variables.ToList(); + HashSet arg_variables = new HashSet(); + List inputs = new(); + foreach(var arg in composite_tensor_utils.flatten_with_variables(func_args)) + { + if(arg is BaseResourceVariable variable) + { + var resource_placeholder = func_graph.pop_capture(variable.Handle); + if(resource_placeholder is null) + { + continue; + } + Debug.Assert(variable is IVariableV1); + arg_variables.Add(variable as IVariableV1); + inputs.Add(resource_placeholder); + } + else if(arg is Tensor tensor) + { + inputs.Add(tensor); + } + } + var variables = graph_variables.Select(v => + { + if (v.TryGetTarget(out var target)) + { + return target; + } + else + { + return null; + } + }).Where(v => v is not null && !arg_variables.Contains(v)); + func_graph.Inputs = inputs.Concat(func_graph.internal_captures).ToArray(); + func_graph._structured_outputs = func_outputs; + func_graph.Outputs.AddRange(func_graph.FlatStructuredOutputs.Where(x => x is not null) + .Select(x => func_graph.capture(x))); + + func_graph.Variables = variables; + + func_graph.Exit(); + + if (add_control_dependencies) + { + // TODO(Rinne): implement it. + } + return func_graph; + } + + private static object[] _get_defun_inputs_from_args(object[] args, string[] names) + { + return _get_defun_inputs(args, names, args) as object[]; + } + + private static Dictionary _get_defun_inputs_from_kwargs(Dictionary kwargs) + { + // TODO(Rinne): implement it. + Debug.Assert(kwargs is null || kwargs.Count == 0); + return kwargs; + //string[] names; + //object[] args; + //if(kwargs is not null && kwargs.Count > 0) + //{ + // var sorted_kwargs = kwargs.OrderBy(x => x.Key); + // names = sorted_kwargs.Select(x => x.Key).ToArray(); + // args = sorted_kwargs.Select(x => x.Value).ToArray(); + //} + //else + //{ + // names = new string[0]; + // args = new object[0]; + //} + //return _get_defun_inputs(args, names, kwargs) as Dictionary; + } + + private static object _get_defun_inputs(object[] args, string[] names, object structured_args) + { + List function_inputs = new(); + if(names is null) + { + names = new string[args.Length]; + } + + foreach(var (arg_value, name) in zip(args, names)) + { + foreach(var val in composite_tensor_utils.flatten_with_variables_or_variable_specs(arg_value)) + { + function_inputs.Add(_get_defun_input(val, name)); + } + } + return nest.pack_sequence_as(structured_args, nest.flatten(function_inputs), true); + } + + private static object _get_defun_input(object arg, string name) + { + var func_graph = ops.get_default_graph() as FuncGraph; + Debug.Assert(func_graph is not null); + if (arg is Tensor tensor) + { + Tensor placeholder; + try + { + placeholder = tf.placeholder(tensor.dtype, tensor.shape, name); + } + catch (ValueError) + { + // TODO(Rinne): Add warning here. + placeholder = tf.placeholder(tensor.dtype, tensor.shape); + } + handle_data_util.copy_handle_data(tensor, placeholder); + if (name is not null) + { + placeholder.op._set_attr("_user_specified_name", new AttrValue() + { + S = tf.compat.as_bytes(name) + }); + } + return placeholder; + } + else if (arg is TensorSpec spec) + { + string requested_name; + if (!string.IsNullOrEmpty(spec.name)) + { + requested_name = spec.name; + } + else + { + requested_name = name; + } + Tensor placeholder; + try + { + placeholder = tf.placeholder(spec.dtype, spec.shape, requested_name); + } + catch (ValueError) + { + // TODO(Rinne): Add warning here. + placeholder = tf.placeholder(spec.dtype, spec.shape); + } + if (name is not null) + { + placeholder.op._set_attr("_user_specified_name", new AttrValue() + { + S = tf.compat.as_bytes(requested_name) + }); + } + return placeholder; + } + else if (arg is BaseResourceVariable variable) + { + var placeholder = func_graph.capture(variable.Handle, name); + placeholder.op._set_attr("_user_specified_name", new AttrValue() + { + S = tf.compat.as_bytes(name) + }); + return arg; + } + // TODO(Rinne): deal with `VariableSpec`. + else + { + return arg; + } + } } diff --git a/src/TensorFlowNET.Core/Graphs/Graph.Gradient.cs.cs b/src/TensorFlowNET.Core/Graphs/Graph.Gradient.cs.cs index 91aef2dc..bed8b35c 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.Gradient.cs.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.Gradient.cs.cs @@ -1,4 +1,6 @@ -namespace Tensorflow +using Tensorflow.Graphs; + +namespace Tensorflow { public partial class Graph { @@ -6,5 +8,10 @@ { } + + internal GraphOverrideGradientContext _override_gradient_function(Dictionary> gradient_function_map) + { + return new GraphOverrideGradientContext(this, gradient_function_map); + } } } diff --git a/src/TensorFlowNET.Core/Graphs/Graph.Operation.cs b/src/TensorFlowNET.Core/Graphs/Graph.Operation.cs index fc356687..c788aaf0 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.Operation.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.Operation.cs @@ -118,7 +118,7 @@ namespace Tensorflow /// (Optional.) If True, device functions will be executed /// to compute the device property of the Operation. /// An `Operation` object. - public Operation _create_op_from_tf_operation(IntPtr c_op, bool compute_device = true) + public Operation _create_op_from_tf_operation(IntPtr c_op, bool compute_device = true, OperationDescription desc = null) { var ret = new Operation(c_op, this); _add_op(ret); diff --git a/src/TensorFlowNET.Core/Graphs/Graph.cs b/src/TensorFlowNET.Core/Graphs/Graph.cs index e583868e..f443bcff 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.cs @@ -21,6 +21,7 @@ using System.Collections.Specialized; using System.Linq; using Tensorflow.Framework; using Tensorflow.Functions; +using Tensorflow.Common.Extensions; using static Tensorflow.Binding; namespace Tensorflow @@ -88,6 +89,7 @@ namespace Tensorflow private List _unfetchable_ops = new List(); private List _unfeedable_tensors = new List(); private Dictionary _functions = new(); + internal Dictionary> _gradient_function_map = new(); private VersionDef _graph_def_versions = new VersionDef() { Producer = versions.GRAPH_DEF_VERSION, @@ -161,13 +163,30 @@ namespace Tensorflow return _functions.ContainsKey(tf.compat.as_str(name)); } - public void AddFunction(EagerDefinedFunction function) + internal void AddFunction(EagerDefinedFunction function) { _check_not_finalized(); var name = function.Name; + if(function._grad_func_name is not null && function.csharp_grad_func is not null) + { + throw new ValueError($"Gradient defined twice for function {name}"); + } - // TODO(Rinne): deal with c_graph + var c_graph = this.c_graph; + var func = function._c_func.Get(); + Status status = new(); + if (function._grad_func is not null) + { + var gradient = function._grad_func._c_func.Get(); + c_api.TF_GraphCopyFunction(c_graph, func, gradient, status); + status.Check(true); + } + else + { + c_api.TF_GraphCopyFunction(c_graph, func, new SafeFuncGraphHandle(IntPtr.Zero), status); + status.Check(true); + } _functions[tf.compat.as_str(name)] = function; @@ -332,6 +351,9 @@ namespace Tensorflow private void _create_op_helper(Operation op, bool compute_device = true) { + // high priority + // TODO(Rinne): complete the implementation. + op._gradient_function = _gradient_function_map.GetOrDefault(op.type, null); _record_op_seen_by_control_dependencies(op); } @@ -548,6 +570,11 @@ namespace Tensorflow ops.pop_graph(); } + internal EagerDefinedFunction _get_function(string name) + { + return _functions.GetOrDefault(name, null); + } + string debugString = string.Empty; public override string ToString() { diff --git a/src/TensorFlowNET.Core/Graphs/GraphOverrideGradientContext.cs b/src/TensorFlowNET.Core/Graphs/GraphOverrideGradientContext.cs new file mode 100644 index 00000000..2befbbff --- /dev/null +++ b/src/TensorFlowNET.Core/Graphs/GraphOverrideGradientContext.cs @@ -0,0 +1,37 @@ +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Text; + +namespace Tensorflow.Graphs +{ + internal class GraphOverrideGradientContext: ITensorFlowObject + { + Graph _graph; + Dictionary> _new_gradient_function_map; + public GraphOverrideGradientContext(Graph graph, + Dictionary> new_gradient_function_map) + { + _graph = graph; + _new_gradient_function_map = new_gradient_function_map; + } + + [DebuggerStepThrough] + public void __enter__() + { + Debug.Assert(_graph._gradient_function_map.Count == 0); + _graph._gradient_function_map = _new_gradient_function_map; + } + + [DebuggerStepThrough] + public void __exit__() + { + _graph._gradient_function_map = new Dictionary>(); + } + + public void Dispose() + { + + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/Operation.cs b/src/TensorFlowNET.Core/Operations/Operation.cs index 28e69886..ca00710c 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.cs @@ -20,6 +20,9 @@ using System.Collections.Generic; using System.Linq; using Tensorflow.Util; using static Tensorflow.Binding; +using Google.Protobuf; +using Google.Protobuf.WellKnownTypes; +using System.Diagnostics; namespace Tensorflow { @@ -47,6 +50,8 @@ namespace Tensorflow private readonly Graph _graph; + internal Func _gradient_function; + public string type => OpType; public Graph graph => _graph; @@ -61,7 +66,7 @@ namespace Tensorflow public string Device => _handle == IntPtr.Zero ? "" : c_api.StringPiece(c_api.TF_OperationDevice(_handle)); - // OperationDescription _opDesc; + //private OperationDescription _op_desc; public NodeDef node_def => GetNodeDef(); @@ -216,21 +221,19 @@ namespace Tensorflow var x = AttrValue.Parser.ParseFrom(buf.ToArray()); - string oneof_value = x.ValueCase.ToString(); - if (string.IsNullOrEmpty(oneof_value)) - return null; + var oneof_value = x.ValueCase; + if (oneof_value == AttrValue.ValueOneofCase.None) + return new object[0]; - switch (oneof_value.ToLower()) + if(oneof_value == AttrValue.ValueOneofCase.List) { - case "list": - throw new NotImplementedException($"Unsupported field type in {oneof_value}"); - case "type": - return x.Type; - case "s": - return x.S.ToStringUtf8(); - default: - return x.GetType().GetProperty(oneof_value).GetValue(x); + throw new NotImplementedException($"Unsupported field type in {oneof_value}"); } + if(oneof_value == AttrValue.ValueOneofCase.Type) + { + return dtypes.as_tf_dtype(x.Type); + } + return ProtoUtils.GetSingleAttrValue(x, oneof_value); } public TF_AttrMetadata GetAttributeMetadata(string attr_name, Status s) @@ -309,5 +312,83 @@ namespace Tensorflow } public NDArray numpy() => throw new NotImplementedException(""); + + internal void _add_outputs(TF_DataType[] types, Shape[] shapes) + { + Debug.Assert(types.Length == shapes.Length); + int orig_num_outputs = this.outputs.Length; + //var new_outputs = new List(_outputs); + + var old_outputs = _outputs; + _outputs = new Tensor[orig_num_outputs + types.Length]; + for(int i = 0; i < orig_num_outputs; i++) + { + _outputs[i] = old_outputs[i]; + } + + // Since the `_outputs` is defined as `Array`, when we add new output, we + // have to create a new array, which brings some performance concerns. + // In the future maybe the type of `outputs` should be reconsidered. + for(int i = 0; i < types.Length; i++) + { + var t = new Tensor(this, orig_num_outputs + 1, types[i]); + _outputs[i] = t; + //t = tf.ensure_shape(t, shapes[i]); + t.shape = shapes[i]; + //new_outputs.Add(t); + } + //_outputs = new_outputs.ToArray(); + } + + internal void _set_func_attr(string attr_name, string func_name) + { + var func = new NameAttrList() { Name = func_name }; + _set_attr(attr_name, new AttrValue() { Func = func }); + } + + internal void _set_type_list_attr(string attr_name, DataType[] types) + { + if(types is null || types.Length == 0) + { + return; + } + var type_list = new AttrValue.Types.ListValue(); + type_list.Type.AddRange(types); + _set_attr(attr_name, new AttrValue() { List = type_list }); + } + + internal void _set_attr(string attr_name, AttrValue attr_value) + { + var buffer = new Buffer(attr_value.ToByteArray()); + try + { + _set_attr_with_buf(attr_name, buffer); + } + finally + { + buffer.Release(); + } + } + + internal void _set_attr_with_buf(string attr_name, Buffer attr_buf) + { + //if(_op_desc is null) + //{ + // //var new_node_def = NodeDef.Parser.ParseFrom(node_def.ToByteArray()); + // //new_node_def.Name += "_temp"; + // //var op = new Operation(new_node_def, graph, inputs, _output_types, control_inputs, _input_types); + // //Status status = new(); + // //c_api.TF_SetAttrBool(op._op_desc, "trainable", true); + // ////c_api.TF_SetAttrValueProto(op._op_desc, attr_name, attr_buf.ToArray(), attr_buf.Length, status); + // //status.Check(true); + // // TODO(Rinne): deal with it. Give a warning or make the Operation always contains `op_desc`. + //} + //else + //{ + // //Status status = new(); + // //c_api.TF_SetAttrValueProto(_op_desc, attr_name, attr_buf.ToArray(), attr_buf.Length, status); + // //status.Check(true); + //} + } } } \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Operations/c_api.ops.cs b/src/TensorFlowNET.Core/Operations/c_api.ops.cs index 46a654e0..43dc8cd4 100644 --- a/src/TensorFlowNET.Core/Operations/c_api.ops.cs +++ b/src/TensorFlowNET.Core/Operations/c_api.ops.cs @@ -208,9 +208,9 @@ namespace Tensorflow [DllImport(TensorFlowLibName)] public static extern int TF_OperationOutputListLength(IntPtr oper, string arg_name, SafeStatusHandle status); - [DllImport(TensorFlowLibName)] - public static extern IntPtr GetHandleShapeAndType(SafeGraphHandle c_graph, TF_Output output); - [DllImport(TensorFlowLibName)] - public static extern void SetHandleShapeAndType(SafeGraphHandle c_graph, TF_Output output, byte[] data); + //[DllImport(TensorFlowLibName)] + //public static extern IntPtr GetHandleShapeAndType(SafeGraphHandle c_graph, TF_Output output); + //[DllImport(TensorFlowLibName)] + //public static extern void SetHandleShapeAndType(SafeGraphHandle c_graph, TF_Output output, byte[] data); } } diff --git a/src/TensorFlowNET.Core/Operations/functional_ops.cs b/src/TensorFlowNET.Core/Operations/functional_ops.cs index 2d447207..9c2e85d1 100644 --- a/src/TensorFlowNET.Core/Operations/functional_ops.cs +++ b/src/TensorFlowNET.Core/Operations/functional_ops.cs @@ -39,7 +39,7 @@ namespace Tensorflow if (config is null) { - config = function_utils.get_disabled_rewriter_config(); + config = function_utils.get_disabled_rewriter_config().ToString(); } if (executor_type is null) diff --git a/src/TensorFlowNET.Core/Operations/gen_functional_ops.cs b/src/TensorFlowNET.Core/Operations/gen_functional_ops.cs index ce37ec7d..bb84ac39 100644 --- a/src/TensorFlowNET.Core/Operations/gen_functional_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_functional_ops.cs @@ -79,5 +79,50 @@ namespace Tensorflow.Operations }; } + + public static Tensor[] symbolic_gradient(Tensor[] input, TF_DataType[] Tout, NameAttrList f, string name = null) + { + var ctx = tf.Context; + if (ctx.executing_eagerly()) + { + try + { + var _result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo( + "SymbolicGradient", name, input, Tout, f)); + return _result; + } + catch (Exception) + { + + } + + try + { + return symbolic_gradient_eager_fallback(input, Tout, f, name, ctx); + } + catch (Exception) + { + + } + } + var op = tf.OpDefLib._apply_op_helper("SymbolicGradient", name, new object[] { input, Tout, f }); + var result = op.outputs; + if (execute.must_record_gradient()) + { + throw new NotImplementedException(); + } + return result; + } + + public static Tensor[] symbolic_gradient_eager_fallback(Tensor[] input, TF_DataType[] Tout, NameAttrList f, string name, Context ctx) + { + object[] attrs = new object[] { "Tin", input, "Tout", Tout, "f", f }; + var result = execute.executes("SymbolicGradient", Tout.Length, input, attrs, ctx, name); + if (execute.must_record_gradient()) + { + throw new NotImplementedException(); + } + return result; + } } } diff --git a/src/TensorFlowNET.Core/Operations/gen_ops.cs b/src/TensorFlowNET.Core/Operations/gen_ops.cs index bf178b60..8f8b2f11 100644 --- a/src/TensorFlowNET.Core/Operations/gen_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_ops.cs @@ -10050,13 +10050,51 @@ namespace Tensorflow.Operations /// public static Tensor ensure_shape(Tensor input, Shape shape, string name = "EnsureShape") { + var ctx = tf.Context; + if (ctx.executing_eagerly()) + { + try + { + var _result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo("EnsureShape", name, input, shape)); + return _result[0]; + } + catch (Exception) + { + Console.WriteLine(); + } + try + { + return ensure_shape_eager_fallback(input, shape, name, ctx); + } + catch (Exception) + { + Console.WriteLine(); + } + } + var dict = new Dictionary(); dict["input"] = input; dict["shape"] = shape; var op = tf.OpDefLib._apply_op_helper("EnsureShape", name: name, keywords: dict); + if (execute.must_record_gradient()) + { + throw new NotImplementedException(); + } return op.output; } + public static Tensor ensure_shape_eager_fallback(Tensor input, Shape shape, string name, Context ctx) + { + object[] attrs = new object[4] { "shape", shape, "T", input.dtype.as_datatype_enum() }; + var _result = execute.executes("EnsureShape", 1, new Tensor[] { input }, + attrs, ctx, name); + if (execute.must_record_gradient()) + { + throw new NotImplementedException(); + } + return _result[0]; + } + /// /// Creates or finds a child frame, and makes data available to the child frame. /// diff --git a/src/TensorFlowNET.Core/Operations/handle_data_util.cs b/src/TensorFlowNET.Core/Operations/handle_data_util.cs index 5d5fbebb..66daa5c0 100644 --- a/src/TensorFlowNET.Core/Operations/handle_data_util.cs +++ b/src/TensorFlowNET.Core/Operations/handle_data_util.cs @@ -52,5 +52,7 @@ namespace Tensorflow.Operations // TODO(Rinne): enable it. (currently the internal c api cannot be invoked.) //c_api.SetHandleShapeAndType(target_t.graph.c_graph, target_t._as_tf_output(), handle_data.ToByteArray()); } + + public static HandleData get_resource_handle_data(Tensor graph_op) => ops.get_resource_handle_data(graph_op); } } diff --git a/src/TensorFlowNET.Core/Operations/resource_variable_ops.cs b/src/TensorFlowNET.Core/Operations/resource_variable_ops.cs index 7921f28b..3e39338b 100644 --- a/src/TensorFlowNET.Core/Operations/resource_variable_ops.cs +++ b/src/TensorFlowNET.Core/Operations/resource_variable_ops.cs @@ -24,6 +24,7 @@ using static Tensorflow.CppShapeInferenceResult.Types; using static Tensorflow.Binding; using Tensorflow.Operations; using System.Buffers; +using Tensorflow.Eager; namespace Tensorflow { @@ -41,12 +42,7 @@ namespace Tensorflow name: name); } - public static bool is_resource_variable(IVariableV1 var) - { - return var is BaseResourceVariable; - } - - public static bool is_resource_variable(Trackable var) + public static bool is_resource_variable(object var) { return var is BaseResourceVariable; } @@ -138,10 +134,27 @@ namespace Tensorflow /// internal unsafe static void _set_handle_shapes_and_types(Tensor tensor, HandleData handle_data, bool graph_mode) { - tensor.HandleData = handle_data; if (!graph_mode) return; + var size = handle_data.ShapeAndType.Count; + + var shapes = new IntPtr[size]; + var types = new DataType[size]; + var ranks = new int[size]; + + for (int i = 0; i < size; i++) + { + var shapeAndType = handle_data.ShapeAndType[i]; + types[i] = shapeAndType.Dtype; + ranks[i] = shapeAndType.Shape.UnknownRank ? -1 : shapeAndType.Shape.Dim.Count; + var dims = shapeAndType.Shape.Dim.Select(x => x.Size).ToArray(); + } + + //tensor.HandleData = handle_data; + //if (!graph_mode) + // return; + //var shapes = handle_data.ShapeAndType.Select(x => x.Shape); //var types = handle_data.ShapeAndType.Select(x => x.Dtype).ToArray(); //var ranks = shapes.Select(s => s.UnknownRank ? -1 : s.Dim.Count).ToArray(); @@ -196,24 +209,6 @@ namespace Tensorflow throw new NotImplementedException(""); } - private static HandleData get_eager_safe_handle_data(Tensor handle) - { - if (handle.Handle == null) - { - var data = new HandleData(); - data.ShapeAndType.Add(new HandleShapeAndType - { - Shape = handle.shape.as_shape_proto(), - Dtype = handle.dtype.as_datatype_enum() - }); - return data; - } - else - { - return HandleData.Parser.ParseFrom(handle.BufferToArray()); - } - } - /// /// Copies an existing variable to a new graph, with no initializer. /// @@ -281,5 +276,31 @@ namespace Tensorflow } } } + + public static HandleData get_eager_safe_handle_data(Tensor handle) + { + if (handle.Handle == null) + { + var data = new HandleData(); + data.ShapeAndType.Add(new HandleShapeAndType + { + Shape = handle.shape.as_shape_proto(), + Dtype = handle.dtype.as_datatype_enum() + }); + return data; + } + else + { + return HandleData.Parser.ParseFrom(handle.BufferToArray()); + } + //if(handle is EagerTensor) + //{ + // return handle.HandleData; + //} + //else + //{ + // return handle_data_util.get_resource_handle_data(handle); + //} + } } } diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs index 79b8d2c5..fff3cde5 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs @@ -101,6 +101,7 @@ namespace Tensorflow _op = op; _value_index = value_index; _override_dtype = dtype; + _tf_output = null; _id = ops.uid(); } diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.cs b/src/TensorFlowNET.Core/Tensors/Tensor.cs index 0bffbfba..6ca65a07 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.cs @@ -136,9 +136,9 @@ namespace Tensorflow protected virtual void SetShapeInternal(Shape value) { if (value == null) - c_api.TF_GraphSetTensorShape(graph, _as_tf_output(), null, -1, tf.Status); + c_api.TF_GraphSetTensorShape(op.graph.c_graph, _as_tf_output(), null, -1, tf.Status); else - c_api.TF_GraphSetTensorShape(graph, _as_tf_output(), value.dims, value.ndim, tf.Status); + c_api.TF_GraphSetTensorShape(op.graph.c_graph, _as_tf_output(), value.dims, value.ndim, tf.Status); } public int[] _shape_tuple() @@ -177,7 +177,9 @@ namespace Tensorflow if (_handle == null) { var output = _as_tf_output(); - int ndim = c_api.TF_GraphGetTensorNumDims(op.graph, output, tf.Status); + Status status = new(); + int ndim = c_api.TF_GraphGetTensorNumDims(op.graph, output, status); + status.Check(true); return ndim; } @@ -199,7 +201,7 @@ namespace Tensorflow public TF_Output _as_tf_output() { if (!_tf_output.HasValue) - _tf_output = new TF_Output(op, value_index); + _tf_output = new TF_Output(op, _value_index); return _tf_output.Value; } diff --git a/src/TensorFlowNET.Core/Tensors/Tensors.cs b/src/TensorFlowNET.Core/Tensors/Tensors.cs index 60972775..3d734cd1 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensors.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensors.cs @@ -56,7 +56,7 @@ namespace Tensorflow public void Add(Tensor tensor) => items.Add(tensor); - public void AddRange(Tensor[] tensors) + public void AddRange(IEnumerable tensors) => items.AddRange(tensors); public void Insert(int index, Tensor tensor) diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/WrapperFunction.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/WrapperFunction.cs index 341a12ab..695eadfd 100644 --- a/src/TensorFlowNET.Core/Training/Saving/SavedModel/WrapperFunction.cs +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/WrapperFunction.cs @@ -12,11 +12,12 @@ namespace Tensorflow.Training.Saving.SavedModel { public WrapperFunction(ConcreteFunction concrete_function): base(concrete_function.func_graph) { - this.forward_backward = concrete_function.forward_backward; - this.Outputs = concrete_function.Outputs; - this.ReturnType = concrete_function.ReturnType; - this.OutputStructure = concrete_function.OutputStructure; - this.ArgKeywords = concrete_function.ArgKeywords; + throw new NotImplementedException(); + //this.forward_backward = concrete_function.forward_backward; + //this.Outputs = concrete_function.Outputs; + //this.ReturnType = concrete_function.ReturnType; + //this.OutputStructure = concrete_function.OutputStructure; + //this.ArgKeywords = concrete_function.ArgKeywords; } } } diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/function_deserialization.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/function_deserialization.cs index 951d7d00..69dd2c10 100644 --- a/src/TensorFlowNET.Core/Training/Saving/SavedModel/function_deserialization.cs +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/function_deserialization.cs @@ -30,6 +30,31 @@ namespace Tensorflow.Training.Saving.SavedModel { var function_spec = _deserialize_function_spec_as_nonmethod(saved_function.FunctionSpec); + Tensor[] restored_function_body(Tensor[] inputs) + { + if(saved_function.ConcreteFunctions is null || saved_function.ConcreteFunctions.Count == 0) + { + throw new ValueError("Found zero restored functions for caller function."); + } + foreach(var function_name in saved_function.ConcreteFunctions) + { + var function = concrete_functions[function_name]; + if(function.CapturedInputs.Any(x => x is null)) + { + throw new ValueError("Looks like you are trying to run a loaded " + + "non-Keras model that was trained using tf.distribute.experimental.ParameterServerStrategy " + + "with variable partitioning, which is not currently supported. Try using Keras to define your model " + + "if possible."); + } + if(_concrete_function_callable_with(function, inputs, false)) + { + return _call_concrete_function(function, inputs); + } + } + throw new ValueError("Unexpected runtime behavior, please submit an issue to " + + "https://github.com/SciSharp/TensorFlow.NET/issues"); + } + List concrete_function_objects = new(); foreach(var concrete_function_name in saved_function.ConcreteFunctions) { @@ -40,17 +65,10 @@ namespace Tensorflow.Training.Saving.SavedModel cf._set_function_spec(function_spec); } - foreach(var function_name in saved_function.ConcreteFunctions) - { - var function = concrete_functions[function_name]; - if(_concrete_function_callable_with(function, null, false)) - { - return new RestoredFunction(null, function, "function_from_deserialization"); - } - } - return new RestoredFunction(x => x, new ConcreteFunction(x => x, TF_DataType.TF_FLOAT), "function_return_itself"); - //throw new ValueError("Unexpected runtime behavior, please submit an issue to " + - // "https://github.com/SciSharp/TensorFlow.NET/issues"); + var restored_function = new RestoredFunction(restored_function_body, nameof(restored_function_body), + function_spec, concrete_function_objects); + + return restored_function; } public static Dictionary load_function_def_library(FunctionDefLibrary library, @@ -102,15 +120,17 @@ namespace Tensorflow.Training.Saving.SavedModel { var orig_name = _fix_fdef_in_place(fdef, functions, load_shared_name_suffix, new_gradient_op_types); - if(saved_object_graph is not null && saved_object_graph.ConcreteFunctions.ContainsKey(orig_name)) + object structured_input_signature = null; + object structured_outputs = null; + if (saved_object_graph is not null && saved_object_graph.ConcreteFunctions.ContainsKey(orig_name)) { - // TODO(Rinne): implement it. - //var proto = saved_object_graph.ConcreteFunctions[orig_name]; - //throw new NotImplementedException(); + var proto = saved_object_graph.ConcreteFunctions[orig_name]; + structured_input_signature = nested_structure_coder.decode_proto(proto.CanonicalizedInputSignature); + structured_outputs = nested_structure_coder.decode_proto(proto.OutputSignature); } graph.as_default(); - var func_graph = function_def_lib.function_def_to_graph(fdef, null, null); + var func_graph = function_def_lib.function_def_to_graph(fdef, structured_input_signature, structured_outputs); graph.Exit(); _restore_gradient_functions(func_graph, renamed_functions, loaded_gradients); @@ -124,7 +144,7 @@ namespace Tensorflow.Training.Saving.SavedModel { fdef.Attr.Remove("_input_shapes"); } - var func = new ConcreteFunction(func_graph, fdef.Attr.ToDictionary(x => x.Key, x => x.Value.S.ToString())); + var func = new ConcreteFunction(func_graph, fdef.Attr.ToDictionary(x => x.Key, x => x.Value)); if(wrapper_function is not null) { throw new NotImplementedException(); @@ -142,8 +162,7 @@ namespace Tensorflow.Training.Saving.SavedModel { var gradient_op_type = gradients_to_register[orig_name]; loaded_gradients[gradient_op_type] = func; - // TODO(Rinne): deal with gradient registry. - //new RegisteredGradient() { RegisteredOpType = gradient_op_type }. + ops.RegisterGradientFunction(gradient_op_type, _gen_gradient_func(func)); } } return functions; @@ -203,6 +222,16 @@ namespace Tensorflow.Training.Saving.SavedModel } } + private static Func _gen_gradient_func(ConcreteFunction func) + { + return (unused_op, result_grads) => + { + result_grads = zip(result_grads, func.func_graph.Inputs) + .Select((item) => item.Item1 is null ? default_gradient.zeros_like(item.Item2) : item.Item1).ToArray(); + return func.CallFlat(result_grads, func.CapturedInputs); + }; + } + private static void _restore_gradient_functions(FuncGraph func_graph, Dictionary renamed_functions, Dictionary loaded_gradients) { foreach(var op in func_graph.get_operations()) @@ -210,14 +239,14 @@ namespace Tensorflow.Training.Saving.SavedModel if(op.op.type == "StatefulPartitionedCall" || op.op.type == "PartitionedCall") { var function = renamed_functions[op.op.node_def.Attr["f"].Func.Name]; - // TODO(Rinne): deal with `op._gradient_function`. + op.op._gradient_function = function._get_gradient_function(); } string gradient_op_type = null; try { gradient_op_type = op.op.get_attr("_gradient_op_type") as string; } - catch(Exception e) + catch(InvalidArgumentError) { continue; } @@ -389,7 +418,7 @@ namespace Tensorflow.Training.Saving.SavedModel concrete_function.ArgKeywords = saved_bare_concrete_function.ArgumentKeywords.ToList(); concrete_function.NumPositionArgs = saved_bare_concrete_function.AllowedPositionalArguments; - var function_spec = _deserialize_function_spec_as_nonmethod(saved_bare_concrete_function.FunctionSpec); + //var function_spec = _deserialize_function_spec_as_nonmethod(saved_bare_concrete_function.FunctionSpec); // TODO(Rinne): set the functiona spec. concrete_function.AddTograph(); return concrete_function; @@ -413,19 +442,31 @@ namespace Tensorflow.Training.Saving.SavedModel return function.CallFlat(inputs, function.CapturedInputs); } - private static bool _concrete_function_callable_with(ConcreteFunction function, Tensors inputs, bool allow_conversion) + private static bool _concrete_function_callable_with(ConcreteFunction function, Tensor[] inputs, bool allow_conversion) { // TODO(Rinne): revise it. - return true; + return function.CapturedInputs.Length + inputs.Length == function.Inputs.Length; + //var expected_inputs = function.func_graph.Inputs; + //foreach(var (arg, expected) in zip(inputs, expected_inputs)) + //{ + // if(arg.Id != expected.Id) + // { + // return false; + // } + //} + //return true; } } public class RestoredFunction : Function { - public RestoredFunction(Func function, ConcreteFunction concrete_function, - string name, bool auto_graph = true): base(function, name, auto_graph) + IEnumerable _concrete_functions; + FunctionSpec _function_spec; + public RestoredFunction(Func function, string name, FunctionSpec function_spec, + IEnumerable concrete_functions): base(function, name, auto_graph: false) { - _concrete_variable_creation_fn = concrete_function; + _concrete_functions = concrete_functions; + _function_spec = function_spec; } protected override bool _run_functions_eagerly() diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/signature_serialization.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/signature_serialization.cs index 4a0d3b00..d3ffebc9 100644 --- a/src/TensorFlowNET.Core/Training/Saving/SavedModel/signature_serialization.cs +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/signature_serialization.cs @@ -102,6 +102,6 @@ public class SignatureMap: Trackable return new Dictionary(); } - return _signatures.TakeWhile(x => x.Value is Function or ConcreteFunction).ToDictionary(x => x.Key, x => x.Value); + return _signatures.Where(x => x.Value is Function or ConcreteFunction).ToDictionary(x => x.Key, x => x.Value); } } diff --git a/src/TensorFlowNET.Core/Training/data_structures.cs b/src/TensorFlowNET.Core/Training/data_structures.cs index a8033f59..6b607e85 100644 --- a/src/TensorFlowNET.Core/Training/data_structures.cs +++ b/src/TensorFlowNET.Core/Training/data_structures.cs @@ -132,8 +132,8 @@ namespace Tensorflow.Training { get { - var trainable_extra_variables = _self_extra_variables.TakeWhile(x => x.Trainable).ToList(); - var non_trainable_extra_variables = _self_extra_variables.TakeWhile(x => !x.Trainable).ToList(); + var trainable_extra_variables = _self_extra_variables.Where(x => x.Trainable).ToList(); + var non_trainable_extra_variables = _self_extra_variables.Where(x => !x.Trainable).ToList(); List non_trainable_variables = new(); foreach(var obj in Values) { @@ -576,7 +576,7 @@ namespace Tensorflow.Training if(save_type == SaveType.SAVEDMODEL) { - children = children.Concat(this.TakeWhile(x => x is Function or ConcreteFunction).Select((x, idx) => new KeyValuePair(idx.ToString(), x))).ToDictionary(x => x.Key, x => x.Value); + children = children.Concat(this.Where(x => x is Function or ConcreteFunction).Select((x, idx) => new KeyValuePair(idx.ToString(), x))).ToDictionary(x => x.Key, x => x.Value); } return children; diff --git a/src/TensorFlowNET.Core/Util/ProtoUtils.cs b/src/TensorFlowNET.Core/Util/ProtoUtils.cs new file mode 100644 index 00000000..e7de8e30 --- /dev/null +++ b/src/TensorFlowNET.Core/Util/ProtoUtils.cs @@ -0,0 +1,24 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Util +{ + internal static class ProtoUtils + { + public static object GetSingleAttrValue(AttrValue value, AttrValue.ValueOneofCase valueCase) + { + return valueCase switch + { + AttrValue.ValueOneofCase.S => value.S, + AttrValue.ValueOneofCase.I => value.I, + AttrValue.ValueOneofCase.F => value.F, + AttrValue.ValueOneofCase.B => value.B, + AttrValue.ValueOneofCase.Type => value.Type, + AttrValue.ValueOneofCase.Shape => value.Shape, + AttrValue.ValueOneofCase.Tensor => value.Tensor, + AttrValue.ValueOneofCase.Func => value.Func, + }; + } + } +} diff --git a/src/TensorFlowNET.Core/Util/function_utils.cs b/src/TensorFlowNET.Core/Util/function_utils.cs index 2944e88e..d4ba4423 100644 --- a/src/TensorFlowNET.Core/Util/function_utils.cs +++ b/src/TensorFlowNET.Core/Util/function_utils.cs @@ -7,15 +7,15 @@ namespace Tensorflow.Util { internal static class function_utils { - private static string _rewriter_config_optimizer_disabled; - public static string get_disabled_rewriter_config() + private static ByteString _rewriter_config_optimizer_disabled; + public static ByteString get_disabled_rewriter_config() { if(_rewriter_config_optimizer_disabled is null) { var config = new ConfigProto(); var rewriter_config = config.GraphOptions.RewriteOptions; rewriter_config.DisableMetaOptimizer = true; - _rewriter_config_optimizer_disabled = config.ToString(); + _rewriter_config_optimizer_disabled = config.ToByteString(); } return _rewriter_config_optimizer_disabled; } diff --git a/src/TensorFlowNET.Core/Util/nest.py.cs b/src/TensorFlowNET.Core/Util/nest.py.cs index c4537896..eb94f4d0 100644 --- a/src/TensorFlowNET.Core/Util/nest.py.cs +++ b/src/TensorFlowNET.Core/Util/nest.py.cs @@ -137,10 +137,12 @@ namespace Tensorflow.Util switch (instance) { case Hashtable hash: - var result = new Hashtable(); - foreach ((object key, object value) in zip(_sorted(hash), args)) - result[key] = value; - return result; + { + var result = new Hashtable(); + foreach ((object key, object value) in zip(_sorted(hash), args)) + result[key] = value; + return result; + } } } //else if( _is_namedtuple(instance) || _is_attrs(instance)) @@ -221,6 +223,16 @@ namespace Tensorflow.Util return list; } + public static List flatten(IEnumerable structure) + { + var list = new List(); + foreach(var item in structure) + { + _flatten_recursive(item, list); + } + return list; + } + public static object[] flatten2(ICanBeFlattened structure) => structure.Flatten(); @@ -527,6 +539,14 @@ namespace Tensorflow.Util return pack_sequence_as(structure, mapped_flat_structure) as T2; } + public static IEnumerable map_structure(Func func, IEnumerable structure) where T2 : class + { + var flat_structure = flatten(structure); + var mapped_flat_structure = flat_structure.Select(func).Select(x => (object)x); + + return pack_sequence_as(structure, mapped_flat_structure) as IEnumerable; + } + /// /// Same as map_structure, but with only one structure (no combining of multiple structures) /// diff --git a/src/TensorFlowNET.Core/Util/variable_utils.cs b/src/TensorFlowNET.Core/Util/variable_utils.cs new file mode 100644 index 00000000..13237f9d --- /dev/null +++ b/src/TensorFlowNET.Core/Util/variable_utils.cs @@ -0,0 +1,33 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Framework; + +namespace Tensorflow.Util +{ + internal static class variable_utils + { + public static Tensor[] convert_variables_to_tensors(object[] values) + { + return values.Select(x => + { + if (resource_variable_ops.is_resource_variable(x)) + { + return ops.convert_to_tensor(x); + } + else if (x is CompositeTensor) + { + throw new NotImplementedException("The composite tensor has not been fully supported."); + } + else if(x is Tensor tensor) + { + return tensor; + } + else + { + throw new TypeError("Currently the output of function to be traced must be `Tensor`."); + } + }).ToArray(); + } + } +} diff --git a/src/TensorFlowNET.Core/ops.cs b/src/TensorFlowNET.Core/ops.cs index bce64198..7aadb206 100644 --- a/src/TensorFlowNET.Core/ops.cs +++ b/src/TensorFlowNET.Core/ops.cs @@ -248,7 +248,7 @@ namespace Tensorflow foreach (var attr in node_def.Attr) { var bytes = attr.Value.ToByteArray(); - c_api.TF_SetAttrValueProto(op_desc, attr.Key, bytes, proto_len: bytes.Length, status: status); + c_api.TF_SetAttrValueProto(op_desc, attr.Key, bytes, proto_len: (ulong)bytes.Length, status: status); status.Check(true); } @@ -575,10 +575,12 @@ namespace Tensorflow public static HandleData get_resource_handle_data(Tensor graph_op) { + throw new NotImplementedException(); // This implementation hasn't been checked for some reasons. // If it throws an exception in the future, please check it. - var handle_data = c_api.GetHandleShapeAndType(graph_op.graph.c_graph, graph_op._as_tf_output()); - return HandleData.Parser.ParseFrom(tf.compat.as_bytes(c_api.StringPiece(handle_data))); + + //var handle_data = c_api.GetHandleShapeAndType(graph_op.graph.c_graph, graph_op._as_tf_output()); + //return HandleData.Parser.ParseFrom(tf.compat.as_bytes(c_api.StringPiece(handle_data))); } public static void dismantle_graph(Graph graph) diff --git a/src/TensorFlowNET.Keras/Engine/Model.Train.cs b/src/TensorFlowNET.Keras/Engine/Model.Train.cs index d8171e2a..5cf34250 100644 --- a/src/TensorFlowNET.Keras/Engine/Model.Train.cs +++ b/src/TensorFlowNET.Keras/Engine/Model.Train.cs @@ -35,6 +35,10 @@ namespace Tensorflow.Keras.Engine { (x, y) = data_handler.DataAdapter.Expand1d(x, y); using var tape = tf.GradientTape(); + //foreach (var variable in TrainableVariables) + //{ + // tape.watch(variable.Handle); + //} var y_pred = Apply(x, training: true); var loss = compiled_loss.Call(y, y_pred); diff --git a/src/TensorFlowNET.Keras/Layers/TensorFlowOpLayer.cs b/src/TensorFlowNET.Keras/Layers/TensorFlowOpLayer.cs index c7b9157b..1ac4a277 100644 --- a/src/TensorFlowNET.Keras/Layers/TensorFlowOpLayer.cs +++ b/src/TensorFlowNET.Keras/Layers/TensorFlowOpLayer.cs @@ -84,8 +84,8 @@ namespace Tensorflow.Keras.Layers inputs.Insert(index, value); } - var (c_op, _) = ops._create_c_op(graph, node_def, inputs.ToArray(), new Operation[0]); - var op = graph._create_op_from_tf_operation(c_op); + var (c_op, op_desc) = ops._create_c_op(graph, node_def, inputs.ToArray(), new Operation[0]); + var op = graph._create_op_from_tf_operation(c_op, desc: op_desc); op._control_flow_post_processing(); // Record the gradient because custom-made ops don't go through the diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/serialized_attributes.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/serialized_attributes.cs index d7df6eb2..9d611efe 100644 --- a/src/TensorFlowNET.Keras/Saving/SavedModel/serialized_attributes.cs +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/serialized_attributes.cs @@ -51,9 +51,9 @@ namespace Tensorflow.Keras.Saving.SavedModel _all_functions = new HashSet(objects_and_functions.Item2); } - public IDictionary Functions => _function_dict.TakeWhile(x => x.Value is not null).ToDictionary(x => x.Key, x => x.Value!); + public IDictionary Functions => _function_dict.Where(x => x.Value is not null).ToDictionary(x => x.Key, x => x.Value!); - public IDictionary CheckpointableObjects => _object_dict.TakeWhile(x => x.Value is not null).ToDictionary(x => x.Key, x => x.Value!); + public IDictionary CheckpointableObjects => _object_dict.Where(x => x.Value is not null).ToDictionary(x => x.Key, x => x.Value!); /// /// Returns functions to attach to the root object during serialization. @@ -82,7 +82,7 @@ namespace Tensorflow.Keras.Saving.SavedModel { get { - var objects = CheckpointableObjects.TakeWhile( x=> _all_checkpointable_objects.Contains(x.Key)).ToDictionary(x => x.Key, x => x.Value); + var objects = CheckpointableObjects.Where( x=> _all_checkpointable_objects.Contains(x.Key)).ToDictionary(x => x.Key, x => x.Value); objects[Constants.KERAS_ATTR] = _keras_trackable; return objects; } diff --git a/test/TensorFlowNET.Keras.UnitTest/Assets/python_func_model/fingerprint.pb b/test/TensorFlowNET.Keras.UnitTest/Assets/python_func_model/fingerprint.pb new file mode 100644 index 00000000..361ca3a8 Binary files /dev/null and b/test/TensorFlowNET.Keras.UnitTest/Assets/python_func_model/fingerprint.pb differ diff --git a/test/TensorFlowNET.Keras.UnitTest/Assets/python_func_model/keras_metadata.pb b/test/TensorFlowNET.Keras.UnitTest/Assets/python_func_model/keras_metadata.pb new file mode 100644 index 00000000..b98e1733 --- /dev/null +++ b/test/TensorFlowNET.Keras.UnitTest/Assets/python_func_model/keras_metadata.pb @@ -0,0 +1,6 @@ + +root"_tf_keras_sequential*{"name": "sequential", "trainable": true, "expects_training_arg": true, "dtype": "float32", "batch_input_shape": null, "must_restore_from_config": false, "preserve_input_structure_in_config": false, "autocast": false, "class_name": "Sequential", "config": {"name": "sequential", "layers": [{"class_name": "InputLayer", "config": {"batch_input_shape": {"class_name": "__tuple__", "items": [null, 784]}, "dtype": "float32", "sparse": false, "ragged": false, "name": "input_1"}}, {"class_name": "Transformer", "config": {"name": "transformer", "trainable": true, "dtype": "float32", "a": 784, "b": 10}}, {"class_name": "Softmax", "config": {"name": "softmax", "trainable": true, "dtype": "float32", "axis": -1}}]}, "shared_object_id": 3, "input_spec": [{"class_name": "InputSpec", "config": {"dtype": null, "shape": {"class_name": "__tuple__", "items": [null, 784]}, "ndim": 2, "max_ndim": null, "min_ndim": null, "axes": {}}}], "build_input_shape": {"class_name": "TensorShape", "items": [null, 784]}, "is_graph_network": true, "full_save_spec": {"class_name": "__tuple__", "items": [[{"class_name": "TypeSpec", "type_spec": "tf.TensorSpec", "serialized": [{"class_name": "TensorShape", "items": [null, 784]}, "float32", "input_1"]}], {}]}, "save_spec": {"class_name": "TypeSpec", "type_spec": "tf.TensorSpec", "serialized": [{"class_name": "TensorShape", "items": [null, 784]}, "float32", "input_1"]}, "keras_version": "2.11.0", "backend": "tensorflow", "model_config": {"class_name": "Sequential", "config": {"name": "sequential", "layers": [{"class_name": "InputLayer", "config": {"batch_input_shape": {"class_name": "__tuple__", "items": [null, 784]}, "dtype": "float32", "sparse": false, "ragged": false, "name": "input_1"}, "shared_object_id": 0}, {"class_name": "Transformer", "config": {"name": "transformer", "trainable": true, "dtype": "float32", "a": 784, "b": 10}, "shared_object_id": 1}, {"class_name": "Softmax", "config": {"name": "softmax", "trainable": true, "dtype": "float32", "axis": -1}, "shared_object_id": 2}]}}, "training_config": {"loss": "sparse_categorical_crossentropy", "metrics": [[{"class_name": "MeanMetricWrapper", "config": {"name": "accuracy", "dtype": "float32", "fn": "categorical_accuracy"}, "shared_object_id": 5}]], "weighted_metrics": null, "loss_weights": null, "optimizer_config": {"class_name": "Custom>Adam", "config": {"name": "Adam", "weight_decay": null, "clipnorm": null, "global_clipnorm": null, "clipvalue": null, "use_ema": false, "ema_momentum": 0.99, "ema_overwrite_frequency": null, "jit_compile": false, "is_legacy_optimizer": false, "learning_rate": 0.0010000000474974513, "beta_1": 0.9, "beta_2": 0.999, "epsilon": 1e-07, "amsgrad": false}}}}2 +root.layer_with_weights-0"_tf_keras_layer*{"name": "transformer", "trainable": true, "expects_training_arg": false, "dtype": "float32", "batch_input_shape": null, "stateful": false, "must_restore_from_config": false, "preserve_input_structure_in_config": false, "autocast": true, "class_name": "Transformer", "config": {"name": "transformer", "trainable": true, "dtype": "float32", "a": 784, "b": 10}, "shared_object_id": 1, "build_input_shape": {"class_name": "TensorShape", "items": [null, 784]}}2 + root.layer-1"_tf_keras_layer*{"name": "softmax", "trainable": true, "expects_training_arg": false, "dtype": "float32", "batch_input_shape": null, "stateful": false, "must_restore_from_config": false, "preserve_input_structure_in_config": false, "autocast": true, "class_name": "Softmax", "config": {"name": "softmax", "trainable": true, "dtype": "float32", "axis": -1}, "shared_object_id": 2, "build_input_shape": {"class_name": "TensorShape", "items": [null, 10]}}2 +9root.keras_api.metrics.0"_tf_keras_metric*{"class_name": "Mean", "name": "loss", "dtype": "float32", "config": {"name": "loss", "dtype": "float32"}, "shared_object_id": 6}2 +:root.keras_api.metrics.1"_tf_keras_metric*{"class_name": "MeanMetricWrapper", "name": "accuracy", "dtype": "float32", "config": {"name": "accuracy", "dtype": "float32", "fn": "categorical_accuracy"}, "shared_object_id": 5}2 \ No newline at end of file diff --git a/test/TensorFlowNET.Keras.UnitTest/Assets/python_func_model/saved_model.pb b/test/TensorFlowNET.Keras.UnitTest/Assets/python_func_model/saved_model.pb new file mode 100644 index 00000000..f22755e0 Binary files /dev/null and b/test/TensorFlowNET.Keras.UnitTest/Assets/python_func_model/saved_model.pb differ diff --git a/test/TensorFlowNET.Keras.UnitTest/Assets/python_func_model/variables/variables.data-00000-of-00001 b/test/TensorFlowNET.Keras.UnitTest/Assets/python_func_model/variables/variables.data-00000-of-00001 new file mode 100644 index 00000000..399265af Binary files /dev/null and b/test/TensorFlowNET.Keras.UnitTest/Assets/python_func_model/variables/variables.data-00000-of-00001 differ diff --git a/test/TensorFlowNET.Keras.UnitTest/Assets/python_func_model/variables/variables.index b/test/TensorFlowNET.Keras.UnitTest/Assets/python_func_model/variables/variables.index new file mode 100644 index 00000000..e0b0e800 Binary files /dev/null and b/test/TensorFlowNET.Keras.UnitTest/Assets/python_func_model/variables/variables.index differ diff --git a/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelLoad.cs b/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelLoad.cs index 17d864d2..cb230605 100644 --- a/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelLoad.cs +++ b/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelLoad.cs @@ -6,7 +6,6 @@ using Tensorflow.Keras.Optimizers; using Tensorflow.Keras.UnitTest.Helpers; using Tensorflow.NumPy; using static Tensorflow.Binding; -using static Tensorflow.KerasApi; namespace TensorFlowNET.Keras.UnitTest.SaveModel; @@ -62,11 +61,26 @@ public class SequentialModelLoad [TestMethod] public void Temp() { - var model = tf.keras.models.load_model(@"D:\development\tf.net\tf_test\python_func"); + var model = tf.keras.models.load_model(@"Assets/python_func_model"); model.summary(); - var x = tf.ones((2, 10)); + var x = tf.random.uniform((8, 784), -1, 1); var y = model.Apply(x); Console.WriteLine(y); + + //model.compile(tf.keras.optimizers.Adam(), tf.keras.losses.SparseCategoricalCrossentropy(), new string[] { "accuracy" }); + + //var data_loader = new MnistModelLoader(); + //var num_epochs = 1; + //var batch_size = 8; + + //var dataset = data_loader.LoadAsync(new ModelLoadSetting + //{ + // TrainDir = "mnist", + // OneHot = false, + // ValidationSize = 58000, + //}).Result; + + //model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs); } } diff --git a/test/TensorFlowNET.Keras.UnitTest/Tensorflow.Keras.UnitTest.csproj b/test/TensorFlowNET.Keras.UnitTest/Tensorflow.Keras.UnitTest.csproj index 7d519bf6..a5c381fe 100644 --- a/test/TensorFlowNET.Keras.UnitTest/Tensorflow.Keras.UnitTest.csproj +++ b/test/TensorFlowNET.Keras.UnitTest/Tensorflow.Keras.UnitTest.csproj @@ -49,6 +49,22 @@ PreserveNewest + + + PreserveNewest + + + PreserveNewest + + + PreserveNewest + + + PreserveNewest + + + PreserveNewest + diff --git a/test/TensorFlowNET.Native.UnitTest/Functions/FunctionTest.cs b/test/TensorFlowNET.Native.UnitTest/Functions/FunctionTest.cs index 0872394b..9230bc73 100644 --- a/test/TensorFlowNET.Native.UnitTest/Functions/FunctionTest.cs +++ b/test/TensorFlowNET.Native.UnitTest/Functions/FunctionTest.cs @@ -413,7 +413,7 @@ namespace Tensorflow.Native.UnitTest ASSERT_EQ(TF_OK, s_.Code, s_.Message); ASSERT_NE(func_, IntPtr.Zero); ASSERT_EQ(func_name_, c_api.StringPiece(c_api.TF_FunctionName(func_))); - c_api.TF_GraphCopyFunction(host_graph_, func_, IntPtr.Zero, s_); + c_api.TF_GraphCopyFunction(host_graph_, func_, new SafeFuncGraphHandle(IntPtr.Zero), s_); ASSERT_EQ(TF_OK, s_.Code, s_.Message); }