From fd1eb40f25968b10186f9a4219b27f32487d1c04 Mon Sep 17 00:00:00 2001 From: Yaohui Liu Date: Mon, 10 Apr 2023 14:42:31 +0800 Subject: [PATCH] Partially support the backward of loaded function model. --- .../Extensions/DictionaryExtension.cs | 31 ++ src/TensorFlowNET.Core/APIs/tf.gradients.cs | 2 +- src/TensorFlowNET.Core/APIs/tf.tensor.cs | 7 + .../Attributes/c_api.ops.cs | 2 +- src/TensorFlowNET.Core/Binding.Util.cs | 1 + src/TensorFlowNET.Core/Buffers/Buffer.cs | 6 + .../Checkpoint/CheckPointUtils.cs | 2 +- src/TensorFlowNET.Core/Checkpoint/SaveUtil.cs | 1 + .../Contexts/Context.Config.cs | 89 ++++- src/TensorFlowNET.Core/Contexts/Context.cs | 44 ++- .../Contexts/FunctionCallOptions.cs | 5 +- .../Eager/EagerRunner.TFE_Execute.cs | 1 + src/TensorFlowNET.Core/Eager/backprop_util.cs | 53 +++ src/TensorFlowNET.Core/Eager/c_api.eager.cs | 4 +- .../Framework/Models/ScopedTFFunction.cs | 6 - .../Framework/ScopedTFFunction.cs | 22 ++ .../Framework/function_def_lib.cs | 15 +- .../Functions/ConcreteFunction.cs | 56 +-- .../Functions/EagerDefinedFunction.cs | 117 ++++-- .../FirstOrderTapeGradientFunctions.cs | 4 +- src/TensorFlowNET.Core/Functions/Function.cs | 33 +- .../Functions/TapeGradientFunctions.cs | 60 +-- .../Functions/TracingCompiler.cs | 84 +++++ .../Functions/c_api.function.cs | 6 +- .../Functions/composite_tensor_utils.cs | 50 +++ .../Functions/function_saved_model_utils.cs | 15 +- .../Functions/monomorphic_function.cs | 268 ++++++++++++- .../Gradients/default_gradient.cs | 52 +++ .../Gradients/gradients_util.cs | 95 ++++- .../ops.gradient_function_mapping.cs | 19 +- src/TensorFlowNET.Core/Graphs/FuncGraph.cs | 355 +++++++++++++++++- .../Graphs/Graph.Gradient.cs.cs | 9 +- .../Graphs/Graph.Operation.cs | 2 +- src/TensorFlowNET.Core/Graphs/Graph.cs | 31 +- .../Graphs/GraphOverrideGradientContext.cs | 37 ++ .../Operations/Operation.cs | 107 +++++- .../Operations/c_api.ops.cs | 8 +- .../Operations/functional_ops.cs | 2 +- .../Operations/gen_functional_ops.cs | 45 +++ src/TensorFlowNET.Core/Operations/gen_ops.cs | 38 ++ .../Operations/handle_data_util.cs | 2 + .../Operations/resource_variable_ops.cs | 71 ++-- .../Tensors/Tensor.Creation.cs | 1 + src/TensorFlowNET.Core/Tensors/Tensor.cs | 10 +- src/TensorFlowNET.Core/Tensors/Tensors.cs | 2 +- .../Saving/SavedModel/WrapperFunction.cs | 11 +- .../SavedModel/function_deserialization.cs | 95 +++-- .../SavedModel/signature_serialization.cs | 2 +- .../Training/data_structures.cs | 6 +- src/TensorFlowNET.Core/Util/ProtoUtils.cs | 24 ++ src/TensorFlowNET.Core/Util/function_utils.cs | 6 +- src/TensorFlowNET.Core/Util/nest.py.cs | 28 +- src/TensorFlowNET.Core/Util/variable_utils.cs | 33 ++ src/TensorFlowNET.Core/ops.cs | 8 +- src/TensorFlowNET.Keras/Engine/Model.Train.cs | 4 + .../Layers/TensorFlowOpLayer.cs | 4 +- .../SavedModel/serialized_attributes.cs | 6 +- .../Assets/python_func_model/fingerprint.pb | Bin 0 -> 54 bytes .../python_func_model/keras_metadata.pb | 6 + .../Assets/python_func_model/saved_model.pb | Bin 0 -> 47187 bytes .../variables/variables.data-00000-of-00001 | Bin 0 -> 34194 bytes .../variables/variables.index | Bin 0 -> 575 bytes .../SaveModel/SequentialModelLoad.cs | 20 +- .../Tensorflow.Keras.UnitTest.csproj | 16 + .../Functions/FunctionTest.cs | 2 +- 65 files changed, 1886 insertions(+), 255 deletions(-) create mode 100644 Tensorflow.Common/Extensions/DictionaryExtension.cs create mode 100644 src/TensorFlowNET.Core/Eager/backprop_util.cs delete mode 100644 src/TensorFlowNET.Core/Framework/Models/ScopedTFFunction.cs create mode 100644 src/TensorFlowNET.Core/Framework/ScopedTFFunction.cs create mode 100644 src/TensorFlowNET.Core/Functions/TracingCompiler.cs create mode 100644 src/TensorFlowNET.Core/Functions/composite_tensor_utils.cs create mode 100644 src/TensorFlowNET.Core/Gradients/default_gradient.cs create mode 100644 src/TensorFlowNET.Core/Graphs/GraphOverrideGradientContext.cs create mode 100644 src/TensorFlowNET.Core/Util/ProtoUtils.cs create mode 100644 src/TensorFlowNET.Core/Util/variable_utils.cs create mode 100644 test/TensorFlowNET.Keras.UnitTest/Assets/python_func_model/fingerprint.pb create mode 100644 test/TensorFlowNET.Keras.UnitTest/Assets/python_func_model/keras_metadata.pb create mode 100644 test/TensorFlowNET.Keras.UnitTest/Assets/python_func_model/saved_model.pb create mode 100644 test/TensorFlowNET.Keras.UnitTest/Assets/python_func_model/variables/variables.data-00000-of-00001 create mode 100644 test/TensorFlowNET.Keras.UnitTest/Assets/python_func_model/variables/variables.index diff --git a/Tensorflow.Common/Extensions/DictionaryExtension.cs b/Tensorflow.Common/Extensions/DictionaryExtension.cs new file mode 100644 index 00000000..7502a3a7 --- /dev/null +++ b/Tensorflow.Common/Extensions/DictionaryExtension.cs @@ -0,0 +1,31 @@ +using System; +using System.Collections.Generic; +using System.Runtime.CompilerServices; +using System.Text; + +namespace Tensorflow.Common.Extensions +{ + public static class DictionaryExtension + { + public static void Deconstruct(this KeyValuePair pair, out T1 first, out T2 second) + { + first = pair.Key; + second = pair.Value; + } + public static void Update(this Dictionary dic, IDictionary other) + { + foreach(var (key, value) in other) + { + dic[key] = value; + } + } + public static T2 GetOrDefault(this Dictionary dic, T1 key, T2 defaultValue) + { + if (dic.ContainsKey(key)) + { + return dic[key]; + } + return defaultValue; + } + } +} diff --git a/src/TensorFlowNET.Core/APIs/tf.gradients.cs b/src/TensorFlowNET.Core/APIs/tf.gradients.cs index d722cb14..492b1034 100644 --- a/src/TensorFlowNET.Core/APIs/tf.gradients.cs +++ b/src/TensorFlowNET.Core/APIs/tf.gradients.cs @@ -21,7 +21,7 @@ namespace Tensorflow { public partial class tensorflow { - GradientTape _tapeSet; + internal GradientTape _tapeSet; /// /// Record operations for automatic differentiation. diff --git a/src/TensorFlowNET.Core/APIs/tf.tensor.cs b/src/TensorFlowNET.Core/APIs/tf.tensor.cs index 91293b3a..35efde06 100644 --- a/src/TensorFlowNET.Core/APIs/tf.tensor.cs +++ b/src/TensorFlowNET.Core/APIs/tf.tensor.cs @@ -14,6 +14,8 @@ limitations under the License. ******************************************************************************/ +using Tensorflow.Operations; + namespace Tensorflow { public partial class tensorflow @@ -79,5 +81,10 @@ namespace Tensorflow num_split: num_split, axis: axis, name: name); + + public Tensor ensure_shape(Tensor x, Shape shape, string name = null) + { + return gen_ops.ensure_shape(x, shape, name); + } } } diff --git a/src/TensorFlowNET.Core/Attributes/c_api.ops.cs b/src/TensorFlowNET.Core/Attributes/c_api.ops.cs index 7d9ff65f..2a22413b 100644 --- a/src/TensorFlowNET.Core/Attributes/c_api.ops.cs +++ b/src/TensorFlowNET.Core/Attributes/c_api.ops.cs @@ -61,7 +61,7 @@ namespace Tensorflow public static extern void TF_SetAttrBool(IntPtr desc, string attr_name, bool value); [DllImport(TensorFlowLibName)] - public static extern void TF_SetAttrValueProto(IntPtr desc, string attr_name, byte[] proto, int proto_len, SafeStatusHandle status); + public static extern void TF_SetAttrValueProto(IntPtr desc, string attr_name, byte[] proto, ulong proto_len, SafeStatusHandle status); /// /// Set `num_dims` to -1 to represent "unknown rank". diff --git a/src/TensorFlowNET.Core/Binding.Util.cs b/src/TensorFlowNET.Core/Binding.Util.cs index 5d9d799d..8df39334 100644 --- a/src/TensorFlowNET.Core/Binding.Util.cs +++ b/src/TensorFlowNET.Core/Binding.Util.cs @@ -22,6 +22,7 @@ using System.ComponentModel; using System.Diagnostics; using System.IO; using System.Linq; +using Tensorflow.Operations; namespace Tensorflow { diff --git a/src/TensorFlowNET.Core/Buffers/Buffer.cs b/src/TensorFlowNET.Core/Buffers/Buffer.cs index 9ec9e22f..330e30ca 100644 --- a/src/TensorFlowNET.Core/Buffers/Buffer.cs +++ b/src/TensorFlowNET.Core/Buffers/Buffer.cs @@ -107,6 +107,12 @@ namespace Tensorflow } } + public void Release() + { + _handle.Dispose(); + _handle = null; + } + public override string ToString() => $"0x{_handle.DangerousGetHandle():x16}"; diff --git a/src/TensorFlowNET.Core/Checkpoint/CheckPointUtils.cs b/src/TensorFlowNET.Core/Checkpoint/CheckPointUtils.cs index 9793798d..490c284b 100644 --- a/src/TensorFlowNET.Core/Checkpoint/CheckPointUtils.cs +++ b/src/TensorFlowNET.Core/Checkpoint/CheckPointUtils.cs @@ -161,7 +161,7 @@ public static class CheckPointUtils internal static IEnumerable _objects_with_attributes(IEnumerable full_list) { - return full_list.TakeWhile(x => + return full_list.Where(x => { var saveables = x.gather_saveables_for_checkpoint(); return saveables is not null && saveables.Count > 0; diff --git a/src/TensorFlowNET.Core/Checkpoint/SaveUtil.cs b/src/TensorFlowNET.Core/Checkpoint/SaveUtil.cs index 4aa2a808..7a5da7e3 100644 --- a/src/TensorFlowNET.Core/Checkpoint/SaveUtil.cs +++ b/src/TensorFlowNET.Core/Checkpoint/SaveUtil.cs @@ -109,6 +109,7 @@ namespace Tensorflow.Checkpoint TrackableObjectGraph.Types.TrackableObject trackable_object = new(); trackable_object.SlotVariables.AddRange(td.slot_variable_proto); trackable_object.Children.AddRange(td.children_proto); + object_graph_proto.Nodes.Add(trackable_object); } return object_graph_proto; } diff --git a/src/TensorFlowNET.Core/Contexts/Context.Config.cs b/src/TensorFlowNET.Core/Contexts/Context.Config.cs index b363b516..0c7bded6 100644 --- a/src/TensorFlowNET.Core/Contexts/Context.Config.cs +++ b/src/TensorFlowNET.Core/Contexts/Context.Config.cs @@ -14,9 +14,11 @@ limitations under the License. ******************************************************************************/ +using Google.Protobuf; using System; using System.Diagnostics; using System.Linq; +using Tensorflow.Common.Extensions; namespace Tensorflow.Contexts { @@ -25,12 +27,93 @@ namespace Tensorflow.Contexts /// public sealed partial class Context { - public ConfigProto Config { get; set; } = new ConfigProto + protected Device.PhysicalDevice[] _physical_devices; + protected Dictionary _physical_device_to_index; + ConfigProto _config; + public ConfigProto Config { - GpuOptions = new GPUOptions + get { + _initialize_physical_devices(); + + var config = new ConfigProto(); + if(_config is not null) + { + config.MergeFrom(_config); + } + config.LogDevicePlacement = _log_device_placement; + + config.DeviceCount["CPU"] = 0; + config.DeviceCount["GPU"] = 0; + foreach(var dev in _physical_devices) + { + if (config.DeviceCount.ContainsKey(dev.DeviceType)) + { + config.DeviceCount[dev.DeviceType] += 1; + } + else + { + config.DeviceCount[dev.DeviceType] = 1; + } + } + + var gpu_options = _compute_gpu_options(); + config.GpuOptions = GPUOptions.Parser.ParseFrom(gpu_options.ToByteArray()); + + return config; + } + set + { + _config = value; + } + } + + protected void _initialize_physical_devices(bool reinitialize = false) + { + if(!reinitialize && _physical_devices is not null) + { + return; + } + var devs = list_physical_devices(); + _physical_devices = devs.Select(d => new Device.PhysicalDevice() + { + DeviceName = d.DeviceName, + DeviceType = d.DeviceType + }).ToArray(); + _physical_device_to_index = _physical_devices.Select((p, i) => new KeyValuePair(p, i)) + .ToDictionary(x => x.Key, x => x.Value); + + _import_config(); + } + + protected void _import_config() + { + if(_config is null) + { + return; + } + if(!_config.DeviceCount.TryGetValue("CPU", out var num_cpus)) + { + num_cpus = 1; + } + if(num_cpus != 1) + { + // TODO(Rinne): implement it. } - }; + + var gpus = _physical_devices.Where(d => d.DeviceType == "GPU"); + if(gpus.Count() == 0) + { + return; + } + + if(!_config.DeviceCount.TryGetValue("GPU", out var gpu_count)) + { + gpu_count = 0; + } + + // TODO(Rinne): implement it. + } ConfigProto MergeConfig() { diff --git a/src/TensorFlowNET.Core/Contexts/Context.cs b/src/TensorFlowNET.Core/Contexts/Context.cs index deb67920..7fec1e5a 100644 --- a/src/TensorFlowNET.Core/Contexts/Context.cs +++ b/src/TensorFlowNET.Core/Contexts/Context.cs @@ -38,7 +38,26 @@ namespace Tensorflow.Contexts public string ScopeName { get; set; } = ""; bool initialized = false; ContextSwitchStack context_switches; - public FunctionCallOptions FunctionCallOptions { get; } + protected FunctionCallOptions _function_call_options; + public FunctionCallOptions FunctionCallOptions + { + get + { + if(_function_call_options is null) + { + var config = Config; + _function_call_options = new FunctionCallOptions() + { + Config = config + }; + } + return _function_call_options; + } + set + { + _function_call_options = value; + } + } SafeContextHandle _handle; @@ -62,7 +81,6 @@ namespace Tensorflow.Contexts if (initialized) return; - Config = MergeConfig(); FunctionCallOptions.Config = Config; var config_str = Config.ToByteArray(); var opts = new ContextOptions(); @@ -167,11 +185,29 @@ namespace Tensorflow.Contexts return c_api.TFE_ContextHasFunction(_handle, name); } + public void add_function(SafeFuncGraphHandle fn) + { + ensure_initialized(); + Status status = new(); + c_api.TFE_ContextAddFunction(_handle, fn, status); + status.Check(true); + } + + public void remove_function(string name) + { + ensure_initialized(); + Status status = new(); + c_api.TFE_ContextRemoveFunction(_handle, name, status); + status.Check(true); + } + public void add_function_def(FunctionDef fdef) { ensure_initialized(); - var fdef_string = fdef.ToString(); - c_api.TFE_ContextAddFunctionDef(_handle, fdef_string, fdef_string.Length); + var fdef_string = fdef.ToByteArray(); + Status status = new Status(); + c_api.TFE_ContextAddFunctionDef(_handle, fdef_string, (ulong)fdef_string.Length, status); + status.Check(true); } public void restore_mode() diff --git a/src/TensorFlowNET.Core/Contexts/FunctionCallOptions.cs b/src/TensorFlowNET.Core/Contexts/FunctionCallOptions.cs index 6b6028f0..2fcf9dce 100644 --- a/src/TensorFlowNET.Core/Contexts/FunctionCallOptions.cs +++ b/src/TensorFlowNET.Core/Contexts/FunctionCallOptions.cs @@ -9,10 +9,11 @@ namespace Tensorflow.Contexts public class FunctionCallOptions { public ConfigProto Config { get; set; } + public string ExecutorType { get; set; } - public string config_proto_serialized() + public ByteString config_proto_serialized() { - return Config.ToByteString().ToStringUtf8(); + return Config.ToByteString(); } } } diff --git a/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_Execute.cs b/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_Execute.cs index aa205d45..3806b3ad 100644 --- a/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_Execute.cs +++ b/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_Execute.cs @@ -17,6 +17,7 @@ using System; using System.Linq; using Tensorflow.Contexts; +using Tensorflow.Functions; using static Tensorflow.Binding; namespace Tensorflow.Eager diff --git a/src/TensorFlowNET.Core/Eager/backprop_util.cs b/src/TensorFlowNET.Core/Eager/backprop_util.cs new file mode 100644 index 00000000..0d726e1d --- /dev/null +++ b/src/TensorFlowNET.Core/Eager/backprop_util.cs @@ -0,0 +1,53 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Operations; + +namespace Tensorflow.Eager +{ + internal static class backprop_util + { + // TODO: add quantized_dtypes (after being supported). + private static HashSet _trainable_dtypes = new HashSet(new TF_DataType[] + { + dtypes.float16, dtypes.float32, dtypes.float64, dtypes.complex64, dtypes.complex128, + dtypes.resource, dtypes.variant, TF_DataType.TF_BFLOAT16 + }); + public static bool IsTrainable(Tensor tensor) + { + var dtype = _DTypeFromTensor(tensor); + return _trainable_dtypes.Contains(dtype); + } + public static bool IsTrainable(TF_DataType dtype) + { + return _trainable_dtypes.Contains(dtype); + } + + private static TF_DataType _DTypeFromTensor(Tensor tensor) + { + var dtype = tensor.dtype; + if(dtype.as_base_dtype() == TF_DataType.TF_VARIANT) + { + CppShapeInferenceResult.Types.HandleData handle_data; + if (tensor is EagerTensor) + { + handle_data = tensor.HandleData; + } + else + { + handle_data = handle_data_util.get_resource_handle_data(tensor); + } + if(handle_data is not null && handle_data.IsSet && handle_data.ShapeAndType is not null && + handle_data.ShapeAndType.Count > 0) + { + var first_type = handle_data.ShapeAndType[0].Dtype; + if(first_type != DataType.DtInvalid && handle_data.ShapeAndType.All(x => x.Dtype == first_type)) + { + return first_type.as_tf_dtype(); + } + } + } + return dtype; + } + } +} diff --git a/src/TensorFlowNET.Core/Eager/c_api.eager.cs b/src/TensorFlowNET.Core/Eager/c_api.eager.cs index e8746c1b..665e537f 100644 --- a/src/TensorFlowNET.Core/Eager/c_api.eager.cs +++ b/src/TensorFlowNET.Core/Eager/c_api.eager.cs @@ -31,7 +31,7 @@ namespace Tensorflow public static extern void TFE_ContextOptionsSetConfig(SafeContextOptionsHandle opts, byte[] proto, ulong proto_len, SafeStatusHandle status); [DllImport(TensorFlowLibName)] - public static extern void TFE_ContextAddFunctionDef(SafeContextHandle ctx, string serialized_function_def, int size); + public static extern void TFE_ContextAddFunctionDef(SafeContextHandle ctx, byte[] serialized_function_def, ulong size, SafeStatusHandle status); [DllImport(TensorFlowLibName)] public static extern void TFE_ContextOptionsSetDevicePlacementPolicy(SafeContextOptionsHandle opts, ContextDevicePlacementPolicy device_policy); @@ -280,7 +280,7 @@ namespace Tensorflow public static extern void TFE_OpSetAttrIntList(SafeEagerOpHandle op, string attr_name, long[] values, int num_values); [DllImport(TensorFlowLibName)] - public static extern void TFE_OpSetAttrValueProto(SafeEagerOpHandle op, string attr_name, IMessage[] proto, int proto_len, SafeStatusHandle status); + public static extern void TFE_OpSetAttrValueProto(IntPtr op, string attr_name, IntPtr proto, ulong proto_len, SafeStatusHandle status); /// /// diff --git a/src/TensorFlowNET.Core/Framework/Models/ScopedTFFunction.cs b/src/TensorFlowNET.Core/Framework/Models/ScopedTFFunction.cs deleted file mode 100644 index bce889b6..00000000 --- a/src/TensorFlowNET.Core/Framework/Models/ScopedTFFunction.cs +++ /dev/null @@ -1,6 +0,0 @@ -namespace Tensorflow.Framework.Models -{ - class ScopedTFFunction - { - } -} diff --git a/src/TensorFlowNET.Core/Framework/ScopedTFFunction.cs b/src/TensorFlowNET.Core/Framework/ScopedTFFunction.cs new file mode 100644 index 00000000..11e920f8 --- /dev/null +++ b/src/TensorFlowNET.Core/Framework/ScopedTFFunction.cs @@ -0,0 +1,22 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Framework +{ + internal class ScopedTFFunction + { + SafeFuncGraphHandle _handle; + string _name; + public ScopedTFFunction(SafeFuncGraphHandle func, string name) + { + _handle = func; + _name = name; + } + + public SafeFuncGraphHandle Get() + { + return _handle; + } + } +} diff --git a/src/TensorFlowNET.Core/Framework/function_def_lib.cs b/src/TensorFlowNET.Core/Framework/function_def_lib.cs index b81cb71b..67f8d324 100644 --- a/src/TensorFlowNET.Core/Framework/function_def_lib.cs +++ b/src/TensorFlowNET.Core/Framework/function_def_lib.cs @@ -4,6 +4,7 @@ using System.Diagnostics; using System.Security.Cryptography; using System.Text; using Tensorflow.Graphs; +using Tensorflow.Common.Extensions; using static Tensorflow.Binding; using static Tensorflow.CppShapeInferenceResult.Types; @@ -64,7 +65,7 @@ namespace Tensorflow.Framework { output_names[ops.tensor_id(func_graph.get_tensor_by_name(tensor_name))] = ret_arg_def.Name; } - // TODO(Rinne): func_graph._output_names = output_names + func_graph._output_names = output_names; func_graph.Exit(); return func_graph; @@ -154,9 +155,17 @@ namespace Tensorflow.Framework foreach(var node_def in fdef.NodeDef) { var graph = default_graph; - // TODO(Rinne): The `Graph` lacks `_functions`, needed to be implemented in the future. - while(graph.OuterGraph is not null) + while (true) { + if(graph is null) + { + break; + } + var f = graph.Functions.GetOrDefault(node_def.Op, null); + if(f is not null && graph.OuterGraph is null) + { + break; + } graph = graph.OuterGraph; } diff --git a/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs b/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs index 402d876e..8524f724 100644 --- a/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs +++ b/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs @@ -4,6 +4,7 @@ using System.Diagnostics; using System.Linq; using Tensorflow.Eager; using Tensorflow.Framework.Models; +using Tensorflow.Gradients; using Tensorflow.Graphs; using Tensorflow.Train; using Tensorflow.Util; @@ -19,7 +20,7 @@ namespace Tensorflow.Functions protected IEnumerable _captured_inputs; internal FuncGraph func_graph; protected DelayedRewriteGradientFunctions _delayed_rewrite_functions; - protected Dictionary _attrs; + protected Dictionary _attrs; protected FunctionSpec _function_spec; protected FunctionSpec _pre_initialized_function_spec = null; protected EagerDefinedFunction _inference_function; @@ -29,22 +30,25 @@ namespace Tensorflow.Functions public string Name => _delayed_rewrite_functions.Forward().Name; - public Tensor[] Outputs; + public Tensor[] Outputs => func_graph.Outputs; public Type ReturnType; public TensorSpec[] OutputStructure; public IEnumerable ArgKeywords { get; set; } public long NumPositionArgs { get; set; } public FunctionDef FunctionDef => _delayed_rewrite_functions.Forward().Definition; + public Tensor[] FlatStructuredOutputs => func_graph.FlatStructuredOutputs; + public IEnumerable Variables => func_graph.Variables; + public IEnumerable TrainableVariables => func_graph.TrainableVariables; public ConcreteFunction(string name) { func_graph = new FuncGraph(name); _captured_inputs = func_graph.external_captures; - _attrs= new Dictionary(); + _attrs= new Dictionary(); _set_infer_function(); } - public ConcreteFunction(FuncGraph graph, Dictionary attrs = null) + public ConcreteFunction(FuncGraph graph, Dictionary attrs = null) { func_graph = graph; _captured_inputs = func_graph.external_captures; @@ -70,7 +74,7 @@ namespace Tensorflow.Functions null); func_graph.Exit(); _captured_inputs = func_graph.external_captures; - _attrs = new Dictionary(); + _attrs = new Dictionary(); _set_infer_function(); } @@ -93,7 +97,7 @@ namespace Tensorflow.Functions null); func_graph.Exit(); _captured_inputs = func_graph.external_captures; - _attrs = new Dictionary(); + _attrs = new Dictionary(); _set_infer_function(); } @@ -160,27 +164,20 @@ namespace Tensorflow.Functions } if (!executing_eagerly) { - + // TODO(Rinne): add the check } - tensor_inputs.AddRange(captured_inputs); + tensor_inputs.AddRange(captured_inputs); args = tensor_inputs.ToArray(); - var possible_gradient_type = tf.Runner.MustRecordGradient() ? 1 : 0; + var possible_gradient_type = gradients_util.PossibleTapeGradientTypes(args); // No tape is watching; skip to running the function. - if (possible_gradient_type == 0 && executing_eagerly) + if (possible_gradient_type == gradients_util.POSSIBLE_GRADIENT_TYPES_NONE && executing_eagerly) { return _build_call_outputs(_inference_function.Call(args)); - //var attrs = new object[] - //{ - // "executor_type", "", - // "config_proto", tf.Context.FunctionCallOptions.config_proto_serialized() - //}; - //return tf.Runner.Execute(tf.Context, func_graph.FuncName, func_graph.Outputs.Length, args, attrs); } - if (forward_backward == null) - forward_backward = SelectForwardAndBackwardFunctions(args, possible_gradient_type, executing_eagerly); + forward_backward = SelectForwardAndBackwardFunctions(args, possible_gradient_type, executing_eagerly); var (forward_function, args_with_tangents) = forward_backward.Forward(); Tensors flat_outputs = null; if (executing_eagerly) @@ -189,8 +186,12 @@ namespace Tensorflow.Functions } else { - // TODO(Rinne): add `default_graph._override_gradient_function`. - flat_outputs = forward_function.Call(args_with_tangents); + tf_with(default_graph._override_gradient_function(new Dictionary>(){ + { "PartitionedCall", _get_gradient_function() }, { "StatefulPartitionedCall", _get_gradient_function() } + }), _ => + { + flat_outputs = forward_function.Call(args_with_tangents); + }); } forward_backward.Record(flat_outputs); return _build_call_outputs(flat_outputs); @@ -215,7 +216,8 @@ namespace Tensorflow.Functions TangentInfo input_tangents; if (executing_eagerly) { - throw new NotImplementedException(); + // TODO(Rinne): check if it needs to be implemented. + input_tangents = new TangentInfo(); } else { @@ -239,7 +241,12 @@ namespace Tensorflow.Functions } // TODO(Rinne): add arg "input_tagents" for ForwardBackwardCall. - return new ForwardBackwardCall(_delayed_rewrite_functions, args, tape_watching: false); + return new ForwardBackwardCall(_delayed_rewrite_functions, args, tape_watching: tf.Runner.MustRecordGradient()); + } + + internal void set_variables(IEnumerable variables) + { + func_graph.Variables = variables; } internal void _set_infer_function() @@ -274,6 +281,11 @@ namespace Tensorflow.Functions }; } + internal Func _get_gradient_function() + { + return _delayed_rewrite_functions._rewrite_forward_and_call_backward; + } + private Tensors _build_call_outputs(Tensors result) { // TODO(Rinne): dwal with `func_graph.structured_outputs` diff --git a/src/TensorFlowNET.Core/Functions/EagerDefinedFunction.cs b/src/TensorFlowNET.Core/Functions/EagerDefinedFunction.cs index 61d3121c..c2f8e016 100644 --- a/src/TensorFlowNET.Core/Functions/EagerDefinedFunction.cs +++ b/src/TensorFlowNET.Core/Functions/EagerDefinedFunction.cs @@ -9,18 +9,27 @@ using Tensorflow.Eager; using Tensorflow.Graphs; using Tensorflow.Operations; using Tensorflow.Util; +using Tensorflow.Common.Extensions; using static Tensorflow.Binding; +using Tensorflow.Framework; +using System.Buffers; +using Tensorflow.Gradients; namespace Tensorflow.Functions { - public class EagerDefinedFunction + public class EagerDefinedFunction: IDisposable { public int _num_outputs; - FuncGraph _func_graph; + FuncGraph _graph; FunctionDef _definition; OpDef _signature; string _name; - Tensor[] _func_graph_outputs; + internal ScopedTFFunction _c_func; + internal Tensor[] _func_graph_outputs; + internal string _grad_func_name; + internal Func csharp_grad_func; + internal EagerDefinedFunction _grad_func; + internal bool _registered_on_context = false; public string Name => _name; public DataType[] OutputTypes { get; protected set; } public Shape[] OutputShapes { get; protected set; } @@ -47,48 +56,93 @@ namespace Tensorflow.Functions return _signature; } } - public EagerDefinedFunction(string name, FuncGraph graph, + public unsafe EagerDefinedFunction(string name, FuncGraph graph, Tensors inputs, Tensors outputs, - Dictionary attrs) + Dictionary attrs) { var input_ops = inputs.Select(x => x.op).ToArray(); var operations = graph.get_operations().Where(x => !input_ops.Contains(x.op)) .Select(x => x as Operation).ToArray(); - var output_names = new string[0]; - _func_graph = new FuncGraph(graph, name, attrs); - _func_graph_outputs = new List(outputs).ToArray(); - _func_graph.ToGraph(operations, inputs, outputs, output_names); + var graph_output_names = graph._output_names; + string[] output_names; + if(graph_output_names is not null && outputs.All(t => graph_output_names.ContainsKey(ops.tensor_id(t)))) + { + output_names = outputs.Select(t => graph_output_names[ops.tensor_id(t)]).ToArray(); + if(output_names.Distinct().Count() != output_names.Length) + { + output_names = new string[0]; + } + } + else + { + output_names = new string[0]; + } + + Status status = new Status(); + var fn = c_api.TF_GraphToFunction(graph.c_graph, + name, + false, + operations.Length, + operations.Length == 0 ? new IntPtr[0] : operations.Select(x => (IntPtr)x).ToArray(), + inputs.Length, + inputs.Select(t => t._as_tf_output()).ToArray(), + outputs.Length, + outputs.Select(t => t._as_tf_output()).ToArray(), + output_names.Length != outputs.Length ? null : output_names, + IntPtr.Zero, // warning: the control output hasbben totally ignored. + null, + status); + status.Check(true); + + _c_func = new ScopedTFFunction(fn, name); + + foreach(var (attr_name, attr_value) in attrs) + { + var serialized = attr_value.ToByteArray(); + c_api.TF_FunctionSetAttrValueProto(fn, attr_name, serialized, serialized.Length, status); + status.Check(true); + } var signature = _get_definition().Signature; _name = signature.Name; - // TODO(Rinne): deal with `fn` + tf_with(ops.init_scope(), s => + { + tf.Context.add_function(fn); + _registered_on_context = true; + }); _num_outputs = signature.OutputArg.Count; OutputTypes = signature.OutputArg.Select(x => x.Type).ToArray(); OutputShapes = outputs.Select(x => x.shape).ToArray(); + _func_graph_outputs = new List(outputs).ToArray(); + csharp_grad_func = null; + _graph = graph; } - public Tensors Call(Tensors args) + public unsafe Tensors Call(Tensors args) { // TODO(Rinne): Add arg `CancellationManager`. // TODO(Rinne): Check the arg length. var function_call_options = tf.Context.FunctionCallOptions; string config; - if (string.IsNullOrEmpty(function_call_options.config_proto_serialized())) + if (function_call_options.config_proto_serialized().Length == 0) { - config = function_utils.get_disabled_rewriter_config(); + config = function_utils.get_disabled_rewriter_config().ToString(); } else { - config = function_call_options.config_proto_serialized(); + config = function_call_options.config_proto_serialized().ToString(); } - // TODO(Rinne): executor_type + + config = ""; // TODO(Rinne): revise it. + + string executor_type = function_call_options.ExecutorType ?? ""; var executing_eagerly = tf.Context.executing_eagerly(); var attrs = new object[] { - "executor_type", "", - "config_proto", tf.Context.FunctionCallOptions.config_proto_serialized() + "executor_type", executor_type, + "config_proto", config }; Tensor[] outputs; @@ -103,9 +157,19 @@ namespace Tensorflow.Functions } else { - tf.GradientTape().stop_recording(); - outputs = functional_ops.partitioned_call(args, this, OutputTypes, - executing_eagerly, config, ""); + if(tf.GetTapeSet().Count == 0) + { + outputs = functional_ops.partitioned_call(args, this, OutputTypes, + executing_eagerly, config, ""); + } + else + { + var tape = tf.GetTapeSet().Peek(); + tape.StopRecord(); + outputs = functional_ops.partitioned_call(args, this, OutputTypes, + executing_eagerly, config, ""); + tape.StartRecord(); + } } foreach(var (i, func_graph_output) in enumerate(_func_graph_outputs)) { @@ -141,7 +205,7 @@ namespace Tensorflow.Functions { g.AddFunction(this); } - foreach(var f in _func_graph.Functions.Values) + foreach(var f in _graph.Functions.Values) { if (!g.IsFunction(f.Name)) { @@ -155,12 +219,15 @@ namespace Tensorflow.Functions { var buffer = c_api_util.tf_buffer(); Status status = new(); - c_api.TF_FunctionToFunctionDef(_func_graph._func_graph_handle, buffer, status); + c_api.TF_FunctionToFunctionDef(_c_func.Get(), buffer, status); status.Check(true); var proto_data = c_api.TF_GetBuffer(buffer); - FunctionDef function_def = new(); - function_def.MergeFrom(proto_data.AsSpan()); - return function_def; + return FunctionDef.Parser.ParseFrom(proto_data.AsSpan()); + } + + public void Dispose() + { + tf.Context.remove_function(Name); } } } diff --git a/src/TensorFlowNET.Core/Functions/FirstOrderTapeGradientFunctions.cs b/src/TensorFlowNET.Core/Functions/FirstOrderTapeGradientFunctions.cs index 3c099927..c0e69dba 100644 --- a/src/TensorFlowNET.Core/Functions/FirstOrderTapeGradientFunctions.cs +++ b/src/TensorFlowNET.Core/Functions/FirstOrderTapeGradientFunctions.cs @@ -17,9 +17,9 @@ namespace Tensorflow.Functions public override EagerDefinedFunction ForwardAndBackwardFunctions(Tensors inference_args) { var outputs = _func_graph.Outputs; - (_forward, _forward_graph, _backward, _forwardprop_output_indices, _num_forwardprop_outputs) + (_forward_function, _forward_graph, _backward_function, _forwardprop_output_indices, _num_forwardprop_outputs) = BuildFunctionsForOutputs(outputs, inference_args); - return _forward; + return _forward_function; } } } diff --git a/src/TensorFlowNET.Core/Functions/Function.cs b/src/TensorFlowNET.Core/Functions/Function.cs index cfea3954..a53df14c 100644 --- a/src/TensorFlowNET.Core/Functions/Function.cs +++ b/src/TensorFlowNET.Core/Functions/Function.cs @@ -10,23 +10,26 @@ namespace Tensorflow private IntPtr _handle; #pragma warning restore CS0169 // The field 'Function._handle' is never used - protected Func _function; + protected Func _csharp_function; protected ConcreteFunction _concrete_variable_creation_fn; - protected bool _auto_graph; + protected bool _autograph; + protected TracingCompiler _variable_creation_fn; + protected bool _has_initialized; public string Name { get; set; } - public Function(Func function, + public Function(Func csharp_function, string name, bool auto_graph = true) { - _function = function; + _csharp_function = csharp_function; Name = name; - _auto_graph = auto_graph; + _autograph = auto_graph; + _has_initialized = false; } public virtual Tensors Apply(Tensors inputs) { if (_run_functions_eagerly()) { - return _function(inputs); + return _csharp_function(inputs); } var result = _call(inputs); @@ -35,20 +38,32 @@ namespace Tensorflow protected virtual Tensors _call(Tensors inputs) { - _initialize(); + if (!_has_initialized) + { + _initialize(inputs); + } return _concrete_variable_creation_fn.CallFlat(inputs, _concrete_variable_creation_fn.CapturedInputs); } + protected TracingCompiler _compiler(Func fn) + { + var name = nameof(fn); + return new TracingCompiler(fn, name, autograph: _autograph); + } + protected virtual bool _run_functions_eagerly() { return false; } - private void _initialize() + private void _initialize(Tensor[] args) { - + _variable_creation_fn = _compiler(_csharp_function); + _variable_creation_fn._name = this.Name; + _concrete_variable_creation_fn = _variable_creation_fn._get_concrete_function_internal_garbage_collected(args); + _has_initialized = true; } } } diff --git a/src/TensorFlowNET.Core/Functions/TapeGradientFunctions.cs b/src/TensorFlowNET.Core/Functions/TapeGradientFunctions.cs index 23889d44..638aeaf1 100644 --- a/src/TensorFlowNET.Core/Functions/TapeGradientFunctions.cs +++ b/src/TensorFlowNET.Core/Functions/TapeGradientFunctions.cs @@ -3,8 +3,10 @@ using System.Collections.Generic; using System.Linq; using System.Text; using Tensorflow.Eager; +using Tensorflow.Gradients; using Tensorflow.Graphs; using Tensorflow.NumPy; +using Tensorflow.Operations; using static Tensorflow.Binding; using static Tensorflow.tensorflow; @@ -22,11 +24,11 @@ namespace Tensorflow.Functions protected string _INFERENCE_PREFIX = "__inference_"; protected FuncGraph _func_graph; - protected EagerDefinedFunction _forward; + protected EagerDefinedFunction _forward_function; protected FuncGraph _forward_graph; protected List _forwardprop_output_indices; protected int _num_forwardprop_outputs; - protected ConcreteFunction _backward; + protected ConcreteFunction _backward_function; BackwardFunction _backward_function_wrapper; public TapeGradientFunctions(FuncGraph func_graph, @@ -49,8 +51,8 @@ namespace Tensorflow.Functions public virtual void Record(Tensors flat_outputs, Tensors inference_args) { // TODO(Rinne): add arg `input_tagents`. - var (backward_function, to_record) = _wrap_backward_function(_forward_graph, _backward, flat_outputs); - tf.Runner.RecordGradient(_forward.Name, inference_args, new object[0], to_record, + var (backward_function, to_record) = _wrap_backward_function(_forward_graph, _backward_function, flat_outputs); + tf.Runner.RecordGradient(_forward_function.Name, inference_args, new object[0], to_record, getBackwardFunction: backward_function); } @@ -134,46 +136,58 @@ namespace Tensorflow.Functions var trainable_indices = new List(); foreach(var (index, output) in enumerate(outputs)) { - if (gradients_util.IsTrainable(output)) + if (backprop_util.IsTrainable(output)) { trainable_outputs.Add(output); trainable_indices.Add(index); } } - var gradients_wrt_outputs = new List(); - var backwards_graph = new FuncGraph($"{_BACKWARD_PREFIX}_{_func_graph.FuncName}_{ops.uid()}"); + var backwards_graph = new FuncGraph(_func_graph.Name); backwards_graph.as_default(); + var gradients_wrt_outputs = new List(); foreach (var output in trainable_outputs) - gradients_wrt_outputs.Add(tf.placeholder(output.dtype, output.shape)); + { + var (gradient_shape, gradient_dtype) = default_gradient.shape_and_dtype(output); + var gradient_placeholder = tf.placeholder(gradient_dtype, gradient_shape); + gradients_wrt_outputs.Add(gradient_placeholder); + handle_data_util.copy_handle_data(output, gradient_placeholder); + } var gradients_wrt_inputs = gradients_util._GradientsHelper(trainable_outputs.ToArray(), - _func_graph.Inputs, - grad_ys: gradients_wrt_outputs.ToArray(), - src_graph: _func_graph); + _func_graph.Inputs, + grad_ys: gradients_wrt_outputs.ToArray(), + src_graph: _func_graph); var captures_from_forward = backwards_graph.external_captures .Where(x => x is not EagerTensor && x is not NDArray && x.graph == _func_graph) .ToArray(); + HashSet existing_outputs = new(_func_graph.Outputs); foreach(var capture in captures_from_forward) { - if (!_func_graph.Outputs.Contains(capture)) + if (!existing_outputs.Contains(capture)) + { + existing_outputs.Add(capture); _func_graph.Outputs.Add(capture); + } } backwards_graph.Exit(); - var forward_function_name = $"{_FORWARD_PREFIX}_{_func_graph.FuncName}_{ops.uid()}"; - var backward_function_attr = new Dictionary(); - backward_function_attr[FORWARD_FUNCTION_ATTRIBUTE_NAME] = forward_function_name; - gradients_wrt_outputs.append(backwards_graph.internal_captures); - backwards_graph.Inputs = gradients_wrt_outputs; - backwards_graph.Outputs = gradients_wrt_inputs; + backwards_graph.Inputs = gradients_wrt_outputs.Concat(backwards_graph.internal_captures).ToArray(); + backwards_graph.Outputs.AddRange(gradients_wrt_inputs.Where(x => x is not null)); + + var (forward_function, backward_function) = monomorphic_function_utils._create_forward_backward_with_graph(null, _func_graph, backwards_graph); + //var forward_function_name = $"{_FORWARD_PREFIX}_{_func_graph.FuncName}_{ops.uid()}"; + //var backward_function_attr = new Dictionary(); + //backward_function_attr[FORWARD_FUNCTION_ATTRIBUTE_NAME] = forward_function_name; - var backward_function = new ConcreteFunction(backwards_graph, backward_function_attr); + //var backward_function = new ConcreteFunction(backwards_graph, + // monomorphic_function_utils._parse_func_attrs(backward_function_attr)); - var forward_function_attr = new Dictionary(); - forward_function_attr[BACKWARD_FUNCTION_ATTRIBUTE_NAME] = backward_function.Name; - var forward_function = new EagerDefinedFunction(forward_function_name, _func_graph, - _func_graph.Inputs, _func_graph.Outputs, forward_function_attr); + //var forward_function_attr = new Dictionary(); + //forward_function_attr[BACKWARD_FUNCTION_ATTRIBUTE_NAME] = backward_function.Name; + //var forward_function = new EagerDefinedFunction(forward_function_name, _func_graph, + // _func_graph.Inputs, _func_graph.Outputs, + // monomorphic_function_utils._parse_func_attrs(forward_function_attr)); return (forward_function, _func_graph, backward_function, null, 0); } diff --git a/src/TensorFlowNET.Core/Functions/TracingCompiler.cs b/src/TensorFlowNET.Core/Functions/TracingCompiler.cs new file mode 100644 index 00000000..8a844671 --- /dev/null +++ b/src/TensorFlowNET.Core/Functions/TracingCompiler.cs @@ -0,0 +1,84 @@ +using System; +using System.Collections.Generic; +using System.Security.Cryptography.X509Certificates; +using System.Text; +using Tensorflow.Graphs; + +namespace Tensorflow.Functions +{ + public class TracingCompiler + { + Func _csharp_function; + //FunctionSpec _function_spec; + internal string _name; + bool _autograph; + Dictionary _function_cache; + Dictionary _function_attributes; + int _tracing_count; + + public TracingCompiler(Func csharp_function, string name, object? input_signatures = null, + Dictionary attributes = null, bool autograph = true, object? autograph_options = null, + bool reduce_retracing = false, bool capture_by_value = false) + { + _csharp_function = csharp_function; + bool pure_function = attributes is not null && attributes.Count > 0 && attributes.ContainsKey(monomorphic_function_utils.IMPLEMENTS_ATTRIBUTE_NAME); + _name = name; + _autograph = autograph; + _function_attributes = attributes ?? new Dictionary(); + _function_cache = new Dictionary(); + _tracing_count = 0; + } + + public Tensor[] Apply(Tensor[] inputs) + { + // TODO(Rinne): add lock here. + var (concrete_function, filtered_flat_args) = _maybe_define_function(inputs); + return concrete_function.CallFlat(filtered_flat_args, concrete_function.CapturedInputs); + } + + internal ConcreteFunction _get_concrete_function_internal_garbage_collected(Tensor[] args) + { + var (concrete_function, _) = _maybe_define_concrete_function(args); + return concrete_function; + } + + private (ConcreteFunction, Tensor[]) _maybe_define_concrete_function(Tensor[] args) + { + return _maybe_define_function(args); + } + + private (ConcreteFunction, Tensor[]) _maybe_define_function(Tensor[] args) + { + var lookup_func_key = male_cache_key(args); + if(_function_cache.TryGetValue(lookup_func_key, out var concrete_function)) + { + return (concrete_function, args); + } + concrete_function = _create_concrete_function(args); + _function_cache[lookup_func_key] = concrete_function; + return (concrete_function, args); + } + + private ConcreteFunction _create_concrete_function(Tensor[] args) + { + _tracing_count++; + + int arglen = args.Length; + var concrete_function = new ConcreteFunction(FuncGraph.func_graph_from_func( + _name, x => _csharp_function(x.Where(y => y is Tensor).Select(y => (Tensor)y).ToArray()), + args, new Dictionary(), autograph: _autograph + ), _function_attributes); + return concrete_function; + } + + private static string male_cache_key(Tensor[] inputs) + { + string res = ""; + foreach (var input in inputs) + { + res += $"{input.name}_{input.Id}"; + } + return res; + } + } +} diff --git a/src/TensorFlowNET.Core/Functions/c_api.function.cs b/src/TensorFlowNET.Core/Functions/c_api.function.cs index 3fbb3868..04d102b5 100644 --- a/src/TensorFlowNET.Core/Functions/c_api.function.cs +++ b/src/TensorFlowNET.Core/Functions/c_api.function.cs @@ -16,6 +16,7 @@ using System; using System.Runtime.InteropServices; +using Tensorflow.Functions; namespace Tensorflow { @@ -54,6 +55,9 @@ namespace Tensorflow public static extern IntPtr TF_FunctionName(SafeFuncGraphHandle func); [DllImport(TensorFlowLibName)] - public static extern void TF_GraphCopyFunction(SafeGraphHandle g, SafeFuncGraphHandle func, IntPtr grad, SafeStatusHandle status); + public static extern void TF_GraphCopyFunction(SafeGraphHandle g, SafeFuncGraphHandle func, SafeFuncGraphHandle grad, SafeStatusHandle status); + + [DllImport(TensorFlowLibName)] + public static extern int TF_GraphGetFunctions(SafeGraphHandle g, IntPtr[] funcs, int max_func, SafeStatusHandle status); } } diff --git a/src/TensorFlowNET.Core/Functions/composite_tensor_utils.cs b/src/TensorFlowNET.Core/Functions/composite_tensor_utils.cs new file mode 100644 index 00000000..7994bef1 --- /dev/null +++ b/src/TensorFlowNET.Core/Functions/composite_tensor_utils.cs @@ -0,0 +1,50 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Framework; +using Tensorflow.Framework.Models; +using Tensorflow.Util; + +namespace Tensorflow.Functions +{ + internal static class composite_tensor_utils + { + public static List flatten_with_variables(object inputs) + { + List flat_inputs = new(); + foreach(var value in nest.flatten(inputs)) + { + if(value is CompositeTensor && !resource_variable_ops.is_resource_variable(value)) + { + throw new NotImplementedException("The composite tensor has not been fully supported."); + } + else + { + flat_inputs.Add(value); + } + } + return flat_inputs; + } + public static List flatten_with_variables_or_variable_specs(object arg) + { + List flat_inputs = new(); + foreach(var value in nest.flatten(arg)) + { + if(value is CompositeTensor && !resource_variable_ops.is_resource_variable(value)) + { + throw new NotImplementedException("The composite tensor has not been fully supported."); + } + // TODO(Rinne): deal with `VariableSpec`. + else if(value is TypeSpec type_spec && value is not TensorSpec) + { + throw new NotImplementedException("The TypeSpec has not been fully supported."); + } + else + { + flat_inputs.Add(value); + } + } + return flat_inputs; + } + } +} diff --git a/src/TensorFlowNET.Core/Functions/function_saved_model_utils.cs b/src/TensorFlowNET.Core/Functions/function_saved_model_utils.cs index e92fa3a1..b3caef96 100644 --- a/src/TensorFlowNET.Core/Functions/function_saved_model_utils.cs +++ b/src/TensorFlowNET.Core/Functions/function_saved_model_utils.cs @@ -34,11 +34,10 @@ namespace Tensorflow.Functions "https://github.com/SciSharp/TensorFlow.NET/issues"); } }); - var bound_variables = inputs.TakeWhile(obj => obj is IVariableV1); + var bound_variables = inputs.Where(obj => obj is IVariableV1).Select(x => (IVariableV1)x); List captured_inputs_list = new(); - // TODO(Rinne): concrete_function.set_variables(bound_variables) - + concrete_function.set_variables(bound_variables); if (bound_inputs is not null) { @@ -54,8 +53,14 @@ namespace Tensorflow.Functions concrete_function.func_graph.replace_capture(bound_input, internal_capture); if(internal_capture.dtype == dtypes.resource) { - // skip the check of variable. - handle_data_util.copy_handle_data(bound_input, internal_capture); + if (resource_variable_ops.is_resource_variable(bound_input)) + { + handle_data_util.copy_handle_data(bound_input.Handle, internal_capture); + } + else + { + handle_data_util.copy_handle_data(bound_input, internal_capture); + } } concrete_function.func_graph.capture(bound_input); } diff --git a/src/TensorFlowNET.Core/Functions/monomorphic_function.cs b/src/TensorFlowNET.Core/Functions/monomorphic_function.cs index a8769438..acf00597 100644 --- a/src/TensorFlowNET.Core/Functions/monomorphic_function.cs +++ b/src/TensorFlowNET.Core/Functions/monomorphic_function.cs @@ -1,20 +1,137 @@ -using System; +using Google.Protobuf; +using System; using System.Collections.Generic; using System.Text; +using Tensorflow.Eager; +using Tensorflow.Framework.Models; +using Tensorflow.Gradients; using Tensorflow.Graphs; +using Tensorflow.Common.Extensions; +using Tensorflow.Operations; +using Tensorflow.Framework; +using static Tensorflow.Binding; +using System.Diagnostics; namespace Tensorflow.Functions { - public class DelayedRewriteGradientFunctions: TapeGradientFunctions + internal static class monomorphic_function_utils + { + internal static string _FORWARD_PREFIX = "__forward_"; + internal static string _BACKWARD_PREFIX = "__backward_"; + internal static string _INFERENCE_PREFIX = "__inference_"; + internal static string IMPLEMENTS_ATTRIBUTE_NAME = "_implements"; + internal static string FORWARD_FUNCTION_ATTRIBUTE_NAME = "forward_function_name"; + internal static string BACKWARD_FUNCTION_ATTRIBUTE_NAME = "backward_function_name"; + public static string _inference_name(string name) + { + return $"{_INFERENCE_PREFIX}{name}_{ops.uid()}"; + } + public static string _forward_name(string name) + { + return $"{_FORWARD_PREFIX}{name}_{ops.uid()}"; + } + public static string _backward_name(string name) + { + return $"{_BACKWARD_PREFIX}{name}_{ops.uid()}"; + } + + public static (EagerDefinedFunction, ConcreteFunction) _create_forward_backward_with_graph(Dictionary attrs, + FuncGraph forward_graph, FuncGraph backwards_graph) + { + string forward_function_name = _forward_name(forward_graph.Name); + Dictionary common_attributes; + if(attrs is null) + { + common_attributes = new Dictionary(); + } + else + { + common_attributes = new Dictionary(attrs); + } + + if (common_attributes.ContainsKey(IMPLEMENTS_ATTRIBUTE_NAME)) + { + common_attributes.Remove(IMPLEMENTS_ATTRIBUTE_NAME); + } + var backward_function_attr = _parse_func_attrs(new Dictionary() + { + {FORWARD_FUNCTION_ATTRIBUTE_NAME, forward_function_name } + }); + backward_function_attr.Update(common_attributes); + var backward_function = new ConcreteFunction(backwards_graph, backward_function_attr); + var forward_function_attr = _parse_func_attrs(new Dictionary() + { + {BACKWARD_FUNCTION_ATTRIBUTE_NAME, backward_function.Name } + }); + forward_function_attr.Update(common_attributes); + var forward_function = new EagerDefinedFunction(forward_function_name, forward_graph, + forward_graph.Inputs, forward_graph.Outputs, forward_function_attr); + return (forward_function, backward_function); + } + + public static Dictionary _parse_func_attrs(Dictionary attributes) + { + Dictionary attrs = new(); + foreach(var item in attributes) + { + var key = item.Key; + var value = item.Value; + if (value is AttrValue attr_value) + { + attrs[key] = attr_value; + } + else if (value is bool b) + { + attrs[key] = new AttrValue() { B = b }; + } + else if (value is int i) + { + attrs[key] = new AttrValue() { I = i }; + } + else if (value is float f) + { + attrs[key] = new AttrValue() { F = f }; + } + else if(value is string s) + { + attrs[key] = new AttrValue() { S = ByteString.CopyFromUtf8(s) }; + } + else if (value is byte[] bytes) + { + attrs[key] = new AttrValue() { S = ByteString.CopyFrom(bytes) }; + } + else + { + throw new ValueError($"Attribute {key} must be bool, int, float, string, or " + + $"AttrValue. Got {value.GetType()}."); + } + } + return attrs; + } + + public static Dictionary _parse_func_attrs(Dictionary attributes) + { + Dictionary attrs = new(); + foreach (var item in attributes) + { + var key = item.Key; + var value = item.Value; + attrs[key] = new AttrValue() { S = ByteString.CopyFromUtf8(value) }; + } + return attrs; + } + } + public class DelayedRewriteGradientFunctions : TapeGradientFunctions { EagerDefinedFunction _inference_function; - Dictionary _attrs; + Dictionary _attrs; int _num_inference_outputs; - public DelayedRewriteGradientFunctions(FuncGraph func_graph, Dictionary attrs) - :base(func_graph, false) + Dictionary _cached_function_pairs = new(); + public DelayedRewriteGradientFunctions(FuncGraph func_graph, Dictionary attrs) + : base(func_graph, false) { - _func_graph= func_graph; - _inference_function = new EagerDefinedFunction(_inference_name(_func_graph.Name), + _func_graph = func_graph; + _inference_function = new EagerDefinedFunction(monomorphic_function_utils._inference_name(_func_graph.Name), _func_graph, _func_graph.Inputs, _func_graph.Outputs, attrs); _attrs = attrs; _num_inference_outputs = _func_graph.Outputs.Length; @@ -22,7 +139,7 @@ namespace Tensorflow.Functions public override EagerDefinedFunction Forward(Tensors inference_args = null, Tensors input_tangents = null) { - if(input_tangents is not null) + if (input_tangents is not null) { throw new InvalidArgumentError($"unexpectedly got forwardprop information in " + $"a class that does not support forwardprop."); @@ -32,23 +149,134 @@ namespace Tensorflow.Functions public override void Record(Tensors flat_outputs, Tensors inference_args) { - // TODO(Rinne): implement it. - throw new NotImplementedException(); - base.Record(flat_outputs, inference_args); + var (backward_function, to_record) = _backward(flat_outputs); + foreach(var tape in tf.GetTapeSet()) + { + tape.RecordOperation(_inference_function.Signature.Name, to_record, + inference_args.Select(t => new TapeTensor(t)).ToArray(), backward_function); + } } - //private (BackwardFunction, Tensors) _backward(Tensors outputs) - //{ - // Tensor[] backward_function(Tensor[] grads, long[] unneeded_gradients) - // { - // var call_op = outputs[0].op; + public (EagerDefinedFunction, ConcreteFunction) forward_backward(int num_doutputs = -2) + { + if(num_doutputs == -2) + { + num_doutputs = _num_inference_outputs; + } + if(_cached_function_pairs.TryGetValue(num_doutputs, out var target)) + { + return target; + } + var (forward, backward) = _construct_forward_backward(num_doutputs); + _cached_function_pairs[num_doutputs] = (forward, backward); + return (forward, backward); - // } - //} + } - private string _inference_name(string name) + private (BackwardFunction, Tensors) _backward(Tensors outputs) { - return $"{_INFERENCE_PREFIX}{name}_{ops.uid()}"; + Tensor[] backward_function(Tensor[] args, long[] unneeded_gradients) + { + var call_op = outputs[0].op; + return _rewrite_forward_and_call_backward(call_op, args); + } + return (backward_function, outputs); + } + + internal Tensor[] _rewrite_forward_and_call_backward(Operation op, params object[] doutputs) + { + var (forward_function, backward_function) = forward_backward(doutputs.Length); + if(backward_function.Outputs is null || backward_function.Outputs.Length == 0) + { + return backward_function.FlatStructuredOutputs; + } + forward_function.AddToGraph(op.graph); + + op._set_func_attr("f", forward_function.Name); + op._set_type_list_attr("Tout", forward_function.OutputTypes); + op._add_outputs(forward_function.OutputTypes.Select(x => x.as_tf_dtype()). + Skip(op.outputs.Length).ToArray(), forward_function.OutputShapes.Skip(op.outputs.Length).ToArray() + ); + for(int i = 0; i < op.outputs.Length; i++) + { + var func_graph_output = forward_function._func_graph_outputs[i]; + handle_data_util.copy_handle_data(func_graph_output, op.outputs[i]); + } + + var capture_mapping = zip(_func_graph.Outputs.Select(t => ops.tensor_id(t)), op.outputs). + ToDictionary(x => x.Item1, x => x.Item2); + var remapped_captures = backward_function.CapturedInputs.Select( + x => capture_mapping.GetOrDefault(ops.tensor_id(x), x) + ); + + List cleaned_doutputs = new(); + foreach(var (doutput, placeholder) in zip(doutputs, _func_graph.Outputs)) + { + if (backprop_util.IsTrainable(placeholder)) + { + if(doutput is IndexedSlices) + { + cleaned_doutputs.Add(ops.convert_to_tensor(doutput)); + } + else if(doutput is null) + { + cleaned_doutputs.Add(default_gradient.zeros_like(placeholder)); + } + else if(doutput is Tensor tensor) + { + cleaned_doutputs.Add(tensor); + } + else + { + throw new ValueError($"Unsupported type {doutput.GetType()} in function _rewrite_forward_and_call_backward"); + } + } + } + + return backward_function.CallFlat(cleaned_doutputs.ToArray(), remapped_captures.ToArray()); + } + + private (EagerDefinedFunction, ConcreteFunction) _construct_forward_backward(int num_doutputs) + { + var trainable_outputs = _func_graph.Outputs.Take(num_doutputs).Where(x => backprop_util.IsTrainable(x)); + + List signature = new(); + foreach(var t in trainable_outputs) + { + var (shape, dtype) = default_gradient.shape_and_dtype(t); + signature.Add(new TensorSpec(shape, dtype)); + } + + Tensor[] _backprop_function(Tensor[] grad_ys) + { + return gradients_util._GradientsHelper(trainable_outputs.ToArray(), _func_graph.Inputs, + grad_ys, src_graph: _func_graph); + } + + _func_graph.as_default(); + FuncGraph backwards_graph = new(monomorphic_function_utils._backward_name(_func_graph.Name)); + FuncGraph.func_graph_from_func(backwards_graph.Name, x => _backprop_function(x.Select(y => + { + Debug.Assert(y is Tensor); + return (Tensor)y; + }).ToArray()), new object[0], new Dictionary(), signature.ToArray(), backwards_graph); + var backwards_graph_captures = backwards_graph.external_captures; + var captures_from_forward = backwards_graph_captures.Where(c => c is not EagerTensor && c.graph == _func_graph); + + HashSet existing_outputs = new HashSet(_func_graph.Outputs); + foreach(var capture in captures_from_forward) + { + if (!existing_outputs.Contains(capture)) + { + existing_outputs.Add(capture); + _func_graph.Outputs.Add(capture); + } + } + + var (forward_function, backward_function) = monomorphic_function_utils._create_forward_backward_with_graph( + _attrs, _func_graph, backwards_graph); + _func_graph.Exit(); + return (forward_function, backward_function); } } } diff --git a/src/TensorFlowNET.Core/Gradients/default_gradient.cs b/src/TensorFlowNET.Core/Gradients/default_gradient.cs new file mode 100644 index 00000000..e6c22e36 --- /dev/null +++ b/src/TensorFlowNET.Core/Gradients/default_gradient.cs @@ -0,0 +1,52 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Gradients +{ + internal static class default_gradient + { + public static (Shape, TF_DataType) shape_and_dtype(Tensor t) + { + if(t.dtype == dtypes.resource) + { + var handle_data = resource_variable_ops.get_eager_safe_handle_data(t); + if(handle_data is null || !handle_data.IsSet || handle_data.ShapeAndType.Count != 1) + { + throw new ValueError($"Internal error: Tried to take gradients (or similar) " + + $"of a variable without handle data:\n{t}"); + } + return (new Shape(handle_data.ShapeAndType[0].Shape), handle_data.ShapeAndType[0].Dtype.as_tf_dtype()); + } + return (t.shape, t.dtype); + } + + public static Tensor zeros_like(Tensor t) + { + if(t.dtype == dtypes.resource) + { + var (shape, dtype) = shape_and_dtype(t); + return array_ops.zeros(shape, dtype); + } + else + { + return array_ops.zeros_like(t); + } + } + + public static TF_DataType get_zeros_dtype(Tensor t) + { + if(t.dtype == dtypes.resource) + { + var handle_data = resource_variable_ops.get_eager_safe_handle_data(t); + if(handle_data is null || !handle_data.IsSet || handle_data.ShapeAndType.Count != 1) + { + throw new ValueError($"Internal error: Tried to take gradients (or similar) " + + $"of a variable without handle data:\n{t}"); + } + return handle_data.ShapeAndType[0].Dtype.as_tf_dtype(); + } + return t.dtype; + } + } +} diff --git a/src/TensorFlowNET.Core/Gradients/gradients_util.cs b/src/TensorFlowNET.Core/Gradients/gradients_util.cs index e6312c0d..10166911 100644 --- a/src/TensorFlowNET.Core/Gradients/gradients_util.cs +++ b/src/TensorFlowNET.Core/Gradients/gradients_util.cs @@ -14,10 +14,15 @@ limitations under the License. ******************************************************************************/ +using Google.Protobuf; using System; using System.Collections.Generic; +using System.Diagnostics; using System.Linq; +using Tensorflow.Functions; +using Tensorflow.Gradients; using Tensorflow.Graphs; +using Tensorflow.Operations; using Tensorflow.Operations.ControlFlows; using static Tensorflow.Binding; @@ -148,7 +153,7 @@ namespace Tensorflow Tensor[] in_grads = null; Func grad_fn = null; var is_partitioned_call = _IsPartitionedCall(op); - var is_func_call = false; + var is_func_call = src_graph.IsFunction(op.type) || is_partitioned_call; var has_out_grads = out_grads.Exists(x => x != null); if (has_out_grads && !stop_ops.Contains(op)) { @@ -162,14 +167,41 @@ namespace Tensorflow { if (is_func_call) { + EagerDefinedFunction func_call = null; if (is_partitioned_call) { - + var func_attr = op.get_attr("f"); + Debug.Assert(func_attr is NameAttrList); + var func_name = ((NameAttrList)func_attr).Name; + func_call = src_graph._get_function(func_name); + if(func_call is null && src_graph.OuterGraph is not null) + { + var graph = src_graph.OuterGraph; + while(graph is not null) + { + func_call = graph._get_function(func_name); + if(func_call is not null) + { + break; + } + if(graph.OuterGraph is not null) + { + graph = graph.OuterGraph; + } + else + { + break; + } + } + } } else { - + func_call = src_graph._get_function(op.type); } + // skip the following codes: + // `func_call = getattr(op, "__defun", func_call)` + grad_fn = func_call.csharp_grad_func; } else { @@ -213,6 +245,8 @@ namespace Tensorflow } else { + in_grads = _MaybeCompile(grad_scope, op, out_grads.Where(x => x != null).Select(x => x[0]).ToArray(), + null, (x, y) => _SymGrad(x, y)); throw new NotImplementedException("lambda: _SymGrad(op, out_grads)"); } _VerifyGeneratedGradients(in_grads, op); @@ -668,6 +702,36 @@ namespace Tensorflow dtypes.resource, dtypes.variant}.Contains(dtype); } + public static int PossibleTapeGradientTypes(Tensor[] tensors) + { + var tape_set = tf.GetTapeSet(); + bool some_tape_watching = false; + if(tape_set is not null && tape_set.Count > 0) + { + foreach(var tape in tape_set) + { + if (tape.ShouldRecord(tensors)) + { + if(tape.Persistent || some_tape_watching) + { + return POSSIBLE_GRADIENT_TYPES_HIGHER_ORDER; + } + some_tape_watching = true; + } + } + } + // skip the forward_accumulators. + + if (some_tape_watching) + { + return POSSIBLE_GRADIENT_TYPES_FIRST_ORDER; + } + else + { + return POSSIBLE_GRADIENT_TYPES_NONE; + } + } + /// /// Return true if op has real gradient. /// @@ -688,7 +752,7 @@ namespace Tensorflow private static Tensor[] _MaybeCompile(string scope, Operation op, Tensor[] out_grads, Action func, Func grad_fn) { - scope = scope.EndsWith("/") ? scope.Substring(0, scope.Length - 1) : scope; + //scope = scope.TrimEnd('/').Replace('/', '_'); return grad_fn(op, out_grads); } @@ -701,5 +765,28 @@ namespace Tensorflow throw new ValueError($"Num gradients {grads.Length} generated for op {op.node_def} do not match num " + $"inputs {op.inputs._inputs.Count()}"); } + + private static Tensor[] _SymGrad(Operation op, Tensor[] out_grads) + { + var f_in = ((Tensor[])op.inputs).Concat(out_grads).ToArray(); + var f_types = ((Tensor[])op.inputs).Select(x => default_gradient.get_zeros_dtype(x)).ToArray(); + NameAttrList f = new(); + if (_IsPartitionedCall(op)) + { + var func_attr = op.get_attr("f"); + Debug.Assert(func_attr is NameAttrList); + f.Name = ((NameAttrList)func_attr).Name; + } + else + { + f.Name = op.type; + } + foreach(var k in op.node_def.Attr.Keys) + { + f.Attr[k] = AttrValue.Parser.ParseFrom(op.node_def.Attr[k].ToByteArray()); + } + var in_grads = gen_functional_ops.symbolic_gradient(f_in, f_types, f); + return in_grads; + } } } diff --git a/src/TensorFlowNET.Core/Gradients/ops.gradient_function_mapping.cs b/src/TensorFlowNET.Core/Gradients/ops.gradient_function_mapping.cs index e5831f25..7d3ea171 100644 --- a/src/TensorFlowNET.Core/Gradients/ops.gradient_function_mapping.cs +++ b/src/TensorFlowNET.Core/Gradients/ops.gradient_function_mapping.cs @@ -98,12 +98,23 @@ namespace Tensorflow { if (op.inputs == null) return null; - RegisterFromAssembly(); + var gradient_function = op._gradient_function; + if(gradient_function is null) + { + RegisterFromAssembly(); + + if (!gradientFunctions.ContainsKey(op.type)) + throw new LookupError($"can't get graident function through get_gradient_function {op.type}"); - if (!gradientFunctions.ContainsKey(op.type)) - throw new LookupError($"can't get graident function through get_gradient_function {op.type}"); + return gradientFunctions[op.type]; + } - return gradientFunctions[op.type]; + Tensor[] wrapped_gradient_function(Operation operation, Tensor[] args) + { + return gradient_function(operation, args); + } + // TODO(Rinne): check if this needs to be registered. + return wrapped_gradient_function; } } } diff --git a/src/TensorFlowNET.Core/Graphs/FuncGraph.cs b/src/TensorFlowNET.Core/Graphs/FuncGraph.cs index 9367414e..9ef0b95b 100644 --- a/src/TensorFlowNET.Core/Graphs/FuncGraph.cs +++ b/src/TensorFlowNET.Core/Graphs/FuncGraph.cs @@ -1,6 +1,15 @@ using Google.Protobuf; +using System; +using System.Buffers; +using System.Diagnostics; +using System.Linq; using Tensorflow.Eager; using Tensorflow.Exceptions; +using Tensorflow.Framework; +using Tensorflow.Framework.Models; +using Tensorflow.Functions; +using Tensorflow.Operations; +using Tensorflow.Util; using static Tensorflow.Binding; namespace Tensorflow.Graphs; @@ -11,12 +20,65 @@ namespace Tensorflow.Graphs; public class FuncGraph : Graph, IDisposable { internal SafeFuncGraphHandle _func_graph_handle; + internal HashSet _resource_tensor_inputs; + internal HashSet> _watched_variables; + internal IEnumerable> _weak_variables; + internal object[] _structured_outputs; + internal Dictionary _output_names; public string FuncName => _graph_key; public Tensors Inputs { get; set; } = new Tensors(); public Tensors Outputs { get; set; } = new Tensors(); + public Tensors FlatStructuredOutputs + { + get + { + List res = new(); + foreach(var obj in _structured_outputs) + { + if(obj is Tensor tensor) + { + res.Add(tensor); + } + else if(obj is IEnumerable tensors) + { + res.AddRange(tensors); + } + else + { + throw new TypeError("The structured outputs member should be tensor or tensors."); + } + } + return res; + } + } public string Name { get; set; } - public Dictionary Attrs { get; set; } + public IEnumerable Variables + { + get + { + return _weak_variables.Select(v => + { + if (v.TryGetTarget(out var target)) + { + return target; + } + else + { + throw new AssertionError("Called a function referencing variables which have been deleted. " + + "This likely means that function-local variables were created and " + + "not referenced elsewhere in the program. This is generally a " + + "mistake; consider storing variables in an object attribute on first call."); + } + }); + } + internal set + { + _weak_variables = value.Select(x => new WeakReference(x)); + } + } + public IEnumerable TrainableVariables => Variables.Where(v => v.Trainable); + public Dictionary Attrs { get; set; } Dictionary _captures = new Dictionary(); @@ -42,9 +104,12 @@ public class FuncGraph : Graph, IDisposable outer_graph = outer_graph.OuterGraph; _graph_key = Name = name; building_function = true; + _weak_variables = new List>(); + _resource_tensor_inputs = new HashSet(); + _watched_variables = new HashSet>(); } - public FuncGraph(SafeGraphHandle handle, string name, Dictionary attrs) : base() + public FuncGraph(SafeGraphHandle handle, string name, Dictionary attrs) : base() { outer_graph = ops.get_default_graph(); while (outer_graph.building_function) @@ -55,6 +120,9 @@ public class FuncGraph : Graph, IDisposable // Will to test if FuncGraph has memory leak // c_api.TF_DeleteGraph(_handle); _handle = handle; + _weak_variables = new List>(); + _resource_tensor_inputs = new HashSet(); + _watched_variables = new HashSet>(); } public void replace_capture(Tensor tensor, Tensor placeholder) @@ -62,14 +130,14 @@ public class FuncGraph : Graph, IDisposable _captures[tensor.Id] = (tensor, placeholder); } - public void ToGraph(Operation[] opers, + public unsafe void ToGraph(Operation[] opers, Tensor[] inputs, Tensor[] outputs, string[] output_names) { var status = new Status(); - if (output_names != null && output_names.Length == 0) + if (output_names is null) { - output_names = null; + output_names = new string[0]; }; _func_graph_handle = c_api.TF_GraphToFunction(_handle, @@ -81,7 +149,7 @@ public class FuncGraph : Graph, IDisposable inputs.Select(x => new TF_Output(x.op, 0)).ToArray(), outputs.Length, outputs.Select(x => new TF_Output(x.op, 0)).ToArray(), - output_names, + output_names.Length != outputs.Length ? null : output_names, IntPtr.Zero, null, status); @@ -211,6 +279,19 @@ public class FuncGraph : Graph, IDisposable Inputs.Add(placeholder); } + Tensor pop_capture(Tensor tensor) + { + if(_captures.TryGetValue(tensor.Id, out var capture)) + { + _captures.Remove(tensor.Id); + return capture.Item2; + } + else + { + return null; + } + } + Tensor _create_substitute_placeholder(Tensor value, string name = null, TF_DataType dtype = TF_DataType.DtInvalid, @@ -234,10 +315,7 @@ public class FuncGraph : Graph, IDisposable foreach (var (_name, attr_value) in enumerate(Attrs)) { - var serialized = new AttrValue - { - S = ByteString.CopyFromUtf8(attr_value) - }.ToByteArray(); + var serialized = attr_value.ToByteArray(); c_api.TF_FunctionSetAttrValueProto(_func_graph_handle, _name, serialized, serialized.Length, tf.Status); tf.Status.Check(true); } @@ -260,4 +338,261 @@ public class FuncGraph : Graph, IDisposable { c_api.TFE_ContextRemoveFunction(tf.Context, _graph_key, tf.Status); } + + public static FuncGraph func_graph_from_func(string name, Func func, + object[] args, Dictionary kwargs, TensorSpec[] signature = null, + FuncGraph func_graph = null, bool autograph = false, object autograph_options = null, + bool add_control_dependencies = true, string[] arg_names = null, + Tensor op_return_value = null, bool capture_by_value = false, + bool acd_record_initial_resource_uses = false) + { + if(func_graph is null) + { + func_graph = new FuncGraph(name); + } + + // TODO(Rinne): deal with control dependencies. + + func_graph.as_default(); + var current_scope = variable_scope.get_variable_scope(); + var default_use_resource = current_scope.use_resource; + current_scope.use_resource = true; + + if(signature is not null) + { + args = signature; + kwargs = new Dictionary(); + } + var func_args = _get_defun_inputs_from_args(args, arg_names); + var func_kwargs = _get_defun_inputs_from_kwargs(kwargs); + + if(func_kwargs is not null && func_kwargs.Count > 0) + { + throw new NotImplementedException("The keyword args has not been supported in `func_graph_from_func`."); + } + + foreach(var arg in nest.flatten(new object[] { func_args, func_kwargs })) + { + if(arg is Tensor tensor && tensor.dtype == dtypes.resource) + { + func_graph._resource_tensor_inputs.Add(tensor); + } + else if (arg is ResourceVariable variable) + { + func_graph._resource_tensor_inputs.Add(variable.Handle); + } + } + + // skip the assignment of `func_graph.structured_input_signature`. + + var flat_func_args = nest.flatten(func_args as object); + var flat_func_kwargs = nest.flatten(func_kwargs as object); + func_graph.Inputs = new Tensors(flat_func_args.concat(flat_func_kwargs) + .Where(x => x is Tensor).Select(x => (Tensor)x)); + + //var func_args_before = nest.pack_sequence_as(func_args, flat_func_args, true); + //var func_kwargs_before = nest.pack_sequence_as(func_kwargs, flat_func_kwargs, true); + + Tensor convert(object x) + { + if (x is null) return null; + Tensor res = null; + if(op_return_value is not null && x is Operation) + { + tf_with(ops.control_dependencies(new object[] { x }), _ => + { + res = array_ops.identity(op_return_value); + }); + } + else if(x is not TensorArray) + { + Debug.Assert(x is Tensor); + res = ops.convert_to_tensor_or_composite(x as Tensor); + } + else + { + throw new NotImplementedException($"The `TensorArray` is not supported here currently."); + } + if (add_control_dependencies) + { + // TODO(Rinne): `x = deps_ctx.mark_as_return(x)`. + } + return res; + } + + if (autograph) + { + throw new NotImplementedException("The autograph of `func_graph_from_func` has not been supported."); + } + + var func_outputs = func(func_args); + func_outputs = variable_utils.convert_variables_to_tensors(func_outputs); + func_outputs = func_outputs.Select(x => convert(x)).ToArray(); + // TODO(Rinne): `check_func_mutation`. + + current_scope.use_resource = default_use_resource; + + var graph_variables = func_graph._watched_variables.ToList(); + HashSet arg_variables = new HashSet(); + List inputs = new(); + foreach(var arg in composite_tensor_utils.flatten_with_variables(func_args)) + { + if(arg is BaseResourceVariable variable) + { + var resource_placeholder = func_graph.pop_capture(variable.Handle); + if(resource_placeholder is null) + { + continue; + } + Debug.Assert(variable is IVariableV1); + arg_variables.Add(variable as IVariableV1); + inputs.Add(resource_placeholder); + } + else if(arg is Tensor tensor) + { + inputs.Add(tensor); + } + } + var variables = graph_variables.Select(v => + { + if (v.TryGetTarget(out var target)) + { + return target; + } + else + { + return null; + } + }).Where(v => v is not null && !arg_variables.Contains(v)); + func_graph.Inputs = inputs.Concat(func_graph.internal_captures).ToArray(); + func_graph._structured_outputs = func_outputs; + func_graph.Outputs.AddRange(func_graph.FlatStructuredOutputs.Where(x => x is not null) + .Select(x => func_graph.capture(x))); + + func_graph.Variables = variables; + + func_graph.Exit(); + + if (add_control_dependencies) + { + // TODO(Rinne): implement it. + } + return func_graph; + } + + private static object[] _get_defun_inputs_from_args(object[] args, string[] names) + { + return _get_defun_inputs(args, names, args) as object[]; + } + + private static Dictionary _get_defun_inputs_from_kwargs(Dictionary kwargs) + { + // TODO(Rinne): implement it. + Debug.Assert(kwargs is null || kwargs.Count == 0); + return kwargs; + //string[] names; + //object[] args; + //if(kwargs is not null && kwargs.Count > 0) + //{ + // var sorted_kwargs = kwargs.OrderBy(x => x.Key); + // names = sorted_kwargs.Select(x => x.Key).ToArray(); + // args = sorted_kwargs.Select(x => x.Value).ToArray(); + //} + //else + //{ + // names = new string[0]; + // args = new object[0]; + //} + //return _get_defun_inputs(args, names, kwargs) as Dictionary; + } + + private static object _get_defun_inputs(object[] args, string[] names, object structured_args) + { + List function_inputs = new(); + if(names is null) + { + names = new string[args.Length]; + } + + foreach(var (arg_value, name) in zip(args, names)) + { + foreach(var val in composite_tensor_utils.flatten_with_variables_or_variable_specs(arg_value)) + { + function_inputs.Add(_get_defun_input(val, name)); + } + } + return nest.pack_sequence_as(structured_args, nest.flatten(function_inputs), true); + } + + private static object _get_defun_input(object arg, string name) + { + var func_graph = ops.get_default_graph() as FuncGraph; + Debug.Assert(func_graph is not null); + if (arg is Tensor tensor) + { + Tensor placeholder; + try + { + placeholder = tf.placeholder(tensor.dtype, tensor.shape, name); + } + catch (ValueError) + { + // TODO(Rinne): Add warning here. + placeholder = tf.placeholder(tensor.dtype, tensor.shape); + } + handle_data_util.copy_handle_data(tensor, placeholder); + if (name is not null) + { + placeholder.op._set_attr("_user_specified_name", new AttrValue() + { + S = tf.compat.as_bytes(name) + }); + } + return placeholder; + } + else if (arg is TensorSpec spec) + { + string requested_name; + if (!string.IsNullOrEmpty(spec.name)) + { + requested_name = spec.name; + } + else + { + requested_name = name; + } + Tensor placeholder; + try + { + placeholder = tf.placeholder(spec.dtype, spec.shape, requested_name); + } + catch (ValueError) + { + // TODO(Rinne): Add warning here. + placeholder = tf.placeholder(spec.dtype, spec.shape); + } + if (name is not null) + { + placeholder.op._set_attr("_user_specified_name", new AttrValue() + { + S = tf.compat.as_bytes(requested_name) + }); + } + return placeholder; + } + else if (arg is BaseResourceVariable variable) + { + var placeholder = func_graph.capture(variable.Handle, name); + placeholder.op._set_attr("_user_specified_name", new AttrValue() + { + S = tf.compat.as_bytes(name) + }); + return arg; + } + // TODO(Rinne): deal with `VariableSpec`. + else + { + return arg; + } + } } diff --git a/src/TensorFlowNET.Core/Graphs/Graph.Gradient.cs.cs b/src/TensorFlowNET.Core/Graphs/Graph.Gradient.cs.cs index 91aef2dc..bed8b35c 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.Gradient.cs.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.Gradient.cs.cs @@ -1,4 +1,6 @@ -namespace Tensorflow +using Tensorflow.Graphs; + +namespace Tensorflow { public partial class Graph { @@ -6,5 +8,10 @@ { } + + internal GraphOverrideGradientContext _override_gradient_function(Dictionary> gradient_function_map) + { + return new GraphOverrideGradientContext(this, gradient_function_map); + } } } diff --git a/src/TensorFlowNET.Core/Graphs/Graph.Operation.cs b/src/TensorFlowNET.Core/Graphs/Graph.Operation.cs index fc356687..c788aaf0 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.Operation.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.Operation.cs @@ -118,7 +118,7 @@ namespace Tensorflow /// (Optional.) If True, device functions will be executed /// to compute the device property of the Operation. /// An `Operation` object. - public Operation _create_op_from_tf_operation(IntPtr c_op, bool compute_device = true) + public Operation _create_op_from_tf_operation(IntPtr c_op, bool compute_device = true, OperationDescription desc = null) { var ret = new Operation(c_op, this); _add_op(ret); diff --git a/src/TensorFlowNET.Core/Graphs/Graph.cs b/src/TensorFlowNET.Core/Graphs/Graph.cs index e583868e..f443bcff 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.cs @@ -21,6 +21,7 @@ using System.Collections.Specialized; using System.Linq; using Tensorflow.Framework; using Tensorflow.Functions; +using Tensorflow.Common.Extensions; using static Tensorflow.Binding; namespace Tensorflow @@ -88,6 +89,7 @@ namespace Tensorflow private List _unfetchable_ops = new List(); private List _unfeedable_tensors = new List(); private Dictionary _functions = new(); + internal Dictionary> _gradient_function_map = new(); private VersionDef _graph_def_versions = new VersionDef() { Producer = versions.GRAPH_DEF_VERSION, @@ -161,13 +163,30 @@ namespace Tensorflow return _functions.ContainsKey(tf.compat.as_str(name)); } - public void AddFunction(EagerDefinedFunction function) + internal void AddFunction(EagerDefinedFunction function) { _check_not_finalized(); var name = function.Name; + if(function._grad_func_name is not null && function.csharp_grad_func is not null) + { + throw new ValueError($"Gradient defined twice for function {name}"); + } - // TODO(Rinne): deal with c_graph + var c_graph = this.c_graph; + var func = function._c_func.Get(); + Status status = new(); + if (function._grad_func is not null) + { + var gradient = function._grad_func._c_func.Get(); + c_api.TF_GraphCopyFunction(c_graph, func, gradient, status); + status.Check(true); + } + else + { + c_api.TF_GraphCopyFunction(c_graph, func, new SafeFuncGraphHandle(IntPtr.Zero), status); + status.Check(true); + } _functions[tf.compat.as_str(name)] = function; @@ -332,6 +351,9 @@ namespace Tensorflow private void _create_op_helper(Operation op, bool compute_device = true) { + // high priority + // TODO(Rinne): complete the implementation. + op._gradient_function = _gradient_function_map.GetOrDefault(op.type, null); _record_op_seen_by_control_dependencies(op); } @@ -548,6 +570,11 @@ namespace Tensorflow ops.pop_graph(); } + internal EagerDefinedFunction _get_function(string name) + { + return _functions.GetOrDefault(name, null); + } + string debugString = string.Empty; public override string ToString() { diff --git a/src/TensorFlowNET.Core/Graphs/GraphOverrideGradientContext.cs b/src/TensorFlowNET.Core/Graphs/GraphOverrideGradientContext.cs new file mode 100644 index 00000000..2befbbff --- /dev/null +++ b/src/TensorFlowNET.Core/Graphs/GraphOverrideGradientContext.cs @@ -0,0 +1,37 @@ +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Text; + +namespace Tensorflow.Graphs +{ + internal class GraphOverrideGradientContext: ITensorFlowObject + { + Graph _graph; + Dictionary> _new_gradient_function_map; + public GraphOverrideGradientContext(Graph graph, + Dictionary> new_gradient_function_map) + { + _graph = graph; + _new_gradient_function_map = new_gradient_function_map; + } + + [DebuggerStepThrough] + public void __enter__() + { + Debug.Assert(_graph._gradient_function_map.Count == 0); + _graph._gradient_function_map = _new_gradient_function_map; + } + + [DebuggerStepThrough] + public void __exit__() + { + _graph._gradient_function_map = new Dictionary>(); + } + + public void Dispose() + { + + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/Operation.cs b/src/TensorFlowNET.Core/Operations/Operation.cs index 28e69886..ca00710c 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.cs @@ -20,6 +20,9 @@ using System.Collections.Generic; using System.Linq; using Tensorflow.Util; using static Tensorflow.Binding; +using Google.Protobuf; +using Google.Protobuf.WellKnownTypes; +using System.Diagnostics; namespace Tensorflow { @@ -47,6 +50,8 @@ namespace Tensorflow private readonly Graph _graph; + internal Func _gradient_function; + public string type => OpType; public Graph graph => _graph; @@ -61,7 +66,7 @@ namespace Tensorflow public string Device => _handle == IntPtr.Zero ? "" : c_api.StringPiece(c_api.TF_OperationDevice(_handle)); - // OperationDescription _opDesc; + //private OperationDescription _op_desc; public NodeDef node_def => GetNodeDef(); @@ -216,21 +221,19 @@ namespace Tensorflow var x = AttrValue.Parser.ParseFrom(buf.ToArray()); - string oneof_value = x.ValueCase.ToString(); - if (string.IsNullOrEmpty(oneof_value)) - return null; + var oneof_value = x.ValueCase; + if (oneof_value == AttrValue.ValueOneofCase.None) + return new object[0]; - switch (oneof_value.ToLower()) + if(oneof_value == AttrValue.ValueOneofCase.List) { - case "list": - throw new NotImplementedException($"Unsupported field type in {oneof_value}"); - case "type": - return x.Type; - case "s": - return x.S.ToStringUtf8(); - default: - return x.GetType().GetProperty(oneof_value).GetValue(x); + throw new NotImplementedException($"Unsupported field type in {oneof_value}"); } + if(oneof_value == AttrValue.ValueOneofCase.Type) + { + return dtypes.as_tf_dtype(x.Type); + } + return ProtoUtils.GetSingleAttrValue(x, oneof_value); } public TF_AttrMetadata GetAttributeMetadata(string attr_name, Status s) @@ -309,5 +312,83 @@ namespace Tensorflow } public NDArray numpy() => throw new NotImplementedException(""); + + internal void _add_outputs(TF_DataType[] types, Shape[] shapes) + { + Debug.Assert(types.Length == shapes.Length); + int orig_num_outputs = this.outputs.Length; + //var new_outputs = new List(_outputs); + + var old_outputs = _outputs; + _outputs = new Tensor[orig_num_outputs + types.Length]; + for(int i = 0; i < orig_num_outputs; i++) + { + _outputs[i] = old_outputs[i]; + } + + // Since the `_outputs` is defined as `Array`, when we add new output, we + // have to create a new array, which brings some performance concerns. + // In the future maybe the type of `outputs` should be reconsidered. + for(int i = 0; i < types.Length; i++) + { + var t = new Tensor(this, orig_num_outputs + 1, types[i]); + _outputs[i] = t; + //t = tf.ensure_shape(t, shapes[i]); + t.shape = shapes[i]; + //new_outputs.Add(t); + } + //_outputs = new_outputs.ToArray(); + } + + internal void _set_func_attr(string attr_name, string func_name) + { + var func = new NameAttrList() { Name = func_name }; + _set_attr(attr_name, new AttrValue() { Func = func }); + } + + internal void _set_type_list_attr(string attr_name, DataType[] types) + { + if(types is null || types.Length == 0) + { + return; + } + var type_list = new AttrValue.Types.ListValue(); + type_list.Type.AddRange(types); + _set_attr(attr_name, new AttrValue() { List = type_list }); + } + + internal void _set_attr(string attr_name, AttrValue attr_value) + { + var buffer = new Buffer(attr_value.ToByteArray()); + try + { + _set_attr_with_buf(attr_name, buffer); + } + finally + { + buffer.Release(); + } + } + + internal void _set_attr_with_buf(string attr_name, Buffer attr_buf) + { + //if(_op_desc is null) + //{ + // //var new_node_def = NodeDef.Parser.ParseFrom(node_def.ToByteArray()); + // //new_node_def.Name += "_temp"; + // //var op = new Operation(new_node_def, graph, inputs, _output_types, control_inputs, _input_types); + // //Status status = new(); + // //c_api.TF_SetAttrBool(op._op_desc, "trainable", true); + // ////c_api.TF_SetAttrValueProto(op._op_desc, attr_name, attr_buf.ToArray(), attr_buf.Length, status); + // //status.Check(true); + // // TODO(Rinne): deal with it. Give a warning or make the Operation always contains `op_desc`. + //} + //else + //{ + // //Status status = new(); + // //c_api.TF_SetAttrValueProto(_op_desc, attr_name, attr_buf.ToArray(), attr_buf.Length, status); + // //status.Check(true); + //} + } } } \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Operations/c_api.ops.cs b/src/TensorFlowNET.Core/Operations/c_api.ops.cs index 46a654e0..43dc8cd4 100644 --- a/src/TensorFlowNET.Core/Operations/c_api.ops.cs +++ b/src/TensorFlowNET.Core/Operations/c_api.ops.cs @@ -208,9 +208,9 @@ namespace Tensorflow [DllImport(TensorFlowLibName)] public static extern int TF_OperationOutputListLength(IntPtr oper, string arg_name, SafeStatusHandle status); - [DllImport(TensorFlowLibName)] - public static extern IntPtr GetHandleShapeAndType(SafeGraphHandle c_graph, TF_Output output); - [DllImport(TensorFlowLibName)] - public static extern void SetHandleShapeAndType(SafeGraphHandle c_graph, TF_Output output, byte[] data); + //[DllImport(TensorFlowLibName)] + //public static extern IntPtr GetHandleShapeAndType(SafeGraphHandle c_graph, TF_Output output); + //[DllImport(TensorFlowLibName)] + //public static extern void SetHandleShapeAndType(SafeGraphHandle c_graph, TF_Output output, byte[] data); } } diff --git a/src/TensorFlowNET.Core/Operations/functional_ops.cs b/src/TensorFlowNET.Core/Operations/functional_ops.cs index 2d447207..9c2e85d1 100644 --- a/src/TensorFlowNET.Core/Operations/functional_ops.cs +++ b/src/TensorFlowNET.Core/Operations/functional_ops.cs @@ -39,7 +39,7 @@ namespace Tensorflow if (config is null) { - config = function_utils.get_disabled_rewriter_config(); + config = function_utils.get_disabled_rewriter_config().ToString(); } if (executor_type is null) diff --git a/src/TensorFlowNET.Core/Operations/gen_functional_ops.cs b/src/TensorFlowNET.Core/Operations/gen_functional_ops.cs index ce37ec7d..bb84ac39 100644 --- a/src/TensorFlowNET.Core/Operations/gen_functional_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_functional_ops.cs @@ -79,5 +79,50 @@ namespace Tensorflow.Operations }; } + + public static Tensor[] symbolic_gradient(Tensor[] input, TF_DataType[] Tout, NameAttrList f, string name = null) + { + var ctx = tf.Context; + if (ctx.executing_eagerly()) + { + try + { + var _result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo( + "SymbolicGradient", name, input, Tout, f)); + return _result; + } + catch (Exception) + { + + } + + try + { + return symbolic_gradient_eager_fallback(input, Tout, f, name, ctx); + } + catch (Exception) + { + + } + } + var op = tf.OpDefLib._apply_op_helper("SymbolicGradient", name, new object[] { input, Tout, f }); + var result = op.outputs; + if (execute.must_record_gradient()) + { + throw new NotImplementedException(); + } + return result; + } + + public static Tensor[] symbolic_gradient_eager_fallback(Tensor[] input, TF_DataType[] Tout, NameAttrList f, string name, Context ctx) + { + object[] attrs = new object[] { "Tin", input, "Tout", Tout, "f", f }; + var result = execute.executes("SymbolicGradient", Tout.Length, input, attrs, ctx, name); + if (execute.must_record_gradient()) + { + throw new NotImplementedException(); + } + return result; + } } } diff --git a/src/TensorFlowNET.Core/Operations/gen_ops.cs b/src/TensorFlowNET.Core/Operations/gen_ops.cs index bf178b60..8f8b2f11 100644 --- a/src/TensorFlowNET.Core/Operations/gen_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_ops.cs @@ -10050,13 +10050,51 @@ namespace Tensorflow.Operations /// public static Tensor ensure_shape(Tensor input, Shape shape, string name = "EnsureShape") { + var ctx = tf.Context; + if (ctx.executing_eagerly()) + { + try + { + var _result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo("EnsureShape", name, input, shape)); + return _result[0]; + } + catch (Exception) + { + Console.WriteLine(); + } + try + { + return ensure_shape_eager_fallback(input, shape, name, ctx); + } + catch (Exception) + { + Console.WriteLine(); + } + } + var dict = new Dictionary(); dict["input"] = input; dict["shape"] = shape; var op = tf.OpDefLib._apply_op_helper("EnsureShape", name: name, keywords: dict); + if (execute.must_record_gradient()) + { + throw new NotImplementedException(); + } return op.output; } + public static Tensor ensure_shape_eager_fallback(Tensor input, Shape shape, string name, Context ctx) + { + object[] attrs = new object[4] { "shape", shape, "T", input.dtype.as_datatype_enum() }; + var _result = execute.executes("EnsureShape", 1, new Tensor[] { input }, + attrs, ctx, name); + if (execute.must_record_gradient()) + { + throw new NotImplementedException(); + } + return _result[0]; + } + /// /// Creates or finds a child frame, and makes data available to the child frame. /// diff --git a/src/TensorFlowNET.Core/Operations/handle_data_util.cs b/src/TensorFlowNET.Core/Operations/handle_data_util.cs index 5d5fbebb..66daa5c0 100644 --- a/src/TensorFlowNET.Core/Operations/handle_data_util.cs +++ b/src/TensorFlowNET.Core/Operations/handle_data_util.cs @@ -52,5 +52,7 @@ namespace Tensorflow.Operations // TODO(Rinne): enable it. (currently the internal c api cannot be invoked.) //c_api.SetHandleShapeAndType(target_t.graph.c_graph, target_t._as_tf_output(), handle_data.ToByteArray()); } + + public static HandleData get_resource_handle_data(Tensor graph_op) => ops.get_resource_handle_data(graph_op); } } diff --git a/src/TensorFlowNET.Core/Operations/resource_variable_ops.cs b/src/TensorFlowNET.Core/Operations/resource_variable_ops.cs index 7921f28b..3e39338b 100644 --- a/src/TensorFlowNET.Core/Operations/resource_variable_ops.cs +++ b/src/TensorFlowNET.Core/Operations/resource_variable_ops.cs @@ -24,6 +24,7 @@ using static Tensorflow.CppShapeInferenceResult.Types; using static Tensorflow.Binding; using Tensorflow.Operations; using System.Buffers; +using Tensorflow.Eager; namespace Tensorflow { @@ -41,12 +42,7 @@ namespace Tensorflow name: name); } - public static bool is_resource_variable(IVariableV1 var) - { - return var is BaseResourceVariable; - } - - public static bool is_resource_variable(Trackable var) + public static bool is_resource_variable(object var) { return var is BaseResourceVariable; } @@ -138,10 +134,27 @@ namespace Tensorflow /// internal unsafe static void _set_handle_shapes_and_types(Tensor tensor, HandleData handle_data, bool graph_mode) { - tensor.HandleData = handle_data; if (!graph_mode) return; + var size = handle_data.ShapeAndType.Count; + + var shapes = new IntPtr[size]; + var types = new DataType[size]; + var ranks = new int[size]; + + for (int i = 0; i < size; i++) + { + var shapeAndType = handle_data.ShapeAndType[i]; + types[i] = shapeAndType.Dtype; + ranks[i] = shapeAndType.Shape.UnknownRank ? -1 : shapeAndType.Shape.Dim.Count; + var dims = shapeAndType.Shape.Dim.Select(x => x.Size).ToArray(); + } + + //tensor.HandleData = handle_data; + //if (!graph_mode) + // return; + //var shapes = handle_data.ShapeAndType.Select(x => x.Shape); //var types = handle_data.ShapeAndType.Select(x => x.Dtype).ToArray(); //var ranks = shapes.Select(s => s.UnknownRank ? -1 : s.Dim.Count).ToArray(); @@ -196,24 +209,6 @@ namespace Tensorflow throw new NotImplementedException(""); } - private static HandleData get_eager_safe_handle_data(Tensor handle) - { - if (handle.Handle == null) - { - var data = new HandleData(); - data.ShapeAndType.Add(new HandleShapeAndType - { - Shape = handle.shape.as_shape_proto(), - Dtype = handle.dtype.as_datatype_enum() - }); - return data; - } - else - { - return HandleData.Parser.ParseFrom(handle.BufferToArray()); - } - } - /// /// Copies an existing variable to a new graph, with no initializer. /// @@ -281,5 +276,31 @@ namespace Tensorflow } } } + + public static HandleData get_eager_safe_handle_data(Tensor handle) + { + if (handle.Handle == null) + { + var data = new HandleData(); + data.ShapeAndType.Add(new HandleShapeAndType + { + Shape = handle.shape.as_shape_proto(), + Dtype = handle.dtype.as_datatype_enum() + }); + return data; + } + else + { + return HandleData.Parser.ParseFrom(handle.BufferToArray()); + } + //if(handle is EagerTensor) + //{ + // return handle.HandleData; + //} + //else + //{ + // return handle_data_util.get_resource_handle_data(handle); + //} + } } } diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs index 79b8d2c5..fff3cde5 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs @@ -101,6 +101,7 @@ namespace Tensorflow _op = op; _value_index = value_index; _override_dtype = dtype; + _tf_output = null; _id = ops.uid(); } diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.cs b/src/TensorFlowNET.Core/Tensors/Tensor.cs index 0bffbfba..6ca65a07 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.cs @@ -136,9 +136,9 @@ namespace Tensorflow protected virtual void SetShapeInternal(Shape value) { if (value == null) - c_api.TF_GraphSetTensorShape(graph, _as_tf_output(), null, -1, tf.Status); + c_api.TF_GraphSetTensorShape(op.graph.c_graph, _as_tf_output(), null, -1, tf.Status); else - c_api.TF_GraphSetTensorShape(graph, _as_tf_output(), value.dims, value.ndim, tf.Status); + c_api.TF_GraphSetTensorShape(op.graph.c_graph, _as_tf_output(), value.dims, value.ndim, tf.Status); } public int[] _shape_tuple() @@ -177,7 +177,9 @@ namespace Tensorflow if (_handle == null) { var output = _as_tf_output(); - int ndim = c_api.TF_GraphGetTensorNumDims(op.graph, output, tf.Status); + Status status = new(); + int ndim = c_api.TF_GraphGetTensorNumDims(op.graph, output, status); + status.Check(true); return ndim; } @@ -199,7 +201,7 @@ namespace Tensorflow public TF_Output _as_tf_output() { if (!_tf_output.HasValue) - _tf_output = new TF_Output(op, value_index); + _tf_output = new TF_Output(op, _value_index); return _tf_output.Value; } diff --git a/src/TensorFlowNET.Core/Tensors/Tensors.cs b/src/TensorFlowNET.Core/Tensors/Tensors.cs index 60972775..3d734cd1 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensors.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensors.cs @@ -56,7 +56,7 @@ namespace Tensorflow public void Add(Tensor tensor) => items.Add(tensor); - public void AddRange(Tensor[] tensors) + public void AddRange(IEnumerable tensors) => items.AddRange(tensors); public void Insert(int index, Tensor tensor) diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/WrapperFunction.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/WrapperFunction.cs index 341a12ab..695eadfd 100644 --- a/src/TensorFlowNET.Core/Training/Saving/SavedModel/WrapperFunction.cs +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/WrapperFunction.cs @@ -12,11 +12,12 @@ namespace Tensorflow.Training.Saving.SavedModel { public WrapperFunction(ConcreteFunction concrete_function): base(concrete_function.func_graph) { - this.forward_backward = concrete_function.forward_backward; - this.Outputs = concrete_function.Outputs; - this.ReturnType = concrete_function.ReturnType; - this.OutputStructure = concrete_function.OutputStructure; - this.ArgKeywords = concrete_function.ArgKeywords; + throw new NotImplementedException(); + //this.forward_backward = concrete_function.forward_backward; + //this.Outputs = concrete_function.Outputs; + //this.ReturnType = concrete_function.ReturnType; + //this.OutputStructure = concrete_function.OutputStructure; + //this.ArgKeywords = concrete_function.ArgKeywords; } } } diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/function_deserialization.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/function_deserialization.cs index 951d7d00..69dd2c10 100644 --- a/src/TensorFlowNET.Core/Training/Saving/SavedModel/function_deserialization.cs +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/function_deserialization.cs @@ -30,6 +30,31 @@ namespace Tensorflow.Training.Saving.SavedModel { var function_spec = _deserialize_function_spec_as_nonmethod(saved_function.FunctionSpec); + Tensor[] restored_function_body(Tensor[] inputs) + { + if(saved_function.ConcreteFunctions is null || saved_function.ConcreteFunctions.Count == 0) + { + throw new ValueError("Found zero restored functions for caller function."); + } + foreach(var function_name in saved_function.ConcreteFunctions) + { + var function = concrete_functions[function_name]; + if(function.CapturedInputs.Any(x => x is null)) + { + throw new ValueError("Looks like you are trying to run a loaded " + + "non-Keras model that was trained using tf.distribute.experimental.ParameterServerStrategy " + + "with variable partitioning, which is not currently supported. Try using Keras to define your model " + + "if possible."); + } + if(_concrete_function_callable_with(function, inputs, false)) + { + return _call_concrete_function(function, inputs); + } + } + throw new ValueError("Unexpected runtime behavior, please submit an issue to " + + "https://github.com/SciSharp/TensorFlow.NET/issues"); + } + List concrete_function_objects = new(); foreach(var concrete_function_name in saved_function.ConcreteFunctions) { @@ -40,17 +65,10 @@ namespace Tensorflow.Training.Saving.SavedModel cf._set_function_spec(function_spec); } - foreach(var function_name in saved_function.ConcreteFunctions) - { - var function = concrete_functions[function_name]; - if(_concrete_function_callable_with(function, null, false)) - { - return new RestoredFunction(null, function, "function_from_deserialization"); - } - } - return new RestoredFunction(x => x, new ConcreteFunction(x => x, TF_DataType.TF_FLOAT), "function_return_itself"); - //throw new ValueError("Unexpected runtime behavior, please submit an issue to " + - // "https://github.com/SciSharp/TensorFlow.NET/issues"); + var restored_function = new RestoredFunction(restored_function_body, nameof(restored_function_body), + function_spec, concrete_function_objects); + + return restored_function; } public static Dictionary load_function_def_library(FunctionDefLibrary library, @@ -102,15 +120,17 @@ namespace Tensorflow.Training.Saving.SavedModel { var orig_name = _fix_fdef_in_place(fdef, functions, load_shared_name_suffix, new_gradient_op_types); - if(saved_object_graph is not null && saved_object_graph.ConcreteFunctions.ContainsKey(orig_name)) + object structured_input_signature = null; + object structured_outputs = null; + if (saved_object_graph is not null && saved_object_graph.ConcreteFunctions.ContainsKey(orig_name)) { - // TODO(Rinne): implement it. - //var proto = saved_object_graph.ConcreteFunctions[orig_name]; - //throw new NotImplementedException(); + var proto = saved_object_graph.ConcreteFunctions[orig_name]; + structured_input_signature = nested_structure_coder.decode_proto(proto.CanonicalizedInputSignature); + structured_outputs = nested_structure_coder.decode_proto(proto.OutputSignature); } graph.as_default(); - var func_graph = function_def_lib.function_def_to_graph(fdef, null, null); + var func_graph = function_def_lib.function_def_to_graph(fdef, structured_input_signature, structured_outputs); graph.Exit(); _restore_gradient_functions(func_graph, renamed_functions, loaded_gradients); @@ -124,7 +144,7 @@ namespace Tensorflow.Training.Saving.SavedModel { fdef.Attr.Remove("_input_shapes"); } - var func = new ConcreteFunction(func_graph, fdef.Attr.ToDictionary(x => x.Key, x => x.Value.S.ToString())); + var func = new ConcreteFunction(func_graph, fdef.Attr.ToDictionary(x => x.Key, x => x.Value)); if(wrapper_function is not null) { throw new NotImplementedException(); @@ -142,8 +162,7 @@ namespace Tensorflow.Training.Saving.SavedModel { var gradient_op_type = gradients_to_register[orig_name]; loaded_gradients[gradient_op_type] = func; - // TODO(Rinne): deal with gradient registry. - //new RegisteredGradient() { RegisteredOpType = gradient_op_type }. + ops.RegisterGradientFunction(gradient_op_type, _gen_gradient_func(func)); } } return functions; @@ -203,6 +222,16 @@ namespace Tensorflow.Training.Saving.SavedModel } } + private static Func _gen_gradient_func(ConcreteFunction func) + { + return (unused_op, result_grads) => + { + result_grads = zip(result_grads, func.func_graph.Inputs) + .Select((item) => item.Item1 is null ? default_gradient.zeros_like(item.Item2) : item.Item1).ToArray(); + return func.CallFlat(result_grads, func.CapturedInputs); + }; + } + private static void _restore_gradient_functions(FuncGraph func_graph, Dictionary renamed_functions, Dictionary loaded_gradients) { foreach(var op in func_graph.get_operations()) @@ -210,14 +239,14 @@ namespace Tensorflow.Training.Saving.SavedModel if(op.op.type == "StatefulPartitionedCall" || op.op.type == "PartitionedCall") { var function = renamed_functions[op.op.node_def.Attr["f"].Func.Name]; - // TODO(Rinne): deal with `op._gradient_function`. + op.op._gradient_function = function._get_gradient_function(); } string gradient_op_type = null; try { gradient_op_type = op.op.get_attr("_gradient_op_type") as string; } - catch(Exception e) + catch(InvalidArgumentError) { continue; } @@ -389,7 +418,7 @@ namespace Tensorflow.Training.Saving.SavedModel concrete_function.ArgKeywords = saved_bare_concrete_function.ArgumentKeywords.ToList(); concrete_function.NumPositionArgs = saved_bare_concrete_function.AllowedPositionalArguments; - var function_spec = _deserialize_function_spec_as_nonmethod(saved_bare_concrete_function.FunctionSpec); + //var function_spec = _deserialize_function_spec_as_nonmethod(saved_bare_concrete_function.FunctionSpec); // TODO(Rinne): set the functiona spec. concrete_function.AddTograph(); return concrete_function; @@ -413,19 +442,31 @@ namespace Tensorflow.Training.Saving.SavedModel return function.CallFlat(inputs, function.CapturedInputs); } - private static bool _concrete_function_callable_with(ConcreteFunction function, Tensors inputs, bool allow_conversion) + private static bool _concrete_function_callable_with(ConcreteFunction function, Tensor[] inputs, bool allow_conversion) { // TODO(Rinne): revise it. - return true; + return function.CapturedInputs.Length + inputs.Length == function.Inputs.Length; + //var expected_inputs = function.func_graph.Inputs; + //foreach(var (arg, expected) in zip(inputs, expected_inputs)) + //{ + // if(arg.Id != expected.Id) + // { + // return false; + // } + //} + //return true; } } public class RestoredFunction : Function { - public RestoredFunction(Func function, ConcreteFunction concrete_function, - string name, bool auto_graph = true): base(function, name, auto_graph) + IEnumerable _concrete_functions; + FunctionSpec _function_spec; + public RestoredFunction(Func function, string name, FunctionSpec function_spec, + IEnumerable concrete_functions): base(function, name, auto_graph: false) { - _concrete_variable_creation_fn = concrete_function; + _concrete_functions = concrete_functions; + _function_spec = function_spec; } protected override bool _run_functions_eagerly() diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/signature_serialization.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/signature_serialization.cs index 4a0d3b00..d3ffebc9 100644 --- a/src/TensorFlowNET.Core/Training/Saving/SavedModel/signature_serialization.cs +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/signature_serialization.cs @@ -102,6 +102,6 @@ public class SignatureMap: Trackable return new Dictionary(); } - return _signatures.TakeWhile(x => x.Value is Function or ConcreteFunction).ToDictionary(x => x.Key, x => x.Value); + return _signatures.Where(x => x.Value is Function or ConcreteFunction).ToDictionary(x => x.Key, x => x.Value); } } diff --git a/src/TensorFlowNET.Core/Training/data_structures.cs b/src/TensorFlowNET.Core/Training/data_structures.cs index a8033f59..6b607e85 100644 --- a/src/TensorFlowNET.Core/Training/data_structures.cs +++ b/src/TensorFlowNET.Core/Training/data_structures.cs @@ -132,8 +132,8 @@ namespace Tensorflow.Training { get { - var trainable_extra_variables = _self_extra_variables.TakeWhile(x => x.Trainable).ToList(); - var non_trainable_extra_variables = _self_extra_variables.TakeWhile(x => !x.Trainable).ToList(); + var trainable_extra_variables = _self_extra_variables.Where(x => x.Trainable).ToList(); + var non_trainable_extra_variables = _self_extra_variables.Where(x => !x.Trainable).ToList(); List non_trainable_variables = new(); foreach(var obj in Values) { @@ -576,7 +576,7 @@ namespace Tensorflow.Training if(save_type == SaveType.SAVEDMODEL) { - children = children.Concat(this.TakeWhile(x => x is Function or ConcreteFunction).Select((x, idx) => new KeyValuePair(idx.ToString(), x))).ToDictionary(x => x.Key, x => x.Value); + children = children.Concat(this.Where(x => x is Function or ConcreteFunction).Select((x, idx) => new KeyValuePair(idx.ToString(), x))).ToDictionary(x => x.Key, x => x.Value); } return children; diff --git a/src/TensorFlowNET.Core/Util/ProtoUtils.cs b/src/TensorFlowNET.Core/Util/ProtoUtils.cs new file mode 100644 index 00000000..e7de8e30 --- /dev/null +++ b/src/TensorFlowNET.Core/Util/ProtoUtils.cs @@ -0,0 +1,24 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Util +{ + internal static class ProtoUtils + { + public static object GetSingleAttrValue(AttrValue value, AttrValue.ValueOneofCase valueCase) + { + return valueCase switch + { + AttrValue.ValueOneofCase.S => value.S, + AttrValue.ValueOneofCase.I => value.I, + AttrValue.ValueOneofCase.F => value.F, + AttrValue.ValueOneofCase.B => value.B, + AttrValue.ValueOneofCase.Type => value.Type, + AttrValue.ValueOneofCase.Shape => value.Shape, + AttrValue.ValueOneofCase.Tensor => value.Tensor, + AttrValue.ValueOneofCase.Func => value.Func, + }; + } + } +} diff --git a/src/TensorFlowNET.Core/Util/function_utils.cs b/src/TensorFlowNET.Core/Util/function_utils.cs index 2944e88e..d4ba4423 100644 --- a/src/TensorFlowNET.Core/Util/function_utils.cs +++ b/src/TensorFlowNET.Core/Util/function_utils.cs @@ -7,15 +7,15 @@ namespace Tensorflow.Util { internal static class function_utils { - private static string _rewriter_config_optimizer_disabled; - public static string get_disabled_rewriter_config() + private static ByteString _rewriter_config_optimizer_disabled; + public static ByteString get_disabled_rewriter_config() { if(_rewriter_config_optimizer_disabled is null) { var config = new ConfigProto(); var rewriter_config = config.GraphOptions.RewriteOptions; rewriter_config.DisableMetaOptimizer = true; - _rewriter_config_optimizer_disabled = config.ToString(); + _rewriter_config_optimizer_disabled = config.ToByteString(); } return _rewriter_config_optimizer_disabled; } diff --git a/src/TensorFlowNET.Core/Util/nest.py.cs b/src/TensorFlowNET.Core/Util/nest.py.cs index c4537896..eb94f4d0 100644 --- a/src/TensorFlowNET.Core/Util/nest.py.cs +++ b/src/TensorFlowNET.Core/Util/nest.py.cs @@ -137,10 +137,12 @@ namespace Tensorflow.Util switch (instance) { case Hashtable hash: - var result = new Hashtable(); - foreach ((object key, object value) in zip(_sorted(hash), args)) - result[key] = value; - return result; + { + var result = new Hashtable(); + foreach ((object key, object value) in zip(_sorted(hash), args)) + result[key] = value; + return result; + } } } //else if( _is_namedtuple(instance) || _is_attrs(instance)) @@ -221,6 +223,16 @@ namespace Tensorflow.Util return list; } + public static List flatten(IEnumerable structure) + { + var list = new List(); + foreach(var item in structure) + { + _flatten_recursive(item, list); + } + return list; + } + public static object[] flatten2(ICanBeFlattened structure) => structure.Flatten(); @@ -527,6 +539,14 @@ namespace Tensorflow.Util return pack_sequence_as(structure, mapped_flat_structure) as T2; } + public static IEnumerable map_structure(Func func, IEnumerable structure) where T2 : class + { + var flat_structure = flatten(structure); + var mapped_flat_structure = flat_structure.Select(func).Select(x => (object)x); + + return pack_sequence_as(structure, mapped_flat_structure) as IEnumerable; + } + /// /// Same as map_structure, but with only one structure (no combining of multiple structures) /// diff --git a/src/TensorFlowNET.Core/Util/variable_utils.cs b/src/TensorFlowNET.Core/Util/variable_utils.cs new file mode 100644 index 00000000..13237f9d --- /dev/null +++ b/src/TensorFlowNET.Core/Util/variable_utils.cs @@ -0,0 +1,33 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Framework; + +namespace Tensorflow.Util +{ + internal static class variable_utils + { + public static Tensor[] convert_variables_to_tensors(object[] values) + { + return values.Select(x => + { + if (resource_variable_ops.is_resource_variable(x)) + { + return ops.convert_to_tensor(x); + } + else if (x is CompositeTensor) + { + throw new NotImplementedException("The composite tensor has not been fully supported."); + } + else if(x is Tensor tensor) + { + return tensor; + } + else + { + throw new TypeError("Currently the output of function to be traced must be `Tensor`."); + } + }).ToArray(); + } + } +} diff --git a/src/TensorFlowNET.Core/ops.cs b/src/TensorFlowNET.Core/ops.cs index bce64198..7aadb206 100644 --- a/src/TensorFlowNET.Core/ops.cs +++ b/src/TensorFlowNET.Core/ops.cs @@ -248,7 +248,7 @@ namespace Tensorflow foreach (var attr in node_def.Attr) { var bytes = attr.Value.ToByteArray(); - c_api.TF_SetAttrValueProto(op_desc, attr.Key, bytes, proto_len: bytes.Length, status: status); + c_api.TF_SetAttrValueProto(op_desc, attr.Key, bytes, proto_len: (ulong)bytes.Length, status: status); status.Check(true); } @@ -575,10 +575,12 @@ namespace Tensorflow public static HandleData get_resource_handle_data(Tensor graph_op) { + throw new NotImplementedException(); // This implementation hasn't been checked for some reasons. // If it throws an exception in the future, please check it. - var handle_data = c_api.GetHandleShapeAndType(graph_op.graph.c_graph, graph_op._as_tf_output()); - return HandleData.Parser.ParseFrom(tf.compat.as_bytes(c_api.StringPiece(handle_data))); + + //var handle_data = c_api.GetHandleShapeAndType(graph_op.graph.c_graph, graph_op._as_tf_output()); + //return HandleData.Parser.ParseFrom(tf.compat.as_bytes(c_api.StringPiece(handle_data))); } public static void dismantle_graph(Graph graph) diff --git a/src/TensorFlowNET.Keras/Engine/Model.Train.cs b/src/TensorFlowNET.Keras/Engine/Model.Train.cs index d8171e2a..5cf34250 100644 --- a/src/TensorFlowNET.Keras/Engine/Model.Train.cs +++ b/src/TensorFlowNET.Keras/Engine/Model.Train.cs @@ -35,6 +35,10 @@ namespace Tensorflow.Keras.Engine { (x, y) = data_handler.DataAdapter.Expand1d(x, y); using var tape = tf.GradientTape(); + //foreach (var variable in TrainableVariables) + //{ + // tape.watch(variable.Handle); + //} var y_pred = Apply(x, training: true); var loss = compiled_loss.Call(y, y_pred); diff --git a/src/TensorFlowNET.Keras/Layers/TensorFlowOpLayer.cs b/src/TensorFlowNET.Keras/Layers/TensorFlowOpLayer.cs index c7b9157b..1ac4a277 100644 --- a/src/TensorFlowNET.Keras/Layers/TensorFlowOpLayer.cs +++ b/src/TensorFlowNET.Keras/Layers/TensorFlowOpLayer.cs @@ -84,8 +84,8 @@ namespace Tensorflow.Keras.Layers inputs.Insert(index, value); } - var (c_op, _) = ops._create_c_op(graph, node_def, inputs.ToArray(), new Operation[0]); - var op = graph._create_op_from_tf_operation(c_op); + var (c_op, op_desc) = ops._create_c_op(graph, node_def, inputs.ToArray(), new Operation[0]); + var op = graph._create_op_from_tf_operation(c_op, desc: op_desc); op._control_flow_post_processing(); // Record the gradient because custom-made ops don't go through the diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/serialized_attributes.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/serialized_attributes.cs index d7df6eb2..9d611efe 100644 --- a/src/TensorFlowNET.Keras/Saving/SavedModel/serialized_attributes.cs +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/serialized_attributes.cs @@ -51,9 +51,9 @@ namespace Tensorflow.Keras.Saving.SavedModel _all_functions = new HashSet(objects_and_functions.Item2); } - public IDictionary Functions => _function_dict.TakeWhile(x => x.Value is not null).ToDictionary(x => x.Key, x => x.Value!); + public IDictionary Functions => _function_dict.Where(x => x.Value is not null).ToDictionary(x => x.Key, x => x.Value!); - public IDictionary CheckpointableObjects => _object_dict.TakeWhile(x => x.Value is not null).ToDictionary(x => x.Key, x => x.Value!); + public IDictionary CheckpointableObjects => _object_dict.Where(x => x.Value is not null).ToDictionary(x => x.Key, x => x.Value!); /// /// Returns functions to attach to the root object during serialization. @@ -82,7 +82,7 @@ namespace Tensorflow.Keras.Saving.SavedModel { get { - var objects = CheckpointableObjects.TakeWhile( x=> _all_checkpointable_objects.Contains(x.Key)).ToDictionary(x => x.Key, x => x.Value); + var objects = CheckpointableObjects.Where( x=> _all_checkpointable_objects.Contains(x.Key)).ToDictionary(x => x.Key, x => x.Value); objects[Constants.KERAS_ATTR] = _keras_trackable; return objects; } diff --git a/test/TensorFlowNET.Keras.UnitTest/Assets/python_func_model/fingerprint.pb b/test/TensorFlowNET.Keras.UnitTest/Assets/python_func_model/fingerprint.pb new file mode 100644 index 0000000000000000000000000000000000000000..361ca3a8a0d88d36df3cecef2aa4d8e0c20adbb9 GIT binary patch literal 54 zcmV-60LlLdhmweevB%Z)vjGr--Qs}x`R~3o82#q7j+2Ab#FPOb%#YgT*_FQXMkwI7 M@!0ai*P1ml0EfFE`v3p{ literal 0 HcmV?d00001 diff --git a/test/TensorFlowNET.Keras.UnitTest/Assets/python_func_model/keras_metadata.pb b/test/TensorFlowNET.Keras.UnitTest/Assets/python_func_model/keras_metadata.pb new file mode 100644 index 00000000..b98e1733 --- /dev/null +++ b/test/TensorFlowNET.Keras.UnitTest/Assets/python_func_model/keras_metadata.pb @@ -0,0 +1,6 @@ + +root"_tf_keras_sequential*{"name": "sequential", "trainable": true, "expects_training_arg": true, "dtype": "float32", "batch_input_shape": null, "must_restore_from_config": false, "preserve_input_structure_in_config": false, "autocast": false, "class_name": "Sequential", "config": {"name": "sequential", "layers": [{"class_name": "InputLayer", "config": {"batch_input_shape": {"class_name": "__tuple__", "items": [null, 784]}, "dtype": "float32", "sparse": false, "ragged": false, "name": "input_1"}}, {"class_name": "Transformer", "config": {"name": "transformer", "trainable": true, "dtype": "float32", "a": 784, "b": 10}}, {"class_name": "Softmax", "config": {"name": "softmax", "trainable": true, "dtype": "float32", "axis": -1}}]}, "shared_object_id": 3, "input_spec": [{"class_name": "InputSpec", "config": {"dtype": null, "shape": {"class_name": "__tuple__", "items": [null, 784]}, "ndim": 2, "max_ndim": null, "min_ndim": null, "axes": {}}}], "build_input_shape": {"class_name": "TensorShape", "items": [null, 784]}, "is_graph_network": true, "full_save_spec": {"class_name": "__tuple__", "items": [[{"class_name": "TypeSpec", "type_spec": "tf.TensorSpec", "serialized": [{"class_name": "TensorShape", "items": [null, 784]}, "float32", "input_1"]}], {}]}, "save_spec": {"class_name": "TypeSpec", "type_spec": "tf.TensorSpec", "serialized": [{"class_name": "TensorShape", "items": [null, 784]}, "float32", "input_1"]}, "keras_version": "2.11.0", "backend": "tensorflow", "model_config": {"class_name": "Sequential", "config": {"name": "sequential", "layers": [{"class_name": "InputLayer", "config": {"batch_input_shape": {"class_name": "__tuple__", "items": [null, 784]}, "dtype": "float32", "sparse": false, "ragged": false, "name": "input_1"}, "shared_object_id": 0}, {"class_name": "Transformer", "config": {"name": "transformer", "trainable": true, "dtype": "float32", "a": 784, "b": 10}, "shared_object_id": 1}, {"class_name": "Softmax", "config": {"name": "softmax", "trainable": true, "dtype": "float32", "axis": -1}, "shared_object_id": 2}]}}, "training_config": {"loss": "sparse_categorical_crossentropy", "metrics": [[{"class_name": "MeanMetricWrapper", "config": {"name": "accuracy", "dtype": "float32", "fn": "categorical_accuracy"}, "shared_object_id": 5}]], "weighted_metrics": null, "loss_weights": null, "optimizer_config": {"class_name": "Custom>Adam", "config": {"name": "Adam", "weight_decay": null, "clipnorm": null, "global_clipnorm": null, "clipvalue": null, "use_ema": false, "ema_momentum": 0.99, "ema_overwrite_frequency": null, "jit_compile": false, "is_legacy_optimizer": false, "learning_rate": 0.0010000000474974513, "beta_1": 0.9, "beta_2": 0.999, "epsilon": 1e-07, "amsgrad": false}}}}2 +root.layer_with_weights-0"_tf_keras_layer*{"name": "transformer", "trainable": true, "expects_training_arg": false, "dtype": "float32", "batch_input_shape": null, "stateful": false, "must_restore_from_config": false, "preserve_input_structure_in_config": false, "autocast": true, "class_name": "Transformer", "config": {"name": "transformer", "trainable": true, "dtype": "float32", "a": 784, "b": 10}, "shared_object_id": 1, "build_input_shape": {"class_name": "TensorShape", "items": [null, 784]}}2 + root.layer-1"_tf_keras_layer*{"name": "softmax", "trainable": true, "expects_training_arg": false, "dtype": "float32", "batch_input_shape": null, "stateful": false, "must_restore_from_config": false, "preserve_input_structure_in_config": false, "autocast": true, "class_name": "Softmax", "config": {"name": "softmax", "trainable": true, "dtype": "float32", "axis": -1}, "shared_object_id": 2, "build_input_shape": {"class_name": "TensorShape", "items": [null, 10]}}2 +9root.keras_api.metrics.0"_tf_keras_metric*{"class_name": "Mean", "name": "loss", "dtype": "float32", "config": {"name": "loss", "dtype": "float32"}, "shared_object_id": 6}2 +:root.keras_api.metrics.1"_tf_keras_metric*{"class_name": "MeanMetricWrapper", "name": "accuracy", "dtype": "float32", "config": {"name": "accuracy", "dtype": "float32", "fn": "categorical_accuracy"}, "shared_object_id": 5}2 \ No newline at end of file diff --git a/test/TensorFlowNET.Keras.UnitTest/Assets/python_func_model/saved_model.pb b/test/TensorFlowNET.Keras.UnitTest/Assets/python_func_model/saved_model.pb new file mode 100644 index 0000000000000000000000000000000000000000..f22755e07b9ed8455dea988d0d49c1384bd602b5 GIT binary patch literal 47187 zcmeG_X^b1!c{}9pa`^52c2`S1X{8kz6)T~Xqa z42RT(rgf4waE+#Qnxt)11PHchS~yMHq!$VV2!bL-uK@ki0_l;qC{PqB3ZzAvHZ73$ zdvnilh8!MiTQY6f%b9ue=6m1!?(ci=@f`W)cSG>qG4j$OU~X7n?m_e+Lra!DyyT^++|;GfRRrCOnOmjDbE zO6*-6-iX-|jxlPE5k3UN&q=lC>P3bKNk_;v_QohgD%Enno@4LF=U}8(l}bvbtjMBd zLF;^Q&GEoQ>5Y<)VxjEEn@)zED-nuXFK(FfA2}<=f(R0Vk@oC2kgq zveiJv@-PgZFJGv@b1-;O%H6~{8&>B`!2u_@^G=+N0(YK-5z^yainC!bJ^_Q$or0o1 zqoLRUoQ08#MJXp=N3i5-;}p6`VFrdY3ef)0!AvN|!4oikNtW_{^hKM~7o(vvL@y!Q z%2mWOlESctQN-sYjB2zamZWVJ!!t0cf*?wzyr>ilId+#$K}2h(2*WB7EAbfUzX%iy zO6@o!?Gp?k?1M!ZzAWA9@N1;quRV@m!`aW($QVR(s(Cs8B$5V01!LiBlA``bJVn)sspey0qzow4w;88a^GYm?yumYiVlTxaunv*S z<;~i*bcf)_ish|Bjgb)JAVs^6SMh0+=`uVF3xm_+TnN6-!QAB&`-BN_`l|l*UTNebv;1!ez*c-d;oQC)es1IFtk=r-P;AXv)!v@2Ms?EX{ z8CGglWD6Q111O9}wY$h$ z*vFLE1LLALBbxd_$q1p$Kw`vDQ=t^cE;KmVlReJiMTpidya)EG{U@2YQvYX@SuPny zwRZx2odI4e*CgZv+hhcKwH4rMw+wLgM=-$Erh&c=*C3)|L`-)Jdd!LQC^0&Q7h#|N zjX>ZV?*$%!I^5kDfZj11fc`Vs0QAPe22g^rqAXQQOz)#~kh`6S1I~GvAkOY#3l2HH z7idV%K7tKw6QTu_oT?$z_MoCpzM#%QjS#DA5Mkc@PN2b9?*OXW-3-u?TF=4@h<+-ye|!t3c-(2{Mad%bu&-)K?K5QFFrrXu^;cm z9|J<@nSnL#quf9u!K|gDmkz@r5j)zHt8yuaJ{tOZQmtN<#oJY>QbAzTYb!9SIS@p1 zBfhyu_u7c&yv4_onmW(mw=$_9j1ft^jGd@X2>Ib`_J3#j2pPzJZx&t(L5%0fOi{Xv z-@aX_T^Dc5g{|u~QK}SR6d~e8F^2|=C_;kg$)ft3c@9+>F;o)S zBhw(4ilW}pK0Zq3OlnuqQ?Rg#UjY&=S89dr!acbP3Qp>te>@ab=bU?0Di%$3l`_M#EsDH&F&OhJmVDB9^F4%iVj0^VOQR9NW$EN85GRGnT8}m-^$H<6i>T3lE zf1Heq4vPe4_!DGOtXH&%Q%$ajcZ!la2#mY%2{PeQNBAWNg*Yx49G_z3^0sE=a?(aF zr|jf%x)r%38_A_>#vkQ7<@7~ykFasCv(`MC4VC!B9?SZ`*HxWqVr7Nz87NxXTR zMbXUJdsLgf$BfyVcEP1Ke4e0~but!G2a86CX}WA$LrGnoB~wZDVJjSZcB}jG-rZ`U zl@{#Ih;l<4R@Oc&JJQr)c@dq4bFZ(bgAK~FQx$ISP6G{zQ4>x9vl6Rn=w6M_I$Waw zA*_hQ$HxT)4d^w}<}NgC9WtQJLL{_V0EqgBwy1x&j%t2OVn%b` z(0>)85iZKHe^bSSP>pHynH{4&M*AZi;%B9$AIMlA2gGeA9A;O=U^Z;zSrKU~!_@ zOlAyKvlZ|mi?F}py|kq{@r6NjKLfKJ**W;ycRZsBcBgvKNrSOgxVf-*oJTcFgHWMC1T&p7icfqIvj_{~bqpqA(6^4}G|Qz{%SDU} zR^$>℘z>EX*q0~|YHYEBZ0LJIgert-Cv=3gQP4M#Dn4@Tvqofft zO;rz_wqV?Sq(mJ^5t<`K@Igv5V)=b9qzE2JX{l#zh!J)lF-~e(YL1wt4iNKYFT^Z) zAf|~?VMEE%>y47-<|tY26-t%^p~SSHY$#cNy-~8#93?9qpyW5bsARefnhy5Sba zsXT59UN@J#ZZ3P>T=BZON{6npGD&iJYkQQ==GqYmfUHbW)y{$ zWmdea7esQb7|{p-py~dsRe5TZpi;f9!6;ndatiuBPcLIiuWilfuwx~I=~M)*)WE8G z6KvAxBVd`!Eb|1g-Vw{qS&=v62>DKpV5!y3bF4Vz62?;s+j3&jP2S$6A(;y}fA4~z z@WF!5gz`j3CwyiIHsENBxx{pXu)N-6!oC=S33r{Z`-VA#qgotRFPhn<8k(L5WS?e% z!gpNCc1if@)EW9gPlAlO5NOH1w<2ABYZSi0!C41k${(t;B_OQ!My3x=Uh#I4(I+YwcbzSJ2QGV9MS|%IA}J z^*E8A5#b03Bxa4E&}bGgK0J^Z92j`%aax86SHDDC1Y~d;rLO16yB zCGO(rHW1!kwIRhyBuC;xLqE}1{!msDW<=PNi1o}p40@fg-}IcXj9RDGi8^r4v{Jv; zo0_e$FtB@O zr<`!xDKXgNRkxcN5S|PoELzKx%tbVLo`t|tDuB)NyW@bt;d#4m(d%FH;&vGq--&0| zahXy?FI|pt{(1n`vry*5I^ujj5sq>>D z{D!8yjx`;wO^^9eSl24;^IS7-vea8(Nd0EGbES9Jfk&4bvh?4A4=%e=g^H&tJYxHj z3!GkP#kxZ)p1NxLT1(;hfGNliv?<-%l7*RfquP9~4Yhg5Cscr|^K?B3Kf4cJ4L2`H zA`PW4X*k8Im2L?h1uxnv7q-;J^1e{p9)~0PD`xe#XUix#3CH!8j-pS`)*M#T#0^li zn#Y#y^m-Scc1K{r(Vx0zr(rZ!E3G&0R=9dy?@J9B;o@R@O-x#^iSgKAfZq05r|h|U z*-;|k3~eWTozYoqX)_k#UsaYe$>iw-eSfocLH@nja@{P5t}hE3%Z{jvf`(4x%I6xD zAaa%k&X^2Ve*}fq(fdB5caoh{dk3*5Z&#A)CvdHvvD9AwOXs~7e49QK9akZ-t|xU6 z%{y)&TL$$v!M;M=_S-Q}bVAz_pxJlMlTa+=H-i{ZwJr2-L%)X;ScI#s+lpo*xaw*k z!*O&IW-mLp7)WAJzMA)(sz3|GyuLF)KZS0u>qEqSgHB&Im)M*5p$A;AY4BOkr>$)d z%$v^1^97rJWDuq=s~bY}i@?r#l~}c|8uxyyWnq>_!Py&B1omNEXsVF_5`G1DjxjeF z@!^5MwSa*Nz`heU_axExCZW_t@fl+y4~zrt_aIl3?CFBfNURp^DR3G{dDFm(n+8zY zW}7Aoo3>QeW(!|MO&vgt0FpBB`ge>A2dz zMZ9MEG;j==*1;S4h6rN|Jhd50ykYHHry-9u{uKNOwbUKBF3|@xWg3sbbQ?Pf=m>+a7X6&lZ zmmVZInZF z!kxV@efURca2UL`r)}e_oUse7jV{hCI~8l(QJ^HISJ#6m*uyqwCFvuC=|6-ZNIU<< z?xy0km&-8_A36nWL#N#mYCf0iYxG+vFpXTUs|uAgl<}UleZ_CmOdXwg%NP6l7&kuX zHWAmkaZ_{{H%AVz`}k1MBgla7s3bM2KaSX$h0F1XnEBUGIbk3%WRXBmJx;^ zB8vWfw+4{m+Bof><#gZv!JS~mzWvjV~^~fFLV<5G5y1bAoR1y0<(uM9c;0>p&!(FWfvT|Uab%pRveC8e}%YuFr$5iZt%L9 z{t9v9+%fO{JNhfc)nkzwN`(3=#68XdqWu-(c>e4gvO>I_8l#o9U2$dmHAcOwG3r+_ zG~IlvjnGpkRSdh}@U~TB)X%UrJ)EGk^W?lUY#n7%`W1k0BzN{c0NjJ^0dbc!yHn<@ z&Fyaws1R@WqW9Ya{w=^4dyy?or|k2Ii*?5B$@ul&Rf~PQSIJTrfc;!)KUeB~vT8q9 zs-Jbxa+}{bV!g_nJ6GDPa@^Hbrw{6`Mp^A$HOj7)zy7WBQCHg;Jr$nj9sE{T_k#E@k^Lpw;a@yO{g}2alM$Qetv`uj_DQeS>@i zev+fad9Hkze|4C=9U^+&6-oJ?8>qhBcq%>aCM%bea&@yoxvCE+aKnX`UHxhQn0d5b@UZGY$IPqjvL%<~ zY9XhcI8nZKL(bLodHKsQ9$ek_3G~8uLP6J^rdNWtTo(ejTo2k3Ge%$Ia0}wEgulYU z$dmPAaih9*8IRmRnLfx)btWh7hZkY^#-Jh>ao3s~#$gR`gGYKsO;XxoPq!uICK-bW zZl0*Vg(u0O8-V+@Rd3!V>|zf1yu2yZi#3Ie+z;~*s^AIg_+Sd)=IwH+cvpQdd_O!u z@L)RamiY#2NMJ-Nm2kI2c6^0G!mot~Cb-`%i$li65`29S z=QVCzr|{yiagoN0BgTb*7e|eYCE>5a-LEqj(n1e95mvhIB)z)YgHCV(=EFck7BL{D zSj2!A7BS$3MGWxATow<8Xt8;>QPLaZ;=jnjCy~>zGR47ZBRZ$4v_EC>5VDoZ6H^8s zA@^xVt&?#ZH+9QIjD_F^jPr|ln!B<2sR9%HF*0JE@2lW;o5#tx^B`ZC;ZKlBv0lkb zHCa?@az(sTl<+8;S@{mCHatzHga^7`sNlIWg?p;13!w#a5~leCahsI!Sr}6aTP3Mh z$5W2CrAoVsv4~Cch<-{v9OO@vecFj2 zRe7sk#FIDfsq#uJmYGWpQ+$$)X$_2b#`sYZwfdgp#|W!IDU@6v-OG=YS?^v)`ANcy zqIw942nl|QEV>UQk&5a`p!oump62RlRFYF9?rz(-%DRP;{5Pa%K4nPBl<*f^JBto? zIQT5G#b;b>F|V`5?ENrlR6tFYO)uHQ)G^lh;N=jGH*Jj-!9qscG=FiE3No;wUKY9oqRjv$G5Zod^_jh z+gKZX8|wkzexLE}8HqdL#d-w)G9o(*IuPTN#lWG9Rj z-@P&tt%1TP83V6b1OFb^6VTtYrn3w&je}T zwg_MWAD18fKV0^|xS-w-?#X2{&mMRq7Jk5Gk);D}NVYB*C-0$Y`c@nJ`a?DLwul~G z{)dX_KOM@VIPb~ve0p%>q2PyJp#|5ZJWOc$Oep(vJ)jnhoIG4y)GNGTFzBI@E-b2W zEJy%(*X`<3u(E(==FE$GDk@G;YZ z#$fAKRmvAIoC5gU2o}K5jqpwR?(6Idr#Li0PLJHpF7)AH+5R1u{p+^kDwf!O5W))Z-(2=Tdlg4{z=pmH1k?*M z^&rB3=Cc13I7nd4!UN|T!$(5dANC?B@F3&hJdE*gx5v1~_(w=ot1FNy1&H#GlA#(q zzYKpqMuyb0O5!n?v~u=ZkXsMsU&a0&clP&$v%d|izpUOLCMM2eVgfg)({dk8`e*Tm zN40pvV?p8#r(EM}5#?qY9@Y6m_&6u%!*<+KhYt*h4GoBFP$29`=?M@nfzrsxtW!$VxU z_TQLAUSuGz?}S$&xf4Da2xL7Ct$}`WCwxMCi`OcQ#1!!Y7}Qiq9a3p?UX9bKThCqSbo+X>JeV;b|k0Yh;Zhm^V` zU5(5169qhpmIQ|fLqr|HV?$K{cz!xCelS$|!S2Ydx8SUI1WtR{`79S?=dv0@%OgAC zbqMc-2S9i&oP<-Bm<}McJCWDF+6&7lN0)ZODLAo%^+gM>gyzAJ9-b1nLurR6YX;0{ wy|*)P0uUC~&-NxeoMV%n+6kY6#hvisAR<>)$n9yhpKQy=d~atUwZ_{219nu-ga7~l literal 0 HcmV?d00001 diff --git a/test/TensorFlowNET.Keras.UnitTest/Assets/python_func_model/variables/variables.data-00000-of-00001 b/test/TensorFlowNET.Keras.UnitTest/Assets/python_func_model/variables/variables.data-00000-of-00001 new file mode 100644 index 0000000000000000000000000000000000000000..399265af6a8977e264075c30974497c7bd37ea2c GIT binary patch literal 34194 zcmWifha;8m8^%volB7shIVg#gB+2{SIwVA-k|aryXjl=Ia_lX8@9a$~d7oR_(vtK| zTiPXQYw34>f5AELGw$cQuj_MrPWK?ZPEp3{lpAPio51meN|3+yjLJW!>}@mZelVWOZ7Ib8DS6-~4KS}He85>g8w*onKt7Ej*^;I3 zE@=~_*T#abP!EmUB84Gz21Q{V6t z#7C+L-q?I2PC@@s-JW2&-Mbv63Jq{@cnjN?wGpHbj4h+eDlPQOgBD0p`oR*tDr`No5=Jlk(V$Re zOyR~b6IU>Jx;gyyaM?CaUaaJQN;(L%~;QYqC`+4dKwndlICpIKY@Xi z%^Q$3l~YlJDBv#WXZ6utEJ9Ts zJsf4GfUlMkeYY?Owj4}>H{IunVfG6uWGe}FcZY~W%Oc?XE@Qi=%)s${S-Afv2SR;{ z;8S-FxCUNkZ$B@B6qoO8r-cR{n4$^RyQ^TiLj?LxA7o>7&FK+9Sa<|c5ZK!ceV2=|S;3vQd~Jk>kE?(aSaIg- zV?K4;p9|anO5m|N9rSlpgJph&@IBfOnl{;fqvg(-(nQ#ig|ImO7?Sic@QFb0JE~?L2Zpb%nLU`<*h#Gzb*<#3&qGiy~*gx)%nloqffDr-<^+j6LtTnID6%Yef>eLCl7HRYE(b+kHU zV#&;Jw7NWs)-0Kb-DlG2Z#D@A=UG!__uFK(N+OK=uK>}+C0P2H4|dEOYI=PE9L|o! z&G$-S$GZg}tXYmITn5{gm&4eEFN8T-4w`+{7&xy6#$KMFoLR@1sAt8%k&kXaJYfTJ zHrmh)zZ*d+(FPu?!~x$!g6PiSlQN|t+V!Z3Ju8_De`Ldm;lFV*oDoeqHJwJ0{tB4Y zVFq85r-4smBaDo%1GUm^&^{v#Z+@Hya&sI(^He|i6|w?Ws^`M_0zTTFOvIZJ*XZ1Y zWQ1Qeuvfo`EbKOimmgQ+=(`?9&U7PW2x@`rS}XS5y-XB;a)G?~?;Jf4Y=V`#UDGmW5kO(R0p2H1Z5JN+To0P!6* zII<}e4Ra%jXS53yqk7<~`jWYEG6O6fO;O=kEObpz>*zm`fI@xx;QeF+WGlp?*Q`AF z{U;aJZYY8-*BtswDFI(?TMrsrl%eUTF+La1hnd3qXgcH%x~Ei$s(lt{xdh{9%bn2c zw3AqhNYXBYa^{RtIE>^8^uQ8D?9VTy5(gGTG(QZ-5ADSnBDpBik%52yj!^M6b@(cF zCvu*hWc^QX0I|MsoK>6w#xWYWMy`^KRte5Zml!=Mod}I0)3NRFW7^t&1i}x#p;Fb>3cl)&K6Ho}>nz_13{u%ssjQ?wp138^WtvsF-AdJ{1vT#s@SKC%Yu^r2T$ znodbfL(N(}+?iH@Mi~oWCyj>LA|7x^=O{`<-=)>6p-xY&z_e>8 za4tV$&+j<|Jy(kHVNW$!J^#nO%R*U6zEP1@S93MP40H4|lD3A6*&cK8A=oRI7YUOfV zd?6i6ZDZm0DL#b1Bp@0mhEbj+z`df(d?`H!hZCjHG2|?LH=_dInY^Z+?V7~)P9@IC z$fGxH^ijIt8S#^iK@Yh@Fx@~HIeDW@U+s1jPdrBLf>TM<;vgK=XnPut8hVmF0kv?cJPr`?WIG6PiA^W4a zd2uw{Hnzdhv1V%QZw*nCHBk0+6@8~_gO|S4gW04ET-=xiC43nwfA0#}qgV)?<*Pw? ze>#yao_Az}HnyH6fp#P$(hfGtDaBN3%_@U?6h#7@`4HnUvx{TWL_ldFOEX(@3& z;E!K=rEw&S15p-9*cEV{v0FKX{+WJ)c04eHsDn)~67K~LpEqNk@NOsx=CfjH3m`{N z9=Qc~iC1_LcK?f|!s;rRzg!Z>m`U`Rusu%RF&$e*(vW*kgdW~n1TnYMz0C88*NegP(J0WYBRd!>h#hg4HOSa zrmvp-1Pmdu0Q<>}is`Qx-N$HsS7(y%^tP3tH2H@uvTBi1Kp5?fZnW+oP8Fh&F&* zn+>k~#ltaHiWu%JBHP>Q!MCb}zIi-@WF9Mk=UY>t^0+^4>50TnZ&45uJ_1I|sk+z{+TFB_CMY0KdS7Lp8A=PnmL(CmS9Jk)Rk+7l~Hu;?QD`3kcMMF8fn+1pX|=^ddLz{hhH&bVChTf>f8!ge|Rbk zgj!MI?l92x51_|{>VZtVPkpc3ke3NMIGXyD@f8)L5$_V=fzm>7`&5Cg{kEX2okD|S z%V}?x9u*rg1=qceO!cxqq{<>4Y#j;c8a1(VVj5{HcPaL!oh51F8(?(zNviWO33@uS z(6Pjb8l9g3kMue4vNa6s+7(IBk~l1qRs-(IIZQ$aA4avM$h+nK5K?rKrl>tc+^DC7q{`Jx``e^e6jFM((%camIMlY>9~#8FA1 z7NP@_Fya0(Ozt#+Xs=W};gkg0i&LQ1YZ@M1=MHU|ugKk@aBwJS!rAxN!9{ySEeGZ`e=r=RyU&ithl?d}V6M{H9F>&Y~ZP1#8M`O!T#Jz$l zeqmtH@*qlg>fvOk>1b#wOW%!$fOB>}@raWGqm(%?v~3sKyEag%z%@8Mq#2tnC5ipJ zMxy=C9zEA60qtDHzqN+TIxD`~g-E^07640-R1nXPIYaKQEi z9jNid2fcyNCF;P`)~|wbktJZ+SwV~jL!o)(50$Njl-dCwHz&bB zV-^t~&>+4+Lu9XU6y*lQGcP&{Q2t&Pajxv6ekD=J%2{Gj?;C1olE7Ga*JF{e2q-Eq z18Mmz+7^^RCZ88WlduSEoGS*h{W<8d;yZmaQw4b^ACt;JZ8Rt|hJs5MXl;lWc>4b$ zB7t&{IpBtTZDhvf*T94*KAd=z1noU*F)<>E?2cZ7+FO$0(d}kTEO|@qGQTsYF3mwt zzdV|{E&=F|%XFm72}Z^QYuS7M zygSVIRZ>_W{+#GfuA`$J*0>{KADm1n1M?sel*p|^zDqC5yG+Ql>;{zkWe;^d4cHdh zgtJt9WWY*XQ3e;ov0Y$Ln5zA&cTh}XH7X>Z*phO1Y> zq@2G;sD_pN zcAs8FG+i4>k+d(A#k;_{`@7L5R1vwBhRhr7tF+_OTCm#Si_IIvNn!{GdadW7P|-#Z z@5rW;R)m4C^c%)5*oNB8nMfy3sfFJ%N6_AP9~xga0S>pdec?wN;CvT&DSKuJ(cm9!)z%d_Za53uBnRk7jvJgx z-vim#(m`)RC|sClg`Oq(g!g7Dt2ikIf2~~$64Sil@E#eQ?(YXLJS8F7el84EjL>#z zCAz$^34AXWu>+lFschK^Vm~X4s)R4a&KZX=Canc|PfxH=tAWenVtN_{!354Jrh@3CO2{@TMvhn@yDZ5St%Wl{&B_PG zahMG9E9j8oGT5>z3`Xm0sM%;L{re&tB@R`h)xP%e)W&SQ+3JWkfe~0gn1|!DHbG7zhpsZHfVizC zaPCqfR%FCNiPbBX6bpm@^$hGE^+JB})DCkKZJ4K74enBHbZ53Y8hlv=7l$kHP>=!m zcWejO+vkbqRhDwjekBTL45xwYa6Vb29tQnxaxkGZ z1Dq`4P`cv|YdJg_n!gXQ=KG2T z(F`(KssYxe)C$?$C~tHCS-A9&e>L zfm|vFrOl(6_hy-}aKZ+VUlm9s`8hN)zK~j_TY<1}I@tBCz+U04#C^(RqLCy9Ne)ri zeBm7Xy>Ken-;w~=8fD7wTEZ@C%Yz-tW7Hrp8e=WHh+$?i(RZt&<6{Ed)xl$pjvPde zQOlXjZEHbkc>_q^$tU`b1H`x9n9AlqCNtfw@biN$aC<7?k;H|#xA!W!eAWWg>;pk4 z#Ri=V`J_3moMhx&B8p2Fg2l~Jd={_<2kbS8ut^rY`zZvR)q_S7rH7#`zyjpVG$AWZ z1OI)_1?Q~aBt9k@OLB6VqjUG5s6qk8ZEnGYm5reJL4pPpT_(+r$I$gcKl8599`;-m zc%P!FtWl93v>pFVW_V1-#Perp@A6zKd*TJv_?dwPRnd56btLp|{LPBx*n@c|7knF( zh{O0QSpJ5fsB;mFwn~wdVwUiW>{x|DGvp762zuUy^ka1ea_W_qk7?W?hefLe_1X?X z{xp!?7v|#N+=;|sVJ#|iTFBsAAx1PR!LnY?a`OXWSfR(qHNn^(zmP;ABjl9J)h}V`{y2g6BUlpwYDIw z5{7pyiVAb50LmSkPym&pd`!A00!_p}itx3RCZ`__1fgu5F@%Q9vtli-c zihs;u$<93FH%;sq6-VZfm>-Cc7UFR(2d(THQFFNuee!Y-*8OgR2{Q4>UAm9{v&jZ) zwi)QO1j4&rg9$&DqMSgtPJPaYSBo2Q!{j>b%T)qj>p3JuEgNfktDvt*3vUXmpqI`T ztmfp>b!Me-vo!{X?gXIcAyeA?@inPFyO7GeWTN-Ob-=AxVuns{z~^eE;Ii{LYFnXqKG5g*v9Tpwd!0x+JG_olge`dtu)}(N(@y&;|Cdsf> z%o(lD9P>9n5~W`I!PkdKwJ#Kc@B1NkubBdtT-#1L24Ty)goPRV3D0P9p%_*y zONSD@3~Zet0$&Hb@$Do&@a;|+b&a-lY&jDHzIRU(k{FFK$p^63oPm&E|IuqUdElz7 zhy$EKnswF@ocZ&RuaU$YHQWNt=GTbSBS)lOm&mTAlhND%5FSs;0<(y0LPTFv*wKi} ziG`RfYXqLZMCq8#Kf*cnfVudhht6p|OZ){s{%zlM>`71tA-f;6d5$xge#r$Md7q+= zU#EgoqYX&6Ae$8;gc^-fFjGtm+ZJWe&N{)Kp4TBdeNAA;JIQ#XA{^ad04uG_;eA6c z)VrwTT+?)DE;kp{TSIh|6@g*jWjL6oO8pmdV0f)0uGh`NqYKlJx2Doa`mYzZ`(30Z zZC9wY+eR|}RtlsgYRNmFw={XL9k^=mWJ1=Z;?K@RDzA0d~qPThpfNv1pTW8{&(XjHo@Ns382{q7;O2KZh4uEHu?GBRJ{-S&t9N>fuATjJCbyl{3B+6ACY?D z+vNIx<1{5`D$zgLK#UDFKzvp<$x_{fvX20ZK5d2d_Kwi?A)7d+2vfyn@wicFI< zcT8uMVAfNC-+t_`(K|I6bZ7;T3{9Z1Mc;^yND4$2S5l5%NQmi-fmA; zn;YY+3Q|2qFy*?3aO3^JwS$-kv`OHGs*MuRVr|3K(J~U@U zvHS_igxwlC2s+Os;9tJLZ}ZwryljeK`rV&& zAYeW<(6fSw#&F2Z4@bV@b@F^MAEpFK;Df8HVN*>b#4GHFPZ?EMS#1oLe9GZcdMqK0n1=V$5Gf&7X>$W&%)q?0zK_i zLtXY9hQxLy=*I2DuQwY8jNcL`i30p0@ZJ+UiFO*+*V` zHRI)rNw{@>0`gWwh4s>r387)caodQf$(s!5o#mF@x@C9lN*eInH>d%vy<_MK*KUasj%{Q0}e?p zgSD{<@aEcY^0{Fv=o$;tfpc~A%*;|Ki26)C<#eg3b0+RflEnDbWiaQ?IC*U%j&1vH z67x0bc=MkoTy~5`Zh#H?O)7!+F-7oF;D0!q9j8CV1@*zY0)79=fSCR);EjDFoR?!p zy;8SXsSis~WmYO4)%Zwl2g-4@M*`)foI!I`;I(d9Opo19!klhZqM4&jcoGUMuiJs- z1#1H5r3GzyBshy(S{PTbcVvd_s4L@L?cibI0i;SjaZf_0;P+E(Cus)SeDKr z-J8p3U3(aGmMnxFHv3`kuUPC-vS3%6@4+8WWsoDK&Kwud0lw~gX33LWs1c0?!&QT1 zaQ97`Vv|YMe{_OoGgUmekq=`uh{_LV(M3B-pd)%2B;9C%p9wr1O&54*Qo;-;cq(Fl4_*u2qXR`99Y!zcfB(X9zjHqPoXUqg3Ta?AtBfA0HNxlaiSYKmfVpscK~I^P z!IP1lXcfN}zIygk-yJ_lh(j&$O!c7Jw*_~>C7W!ToPx@p#c1fCMG`O7Q!7tflz6Zf zhvW0HS(YVE-9`}UbcxC|RN=?(;`sfuGr0E4u&!GM=&LR1SQ2E9u3v>1j!rb2e3pT) zCH2HMvm5;pmM;(3>SozW4aU z615b_(9NPLk9*nm6N5l2F$d@T)4?k(O*oLAOdcH1!l&=H;Dn6=?kB^8NQ*yYpH8Sm z=?C|isefCLlm6PM@9-wzs_+=jR&j>7(f73d z^iA5Ta0D}q>nPXzI@8Pw`f|yvcFldeXy+<*a6k8oJY-DKW;7OB-_C$$@|)0$gCxu~ z4~nMGg{|`vP~RbuDIx`>>OWcE-Q1W*? zF_`0tr!3~eN8RJFYjP>*XYWVbe{bnN%P7#1sl}~(Td=iP34=sdVBOz(@R^hki&hsy zL7EfBVHsSPFGTyB=ct5~CYY|?1^c}?$kTUYOGb_`t~NvL#d(j3`=1K>pN=(bYiWkX zGJ(jK13I?anG7}sF`V7X%n|QIJTEDa@8XSdKVyU5-#^m$27egz^r6j)qdr<8;%=P9D1!~&=tJ_>Kbf6%t3Yb5kS9dc!QNuA*# zTr?y5RV78EGf8N$8AHJkTik{mn*j*p{%UTp}l5nW1k}H}l5!B5C)WNULv6BA50R!IK-aps~~qUY#|AFxhzE z?Y1NQ7)yq8AcXji{3Uj36|COQ9%?tEjMfg=fW@C8c;`Pvzj<%N_cOD=eBnHZeb+#F zKWEV=bG}g9_bWk!BZID{`^n7Dd^~Z`5;G2`lUU1E`oQ)74HJrp>`0BBULb3`LJN)CZn37 z7cyh3aCNZ_zT#+sUC>!Jx z_#kB-{G8YVi%py{AtV<3#4_;88vmX6GMvvyvKjQtQ-?kQ2t>&Ze zf~Ca$??2MqC0NU_W;E=|rz#0D7_lN7Mio-XRZ%B&Hi@CqKfTa!-9#7=N+(6Ui7={C zOJDaSK}V4$axWjki=ACW|9SzA--yFA>5cG}4v-rGg|PooEOac3!NJa*WVM_G|bS`?gG)r4pFR>3WS7kwbV5f7-Xfb^IIc(Lvvnw%8)9m3bn zxO+#S#4LLdQ`Q15e+RoExdM7G#uF(?TZnY7rLN%;OjJM}_MR4oz*BM9TkwPpY5PE3 zPhDfCpY%h$B41eeTMfq+@20>0%!7;N;Z&rriwH$70~c=}=&_IjPR84g_G7Q9oo78c z8uEp9y?@qm@M8j$?TE!kO_eB=XAR9OSJU>QW|BOx6k3$41-@z#N*BhkUiWe!P$3mu z_h~Sdxjx`OGB#PY4lAYC3cP72A{D;@u6ziGy-flQ)3*ie?VeDN(jFT8G@Wo- z!x>fXalyUJhMq;&X@6BA4!X)SZAXe}{_9K_73IJMi4-_{wFu8QuYp}r1pk^%MY+&N zbmM9tOtEB{Zy5#P<5`1uLmaS1sS=`{CkV3RAn@z01;fwg^m2$P==@Cw{?<^^_j?8G z39Lu&9}j8Z-AvH)-Uw^@KTy6y0`YoL3{OxJdu=z9Yf+J~VtEdTsotRbHbf#n^tuse z{T$Ynth8VS2@7t$eqwN!wemt)nQ#(3Y0wdWN!_mfXK=? z(lN3ge(lS`Jxhwit}*)fUGR>zE5d=Yg>cH=4^COSf$R!#u=p1b zaejHA+m%Vm{R@akmjNU#S%sfJ=t5rEE~pn>*1>Jo6reDe2(fPV-oh4Ht4=)m}I(lvDwWBObfc;cBx9Io?n zkA69*Z+uJd7DV9Tm*zOFassy3?jaA1IRdt_gCy`}@%8j<(78|s3m*vf=7%JsRk#P{ zr}H6)RYTt9xvb~WeDu7nP9+2H60s{GsIIUM`A^a@b!Zc22sPl&wkEjiQ4Ds?TBLo+ zIr_dLAB+>EpuAHZ;d>SM!ErpiO%yl7X1wXV4SvQR2iL{`mam=JacyQEOk7q4wXI9x zzTgg&qF2yr&EaZ0lvayxWccb69OIrl{UrnWbI3!gi#>sYS~*_>OQv{`iF(IS=&2rN%sxw9JJ&Iu2aQRJ zo+oWNSckm64`j>Agio-`BWdn=H!9+dTVli z-XY{QS+M50X*hnV3QZ?Pp}BxF>sH-N{3oU3)U zOCWcx81iagkUD#7I9WLADs6D>U+2j6SHQvau+ zFlDcR4H|G_JaT8je~WXl>&6;dvRQ^ruAG9e)U-iLVIA~6mPJ>w)6{V0dSddi4s^@D zvnhLlhW2j68M&dTIh=<=|LH(aObERC1%lkZ-Dvc8J>&5u7v(mjB9-r^B1yMtuYiX- zC|DPl@ui@=^(xWslg5khSK_Iob79}fIIQu>h0^Jju)A3jTRfGKGZD|2MOKhoX{)jF zbt(FGZlTpx>U7jLhdPb^pl8oEf$?4oIBM{Wa*Jdc-ku=Ry3h*UyM-a=0)y)My6DqW zg3U!6$-tQ&Jwbj6~0vOFCWpU5~K-iNy`#zI80DmE9+q3t)^sBZCk9NSb+jvvl} ze+Osb{K~^9KKT(jWnPJf6PA(*i`=kz$0}mJUmkd=iTEg^7+q7Mn4TzCcq`lp>o(-U zsN{68qx)!V)>kV2Oqa}Cor>Sx1RTyZHTdps2`~OE$6%cfa_LMl%HF&{?Y*87j;AN{ zrg)45I27WnZ>so0`w{Wr&j-GPJ=|K! zHx2CbO9Y(DJ<@kJ3LP^T`h*1FhT2>p#TNxV;X_g|QiIdWd?50PC0=Z=pc&b>X~`o6 zIx9y3<1@VR^nqhwJmv=_Q}2+XVLiC4GDbtS_oF;5#+&Zj;AwmVNYC*g+%1W0oJ$34 zdXo!>3OM-GX)=F%urfFq45O}w^eOiJ1yaHzFi+u2T)F5EFi&64MD&`i|!jm91C~H3q zizg*O!L%ys-L3~;Te7e#^^uWQyc~@BJtZRf(bVE}G`y+SpjJsL@JBv^^!_J7$7gcD zVNV}z5~{&(-9{MgqltD0?vm(-Z|GF%DT4Rs25VdTnf`tghM`@lpm}^Ma4*y{V|%7S zTWke-Z$5@L!)f5?J&lHhU8QF&^MRK$XvDj6zN3Dg9H#JX(WJf!Lu@}1b(LiJH8~Ux zG|0f4qw;jIcQ#n{KB1)oPkeBoiq$OoOZM>zu8B)Ul?&y-{@#FVTx6jz*$xNiIJ2?= z)6wyhBB44J_&1>jKGXdWP!NuTwTi6nQ5{+nu^0<7C{a@r#1b`V40t@mI% ziwjY!?FFs|ac~ID2vM%W6_Axd?Y} z5@f%B>f7bdreaQIDt&yd6!<)Kay3E?J%X0u+#BV1bg>o=XyzlQ9Z2(x0cM(;I1col zBll<#svob$#e-4MwVHMeEU3Ysa;4yNyah)7a|3n99hfd#4Y|45SS~q4orKiE@Uc2> zENc{GT;cdyz#WyquV=K1)hgt|2oL_)u0oPT#1Xr0&c5X#10y#9pbExY&dt|84`0$)r<9{Uzk@ zmKc!UwV$f2EQFhzd+GoAgeHaAD5i0jmYldmrQB+;<8l$!md?hW3P-43y^`>P-y3nq zJQ%yXlF;w!g)2p;;c)3pIHwQ>4Z1Nfx>JSrN*txhp5C~ZJx@D6O5om#|)+xB;l9q&Cq^Po@nuHP~yja zOnE9!^^Xq`-@;ZlGQ5y#dmYEqOSJJ^V<~cWEGLaOYtjCBIvvWG0^*g?*jFy_v+HUp z|HvMrMuSbD_)8DQHg}O8+Yrdyk`B`8^YQ3jU5IR`B;3)BiVHC0T?n?~^dM zr2!mLcEFa0J8(cRlk`sPA{r$+Fqr&=lw_=9bYm`1AM1MP*G?kcU%hbLXBB8FW+Cq; zm$p}>(9Z9+zaDt&JEllmbROO7PT6b>KADEGzlbONCA= zp`qb1{2b&BjcDG;B!1Vv6y2Ge0foLL$T_0Kq%BXy8G zwF}4f(_)d!bz+;JDNy^B|IsC~br3!GDBS;=1(I^sfW& z(PZS9`<@xGvxKy{_OScEDC|(24-(6^;OF#>P;8-!c4s>2-tuDl_(~S|d`$%N^h7Kv zPh}jgu7xJET1eHB!CkfUv3JfHQafD{-t6rtAFxPl=j|-S9n8KRIcpBi-h|M1-&>41h@S{y0vtk?} z*g_K9mmi?}n+tH@Ul?v0R>Sft!F=DPLX1p#NB`+m;gu(4cn4r(2K0O7DXLSxWi=^UEtqJ?$}r>ja=UcWZxbY6jHrL{q9EN=mRHgnguBOB@%RX zhZyegB`Sa74}D-g9eXW)(WSyx7_F0xsZHS!yg3-HR43slIZ;d!Xe0aP3fk^+2nX%2 z5RT>aA&IZ=$*+l#J!(r=N18{ z*Tr!4-#E~HE=)B#JYmZL!FR1#M{@G|Y3(&F(ENOxrtGyPwXOOvcv77Wb=-=Hhh#zB zx(pKSE6JaLGThXrgJZ`Ik~MM>_^3+(nw_68t|_Zn-`SVhxq|tGQ_$&bD7ti+fwf@vOtbGm^7q_#$wCBIw~4M7JnoaXIo(9#BAuDHAYjggqjW=0p+;UlrK5P*8h#b z@2kw=twt5T(kMZWeqqOK#b|U5f6a(2;FGS7SbA`v9?$o=!AK_`ho-v2;KWj9yMR5~ zzefc39*CtWZ(7K(=m1FwoP@kc1(-Fp3?7d+VxUn5lpJQ?eQp2I9)+0=>F}#;6WS-1(O$#%#L~DPwIzdbkRr>Q_qb!=iziN1PJx`o zIb_S7Xngvr1Qg?~!E!|_o!Od#;|bQVIouB?E{ud_{(@P}rT<8QUlUpe?xSY!98vzp zQJSqJiOYf_pk&ut=0#-zjcSw=%vM(zE7d{qbqrSOt z^-UxC3%E}CX@}4$aI0XZu8L@%EeD)k6!^*nuQHI#?J z$$|8rfT=AkiHGDKEvP^7gTy?^hyD8TU?ph}VLNl6qt^w7`sV}u*NnIJHsP5)w%A2q zc8m|~2lI{Y&>Ip>bdmU%`ogrGH7KnwM>gfwfYUrL(3MIe za}u(!y=@=fSmuZ$jZ-0UJV4OHrlZi%P8b+CN22F@!=TUwW?OIx%$5m&nnO`A$mkKq zy&Tfk*}#Kt1{__LNQ`(sxI{gr@!h@xt|uDXlunVEYzw4ImkMevmvE^W?W^;F+atm_ zJ?R~FQl!+jx z)O~1Z!WP0~&3fqYR>N}VXcWU<60A*|1<$5 zHG3#WD08`%%uL+7KLX>UCqcNEHNI6Y!c}c0@XRFzJZlQ*n|e(Wl39;Ze-6N2)yG6= zr8s2%EW+2CD)>3;Ih9r{Bl>3(1=#_Y{?sgi2@{gBpdgaY6>unNfdU@bg9~$?0Dg=H zloZSh$uH-l%d?+!l##&x`Y-hV^Tg9{>Z8n?V!$ne+xYb(i~1$-b@onF)Cxz*r;G8^ z?IvgqUjRLM&Tu<21JAz|LD%+I%;3csCbFxP3|@9*hgF`E(!>bVRhFU(Qxd^(uP?oL z(-RLoHHJlI8SpZ67alNN0L=>iga_pv-{&00)GjlWk6l0(Y!}RqcHg8s#znBWYXb6x z{}RJ`S86yEOUL5BQ0wyv&?_@e8cgHy%f?jjXivq#i(J}$u!ER6io(y&DY$K4JY27? z!Sl*X!ERzE)3eYLB)_W)GX6;r<#7xoszgvcC5%>GD?z*dM&y|OT0S(*3uY&4L#|2^ zWFFdq@AU*;({D>Oe;o|6GvftpkSr;V6wGPeC_}#k)CC8L# zqT+En+t&#mtlkT}KO4~KM<9+Vj?g7m>NxeVIl2TzVUFWGVt7}aO7@?nU2R<*>38kX zzVip^Y8PQ&B<9hE$MktPpIY3wb0MmOdib!8-C3#zPfm@cu>^oQm6Rw9NciTA2 zU;m7~ESMkeA5+Cs>t+k)n}X?M$0eu`RtCIH75MFz3HtaeVv)iJYFcOmU786j*JdSI zviCS7dPu^M`AYaxy$!8T*Fn=*I`n#^F%BB5fcyFy;bcE2mb&xlS8fjdSXhm_H9Wvb zakjwY+1JrslS4&cC&FPLRdnatASY#*@RK=Yu1pnh&yO(NPA9kdsYC5f+IazP72p{=Nvt80>F*iy!y4)5T28nr}vkP0k&Z4DMus&K!Z@zh)x zFXC7|ElsV3w{x?=%J(?wlBvOn6^U@gvR;aiknB$9m6$Hg(;*;f6q;Po`2C(;1?ak6eC`uncMPHycAGWZy zqK!Hj_LJkoccAWFML48T3paP?(ow#l;M71M?!E#JzBwF{KNWz3@?xAPzWaAv`h}CL zF)qUqX@(i@YTJXN^||=bw-RhKIbu6`H7$G5L#$qIM7tL$C>1M7`Sp6_?W$~$+cuhZ zEp|b&^9XGei>9U{P70c{?r{8~oz52o<8G}5 z)9SN=oKJh8xtFb`H&~Y=Uis&06Bvcsukpq%Vgz&_fdn?Id7t)<#&L` zts-o;cqeGTc>rgGM1qt|54Sfq4ex4J!WY|6h!sl5Z*)5L-g`r}Cl}I%bBmyJn>k5& z&2R%+LZo1GC@N3fOVvNf1K(2B_=mp;1nXa;h2laO{n3X`;8j9Ez#8~sQ~*iOmca?O zmy}(ujd|mE;M1@KtzH+hJp`|kVRWM#~xFQVsRt4f;;vv?4H7t^f$Gh3t@a1PD zjF=sQA&ZXD_T9_Do^eWdyc5Mf&n`}gX-aHdoPk+mD^Sy3AD;>sE^omu%wOw@sXELr zk(EZ)B(Fi!sA~d~Mt5ku*-ioqlTqu39<)9Tfxc@|0Q99q}DRD04R?4r?MLFqz18E+Qdo zc9Q~mbI8cfq5Py>#@2Z)T(eezag+FF0MSBdv$r6MB4*s27tJ)Qy%Z7`OB24mw6WXN z6f~4P06(5);q=)eIQwWB<3rBoT3XMNlHJ>(yL&in(yT^X%?A`LvVcEKjq>_m9FsbF zj1;N+HkX}ErsKi2gc zVv03erwcb=TemuO_DjbF=S_h3eQ{I7`7%`Lw}JM_Tfs9?DnRtBcTFBIWNjukOa7|stplDtraMCvi)2coA;C2wS^Np}G z{w}xOF$ep`J*RdHdP!ev5SI3Yg2|`}@Vt3APEnhN3-9iQjL++-brzEZ5jjBlr%IX< z-sC{{5EF1{{2~bP&1D#z7_3_To`mze$oL*v{GL#N{8@#Zon;}tD;^7ZhM}Z6cO3oE z?hIAt@5t+?YS29UK$EEZT+p7`LwyZ|(6coaM+kl+OESVhmF zEJgBmui##1chZnCSq!I=O}+giakH2hyu2dG_yt)cH9s0&7^gz^o>JO+xtDf7SA>#1 zZqPb-1&n9D>w{aA!7gGd>Y3z2-*aVDJ)cNBRBw=X2gcyVDWMJ^g>HV9b6h)iLt$F!6>bUj5N){ zyalnO^kPTN!rByw)oVKP%PS3`~K~bEk zs4})$ade(^Io2W{B&NxONM${aBFWIFmPqAmQmHDQC#uRt^vsN5u+mQlwypF=x#sWW z>A)QP6rl!MU4HO2U?)sA=HrnS@~~3R3xu1D!RYp7nitQ;pI9xuaUmQHR+_@kL($;f z@RJM6-;bGR7vSvesc>Yp1Z1C>gI#J2KhkfA@Pso40;(L*FxOao+K2KJ9VPU;5L z2ozV%;#!n~QLaf8?wqiMuS?$17PjB*nc6^Ap53PEEjvL}Wd<0PxD)>M#R7-97w8+E zS_pOhNb}?+smH>v)Kw&ks*cGeE~Xsh%NOD9+TKnYH)C`J3L}oKB_MX zorfl((TXjkxr>Dr`Ge`i0Ta}aNP(FX3t{O=5mXDy1?7a%82G)2nzx4_&#`XinQs+% zNP&;a1M#>w(h1V9m&3K9)p+jgQm`J+qxIi|kw1J3!$B>AyhXvp^l)BNO;s@_j(AHY z3+nLYpghuIA&PN!vk)Rv&`!7j{Y-|#x5xF^lv{yI%&gFI^fK~wY$%!*PsC`43Dh;X zlcqjMv>jmkz?4Yb4uG0n+hK!i zI%Mt62lKhV(b7+Rv|Xb{*S;*oD;<@fZ~BtxT-ktj{AQPYx@ z4A;FdoHPW#AUr+irrXo3v0rx_c`(}$Lc|;BkAU^4JF$y;TfU(~^5$aqn_@h9RTo<_ zMiDzPN=Htq1*Jiv`2J8O$X|(}s^@mmPutbt6JH1T9xtitOn*GQeiF(Zvmv}u{l*3> zr@)eljDOxU8XGgN5dPtf0{ZYQRs8sd)WsDFcy?2pX4#eDOWS;;);uVSsl~op7u+aY zO^Y_2A(zjepmt6TG@;!O^>^^`%!){+*WryPKo*0qE`U$w4scOp3MRNDLE_A-f{=&j z$;+j;=%>B6Y5cZwvdS_JZkc3apEWG&G;SNOPVFW6=Ugg}b?}xS*(nPpI1DbE=5Qnfzfwuiwj;|*qF!N6YPMh&P z9QI(3%M@~9Hsg!PDv&)k5kR(`qSLk~!rEhu-*RFkHcFP^0>;ZO3)UsNsgGzmmO(?- z9#Rmp72Ai*h7cEqv((E#PgQ*wa)-sr-GVqp#S3K8hEcGra~Fo~O9Vc{pbAe~fzPdT z@XkIKvwJ=cH?7@(Up7SHL$PYqdCq4VL#=f2;>pnbK9-uE8_!kP`hl4X4?hfKlQPTc zM7}f)UAnyirsRR5yCv~a41vb`?}_wjMG)D}xVImd;RAsl9N5ERRiF9=L(3$wq2nf{ zyZz9?uz}N7ksyl-O5pO*H#A@z!=S`02mXuk#9uug8^mT~a~P$)3~%VaY84tpy1?Wo zvCuem4Aoz4iq{fKKwkSMDIdK8Z|?EI;J4Cvllf`HTbAQNi)nE6!U1$DUjy>Nd$8|% z7uT>o6PuFvVV%NnoZ`YPlGLIOBMy#$!Na1c_pnP`bE`0YqJNUsTv!8XK@(uyq&Z+) zyd0z3TuJtv9wK=#7v$2m64S~C%DXhyc-?n(T)Ax|tnbo>DV2rrc2_l?OWO&?rxUR! zU=nS19R`73=je&XXn5Z#j-KN@pnLHIw3Tfmwc(*mZ>y8)cc&2BxEUmVUp6rfi{>08 zCkm!|&c>!%SLmymOCL=y1OHF4uxJGz+l=^vd0(P||D=Zajtzt6v2VEMEz)Fc&~A8c zRE_!rm&wz>p-emK64$YKEs?OD1G(qcLtf`9!5G5?z!ik8A?#uMaA39v+XD19ScaxO3YX z2QxlMNO(JapCyAY*D*fNt-GY_yBF~9$bs}KSMbohKomMM@$S?xv@BPrqc=_foo^{P zcLjS-Dn3lx)J_rUw3{?hIg#tPHo&R3qCi@9Cwy1ShmxN4%>TZY^HvQ<-_m?o5MF}b zb)j6aD35XfXTbXAV30p>kk(x)quXk8piTac;6v~Tm~>+!=B>FXupD5xrd4BMAvYF{ zjM?5^nkPu8I{+CX+i2S9d-P36F0>`=rQxHK(W9`E>{(xew`LZ=dG{C?H}f+!O^}3% zO3NTG9LbpUH0ZhgjkxX1$2ISE;=MV-F!EbI+sDoq$SG~dzMdn5ujWK2K39S_`twlN zIZ#a*9df)=bc|53PShj<@QgHG@D2Ohc4VI0IsD`y2 z_M7h_KkUkJ@Md`&_96{7_pb)&(|&Mvzc9`*szFobRf2QA4xq2L6@v#oq5Kt_jr#_r zQEzKUOb)DsA#D8!d-rQUTuu|uNzj!qUenjFIjotnm1#vW@?fDE23WQi8{;V@Z#h=wAWNZ%T!gOoW#e#$S~%=ABD3HRO9~H zyFq4rF%fQc1~-#3uziz4Hy9C=dGM8psK?^C6^@W`-3zbjGwqUXZ-_d z+#*(sgVhGZlB*RE8>S6)g+b(Yr6G2WAD}1qhJ#G}NorZQgm}-@Ar-Mh;C&5`y<5Gb z5j!%`bkY@0#aFAK*5)9WGnzA26mEkv`kbs!S(iLALH1U^1SOi#d-D= zd$A8QUrrYV;osb0U3VEAZO(vD1$VHFQ=^-e%|YtRb-}u2bMV3BEX=sO2w!fwM*_b} z(Dw%A_-P7@_vmb(-Vq+eZc-#IdRdRx9MV8&mK+8%-4hOP6Pgwx=9$5v<)kHU63K<8Pdf~p77 z;Ca(vPQ*UET+j5s;=`#_+iA|~_!APUy%N71D+0MN13G0Id&VVmv{(voljPumUp9Q( zMsVsVLEe!e(3&+Cv`4ja9e0z6>EvVq|K@L8^C1b&u`C9T_(|f^`MEHa#cTVf`;eo{ z2H}O(>CmX+O~W&mF^<+L8vTZ&Rqt}qc;ahncrOoTNUF+Sd-7ISQTpijDF&G23HKAfDIhE~_MVTSuq zaxiTcv)m1e$%nZkM=fwj`Cz=f`2|%`QN*^UMgh;DbLI!hk>GSM z93w|eg`w5naK~moG=JUO6jE45_0AjO`p+{!TQ!Sp-KBt!Tm!KxDVwBxY$J(6wVb?B zC~XTG3{LB1;Nlr`=o=G@%^s>mfAto0o1F_^&u;?xj|}HpTY?h8)=;$SBGC>}BGU!g zFw$%g!zQk!!QC7Tc8#FC^w-?st>vKDHkY_MO{E`%4^eF;Q$iF&F}gb(8^$W(!3FD( zKjR7!=~82wUiz4@WCaSBZ@@*L*5l3seF*;^jq;mLlU|`@Z2#g1(K40(+fsG(CEOD0}$hIK?$EDQO0LSj#XTBi>O|jR)`jb~KNt#hDJ16iD5=FKCvO z<KL z%_Ql?Qm}3JM~66hyeO#%)*j>G;M6%#F?SYhZj8ZxKQYKokbrLz=@9D_gN;4|WVme= zbe7B$KujLXLlhPO%Ior7F`!)|Chr$YM~Z~omx_Kc=wkTGN|thwWgy(+=rT~|xt z9K-BJ>0<1VQVcoJMs&m+@O9io7UKwFc+&vN`!I=IxX5tcg)2}(aT_#0eb)3t!Vji1 zyvKxra%i8p0Bmiq6aKNrrg27Lc7adurfe{CmqxgYT z3>aYz_nZy!vb`+i7{}wjd->p}H3PRks{r@bT=*Fs10xouV?neh^k^y4YZ0Yre7KXE zyjcv9iJM@vxHL?EpoHfhvvn%r7L|*sz%cuC5Nfo;pq>mg(3uK*69C>S6hgrvCs28& z4wlt@^i4z!?+*7K+ ze9CfgU|8Uzw+dh&DG3eIgJ9T=GT6R#IC`jLq5M+`c=gE%BP-Z` zYzhwjwgvjl5~;;fQ>^`Ag=&k7aIdcldi2Nu-Sml!Z(-Q&r{y4dsRGh|GL4KaMezDr z5RCE*!L=Wm{#CpqjOdJ_zOOQ%sO)wp}Czy1{Ip8?zrYHglMa z&C|zVp=l5}em1eL{f&0#Jf}5kR`Aj}9Cy5o#)g1I;&t5&Z6k(}zNJ9V?Hmk7bChUk zydpdfV!X0^6_~ke9~A9voB;ozU*Y&7#BxZZ7)U{ zCH9`NundFsw}F9k8O~eDhccODGSIOfy9ZAL`I(PNXGJ#+d6i3?{C%PMLTgiSbu^~Q zU7=!CIrw?BEz|7Vh6cQJRP7!~<#Ws#PjnHaO$$ZMb*Z%Mek0AZF($s*Iq+g(E`B!4 z2D`8iWDIMEv*(N97yn<**Nz-LX^(*u-jc)#;nZ3GxG$4M*L#e66T~eglPtC+a;r_C4FiCntOm`=sY&f9=$NcQi_b`6(8L&M(|t-Vua|?j>!YB7lf{Ff%Fs7q9>dBUq^6I)Hr*@T zk7J(I02%X)9B__?nNwC^@NNUBJXOx{>O+_|r3+pd6al7&;{`qegW*MC9jfQL!(_uS zXwA+>SG@`vIan5_j@u3AxqT3Bn@aRH>f)7u}^0xR;dbs`S)^MHRK>F zZZi~EF(2cpsv!6*8-1-d;xFw$?Oh|{zV;gB_J9=I^s8`GK>0x<}sJtyKw?81qt z6=?!Q?IzW%cuG^W;&8r2JoHW&3od56piF)%t-a%oWlHnVboYFw<#UvKnE}ljJ1IZpFdc4}0mY^f*l&EE?6+Qk7t%9f#LZz~v(y#$e5n93@gr5A zW)An79#F+|%P5VTgLMxpuyIo_IE zN2oEK)D$f8F2aoG<7nedYnn193m^2RqVzo_$a|cIgJzU~*XU}<>-tF=Y=Uw6(n&Pviuv&PR@qHOJCD;--x%jYUlWZq zYEc)|z=$wBmd;H<_=5q`!tggnLA!|hatBytcK}LxmU!ew47!$y(~e3nlxiPD0v+q9 zgq8!;vOTNm`7Car>=vzfrU_Eo$SuvBf+A;5k~pR%v_&%##7QK0_pG71v-UHr$40y! z;0^pmr<%&#c2Me6NA56x4arO;XIw54(X|Y-Y$7huc78|rImv>QYi#dbBS+o}i{ag? z^_b}Po*Q6eD<>pQ{mOqI@rrN#go>qM~m$V z=#UQ_FMgA8b5CvAfkn zOw|nqZL+-rzUD;ZXLBN;Cq{vUSJ&c`C(@|$LK<32XF!ZX4*tRwwsj@qjl>#Q(l{Q) zW;4F-+iFgAYzV38Uk2^Ei=b7h5;1BH-V`9Qf0F^cBRN6J%v=?pbky1WO8r%bKvvbXebht;cHJJ7&4_2o^2nDf#r$R zGI>1MJSjkr8O6k^v=UvjjIrTPE-gw-q3(Xfa*y-`tdU#cjpA-iM8b*zCR5+mzxnO$+n-qs5VK_&ig0mtEE{GcYG#paA%qfmE9bFNfLRYevD=+SThc&1;g51<1F6qfctCG zkWP&wn)Rb$^{8CNA$&x#!4|<(Vt6tJ zukDn?-;OhFRNra1l+(iT^7U9)!gNL)RE=GZOa$BgA^7ga9Ej}E0&T11r1`!wSrHip ziu=W3wcJeL3Ta#3FC>)nok{V<2(R`6b($gl~&rf6*gS7wSMd5^PE3`Bbp;PNJ{6 zDxkMH4);W6fh9YOBF#vO@$S~6b^Tdt9o#FJ=(ie7?XESQZ?1wm*;uMQpiAApB>}JH z{Y+jn!^q|YbG%da+-(0^Ts?dp8ee8QNLqE+<{dzLZzzBy(|8L^c+B?gt(2c-!C5F~ z;36+yT(`0mZHLYx59;T_>3ta(D-wp*pGHx;t3r_BZ^Ag{iS*9H1l*vJ0vGK^K;9fL zGH&u_2(VlYksku#vbZ@>&_>l%)a2Zci7 zy&crZaV!0KstDGb=mKv?0=2le7@t;FgI8HN^f->h3==ofJJJioA9d2-5)|;}D7Fu% zc}6t1-J%CJn4!LIF71>^!TjVj=%4 z!?@#s-n%`B&Bt#fc*rma@A*VFG)+MC$v`DZ9#%?uIsX zO4Ngv!J}#3q6!R{z5(C#j>VN@W`kLM4CZ+oaN5Nc9B)lHm-KKA#6DwjO13}0`?Ux= z?x&%eV;O#!@RUAIF2x(K=Hts&Nf`F12zk$!b6rEXf#lwK*nCNI>i>uuREZh<9rmm?}OSK=HiyTbm$*0OEPjbY2XNT+Ww;&C*(2xnH7;3S;I$t zuhn=nJ)AsxPzZCJhGXcq#h_g9jc9DEXIf-2Tz)i*tzUUVRi-iwX;2R-UR@05il)Q- zIm!5?sU8$V$IudUDO{eW#&8l#V3$}9I5ld6MF#V$8oE-E4^xT3iY)xGg@eYrR$>r8 z1}YtxW~Xo%DDE>7xO}&QgWsludBrNw8dwU^LlkIybUBgQoX)9QGwq1lamaI-d8{dw z;YD|*fnHZUX!pP5YFknvftL>k%p z1`_a`(?|$C7XUoVn~g2I?653Vi>|M!!7DTY)-wO`eec~+7T!)WZX_{2_8CE)kr+)h z8IH?$RD+-AMcRKbf!zHq6IKTDK-xYP6`yZJ)08g)9f?>Nu&F`Q1O-mygBNv~G>>VP z9uhcgd@sl^albAfT2D`d9_14HF|!sQEmg*7c3#03sd5VVC@~Rblu?fSW?haskAO?G_Y@pBR ztYF@-g}BI=>6?Z%k-UaU)aSxtka%kkllHb#ho>O|#f~n4_x3Clc3K2hYF}u3v?TC@ zj~gFdlL#U=#sV@7;^!-o;83@g#5c6k5c5LX{K}kmTr#J^Bg*l|tOaP&^O8PFsfE)E z^wDZUCaPRr4a1(!M#})kZDGD6lnTYXj}OSVwgZea5sAIVnwT>^1}3@gg!+yHpvcbg zxjlFbE_`5*`)1_AaYKJB(x&vB*=E=rxeEG^9w!cyzX(>FMxzlYL1V&GKyT-6XsS?# z<%#8Z^k5o_{8)m4Lmrc=pl|ezbUxF8ilLo}tGT*|L8O0mI^BOh365W6y6pBcux9c# zy7#LBezu>8jR8!{>#aIF3quC(#v5V4wLq{<*+oo)oyfTiWr*q0!>*2ra9t||{BAXq z111x29_9gOpoJ@M&qg(m)mW)pip_O{1)kzdQ2u&4X+E)1aAS8WOq4IfIT4}|$@CQT z##e)t)gEeae<7}z5DG!>s=-jN8mriNcBs5cc=xU|EioHXuyzF02{A18f`jCyvOOMV z8Z%SMia=3k30;0H0z8s@(Y`(or1wd}y`jfQ?ifpuSicm{bPeS+wT%TV_C0^E7h2$U+A zPK5s!;2ZJ?zhx`ol|>2MR}_Q$#B5NQe*o<-FTopcB(YP?ol|}Kp46%t;*z8Upf!2K z)J`0OoO9vHt#n$_y$CFi4kkiS4(o&>;D(R`u0Lad=0(+*%k-&CT^-5T1<5$$S}4c~ z&QiMv2B@mChOJ+v*we-E!_<)mmBoSi>_i;>b~1c8B!VW^vq7?79;;U|ERkNN;0V)c z+<9RNT)1Y4&02m$(b-$D_4n-Dd533_P4(@~N>G!G8`e>Tz$@GJZPg0|>8eHlxhTiLH1VK#Gc&cU? ztUv1pYpdp9n_DQCm>WdYw@m}W_=Eh9@#MBrHtfAU3bbEIkZ3N*EO$$i~6k7?(hXL=u^+$lGEXg=)A zorx$lHu%i#w}3{g8rORJl&NG z->pWWl)oS4-TZ1)(Ps(XzM=H_C*UT4me683SNG>N~B~*!snK7eDGliT-~}1 zn>Rlo8}8PCNHxVJ6KfX(RBR~`c*mY3 zv(u9xa^wQ=lYB<1N7X{JPcYS*Y6$9Q7cfpC<4B&b5gdPAg53u)P-AHxPPiNeo9`sz zsXSNMryz+7tgNv&RTi@%6JXK1G*~-A1f4>RVECwT;A!tLp3At=(^|{H_IoRl2z7%N z#o<)bWFkrj48n|;z9ceaJIwr)0woo?VD7B}1AZeR&3y;#l{CPjhHA7sp^mE0d35Qz zDdzRTRiQO(VfwUZCxAhwHf?jjq1&Ov7m=e%r}n8K1o2 zMtu^LMh?QZ+Y#h|T@}3lTF=geo5{2dZjcNSNvLZbN>-}RMgC7CVwK#)v=b7!i~H6? z=s2cd{y-0n>f`7s4OM6c0aq64PV4g&uz{T!H)*99%;u`VXMH^?x$VZi8k0f!@oqBF zxfoaMX8QC^a>O(}jaql*vKpob>W zehqCpy*U%C>cl|GKZ;xaaS@udCgWxKWW3`j30Gp1&|#V-SL7fCtpjFY%3_y`*X3cL zVmNhW+QXk_PJ{;(R)yJbjuG9MJWR8=PcJX7B2QNAWhl^@xV_s3HQwieeyk4DJv~kRucv|4>0qq? zS&pUE1z2?aG|5~r6fQfpla0Z*h)~X2Ed4M9r^aT$kHm=>R8a)e`RSl9kN_VwIXLk! z8uC=a1atO}1^$qAsC=)B-nqobQl=@ibCW(C+GvV{l&Zk_#XhW=wi*1Dv%vfCVX{d_ z0y~cxQHQ10M6F{qMhiSh%^P9dzr!5YIWj-kS8+NojA_a7Okib^3~nC21`=)Wb7y1g zQRI|9mi-(-uU#`mVi--Yk6Q&vA*-NWPZmYBmqLiTC~R?(L*CR(QvWyxVwbr>^Nk{m zI24Rt*P?*)SI1@UCQy+u4zvxQ6WgcuwAvyZ^b=~BmZ3aM^OV9n+M;ll@tOjU1fe_Q z(t0Q*U^{!a-dw5-hl^{$RrxGk5O2sh0&&E7m?Qp3s70goZq(G$h08FL#V7rY%gM_U ztl(E-|1&Mp5o=90m{fwpk_^(#_@uK3Z$x1S2l$cZ4qti|P;soZ!0tvh20pW;o=5ee z{Qg=r^}k_UcPm~Hxz7nYADo~or!9xgvOUy6Jy%fl=nPd{`H)nynA81-muYk>uHgcwM(Z&y7WJ05D5;>gA$21jp@RWDN0DpG= zOlJ)8Hxy$J9ZMrtg7*!WSoioP{P$kF12vNps=wbv`5pnqo?H z9LbSjT8Ng_f~!}U9+#3HTwhaz4%S1+xqMG>sja|q^X%cO=6*=tTnwL+Gr&xi`72!e z1=-a#*u8olocE7qc$mH9-jP_k{=hD0eCz6&m^O0^SqK{Lvz-DkYwmfjfNcRa4U@niRY6Gxtb2#^a&c=n- zwkS2_D93Y-GL}+1NAj*2k-<+h$aT3XkXuoO{O=nDa`UItqGLY^@6I6TAMudv;xj(d z!61%*i*bd<9u)jsnE-p-DnR<}b}*e1PIh-L!o<2q+%XmAPj2%h>n5W>1Q-@YE4J#Idw-_UEc$rK(JipQ(^KFsq zV7dh9D)i)yDi~Rqg^Lvw@%VEsc4k;Ut~HK9?XKCh?V~S|X%SG7#Pej7awWXhWB$Hp zD{1tn0+e2E1(%zIVD*;0uraxX?&RE{*V_PWG(%DQ>oC}|CKvc=V%*9bb*QVd4Ylh^ z@UV3hP8lwYyrQLo=7APcGoQULTp7>W{h5)Lk!?r`~gIB9IyL?#58p|@Eo z89zJ>Tr;vr+b3~i+oMBQv$G;IGTxAehsE?#uO0?&3?Q6wJ$o;zCt)dN@J?wwG?%H< zSr<7}E=nQG8g|0^?A0*;mIRjh1=H4HJD_1i1SwpbiBuwnl)G+5zTK#%#K3$Uzq16K zhh~D0=R}ZNT0$4MG5w7#PB3EmSX@?-jcbRcpk@48Vk~P3pS$N_U!x|y6kmz5Pu5|x z|G}m~Kf+_gy`NA1gsO!2p&Ymp~6C4+S^>da`FbOtID2BjTS*j%z2o6%$ zxDFQq85zj*(N~TEbINqcYJ-Vry)W*Q-46EZ6qK90Wj zj^1vfBSeNti`d(5aP;u7w-;3x5s_B;{j;OzMtdJ;KYt%j`(L*rP${YUfVpK zH~ynnelPIeUj-iZ_X5lOX9br1&kH>I?*$(7w*t%kcMGh}6BQF3!U`&+{d?*DmJnul z>x4zei;R^X=Y?PSCjXiEl~4K4#IJl){!IMJ_p6(hu)(1sE&7+iD=9Kn zTFl?|rHq9!e6@9({lCG36uoW1P>Jsf|p z#2C?0BI?qTzRo@YzrOvC0m^>qZx-^m#IJHB|1uYi|H@o6|6(rF{%S5-|IS>d|FgOL z9}oNyBLCBot|Bt>ziknctjo0j>@scEOusJZ{Plv9w4JELO0#A2&CIRl+OIUTS}|8z znpu#iuZx$@7H6O7f6OW44E`7}f1m3AadWzsnT^e|`Q|HZ<}RQ9_h;&;R>= z^^@*@)lYhV=_mca_S1}i*G~rj+)w|<^W6xM|KS8Dsv;uy-|DA7I>qp>o${-fjQ)Dz zSL;c$=~9SIm+X(Qw2;p4E&J=*%s;Lf|9S1tmi~8te+{9je@?&4r5F8^84UlAF+5KD zKi0zOe;L{TuHV;4+x*i*#s4-BjMMwega3Bt-#z=^)$ZTjH)4w+_t(WxT!j2}VH++2 z(nI{b{2V<*hl`j<582@5@9Fn@n_IG+v;Lmb^zS*%{#Q=(f8=Dd literal 0 HcmV?d00001 diff --git a/test/TensorFlowNET.Keras.UnitTest/Assets/python_func_model/variables/variables.index b/test/TensorFlowNET.Keras.UnitTest/Assets/python_func_model/variables/variables.index new file mode 100644 index 0000000000000000000000000000000000000000..e0b0e800ae504e35e032121aef03214ebfddd899 GIT binary patch literal 575 zcmZQzVB=tvV&Y(AkP(P?_HcFf4)FK%3vqPvagFzP@^W10F7ZWT~;k PreserveNewest + + + PreserveNewest + + + PreserveNewest + + + PreserveNewest + + + PreserveNewest + + + PreserveNewest + diff --git a/test/TensorFlowNET.Native.UnitTest/Functions/FunctionTest.cs b/test/TensorFlowNET.Native.UnitTest/Functions/FunctionTest.cs index 0872394b..9230bc73 100644 --- a/test/TensorFlowNET.Native.UnitTest/Functions/FunctionTest.cs +++ b/test/TensorFlowNET.Native.UnitTest/Functions/FunctionTest.cs @@ -413,7 +413,7 @@ namespace Tensorflow.Native.UnitTest ASSERT_EQ(TF_OK, s_.Code, s_.Message); ASSERT_NE(func_, IntPtr.Zero); ASSERT_EQ(func_name_, c_api.StringPiece(c_api.TF_FunctionName(func_))); - c_api.TF_GraphCopyFunction(host_graph_, func_, IntPtr.Zero, s_); + c_api.TF_GraphCopyFunction(host_graph_, func_, new SafeFuncGraphHandle(IntPtr.Zero), s_); ASSERT_EQ(TF_OK, s_.Code, s_.Message); }